Merge remote-tracking branch 'origin/main' into masenf/proxy

This commit is contained in:
Masen Furer 2024-12-16 20:00:08 -08:00
commit ce153ee8f7
No known key found for this signature in database
GPG Key ID: 2AE2BD5531FF94F4
61 changed files with 1277 additions and 332 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from utils import send_data_to_posthog
@ -18,7 +19,7 @@ def extract_stats_from_json(json_file: str) -> list[dict]:
Returns:
list[dict]: The stats for each test.
"""
with open(json_file, "r") as file:
with Path(json_file).open() as file:
json_data = json.load(file)
# Load the JSON data if it is a string, otherwise assume it's already a dictionary

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from utils import send_data_to_posthog
@ -18,7 +19,7 @@ def extract_stats_from_json(json_file: str) -> dict:
Returns:
dict: The stats for each test.
"""
with open(json_file, "r") as file:
with Path(json_file).open() as file:
json_data = json.load(file)
# Load the JSON data if it is a string, otherwise assume it's already a dictionary

View File

@ -4,26 +4,19 @@ version = "0.6.7dev1"
description = "Web apps in pure Python."
license = "Apache-2.0"
authors = [
"Nikhil Rao <nikhil@reflex.dev>",
"Alek Petuskey <alek@reflex.dev>",
"Masen Furer <masen@reflex.dev>",
"Elijah Ahianyo <elijah@reflex.dev>",
"Thomas Brandého <thomas@reflex.dev>",
"Nikhil Rao <nikhil@reflex.dev>",
"Alek Petuskey <alek@reflex.dev>",
"Masen Furer <masen@reflex.dev>",
"Elijah Ahianyo <elijah@reflex.dev>",
"Thomas Brandého <thomas@reflex.dev>",
]
readme = "README.md"
homepage = "https://reflex.dev"
repository = "https://github.com/reflex-dev/reflex"
documentation = "https://reflex.dev/docs/getting-started/introduction"
keywords = [
"web",
"framework",
]
classifiers = [
"Development Status :: 4 - Beta",
]
packages = [
{include = "reflex"}
]
keywords = ["web", "framework"]
classifiers = ["Development Status :: 4 - Beta"]
packages = [{ include = "reflex" }]
[tool.poetry.dependencies]
python = "^3.9"
@ -42,11 +35,11 @@ uvicorn = ">=0.20.0"
starlette-admin = ">=0.11.0,<1.0"
alembic = ">=1.11.1,<2.0"
platformdirs = ">=3.10.0,<5.0"
distro = {version = ">=1.8.0,<2.0", platform = "linux"}
distro = { version = ">=1.8.0,<2.0", platform = "linux" }
python-engineio = "!=4.6.0"
wrapt = [
{version = ">=1.14.0,<2.0", python = ">=3.11"},
{version = ">=1.11.0,<2.0", python = "<3.11"},
{ version = ">=1.14.0,<2.0", python = ">=3.11" },
{ version = ">=1.11.0,<2.0", python = "<3.11" },
]
packaging = ">=23.1,<25.0"
reflex-hosting-cli = ">=0.1.29,<2.0"
@ -97,14 +90,15 @@ build-backend = "poetry.core.masonry.api"
[tool.ruff]
target-version = "py39"
output-format = "concise"
lint.isort.split-on-trailing-comma = false
lint.select = ["B", "D", "E", "F", "I", "SIM", "W", "RUF", "FURB", "ERA"]
lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "W"]
lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012"]
lint.pydocstyle.convention = "google"
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
"tests/*.py" = ["D100", "D103", "D104", "B018"]
"tests/*.py" = ["D100", "D103", "D104", "B018", "PERF"]
"reflex/.templates/*.py" = ["D100", "D103", "D104"]
"*.pyi" = ["D301", "D415", "D417", "D418", "E742"]
"*/blank.py" = ["I001"]

View File

@ -5,11 +5,15 @@ export function {{tag_name}} () {
{{ hook }}
{% endfor %}
{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
{{ hook }}
{% endfor %}
{% for hook in memo_trigger_hooks %}
{{ hook }}
{% endfor %}
{% for hook in component._get_all_hooks() %}
{% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
{{ hook }}
{% endfor %}

View File

@ -442,7 +442,7 @@ class App(MiddlewareMixin, LifespanMixin):
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_origins=["*"],
allow_origins=get_config().cors_allowed_origins,
)
@property
@ -1301,7 +1301,7 @@ async def process(
await asyncio.create_task(
app.event_namespace.emit(
"reload",
data=format.json_dumps(event),
data=event,
to=sid,
)
)

View File

@ -30,15 +30,16 @@ def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None
# can't use reflex.config.environment here cause of circular import
reload = os.getenv("__RELOAD_CONFIG", "").lower() == "true"
for base in bases:
try:
base = None
try:
for base in bases:
if not reload and getattr(base, field_name, None):
pass
except TypeError as te:
raise VarNameError(
f'State var "{field_name}" in {base} has been shadowed by a substate var; '
f'use a different field name instead".'
) from te
except TypeError as te:
raise VarNameError(
f'State var "{field_name}" in {base} has been shadowed by a substate var; '
f'use a different field name instead".'
) from te
# monkeypatch pydantic validate_field_name method to skip validating

View File

@ -45,6 +45,7 @@ class ReflexJinjaEnvironment(Environment):
"on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
"update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL,
"frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL,
"hook_position": constants.Hooks.HookPosition,
}

View File

@ -115,7 +115,7 @@ def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
default, rest = compile_import_statement(fields)
# prevent lib from being rendered on the page if all imports are non rendered kind
if not any({f.render for f in fields}): # type: ignore
if not any(f.render for f in fields): # type: ignore
continue
if not lib:
@ -123,8 +123,7 @@ def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
raise ValueError("No default field allowed for empty library.")
if rest is None or len(rest) == 0:
raise ValueError("No fields to import.")
for module in sorted(rest):
import_dicts.append(get_import_dict(module))
import_dicts.extend(get_import_dict(module) for module in sorted(rest))
continue
# remove the version before rendering the package imports

View File

@ -1208,7 +1208,7 @@ class Component(BaseComponent, ABC):
Yields:
The parent classes that define the method (differently than the base).
"""
seen_methods = set([getattr(Component, method)])
seen_methods = {getattr(Component, method)}
for clz in cls.mro():
if clz is Component:
break
@ -1368,7 +1368,9 @@ class Component(BaseComponent, ABC):
if user_hooks_data is not None:
other_imports.append(user_hooks_data.imports)
other_imports.extend(
hook_imports for hook_imports in self._get_added_hooks().values()
hook_vardata.imports
for hook_vardata in self._get_added_hooks().values()
if hook_vardata is not None
)
return imports.merge_imports(_imports, *other_imports)
@ -1390,15 +1392,9 @@ class Component(BaseComponent, ABC):
# Collect imports from Vars used directly by this component.
var_datas = [var._get_all_var_data() for var in self._get_vars()]
var_imports: List[ImmutableParsedImportDict] = list(
map(
lambda var_data: var_data.imports,
filter(
None,
var_datas,
),
)
)
var_imports: List[ImmutableParsedImportDict] = [
var_data.imports for var_data in var_datas if var_data is not None
]
added_import_dicts: list[ParsedImportDict] = []
for clz in self._iter_parent_classes_with_method("add_imports"):
@ -1407,8 +1403,9 @@ class Component(BaseComponent, ABC):
if not isinstance(list_of_import_dict, list):
list_of_import_dict = [list_of_import_dict]
for import_dict in list_of_import_dict:
added_import_dicts.append(parse_imports(import_dict))
added_import_dicts.extend(
[parse_imports(import_dict) for import_dict in list_of_import_dict]
)
return imports.merge_imports(
*self._get_props_imports(),
@ -1522,7 +1519,7 @@ class Component(BaseComponent, ABC):
**self._get_special_hooks(),
}
def _get_added_hooks(self) -> dict[str, ImportDict]:
def _get_added_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks added via `add_hooks` method.
Returns:
@ -1531,17 +1528,15 @@ class Component(BaseComponent, ABC):
code = {}
def extract_var_hooks(hook: Var):
_imports = {}
var_data = VarData.merge(hook._get_all_var_data())
if var_data is not None:
for sub_hook in var_data.hooks:
code[sub_hook] = {}
if var_data.imports:
_imports = var_data.imports
code[sub_hook] = None
if str(hook) in code:
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
code[str(hook)] = VarData.merge(var_data, code[str(hook)])
else:
code[str(hook)] = _imports
code[str(hook)] = var_data
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output)
@ -1550,7 +1545,7 @@ class Component(BaseComponent, ABC):
if isinstance(hook, Var):
extract_var_hooks(hook)
else:
code[hook] = {}
code[hook] = None
return code
@ -1592,8 +1587,7 @@ class Component(BaseComponent, ABC):
if hooks is not None:
code[hooks] = None
for hook in self._get_added_hooks():
code[hook] = None
code.update(self._get_added_hooks())
# Add the hook code for the children.
for child in self.children:
@ -2195,6 +2189,31 @@ class StatefulComponent(BaseComponent):
]
return [var_name]
@staticmethod
def _get_deps_from_event_trigger(event: EventChain | EventSpec | Var) -> set[str]:
"""Get the dependencies accessed by event triggers.
Args:
event: The event trigger to extract deps from.
Returns:
The dependencies accessed by the event triggers.
"""
events: list = [event]
deps = set()
if isinstance(event, EventChain):
events.extend(event.events)
for ev in events:
if isinstance(ev, EventSpec):
for arg in ev.args:
for a in arg:
var_datas = VarData.merge(a._get_all_var_data())
if var_datas and var_datas.deps is not None:
deps |= {str(dep) for dep in var_datas.deps}
return deps
@classmethod
def _get_memoized_event_triggers(
cls,
@ -2231,6 +2250,11 @@ class StatefulComponent(BaseComponent):
# Calculate Var dependencies accessed by the handler for useCallback dep array.
var_deps = ["addEvents", "Event"]
# Get deps from event trigger var data.
var_deps.extend(cls._get_deps_from_event_trigger(event))
# Get deps from hooks.
for arg in event_args:
var_data = arg._get_all_var_data()
if var_data is None:

View File

@ -6,11 +6,12 @@ from typing import Dict, List, Tuple, Union
from reflex.components.base.fragment import Fragment
from reflex.components.tags.tag import Tag
from reflex.constants.compiler import Hooks
from reflex.event import EventChain, EventHandler, passthrough_event_spec
from reflex.utils.format import format_prop, wrap
from reflex.utils.imports import ImportVar
from reflex.vars import get_unique_variable_name
from reflex.vars.base import Var
from reflex.vars.base import Var, VarData
class Clipboard(Fragment):
@ -72,7 +73,7 @@ class Clipboard(Fragment):
),
}
def add_hooks(self) -> list[str]:
def add_hooks(self) -> list[str | Var[str]]:
"""Add hook to register paste event listener.
Returns:
@ -83,13 +84,14 @@ class Clipboard(Fragment):
return []
if isinstance(on_paste, EventChain):
on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})"
return [
"usePasteHandler(%s, %s, %s)"
% (
str(self.targets),
str(self.on_paste_event_actions),
on_paste,
)
Var(
hook_expr,
_var_type="str",
_var_data=VarData(position=Hooks.HookPosition.POST_TRIGGER),
),
]

View File

@ -71,6 +71,6 @@ class Clipboard(Fragment):
...
def add_imports(self) -> dict[str, ImportVar]: ...
def add_hooks(self) -> list[str]: ...
def add_hooks(self) -> list[str | Var[str]]: ...
clipboard = Clipboard.create

View File

@ -339,8 +339,11 @@ class DataEditor(NoSSRComponent):
editor_id = get_unique_variable_name()
# Define the name of the getData callback associated with this component and assign to get_cell_content.
data_callback = f"getData_{editor_id}"
self.get_cell_content = Var(_js_expr=data_callback) # type: ignore
if self.get_cell_content is not None:
data_callback = self.get_cell_content._js_expr
else:
data_callback = f"getData_{editor_id}"
self.get_cell_content = Var(_js_expr=data_callback) # type: ignore
code = [f"function {data_callback}([col, row])" "{"]

View File

@ -127,7 +127,7 @@ _MAPPING = {
EXCLUDE = ["del_", "Del", "image"]
for _, v in _MAPPING.items():
for v in _MAPPING.values():
v.extend([mod.capitalize() for mod in v if mod not in EXCLUDE])
_SUBMOD_ATTRS: dict[str, list[str]] = _MAPPING

View File

@ -339,5 +339,5 @@ _MAPPING = {
],
}
EXCLUDE = ["del_", "Del", "image"]
for _, v in _MAPPING.items():
for v in _MAPPING.values():
v.extend([mod.capitalize() for mod in v if mod not in EXCLUDE])

View File

@ -18,6 +18,7 @@ from reflex.event import (
prevent_default,
)
from reflex.utils.imports import ImportDict
from reflex.utils.types import is_optional
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
@ -382,6 +383,33 @@ class Input(BaseHTML):
# Fired when a key is released
on_key_up: EventHandler[key_event]
@classmethod
def create(cls, *children, **props):
"""Create an Input component.
Args:
*children: The children of the component.
**props: The properties of the component.
Returns:
The component.
"""
from reflex.vars.number import ternary_operation
value = props.get("value")
# React expects an empty string(instead of null) for controlled inputs.
if value is not None and is_optional(
(value_var := Var.create(value))._var_type
):
props["value"] = ternary_operation(
(value_var != Var.create(None)) # pyright: ignore [reportGeneralTypeIssues]
& (value_var != Var(_js_expr="undefined")),
value,
Var.create(""),
)
return super().create(*children, **props)
class Label(BaseHTML):
"""Display the label element."""

View File

@ -512,7 +512,7 @@ class Input(BaseHTML):
on_unmount: Optional[EventType[[], BASE_STATE]] = None,
**props,
) -> "Input":
"""Create the component.
"""Create an Input component.
Args:
*children: The children of the component.
@ -576,7 +576,7 @@ class Input(BaseHTML):
class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute
**props: The props of the component.
**props: The properties of the component.
Returns:
The component.

View File

@ -8,6 +8,7 @@ from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spe
from reflex.vars.base import Var
from ..base import LiteralAccentColor, RadixThemesComponent
from .checkbox import Checkbox
LiteralDirType = Literal["ltr", "rtl"]
@ -232,6 +233,15 @@ class ContextMenuSeparator(RadixThemesComponent):
tag = "ContextMenu.Separator"
class ContextMenuCheckbox(Checkbox):
"""The component that contains the checkbox."""
tag = "ContextMenu.CheckboxItem"
# Text to render as shortcut.
shortcut: Var[str]
class ContextMenu(ComponentNamespace):
"""Menu representing a set of actions, displayed at the origin of a pointer right-click or long-press."""
@ -243,6 +253,7 @@ class ContextMenu(ComponentNamespace):
sub_content = staticmethod(ContextMenuSubContent.create)
item = staticmethod(ContextMenuItem.create)
separator = staticmethod(ContextMenuSeparator.create)
checkbox = staticmethod(ContextMenuCheckbox.create)
context_menu = ContextMenu()

View File

@ -12,6 +12,7 @@ from reflex.style import Style
from reflex.vars.base import Var
from ..base import RadixThemesComponent
from .checkbox import Checkbox
LiteralDirType = Literal["ltr", "rtl"]
LiteralSizeType = Literal["1", "2"]
@ -672,6 +673,159 @@ class ContextMenuSeparator(RadixThemesComponent):
"""
...
class ContextMenuCheckbox(Checkbox):
@overload
@classmethod
def create( # type: ignore
cls,
*children,
shortcut: Optional[Union[Var[str], str]] = None,
as_child: Optional[Union[Var[bool], bool]] = None,
size: Optional[
Union[
Breakpoints[str, Literal["1", "2", "3"]],
Literal["1", "2", "3"],
Var[
Union[
Breakpoints[str, Literal["1", "2", "3"]], Literal["1", "2", "3"]
]
],
]
] = None,
variant: Optional[
Union[
Literal["classic", "soft", "surface"],
Var[Literal["classic", "soft", "surface"]],
]
] = None,
color_scheme: Optional[
Union[
Literal[
"amber",
"blue",
"bronze",
"brown",
"crimson",
"cyan",
"gold",
"grass",
"gray",
"green",
"indigo",
"iris",
"jade",
"lime",
"mint",
"orange",
"pink",
"plum",
"purple",
"red",
"ruby",
"sky",
"teal",
"tomato",
"violet",
"yellow",
],
Var[
Literal[
"amber",
"blue",
"bronze",
"brown",
"crimson",
"cyan",
"gold",
"grass",
"gray",
"green",
"indigo",
"iris",
"jade",
"lime",
"mint",
"orange",
"pink",
"plum",
"purple",
"red",
"ruby",
"sky",
"teal",
"tomato",
"violet",
"yellow",
]
],
]
] = None,
high_contrast: Optional[Union[Var[bool], bool]] = None,
default_checked: Optional[Union[Var[bool], bool]] = None,
checked: Optional[Union[Var[bool], bool]] = None,
disabled: Optional[Union[Var[bool], bool]] = None,
required: Optional[Union[Var[bool], bool]] = None,
name: Optional[Union[Var[str], str]] = None,
value: Optional[Union[Var[str], str]] = None,
style: Optional[Style] = None,
key: Optional[Any] = None,
id: Optional[Any] = None,
class_name: Optional[Any] = None,
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
on_blur: Optional[EventType[[], BASE_STATE]] = None,
on_change: Optional[
Union[EventType[[], BASE_STATE], EventType[[bool], BASE_STATE]]
] = None,
on_click: Optional[EventType[[], BASE_STATE]] = None,
on_context_menu: Optional[EventType[[], BASE_STATE]] = None,
on_double_click: Optional[EventType[[], BASE_STATE]] = None,
on_focus: Optional[EventType[[], BASE_STATE]] = None,
on_mount: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_down: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_move: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_out: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_over: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_up: Optional[EventType[[], BASE_STATE]] = None,
on_scroll: Optional[EventType[[], BASE_STATE]] = None,
on_unmount: Optional[EventType[[], BASE_STATE]] = None,
**props,
) -> "ContextMenuCheckbox":
"""Create a new component instance.
Will prepend "RadixThemes" to the component tag to avoid conflicts with
other UI libraries for common names, like Text and Button.
Args:
*children: Child components.
shortcut: Text to render as shortcut.
as_child: Change the default rendered element for the one passed as a child, merging their props and behavior.
size: Checkbox size "1" - "3"
variant: Variant of checkbox: "classic" | "surface" | "soft"
color_scheme: Override theme color for checkbox
high_contrast: Whether to render the checkbox with higher contrast color against background
default_checked: Whether the checkbox is checked by default
checked: Whether the checkbox is checked
disabled: Whether the checkbox is disabled
required: Whether the checkbox is required
name: The name of the checkbox control when submitting the form.
value: The value of the checkbox control when submitting the form.
on_change: Fired when the checkbox is checked or unchecked.
style: The style of the component.
key: A unique key for the component.
id: The id for the component.
class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute
**props: Component properties.
Returns:
A new component instance.
"""
...
class ContextMenu(ComponentNamespace):
root = staticmethod(ContextMenuRoot.create)
trigger = staticmethod(ContextMenuTrigger.create)
@ -681,5 +835,6 @@ class ContextMenu(ComponentNamespace):
sub_content = staticmethod(ContextMenuSubContent.create)
item = staticmethod(ContextMenuItem.create)
separator = staticmethod(ContextMenuSeparator.create)
checkbox = staticmethod(ContextMenuCheckbox.create)
context_menu = ContextMenu()

View File

@ -79,7 +79,7 @@ class IconButton(elements.Button, RadixLoadingProp, RadixThemesComponent):
else:
size_map_var = Match.create(
props["size"],
*[(size, px) for size, px in RADIX_TO_LUCIDE_SIZE.items()],
*list(RADIX_TO_LUCIDE_SIZE.items()),
12,
)
if not isinstance(size_map_var, Var):

View File

@ -9,7 +9,9 @@ from reflex.components.core.breakpoints import Responsive
from reflex.components.core.debounce import DebounceInput
from reflex.components.el import elements
from reflex.event import EventHandler, input_event, key_event
from reflex.utils.types import is_optional
from reflex.vars.base import Var
from reflex.vars.number import ternary_operation
from ..base import LiteralAccentColor, LiteralRadius, RadixThemesComponent
@ -17,7 +19,7 @@ LiteralTextFieldSize = Literal["1", "2", "3"]
LiteralTextFieldVariant = Literal["classic", "surface", "soft"]
class TextFieldRoot(elements.Div, RadixThemesComponent):
class TextFieldRoot(elements.Input, RadixThemesComponent):
"""Captures user input with an optional slot for buttons and icons."""
tag = "TextField.Root"
@ -96,6 +98,19 @@ class TextFieldRoot(elements.Div, RadixThemesComponent):
Returns:
The component.
"""
value = props.get("value")
# React expects an empty string(instead of null) for controlled inputs.
if value is not None and is_optional(
(value_var := Var.create(value))._var_type
):
props["value"] = ternary_operation(
(value_var != Var.create(None)) # pyright: ignore [reportGeneralTypeIssues]
& (value_var != Var(_js_expr="undefined")),
value,
Var.create(""),
)
component = super().create(*children, **props)
if props.get("value") is not None and props.get("on_change") is not None:
# create a debounced input if the user requests full control to avoid typing jank

View File

@ -17,7 +17,7 @@ from ..base import RadixThemesComponent
LiteralTextFieldSize = Literal["1", "2", "3"]
LiteralTextFieldVariant = Literal["classic", "surface", "soft"]
class TextFieldRoot(elements.Div, RadixThemesComponent):
class TextFieldRoot(elements.Input, RadixThemesComponent):
@overload
@classmethod
def create( # type: ignore
@ -120,6 +120,30 @@ class TextFieldRoot(elements.Div, RadixThemesComponent):
type: Optional[Union[Var[str], str]] = None,
value: Optional[Union[Var[Union[float, int, str]], float, int, str]] = None,
list: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
accept: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
alt: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
auto_focus: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
capture: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
checked: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
default_checked: Optional[Union[Var[bool], bool]] = None,
dirname: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form_action: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form_enc_type: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
] = None,
form_method: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form_no_validate: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
] = None,
form_target: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
max: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
min: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
multiple: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
pattern: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
src: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
step: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
use_map: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
access_key: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
auto_capitalize: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
@ -192,12 +216,12 @@ class TextFieldRoot(elements.Div, RadixThemesComponent):
Args:
*children: The children of the component.
size: Text field size "1" - "3"
size: Specifies the visible width of a text control
variant: Variant of text field: "classic" | "surface" | "soft"
color_scheme: Override theme color for text field
radius: Override theme radius for text field: "none" | "small" | "medium" | "large" | "full"
auto_complete: Whether the input should have autocomplete enabled
default_value: The value of the input when initially rendered.
default_value: The initial value for a text field
disabled: Disables the input
max_length: Specifies the maximum number of characters allowed in the input
min_length: Specifies the minimum number of characters required in the input
@ -208,11 +232,31 @@ class TextFieldRoot(elements.Div, RadixThemesComponent):
type: Specifies the type of input
value: Value of the input
list: References a datalist for suggested options
on_change: Fired when the value of the textarea changes.
on_focus: Fired when the textarea is focused.
on_blur: Fired when the textarea is blurred.
on_key_down: Fired when a key is pressed down.
on_key_up: Fired when a key is released.
on_change: Fired when the input value changes
on_focus: Fired when the input gains focus
on_blur: Fired when the input loses focus
on_key_down: Fired when a key is pressed down
on_key_up: Fired when a key is released
accept: Accepted types of files when the input is file type
alt: Alternate text for input type="image"
auto_focus: Automatically focuses the input when the page loads
capture: Captures media from the user (camera or microphone)
checked: Indicates whether the input is checked (for checkboxes and radio buttons)
default_checked: The initial value (for checkboxes and radio buttons)
dirname: Name part of the input to submit in 'dir' and 'name' pair when form is submitted
form: Associates the input with a form (by id)
form_action: URL to send the form data to (for type="submit" buttons)
form_enc_type: How the form data should be encoded when submitting to the server (for type="submit" buttons)
form_method: HTTP method to use for sending form data (for type="submit" buttons)
form_no_validate: Bypasses form validation when submitting (for type="submit" buttons)
form_target: Specifies where to display the response after submitting the form (for type="submit" buttons)
max: Specifies the maximum value for the input
min: Specifies the minimum value for the input
multiple: Indicates whether multiple values can be entered in an input of the type email or file
pattern: Regex pattern the input's value must match to be valid
src: URL for image inputs
step: Specifies the legal number intervals for an input
use_map: Name of the image map used with the input
access_key: Provides a hint for generating a keyboard shortcut for the current element.
auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user.
content_editable: Indicates whether the element's content is editable.
@ -457,6 +501,30 @@ class TextField(ComponentNamespace):
type: Optional[Union[Var[str], str]] = None,
value: Optional[Union[Var[Union[float, int, str]], float, int, str]] = None,
list: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
accept: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
alt: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
auto_focus: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
capture: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
checked: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
default_checked: Optional[Union[Var[bool], bool]] = None,
dirname: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form_action: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form_enc_type: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
] = None,
form_method: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form_no_validate: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
] = None,
form_target: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
max: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
min: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
multiple: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
pattern: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
src: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
step: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
use_map: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
access_key: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
auto_capitalize: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
@ -529,12 +597,12 @@ class TextField(ComponentNamespace):
Args:
*children: The children of the component.
size: Text field size "1" - "3"
size: Specifies the visible width of a text control
variant: Variant of text field: "classic" | "surface" | "soft"
color_scheme: Override theme color for text field
radius: Override theme radius for text field: "none" | "small" | "medium" | "large" | "full"
auto_complete: Whether the input should have autocomplete enabled
default_value: The value of the input when initially rendered.
default_value: The initial value for a text field
disabled: Disables the input
max_length: Specifies the maximum number of characters allowed in the input
min_length: Specifies the minimum number of characters required in the input
@ -545,11 +613,31 @@ class TextField(ComponentNamespace):
type: Specifies the type of input
value: Value of the input
list: References a datalist for suggested options
on_change: Fired when the value of the textarea changes.
on_focus: Fired when the textarea is focused.
on_blur: Fired when the textarea is blurred.
on_key_down: Fired when a key is pressed down.
on_key_up: Fired when a key is released.
on_change: Fired when the input value changes
on_focus: Fired when the input gains focus
on_blur: Fired when the input loses focus
on_key_down: Fired when a key is pressed down
on_key_up: Fired when a key is released
accept: Accepted types of files when the input is file type
alt: Alternate text for input type="image"
auto_focus: Automatically focuses the input when the page loads
capture: Captures media from the user (camera or microphone)
checked: Indicates whether the input is checked (for checkboxes and radio buttons)
default_checked: The initial value (for checkboxes and radio buttons)
dirname: Name part of the input to submit in 'dir' and 'name' pair when form is submitted
form: Associates the input with a form (by id)
form_action: URL to send the form data to (for type="submit" buttons)
form_enc_type: How the form data should be encoded when submitting to the server (for type="submit" buttons)
form_method: HTTP method to use for sending form data (for type="submit" buttons)
form_no_validate: Bypasses form validation when submitting (for type="submit" buttons)
form_target: Specifies where to display the response after submitting the form (for type="submit" buttons)
max: Specifies the maximum value for the input
min: Specifies the minimum value for the input
multiple: Indicates whether multiple values can be entered in an input of the type email or file
pattern: Regex pattern the input's value must match to be valid
src: URL for image inputs
step: Specifies the legal number intervals for an input
use_map: Name of the image map used with the input
access_key: Provides a hint for generating a keyboard shortcut for the current element.
auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user.
content_editable: Indicates whether the element's content is editable.

View File

@ -84,10 +84,10 @@ class ChartBase(RechartsCharts):
cls._ensure_valid_dimension("width", width)
cls._ensure_valid_dimension("height", height)
dim_props = dict(
width=width or "100%",
height=height or "100%",
)
dim_props = {
"width": width or "100%",
"height": height or "100%",
}
# Provide min dimensions so the graph always appears, even if the outer container is zero-size.
if width is None:
dim_props["min_width"] = 200

View File

@ -684,6 +684,9 @@ class Config(Base):
# Maximum expiration lock time for redis state manager
redis_lock_expiration: int = constants.Expiration.LOCK
# Maximum lock time before warning for redis state manager.
redis_lock_warning_threshold: int = constants.Expiration.LOCK_WARNING_THRESHOLD
# Token expiration time for redis state manager
redis_token_expiration: int = constants.Expiration.TOKEN
@ -870,7 +873,7 @@ def get_config(reload: bool = False) -> Config:
with _config_lock:
sys_path = sys.path.copy()
sys.path.clear()
sys.path.append(os.getcwd())
sys.path.append(str(Path.cwd()))
try:
# Try to import the module with only the current directory in the path.
return _get_config()

View File

@ -132,6 +132,12 @@ class Hooks(SimpleNamespace):
}
})"""
class HookPosition(enum.Enum):
"""The position of the hook in the component."""
PRE_TRIGGER = "pre_trigger"
POST_TRIGGER = "post_trigger"
class MemoizationDisposition(enum.Enum):
"""The conditions under which a component should be memoized."""

View File

@ -29,6 +29,8 @@ class Expiration(SimpleNamespace):
LOCK = 10000
# The PING timeout
PING = 120
# The maximum time in milliseconds to hold a lock before throwing a warning.
LOCK_WARNING_THRESHOLD = 1000
class GitIgnore(SimpleNamespace):

View File

@ -10,7 +10,7 @@ class CustomComponents(SimpleNamespace):
"""Constants for the custom components."""
# The name of the custom components source directory.
SRC_DIR = "custom_components"
SRC_DIR = Path("custom_components")
# The name of the custom components pyproject.toml file.
PYPROJECT_TOML = Path("pyproject.toml")
# The name of the custom components package README file.

View File

@ -31,7 +31,7 @@ class RouteVar(SimpleNamespace):
# This subset of router_data is included in chained on_load events.
ROUTER_DATA_INCLUDE = set((RouteVar.PATH, RouteVar.ORIGIN, RouteVar.QUERY))
ROUTER_DATA_INCLUDE = {RouteVar.PATH, RouteVar.ORIGIN, RouteVar.QUERY}
class RouteRegex(SimpleNamespace):

View File

@ -150,27 +150,27 @@ def _populate_demo_app(name_variants: NameVariants):
from reflex.compiler import templates
from reflex.reflex import _init
demo_app_dir = name_variants.demo_app_dir
demo_app_dir = Path(name_variants.demo_app_dir)
demo_app_name = name_variants.demo_app_name
console.info(f"Creating app for testing: {demo_app_dir}")
console.info(f"Creating app for testing: {demo_app_dir!s}")
os.makedirs(demo_app_dir)
demo_app_dir.mkdir(exist_ok=True)
with set_directory(demo_app_dir):
# We start with the blank template as basis.
_init(name=demo_app_name, template=constants.Templates.DEFAULT)
# Then overwrite the app source file with the one we want for testing custom components.
# This source file is rendered using jinja template file.
with open(f"{demo_app_name}/{demo_app_name}.py", "w") as f:
f.write(
templates.CUSTOM_COMPONENTS_DEMO_APP.render(
custom_component_module_dir=name_variants.custom_component_module_dir,
module_name=name_variants.module_name,
)
demo_file = Path(f"{demo_app_name}/{demo_app_name}.py")
demo_file.write_text(
templates.CUSTOM_COMPONENTS_DEMO_APP.render(
custom_component_module_dir=name_variants.custom_component_module_dir,
module_name=name_variants.module_name,
)
)
# Append the custom component package to the requirements.txt file.
with open(f"{constants.RequirementsTxt.FILE}", "a") as f:
with Path(f"{constants.RequirementsTxt.FILE}").open(mode="a") as f:
f.write(f"{name_variants.package_name}\n")
@ -296,13 +296,14 @@ def _populate_custom_component_project(name_variants: NameVariants):
)
console.info(
f"Initializing the component directory: {CustomComponents.SRC_DIR}/{name_variants.custom_component_module_dir}"
f"Initializing the component directory: {CustomComponents.SRC_DIR / name_variants.custom_component_module_dir}"
)
os.makedirs(CustomComponents.SRC_DIR)
CustomComponents.SRC_DIR.mkdir(exist_ok=True)
with set_directory(CustomComponents.SRC_DIR):
os.makedirs(name_variants.custom_component_module_dir)
module_dir = Path(name_variants.custom_component_module_dir)
module_dir.mkdir(exist_ok=True, parents=True)
_write_source_and_init_py(
custom_component_src_dir=name_variants.custom_component_module_dir,
custom_component_src_dir=module_dir,
component_class_name=name_variants.component_class_name,
module_name=name_variants.module_name,
)
@ -814,7 +815,7 @@ def _validate_project_info():
)
pyproject_toml["project"] = project
try:
with open(CustomComponents.PYPROJECT_TOML, "w") as f:
with CustomComponents.PYPROJECT_TOML.open("w") as f:
tomlkit.dump(pyproject_toml, f)
except (OSError, TOMLKitError) as ex:
console.error(f"Unable to write to pyproject.toml due to {ex}")
@ -922,16 +923,15 @@ def _validate_url_with_protocol_prefix(url: str | None) -> bool:
def _get_file_from_prompt_in_loop() -> Tuple[bytes, str] | None:
image_file = file_extension = None
while image_file is None:
image_filepath = console.ask(
"Upload a preview image of your demo app (enter to skip)"
image_filepath = Path(
console.ask("Upload a preview image of your demo app (enter to skip)")
)
if not image_filepath:
break
file_extension = image_filepath.split(".")[-1]
file_extension = image_filepath.suffix
try:
with open(image_filepath, "rb") as f:
image_file = f.read()
return image_file, file_extension
image_file = image_filepath.read_bytes()
return image_file, file_extension
except OSError as ose:
console.error(f"Unable to read the {file_extension} file due to {ose}")
raise typer.Exit(code=1) from ose

View File

@ -25,6 +25,7 @@ from typing import (
overload,
)
import typing_extensions
from typing_extensions import (
Concatenate,
ParamSpec,
@ -296,7 +297,7 @@ class EventSpec(EventActionsMixin):
handler: EventHandler,
event_actions: Dict[str, Union[bool, int]] | None = None,
client_handler_name: str = "",
args: Tuple[Tuple[Var, Var], ...] = tuple(),
args: Tuple[Tuple[Var, Var], ...] = (),
):
"""Initialize an EventSpec.
@ -311,7 +312,7 @@ class EventSpec(EventActionsMixin):
object.__setattr__(self, "event_actions", event_actions)
object.__setattr__(self, "handler", handler)
object.__setattr__(self, "client_handler_name", client_handler_name)
object.__setattr__(self, "args", args or tuple())
object.__setattr__(self, "args", args or ())
def with_args(self, args: Tuple[Tuple[Var, Var], ...]) -> EventSpec:
"""Copy the event spec, with updated args.
@ -349,13 +350,14 @@ class EventSpec(EventActionsMixin):
# Construct the payload.
values = []
for arg in args:
try:
values.append(LiteralVar.create(arg))
except TypeError as e:
raise EventHandlerTypeError(
f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}."
) from e
arg = None
try:
for arg in args:
values.append(LiteralVar.create(value=arg)) # noqa: PERF401
except TypeError as e:
raise EventHandlerTypeError(
f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}."
) from e
new_payload = tuple(zip(fn_args, values))
return self.with_args(self.args + new_payload)
@ -513,7 +515,7 @@ def no_args_event_spec() -> Tuple[()]:
Returns:
An empty tuple.
"""
return tuple() # type: ignore
return () # type: ignore
# These chains can be used for their side effects when no other events are desired.
@ -714,26 +716,61 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
)
@overload
def redirect(
path: str | Var[str],
external: Optional[bool] = False,
replace: Optional[bool] = False,
is_external: Optional[bool] = None,
replace: bool = False,
) -> EventSpec: ...
@overload
@typing_extensions.deprecated("`external` is deprecated use `is_external` instead")
def redirect(
path: str | Var[str],
is_external: Optional[bool] = None,
replace: bool = False,
external: Optional[bool] = None,
) -> EventSpec: ...
def redirect(
path: str | Var[str],
is_external: Optional[bool] = None,
replace: bool = False,
external: Optional[bool] = None,
) -> EventSpec:
"""Redirect to a new path.
Args:
path: The path to redirect to.
external: Whether to open in new tab or not.
is_external: Whether to open in new tab or not.
replace: If True, the current page will not create a new history entry.
external(Deprecated): Whether to open in new tab or not.
Returns:
An event to redirect to the path.
"""
if external is not None:
console.deprecate(
"The `external` prop in `rx.redirect`",
"use `is_external` instead.",
"0.6.6",
"0.7.0",
)
# is_external should take precedence over external.
is_external = (
(False if external is None else external)
if is_external is None
else is_external
)
return server_side(
"_redirect",
get_fn_signature(redirect),
path=path,
external=external,
external=is_external,
replace=replace,
)
@ -1101,9 +1138,7 @@ def run_script(
Var(javascript_code) if isinstance(javascript_code, str) else javascript_code
)
return call_function(
ArgsFunctionOperation.create(tuple(), javascript_code), callback
)
return call_function(ArgsFunctionOperation.create((), javascript_code), callback)
def get_event(state, event):
@ -1455,7 +1490,7 @@ def get_handler_args(
"""
args = inspect.getfullargspec(event_spec.handler.fn).args
return event_spec.args if len(args) > 1 else tuple()
return event_spec.args if len(args) > 1 else ()
def fix_events(

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import re
from collections import defaultdict
from contextlib import suppress
from typing import Any, ClassVar, Optional, Type, Union
import alembic.autogenerate
@ -52,12 +53,12 @@ def get_engine_args(url: str | None = None) -> dict[str, Any]:
Returns:
The database engine arguments as a dict.
"""
kwargs: dict[str, Any] = dict(
kwargs: dict[str, Any] = {
# Print the SQL queries if the log level is INFO or lower.
echo=environment.SQLALCHEMY_ECHO.get(),
"echo": environment.SQLALCHEMY_ECHO.get(),
# Check connections before returning them.
pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(),
)
"pool_pre_ping": environment.SQLALCHEMY_POOL_PRE_PING.get(),
}
conf = get_config()
url = url or conf.db_url
if url is not None and url.startswith("sqlite"):
@ -290,11 +291,10 @@ class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssue
relationships = {}
# SQLModel relationships do not appear in __fields__, but should be included if present.
for name in self.__sqlmodel_relationships__:
try:
with suppress(
sqlalchemy.orm.exc.DetachedInstanceError # This happens when the relationship was never loaded and the session is closed.
):
relationships[name] = self._dict_recursive(getattr(self, name))
except sqlalchemy.orm.exc.DetachedInstanceError:
# This happens when the relationship was never loaded and the session is closed.
continue
return {
**base_fields,
**relationships,

View File

@ -3,7 +3,6 @@
from __future__ import annotations
import atexit
import os
from pathlib import Path
from typing import List, Optional
@ -298,7 +297,7 @@ def export(
True, "--frontend-only", help="Export only frontend.", show_default=False
),
zip_dest_dir: str = typer.Option(
os.getcwd(),
str(Path.cwd()),
help="The directory to export the zip files to.",
show_default=False,
),
@ -443,13 +442,13 @@ def deploy(
hidden=True,
),
regions: List[str] = typer.Option(
list(),
[],
"-r",
"--region",
help="The regions to deploy to. `reflex cloud regions` For multiple envs, repeat this option, e.g. --region sjc --region iad",
),
envs: List[str] = typer.Option(
list(),
[],
"--env",
help="The environment variables to set: <key>=<value>. For multiple envs, repeat this option, e.g. --env k1=v2 --env k2=v2.",
),

View File

@ -71,6 +71,11 @@ try:
except ModuleNotFoundError:
BaseModelV1 = BaseModelV2
try:
from pydantic.v1 import validator
except ModuleNotFoundError:
from pydantic import validator
import wrapt
from redis.asyncio import Redis
from redis.exceptions import ResponseError
@ -94,6 +99,7 @@ from reflex.utils.exceptions import (
DynamicRouteArgShadowsStateVar,
EventHandlerShadowsBuiltInStateMethod,
ImmutableStateError,
InvalidLockWarningThresholdError,
InvalidStateManagerMode,
LockExpiredError,
ReflexRuntimeError,
@ -431,9 +437,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
)
# Create a fresh copy of the backend variables for this instance
self._backend_vars = copy.deepcopy(
{name: item for name, item in self.backend_vars.items()}
)
self._backend_vars = copy.deepcopy(self.backend_vars)
def __repr__(self) -> str:
"""Get the string representation of the state.
@ -517,9 +521,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cls.inherited_backend_vars = parent_state.backend_vars
# Check if another substate class with the same name has already been defined.
if cls.get_name() in set(
c.get_name() for c in parent_state.class_subclasses
):
if cls.get_name() in {c.get_name() for c in parent_state.class_subclasses}:
# This should not happen, since we have added module prefix to state names in #3214
raise StateValueError(
f"The substate class '{cls.get_name()}' has been defined multiple times. "
@ -782,11 +784,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
)
# ComputedVar with cache=False always need to be recomputed
cls._always_dirty_computed_vars = set(
cls._always_dirty_computed_vars = {
cvar_name
for cvar_name, cvar in cls.computed_vars.items()
if not cvar._cache
)
}
# Any substate containing a ComputedVar with cache=False always needs to be recomputed
if cls._always_dirty_computed_vars:
@ -1305,9 +1307,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return
if name in self.backend_vars:
# abort if unchanged
if self._backend_vars.get(name) == value:
return
self._backend_vars.__setitem__(name, value)
self.dirty_vars.add(name)
self._mark_dirty()
@ -1856,11 +1855,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Returns:
Set of computed vars to include in the delta.
"""
return set(
return {
cvar
for cvar in self.computed_vars
if self.computed_vars[cvar].needs_update(instance=self)
)
}
def _dirty_computed_vars(
self, from_vars: set[str] | None = None, include_backend: bool = True
@ -1874,12 +1873,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Returns:
Set of computed vars to include in the delta.
"""
return set(
return {
cvar
for dirty_var in from_vars or self.dirty_vars
for cvar in self._computed_var_dependencies[dirty_var]
if include_backend or not self.computed_vars[cvar]._backend
)
}
@classmethod
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
@ -1889,16 +1888,16 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Set of State classes that may need to be fetched to recalc computed vars.
"""
# _always_dirty_substates need to be fetched to recalc computed vars.
fetch_substates = set(
fetch_substates = {
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
for substate_name in cls._always_dirty_substates
)
}
for dependent_substates in cls._substate_var_dependencies.values():
fetch_substates.update(
set(
{
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
for substate_name in dependent_substates
)
}
)
return fetch_substates
@ -2200,7 +2199,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return md5(
pickle.dumps(
list(sorted(_field_tuple(field_name) for field_name in cls.base_vars))
sorted(_field_tuple(field_name) for field_name in cls.base_vars)
)
).hexdigest()
@ -2834,6 +2833,7 @@ class StateManager(Base, ABC):
redis=redis,
token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration,
lock_warning_threshold=config.redis_lock_warning_threshold,
)
raise InvalidStateManagerMode(
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
@ -3203,6 +3203,15 @@ def _default_lock_expiration() -> int:
return get_config().redis_lock_expiration
def _default_lock_warning_threshold() -> int:
"""Get the default lock warning threshold.
Returns:
The default lock warning threshold.
"""
return get_config().redis_lock_warning_threshold
class StateManagerRedis(StateManager):
"""A state manager that stores states in redis."""
@ -3215,6 +3224,11 @@ class StateManagerRedis(StateManager):
# The maximum time to hold a lock (ms).
lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
# The maximum time to hold a lock (ms) before warning.
lock_warning_threshold: int = pydantic.Field(
default_factory=_default_lock_warning_threshold
)
# The keyspace subscription string when redis is waiting for lock to be released
_redis_notify_keyspace_events: str = (
"K" # Enable keyspace notifications (target a particular key)
@ -3333,7 +3347,7 @@ class StateManagerRedis(StateManager):
state_cls = self.state.get_class_substate(state_path)
else:
raise RuntimeError(
"StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
)
# The deserialized or newly created (sub)state instance.
@ -3402,6 +3416,17 @@ class StateManagerRedis(StateManager):
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
"or use `@rx.event(background=True)` decorator for long-running tasks."
)
elif lock_id is not None:
time_taken = self.lock_expiration / 1000 - (
await self.redis.ttl(self._lock_key(token))
)
if time_taken > self.lock_warning_threshold / 1000:
console.warn(
f"Lock for token {token} was held too long {time_taken=}s, "
f"use `@rx.event(background=True)` decorator for long-running tasks.",
dedupe=True,
)
client_token, substate_name = _split_substate_key(token)
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
if state.parent_state is not None and state.get_full_name() != substate_name:
@ -3410,17 +3435,16 @@ class StateManagerRedis(StateManager):
)
# Recursively set_state on all known substates.
tasks = []
for substate in state.substates.values():
tasks.append(
asyncio.create_task(
self.set_state(
token=_substate_key(client_token, substate),
state=substate,
lock_id=lock_id,
)
tasks = [
asyncio.create_task(
self.set_state(
_substate_key(client_token, substate),
substate,
lock_id,
)
)
for substate in state.substates.values()
]
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
if state._get_was_touched():
pickle_state = state._serialize()
@ -3451,6 +3475,27 @@ class StateManagerRedis(StateManager):
yield state
await self.set_state(token, state, lock_id)
@validator("lock_warning_threshold")
@classmethod
def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
"""Validate the lock warning threshold.
Args:
lock_warning_threshold: The lock warning threshold.
values: The validated attributes.
Returns:
The lock warning threshold.
Raises:
InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
"""
if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
raise InvalidLockWarningThresholdError(
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
)
return lock_warning_threshold
@staticmethod
def _lock_key(token: str) -> bytes:
"""Get the redis key for a token's lock.
@ -3601,33 +3646,30 @@ class MutableProxy(wrapt.ObjectProxy):
"""A proxy for a mutable object that tracks changes."""
# Methods on wrapped objects which should mark the state as dirty.
__mark_dirty_attrs__ = set(
[
"add",
"append",
"clear",
"difference_update",
"discard",
"extend",
"insert",
"intersection_update",
"pop",
"popitem",
"remove",
"reverse",
"setdefault",
"sort",
"symmetric_difference_update",
"update",
]
)
__mark_dirty_attrs__ = {
"add",
"append",
"clear",
"difference_update",
"discard",
"extend",
"insert",
"intersection_update",
"pop",
"popitem",
"remove",
"reverse",
"setdefault",
"sort",
"symmetric_difference_update",
"update",
}
# Methods on wrapped objects might return mutable objects that should be tracked.
__wrap_mutable_attrs__ = set(
[
"get",
"setdefault",
]
)
__wrap_mutable_attrs__ = {
"get",
"setdefault",
}
# These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
__never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
@ -3670,7 +3712,7 @@ class MutableProxy(wrapt.ObjectProxy):
self,
wrapped=None,
instance=None,
args=tuple(),
args=(),
kwargs=None,
) -> Any:
"""Mark the state as dirty, then call a wrapped function.
@ -3926,7 +3968,7 @@ class ImmutableMutableProxy(MutableProxy):
self,
wrapped=None,
instance=None,
args=tuple(),
args=(),
kwargs=None,
) -> Any:
"""Raise an exception when an attempt is made to modify the object.

View File

@ -8,7 +8,6 @@ import dataclasses
import functools
import inspect
import os
import pathlib
import platform
import re
import signal
@ -20,6 +19,7 @@ import threading
import time
import types
from http.server import SimpleHTTPRequestHandler
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
@ -101,7 +101,7 @@ class chdir(contextlib.AbstractContextManager):
def __enter__(self):
"""Save current directory and perform chdir."""
self._old_cwd.append(os.getcwd())
self._old_cwd.append(Path.cwd())
os.chdir(self.path)
def __exit__(self, *excinfo):
@ -121,8 +121,8 @@ class AppHarness:
app_source: Optional[
Callable[[], None] | types.ModuleType | str | functools.partial[Any]
]
app_path: pathlib.Path
app_module_path: pathlib.Path
app_path: Path
app_module_path: Path
app_module: Optional[types.ModuleType] = None
app_instance: Optional[reflex.App] = None
frontend_process: Optional[subprocess.Popen] = None
@ -137,7 +137,7 @@ class AppHarness:
@classmethod
def create(
cls,
root: pathlib.Path,
root: Path,
app_source: Optional[
Callable[[], None] | types.ModuleType | str | functools.partial[Any]
] = None,
@ -822,7 +822,7 @@ class AppHarness:
class SimpleHTTPRequestHandlerCustomErrors(SimpleHTTPRequestHandler):
"""SimpleHTTPRequestHandler with custom error page handling."""
def __init__(self, *args, error_page_map: dict[int, pathlib.Path], **kwargs):
def __init__(self, *args, error_page_map: dict[int, Path], **kwargs):
"""Initialize the handler.
Args:
@ -865,8 +865,8 @@ class Subdir404TCPServer(socketserver.TCPServer):
def __init__(
self,
*args,
root: pathlib.Path,
error_page_map: dict[int, pathlib.Path] | None,
root: Path,
error_page_map: dict[int, Path] | None,
**kwargs,
):
"""Initialize the server.

View File

@ -150,7 +150,7 @@ def zip_app(
_zip(
component_name=constants.ComponentName.BACKEND,
target=zip_dest_dir / constants.ComponentName.BACKEND.zip(),
root_dir=Path("."),
root_dir=Path.cwd(),
dirs_to_exclude={"__pycache__"},
files_to_exclude=files_to_exclude,
top_level_dirs_to_exclude={"assets"},

View File

@ -20,6 +20,24 @@ _EMITTED_DEPRECATION_WARNINGS = set()
# Info messages which have been printed.
_EMITTED_INFO = set()
# Warnings which have been printed.
_EMIITED_WARNINGS = set()
# Errors which have been printed.
_EMITTED_ERRORS = set()
# Success messages which have been printed.
_EMITTED_SUCCESS = set()
# Debug messages which have been printed.
_EMITTED_DEBUG = set()
# Logs which have been printed.
_EMITTED_LOGS = set()
# Prints which have been printed.
_EMITTED_PRINTS = set()
def set_log_level(log_level: LogLevel):
"""Set the log level.
@ -55,25 +73,37 @@ def is_debug() -> bool:
return _LOG_LEVEL <= LogLevel.DEBUG
def print(msg: str, **kwargs):
def print(msg: str, dedupe: bool = False, **kwargs):
"""Print a message.
Args:
msg: The message to print.
dedupe: If True, suppress multiple console logs of print message.
kwargs: Keyword arguments to pass to the print function.
"""
if dedupe:
if msg in _EMITTED_PRINTS:
return
else:
_EMITTED_PRINTS.add(msg)
_console.print(msg, **kwargs)
def debug(msg: str, **kwargs):
def debug(msg: str, dedupe: bool = False, **kwargs):
"""Print a debug message.
Args:
msg: The debug message.
dedupe: If True, suppress multiple console logs of debug message.
kwargs: Keyword arguments to pass to the print function.
"""
if is_debug():
msg_ = f"[purple]Debug: {msg}[/purple]"
if dedupe:
if msg_ in _EMITTED_DEBUG:
return
else:
_EMITTED_DEBUG.add(msg_)
if progress := kwargs.pop("progress", None):
progress.console.print(msg_, **kwargs)
else:
@ -97,25 +127,37 @@ def info(msg: str, dedupe: bool = False, **kwargs):
print(f"[cyan]Info: {msg}[/cyan]", **kwargs)
def success(msg: str, **kwargs):
def success(msg: str, dedupe: bool = False, **kwargs):
"""Print a success message.
Args:
msg: The success message.
dedupe: If True, suppress multiple console logs of success message.
kwargs: Keyword arguments to pass to the print function.
"""
if _LOG_LEVEL <= LogLevel.INFO:
if dedupe:
if msg in _EMITTED_SUCCESS:
return
else:
_EMITTED_SUCCESS.add(msg)
print(f"[green]Success: {msg}[/green]", **kwargs)
def log(msg: str, **kwargs):
def log(msg: str, dedupe: bool = False, **kwargs):
"""Takes a string and logs it to the console.
Args:
msg: The message to log.
dedupe: If True, suppress multiple console logs of log message.
kwargs: Keyword arguments to pass to the print function.
"""
if _LOG_LEVEL <= LogLevel.INFO:
if dedupe:
if msg in _EMITTED_LOGS:
return
else:
_EMITTED_LOGS.add(msg)
_console.log(msg, **kwargs)
@ -129,14 +171,20 @@ def rule(title: str, **kwargs):
_console.rule(title, **kwargs)
def warn(msg: str, **kwargs):
def warn(msg: str, dedupe: bool = False, **kwargs):
"""Print a warning message.
Args:
msg: The warning message.
dedupe: If True, suppress multiple console logs of warning message.
kwargs: Keyword arguments to pass to the print function.
"""
if _LOG_LEVEL <= LogLevel.WARNING:
if dedupe:
if msg in _EMIITED_WARNINGS:
return
else:
_EMIITED_WARNINGS.add(msg)
print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
@ -169,14 +217,20 @@ def deprecate(
_EMITTED_DEPRECATION_WARNINGS.add(feature_name)
def error(msg: str, **kwargs):
def error(msg: str, dedupe: bool = False, **kwargs):
"""Print an error message.
Args:
msg: The error message.
dedupe: If True, suppress multiple console logs of error message.
kwargs: Keyword arguments to pass to the print function.
"""
if _LOG_LEVEL <= LogLevel.ERROR:
if dedupe:
if msg in _EMITTED_ERRORS:
return
else:
_EMITTED_ERRORS.add(msg)
print(f"[red]{msg}[/red]", **kwargs)

View File

@ -183,3 +183,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn:
" Please install it through your system package manager."
+ (f" You can do so by running 'brew install {package}'." if IS_MACOS else "")
)
class InvalidLockWarningThresholdError(ReflexError):
"""Raised when an invalid lock warning threshold is provided."""

View File

@ -24,7 +24,7 @@ from reflex.utils.prerequisites import get_web_dir
frontend_process = None
def detect_package_change(json_file_path: str) -> str:
def detect_package_change(json_file_path: Path) -> str:
"""Calculates the SHA-256 hash of a JSON file and returns it as a hexadecimal string.
Args:
@ -37,7 +37,7 @@ def detect_package_change(json_file_path: str) -> str:
>>> detect_package_change("package.json")
'a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0u1v2w3x4y5z6a7b8c9d0e1f2'
"""
with open(json_file_path, "r") as file:
with json_file_path.open("r") as file:
json_data = json.load(file)
# Calculate the hash
@ -81,7 +81,7 @@ def run_process_and_launch_url(run_command: list[str], backend_present=True):
from reflex.utils import processes
json_file_path = get_web_dir() / constants.PackageJson.PATH
last_hash = detect_package_change(str(json_file_path))
last_hash = detect_package_change(json_file_path)
process = None
first_run = True
@ -117,14 +117,14 @@ def run_process_and_launch_url(run_command: list[str], backend_present=True):
console.print("New packages detected: Updating app...")
else:
if any(
[x in line for x in ("bin executable does not exist on disk",)]
x in line for x in ("bin executable does not exist on disk",)
):
console.error(
"Try setting `REFLEX_USE_NPM=1` and re-running `reflex init` and `reflex run` to use npm instead of bun:\n"
"`REFLEX_USE_NPM=1 reflex init`\n"
"`REFLEX_USE_NPM=1 reflex run`"
)
new_hash = detect_package_change(str(json_file_path))
new_hash = detect_package_change(json_file_path)
if new_hash != last_hash:
last_hash = new_hash
kill(process.pid)

View File

@ -1,6 +1,5 @@
"""Export utilities."""
import os
from pathlib import Path
from typing import Optional
@ -15,7 +14,7 @@ def export(
zipping: bool = True,
frontend: bool = True,
backend: bool = True,
zip_dest_dir: str = os.getcwd(),
zip_dest_dir: str = str(Path.cwd()),
upload_db_file: bool = False,
api_url: Optional[str] = None,
deploy_url: Optional[str] = None,

View File

@ -205,14 +205,14 @@ def update_json_file(file_path: str | Path, update_dict: dict[str, int | str]):
# Read the existing json object from the file.
json_object = {}
if fp.stat().st_size:
with open(fp) as f:
with fp.open() as f:
json_object = json.load(f)
# Update the json object with the new data.
json_object.update(update_dict)
# Write the updated json object to the file
with open(fp, "w") as f:
with fp.open("w") as f:
json.dump(json_object, f, ensure_ascii=False)

View File

@ -290,7 +290,7 @@ def get_app(reload: bool = False) -> ModuleType:
"If this error occurs in a reflex test case, ensure that `get_app` is mocked."
)
module = config.module
sys.path.insert(0, os.getcwd())
sys.path.insert(0, str(Path.cwd()))
app = __import__(module, fromlist=(constants.CompileVars.APP,))
if reload:
@ -438,9 +438,11 @@ def create_config(app_name: str):
from reflex.compiler import templates
config_name = f"{re.sub(r'[^a-zA-Z]', '', app_name).capitalize()}Config"
with open(constants.Config.FILE, "w") as f:
console.debug(f"Creating {constants.Config.FILE}")
f.write(templates.RXCONFIG.render(app_name=app_name, config_name=config_name))
console.debug(f"Creating {constants.Config.FILE}")
constants.Config.FILE.write_text(
templates.RXCONFIG.render(app_name=app_name, config_name=config_name)
)
def initialize_gitignore(
@ -494,14 +496,14 @@ def initialize_requirements_txt():
console.debug(f"Detected encoding for {fp} as {encoding}.")
try:
other_requirements_exist = False
with open(fp, "r", encoding=encoding) as f:
with fp.open("r", encoding=encoding) as f:
for req in f:
# Check if we have a package name that is reflex
if re.match(r"^reflex[^a-zA-Z0-9]", req):
console.debug(f"{fp} already has reflex as dependency.")
return
other_requirements_exist = True
with open(fp, "a", encoding=encoding) as f:
with fp.open("a", encoding=encoding) as f:
preceding_newline = "\n" if other_requirements_exist else ""
f.write(
f"{preceding_newline}{constants.RequirementsTxt.DEFAULTS_STUB}{constants.Reflex.VERSION}\n"
@ -699,7 +701,7 @@ def _update_next_config(
}
if transpile_packages:
next_config["transpilePackages"] = list(
set((format_library_name(p) for p in transpile_packages))
{format_library_name(p) for p in transpile_packages}
)
if export:
next_config["output"] = "export"
@ -732,13 +734,13 @@ def download_and_run(url: str, *args, show_status: bool = False, **env):
response.raise_for_status()
# Save the script to a temporary file.
script = tempfile.NamedTemporaryFile()
with open(script.name, "w") as f:
f.write(response.text)
script = Path(tempfile.NamedTemporaryFile().name)
script.write_text(response.text)
# Run the script.
env = {**os.environ, **env}
process = processes.new_process(["bash", f.name, *args], env=env)
process = processes.new_process(["bash", str(script), *args], env=env)
show = processes.show_status if show_status else processes.show_logs
show(f"Installing {url}", process)
@ -752,14 +754,14 @@ def download_and_extract_fnm_zip():
# Download the zip file
url = constants.Fnm.INSTALL_URL
console.debug(f"Downloading {url}")
fnm_zip_file = constants.Fnm.DIR / f"{constants.Fnm.FILENAME}.zip"
fnm_zip_file: Path = constants.Fnm.DIR / f"{constants.Fnm.FILENAME}.zip"
# Function to download and extract the FNM zip release.
try:
# Download the FNM zip release.
# TODO: show progress to improve UX
response = net.get(url, follow_redirects=True)
response.raise_for_status()
with open(fnm_zip_file, "wb") as output_file:
with fnm_zip_file.open("wb") as output_file:
for chunk in response.iter_bytes():
output_file.write(chunk)
@ -807,7 +809,7 @@ def install_node():
)
else: # All other platforms (Linux, MacOS).
# Add execute permissions to fnm executable.
os.chmod(constants.Fnm.EXE, stat.S_IXUSR)
constants.Fnm.EXE.chmod(stat.S_IXUSR)
# Install node.
# Specify arm64 arch explicitly for M1s and M2s.
architecture_arg = (
@ -925,7 +927,7 @@ def cached_procedure(cache_file: str, payload_fn: Callable[..., str]):
@cached_procedure(
cache_file=str(get_web_dir() / "reflex.install_frontend_packages.cached"),
payload_fn=lambda p, c: f"{sorted(list(p))!r},{c.json()}",
payload_fn=lambda p, c: f"{sorted(p)!r},{c.json()}",
)
def install_frontend_packages(packages: set[str], config: Config):
"""Installs the base and custom frontend packages.
@ -1300,7 +1302,7 @@ def fetch_app_templates(version: str) -> dict[str, Template]:
for tp in templates_data:
if tp["hidden"] or tp["code_url"] is None:
continue
known_fields = set(f.name for f in dataclasses.fields(Template))
known_fields = {f.name for f in dataclasses.fields(Template)}
filtered_templates[tp["name"]] = Template(
**{k: v for k, v in tp.items() if k in known_fields}
)
@ -1326,7 +1328,7 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str
raise typer.Exit(1) from ose
# Use httpx GET with redirects to download the zip file.
zip_file_path = Path(temp_dir) / "template.zip"
zip_file_path: Path = Path(temp_dir) / "template.zip"
try:
# Note: following redirects can be risky. We only allow this for reflex built templates at the moment.
response = net.get(template_url, follow_redirects=True)
@ -1336,9 +1338,8 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str
console.error(f"Failed to download the template: {he}")
raise typer.Exit(1) from he
try:
with open(zip_file_path, "wb") as f:
f.write(response.content)
console.debug(f"Downloaded the zip to {zip_file_path}")
zip_file_path.write_bytes(response.content)
console.debug(f"Downloaded the zip to {zip_file_path}")
except OSError as ose:
console.error(f"Unable to write the downloaded zip to disk {ose}")
raise typer.Exit(1) from ose

View File

@ -58,7 +58,9 @@ def get_process_on_port(port) -> Optional[psutil.Process]:
The process on the given port.
"""
for proc in psutil.process_iter(["pid", "name", "cmdline"]):
try:
with contextlib.suppress(
psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess
):
if importlib.metadata.version("psutil") >= "6.0.0":
conns = proc.net_connections(kind="inet") # type: ignore
else:
@ -66,8 +68,6 @@ def get_process_on_port(port) -> Optional[psutil.Process]:
for conn in conns:
if conn.laddr.port == int(port):
return proc
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
pass
return None

View File

@ -287,10 +287,9 @@ def _generate_docstrings(clzs: list[Type[Component]], props: list[str]) -> str:
for line in (clz.create.__doc__ or "").splitlines():
if "**" in line:
indent = line.split("**")[0]
for nline in [
f"{indent}{n}:{' '.join(c)}" for n, c in props_comments.items()
]:
new_docstring.append(nline)
new_docstring.extend(
[f"{indent}{n}:{' '.join(c)}" for n, c in props_comments.items()]
)
new_docstring.append(line)
return "\n".join(new_docstring)

View File

@ -9,6 +9,7 @@ from .base import get_unique_variable_name as get_unique_variable_name
from .base import get_uuid_string_var as get_uuid_string_var
from .base import var_operation as var_operation
from .base import var_operation_return as var_operation_return
from .datetime import DateTimeVar as DateTimeVar
from .function import FunctionStringVar as FunctionStringVar
from .function import FunctionVar as FunctionVar
from .function import VarOperationCall as VarOperationCall

View File

@ -42,7 +42,8 @@ from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints,
from reflex import constants
from reflex.base import Base
from reflex.utils import console, imports, serializers, types
from reflex.constants.compiler import Hooks
from reflex.utils import console, exceptions, imports, serializers, types
from reflex.utils.exceptions import (
VarAttributeError,
VarDependencyError,
@ -115,12 +116,20 @@ class VarData:
# Hooks that need to be present in the component to render this var
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
# Dependencies of the var
deps: Tuple[Var, ...] = dataclasses.field(default_factory=tuple)
# Position of the hook in the component
position: Hooks.HookPosition | None = None
def __init__(
self,
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
):
"""Initialize the var data.
@ -129,6 +138,8 @@ class VarData:
field_name: The name of the field in the state.
imports: Imports needed to render this var.
hooks: Hooks that need to be present in the component to render this var.
deps: Dependencies of the var for useCallback.
position: Position of the hook in the component.
"""
immutable_imports: ImmutableParsedImportDict = tuple(
sorted(
@ -139,6 +150,8 @@ class VarData:
object.__setattr__(self, "field_name", field_name)
object.__setattr__(self, "imports", immutable_imports)
object.__setattr__(self, "hooks", tuple(hooks or {}))
object.__setattr__(self, "deps", tuple(deps or []))
object.__setattr__(self, "position", position or None)
def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict.
@ -146,7 +159,7 @@ class VarData:
Returns:
The imports as a mutable dict.
"""
return dict((k, list(v)) for k, v in self.imports)
return {k: list(v) for k, v in self.imports}
def merge(*all: VarData | None) -> VarData | None:
"""Merge multiple var data objects.
@ -154,6 +167,9 @@ class VarData:
Args:
*all: The var data objects to merge.
Raises:
ReflexError: If trying to merge VarData with different positions.
Returns:
The merged var data object.
@ -184,12 +200,32 @@ class VarData:
*(var_data.imports for var_data in all_var_datas)
)
if state or _imports or hooks or field_name:
deps = [dep for var_data in all_var_datas for dep in var_data.deps]
positions = list(
{
var_data.position
for var_data in all_var_datas
if var_data.position is not None
}
)
if positions:
if len(positions) > 1:
raise exceptions.ReflexError(
f"Cannot merge var data with different positions: {positions}"
)
position = positions[0]
else:
position = None
if state or _imports or hooks or field_name or deps or position:
return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
deps=deps,
position=position,
)
return None
@ -200,7 +236,14 @@ class VarData:
Returns:
True if any field is set to a non-default value.
"""
return bool(self.state or self.imports or self.hooks or self.field_name)
return bool(
self.state
or self.imports
or self.hooks
or self.field_name
or self.deps
or self.position
)
@classmethod
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
@ -480,7 +523,6 @@ class Var(Generic[VAR_TYPE]):
raise TypeError(
"The _var_full_name_needs_state_prefix argument is not supported for Var."
)
value_with_replaced = dataclasses.replace(
self,
_var_type=_var_type or self._var_type,
@ -1591,14 +1633,12 @@ class CachedVarOperation:
The cached VarData.
"""
return VarData.merge(
*map(
lambda value: (
value._get_all_var_data() if isinstance(value, Var) else None
),
map(
lambda field: getattr(self, field.name),
dataclasses.fields(self), # type: ignore
),
*(
value._get_all_var_data() if isinstance(value, Var) else None
for value in (
getattr(self, field.name)
for field in dataclasses.fields(self) # type: ignore
)
),
self._var_data,
)
@ -1889,20 +1929,20 @@ class ComputedVar(Var[RETURN_TYPE]):
Raises:
TypeError: If kwargs contains keys that are not allowed.
"""
field_values = dict(
fget=kwargs.pop("fget", self._fget),
initial_value=kwargs.pop("initial_value", self._initial_value),
cache=kwargs.pop("cache", self._cache),
deps=kwargs.pop("deps", self._static_deps),
auto_deps=kwargs.pop("auto_deps", self._auto_deps),
interval=kwargs.pop("interval", self._update_interval),
backend=kwargs.pop("backend", self._backend),
_js_expr=kwargs.pop("_js_expr", self._js_expr),
_var_type=kwargs.pop("_var_type", self._var_type),
_var_data=kwargs.pop(
field_values = {
"fget": kwargs.pop("fget", self._fget),
"initial_value": kwargs.pop("initial_value", self._initial_value),
"cache": kwargs.pop("cache", self._cache),
"deps": kwargs.pop("deps", self._static_deps),
"auto_deps": kwargs.pop("auto_deps", self._auto_deps),
"interval": kwargs.pop("interval", self._update_interval),
"backend": kwargs.pop("backend", self._backend),
"_js_expr": kwargs.pop("_js_expr", self._js_expr),
"_var_type": kwargs.pop("_var_type", self._var_type),
"_var_data": kwargs.pop(
"_var_data", VarData.merge(self._var_data, merge_var_data)
),
)
}
if kwargs:
unexpected_kwargs = ", ".join(kwargs.keys())
@ -2371,10 +2411,7 @@ class CustomVarOperation(CachedVarOperation, Var[T]):
The cached VarData.
"""
return VarData.merge(
*map(
lambda arg: arg[1]._get_all_var_data(),
self._args,
),
*(arg[1]._get_all_var_data() for arg in self._args),
self._return._get_all_var_data(),
self._var_data,
)

222
reflex/vars/datetime.py Normal file
View File

@ -0,0 +1,222 @@
"""Immutable datetime and date vars."""
from __future__ import annotations
import dataclasses
import sys
from datetime import date, datetime
from typing import Any, NoReturn, TypeVar, Union, overload
from reflex.utils.exceptions import VarTypeError
from reflex.vars.number import BooleanVar
from .base import (
CustomVarOperationReturn,
LiteralVar,
Var,
VarData,
var_operation,
var_operation_return,
)
DATETIME_T = TypeVar("DATETIME_T", datetime, date)
datetime_types = Union[datetime, date]
def raise_var_type_error():
"""Raise a VarTypeError.
Raises:
VarTypeError: Cannot compare a datetime object with a non-datetime object.
"""
raise VarTypeError("Cannot compare a datetime object with a non-datetime object.")
class DateTimeVar(Var[DATETIME_T], python_types=(datetime, date)):
"""A variable that holds a datetime or date object."""
@overload
def __lt__(self, other: datetime_types) -> BooleanVar: ...
@overload
def __lt__(self, other: NoReturn) -> NoReturn: ...
def __lt__(self, other: Any):
"""Less than comparison.
Args:
other: The other datetime to compare.
Returns:
The result of the comparison.
"""
if not isinstance(other, DATETIME_TYPES):
raise_var_type_error()
return date_lt_operation(self, other)
@overload
def __le__(self, other: datetime_types) -> BooleanVar: ...
@overload
def __le__(self, other: NoReturn) -> NoReturn: ...
def __le__(self, other: Any):
"""Less than or equal comparison.
Args:
other: The other datetime to compare.
Returns:
The result of the comparison.
"""
if not isinstance(other, DATETIME_TYPES):
raise_var_type_error()
return date_le_operation(self, other)
@overload
def __gt__(self, other: datetime_types) -> BooleanVar: ...
@overload
def __gt__(self, other: NoReturn) -> NoReturn: ...
def __gt__(self, other: Any):
"""Greater than comparison.
Args:
other: The other datetime to compare.
Returns:
The result of the comparison.
"""
if not isinstance(other, DATETIME_TYPES):
raise_var_type_error()
return date_gt_operation(self, other)
@overload
def __ge__(self, other: datetime_types) -> BooleanVar: ...
@overload
def __ge__(self, other: NoReturn) -> NoReturn: ...
def __ge__(self, other: Any):
"""Greater than or equal comparison.
Args:
other: The other datetime to compare.
Returns:
The result of the comparison.
"""
if not isinstance(other, DATETIME_TYPES):
raise_var_type_error()
return date_ge_operation(self, other)
@var_operation
def date_gt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
"""Greater than comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(rhs, lhs, strict=True)
@var_operation
def date_lt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
"""Less than comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(lhs, rhs, strict=True)
@var_operation
def date_le_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
"""Less than or equal comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(lhs, rhs)
@var_operation
def date_ge_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
"""Greater than or equal comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(rhs, lhs)
def date_compare_operation(
lhs: DateTimeVar[DATETIME_T] | Any,
rhs: DateTimeVar[DATETIME_T] | Any,
strict: bool = False,
) -> CustomVarOperationReturn:
"""Check if the value is less than the other value.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
strict: Whether to use strict comparison.
Returns:
The result of the operation.
"""
return var_operation_return(
f"({lhs} { '<' if strict else '<='} {rhs})",
bool,
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralDatetimeVar(LiteralVar, DateTimeVar):
"""Base class for immutable datetime and date vars."""
_var_value: datetime | date = dataclasses.field(default=datetime.now())
@classmethod
def create(cls, value: datetime | date, _var_data: VarData | None = None):
"""Create a new instance of the class.
Args:
value: The value to set.
Returns:
LiteralDatetimeVar: The new instance of the class.
"""
js_expr = f'"{value!s}"'
return cls(
_js_expr=js_expr,
_var_type=type(value),
_var_value=value,
_var_data=_var_data,
)
DATETIME_TYPES = (datetime, date, DateTimeVar)

View File

@ -292,7 +292,7 @@ class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
class DestructuredArg:
"""Class for destructured arguments."""
fields: Tuple[str, ...] = tuple()
fields: Tuple[str, ...] = ()
rest: Optional[str] = None
def to_javascript(self) -> str:
@ -314,7 +314,7 @@ class DestructuredArg:
class FunctionArgs:
"""Class for function arguments."""
args: Tuple[Union[str, DestructuredArg], ...] = tuple()
args: Tuple[Union[str, DestructuredArg], ...] = ()
rest: Optional[str] = None

View File

@ -51,7 +51,7 @@ def raise_unsupported_operand_types(
VarTypeError: The operand types are unsupported.
"""
raise VarTypeError(
f"Unsupported Operand type(s) for {operator}: {', '.join(map(lambda t: t.__name__, operands_types))}"
f"Unsupported Operand type(s) for {operator}: {', '.join(t.__name__ for t in operands_types)}"
)

View File

@ -1177,7 +1177,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
if num_args == 0:
return_value = fn()
function_var = ArgsFunctionOperation.create(tuple(), return_value)
function_var = ArgsFunctionOperation.create((), return_value)
else:
# generic number var
number_var = Var("").to(NumberVar, int)

View File

@ -51,11 +51,10 @@ def main():
parser.add_argument("--server-pid", type=int)
args = parser.parse_args()
executor = ThreadPoolExecutor(max_workers=len(args.port))
futures = []
for p in args.port:
futures.append(
executor.submit(_wait_for_port, p, args.server_pid, args.timeout)
)
futures = [
executor.submit(_wait_for_port, p, args.server_pid, args.timeout)
for p in args.port
]
for f in as_completed(futures):
ok, msg = f.result()
if ok:

View File

@ -6,6 +6,7 @@ from pathlib import Path
import pytest
import reflex.app
from reflex.config import environment
from reflex.testing import AppHarness, AppHarnessProd
@ -76,3 +77,25 @@ def app_harness_env(request):
The AppHarness class to use for the test.
"""
return request.param
@pytest.fixture(autouse=True)
def raise_console_error(request, mocker):
"""Spy on calls to `console.error` used by the framework.
Help catch spurious error conditions that might otherwise go unnoticed.
If a test is marked with `ignore_console_error`, the spy will be ignored
after the test.
Args:
request: The pytest request object.
mocker: The pytest mocker object.
Yields:
control to the test function.
"""
spy = mocker.spy(reflex.app.console, "error")
yield
if "ignore_console_error" not in request.keywords:
spy.assert_not_called()

View File

@ -15,6 +15,7 @@ from .utils import SessionStorage
def CallScript():
"""A test app for browser javascript integration."""
from pathlib import Path
from typing import Dict, List, Optional, Union
import reflex as rx
@ -186,8 +187,7 @@ def CallScript():
self.reset()
app = rx.App(state=rx.State)
with open("assets/external.js", "w") as f:
f.write(external_scripts)
Path("assets/external.js").write_text(external_scripts)
@app.add_page
def index():

View File

@ -637,8 +637,7 @@ async def test_client_side_state(
assert await AppHarness._poll_for_async(poll_for_not_hydrated)
# Trigger event to get a new instance of the state since the old was expired.
state_var_input = driver.find_element(By.ID, "state_var")
state_var_input.send_keys("re-triggering")
set_sub("c1", "c1 post expire")
# get new references to all cookie and local storage elements (again)
c1 = driver.find_element(By.ID, "c1")
@ -659,7 +658,7 @@ async def test_client_side_state(
l1s = driver.find_element(By.ID, "l1s")
s1s = driver.find_element(By.ID, "s1s")
assert c1.text == "c1 value"
assert c1.text == "c1 post expire"
assert c2.text == "c2 value"
assert c3.text == "" # temporary cookie expired after reset state!
assert c4.text == "c4 value"
@ -690,11 +689,11 @@ async def test_client_side_state(
async def poll_for_c1_set():
sub_state = await get_sub_state()
return sub_state.c1 == "c1 value"
return sub_state.c1 == "c1 post expire"
assert await AppHarness._poll_for_async(poll_for_c1_set)
sub_state = await get_sub_state()
assert sub_state.c1 == "c1 value"
assert sub_state.c1 == "c1 post expire"
assert sub_state.c2 == "c2 value"
assert sub_state.c3 == ""
assert sub_state.c4 == "c4 value"

View File

@ -13,6 +13,8 @@ from selenium.webdriver.support.ui import WebDriverWait
from reflex.testing import AppHarness, AppHarnessProd
pytestmark = [pytest.mark.ignore_console_error]
def TestApp():
"""A test app for event exception handler integration."""

View File

@ -0,0 +1,87 @@
from typing import Generator
import pytest
from playwright.sync_api import Page, expect
from reflex.testing import AppHarness
def DatetimeOperationsApp():
from datetime import datetime
import reflex as rx
class DtOperationsState(rx.State):
date1: datetime = datetime(2021, 1, 1)
date2: datetime = datetime(2031, 1, 1)
date3: datetime = datetime(2021, 1, 1)
app = rx.App(state=DtOperationsState)
@app.add_page
def index():
return rx.vstack(
rx.text(DtOperationsState.date1, id="date1"),
rx.text(DtOperationsState.date2, id="date2"),
rx.text(DtOperationsState.date3, id="date3"),
rx.text("Operations between date1 and date2"),
rx.text(DtOperationsState.date1 == DtOperationsState.date2, id="1_eq_2"),
rx.text(DtOperationsState.date1 != DtOperationsState.date2, id="1_neq_2"),
rx.text(DtOperationsState.date1 < DtOperationsState.date2, id="1_lt_2"),
rx.text(DtOperationsState.date1 <= DtOperationsState.date2, id="1_le_2"),
rx.text(DtOperationsState.date1 > DtOperationsState.date2, id="1_gt_2"),
rx.text(DtOperationsState.date1 >= DtOperationsState.date2, id="1_ge_2"),
rx.text("Operations between date1 and date3"),
rx.text(DtOperationsState.date1 == DtOperationsState.date3, id="1_eq_3"),
rx.text(DtOperationsState.date1 != DtOperationsState.date3, id="1_neq_3"),
rx.text(DtOperationsState.date1 < DtOperationsState.date3, id="1_lt_3"),
rx.text(DtOperationsState.date1 <= DtOperationsState.date3, id="1_le_3"),
rx.text(DtOperationsState.date1 > DtOperationsState.date3, id="1_gt_3"),
rx.text(DtOperationsState.date1 >= DtOperationsState.date3, id="1_ge_3"),
)
@pytest.fixture()
def datetime_operations_app(tmp_path_factory) -> Generator[AppHarness, None, None]:
"""Start Table app at tmp_path via AppHarness.
Args:
tmp_path_factory: pytest tmp_path_factory fixture
Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path_factory.mktemp("datetime_operations_app"),
app_source=DatetimeOperationsApp, # type: ignore
) as harness:
assert harness.app_instance is not None, "app is not running"
yield harness
def test_datetime_operations(datetime_operations_app: AppHarness, page: Page):
assert datetime_operations_app.frontend_url is not None
page.goto(datetime_operations_app.frontend_url)
expect(page).to_have_url(datetime_operations_app.frontend_url + "/")
# Check the actual values
expect(page.locator("id=date1")).to_have_text("2021-01-01 00:00:00")
expect(page.locator("id=date2")).to_have_text("2031-01-01 00:00:00")
expect(page.locator("id=date3")).to_have_text("2021-01-01 00:00:00")
# Check the operations between date1 and date2
expect(page.locator("id=1_eq_2")).to_have_text("false")
expect(page.locator("id=1_neq_2")).to_have_text("true")
expect(page.locator("id=1_lt_2")).to_have_text("true")
expect(page.locator("id=1_le_2")).to_have_text("true")
expect(page.locator("id=1_gt_2")).to_have_text("false")
expect(page.locator("id=1_ge_2")).to_have_text("false")
# Check the operations between date1 and date3
expect(page.locator("id=1_eq_3")).to_have_text("true")
expect(page.locator("id=1_neq_3")).to_have_text("false")
expect(page.locator("id=1_lt_3")).to_have_text("false")
expect(page.locator("id=1_le_3")).to_have_text("true")
expect(page.locator("id=1_gt_3")).to_have_text("false")
expect(page.locator("id=1_ge_3")).to_have_text("true")

View File

@ -12,7 +12,7 @@ def test_websocket_target_url():
url = WebsocketTargetURL.create()
var_data = url._get_all_var_data()
assert var_data is not None
assert sorted(tuple((key for key, _ in var_data.imports))) == sorted(
assert sorted(key for key, _ in var_data.imports) == sorted(
("$/utils/state", "$/env.json")
)
@ -20,7 +20,7 @@ def test_websocket_target_url():
def test_connection_banner():
banner = ConnectionBanner.create()
_imports = banner._get_all_imports(collapse=True)
assert sorted(tuple(_imports)) == sorted(
assert sorted(_imports) == sorted(
(
"react",
"$/utils/context",
@ -38,7 +38,7 @@ def test_connection_banner():
def test_connection_modal():
modal = ConnectionModal.create()
_imports = modal._get_all_imports(collapse=True)
assert sorted(tuple(_imports)) == sorted(
assert sorted(_imports) == sorted(
(
"react",
"$/utils/context",

View File

@ -61,14 +61,13 @@ class FileUploadState(State):
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
assert file.filename is not None
outfile = self._tmp_path / file.filename
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
outfile.write_bytes(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
@rx.event(background=True)
@ -109,14 +108,13 @@ class ChildFileUploadState(FileStateBase1):
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
assert file.filename is not None
outfile = self._tmp_path / file.filename
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
outfile.write_bytes(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
@rx.event(background=True)
@ -157,14 +155,13 @@ class GrandChildFileUploadState(FileStateBase2):
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
assert file.filename is not None
outfile = self._tmp_path / file.filename
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
outfile.write_bytes(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
@rx.event(background=True)

View File

@ -105,8 +105,8 @@ def test_initialize_requirements_txt_no_op(mocker):
return_value=Mock(best=lambda: Mock(encoding="utf-8")),
)
mock_fp_touch = mocker.patch("pathlib.Path.touch")
open_mock = mock_open(read_data="reflex==0.2.9")
mocker.patch("builtins.open", open_mock)
open_mock = mock_open(read_data="reflex==0.6.7")
mocker.patch("pathlib.Path.open", open_mock)
initialize_requirements_txt()
assert open_mock.call_count == 1
assert open_mock.call_args.kwargs["encoding"] == "utf-8"
@ -122,7 +122,7 @@ def test_initialize_requirements_txt_missing_reflex(mocker):
return_value=Mock(best=lambda: Mock(encoding="utf-8")),
)
open_mock = mock_open(read_data="random-package=1.2.3")
mocker.patch("builtins.open", open_mock)
mocker.patch("pathlib.Path.open", open_mock)
initialize_requirements_txt()
# Currently open for read, then open for append
assert open_mock.call_count == 2
@ -138,7 +138,7 @@ def test_initialize_requirements_txt_not_exist(mocker):
# File does not exist, create file with reflex
mocker.patch("pathlib.Path.exists", return_value=False)
open_mock = mock_open()
mocker.patch("builtins.open", open_mock)
mocker.patch("pathlib.Path.open", open_mock)
initialize_requirements_txt()
assert open_mock.call_count == 2
# By default, use utf-8 encoding
@ -170,7 +170,7 @@ def test_requirements_txt_other_encoding(mocker):
)
initialize_requirements_txt()
open_mock = mock_open(read_data="random-package=1.2.3")
mocker.patch("builtins.open", open_mock)
mocker.patch("pathlib.Path.open", open_mock)
initialize_requirements_txt()
# Currently open for read, then open for append
assert open_mock.call_count == 2

View File

@ -56,6 +56,7 @@ from reflex.state import (
from reflex.testing import chdir
from reflex.utils import format, prerequisites, types
from reflex.utils.exceptions import (
InvalidLockWarningThresholdError,
ReflexRuntimeError,
SetUndefinedStateVarError,
StateSerializationError,
@ -67,7 +68,9 @@ from tests.units.states.mutation import MutableSQLAModel, MutableTestState
from .states import GenState
CI = bool(os.environ.get("CI", False))
LOCK_EXPIRATION = 2000 if CI else 300
LOCK_EXPIRATION = 2500 if CI else 300
LOCK_WARNING_THRESHOLD = 1000 if CI else 100
LOCK_WARN_SLEEP = 1.5 if CI else 0.15
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
@ -1787,6 +1790,7 @@ async def test_state_manager_lock_expire(
substate_token_redis: A token + substate name for looking up in state manager.
"""
state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
async with state_manager_redis.modify_state(substate_token_redis):
await asyncio.sleep(0.01)
@ -1811,6 +1815,7 @@ async def test_state_manager_lock_expire_contend(
unexp_num1 = 666
state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
order = []
@ -1840,6 +1845,39 @@ async def test_state_manager_lock_expire_contend(
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
@pytest.mark.asyncio
async def test_state_manager_lock_warning_threshold_contend(
state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker
):
"""Test that the state manager triggers a warning when lock contention exceeds the warning threshold.
Args:
state_manager_redis: A state manager instance.
token: A token.
substate_token_redis: A token + substate name for looking up in state manager.
mocker: Pytest mocker object.
"""
console_warn = mocker.patch("reflex.utils.console.warn")
state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
order = []
async def _coro_blocker():
async with state_manager_redis.modify_state(substate_token_redis):
order.append("blocker")
await asyncio.sleep(LOCK_WARN_SLEEP)
tasks = [
asyncio.create_task(_coro_blocker()),
]
await tasks[0]
console_warn.assert_called()
assert console_warn.call_count == 7
class CopyingAsyncMock(AsyncMock):
"""An AsyncMock, but deepcopy the args and kwargs first."""
@ -3253,12 +3291,42 @@ async def test_setvar_async_setter():
@pytest.mark.parametrize(
"expiration_kwargs, expected_values",
[
({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)),
(
{"redis_lock_expiration": 20000},
(
20000,
constants.Expiration.TOKEN,
constants.Expiration.LOCK_WARNING_THRESHOLD,
),
),
(
{"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
(50000, 5600),
(50000, 5600, constants.Expiration.LOCK_WARNING_THRESHOLD),
),
(
{"redis_token_expiration": 7600},
(
constants.Expiration.LOCK,
7600,
constants.Expiration.LOCK_WARNING_THRESHOLD,
),
),
(
{"redis_lock_expiration": 50000, "redis_lock_warning_threshold": 1500},
(50000, constants.Expiration.TOKEN, 1500),
),
(
{"redis_token_expiration": 5600, "redis_lock_warning_threshold": 3000},
(constants.Expiration.LOCK, 5600, 3000),
),
(
{
"redis_lock_expiration": 50000,
"redis_token_expiration": 5600,
"redis_lock_warning_threshold": 2000,
},
(50000, 5600, 2000),
),
({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)),
],
)
def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
@ -3288,6 +3356,44 @@ config = rx.Config(
state_manager = StateManager.create(state=State)
assert state_manager.lock_expiration == expected_values[0] # type: ignore
assert state_manager.token_expiration == expected_values[1] # type: ignore
assert state_manager.lock_warning_threshold == expected_values[2] # type: ignore
@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
@pytest.mark.parametrize(
"redis_lock_expiration, redis_lock_warning_threshold",
[
(10000, 10000),
(20000, 30000),
],
)
def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold(
tmp_path, redis_lock_expiration, redis_lock_warning_threshold
):
proj_root = tmp_path / "project1"
proj_root.mkdir()
config_string = f"""
import reflex as rx
config = rx.Config(
app_name="project1",
redis_url="redis://localhost:6379",
state_manager_mode="redis",
redis_lock_expiration = {redis_lock_expiration},
redis_lock_warning_threshold = {redis_lock_warning_threshold},
)
"""
(proj_root / "rxconfig.py").write_text(dedent(config_string))
with chdir(proj_root):
# reload config for each parameter to avoid stale values
reflex.config.get_config(reload=True)
from reflex.state import State, StateManager
with pytest.raises(InvalidLockWarningThresholdError):
StateManager.create(state=State)
del sys.modules[constants.Config.MODULE]
class MixinState(State, mixin=True):

View File

@ -372,7 +372,7 @@ def test_basic_operations(TestObj):
"var, expected",
[
(v([1, 2, 3]), "[1, 2, 3]"),
(v(set([1, 2, 3])), "[1, 2, 3]"),
(v({1, 2, 3}), "[1, 2, 3]"),
(v(["1", "2", "3"]), '["1", "2", "3"]'),
(
Var(_js_expr="foo")._var_set_state("state").to(list),
@ -903,7 +903,7 @@ def test_literal_var():
True,
False,
None,
set([1, 2, 3]),
{1, 2, 3},
]
)
assert (

View File

@ -222,9 +222,10 @@ def test_serialize(value: Any, expected: str):
'"2021-01-01 01:01:01.000001"',
True,
),
(datetime.date(2021, 1, 1), '"2021-01-01"', True),
(Color(color="slate", shade=1), '"var(--slate-1)"', True),
(BaseSubclass, '"BaseSubclass"', True),
(Path("."), '"."', True),
(Path(), '"."', True),
],
)
def test_serialize_var_to_str(value: Any, expected: str, exp_var_is_string: bool):

View File

@ -270,7 +270,7 @@ def test_unsupported_literals(cls: type):
("appname2.io", "AppnameioConfig"),
],
)
def test_create_config(app_name, expected_config_name, mocker):
def test_create_config(app_name: str, expected_config_name: str, mocker):
"""Test templates.RXCONFIG is formatted with correct app name and config class name.
Args:
@ -278,7 +278,7 @@ def test_create_config(app_name, expected_config_name, mocker):
expected_config_name: Expected config name.
mocker: Mocker object.
"""
mocker.patch("builtins.open")
mocker.patch("pathlib.Path.write_text")
tmpl_mock = mocker.patch("reflex.compiler.templates.RXCONFIG")
prerequisites.create_config(app_name)
tmpl_mock.render.assert_called_with(
@ -464,7 +464,7 @@ def test_node_install_unix(tmp_path, mocker, machine, system):
mocker.patch("httpx.stream", return_value=Resp())
download = mocker.patch("reflex.utils.prerequisites.download_and_extract_fnm_zip")
process = mocker.patch("reflex.utils.processes.new_process")
chmod = mocker.patch("reflex.utils.prerequisites.os.chmod")
chmod = mocker.patch("pathlib.Path.chmod")
mocker.patch("reflex.utils.processes.stream_logs")
prerequisites.install_node()