diff --git a/poetry.lock b/poetry.lock index 7161e6af7..ff9323ad4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1977,20 +1977,6 @@ files = [ [package.dependencies] six = ">=1.5" -[[package]] -name = "python-dotenv" -version = "1.0.1" -description = "Read key-value pairs from a .env file and set them as environment variables" -optional = false -python-versions = ">=3.8" -files = [ - {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, - {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, -] - -[package.extras] -cli = ["click (>=5.0)"] - [[package]] name = "python-engineio" version = "4.10.1" @@ -3047,4 +3033,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "c5da15520cef58124f6699007c81158036840469d4f9972592d72bd456c45e7e" +content-hash = "8090ccaeca173bd8612e17a0b8d157d7492618e49450abd1c8373e2976349db0" diff --git a/pyproject.toml b/pyproject.toml index 3fe6041f0..93f3c5d50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,6 @@ jinja2 = ">=3.1.2,<4.0" psutil = ">=5.9.4,<7.0" pydantic = ">=1.10.2,<3.0" python-multipart = ">=0.0.5,<0.1" -python-dotenv = ">=1.0.1" python-socketio = ">=5.7.0,<6.0" redis = ">=4.3.5,<6.0" rich = ">=13.0.0,<14.0" diff --git a/reflex/.templates/web/components/reflex/radix_themes_color_mode_provider.js b/reflex/.templates/web/components/reflex/radix_themes_color_mode_provider.js index 1e6ea1a56..dd7886c89 100644 --- a/reflex/.templates/web/components/reflex/radix_themes_color_mode_provider.js +++ b/reflex/.templates/web/components/reflex/radix_themes_color_mode_provider.js @@ -4,7 +4,7 @@ import { ColorModeContext, defaultColorMode, isDevMode, - lastCompiledTimeStamp + lastCompiledTimeStamp, } from "$/utils/context.js"; export default function RadixThemesColorModeProvider({ children }) { @@ -37,7 +37,7 @@ export default function RadixThemesColorModeProvider({ children }) { const allowedModes = ["light", "dark", "system"]; if (!allowedModes.includes(mode)) { console.error( - `Invalid color mode "${mode}". Defaulting to "${defaultColorMode}".`, + `Invalid color mode "${mode}". Defaulting to "${defaultColorMode}".` ); mode = defaultColorMode; } diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 21b83af9f..4db4679d5 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -315,7 +315,7 @@ def _compile_stateful_components( # Don't import from the file that we're about to create. all_imports = utils.merge_imports(*all_import_dicts) all_imports.pop( - f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None + f"$/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None ) return templates.STATEFUL_COMPONENTS.render( diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 40640faa6..29398da87 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -83,6 +83,12 @@ def validate_imports(import_dict: ParsedImportDict): f"{_import.tag}/{_import.alias}" if _import.alias else _import.tag ) if import_name in used_tags: + already_imported = used_tags[import_name] + if (already_imported[0] == "$" and already_imported[1:] == lib) or ( + lib[0] == "$" and lib[1:] == already_imported + ): + used_tags[import_name] = lib if lib[0] == "$" else already_imported + continue raise ValueError( f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}" ) diff --git a/reflex/components/component.py b/reflex/components/component.py index 7407080ef..ea647a884 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -39,6 +39,7 @@ from reflex.constants import ( ) from reflex.constants.compiler import SpecialAttributes from reflex.event import ( + EventCallback, EventChain, EventChainVar, EventHandler, @@ -1135,6 +1136,8 @@ class Component(BaseComponent, ABC): for trigger in self.event_triggers.values(): if isinstance(trigger, EventChain): for event in trigger.events: + if isinstance(event, EventCallback): + continue if isinstance(event, EventSpec): if event.handler.state_full_name: return True @@ -2241,7 +2244,7 @@ class StatefulComponent(BaseComponent): """ if self.rendered_as_shared: return { - f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [ + f"$/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [ ImportVar(tag=self.tag) ] } diff --git a/reflex/components/core/debounce.py b/reflex/components/core/debounce.py index 07ad1d69e..86efb7dcd 100644 --- a/reflex/components/core/debounce.py +++ b/reflex/components/core/debounce.py @@ -118,7 +118,7 @@ class DebounceInput(Component): _var_type=Type[Component], _var_data=VarData( imports=child._get_imports(), - hooks=child._get_hooks_internal(), + hooks=child._get_all_hooks(), ), ), ) @@ -128,6 +128,10 @@ class DebounceInput(Component): component.event_triggers.update(child.event_triggers) component.children = child.children component._rename_props = child._rename_props + outer_get_all_custom_code = component._get_all_custom_code + component._get_all_custom_code = lambda: outer_get_all_custom_code().union( + child._get_all_custom_code() + ) return component def _render(self): diff --git a/reflex/components/dynamic.py b/reflex/components/dynamic.py index 4df89330a..c0e172224 100644 --- a/reflex/components/dynamic.py +++ b/reflex/components/dynamic.py @@ -93,7 +93,7 @@ def load_dynamic_serializer(): for lib, names in component._get_all_imports().items(): formatted_lib_name = format_library_name(lib) if ( - not lib.startswith((".", "$/")) + not lib.startswith((".", "/", "$/")) and not lib.startswith("http") and formatted_lib_name not in libs_in_window ): @@ -109,7 +109,7 @@ def load_dynamic_serializer(): # Rewrite imports from `/` to destructure from window for ix, line in enumerate(module_code_lines[:]): if line.startswith("import "): - if 'from "$/' in line: + if 'from "$/' in line or 'from "/' in line: module_code_lines[ix] = ( line.replace("import ", "const ", 1).replace( " from ", " = window['__reflex'][", 1 diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 084aa8f8e..7cb776ee9 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -615,6 +615,42 @@ class Textarea(BaseHTML): # Fired when a key is released on_key_up: EventHandler[key_event] + @classmethod + def create(cls, *children, **props): + """Create a textarea component. + + Args: + *children: The children of the textarea. + **props: The properties of the textarea. + + Returns: + The textarea component. + + Raises: + ValueError: when `enter_key_submit` is combined with `on_key_down`. + """ + enter_key_submit = props.get("enter_key_submit") + auto_height = props.get("auto_height") + custom_attrs = props.setdefault("custom_attrs", {}) + + if enter_key_submit is not None: + enter_key_submit = Var.create(enter_key_submit) + if "on_key_down" in props: + raise ValueError( + "Cannot combine `enter_key_submit` with `on_key_down`.", + ) + custom_attrs["on_key_down"] = Var( + _js_expr=f"(e) => enterKeySubmitOnKeyDown(e, {str(enter_key_submit)})", + _var_data=VarData.merge(enter_key_submit._get_all_var_data()), + ) + if auto_height is not None: + auto_height = Var.create(auto_height) + custom_attrs["on_input"] = Var( + _js_expr=f"(e) => autoHeightOnInput(e, {str(auto_height)})", + _var_data=VarData.merge(auto_height._get_all_var_data()), + ) + return super().create(*children, **props) + def _exclude_props(self) -> list[str]: return super()._exclude_props() + [ "auto_height", @@ -634,28 +670,6 @@ class Textarea(BaseHTML): custom_code.add(ENTER_KEY_SUBMIT_JS) return custom_code - def _render(self) -> Tag: - tag = super()._render() - if self.enter_key_submit is not None: - if "on_key_down" in self.event_triggers: - raise ValueError( - "Cannot combine `enter_key_submit` with `on_key_down`.", - ) - tag.add_props( - on_key_down=Var( - _js_expr=f"(e) => enterKeySubmitOnKeyDown(e, {str(self.enter_key_submit)})", - _var_data=VarData.merge(self.enter_key_submit._get_all_var_data()), - ) - ) - if self.auto_height is not None: - tag.add_props( - on_input=Var( - _js_expr=f"(e) => autoHeightOnInput(e, {str(self.auto_height)})", - _var_data=VarData.merge(self.auto_height._get_all_var_data()), - ) - ) - return tag - button = Button.create fieldset = Fieldset.create diff --git a/reflex/components/el/elements/forms.pyi b/reflex/components/el/elements/forms.pyi index 135291213..2f889bc38 100644 --- a/reflex/components/el/elements/forms.pyi +++ b/reflex/components/el/elements/forms.pyi @@ -1376,10 +1376,10 @@ class Textarea(BaseHTML): on_unmount: Optional[EventType[[]]] = None, **props, ) -> "Textarea": - """Create the component. + """Create a textarea component. Args: - *children: The children of the component. + *children: The children of the textarea. auto_complete: Whether the form control should have autocomplete enabled auto_focus: Automatically focuses the textarea when the page loads auto_height: Automatically fit the content height to the text (use min-height with this prop) @@ -1419,10 +1419,13 @@ class Textarea(BaseHTML): class_name: The class name for the component. autofocus: Whether the component should take the focus once the page is loaded custom_attrs: custom attribute - **props: The props of the component. + **props: The properties of the textarea. Returns: - The component. + The textarea component. + + Raises: + ValueError: when `enter_key_submit` is combined with `on_key_down`. """ ... diff --git a/reflex/config.py b/reflex/config.py index eb319c5dd..ba86d911d 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -8,12 +8,12 @@ import os import sys import urllib.parse from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set from typing_extensions import get_type_hints from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError -from reflex.utils.types import value_inside_optional +from reflex.utils.types import GenericType, is_union, value_inside_optional try: import pydantic.v1 as pydantic @@ -157,11 +157,13 @@ def get_default_value_for_field(field: dataclasses.Field) -> Any: ) -def interpret_boolean_env(value: str) -> bool: +# TODO: Change all interpret_.* signatures to value: str, field: dataclasses.Field once we migrate rx.Config to dataclasses +def interpret_boolean_env(value: str, field_name: str) -> bool: """Interpret a boolean environment variable value. Args: value: The environment variable value. + field_name: The field name. Returns: The interpreted value. @@ -176,14 +178,15 @@ def interpret_boolean_env(value: str) -> bool: return True elif value.lower() in false_values: return False - raise EnvironmentVarValueError(f"Invalid boolean value: {value}") + raise EnvironmentVarValueError(f"Invalid boolean value: {value} for {field_name}") -def interpret_int_env(value: str) -> int: +def interpret_int_env(value: str, field_name: str) -> int: """Interpret an integer environment variable value. Args: value: The environment variable value. + field_name: The field name. Returns: The interpreted value. @@ -194,14 +197,17 @@ def interpret_int_env(value: str) -> int: try: return int(value) except ValueError as ve: - raise EnvironmentVarValueError(f"Invalid integer value: {value}") from ve + raise EnvironmentVarValueError( + f"Invalid integer value: {value} for {field_name}" + ) from ve -def interpret_path_env(value: str) -> Path: +def interpret_path_env(value: str, field_name: str) -> Path: """Interpret a path environment variable value. Args: value: The environment variable value. + field_name: The field name. Returns: The interpreted value. @@ -211,16 +217,19 @@ def interpret_path_env(value: str) -> Path: """ path = Path(value) if not path.exists(): - raise EnvironmentVarValueError(f"Path does not exist: {path}") + raise EnvironmentVarValueError(f"Path does not exist: {path} for {field_name}") return path -def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any: +def interpret_env_var_value( + value: str, field_type: GenericType, field_name: str +) -> Any: """Interpret an environment variable value based on the field type. Args: value: The environment variable value. - field: The field. + field_type: The field type. + field_name: The field name. Returns: The interpreted value. @@ -228,20 +237,25 @@ def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any: Raises: ValueError: If the value is invalid. """ - field_type = value_inside_optional(field.type) + field_type = value_inside_optional(field_type) + + if is_union(field_type): + raise ValueError( + f"Union types are not supported for environment variables: {field_name}." + ) if field_type is bool: - return interpret_boolean_env(value) + return interpret_boolean_env(value, field_name) elif field_type is str: return value elif field_type is int: - return interpret_int_env(value) + return interpret_int_env(value, field_name) elif field_type is Path: - return interpret_path_env(value) + return interpret_path_env(value, field_name) else: raise ValueError( - f"Invalid type for environment variable {field.name}: {field_type}. This is probably an issue in Reflex." + f"Invalid type for environment variable {field_name}: {field_type}. This is probably an issue in Reflex." ) @@ -316,7 +330,7 @@ class EnvironmentVariables: field.type = type_hints.get(field.name) or field.type value = ( - interpret_env_var_value(raw_value, field) + interpret_env_var_value(raw_value, field.type, field.name) if raw_value is not None else get_default_value_for_field(field) ) @@ -387,7 +401,7 @@ class Config(Base): telemetry_enabled: bool = True # The bun path - bun_path: Union[str, Path] = constants.Bun.DEFAULT_PATH + bun_path: Path = constants.Bun.DEFAULT_PATH # List of origins that are allowed to connect to the backend API. cors_allowed_origins: List[str] = ["*"] @@ -484,17 +498,17 @@ class Config(Base): Returns: The updated config values. - - Raises: - EnvVarValueError: If an environment variable is set to an invalid type. """ - from reflex.utils.exceptions import EnvVarValueError - if self.env_file: - from dotenv import load_dotenv + try: + from dotenv import load_dotenv # type: ignore - # load env file if exists - load_dotenv(self.env_file, override=True) + # load env file if exists + load_dotenv(self.env_file, override=True) + except ImportError: + console.error( + """The `python-dotenv` package is required to load environment variables from a file. Run `pip install "python-dotenv>=1.0.1"`.""" + ) updated_values = {} # Iterate over the fields. @@ -510,21 +524,11 @@ class Config(Base): dedupe=True, ) - # Convert the env var to the expected type. - try: - if issubclass(field.type_, bool): - # special handling for bool values - env_var = env_var.lower() in ["true", "1", "yes"] - else: - env_var = field.type_(env_var) - except ValueError as ve: - console.error( - f"Could not convert {key.upper()}={env_var} to type {field.type_}" - ) - raise EnvVarValueError from ve + # Interpret the value. + value = interpret_env_var_value(env_var, field.type_, field.name) # Set the value. - updated_values[key] = env_var + updated_values[key] = value return updated_values diff --git a/reflex/event.py b/reflex/event.py index cfe40ef9a..8b75fc01f 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -16,6 +16,7 @@ from typing import ( Generic, List, Optional, + Sequence, Tuple, Type, TypeVar, @@ -389,7 +390,9 @@ class CallableEventSpec(EventSpec): class EventChain(EventActionsMixin): """Container for a chain of events that will be executed in order.""" - events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list) + events: Sequence[Union[EventSpec, EventVar, EventCallback]] = dataclasses.field( + default_factory=list + ) args_spec: Optional[Callable] = dataclasses.field(default=None) @@ -1445,13 +1448,8 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar): ) -G = ParamSpec("G") - -IndividualEventType = Union[EventSpec, EventHandler, Callable[G, Any], Var[Any]] - -EventType = Union[IndividualEventType[G], List[IndividualEventType[G]]] - P = ParamSpec("P") +Q = ParamSpec("Q") T = TypeVar("T") V = TypeVar("V") V2 = TypeVar("V2") @@ -1473,55 +1471,73 @@ if sys.version_info >= (3, 10): """ self.func = func + @property + def prevent_default(self): + """Prevent default behavior. + + Returns: + The event callback with prevent default behavior. + """ + return self + + @property + def stop_propagation(self): + """Stop event propagation. + + Returns: + The event callback with stop propagation behavior. + """ + return self + @overload - def __get__( - self: EventCallback[[V], T], instance: None, owner - ) -> Callable[[Union[Var[V], V]], EventSpec]: ... + def __call__( + self: EventCallback[Concatenate[V, Q], T], value: V | Var[V] + ) -> EventCallback[Q, T]: ... + + @overload + def __call__( + self: EventCallback[Concatenate[V, V2, Q], T], + value: V | Var[V], + value2: V2 | Var[V2], + ) -> EventCallback[Q, T]: ... + + @overload + def __call__( + self: EventCallback[Concatenate[V, V2, V3, Q], T], + value: V | Var[V], + value2: V2 | Var[V2], + value3: V3 | Var[V3], + ) -> EventCallback[Q, T]: ... + + @overload + def __call__( + self: EventCallback[Concatenate[V, V2, V3, V4, Q], T], + value: V | Var[V], + value2: V2 | Var[V2], + value3: V3 | Var[V3], + value4: V4 | Var[V4], + ) -> EventCallback[Q, T]: ... + + def __call__(self, *values) -> EventCallback: # type: ignore + """Call the function with the values. + + Args: + *values: The values to call the function with. + + Returns: + The function with the values. + """ + return self.func(*values) # type: ignore @overload def __get__( - self: EventCallback[[V, V2], T], instance: None, owner - ) -> Callable[[Union[Var[V], V], Union[Var[V2], V2]], EventSpec]: ... - - @overload - def __get__( - self: EventCallback[[V, V2, V3], T], instance: None, owner - ) -> Callable[ - [Union[Var[V], V], Union[Var[V2], V2], Union[Var[V3], V3]], - EventSpec, - ]: ... - - @overload - def __get__( - self: EventCallback[[V, V2, V3, V4], T], instance: None, owner - ) -> Callable[ - [ - Union[Var[V], V], - Union[Var[V2], V2], - Union[Var[V3], V3], - Union[Var[V4], V4], - ], - EventSpec, - ]: ... - - @overload - def __get__( - self: EventCallback[[V, V2, V3, V4, V5], T], instance: None, owner - ) -> Callable[ - [ - Union[Var[V], V], - Union[Var[V2], V2], - Union[Var[V3], V3], - Union[Var[V4], V4], - Union[Var[V5], V5], - ], - EventSpec, - ]: ... + self: EventCallback[P, T], instance: None, owner + ) -> EventCallback[P, T]: ... @overload def __get__(self, instance, owner) -> Callable[P, T]: ... - def __get__(self, instance, owner) -> Callable: + def __get__(self, instance, owner) -> Callable: # type: ignore """Get the function with the instance bound to it. Args: @@ -1548,6 +1564,9 @@ if sys.version_info >= (3, 10): return func # type: ignore else: + class EventCallback(Generic[P, T]): + """A descriptor that wraps a function to be used as an event.""" + def event_handler(func: Callable[P, T]) -> Callable[P, T]: """Wrap a function to be used as an event. @@ -1560,6 +1579,17 @@ else: return func +G = ParamSpec("G") + +IndividualEventType = Union[ + EventSpec, EventHandler, Callable[G, Any], EventCallback[G, Any], Var[Any] +] + +ItemOrList = Union[V, List[V]] + +EventType = ItemOrList[IndividualEventType[G]] + + class EventNamespace(types.SimpleNamespace): """A namespace for event related classes.""" diff --git a/reflex/model.py b/reflex/model.py index b269044c5..5f5e8647d 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -38,7 +38,7 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: url = url or conf.db_url if url is None: raise ValueError("No database url configured") - if environment.ALEMBIC_CONFIG.exists(): + if not environment.ALEMBIC_CONFIG.exists(): console.warn( "Database is not initialized, run [bold]reflex db init[/bold] first." ) diff --git a/reflex/state.py b/reflex/state.py index 208c23a0f..e3e189b22 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -220,6 +220,7 @@ class EventHandlerSetVar(EventHandler): Raises: AttributeError: If the given Var name does not exist on the state. EventHandlerValueError: If the given Var name is not a str + NotImplementedError: If the setter for the given Var is async """ from reflex.utils.exceptions import EventHandlerValueError @@ -228,11 +229,20 @@ class EventHandlerSetVar(EventHandler): raise EventHandlerValueError( f"Var name must be passed as a string, got {args[0]!r}" ) + + handler = getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) + # Check that the requested Var setter exists on the State at compile time. - if getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) is None: + if handler is None: raise AttributeError( f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`" ) + + if asyncio.iscoroutinefunction(handler.fn): + raise NotImplementedError( + f"Setter for {args[0]} is async, which is not supported." + ) + return super().__call__(*args) @@ -2895,9 +2905,13 @@ class StateManagerDisk(StateManager): for substate in state.get_substates(): substate_token = _substate_key(client_token, substate) + fresh_instance = await root_state.get_state(substate) instance = await self.load_state(substate_token) - if instance is None: - instance = await root_state.get_state(substate) + if instance is not None: + # Ensure all substates exist, even if they weren't serialized previously. + instance.substates = fresh_instance.substates + else: + instance = fresh_instance state.substates[substate.get_name()] = instance instance.parent_state = state diff --git a/reflex/testing.py b/reflex/testing.py index 6a45c51eb..b41e56884 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -249,7 +249,8 @@ class AppHarness: return textwrap.dedent(source) def _initialize_app(self): - os.environ["TELEMETRY_ENABLED"] = "" # disable telemetry reporting for tests + # disable telemetry reporting for tests + os.environ["TELEMETRY_ENABLED"] = "false" self.app_path.mkdir(parents=True, exist_ok=True) if self.app_source is not None: app_globals = self._get_globals_from_signature(self.app_source) diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py index 8f53ed07a..bd422ecc0 100644 --- a/reflex/utils/imports.py +++ b/reflex/utils/imports.py @@ -23,6 +23,12 @@ def merge_imports( for lib, fields in ( import_dict if isinstance(import_dict, tuple) else import_dict.items() ): + # If the lib is an absolute path, we need to prefix it with a $ + lib = ( + "$" + lib + if lib.startswith(("/utils/", "/components/", "/styles/", "/public/")) + else lib + ) if isinstance(fields, (list, tuple, set)): all_imports[lib].extend( ( diff --git a/tests/units/test_config.py b/tests/units/test_config.py index c4a143bc5..1027042c9 100644 --- a/tests/units/test_config.py +++ b/tests/units/test_config.py @@ -1,5 +1,7 @@ import multiprocessing import os +from pathlib import Path +from typing import Any, Dict import pytest @@ -42,7 +44,12 @@ def test_set_app_name(base_config_values): ("TELEMETRY_ENABLED", True), ], ) -def test_update_from_env(base_config_values, monkeypatch, env_var, value): +def test_update_from_env( + base_config_values: Dict[str, Any], + monkeypatch: pytest.MonkeyPatch, + env_var: str, + value: Any, +): """Test that environment variables override config values. Args: @@ -57,6 +64,29 @@ def test_update_from_env(base_config_values, monkeypatch, env_var, value): assert getattr(config, env_var.lower()) == value +def test_update_from_env_path( + base_config_values: Dict[str, Any], + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +): + """Test that environment variables override config values. + + Args: + base_config_values: Config values. + monkeypatch: The pytest monkeypatch object. + tmp_path: The pytest tmp_path fixture object. + """ + monkeypatch.setenv("BUN_PATH", "/test") + assert os.environ.get("BUN_PATH") == "/test" + with pytest.raises(ValueError): + rx.Config(**base_config_values) + + monkeypatch.setenv("BUN_PATH", str(tmp_path)) + assert os.environ.get("BUN_PATH") == str(tmp_path) + config = rx.Config(**base_config_values) + assert config.bun_path == tmp_path + + @pytest.mark.parametrize( "kwargs, expected", [ diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 610d69110..544ddc606 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -106,6 +106,7 @@ class TestState(BaseState): fig: Figure = Figure() dt: datetime.datetime = datetime.datetime.fromisoformat("1989-11-09T18:53:00+01:00") _backend: int = 0 + asynctest: int = 0 @ComputedVar def sum(self) -> float: @@ -129,6 +130,14 @@ class TestState(BaseState): """Do something.""" pass + async def set_asynctest(self, value: int): + """Set the asynctest value. Intentionally overwrite the default setter with an async one. + + Args: + value: The new value. + """ + self.asynctest = value + class ChildState(TestState): """A child state fixture.""" @@ -313,6 +322,7 @@ def test_class_vars(test_state): "upper", "fig", "dt", + "asynctest", } @@ -733,6 +743,7 @@ def test_reset(test_state, child_state): "mapping", "dt", "_backend", + "asynctest", } # The dirty vars should be reset. @@ -3179,6 +3190,13 @@ async def test_setvar(mock_app: rx.App, token: str): TestState.setvar(42, 42) +@pytest.mark.asyncio +async def test_setvar_async_setter(): + """Test that overridden async setters raise Exception when used with setvar.""" + with pytest.raises(NotImplementedError): + TestState.setvar("asynctest", 42) + + @pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis") @pytest.mark.parametrize( "expiration_kwargs, expected_values", @@ -3313,3 +3331,36 @@ def test_assignment_to_undeclared_vars(): state.handle_supported_regular_vars() state.handle_non_var() + + +@pytest.mark.asyncio +async def test_deserialize_gc_state_disk(token): + """Test that a state can be deserialized from disk with a grandchild state. + + Args: + token: A token. + """ + + class Root(BaseState): + pass + + class State(Root): + num: int = 42 + + class Child(State): + foo: str = "bar" + + dsm = StateManagerDisk(state=Root) + async with dsm.modify_state(token) as root: + s = await root.get_state(State) + s.num += 1 + c = await root.get_state(Child) + assert s._get_was_touched() + assert not c._get_was_touched() + + dsm2 = StateManagerDisk(state=Root) + root = await dsm2.get_state(token) + s = await root.get_state(State) + assert s.num == 43 + c = await root.get_state(Child) + assert c.foo == "bar" diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py index 17485d52e..f8b605541 100644 --- a/tests/units/utils/test_format.py +++ b/tests/units/utils/test_format.py @@ -601,6 +601,7 @@ formatted_router = { "sum": 3.14, "upper": "", "router": formatted_router, + "asynctest": 0, }, ChildState.get_full_name(): { "count": 23,