This commit is contained in:
Khaleel Al-Adhami 2024-10-24 14:36:40 -07:00
commit dc9ce91717
19 changed files with 293 additions and 141 deletions

16
poetry.lock generated
View File

@ -1977,20 +1977,6 @@ files = [
[package.dependencies] [package.dependencies]
six = ">=1.5" six = ">=1.5"
[[package]]
name = "python-dotenv"
version = "1.0.1"
description = "Read key-value pairs from a .env file and set them as environment variables"
optional = false
python-versions = ">=3.8"
files = [
{file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"},
{file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"},
]
[package.extras]
cli = ["click (>=5.0)"]
[[package]] [[package]]
name = "python-engineio" name = "python-engineio"
version = "4.10.1" version = "4.10.1"
@ -3047,4 +3033,4 @@ type = ["pytest-mypy"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "c5da15520cef58124f6699007c81158036840469d4f9972592d72bd456c45e7e" content-hash = "8090ccaeca173bd8612e17a0b8d157d7492618e49450abd1c8373e2976349db0"

View File

@ -33,7 +33,6 @@ jinja2 = ">=3.1.2,<4.0"
psutil = ">=5.9.4,<7.0" psutil = ">=5.9.4,<7.0"
pydantic = ">=1.10.2,<3.0" pydantic = ">=1.10.2,<3.0"
python-multipart = ">=0.0.5,<0.1" python-multipart = ">=0.0.5,<0.1"
python-dotenv = ">=1.0.1"
python-socketio = ">=5.7.0,<6.0" python-socketio = ">=5.7.0,<6.0"
redis = ">=4.3.5,<6.0" redis = ">=4.3.5,<6.0"
rich = ">=13.0.0,<14.0" rich = ">=13.0.0,<14.0"

View File

@ -4,7 +4,7 @@ import {
ColorModeContext, ColorModeContext,
defaultColorMode, defaultColorMode,
isDevMode, isDevMode,
lastCompiledTimeStamp lastCompiledTimeStamp,
} from "$/utils/context.js"; } from "$/utils/context.js";
export default function RadixThemesColorModeProvider({ children }) { export default function RadixThemesColorModeProvider({ children }) {
@ -37,7 +37,7 @@ export default function RadixThemesColorModeProvider({ children }) {
const allowedModes = ["light", "dark", "system"]; const allowedModes = ["light", "dark", "system"];
if (!allowedModes.includes(mode)) { if (!allowedModes.includes(mode)) {
console.error( console.error(
`Invalid color mode "${mode}". Defaulting to "${defaultColorMode}".`, `Invalid color mode "${mode}". Defaulting to "${defaultColorMode}".`
); );
mode = defaultColorMode; mode = defaultColorMode;
} }

View File

@ -315,7 +315,7 @@ def _compile_stateful_components(
# Don't import from the file that we're about to create. # Don't import from the file that we're about to create.
all_imports = utils.merge_imports(*all_import_dicts) all_imports = utils.merge_imports(*all_import_dicts)
all_imports.pop( all_imports.pop(
f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None f"$/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None
) )
return templates.STATEFUL_COMPONENTS.render( return templates.STATEFUL_COMPONENTS.render(

View File

@ -83,6 +83,12 @@ def validate_imports(import_dict: ParsedImportDict):
f"{_import.tag}/{_import.alias}" if _import.alias else _import.tag f"{_import.tag}/{_import.alias}" if _import.alias else _import.tag
) )
if import_name in used_tags: if import_name in used_tags:
already_imported = used_tags[import_name]
if (already_imported[0] == "$" and already_imported[1:] == lib) or (
lib[0] == "$" and lib[1:] == already_imported
):
used_tags[import_name] = lib if lib[0] == "$" else already_imported
continue
raise ValueError( raise ValueError(
f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}" f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}"
) )

View File

@ -39,6 +39,7 @@ from reflex.constants import (
) )
from reflex.constants.compiler import SpecialAttributes from reflex.constants.compiler import SpecialAttributes
from reflex.event import ( from reflex.event import (
EventCallback,
EventChain, EventChain,
EventChainVar, EventChainVar,
EventHandler, EventHandler,
@ -1135,6 +1136,8 @@ class Component(BaseComponent, ABC):
for trigger in self.event_triggers.values(): for trigger in self.event_triggers.values():
if isinstance(trigger, EventChain): if isinstance(trigger, EventChain):
for event in trigger.events: for event in trigger.events:
if isinstance(event, EventCallback):
continue
if isinstance(event, EventSpec): if isinstance(event, EventSpec):
if event.handler.state_full_name: if event.handler.state_full_name:
return True return True
@ -2241,7 +2244,7 @@ class StatefulComponent(BaseComponent):
""" """
if self.rendered_as_shared: if self.rendered_as_shared:
return { return {
f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [ f"$/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [
ImportVar(tag=self.tag) ImportVar(tag=self.tag)
] ]
} }

View File

@ -118,7 +118,7 @@ class DebounceInput(Component):
_var_type=Type[Component], _var_type=Type[Component],
_var_data=VarData( _var_data=VarData(
imports=child._get_imports(), imports=child._get_imports(),
hooks=child._get_hooks_internal(), hooks=child._get_all_hooks(),
), ),
), ),
) )
@ -128,6 +128,10 @@ class DebounceInput(Component):
component.event_triggers.update(child.event_triggers) component.event_triggers.update(child.event_triggers)
component.children = child.children component.children = child.children
component._rename_props = child._rename_props component._rename_props = child._rename_props
outer_get_all_custom_code = component._get_all_custom_code
component._get_all_custom_code = lambda: outer_get_all_custom_code().union(
child._get_all_custom_code()
)
return component return component
def _render(self): def _render(self):

View File

@ -93,7 +93,7 @@ def load_dynamic_serializer():
for lib, names in component._get_all_imports().items(): for lib, names in component._get_all_imports().items():
formatted_lib_name = format_library_name(lib) formatted_lib_name = format_library_name(lib)
if ( if (
not lib.startswith((".", "$/")) not lib.startswith((".", "/", "$/"))
and not lib.startswith("http") and not lib.startswith("http")
and formatted_lib_name not in libs_in_window and formatted_lib_name not in libs_in_window
): ):
@ -109,7 +109,7 @@ def load_dynamic_serializer():
# Rewrite imports from `/` to destructure from window # Rewrite imports from `/` to destructure from window
for ix, line in enumerate(module_code_lines[:]): for ix, line in enumerate(module_code_lines[:]):
if line.startswith("import "): if line.startswith("import "):
if 'from "$/' in line: if 'from "$/' in line or 'from "/' in line:
module_code_lines[ix] = ( module_code_lines[ix] = (
line.replace("import ", "const ", 1).replace( line.replace("import ", "const ", 1).replace(
" from ", " = window['__reflex'][", 1 " from ", " = window['__reflex'][", 1

View File

@ -615,6 +615,42 @@ class Textarea(BaseHTML):
# Fired when a key is released # Fired when a key is released
on_key_up: EventHandler[key_event] on_key_up: EventHandler[key_event]
@classmethod
def create(cls, *children, **props):
"""Create a textarea component.
Args:
*children: The children of the textarea.
**props: The properties of the textarea.
Returns:
The textarea component.
Raises:
ValueError: when `enter_key_submit` is combined with `on_key_down`.
"""
enter_key_submit = props.get("enter_key_submit")
auto_height = props.get("auto_height")
custom_attrs = props.setdefault("custom_attrs", {})
if enter_key_submit is not None:
enter_key_submit = Var.create(enter_key_submit)
if "on_key_down" in props:
raise ValueError(
"Cannot combine `enter_key_submit` with `on_key_down`.",
)
custom_attrs["on_key_down"] = Var(
_js_expr=f"(e) => enterKeySubmitOnKeyDown(e, {str(enter_key_submit)})",
_var_data=VarData.merge(enter_key_submit._get_all_var_data()),
)
if auto_height is not None:
auto_height = Var.create(auto_height)
custom_attrs["on_input"] = Var(
_js_expr=f"(e) => autoHeightOnInput(e, {str(auto_height)})",
_var_data=VarData.merge(auto_height._get_all_var_data()),
)
return super().create(*children, **props)
def _exclude_props(self) -> list[str]: def _exclude_props(self) -> list[str]:
return super()._exclude_props() + [ return super()._exclude_props() + [
"auto_height", "auto_height",
@ -634,28 +670,6 @@ class Textarea(BaseHTML):
custom_code.add(ENTER_KEY_SUBMIT_JS) custom_code.add(ENTER_KEY_SUBMIT_JS)
return custom_code return custom_code
def _render(self) -> Tag:
tag = super()._render()
if self.enter_key_submit is not None:
if "on_key_down" in self.event_triggers:
raise ValueError(
"Cannot combine `enter_key_submit` with `on_key_down`.",
)
tag.add_props(
on_key_down=Var(
_js_expr=f"(e) => enterKeySubmitOnKeyDown(e, {str(self.enter_key_submit)})",
_var_data=VarData.merge(self.enter_key_submit._get_all_var_data()),
)
)
if self.auto_height is not None:
tag.add_props(
on_input=Var(
_js_expr=f"(e) => autoHeightOnInput(e, {str(self.auto_height)})",
_var_data=VarData.merge(self.auto_height._get_all_var_data()),
)
)
return tag
button = Button.create button = Button.create
fieldset = Fieldset.create fieldset = Fieldset.create

View File

@ -1376,10 +1376,10 @@ class Textarea(BaseHTML):
on_unmount: Optional[EventType[[]]] = None, on_unmount: Optional[EventType[[]]] = None,
**props, **props,
) -> "Textarea": ) -> "Textarea":
"""Create the component. """Create a textarea component.
Args: Args:
*children: The children of the component. *children: The children of the textarea.
auto_complete: Whether the form control should have autocomplete enabled auto_complete: Whether the form control should have autocomplete enabled
auto_focus: Automatically focuses the textarea when the page loads auto_focus: Automatically focuses the textarea when the page loads
auto_height: Automatically fit the content height to the text (use min-height with this prop) auto_height: Automatically fit the content height to the text (use min-height with this prop)
@ -1419,10 +1419,13 @@ class Textarea(BaseHTML):
class_name: The class name for the component. class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute custom_attrs: custom attribute
**props: The props of the component. **props: The properties of the textarea.
Returns: Returns:
The component. The textarea component.
Raises:
ValueError: when `enter_key_submit` is combined with `on_key_down`.
""" """
... ...

View File

@ -8,12 +8,12 @@ import os
import sys import sys
import urllib.parse import urllib.parse
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set
from typing_extensions import get_type_hints from typing_extensions import get_type_hints
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
from reflex.utils.types import value_inside_optional from reflex.utils.types import GenericType, is_union, value_inside_optional
try: try:
import pydantic.v1 as pydantic import pydantic.v1 as pydantic
@ -157,11 +157,13 @@ def get_default_value_for_field(field: dataclasses.Field) -> Any:
) )
def interpret_boolean_env(value: str) -> bool: # TODO: Change all interpret_.* signatures to value: str, field: dataclasses.Field once we migrate rx.Config to dataclasses
def interpret_boolean_env(value: str, field_name: str) -> bool:
"""Interpret a boolean environment variable value. """Interpret a boolean environment variable value.
Args: Args:
value: The environment variable value. value: The environment variable value.
field_name: The field name.
Returns: Returns:
The interpreted value. The interpreted value.
@ -176,14 +178,15 @@ def interpret_boolean_env(value: str) -> bool:
return True return True
elif value.lower() in false_values: elif value.lower() in false_values:
return False return False
raise EnvironmentVarValueError(f"Invalid boolean value: {value}") raise EnvironmentVarValueError(f"Invalid boolean value: {value} for {field_name}")
def interpret_int_env(value: str) -> int: def interpret_int_env(value: str, field_name: str) -> int:
"""Interpret an integer environment variable value. """Interpret an integer environment variable value.
Args: Args:
value: The environment variable value. value: The environment variable value.
field_name: The field name.
Returns: Returns:
The interpreted value. The interpreted value.
@ -194,14 +197,17 @@ def interpret_int_env(value: str) -> int:
try: try:
return int(value) return int(value)
except ValueError as ve: except ValueError as ve:
raise EnvironmentVarValueError(f"Invalid integer value: {value}") from ve raise EnvironmentVarValueError(
f"Invalid integer value: {value} for {field_name}"
) from ve
def interpret_path_env(value: str) -> Path: def interpret_path_env(value: str, field_name: str) -> Path:
"""Interpret a path environment variable value. """Interpret a path environment variable value.
Args: Args:
value: The environment variable value. value: The environment variable value.
field_name: The field name.
Returns: Returns:
The interpreted value. The interpreted value.
@ -211,16 +217,19 @@ def interpret_path_env(value: str) -> Path:
""" """
path = Path(value) path = Path(value)
if not path.exists(): if not path.exists():
raise EnvironmentVarValueError(f"Path does not exist: {path}") raise EnvironmentVarValueError(f"Path does not exist: {path} for {field_name}")
return path return path
def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any: def interpret_env_var_value(
value: str, field_type: GenericType, field_name: str
) -> Any:
"""Interpret an environment variable value based on the field type. """Interpret an environment variable value based on the field type.
Args: Args:
value: The environment variable value. value: The environment variable value.
field: The field. field_type: The field type.
field_name: The field name.
Returns: Returns:
The interpreted value. The interpreted value.
@ -228,20 +237,25 @@ def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any:
Raises: Raises:
ValueError: If the value is invalid. ValueError: If the value is invalid.
""" """
field_type = value_inside_optional(field.type) field_type = value_inside_optional(field_type)
if is_union(field_type):
raise ValueError(
f"Union types are not supported for environment variables: {field_name}."
)
if field_type is bool: if field_type is bool:
return interpret_boolean_env(value) return interpret_boolean_env(value, field_name)
elif field_type is str: elif field_type is str:
return value return value
elif field_type is int: elif field_type is int:
return interpret_int_env(value) return interpret_int_env(value, field_name)
elif field_type is Path: elif field_type is Path:
return interpret_path_env(value) return interpret_path_env(value, field_name)
else: else:
raise ValueError( raise ValueError(
f"Invalid type for environment variable {field.name}: {field_type}. This is probably an issue in Reflex." f"Invalid type for environment variable {field_name}: {field_type}. This is probably an issue in Reflex."
) )
@ -316,7 +330,7 @@ class EnvironmentVariables:
field.type = type_hints.get(field.name) or field.type field.type = type_hints.get(field.name) or field.type
value = ( value = (
interpret_env_var_value(raw_value, field) interpret_env_var_value(raw_value, field.type, field.name)
if raw_value is not None if raw_value is not None
else get_default_value_for_field(field) else get_default_value_for_field(field)
) )
@ -387,7 +401,7 @@ class Config(Base):
telemetry_enabled: bool = True telemetry_enabled: bool = True
# The bun path # The bun path
bun_path: Union[str, Path] = constants.Bun.DEFAULT_PATH bun_path: Path = constants.Bun.DEFAULT_PATH
# List of origins that are allowed to connect to the backend API. # List of origins that are allowed to connect to the backend API.
cors_allowed_origins: List[str] = ["*"] cors_allowed_origins: List[str] = ["*"]
@ -484,17 +498,17 @@ class Config(Base):
Returns: Returns:
The updated config values. The updated config values.
Raises:
EnvVarValueError: If an environment variable is set to an invalid type.
""" """
from reflex.utils.exceptions import EnvVarValueError
if self.env_file: if self.env_file:
from dotenv import load_dotenv try:
from dotenv import load_dotenv # type: ignore
# load env file if exists # load env file if exists
load_dotenv(self.env_file, override=True) load_dotenv(self.env_file, override=True)
except ImportError:
console.error(
"""The `python-dotenv` package is required to load environment variables from a file. Run `pip install "python-dotenv>=1.0.1"`."""
)
updated_values = {} updated_values = {}
# Iterate over the fields. # Iterate over the fields.
@ -510,21 +524,11 @@ class Config(Base):
dedupe=True, dedupe=True,
) )
# Convert the env var to the expected type. # Interpret the value.
try: value = interpret_env_var_value(env_var, field.type_, field.name)
if issubclass(field.type_, bool):
# special handling for bool values
env_var = env_var.lower() in ["true", "1", "yes"]
else:
env_var = field.type_(env_var)
except ValueError as ve:
console.error(
f"Could not convert {key.upper()}={env_var} to type {field.type_}"
)
raise EnvVarValueError from ve
# Set the value. # Set the value.
updated_values[key] = env_var updated_values[key] = value
return updated_values return updated_values

View File

@ -16,6 +16,7 @@ from typing import (
Generic, Generic,
List, List,
Optional, Optional,
Sequence,
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
@ -389,7 +390,9 @@ class CallableEventSpec(EventSpec):
class EventChain(EventActionsMixin): class EventChain(EventActionsMixin):
"""Container for a chain of events that will be executed in order.""" """Container for a chain of events that will be executed in order."""
events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list) events: Sequence[Union[EventSpec, EventVar, EventCallback]] = dataclasses.field(
default_factory=list
)
args_spec: Optional[Callable] = dataclasses.field(default=None) args_spec: Optional[Callable] = dataclasses.field(default=None)
@ -1445,13 +1448,8 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
) )
G = ParamSpec("G")
IndividualEventType = Union[EventSpec, EventHandler, Callable[G, Any], Var[Any]]
EventType = Union[IndividualEventType[G], List[IndividualEventType[G]]]
P = ParamSpec("P") P = ParamSpec("P")
Q = ParamSpec("Q")
T = TypeVar("T") T = TypeVar("T")
V = TypeVar("V") V = TypeVar("V")
V2 = TypeVar("V2") V2 = TypeVar("V2")
@ -1473,55 +1471,73 @@ if sys.version_info >= (3, 10):
""" """
self.func = func self.func = func
@property
def prevent_default(self):
"""Prevent default behavior.
Returns:
The event callback with prevent default behavior.
"""
return self
@property
def stop_propagation(self):
"""Stop event propagation.
Returns:
The event callback with stop propagation behavior.
"""
return self
@overload @overload
def __get__( def __call__(
self: EventCallback[[V], T], instance: None, owner self: EventCallback[Concatenate[V, Q], T], value: V | Var[V]
) -> Callable[[Union[Var[V], V]], EventSpec]: ... ) -> EventCallback[Q, T]: ...
@overload
def __call__(
self: EventCallback[Concatenate[V, V2, Q], T],
value: V | Var[V],
value2: V2 | Var[V2],
) -> EventCallback[Q, T]: ...
@overload
def __call__(
self: EventCallback[Concatenate[V, V2, V3, Q], T],
value: V | Var[V],
value2: V2 | Var[V2],
value3: V3 | Var[V3],
) -> EventCallback[Q, T]: ...
@overload
def __call__(
self: EventCallback[Concatenate[V, V2, V3, V4, Q], T],
value: V | Var[V],
value2: V2 | Var[V2],
value3: V3 | Var[V3],
value4: V4 | Var[V4],
) -> EventCallback[Q, T]: ...
def __call__(self, *values) -> EventCallback: # type: ignore
"""Call the function with the values.
Args:
*values: The values to call the function with.
Returns:
The function with the values.
"""
return self.func(*values) # type: ignore
@overload @overload
def __get__( def __get__(
self: EventCallback[[V, V2], T], instance: None, owner self: EventCallback[P, T], instance: None, owner
) -> Callable[[Union[Var[V], V], Union[Var[V2], V2]], EventSpec]: ... ) -> EventCallback[P, T]: ...
@overload
def __get__(
self: EventCallback[[V, V2, V3], T], instance: None, owner
) -> Callable[
[Union[Var[V], V], Union[Var[V2], V2], Union[Var[V3], V3]],
EventSpec,
]: ...
@overload
def __get__(
self: EventCallback[[V, V2, V3, V4], T], instance: None, owner
) -> Callable[
[
Union[Var[V], V],
Union[Var[V2], V2],
Union[Var[V3], V3],
Union[Var[V4], V4],
],
EventSpec,
]: ...
@overload
def __get__(
self: EventCallback[[V, V2, V3, V4, V5], T], instance: None, owner
) -> Callable[
[
Union[Var[V], V],
Union[Var[V2], V2],
Union[Var[V3], V3],
Union[Var[V4], V4],
Union[Var[V5], V5],
],
EventSpec,
]: ...
@overload @overload
def __get__(self, instance, owner) -> Callable[P, T]: ... def __get__(self, instance, owner) -> Callable[P, T]: ...
def __get__(self, instance, owner) -> Callable: def __get__(self, instance, owner) -> Callable: # type: ignore
"""Get the function with the instance bound to it. """Get the function with the instance bound to it.
Args: Args:
@ -1548,6 +1564,9 @@ if sys.version_info >= (3, 10):
return func # type: ignore return func # type: ignore
else: else:
class EventCallback(Generic[P, T]):
"""A descriptor that wraps a function to be used as an event."""
def event_handler(func: Callable[P, T]) -> Callable[P, T]: def event_handler(func: Callable[P, T]) -> Callable[P, T]:
"""Wrap a function to be used as an event. """Wrap a function to be used as an event.
@ -1560,6 +1579,17 @@ else:
return func return func
G = ParamSpec("G")
IndividualEventType = Union[
EventSpec, EventHandler, Callable[G, Any], EventCallback[G, Any], Var[Any]
]
ItemOrList = Union[V, List[V]]
EventType = ItemOrList[IndividualEventType[G]]
class EventNamespace(types.SimpleNamespace): class EventNamespace(types.SimpleNamespace):
"""A namespace for event related classes.""" """A namespace for event related classes."""

View File

@ -38,7 +38,7 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
url = url or conf.db_url url = url or conf.db_url
if url is None: if url is None:
raise ValueError("No database url configured") raise ValueError("No database url configured")
if environment.ALEMBIC_CONFIG.exists(): if not environment.ALEMBIC_CONFIG.exists():
console.warn( console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first." "Database is not initialized, run [bold]reflex db init[/bold] first."
) )

View File

@ -220,6 +220,7 @@ class EventHandlerSetVar(EventHandler):
Raises: Raises:
AttributeError: If the given Var name does not exist on the state. AttributeError: If the given Var name does not exist on the state.
EventHandlerValueError: If the given Var name is not a str EventHandlerValueError: If the given Var name is not a str
NotImplementedError: If the setter for the given Var is async
""" """
from reflex.utils.exceptions import EventHandlerValueError from reflex.utils.exceptions import EventHandlerValueError
@ -228,11 +229,20 @@ class EventHandlerSetVar(EventHandler):
raise EventHandlerValueError( raise EventHandlerValueError(
f"Var name must be passed as a string, got {args[0]!r}" f"Var name must be passed as a string, got {args[0]!r}"
) )
handler = getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None)
# Check that the requested Var setter exists on the State at compile time. # Check that the requested Var setter exists on the State at compile time.
if getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) is None: if handler is None:
raise AttributeError( raise AttributeError(
f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`" f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`"
) )
if asyncio.iscoroutinefunction(handler.fn):
raise NotImplementedError(
f"Setter for {args[0]} is async, which is not supported."
)
return super().__call__(*args) return super().__call__(*args)
@ -2895,9 +2905,13 @@ class StateManagerDisk(StateManager):
for substate in state.get_substates(): for substate in state.get_substates():
substate_token = _substate_key(client_token, substate) substate_token = _substate_key(client_token, substate)
fresh_instance = await root_state.get_state(substate)
instance = await self.load_state(substate_token) instance = await self.load_state(substate_token)
if instance is None: if instance is not None:
instance = await root_state.get_state(substate) # Ensure all substates exist, even if they weren't serialized previously.
instance.substates = fresh_instance.substates
else:
instance = fresh_instance
state.substates[substate.get_name()] = instance state.substates[substate.get_name()] = instance
instance.parent_state = state instance.parent_state = state

View File

@ -249,7 +249,8 @@ class AppHarness:
return textwrap.dedent(source) return textwrap.dedent(source)
def _initialize_app(self): def _initialize_app(self):
os.environ["TELEMETRY_ENABLED"] = "" # disable telemetry reporting for tests # disable telemetry reporting for tests
os.environ["TELEMETRY_ENABLED"] = "false"
self.app_path.mkdir(parents=True, exist_ok=True) self.app_path.mkdir(parents=True, exist_ok=True)
if self.app_source is not None: if self.app_source is not None:
app_globals = self._get_globals_from_signature(self.app_source) app_globals = self._get_globals_from_signature(self.app_source)

View File

@ -23,6 +23,12 @@ def merge_imports(
for lib, fields in ( for lib, fields in (
import_dict if isinstance(import_dict, tuple) else import_dict.items() import_dict if isinstance(import_dict, tuple) else import_dict.items()
): ):
# If the lib is an absolute path, we need to prefix it with a $
lib = (
"$" + lib
if lib.startswith(("/utils/", "/components/", "/styles/", "/public/"))
else lib
)
if isinstance(fields, (list, tuple, set)): if isinstance(fields, (list, tuple, set)):
all_imports[lib].extend( all_imports[lib].extend(
( (

View File

@ -1,5 +1,7 @@
import multiprocessing import multiprocessing
import os import os
from pathlib import Path
from typing import Any, Dict
import pytest import pytest
@ -42,7 +44,12 @@ def test_set_app_name(base_config_values):
("TELEMETRY_ENABLED", True), ("TELEMETRY_ENABLED", True),
], ],
) )
def test_update_from_env(base_config_values, monkeypatch, env_var, value): def test_update_from_env(
base_config_values: Dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
env_var: str,
value: Any,
):
"""Test that environment variables override config values. """Test that environment variables override config values.
Args: Args:
@ -57,6 +64,29 @@ def test_update_from_env(base_config_values, monkeypatch, env_var, value):
assert getattr(config, env_var.lower()) == value assert getattr(config, env_var.lower()) == value
def test_update_from_env_path(
base_config_values: Dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
):
"""Test that environment variables override config values.
Args:
base_config_values: Config values.
monkeypatch: The pytest monkeypatch object.
tmp_path: The pytest tmp_path fixture object.
"""
monkeypatch.setenv("BUN_PATH", "/test")
assert os.environ.get("BUN_PATH") == "/test"
with pytest.raises(ValueError):
rx.Config(**base_config_values)
monkeypatch.setenv("BUN_PATH", str(tmp_path))
assert os.environ.get("BUN_PATH") == str(tmp_path)
config = rx.Config(**base_config_values)
assert config.bun_path == tmp_path
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kwargs, expected", "kwargs, expected",
[ [

View File

@ -106,6 +106,7 @@ class TestState(BaseState):
fig: Figure = Figure() fig: Figure = Figure()
dt: datetime.datetime = datetime.datetime.fromisoformat("1989-11-09T18:53:00+01:00") dt: datetime.datetime = datetime.datetime.fromisoformat("1989-11-09T18:53:00+01:00")
_backend: int = 0 _backend: int = 0
asynctest: int = 0
@ComputedVar @ComputedVar
def sum(self) -> float: def sum(self) -> float:
@ -129,6 +130,14 @@ class TestState(BaseState):
"""Do something.""" """Do something."""
pass pass
async def set_asynctest(self, value: int):
"""Set the asynctest value. Intentionally overwrite the default setter with an async one.
Args:
value: The new value.
"""
self.asynctest = value
class ChildState(TestState): class ChildState(TestState):
"""A child state fixture.""" """A child state fixture."""
@ -313,6 +322,7 @@ def test_class_vars(test_state):
"upper", "upper",
"fig", "fig",
"dt", "dt",
"asynctest",
} }
@ -733,6 +743,7 @@ def test_reset(test_state, child_state):
"mapping", "mapping",
"dt", "dt",
"_backend", "_backend",
"asynctest",
} }
# The dirty vars should be reset. # The dirty vars should be reset.
@ -3179,6 +3190,13 @@ async def test_setvar(mock_app: rx.App, token: str):
TestState.setvar(42, 42) TestState.setvar(42, 42)
@pytest.mark.asyncio
async def test_setvar_async_setter():
"""Test that overridden async setters raise Exception when used with setvar."""
with pytest.raises(NotImplementedError):
TestState.setvar("asynctest", 42)
@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis") @pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"expiration_kwargs, expected_values", "expiration_kwargs, expected_values",
@ -3313,3 +3331,36 @@ def test_assignment_to_undeclared_vars():
state.handle_supported_regular_vars() state.handle_supported_regular_vars()
state.handle_non_var() state.handle_non_var()
@pytest.mark.asyncio
async def test_deserialize_gc_state_disk(token):
"""Test that a state can be deserialized from disk with a grandchild state.
Args:
token: A token.
"""
class Root(BaseState):
pass
class State(Root):
num: int = 42
class Child(State):
foo: str = "bar"
dsm = StateManagerDisk(state=Root)
async with dsm.modify_state(token) as root:
s = await root.get_state(State)
s.num += 1
c = await root.get_state(Child)
assert s._get_was_touched()
assert not c._get_was_touched()
dsm2 = StateManagerDisk(state=Root)
root = await dsm2.get_state(token)
s = await root.get_state(State)
assert s.num == 43
c = await root.get_state(Child)
assert c.foo == "bar"

View File

@ -601,6 +601,7 @@ formatted_router = {
"sum": 3.14, "sum": 3.14,
"upper": "", "upper": "",
"router": formatted_router, "router": formatted_router,
"asynctest": 0,
}, },
ChildState.get_full_name(): { ChildState.get_full_name(): {
"count": 23, "count": 23,