diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 017336ba5..2ca9aed23 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -33,7 +33,7 @@ env: PR_TITLE: ${{ github.event.pull_request.title }} jobs: - example-counter: + example-counter-and-nba-proxy: env: OUTPUT_FILE: import_benchmark.json timeout-minutes: 30 @@ -119,6 +119,26 @@ jobs: --benchmark-json "./reflex-examples/counter/${{ env.OUTPUT_FILE }}" --branch-name "${{ github.head_ref || github.ref_name }}" --pr-id "${{ github.event.pull_request.id }}" --app-name "counter" + - name: Install requirements for nba proxy example + working-directory: ./reflex-examples/nba-proxy + run: | + poetry run uv pip install -r requirements.txt + - name: Install additional dependencies for DB access + run: poetry run uv pip install psycopg + - name: Check export --backend-only before init for nba-proxy example + working-directory: ./reflex-examples/nba-proxy + run: | + poetry run reflex export --backend-only + - name: Init Website for nba-proxy example + working-directory: ./reflex-examples/nba-proxy + run: | + poetry run reflex init --loglevel debug + - name: Run Website and Check for errors + run: | + # Check that npm is home + npm -v + poetry run bash scripts/integration.sh ./reflex-examples/nba-proxy dev + reflex-web: strategy: diff --git a/pyproject.toml b/pyproject.toml index eccf21230..d1ae1dcf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,8 +86,8 @@ build-backend = "poetry.core.masonry.api" target-version = "py39" output-format = "concise" lint.isort.split-on-trailing-comma = false -lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "T", "W"] -lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012"] +lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "T", "TRY", "W"] +lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012", "TRY0"] lint.pydocstyle.convention = "google" [tool.ruff.lint.per-file-ignores] diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 41dbee446..5b8046347 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -3,6 +3,7 @@ import axios from "axios"; import io from "socket.io-client"; import JSON5 from "json5"; import env from "$/env.json"; +import reflexEnvironment from "$/reflex.json"; import Cookies from "universal-cookie"; import { useEffect, useRef, useState } from "react"; import Router, { useRouter } from "next/router"; @@ -407,10 +408,18 @@ export const connect = async ( socket.current = io(endpoint.href, { path: endpoint["pathname"], transports: transports, + protocols: env.TEST_MODE ? undefined : [reflexEnvironment.version], autoUnref: false, }); // Ensure undefined fields in events are sent as null instead of removed - socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v) + socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v); + socket.current.io.decoder.tryParse = (str) => { + try { + return JSON5.parse(str); + } catch (e) { + return false; + } + }; function checkVisibility() { if (document.visibilityState === "visible") { diff --git a/reflex/app.py b/reflex/app.py index 712dcee9f..e729eeefd 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -495,14 +495,8 @@ class App(MiddlewareMixin, LifespanMixin): Returns: The generated component. - - Raises: - exceptions.MatchTypeError: If the return types of match cases in rx.match are different. """ - try: - return component if isinstance(component, Component) else component() - except exceptions.MatchTypeError: - raise + return component if isinstance(component, Component) else component() def add_page( self, @@ -596,11 +590,12 @@ class App(MiddlewareMixin, LifespanMixin): meta=meta, ) - def _compile_page(self, route: str): + def _compile_page(self, route: str, save_page: bool = True): """Compile a page. Args: route: The route of the page to compile. + save_page: If True, the compiled page is saved to self.pages. """ component, enable_state = compiler.compile_unevaluated_page( route, self._unevaluated_pages[route], self._state, self.style, self.theme @@ -611,7 +606,8 @@ class App(MiddlewareMixin, LifespanMixin): # Add the page. self._check_routes_conflict(route) - self._pages[route] = component + if save_page: + self.pages[route] = component def get_load_events(self, route: str) -> list[IndividualEventType[[], Any]]: """Get the load events for a route. @@ -914,14 +910,16 @@ class App(MiddlewareMixin, LifespanMixin): # If a theme component was provided, wrap the app with it app_wrappers[(20, "Theme")] = self.theme - for route in self._unevaluated_pages: + should_compile = self._should_compile() + + for route in self.unevaluated_pages: console.debug(f"Evaluating page: {route}") - self._compile_page(route) + self._compile_page(route, save_page=should_compile) # Add the optional endpoints (_upload) self._add_optional_endpoints() - if not self._should_compile(): + if not should_compile: return self._validate_var_dependencies() @@ -1565,7 +1563,11 @@ class EventNamespace(AsyncNamespace): sid: The Socket.IO session id. environ: The request information, including HTTP headers. """ - pass + subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL", None) + if subprotocol and subprotocol != constants.Reflex.VERSION: + console.warn( + f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}." + ) def on_disconnect(self, sid): """Event for when the websocket disconnects. @@ -1598,10 +1600,36 @@ class EventNamespace(AsyncNamespace): Args: sid: The Socket.IO session id. data: The event data. + + Raises: + EventDeserializationError: If the event data is not a dictionary. """ fields = data - # Get the event. - event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS}) + + if isinstance(fields, str): + console.warn( + "Received event data as a string. This generally should not happen and may indicate a bug." + f" Event data: {fields}" + ) + try: + fields = json.loads(fields) + except json.JSONDecodeError as ex: + raise exceptions.EventDeserializationError( + f"Failed to deserialize event data: {fields}." + ) from ex + + if not isinstance(fields, dict): + raise exceptions.EventDeserializationError( + f"Event data must be a dictionary, but received {fields} of type {type(fields)}." + ) + + try: + # Get the event. + event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS}) + except (TypeError, ValueError) as ex: + raise exceptions.EventDeserializationError( + f"Failed to deserialize event data: {fields}." + ) from ex self.token_to_sid[event.token] = sid self.sid_to_token[sid] = event.token diff --git a/reflex/app_module_for_backend.py b/reflex/app_module_for_backend.py index 8109fc3d6..b0ae0a29f 100644 --- a/reflex/app_module_for_backend.py +++ b/reflex/app_module_for_backend.py @@ -7,14 +7,13 @@ from concurrent.futures import ThreadPoolExecutor from reflex import constants from reflex.utils import telemetry from reflex.utils.exec import is_prod_mode -from reflex.utils.prerequisites import get_app +from reflex.utils.prerequisites import get_and_validate_app if constants.CompileVars.APP != "app": raise AssertionError("unexpected variable name for 'app'") telemetry.send("compile") -app_module = get_app(reload=False) -app = getattr(app_module, constants.CompileVars.APP) +app, app_module = get_and_validate_app(reload=False) # For py3.9 compatibility when redis is used, we MUST add any decorator pages # before compiling the app in a thread to avoid event loop error (REF-2172). app._apply_decorated_pages() @@ -30,7 +29,7 @@ if is_prod_mode(): # ensure only "app" is exposed. del app_module del compile_future -del get_app +del get_and_validate_app del is_prod_mode del telemetry del constants diff --git a/reflex/components/component.py b/reflex/components/component.py index 8649b593d..ed90a0f24 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -429,20 +429,22 @@ class Component(BaseComponent, ABC): else: continue + def determine_key(value): + # Try to create a var from the value + key = value if isinstance(value, Var) else LiteralVar.create(value) + + # Check that the var type is not None. + if key is None: + raise TypeError + + return key + # Check whether the key is a component prop. if types._issubclass(field_type, Var): # Used to store the passed types if var type is a union. passed_types = None try: - # Try to create a var from the value. - if isinstance(value, Var): - kwargs[key] = value - else: - kwargs[key] = LiteralVar.create(value) - - # Check that the var type is not None. - if kwargs[key] is None: - raise TypeError + kwargs[key] = determine_key(value) expected_type = fields[key].outer_type_.__args__[0] # validate literal fields. @@ -740,22 +742,21 @@ class Component(BaseComponent, ABC): # Import here to avoid circular imports. from reflex.components.base.bare import Bare from reflex.components.base.fragment import Fragment - from reflex.utils.exceptions import ComponentTypeError + from reflex.utils.exceptions import ChildrenTypeError # Filter out None props props = {key: value for key, value in props.items() if value is not None} def validate_children(children): for child in children: - if isinstance(child, tuple): + if isinstance(child, (tuple, list)): validate_children(child) + # Make sure the child is a valid type. - if not types._isinstance(child, ComponentChild): - raise ComponentTypeError( - "Children of Reflex components must be other components, " - "state vars, or primitive Python types. " - f"Got child {child} of type {type(child)}.", - ) + if isinstance(child, dict) or not types._isinstance( + child, ComponentChild + ): + raise ChildrenTypeError(component=cls.__name__, child=child) # Validate all the children. validate_children(children) diff --git a/reflex/components/dynamic.py b/reflex/components/dynamic.py index fbfc55f97..806d610df 100644 --- a/reflex/components/dynamic.py +++ b/reflex/components/dynamic.py @@ -136,6 +136,23 @@ def load_dynamic_serializer(): module_code_lines.insert(0, "const React = window.__reflex.react;") + function_line = next( + index + for index, line in enumerate(module_code_lines) + if line.startswith("export default function") + ) + + module_code_lines = [ + line + for _, line in sorted( + enumerate(module_code_lines), + key=lambda x: ( + not (x[1].startswith("import ") and x[0] < function_line), + x[0], + ), + ) + ] + return "\n".join( [ "//__reflex_evaluate", diff --git a/reflex/components/lucide/icon.py b/reflex/components/lucide/icon.py index 04410ac56..6c7cbede7 100644 --- a/reflex/components/lucide/icon.py +++ b/reflex/components/lucide/icon.py @@ -2,13 +2,15 @@ from reflex.components.component import Component from reflex.utils import format -from reflex.vars.base import Var +from reflex.utils.imports import ImportVar +from reflex.vars.base import LiteralVar, Var +from reflex.vars.sequence import LiteralStringVar class LucideIconComponent(Component): """Lucide Icon Component.""" - library = "lucide-react@0.469.0" + library = "lucide-react@0.471.1" class Icon(LucideIconComponent): @@ -32,6 +34,7 @@ class Icon(LucideIconComponent): Raises: AttributeError: The errors tied to bad usage of the Icon component. ValueError: If the icon tag is invalid. + TypeError: If the icon name is not a string. Returns: The created component. @@ -39,7 +42,6 @@ class Icon(LucideIconComponent): if children: if len(children) == 1 and isinstance(children[0], str): props["tag"] = children[0] - children = [] else: raise AttributeError( f"Passing multiple children to Icon component is not allowed: remove positional arguments {children[1:]} to fix" @@ -47,24 +49,46 @@ class Icon(LucideIconComponent): if "tag" not in props: raise AttributeError("Missing 'tag' keyword-argument for Icon") + tag: str | Var | LiteralVar = props.pop("tag") + if isinstance(tag, LiteralVar): + if isinstance(tag, LiteralStringVar): + tag = tag._var_value + else: + raise TypeError(f"Icon name must be a string, got {type(tag)}") + elif isinstance(tag, Var): + return DynamicIcon.create(name=tag, **props) + if ( - not isinstance(props["tag"], str) - or format.to_snake_case(props["tag"]) not in LUCIDE_ICON_LIST + not isinstance(tag, str) + or format.to_snake_case(tag) not in LUCIDE_ICON_LIST ): raise ValueError( - f"Invalid icon tag: {props['tag']}. Please use one of the following: {', '.join(LUCIDE_ICON_LIST[0:25])}, ..." + f"Invalid icon tag: {tag}. Please use one of the following: {', '.join(LUCIDE_ICON_LIST[0:25])}, ..." "\nSee full list at https://lucide.dev/icons." ) - if props["tag"] in LUCIDE_ICON_MAPPING_OVERRIDE: - props["tag"] = LUCIDE_ICON_MAPPING_OVERRIDE[props["tag"]] + if tag in LUCIDE_ICON_MAPPING_OVERRIDE: + props["tag"] = LUCIDE_ICON_MAPPING_OVERRIDE[tag] else: - props["tag"] = ( - format.to_title_case(format.to_snake_case(props["tag"])) + "Icon" - ) + props["tag"] = format.to_title_case(format.to_snake_case(tag)) + "Icon" props["alias"] = f"Lucide{props['tag']}" props.setdefault("color", "var(--current-color)") - return super().create(*children, **props) + return super().create(**props) + + +class DynamicIcon(LucideIconComponent): + """A DynamicIcon component.""" + + tag = "DynamicIcon" + + name: Var[str] + + def _get_imports(self): + _imports = super()._get_imports() + if self.library: + _imports.pop(self.library) + _imports["lucide-react/dynamic"] = [ImportVar("DynamicIcon", install=False)] + return _imports LUCIDE_ICON_LIST = [ @@ -846,6 +870,7 @@ LUCIDE_ICON_LIST = [ "house", "house_plug", "house_plus", + "house_wifi", "ice_cream_bowl", "ice_cream_cone", "id_card", @@ -1534,6 +1559,7 @@ LUCIDE_ICON_LIST = [ "trending_up_down", "triangle", "triangle_alert", + "triangle_dashed", "triangle_right", "trophy", "truck", diff --git a/reflex/components/lucide/icon.pyi b/reflex/components/lucide/icon.pyi index 39a1da0e6..6094cfd87 100644 --- a/reflex/components/lucide/icon.pyi +++ b/reflex/components/lucide/icon.pyi @@ -104,12 +104,60 @@ class Icon(LucideIconComponent): Raises: AttributeError: The errors tied to bad usage of the Icon component. ValueError: If the icon tag is invalid. + TypeError: If the icon name is not a string. Returns: The created component. """ ... +class DynamicIcon(LucideIconComponent): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + name: Optional[Union[Var[str], str]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None, + on_blur: Optional[EventType[[], BASE_STATE]] = None, + on_click: Optional[EventType[[], BASE_STATE]] = None, + on_context_menu: Optional[EventType[[], BASE_STATE]] = None, + on_double_click: Optional[EventType[[], BASE_STATE]] = None, + on_focus: Optional[EventType[[], BASE_STATE]] = None, + on_mount: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_down: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_move: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_out: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_over: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_up: Optional[EventType[[], BASE_STATE]] = None, + on_scroll: Optional[EventType[[], BASE_STATE]] = None, + on_unmount: Optional[EventType[[], BASE_STATE]] = None, + **props, + ) -> "DynamicIcon": + """Create the component. + + Args: + *children: The children of the component. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: The props of the component. + + Returns: + The component. + """ + ... + LUCIDE_ICON_LIST = [ "a_arrow_down", "a_arrow_up", @@ -889,6 +937,7 @@ LUCIDE_ICON_LIST = [ "house", "house_plug", "house_plus", + "house_wifi", "ice_cream_bowl", "ice_cream_cone", "id_card", @@ -1577,6 +1626,7 @@ LUCIDE_ICON_LIST = [ "trending_up_down", "triangle", "triangle_alert", + "triangle_dashed", "triangle_right", "trophy", "truck", diff --git a/reflex/config.py b/reflex/config.py index 0579b019f..8511694fb 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -12,6 +12,7 @@ import threading import urllib.parse from importlib.util import find_spec from pathlib import Path +from types import ModuleType from typing import ( TYPE_CHECKING, Any, @@ -567,6 +568,9 @@ class EnvironmentVariables: # The maximum size of the reflex state in kilobytes. REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000) + # Whether to use the turbopack bundler. + REFLEX_USE_TURBOPACK: EnvVar[bool] = env_var(True) + environment = EnvironmentVariables() @@ -604,6 +608,9 @@ class Config(Base): # The name of the app (should match the name of the app directory). app_name: str + # The path to the app module. + app_module_import: Optional[str] = None + # The log level to use. loglevel: constants.LogLevel = constants.LogLevel.DEFAULT @@ -726,6 +733,19 @@ class Config(Base): "REDIS_URL is required when using the redis state manager." ) + @property + def app_module(self) -> ModuleType | None: + """Return the app module if `app_module_import` is set. + + Returns: + The app module. + """ + return ( + importlib.import_module(self.app_module_import) + if self.app_module_import + else None + ) + @property def module(self) -> str: """Get the module name of the app. @@ -733,6 +753,8 @@ class Config(Base): Returns: The module name. """ + if self.app_module is not None: + return self.app_module.__name__ return ".".join([self.app_name, self.app_name]) def update_from_env(self) -> dict[str, Any]: @@ -871,7 +893,7 @@ def get_config(reload: bool = False) -> Config: return cached_rxconfig.config with _config_lock: - sys_path = sys.path.copy() + orig_sys_path = sys.path.copy() sys.path.clear() sys.path.append(str(Path.cwd())) try: @@ -879,9 +901,14 @@ def get_config(reload: bool = False) -> Config: return _get_config() except Exception: # If the module import fails, try to import with the original sys.path. - sys.path.extend(sys_path) + sys.path.extend(orig_sys_path) return _get_config() finally: + # Find any entries added to sys.path by rxconfig.py itself. + extra_paths = [ + p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd()) + ] # Restore the original sys.path. sys.path.clear() - sys.path.extend(sys_path) + sys.path.extend(extra_paths) + sys.path.extend(orig_sys_path) diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index e816da0f7..f5946bf5e 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -1,6 +1,7 @@ """The constants package.""" from .base import ( + APP_HARNESS_FLAG, COOKIES, IS_LINUX, IS_MACOS, diff --git a/reflex/constants/base.py b/reflex/constants/base.py index af96583ad..f737858c0 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -257,6 +257,7 @@ SESSION_STORAGE = "session_storage" # Testing variables. # Testing os env set by pytest when running a test case. PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST" +APP_HARNESS_FLAG = "APP_HARNESS_FLAG" REFLEX_VAR_OPENING_TAG = "" REFLEX_VAR_CLOSING_TAG = "" diff --git a/reflex/constants/installer.py b/reflex/constants/installer.py index 0b45586dd..f9dd26b5a 100644 --- a/reflex/constants/installer.py +++ b/reflex/constants/installer.py @@ -182,7 +182,7 @@ class PackageJson(SimpleNamespace): "@emotion/react": "11.13.3", "axios": "1.7.7", "json5": "2.2.3", - "next": "14.2.16", + "next": "15.1.4", "next-sitemap": "4.2.3", "next-themes": "0.4.3", "react": "18.3.1", diff --git a/reflex/custom_components/custom_components.py b/reflex/custom_components/custom_components.py index 4a169802f..8000e7f4c 100644 --- a/reflex/custom_components/custom_components.py +++ b/reflex/custom_components/custom_components.py @@ -421,12 +421,13 @@ def _run_commands_in_subprocess(cmds: list[str]) -> bool: console.debug(f"Running command: {' '.join(cmds)}") try: result = subprocess.run(cmds, capture_output=True, text=True, check=True) - console.debug(result.stdout) - return True except subprocess.CalledProcessError as cpe: console.error(cpe.stdout) console.error(cpe.stderr) return False + else: + console.debug(result.stdout) + return True def _make_pyi_files(): @@ -931,10 +932,11 @@ def _get_file_from_prompt_in_loop() -> Tuple[bytes, str] | None: file_extension = image_filepath.suffix try: image_file = image_filepath.read_bytes() - return image_file, file_extension except OSError as ose: console.error(f"Unable to read the {file_extension} file due to {ose}") raise typer.Exit(code=1) from ose + else: + return image_file, file_extension console.debug(f"File extension detected: {file_extension}") return None diff --git a/reflex/event.py b/reflex/event.py index 28852fde5..886a306c1 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -1591,7 +1591,7 @@ def get_handler_args( def fix_events( - events: list[EventHandler | EventSpec] | None, + events: list[EventSpec | EventHandler] | None, token: str, router_data: dict[str, Any] | None = None, ) -> list[Event]: diff --git a/reflex/reflex.py b/reflex/reflex.py index 22fcb9fb8..2d6ebc30c 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -440,7 +440,11 @@ def deploy( config.app_name, "--app-name", help="The name of the App to deploy under.", - hidden=True, + ), + app_id: str = typer.Option( + None, + "--app-id", + help="The ID of the App to deploy over.", ), regions: List[str] = typer.Option( [], @@ -480,6 +484,11 @@ def deploy( "--project", help="project id to deploy to", ), + project_name: Optional[str] = typer.Option( + None, + "--project-name", + help="The name of the project to deploy to.", + ), token: Optional[str] = typer.Option( None, "--token", @@ -503,13 +512,6 @@ def deploy( # Set the log level. console.set_log_level(loglevel) - if not token: - # make sure user is logged in. - if interactive: - hosting_cli.login() - else: - raise SystemExit("Token is required for non-interactive mode.") - # Only check requirements if interactive. # There is user interaction for requirements update. if interactive: @@ -519,9 +521,12 @@ def deploy( if prerequisites.needs_reinit(frontend=True): _init(name=config.app_name, loglevel=loglevel) prerequisites.check_latest_package_version(constants.ReflexHostingCLI.MODULE_NAME) - + extra: dict[str, str] = ( + {"config_path": config_path} if config_path is not None else {} + ) hosting_cli.deploy( app_name=app_name, + app_id=app_id, export_fn=lambda zip_dest_dir, api_url, deploy_url, @@ -546,6 +551,8 @@ def deploy( token=token, project=project, config_path=config_path, + project_name=project_name, + **extra, ) diff --git a/reflex/state.py b/reflex/state.py index a31aae032..66098d232 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -104,6 +104,7 @@ from reflex.utils.exceptions import ( LockExpiredError, ReflexRuntimeError, SetUndefinedStateVarError, + StateMismatchError, StateSchemaMismatchError, StateSerializationError, StateTooLargeError, @@ -1199,7 +1200,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): fget=func, auto_deps=False, deps=["router"], - cache=True, _js_expr=param, _var_data=VarData.from_state(cls), ) @@ -1543,7 +1543,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Return the direct parent of target_state_cls for subsequent linking. return parent_state - def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState: + def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE: """Get a state instance from the cache. Args: @@ -1551,11 +1551,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Returns: The instance of state_cls associated with this state's client_token. + + Raises: + StateMismatchError: If the state instance is not of the expected type. """ root_state = self._get_root_state() - return root_state.get_substate(state_cls.get_full_name().split(".")) + substate = root_state.get_substate(state_cls.get_full_name().split(".")) + if not isinstance(substate, state_cls): + raise StateMismatchError( + f"Searched for state {state_cls.get_full_name()} but found {substate}." + ) + return substate - async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState: + async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE: """Get a state instance from redis. Args: @@ -1566,6 +1574,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Raises: RuntimeError: If redis is not used in this backend process. + StateMismatchError: If the state instance is not of the expected type. """ # Fetch all missing parent states from redis. parent_state_of_state_cls = await self._populate_parent_states(state_cls) @@ -1577,14 +1586,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. " "(All states should already be available -- this is likely a bug).", ) - return await state_manager.get_state( + + state_in_redis = await state_manager.get_state( token=_substate_key(self.router.session.client_token, state_cls), top_level=False, get_substates=True, parent_state=parent_state_of_state_cls, ) - async def get_state(self, state_cls: Type[BaseState]) -> BaseState: + if not isinstance(state_in_redis, state_cls): + raise StateMismatchError( + f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}." + ) + + return state_in_redis + + async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE: """Get an instance of the state associated with this token. Allows for arbitrary access to sibling states from within an event handler. @@ -1759,9 +1776,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): except Exception as ex: state._clean() - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) - - event_specs = app_instance.backend_exception_handler(ex) + event_specs = ( + prerequisites.get_and_validate_app().app.backend_exception_handler(ex) + ) if event_specs is None: return StateUpdate() @@ -1871,9 +1888,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): except Exception as ex: telemetry.send_error(ex, context="backend") - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) - - event_specs = app_instance.backend_exception_handler(ex) + event_specs = ( + prerequisites.get_and_validate_app().app.backend_exception_handler(ex) + ) yield state._as_state_update( handler, @@ -2316,6 +2333,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return state +T_STATE = TypeVar("T_STATE", bound=BaseState) + + class State(BaseState): """The app Base State.""" @@ -2383,8 +2403,9 @@ class FrontendEventExceptionState(State): component_stack: The stack trace of the component where the exception occurred. """ - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) - app_instance.frontend_exception_handler(Exception(stack)) + prerequisites.get_and_validate_app().app.frontend_exception_handler( + Exception(stack) + ) class UpdateVarsInternalState(State): @@ -2422,15 +2443,16 @@ class OnLoadInternalState(State): The list of events to queue for on load handling. """ # Do not app._compile()! It should be already compiled by now. - app = getattr(prerequisites.get_app(), constants.CompileVars.APP) - load_events = app.get_load_events(self.router.page.path) + load_events = prerequisites.get_and_validate_app().app.get_load_events( + self.router.page.path + ) if not load_events: self.is_hydrated = True return # Fast path for navigation with no on_load events defined. self.is_hydrated = False return [ *fix_events( - load_events, + cast(list[Union[EventSpec, EventHandler]], load_events), self.router.session.client_token, router_data=self.router_data, ), @@ -2589,7 +2611,7 @@ class StateProxy(wrapt.ObjectProxy): """ super().__init__(state_instance) # compile is not relevant to backend logic - self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP) + self._self_app = prerequisites.get_and_validate_app().app self._self_substate_path = tuple(state_instance.get_full_name().split(".")) self._self_actx = None self._self_mutable = False @@ -3682,8 +3704,7 @@ def get_state_manager() -> StateManager: Returns: The state manager. """ - app = getattr(prerequisites.get_app(), constants.CompileVars.APP) - return app.state_manager + return prerequisites.get_and_validate_app().app.state_manager class MutableProxy(wrapt.ObjectProxy): diff --git a/reflex/testing.py b/reflex/testing.py index f9bef2c09..027648f97 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -282,6 +282,7 @@ class AppHarness: before_decorated_pages = reflex.app.DECORATED_PAGES[self.app_name].copy() # Ensure the AppHarness test does not skip State assignment due to running via pytest os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None) + os.environ[reflex.constants.APP_HARNESS_FLAG] = "true" self.app_module = reflex.utils.prerequisites.get_compiled_app( # Do not reload the module for pre-existing apps (only apps generated from source) reload=self.app_source is not None diff --git a/reflex/utils/build.py b/reflex/utils/build.py index e263374e1..9ea941792 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -13,13 +13,17 @@ from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from reflex import constants from reflex.config import get_config from reflex.utils import console, path_ops, prerequisites, processes +from reflex.utils.exec import is_in_app_harness def set_env_json(): """Write the upload url to a REFLEX_JSON.""" path_ops.update_json_file( str(prerequisites.get_web_dir() / constants.Dirs.ENV_JSON), - {endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint}, + { + **{endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint}, + "TEST_MODE": is_in_app_harness(), + }, ) diff --git a/reflex/utils/console.py b/reflex/utils/console.py index be545140a..8929b63b6 100644 --- a/reflex/utils/console.py +++ b/reflex/utils/console.py @@ -2,6 +2,11 @@ from __future__ import annotations +import inspect +import shutil +from pathlib import Path +from types import FrameType + from rich.console import Console from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from rich.prompt import Prompt @@ -188,6 +193,33 @@ def warn(msg: str, dedupe: bool = False, **kwargs): print(f"[orange1]Warning: {msg}[/orange1]", **kwargs) +def _get_first_non_framework_frame() -> FrameType | None: + import click + import typer + import typing_extensions + + import reflex as rx + + # Exclude utility modules that should never be the source of deprecated reflex usage. + exclude_modules = [click, rx, typer, typing_extensions] + exclude_roots = [ + p.parent.resolve() + if (p := Path(m.__file__)).name == "__init__.py" + else p.resolve() + for m in exclude_modules + ] + # Specifically exclude the reflex cli module. + if reflex_bin := shutil.which(b"reflex"): + exclude_roots.append(Path(reflex_bin.decode())) + + frame = inspect.currentframe() + while frame := frame and frame.f_back: + frame_path = Path(inspect.getfile(frame)).resolve() + if not any(frame_path.is_relative_to(root) for root in exclude_roots): + break + return frame + + def deprecate( feature_name: str, reason: str, @@ -206,15 +238,27 @@ def deprecate( dedupe: If True, suppress multiple console logs of deprecation message. kwargs: Keyword arguments to pass to the print function. """ - if feature_name not in _EMITTED_DEPRECATION_WARNINGS: + dedupe_key = feature_name + loc = "" + + # See if we can find where the deprecation exists in "user code" + origin_frame = _get_first_non_framework_frame() + if origin_frame is not None: + filename = Path(origin_frame.f_code.co_filename) + if filename.is_relative_to(Path.cwd()): + filename = filename.relative_to(Path.cwd()) + loc = f"{filename}:{origin_frame.f_lineno}" + dedupe_key = f"{dedupe_key} {loc}" + + if dedupe_key not in _EMITTED_DEPRECATION_WARNINGS: msg = ( f"{feature_name} has been deprecated in version {deprecation_version} {reason.rstrip('.')}. It will be completely " - f"removed in {removal_version}" + f"removed in {removal_version}. ({loc})" ) if _LOG_LEVEL <= LogLevel.WARNING: print(f"[yellow]DeprecationWarning: {msg}[/yellow]", **kwargs) if dedupe: - _EMITTED_DEPRECATION_WARNINGS.add(feature_name) + _EMITTED_DEPRECATION_WARNINGS.add(dedupe_key) def error(msg: str, dedupe: bool = False, **kwargs): diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index bceadc977..37a68e420 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -1,6 +1,6 @@ """Custom Exceptions.""" -from typing import NoReturn +from typing import Any, NoReturn class ReflexError(Exception): @@ -31,6 +31,22 @@ class ComponentTypeError(ReflexError, TypeError): """Custom TypeError for component related errors.""" +class ChildrenTypeError(ComponentTypeError): + """Raised when the children prop of a component is not a valid type.""" + + def __init__(self, component: str, child: Any): + """Initialize the exception. + + Args: + component: The name of the component. + child: The child that caused the error. + """ + super().__init__( + f"Component {component} received child {child} of type {type(child)}. " + "Accepted types are other components, state vars, or primitive Python types (dict excluded)." + ) + + class EventHandlerTypeError(ReflexError, TypeError): """Custom TypeError for event handler related errors.""" @@ -163,10 +179,18 @@ class StateSerializationError(ReflexError): """Raised when the state cannot be serialized.""" +class StateMismatchError(ReflexError, ValueError): + """Raised when the state retrieved does not match the expected state.""" + + class SystemPackageMissingError(ReflexError): """Raised when a system package is missing.""" +class EventDeserializationError(ReflexError, ValueError): + """Raised when an event cannot be deserialized.""" + + def raise_system_package_missing_error(package: str) -> NoReturn: """Raise a SystemPackageMissingError. diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 621c4a608..c10b6b856 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -240,6 +240,28 @@ def run_backend( run_uvicorn_backend(host, port, loglevel) +def get_reload_dirs() -> list[str]: + """Get the reload directories for the backend. + + Returns: + The reload directories for the backend. + """ + config = get_config() + reload_dirs = [config.app_name] + if config.app_module is not None and config.app_module.__file__: + module_path = Path(config.app_module.__file__).resolve().parent + while module_path.parent.name: + for parent_file in module_path.parent.iterdir(): + if parent_file == "__init__.py": + # go up a level to find dir without `__init__.py` + module_path = module_path.parent + break + else: + break + reload_dirs.append(str(module_path)) + return reload_dirs + + def run_uvicorn_backend(host, port, loglevel: LogLevel): """Run the backend in development mode using Uvicorn. @@ -256,7 +278,7 @@ def run_uvicorn_backend(host, port, loglevel: LogLevel): port=port, log_level=loglevel.value, reload=True, - reload_dirs=[get_config().app_name], + reload_dirs=get_reload_dirs(), ) @@ -281,7 +303,7 @@ def run_granian_backend(host, port, loglevel: LogLevel): interface=Interfaces.ASGI, log_level=LogLevels(loglevel.value), reload=True, - reload_paths=[Path(get_config().app_name)], + reload_paths=get_reload_dirs(), reload_ignore_dirs=[".web"], ).serve() except ImportError: @@ -487,6 +509,15 @@ def is_testing_env() -> bool: return constants.PYTEST_CURRENT_TEST in os.environ +def is_in_app_harness() -> bool: + """Whether the app is running in the app harness. + + Returns: + True if the app is running in the app harness. + """ + return constants.APP_HARNESS_FLAG in os.environ + + def is_prod_mode() -> bool: """Check if the app is running in production mode. diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index d838c0eea..4f9cc0c14 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -17,11 +17,12 @@ import stat import sys import tempfile import time +import typing import zipfile from datetime import datetime from pathlib import Path from types import ModuleType -from typing import Callable, List, Optional +from typing import Callable, List, NamedTuple, Optional import httpx import typer @@ -42,9 +43,19 @@ from reflex.utils.exceptions import ( from reflex.utils.format import format_library_name from reflex.utils.registry import _get_npm_registry +if typing.TYPE_CHECKING: + from reflex.app import App + CURRENTLY_INSTALLING_NODE = False +class AppInfo(NamedTuple): + """A tuple containing the app instance and module.""" + + app: App + module: ModuleType + + @dataclasses.dataclass(frozen=True) class Template: """A template for a Reflex app.""" @@ -267,6 +278,22 @@ def windows_npm_escape_hatch() -> bool: return environment.REFLEX_USE_NPM.get() +def _check_app_name(config: Config): + """Check if the app name is set in the config. + + Args: + config: The config object. + + Raises: + RuntimeError: If the app name is not set in the config. + """ + if not config.app_name: + raise RuntimeError( + "Cannot get the app module because `app_name` is not set in rxconfig! " + "If this error occurs in a reflex test case, ensure that `get_app` is mocked." + ) + + def get_app(reload: bool = False) -> ModuleType: """Get the app module based on the default config. @@ -277,22 +304,23 @@ def get_app(reload: bool = False) -> ModuleType: The app based on the default config. Raises: - RuntimeError: If the app name is not set in the config. + Exception: If an error occurs while getting the app module. """ from reflex.utils import telemetry try: environment.RELOAD_CONFIG.set(reload) config = get_config() - if not config.app_name: - raise RuntimeError( - "Cannot get the app module because `app_name` is not set in rxconfig! " - "If this error occurs in a reflex test case, ensure that `get_app` is mocked." - ) + + _check_app_name(config) + module = config.module sys.path.insert(0, str(Path.cwd())) - app = __import__(module, fromlist=(constants.CompileVars.APP,)) - + app = ( + __import__(module, fromlist=(constants.CompileVars.APP,)) + if not config.app_module + else config.app_module + ) if reload: from reflex.state import reload_state_module @@ -301,11 +329,34 @@ def get_app(reload: bool = False) -> ModuleType: # Reload the app module. importlib.reload(app) - - return app except Exception as ex: telemetry.send_error(ex, context="frontend") raise + else: + return app + + +def get_and_validate_app(reload: bool = False) -> AppInfo: + """Get the app instance based on the default config and validate it. + + Args: + reload: Re-import the app module from disk + + Returns: + The app instance and the app module. + + Raises: + RuntimeError: If the app instance is not an instance of rx.App. + """ + from reflex.app import App + + app_module = get_app(reload=reload) + app = getattr(app_module, constants.CompileVars.APP) + if not isinstance(app, App): + raise RuntimeError( + "The app instance in the specified app_module_import in rxconfig must be an instance of rx.App." + ) + return AppInfo(app=app, module=app_module) def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: @@ -318,8 +369,7 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: Returns: The compiled app based on the default config. """ - app_module = get_app(reload=reload) - app = getattr(app_module, constants.CompileVars.APP) + app, app_module = get_and_validate_app(reload=reload) # For py3.9 compatibility when redis is used, we MUST add any decorator pages # before compiling the app in a thread to avoid event loop error (REF-2172). app._apply_decorated_pages() @@ -610,10 +660,14 @@ def initialize_web_directory(): init_reflex_json(project_hash=project_hash) +def _turbopack_flag() -> str: + return " --turbopack" if environment.REFLEX_USE_TURBOPACK.get() else "" + + def _compile_package_json(): return templates.PACKAGE_JSON.render( scripts={ - "dev": constants.PackageJson.Commands.DEV, + "dev": constants.PackageJson.Commands.DEV + _turbopack_flag(), "export": constants.PackageJson.Commands.EXPORT, "export_sitemap": constants.PackageJson.Commands.EXPORT_SITEMAP, "prod": constants.PackageJson.Commands.PROD, @@ -1149,11 +1203,12 @@ def ensure_reflex_installation_id() -> Optional[int]: if installation_id is None: installation_id = random.getrandbits(128) installation_id_file.write_text(str(installation_id)) - # If we get here, installation_id is definitely set - return installation_id except Exception as e: console.debug(f"Failed to ensure reflex installation id: {e}") return None + else: + # If we get here, installation_id is definitely set + return installation_id def initialize_reflex_user_directory(): @@ -1367,19 +1422,22 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str except OSError as ose: console.error(f"Failed to create temp directory for extracting zip: {ose}") raise typer.Exit(1) from ose + try: zipfile.ZipFile(zip_file_path).extractall(path=unzip_dir) # The zip file downloaded from github looks like: # repo-name-branch/**/*, so we need to remove the top level directory. - if len(subdirs := os.listdir(unzip_dir)) != 1: - console.error(f"Expected one directory in the zip, found {subdirs}") - raise typer.Exit(1) - template_dir = unzip_dir / subdirs[0] - console.debug(f"Template folder is located at {template_dir}") except Exception as uze: console.error(f"Failed to unzip the template: {uze}") raise typer.Exit(1) from uze + if len(subdirs := os.listdir(unzip_dir)) != 1: + console.error(f"Expected one directory in the zip, found {subdirs}") + raise typer.Exit(1) + + template_dir = unzip_dir / subdirs[0] + console.debug(f"Template folder is located at {template_dir}") + # Move the rxconfig file here first. path_ops.mv(str(template_dir / constants.Config.FILE), constants.Config.FILE) new_config = get_config(reload=True) diff --git a/reflex/utils/processes.py b/reflex/utils/processes.py index 871b5f323..3673b36b2 100644 --- a/reflex/utils/processes.py +++ b/reflex/utils/processes.py @@ -17,6 +17,7 @@ import typer from redis.exceptions import RedisError from reflex import constants +from reflex.config import environment from reflex.utils import console, path_ops, prerequisites @@ -156,24 +157,30 @@ def new_process(args, run: bool = False, show_logs: bool = False, **kwargs): Raises: Exit: When attempting to run a command with a None value. """ - node_bin_path = str(path_ops.get_node_bin_path()) - if not node_bin_path and not prerequisites.CURRENTLY_INSTALLING_NODE: - console.warn( - "The path to the Node binary could not be found. Please ensure that Node is properly " - "installed and added to your system's PATH environment variable or try running " - "`reflex init` again." - ) + # Check for invalid command first. if None in args: console.error(f"Invalid command: {args}") raise typer.Exit(1) - # Add the node bin path to the PATH environment variable. + + path_env: str = os.environ.get("PATH", "") + + # Add node_bin_path to the PATH environment variable. + if not environment.REFLEX_BACKEND_ONLY.get(): + node_bin_path = str(path_ops.get_node_bin_path()) + if not node_bin_path and not prerequisites.CURRENTLY_INSTALLING_NODE: + console.warn( + "The path to the Node binary could not be found. Please ensure that Node is properly " + "installed and added to your system's PATH environment variable or try running " + "`reflex init` again." + ) + path_env = os.pathsep.join([node_bin_path, path_env]) + env: dict[str, str] = { **os.environ, - "PATH": os.pathsep.join( - [node_bin_path if node_bin_path else "", os.environ["PATH"]] - ), # type: ignore + "PATH": path_env, **kwargs.pop("env", {}), } + kwargs = { "env": env, "stderr": None if show_logs else subprocess.STDOUT, diff --git a/reflex/utils/telemetry.py b/reflex/utils/telemetry.py index fc90932a6..8e9130b09 100644 --- a/reflex/utils/telemetry.py +++ b/reflex/utils/telemetry.py @@ -156,9 +156,10 @@ def _prepare_event(event: str, **kwargs) -> dict: def _send_event(event_data: dict) -> bool: try: httpx.post(POSTHOG_API_URL, json=event_data) - return True except Exception: return False + else: + return True def _send(event, telemetry_enabled, **kwargs): diff --git a/reflex/utils/types.py b/reflex/utils/types.py index b8bcbf2d6..ac30342e2 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -829,6 +829,22 @@ StateBases = get_base_class(StateVar) StateIterBases = get_base_class(StateIterVar) +def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]): + """Check if a class is a subclass of another class. Returns False if internal error occurs. + + Args: + cls: The class to check. + cls_check: The class to check against. + + Returns: + Whether the class is a subclass of the other class. + """ + try: + return issubclass(cls, cls_check) + except TypeError: + return False + + def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool: """Check if a type hint is a subclass of another type hint. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 0a93901cd..122545187 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -26,6 +26,7 @@ from typing import ( Iterable, List, Literal, + Mapping, NoReturn, Optional, Set, @@ -64,6 +65,7 @@ from reflex.utils.types import ( _isinstance, get_origin, has_args, + safe_issubclass, unionize, ) @@ -127,7 +129,7 @@ class VarData: state: str = "", field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, - hooks: dict[str, VarData | None] | None = None, + hooks: Mapping[str, VarData | None] | None = None, deps: list[Var] | None = None, position: Hooks.HookPosition | None = None, ): @@ -561,7 +563,7 @@ class Var(Generic[VAR_TYPE]): if _var_is_local is not None: console.deprecate( feature_name="_var_is_local", - reason="The _var_is_local argument is not supported for Var." + reason="The _var_is_local argument is not supported for Var. " "If you want to create a Var from a raw Javascript expression, use the constructor directly", deprecation_version="0.6.0", removal_version="0.7.0", @@ -569,7 +571,7 @@ class Var(Generic[VAR_TYPE]): if _var_is_string is not None: console.deprecate( feature_name="_var_is_string", - reason="The _var_is_string argument is not supported for Var." + reason="The _var_is_string argument is not supported for Var. " "If you want to create a Var from a raw Javascript expression, use the constructor directly", deprecation_version="0.6.0", removal_version="0.7.0", @@ -643,8 +645,8 @@ class Var(Generic[VAR_TYPE]): @overload def to( self, - output: type[dict], - ) -> ObjectVar[dict]: ... + output: type[Mapping], + ) -> ObjectVar[Mapping]: ... @overload def to( @@ -686,7 +688,9 @@ class Var(Generic[VAR_TYPE]): # If the first argument is a python type, we map it to the corresponding Var type. for var_subclass in _var_subclasses[::-1]: - if fixed_output_type in var_subclass.python_types: + if fixed_output_type in var_subclass.python_types or safe_issubclass( + fixed_output_type, var_subclass.python_types + ): return self.to(var_subclass.var_subclass, output) if fixed_output_type is None: @@ -820,7 +824,7 @@ class Var(Generic[VAR_TYPE]): return False if issubclass(type_, list): return [] - if issubclass(type_, dict): + if issubclass(type_, Mapping): return {} if issubclass(type_, tuple): return () @@ -1026,7 +1030,7 @@ class Var(Generic[VAR_TYPE]): f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")] } ), - ).to(ObjectVar, Dict[str, str]) + ).to(ObjectVar, Mapping[str, str]) return refs[LiteralVar.create(str(self))] @deprecated("Use `.js_type()` instead.") @@ -1373,7 +1377,7 @@ class LiteralVar(Var): serialized_value = serializers.serialize(value) if serialized_value is not None: - if isinstance(serialized_value, dict): + if isinstance(serialized_value, Mapping): return LiteralObjectVar.create( serialized_value, _var_type=type(value), @@ -1498,7 +1502,7 @@ def var_operation( ) -> Callable[P, ArrayVar[LIST_T]]: ... -OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict) +OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping) @overload @@ -1573,8 +1577,8 @@ def figure_out_type(value: Any) -> types.GenericType: return Set[unionize(*(figure_out_type(v) for v in value))] if isinstance(value, tuple): return Tuple[unionize(*(figure_out_type(v) for v in value)), ...] - if isinstance(value, dict): - return Dict[ + if isinstance(value, Mapping): + return Mapping[ unionize(*(figure_out_type(k) for k in value)), unionize(*(figure_out_type(v) for v in value.values())), ] @@ -1838,7 +1842,7 @@ class ComputedVar(Var[RETURN_TYPE]): self, fget: Callable[[BASE_STATE], RETURN_TYPE], initial_value: RETURN_TYPE | types.Unset = types.Unset(), - cache: bool = False, + cache: bool = True, deps: Optional[List[Union[str, Var]]] = None, auto_deps: bool = True, interval: Optional[Union[int, datetime.timedelta]] = None, @@ -2002,10 +2006,10 @@ class ComputedVar(Var[RETURN_TYPE]): @overload def __get__( - self: ComputedVar[dict[DICT_KEY, DICT_VAL]], + self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]], instance: None, owner: Type, - ) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ... + ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ... @overload def __get__( @@ -2253,7 +2257,7 @@ if TYPE_CHECKING: def computed_var( fget: None = None, initial_value: Any | types.Unset = types.Unset(), - cache: bool = False, + cache: bool = True, deps: Optional[List[Union[str, Var]]] = None, auto_deps: bool = True, interval: Optional[Union[datetime.timedelta, int]] = None, @@ -2266,7 +2270,7 @@ def computed_var( def computed_var( fget: Callable[[BASE_STATE], RETURN_TYPE], initial_value: RETURN_TYPE | types.Unset = types.Unset(), - cache: bool = False, + cache: bool = True, deps: Optional[List[Union[str, Var]]] = None, auto_deps: bool = True, interval: Optional[Union[datetime.timedelta, int]] = None, @@ -2278,7 +2282,7 @@ def computed_var( def computed_var( fget: Callable[[BASE_STATE], Any] | None = None, initial_value: Any | types.Unset = types.Unset(), - cache: Optional[bool] = None, + cache: bool = True, deps: Optional[List[Union[str, Var]]] = None, auto_deps: bool = True, interval: Optional[Union[datetime.timedelta, int]] = None, @@ -2304,15 +2308,6 @@ def computed_var( ValueError: If caching is disabled and an update interval is set. VarDependencyError: If user supplies dependencies without caching. """ - if cache is None: - cache = False - console.deprecate( - "Default non-cached rx.var", - "the default value will be `@rx.var(cache=True)` in a future release. " - "To retain uncached var, explicitly pass `@rx.var(cache=False)`", - deprecation_version="0.6.8", - removal_version="0.7.0", - ) if cache is False and interval is not None: raise ValueError("Cannot set update interval without caching.") @@ -2924,11 +2919,14 @@ V = TypeVar("V") BASE_TYPE = TypeVar("BASE_TYPE", bound=Base) +FIELD_TYPE = TypeVar("FIELD_TYPE") +MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping) -class Field(Generic[T]): + +class Field(Generic[FIELD_TYPE]): """Shadow class for Var to allow for type hinting in the IDE.""" - def __set__(self, instance, value: T): + def __set__(self, instance, value: FIELD_TYPE): """Set the Var. Args: @@ -2940,7 +2938,9 @@ class Field(Generic[T]): def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ... @overload - def __get__(self: Field[int], instance: None, owner) -> NumberVar: ... + def __get__( + self: Field[int] | Field[float] | Field[int | float], instance: None, owner + ) -> NumberVar: ... @overload def __get__(self: Field[str], instance: None, owner) -> StringVar: ... @@ -2957,8 +2957,8 @@ class Field(Generic[T]): @overload def __get__( - self: Field[Dict[str, V]], instance: None, owner - ) -> ObjectVar[Dict[str, V]]: ... + self: Field[MAPPING_TYPE], instance: None, owner + ) -> ObjectVar[MAPPING_TYPE]: ... @overload def __get__( @@ -2966,10 +2966,10 @@ class Field(Generic[T]): ) -> ObjectVar[BASE_TYPE]: ... @overload - def __get__(self, instance: None, owner) -> Var[T]: ... + def __get__(self, instance: None, owner) -> Var[FIELD_TYPE]: ... @overload - def __get__(self, instance, owner) -> T: ... + def __get__(self, instance, owner) -> FIELD_TYPE: ... def __get__(self, instance, owner): # type: ignore """Get the Var. @@ -2980,7 +2980,7 @@ class Field(Generic[T]): """ -def field(value: T) -> Field[T]: +def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]: """Create a Field with a value. Args: diff --git a/reflex/vars/function.py b/reflex/vars/function.py index 2a7d50e1b..131f15b9f 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -390,6 +390,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): Returns: The function var. """ + return_expr = Var.create(return_expr) return cls( _js_expr="", _var_type=_var_type, @@ -445,6 +446,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): Returns: The function var. """ + return_expr = Var.create(return_expr) return cls( _js_expr="", _var_type=_var_type, diff --git a/reflex/vars/number.py b/reflex/vars/number.py index d04aded35..a2a0293d5 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -20,7 +20,6 @@ from typing import ( from reflex.constants.base import Dirs from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError from reflex.utils.imports import ImportDict, ImportVar -from reflex.utils.types import is_optional from .base import ( CustomVarOperationReturn, @@ -431,7 +430,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("<", (type(self), type(other))) - return less_than_operation(self, +other) + return less_than_operation(+self, +other) @overload def __le__(self, other: number_types) -> BooleanVar: ... @@ -450,7 +449,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("<=", (type(self), type(other))) - return less_than_or_equal_operation(self, +other) + return less_than_or_equal_operation(+self, +other) def __eq__(self, other: Any): """Equal comparison. @@ -462,7 +461,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): The result of the comparison. """ if isinstance(other, NUMBER_TYPES): - return equal_operation(self, +other) + return equal_operation(+self, +other) return equal_operation(self, other) def __ne__(self, other: Any): @@ -475,7 +474,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): The result of the comparison. """ if isinstance(other, NUMBER_TYPES): - return not_equal_operation(self, +other) + return not_equal_operation(+self, +other) return not_equal_operation(self, other) @overload @@ -495,7 +494,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types(">", (type(self), type(other))) - return greater_than_operation(self, +other) + return greater_than_operation(+self, +other) @overload def __ge__(self, other: number_types) -> BooleanVar: ... @@ -514,17 +513,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types(">=", (type(self), type(other))) - return greater_than_or_equal_operation(self, +other) - - def bool(self): - """Boolean conversion. - - Returns: - The boolean value of the number. - """ - if is_optional(self._var_type): - return boolify((self != None) & (self != 0)) # noqa: E711 - return self != 0 + return greater_than_or_equal_operation(+self, +other) def _is_strict_float(self) -> bool: """Check if the number is a float. diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 5de431f5a..7b951c559 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -8,8 +8,8 @@ import typing from inspect import isclass from typing import ( Any, - Dict, List, + Mapping, NoReturn, Tuple, Type, @@ -19,6 +19,8 @@ from typing import ( overload, ) +from typing_extensions import is_typeddict + from reflex.utils import types from reflex.utils.exceptions import VarAttributeError from reflex.utils.types import GenericType, get_attribute_access_type, get_origin @@ -36,7 +38,7 @@ from .base import ( from .number import BooleanVar, NumberVar, raise_unsupported_operand_types from .sequence import ArrayVar, StringVar -OBJECT_TYPE = TypeVar("OBJECT_TYPE") +OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True) KEY_TYPE = TypeVar("KEY_TYPE") VALUE_TYPE = TypeVar("VALUE_TYPE") @@ -46,7 +48,7 @@ ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE") OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE") -class ObjectVar(Var[OBJECT_TYPE], python_types=dict): +class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): """Base class for immutable object vars.""" def _key_type(self) -> Type: @@ -59,7 +61,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def _value_type( - self: ObjectVar[Dict[Any, VALUE_TYPE]], + self: ObjectVar[Mapping[Any, VALUE_TYPE]], ) -> Type[VALUE_TYPE]: ... @overload @@ -74,7 +76,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): fixed_type = get_origin(self._var_type) or self._var_type if not isclass(fixed_type): return Any - args = get_args(self._var_type) if issubclass(fixed_type, dict) else () + args = get_args(self._var_type) if issubclass(fixed_type, Mapping) else () return args[1] if args else Any def keys(self) -> ArrayVar[List[str]]: @@ -87,7 +89,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def values( - self: ObjectVar[Dict[Any, VALUE_TYPE]], + self: ObjectVar[Mapping[Any, VALUE_TYPE]], ) -> ArrayVar[List[VALUE_TYPE]]: ... @overload @@ -103,7 +105,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def entries( - self: ObjectVar[Dict[Any, VALUE_TYPE]], + self: ObjectVar[Mapping[Any, VALUE_TYPE]], ) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ... @overload @@ -133,49 +135,55 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): # NoReturn is used here to catch when key value is Any @overload def __getitem__( - self: ObjectVar[Dict[Any, NoReturn]], + self: ObjectVar[Mapping[Any, NoReturn]], key: Var | Any, ) -> Var: ... + @overload + def __getitem__( + self: (ObjectVar[Mapping[Any, bool]]), + key: Var | Any, + ) -> BooleanVar: ... + @overload def __getitem__( self: ( - ObjectVar[Dict[Any, int]] - | ObjectVar[Dict[Any, float]] - | ObjectVar[Dict[Any, int | float]] + ObjectVar[Mapping[Any, int]] + | ObjectVar[Mapping[Any, float]] + | ObjectVar[Mapping[Any, int | float]] ), key: Var | Any, ) -> NumberVar: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, str]], + self: ObjectVar[Mapping[Any, str]], key: Var | Any, ) -> StringVar: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], + self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], key: Var | Any, ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]], key: Var | Any, - ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... + ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ... def __getitem__(self, key: Var | Any) -> Var: """Get an item from the object. @@ -195,49 +203,49 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): # NoReturn is used here to catch when key value is Any @overload def __getattr__( - self: ObjectVar[Dict[Any, NoReturn]], + self: ObjectVar[Mapping[Any, NoReturn]], name: str, ) -> Var: ... @overload def __getattr__( self: ( - ObjectVar[Dict[Any, int]] - | ObjectVar[Dict[Any, float]] - | ObjectVar[Dict[Any, int | float]] + ObjectVar[Mapping[Any, int]] + | ObjectVar[Mapping[Any, float]] + | ObjectVar[Mapping[Any, int | float]] ), name: str, ) -> NumberVar: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, str]], + self: ObjectVar[Mapping[Any, str]], name: str, ) -> StringVar: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]], name: str, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]], name: str, ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], + self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], name: str, ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]], name: str, - ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... + ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ... @overload def __getattr__( @@ -266,8 +274,11 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): var_type = get_args(var_type)[0] fixed_type = var_type if isclass(var_type) else get_origin(var_type) - if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or ( - fixed_type in types.UnionTypes + + if ( + (isclass(fixed_type) and not issubclass(fixed_type, Mapping)) + or (fixed_type in types.UnionTypes) + or is_typeddict(fixed_type) ): attribute_type = get_attribute_access_type(var_type, name) if attribute_type is None: @@ -299,7 +310,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): """Base class for immutable literal object vars.""" - _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( + _var_value: Mapping[Union[Var, Any], Union[Var, Any]] = dataclasses.field( default_factory=dict ) @@ -383,7 +394,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): @classmethod def create( cls, - _var_value: dict, + _var_value: Mapping, _var_type: Type[OBJECT_TYPE] | None = None, _var_data: VarData | None = None, ) -> LiteralObjectVar[OBJECT_TYPE]: @@ -466,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar): """ return var_operation_return( js_expression=f"({{...{lhs}, ...{rhs}}})", - var_type=Dict[ + var_type=Mapping[ Union[lhs._key_type(), rhs._key_type()], Union[lhs._value_type(), rhs._value_type()], ], diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 5864e70b9..1b11f50e6 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -987,7 +987,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): raise_unsupported_operand_types("[]", (type(self), type(i))) return array_item_operation(self, i) - def length(self) -> NumberVar: + def length(self) -> NumberVar[int]: """Get the length of the array. Returns: diff --git a/tests/integration/test_computed_vars.py b/tests/integration/test_computed_vars.py index 03aaf18b4..f56001ea8 100644 --- a/tests/integration/test_computed_vars.py +++ b/tests/integration/test_computed_vars.py @@ -22,22 +22,22 @@ def ComputedVars(): count: int = 0 # cached var with dep on count - @rx.var(cache=True, interval=15) + @rx.var(interval=15) def count1(self) -> int: return self.count # cached backend var with dep on count - @rx.var(cache=True, interval=15, backend=True) + @rx.var(interval=15, backend=True) def count1_backend(self) -> int: return self.count # same as above but implicit backend with `_` prefix - @rx.var(cache=True, interval=15) + @rx.var(interval=15) def _count1_backend(self) -> int: return self.count # explicit disabled auto_deps - @rx.var(interval=15, cache=True, auto_deps=False) + @rx.var(interval=15, auto_deps=False) def count3(self) -> int: # this will not add deps, because auto_deps is False print(self.count1) @@ -45,19 +45,27 @@ def ComputedVars(): return self.count # explicit dependency on count var - @rx.var(cache=True, deps=["count"], auto_deps=False) + @rx.var(deps=["count"], auto_deps=False) def depends_on_count(self) -> int: return self.count # explicit dependency on count1 var - @rx.var(cache=True, deps=[count1], auto_deps=False) + @rx.var(deps=[count1], auto_deps=False) def depends_on_count1(self) -> int: return self.count - @rx.var(deps=[count3], auto_deps=False, cache=True) + @rx.var( + deps=[count3], + auto_deps=False, + ) def depends_on_count3(self) -> int: return self.count + # special floats should be properly decoded on the frontend + @rx.var(cache=True, initial_value=[]) + def special_floats(self) -> list[float]: + return [42.9, float("nan"), float("inf"), float("-inf")] + @rx.event def increment(self): self.count += 1 @@ -103,6 +111,11 @@ def ComputedVars(): State.depends_on_count3, id="depends_on_count3", ), + rx.text("special_floats:"), + rx.text( + State.special_floats.join(", "), + id="special_floats", + ), ), ) @@ -224,6 +237,10 @@ async def test_computed_vars( assert depends_on_count3 assert depends_on_count3.text == "0" + special_floats = driver.find_element(By.ID, "special_floats") + assert special_floats + assert special_floats.text == "42.9, NaN, Infinity, -Infinity" + increment = driver.find_element(By.ID, "increment") assert increment.is_enabled() diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index 735dbd243..c2a912af6 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -71,9 +71,10 @@ def has_error_modal(driver: WebDriver) -> bool: """ try: driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH) - return True except NoSuchElementException: return False + else: + return True @pytest.mark.asyncio diff --git a/tests/integration/test_dynamic_routes.py b/tests/integration/test_dynamic_routes.py index 50d9f23b1..9032fd84c 100644 --- a/tests/integration/test_dynamic_routes.py +++ b/tests/integration/test_dynamic_routes.py @@ -74,16 +74,16 @@ def DynamicRoute(): class ArgState(rx.State): """The app state.""" - @rx.var + @rx.var(cache=False) def arg(self) -> int: return int(self.arg_str or 0) class ArgSubState(ArgState): - @rx.var(cache=True) + @rx.var def cached_arg(self) -> int: return self.arg - @rx.var(cache=True) + @rx.var def cached_arg_str(self) -> str: return self.arg_str diff --git a/tests/integration/test_lifespan.py b/tests/integration/test_lifespan.py index 0fa4a7e92..d79273fbc 100644 --- a/tests/integration/test_lifespan.py +++ b/tests/integration/test_lifespan.py @@ -36,7 +36,7 @@ def LifespanApp(): print("Lifespan global started.") try: while True: - lifespan_task_global += inc # pyright: ignore[reportUnboundVariable] + lifespan_task_global += inc # pyright: ignore[reportUnboundVariable, reportPossiblyUnboundVariable] await asyncio.sleep(0.1) except asyncio.CancelledError as ce: print(f"Lifespan global cancelled: {ce}.") @@ -45,11 +45,11 @@ def LifespanApp(): class LifespanState(rx.State): interval: int = 100 - @rx.var + @rx.var(cache=False) def task_global(self) -> int: return lifespan_task_global - @rx.var + @rx.var(cache=False) def context_global(self) -> int: return lifespan_context_global diff --git a/tests/integration/test_media.py b/tests/integration/test_media.py index 10af26591..649038a7e 100644 --- a/tests/integration/test_media.py +++ b/tests/integration/test_media.py @@ -22,31 +22,31 @@ def MediaApp(): img.format = format # type: ignore return img - @rx.var(cache=True) + @rx.var def img_default(self) -> Image.Image: return self._blue() - @rx.var(cache=True) + @rx.var def img_bmp(self) -> Image.Image: return self._blue(format="BMP") - @rx.var(cache=True) + @rx.var def img_jpg(self) -> Image.Image: return self._blue(format="JPEG") - @rx.var(cache=True) + @rx.var def img_png(self) -> Image.Image: return self._blue(format="PNG") - @rx.var(cache=True) + @rx.var def img_gif(self) -> Image.Image: return self._blue(format="GIF") - @rx.var(cache=True) + @rx.var def img_webp(self) -> Image.Image: return self._blue(format="WEBP") - @rx.var(cache=True) + @rx.var def img_from_url(self) -> Image.Image: img_url = "https://picsum.photos/id/1/200/300" img_resp = httpx.get(img_url, follow_redirects=True) diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py index f09e800e5..eaf98414d 100644 --- a/tests/units/components/core/test_match.py +++ b/tests/units/components/core/test_match.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import List, Mapping, Tuple import pytest @@ -67,7 +67,7 @@ def test_match_components(): assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}' assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })' - assert match_cases[4][0]._var_type == Dict[str, str] + assert match_cases[4][0]._var_type == Mapping[str, str] fifth_return_value_render = match_cases[4][1].render() assert fifth_return_value_render["name"] == "RadixThemesText" assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}' diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 674873b69..6396e4322 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -27,7 +27,7 @@ from reflex.event import ( from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports -from reflex.utils.exceptions import EventFnArgMismatch +from reflex.utils.exceptions import ChildrenTypeError, EventFnArgMismatch from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var @@ -645,14 +645,17 @@ def test_create_filters_none_props(test_component): assert str(component.style["text-align"]) == '"center"' -@pytest.mark.parametrize("children", [((None,),), ("foo", ("bar", (None,)))]) +@pytest.mark.parametrize( + "children", + [ + ((None,),), + ("foo", ("bar", (None,))), + ({"foo": "bar"},), + ], +) def test_component_create_unallowed_types(children, test_component): - with pytest.raises(TypeError) as err: + with pytest.raises(ChildrenTypeError): test_component.create(*children) - assert ( - err.value.args[0] - == "Children of Reflex components must be other components, state vars, or primitive Python types. Got child None of type ." - ) @pytest.mark.parametrize( diff --git a/tests/units/test_app.py b/tests/units/test_app.py index a09fde972..80e0be5fd 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -908,7 +908,7 @@ class DynamicState(BaseState): """Increment the counter var.""" self.counter = self.counter + 1 - @computed_var(cache=True) + @computed_var def comp_dynamic(self) -> str: """A computed var that depends on the dynamic var. @@ -1549,11 +1549,11 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]): base: int = 0 _backend: int = 0 - @computed_var(cache=True) + @computed_var() def foo(self) -> str: return "foo" - @computed_var(deps=["_backend", "base", foo], cache=True) + @computed_var(deps=["_backend", "base", foo]) def bar(self) -> str: return "bar" @@ -1565,7 +1565,7 @@ def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]): app, _ = compilable_app class InvalidDepState(BaseState): - @computed_var(deps=["foolksjdf"], cache=True) + @computed_var(deps=["foolksjdf"]) def bar(self) -> str: return "bar" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index cf3363770..05633b75f 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -202,7 +202,7 @@ class GrandchildState(ChildState): class GrandchildState2(ChildState2): """A grandchild state fixture.""" - @rx.var(cache=True) + @rx.var def cached(self) -> str: """A cached var. @@ -215,7 +215,7 @@ class GrandchildState2(ChildState2): class GrandchildState3(ChildState3): """A great grandchild state fixture.""" - @rx.var + @rx.var(cache=False) def computed(self) -> str: """A computed var. @@ -796,7 +796,7 @@ async def test_process_event_simple(test_state): # The delta should contain the changes, including computed vars. assert update.delta == { - TestState.get_full_name(): {"num1": 69, "sum": 72.14, "upper": ""}, + TestState.get_full_name(): {"num1": 69, "sum": 72.14}, GrandchildState3.get_full_name(): {"computed": ""}, } assert update.events == [] @@ -823,7 +823,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) assert child_state.value == "HI" assert child_state.count == 24 assert update.delta == { - TestState.get_full_name(): {"sum": 3.14, "upper": ""}, + # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, ChildState.get_full_name(): {"value": "HI", "count": 24}, GrandchildState3.get_full_name(): {"computed": ""}, } @@ -839,7 +839,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) update = await test_state._process(event).__anext__() assert grandchild_state.value2 == "new" assert update.delta == { - TestState.get_full_name(): {"sum": 3.14, "upper": ""}, + # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, GrandchildState.get_full_name(): {"value2": "new"}, GrandchildState3.get_full_name(): {"computed": ""}, } @@ -989,7 +989,7 @@ class InterdependentState(BaseState): v1: int = 0 _v2: int = 1 - @rx.var(cache=True) + @rx.var def v1x2(self) -> int: """Depends on var v1. @@ -998,7 +998,7 @@ class InterdependentState(BaseState): """ return self.v1 * 2 - @rx.var(cache=True) + @rx.var def v2x2(self) -> int: """Depends on backend var _v2. @@ -1007,7 +1007,7 @@ class InterdependentState(BaseState): """ return self._v2 * 2 - @rx.var(cache=True, backend=True) + @rx.var(backend=True) def v2x2_backend(self) -> int: """Depends on backend var _v2. @@ -1016,7 +1016,7 @@ class InterdependentState(BaseState): """ return self._v2 * 2 - @rx.var(cache=True) + @rx.var def v1x2x2(self) -> int: """Depends on ComputedVar v1x2. @@ -1025,7 +1025,7 @@ class InterdependentState(BaseState): """ return self.v1x2 * 2 # type: ignore - @rx.var(cache=True) + @rx.var def _v3(self) -> int: """Depends on backend var _v2. @@ -1034,7 +1034,7 @@ class InterdependentState(BaseState): """ return self._v2 - @rx.var(cache=True) + @rx.var def v3x2(self) -> int: """Depends on ComputedVar _v3. @@ -1239,7 +1239,7 @@ def test_computed_var_cached(): class ComputedState(BaseState): v: int = 0 - @rx.var(cache=True) + @rx.var def comp_v(self) -> int: nonlocal comp_v_calls comp_v_calls += 1 @@ -1264,15 +1264,15 @@ def test_computed_var_cached_depends_on_non_cached(): class ComputedState(BaseState): v: int = 0 - @rx.var + @rx.var(cache=False) def no_cache_v(self) -> int: return self.v - @rx.var(cache=True) + @rx.var def dep_v(self) -> int: return self.no_cache_v # type: ignore - @rx.var(cache=True) + @rx.var def comp_v(self) -> int: return self.v @@ -1304,14 +1304,14 @@ def test_computed_var_depends_on_parent_non_cached(): counter = 0 class ParentState(BaseState): - @rx.var + @rx.var(cache=False) def no_cache_v(self) -> int: nonlocal counter counter += 1 return counter class ChildState(ParentState): - @rx.var(cache=True) + @rx.var def dep_v(self) -> int: return self.no_cache_v # type: ignore @@ -1357,7 +1357,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool): def handler(self): self.x = self.x + 1 - @rx.var(cache=True) + @rx.var def cached_x_side_effect(self) -> int: self.handler() nonlocal counter @@ -1393,7 +1393,7 @@ def test_computed_var_dependencies(): def testprop(self) -> int: return self.v - @rx.var(cache=True) + @rx.var def comp_v(self) -> int: """Direct access. @@ -1402,7 +1402,7 @@ def test_computed_var_dependencies(): """ return self.v - @rx.var(cache=True, backend=True) + @rx.var(backend=True) def comp_v_backend(self) -> int: """Direct access backend var. @@ -1411,7 +1411,7 @@ def test_computed_var_dependencies(): """ return self.v - @rx.var(cache=True) + @rx.var def comp_v_via_property(self) -> int: """Access v via property. @@ -1420,7 +1420,7 @@ def test_computed_var_dependencies(): """ return self.testprop - @rx.var(cache=True) + @rx.var def comp_w(self): """Nested lambda. @@ -1429,7 +1429,7 @@ def test_computed_var_dependencies(): """ return lambda: self.w - @rx.var(cache=True) + @rx.var def comp_x(self): """Nested function. @@ -1442,7 +1442,7 @@ def test_computed_var_dependencies(): return _ - @rx.var(cache=True) + @rx.var def comp_y(self) -> List[int]: """Comprehension iterating over attribute. @@ -1451,7 +1451,7 @@ def test_computed_var_dependencies(): """ return [round(y) for y in self.y] - @rx.var(cache=True) + @rx.var def comp_z(self) -> List[bool]: """Comprehension accesses attribute. @@ -2027,10 +2027,6 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): assert mcall.args[0] == str(SocketEvent.EVENT) assert mcall.args[1] == StateUpdate( delta={ - parent_state.get_full_name(): { - "upper": "", - "sum": 3.14, - }, grandchild_state.get_full_name(): { "value2": "42", }, @@ -2053,7 +2049,7 @@ class BackgroundTaskState(BaseState): super().__init__(**kwargs) self.router_data = {"simulate": "hydrate"} - @rx.var + @rx.var(cache=False) def computed_order(self) -> List[str]: """Get the order as a computed var. @@ -3040,10 +3036,6 @@ async def test_get_state(mock_app: rx.App, token: str): grandchild_state.value2 = "set_value" assert test_state.get_delta() == { - TestState.get_full_name(): { - "sum": 3.14, - "upper": "", - }, GrandchildState.get_full_name(): { "value2": "set_value", }, @@ -3081,10 +3073,6 @@ async def test_get_state(mock_app: rx.App, token: str): child_state2.value = "set_c2_value" assert new_test_state.get_delta() == { - TestState.get_full_name(): { - "sum": 3.14, - "upper": "", - }, ChildState2.get_full_name(): { "value": "set_c2_value", }, @@ -3139,7 +3127,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str): child3_var: int = 0 - @rx.var + @rx.var(cache=False) def v(self): pass @@ -3210,8 +3198,8 @@ def test_potentially_dirty_substates(): def bar(self) -> str: return "" - assert RxState._potentially_dirty_substates() == {State} - assert State._potentially_dirty_substates() == {C1} + assert RxState._potentially_dirty_substates() == set() + assert State._potentially_dirty_substates() == set() assert C1._potentially_dirty_substates() == set() @@ -3226,7 +3214,7 @@ def test_router_var_dep() -> None: class RouterVarDepState(RouterVarParentState): """A state with a router var dependency.""" - @rx.var(cache=True) + @rx.var def foo(self) -> str: return self.router.page.params.get("foo", "") @@ -3421,7 +3409,7 @@ class MixinState(State, mixin=True): _backend: int = 0 _backend_no_default: dict - @rx.var(cache=True) + @rx.var def computed(self) -> str: """A computed var on mixin state. diff --git a/tests/units/test_state_tree.py b/tests/units/test_state_tree.py index 44ff58818..70ef71cb8 100644 --- a/tests/units/test_state_tree.py +++ b/tests/units/test_state_tree.py @@ -42,7 +42,7 @@ class SubA_A_A_A(SubA_A_A): class SubA_A_A_B(SubA_A_A): """SubA_A_A_B is a child of SubA_A_A.""" - @rx.var(cache=True) + @rx.var def sub_a_a_a_cached(self) -> int: """A cached var. @@ -117,7 +117,7 @@ class TreeD(Root): d: int - @rx.var + @rx.var(cache=False) def d_var(self) -> int: """A computed var. @@ -156,7 +156,7 @@ class SubE_A_A_A_A(SubE_A_A_A): sub_e_a_a_a_a: int - @rx.var + @rx.var(cache=False) def sub_e_a_a_a_a_var(self) -> int: """A computed var. @@ -183,7 +183,7 @@ class SubE_A_A_A_D(SubE_A_A_A): sub_e_a_a_a_d: int - @rx.var(cache=True) + @rx.var def sub_e_a_a_a_d_var(self) -> int: """A computed var. diff --git a/tests/units/test_style.py b/tests/units/test_style.py index e1d652798..bb585fd22 100644 --- a/tests/units/test_style.py +++ b/tests/units/test_style.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any, Mapping import pytest @@ -379,7 +379,7 @@ class StyleState(rx.State): { "css": Var( _js_expr=f'({{ ["color"] : ("dark"+{StyleState.color}) }})' - ).to(Dict[str, str]) + ).to(Mapping[str, str]) }, ), ( diff --git a/tests/units/test_var.py b/tests/units/test_var.py index bfa8aa35a..a8e9cd88c 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -2,7 +2,7 @@ import json import math import sys import typing -from typing import Dict, List, Optional, Set, Tuple, Union, cast +from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast import pytest from pandas import DataFrame @@ -270,7 +270,7 @@ def test_get_setter(prop: Var, expected): ([1, 2, 3], Var(_js_expr="[1, 2, 3]", _var_type=List[int])), ( {"a": 1, "b": 2}, - Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Dict[str, int]), + Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Mapping[str, int]), ), ], ) @@ -1004,7 +1004,7 @@ def test_all_number_operations(): assert ( str(even_more_complicated_number) - == "!(((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)))) !== 0))" + == "!(isTrue((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))))))" ) assert str(LiteralNumberVar.create(5) > False) == "(5 > 0)" @@ -1814,10 +1814,7 @@ def cv_fget(state: BaseState) -> int: ], ) def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]): - @computed_var( - deps=deps, - cache=True, - ) + @computed_var(deps=deps) def test_var(state) -> int: return 1 @@ -1835,10 +1832,7 @@ def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]): def test_invalid_computed_var_deps(deps: List): with pytest.raises(TypeError): - @computed_var( - deps=deps, - cache=True, - ) + @computed_var(deps=deps) def test_var(state) -> int: return 1 diff --git a/tests/units/vars/test_base.py b/tests/units/vars/test_base.py index 68bc0c38e..e4ae7327a 100644 --- a/tests/units/vars/test_base.py +++ b/tests/units/vars/test_base.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union +from typing import List, Mapping, Union import pytest @@ -37,12 +37,12 @@ class ChildGenericDict(GenericDict): ("a", str), ([1, 2, 3], List[int]), ([1, 2.0, "a"], List[Union[int, float, str]]), - ({"a": 1, "b": 2}, Dict[str, int]), - ({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]), + ({"a": 1, "b": 2}, Mapping[str, int]), + ({"a": 1, 2: "b"}, Mapping[Union[int, str], Union[str, int]]), (CustomDict(), CustomDict), (ChildCustomDict(), ChildCustomDict), - (GenericDict({1: 1}), Dict[int, int]), - (ChildGenericDict({1: 1}), Dict[int, int]), + (GenericDict({1: 1}), Mapping[int, int]), + (ChildGenericDict({1: 1}), Mapping[int, int]), ], ) def test_figure_out_type(value, expected):