diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 39d005611..b02604fd6 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -33,7 +33,7 @@ env: PR_TITLE: ${{ github.event.pull_request.title }} jobs: - example-counter: + example-counter-and-nba-proxy: env: OUTPUT_FILE: import_benchmark.json timeout-minutes: 30 @@ -114,6 +114,26 @@ jobs: --benchmark-json "./reflex-examples/counter/${{ env.OUTPUT_FILE }}" --branch-name "${{ github.head_ref || github.ref_name }}" --pr-id "${{ github.event.pull_request.id }}" --app-name "counter" + - name: Install requirements for nba proxy example + working-directory: ./reflex-examples/nba-proxy + run: | + poetry run uv pip install -r requirements.txt + - name: Install additional dependencies for DB access + run: poetry run uv pip install psycopg + - name: Check export --backend-only before init for nba-proxy example + working-directory: ./reflex-examples/nba-proxy + run: | + poetry run reflex export --backend-only + - name: Init Website for nba-proxy example + working-directory: ./reflex-examples/nba-proxy + run: | + poetry run reflex init --loglevel debug + - name: Run Website and Check for errors + run: | + # Check that npm is home + npm -v + poetry run bash scripts/integration.sh ./reflex-examples/nba-proxy dev + reflex-web: strategy: diff --git a/pyproject.toml b/pyproject.toml index 8691f04bb..160cfc535 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,8 +85,8 @@ build-backend = "poetry.core.masonry.api" target-version = "py39" output-format = "concise" lint.isort.split-on-trailing-comma = false -lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "T", "W"] -lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012"] +lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "T", "TRY", "W"] +lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012", "TRY0"] lint.pydocstyle.convention = "google" [tool.ruff.lint.per-file-ignores] diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index ec603fd13..5b8046347 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -3,6 +3,7 @@ import axios from "axios"; import io from "socket.io-client"; import JSON5 from "json5"; import env from "$/env.json"; +import reflexEnvironment from "$/reflex.json"; import Cookies from "universal-cookie"; import { useEffect, useRef, useState } from "react"; import Router, { useRouter } from "next/router"; @@ -407,6 +408,7 @@ export const connect = async ( socket.current = io(endpoint.href, { path: endpoint["pathname"], transports: transports, + protocols: env.TEST_MODE ? undefined : [reflexEnvironment.version], autoUnref: false, }); // Ensure undefined fields in events are sent as null instead of removed diff --git a/reflex/app.py b/reflex/app.py index 08cb4314e..7e868e730 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -463,14 +463,8 @@ class App(MiddlewareMixin, LifespanMixin): Returns: The generated component. - - Raises: - exceptions.MatchTypeError: If the return types of match cases in rx.match are different. """ - try: - return component if isinstance(component, Component) else component() - except exceptions.MatchTypeError: - raise + return component if isinstance(component, Component) else component() def add_page( self, @@ -564,11 +558,12 @@ class App(MiddlewareMixin, LifespanMixin): meta=meta, ) - def _compile_page(self, route: str): + def _compile_page(self, route: str, save_page: bool = True): """Compile a page. Args: route: The route of the page to compile. + save_page: If True, the compiled page is saved to self.pages. """ component, enable_state = compiler.compile_unevaluated_page( route, self.unevaluated_pages[route], self.state, self.style, self.theme @@ -579,7 +574,8 @@ class App(MiddlewareMixin, LifespanMixin): # Add the page. self._check_routes_conflict(route) - self.pages[route] = component + if save_page: + self.pages[route] = component def get_load_events(self, route: str) -> list[IndividualEventType[[], Any]]: """Get the load events for a route. @@ -879,14 +875,16 @@ class App(MiddlewareMixin, LifespanMixin): # If a theme component was provided, wrap the app with it app_wrappers[(20, "Theme")] = self.theme + should_compile = self._should_compile() + for route in self.unevaluated_pages: console.debug(f"Evaluating page: {route}") - self._compile_page(route) + self._compile_page(route, save_page=should_compile) # Add the optional endpoints (_upload) self._add_optional_endpoints() - if not self._should_compile(): + if not should_compile: return self._validate_var_dependencies() @@ -1530,7 +1528,11 @@ class EventNamespace(AsyncNamespace): sid: The Socket.IO session id. environ: The request information, including HTTP headers. """ - pass + subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL", None) + if subprotocol and subprotocol != constants.Reflex.VERSION: + console.warn( + f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}." + ) def on_disconnect(self, sid): """Event for when the websocket disconnects. @@ -1563,10 +1565,36 @@ class EventNamespace(AsyncNamespace): Args: sid: The Socket.IO session id. data: The event data. + + Raises: + EventDeserializationError: If the event data is not a dictionary. """ fields = data - # Get the event. - event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS}) + + if isinstance(fields, str): + console.warn( + "Received event data as a string. This generally should not happen and may indicate a bug." + f" Event data: {fields}" + ) + try: + fields = json.loads(fields) + except json.JSONDecodeError as ex: + raise exceptions.EventDeserializationError( + f"Failed to deserialize event data: {fields}." + ) from ex + + if not isinstance(fields, dict): + raise exceptions.EventDeserializationError( + f"Event data must be a dictionary, but received {fields} of type {type(fields)}." + ) + + try: + # Get the event. + event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS}) + except (TypeError, ValueError) as ex: + raise exceptions.EventDeserializationError( + f"Failed to deserialize event data: {fields}." + ) from ex self.token_to_sid[event.token] = sid self.sid_to_token[sid] = event.token diff --git a/reflex/app_module_for_backend.py b/reflex/app_module_for_backend.py index 8109fc3d6..b0ae0a29f 100644 --- a/reflex/app_module_for_backend.py +++ b/reflex/app_module_for_backend.py @@ -7,14 +7,13 @@ from concurrent.futures import ThreadPoolExecutor from reflex import constants from reflex.utils import telemetry from reflex.utils.exec import is_prod_mode -from reflex.utils.prerequisites import get_app +from reflex.utils.prerequisites import get_and_validate_app if constants.CompileVars.APP != "app": raise AssertionError("unexpected variable name for 'app'") telemetry.send("compile") -app_module = get_app(reload=False) -app = getattr(app_module, constants.CompileVars.APP) +app, app_module = get_and_validate_app(reload=False) # For py3.9 compatibility when redis is used, we MUST add any decorator pages # before compiling the app in a thread to avoid event loop error (REF-2172). app._apply_decorated_pages() @@ -30,7 +29,7 @@ if is_prod_mode(): # ensure only "app" is exposed. del app_module del compile_future -del get_app +del get_and_validate_app del is_prod_mode del telemetry del constants diff --git a/reflex/components/component.py b/reflex/components/component.py index f39956608..58bdfe7c2 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -426,20 +426,22 @@ class Component(BaseComponent, ABC): else: continue + def determine_key(value): + # Try to create a var from the value + key = value if isinstance(value, Var) else LiteralVar.create(value) + + # Check that the var type is not None. + if key is None: + raise TypeError + + return key + # Check whether the key is a component prop. if types._issubclass(field_type, Var): # Used to store the passed types if var type is a union. passed_types = None try: - # Try to create a var from the value. - if isinstance(value, Var): - kwargs[key] = value - else: - kwargs[key] = LiteralVar.create(value) - - # Check that the var type is not None. - if kwargs[key] is None: - raise TypeError + kwargs[key] = determine_key(value) expected_type = fields[key].outer_type_.__args__[0] # validate literal fields. @@ -702,22 +704,21 @@ class Component(BaseComponent, ABC): # Import here to avoid circular imports. from reflex.components.base.bare import Bare from reflex.components.base.fragment import Fragment - from reflex.utils.exceptions import ComponentTypeError + from reflex.utils.exceptions import ChildrenTypeError # Filter out None props props = {key: value for key, value in props.items() if value is not None} def validate_children(children): for child in children: - if isinstance(child, tuple): + if isinstance(child, (tuple, list)): validate_children(child) + # Make sure the child is a valid type. - if not types._isinstance(child, ComponentChild): - raise ComponentTypeError( - "Children of Reflex components must be other components, " - "state vars, or primitive Python types. " - f"Got child {child} of type {type(child)}.", - ) + if isinstance(child, dict) or not types._isinstance( + child, ComponentChild + ): + raise ChildrenTypeError(component=cls.__name__, child=child) # Validate all the children. validate_children(children) diff --git a/reflex/components/lucide/icon.py b/reflex/components/lucide/icon.py index 04410ac56..6c7cbede7 100644 --- a/reflex/components/lucide/icon.py +++ b/reflex/components/lucide/icon.py @@ -2,13 +2,15 @@ from reflex.components.component import Component from reflex.utils import format -from reflex.vars.base import Var +from reflex.utils.imports import ImportVar +from reflex.vars.base import LiteralVar, Var +from reflex.vars.sequence import LiteralStringVar class LucideIconComponent(Component): """Lucide Icon Component.""" - library = "lucide-react@0.469.0" + library = "lucide-react@0.471.1" class Icon(LucideIconComponent): @@ -32,6 +34,7 @@ class Icon(LucideIconComponent): Raises: AttributeError: The errors tied to bad usage of the Icon component. ValueError: If the icon tag is invalid. + TypeError: If the icon name is not a string. Returns: The created component. @@ -39,7 +42,6 @@ class Icon(LucideIconComponent): if children: if len(children) == 1 and isinstance(children[0], str): props["tag"] = children[0] - children = [] else: raise AttributeError( f"Passing multiple children to Icon component is not allowed: remove positional arguments {children[1:]} to fix" @@ -47,24 +49,46 @@ class Icon(LucideIconComponent): if "tag" not in props: raise AttributeError("Missing 'tag' keyword-argument for Icon") + tag: str | Var | LiteralVar = props.pop("tag") + if isinstance(tag, LiteralVar): + if isinstance(tag, LiteralStringVar): + tag = tag._var_value + else: + raise TypeError(f"Icon name must be a string, got {type(tag)}") + elif isinstance(tag, Var): + return DynamicIcon.create(name=tag, **props) + if ( - not isinstance(props["tag"], str) - or format.to_snake_case(props["tag"]) not in LUCIDE_ICON_LIST + not isinstance(tag, str) + or format.to_snake_case(tag) not in LUCIDE_ICON_LIST ): raise ValueError( - f"Invalid icon tag: {props['tag']}. Please use one of the following: {', '.join(LUCIDE_ICON_LIST[0:25])}, ..." + f"Invalid icon tag: {tag}. Please use one of the following: {', '.join(LUCIDE_ICON_LIST[0:25])}, ..." "\nSee full list at https://lucide.dev/icons." ) - if props["tag"] in LUCIDE_ICON_MAPPING_OVERRIDE: - props["tag"] = LUCIDE_ICON_MAPPING_OVERRIDE[props["tag"]] + if tag in LUCIDE_ICON_MAPPING_OVERRIDE: + props["tag"] = LUCIDE_ICON_MAPPING_OVERRIDE[tag] else: - props["tag"] = ( - format.to_title_case(format.to_snake_case(props["tag"])) + "Icon" - ) + props["tag"] = format.to_title_case(format.to_snake_case(tag)) + "Icon" props["alias"] = f"Lucide{props['tag']}" props.setdefault("color", "var(--current-color)") - return super().create(*children, **props) + return super().create(**props) + + +class DynamicIcon(LucideIconComponent): + """A DynamicIcon component.""" + + tag = "DynamicIcon" + + name: Var[str] + + def _get_imports(self): + _imports = super()._get_imports() + if self.library: + _imports.pop(self.library) + _imports["lucide-react/dynamic"] = [ImportVar("DynamicIcon", install=False)] + return _imports LUCIDE_ICON_LIST = [ @@ -846,6 +870,7 @@ LUCIDE_ICON_LIST = [ "house", "house_plug", "house_plus", + "house_wifi", "ice_cream_bowl", "ice_cream_cone", "id_card", @@ -1534,6 +1559,7 @@ LUCIDE_ICON_LIST = [ "trending_up_down", "triangle", "triangle_alert", + "triangle_dashed", "triangle_right", "trophy", "truck", diff --git a/reflex/components/lucide/icon.pyi b/reflex/components/lucide/icon.pyi index 39a1da0e6..6094cfd87 100644 --- a/reflex/components/lucide/icon.pyi +++ b/reflex/components/lucide/icon.pyi @@ -104,12 +104,60 @@ class Icon(LucideIconComponent): Raises: AttributeError: The errors tied to bad usage of the Icon component. ValueError: If the icon tag is invalid. + TypeError: If the icon name is not a string. Returns: The created component. """ ... +class DynamicIcon(LucideIconComponent): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + name: Optional[Union[Var[str], str]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None, + on_blur: Optional[EventType[[], BASE_STATE]] = None, + on_click: Optional[EventType[[], BASE_STATE]] = None, + on_context_menu: Optional[EventType[[], BASE_STATE]] = None, + on_double_click: Optional[EventType[[], BASE_STATE]] = None, + on_focus: Optional[EventType[[], BASE_STATE]] = None, + on_mount: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_down: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_move: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_out: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_over: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_up: Optional[EventType[[], BASE_STATE]] = None, + on_scroll: Optional[EventType[[], BASE_STATE]] = None, + on_unmount: Optional[EventType[[], BASE_STATE]] = None, + **props, + ) -> "DynamicIcon": + """Create the component. + + Args: + *children: The children of the component. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: The props of the component. + + Returns: + The component. + """ + ... + LUCIDE_ICON_LIST = [ "a_arrow_down", "a_arrow_up", @@ -889,6 +937,7 @@ LUCIDE_ICON_LIST = [ "house", "house_plug", "house_plus", + "house_wifi", "ice_cream_bowl", "ice_cream_cone", "id_card", @@ -1577,6 +1626,7 @@ LUCIDE_ICON_LIST = [ "trending_up_down", "triangle", "triangle_alert", + "triangle_dashed", "triangle_right", "trophy", "truck", diff --git a/reflex/config.py b/reflex/config.py index 7614417d5..8511694fb 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -12,6 +12,7 @@ import threading import urllib.parse from importlib.util import find_spec from pathlib import Path +from types import ModuleType from typing import ( TYPE_CHECKING, Any, @@ -607,6 +608,9 @@ class Config(Base): # The name of the app (should match the name of the app directory). app_name: str + # The path to the app module. + app_module_import: Optional[str] = None + # The log level to use. loglevel: constants.LogLevel = constants.LogLevel.DEFAULT @@ -729,6 +733,19 @@ class Config(Base): "REDIS_URL is required when using the redis state manager." ) + @property + def app_module(self) -> ModuleType | None: + """Return the app module if `app_module_import` is set. + + Returns: + The app module. + """ + return ( + importlib.import_module(self.app_module_import) + if self.app_module_import + else None + ) + @property def module(self) -> str: """Get the module name of the app. @@ -736,6 +753,8 @@ class Config(Base): Returns: The module name. """ + if self.app_module is not None: + return self.app_module.__name__ return ".".join([self.app_name, self.app_name]) def update_from_env(self) -> dict[str, Any]: @@ -874,7 +893,7 @@ def get_config(reload: bool = False) -> Config: return cached_rxconfig.config with _config_lock: - sys_path = sys.path.copy() + orig_sys_path = sys.path.copy() sys.path.clear() sys.path.append(str(Path.cwd())) try: @@ -882,9 +901,14 @@ def get_config(reload: bool = False) -> Config: return _get_config() except Exception: # If the module import fails, try to import with the original sys.path. - sys.path.extend(sys_path) + sys.path.extend(orig_sys_path) return _get_config() finally: + # Find any entries added to sys.path by rxconfig.py itself. + extra_paths = [ + p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd()) + ] # Restore the original sys.path. sys.path.clear() - sys.path.extend(sys_path) + sys.path.extend(extra_paths) + sys.path.extend(orig_sys_path) diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index e816da0f7..f5946bf5e 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -1,6 +1,7 @@ """The constants package.""" from .base import ( + APP_HARNESS_FLAG, COOKIES, IS_LINUX, IS_MACOS, diff --git a/reflex/constants/base.py b/reflex/constants/base.py index af96583ad..f737858c0 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -257,6 +257,7 @@ SESSION_STORAGE = "session_storage" # Testing variables. # Testing os env set by pytest when running a test case. PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST" +APP_HARNESS_FLAG = "APP_HARNESS_FLAG" REFLEX_VAR_OPENING_TAG = "" REFLEX_VAR_CLOSING_TAG = "" diff --git a/reflex/custom_components/custom_components.py b/reflex/custom_components/custom_components.py index 4a169802f..8000e7f4c 100644 --- a/reflex/custom_components/custom_components.py +++ b/reflex/custom_components/custom_components.py @@ -421,12 +421,13 @@ def _run_commands_in_subprocess(cmds: list[str]) -> bool: console.debug(f"Running command: {' '.join(cmds)}") try: result = subprocess.run(cmds, capture_output=True, text=True, check=True) - console.debug(result.stdout) - return True except subprocess.CalledProcessError as cpe: console.error(cpe.stdout) console.error(cpe.stderr) return False + else: + console.debug(result.stdout) + return True def _make_pyi_files(): @@ -931,10 +932,11 @@ def _get_file_from_prompt_in_loop() -> Tuple[bytes, str] | None: file_extension = image_filepath.suffix try: image_file = image_filepath.read_bytes() - return image_file, file_extension except OSError as ose: console.error(f"Unable to read the {file_extension} file due to {ose}") raise typer.Exit(code=1) from ose + else: + return image_file, file_extension console.debug(f"File extension detected: {file_extension}") return None diff --git a/reflex/event.py b/reflex/event.py index 9ade75fa3..f44fa2b1c 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -1534,7 +1534,7 @@ def get_handler_args( def fix_events( - events: list[EventHandler | EventSpec] | None, + events: list[EventSpec | EventHandler] | None, token: str, router_data: dict[str, Any] | None = None, ) -> list[Event]: diff --git a/reflex/reflex.py b/reflex/reflex.py index 83697e821..11cc3b597 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -440,7 +440,11 @@ def deploy( config.app_name, "--app-name", help="The name of the App to deploy under.", - hidden=True, + ), + app_id: str = typer.Option( + None, + "--app-id", + help="The ID of the App to deploy over.", ), regions: List[str] = typer.Option( [], @@ -480,6 +484,11 @@ def deploy( "--project", help="project id to deploy to", ), + project_name: Optional[str] = typer.Option( + None, + "--project-name", + help="The name of the project to deploy to.", + ), token: Optional[str] = typer.Option( None, "--token", @@ -503,13 +512,6 @@ def deploy( # Set the log level. console.set_log_level(loglevel) - if not token: - # make sure user is logged in. - if interactive: - hosting_cli.login() - else: - raise SystemExit("Token is required for non-interactive mode.") - # Only check requirements if interactive. # There is user interaction for requirements update. if interactive: @@ -526,6 +528,7 @@ def deploy( hosting_cli.deploy( app_name=app_name, + app_id=app_id, export_fn=lambda zip_dest_dir, api_url, deploy_url, @@ -549,6 +552,8 @@ def deploy( loglevel=type(loglevel).INFO, # type: ignore token=token, project=project, + config_path=config_path, + project_name=project_name, **extra, ) diff --git a/reflex/state.py b/reflex/state.py index d11e1fa3f..cb5b2b706 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1773,9 +1773,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): except Exception as ex: state._clean() - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) - - event_specs = app_instance.backend_exception_handler(ex) + event_specs = ( + prerequisites.get_and_validate_app().app.backend_exception_handler(ex) + ) if event_specs is None: return StateUpdate() @@ -1885,9 +1885,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): except Exception as ex: telemetry.send_error(ex, context="backend") - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) - - event_specs = app_instance.backend_exception_handler(ex) + event_specs = ( + prerequisites.get_and_validate_app().app.backend_exception_handler(ex) + ) yield state._as_state_update( handler, @@ -2400,8 +2400,9 @@ class FrontendEventExceptionState(State): component_stack: The stack trace of the component where the exception occurred. """ - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) - app_instance.frontend_exception_handler(Exception(stack)) + prerequisites.get_and_validate_app().app.frontend_exception_handler( + Exception(stack) + ) class UpdateVarsInternalState(State): @@ -2439,15 +2440,16 @@ class OnLoadInternalState(State): The list of events to queue for on load handling. """ # Do not app._compile()! It should be already compiled by now. - app = getattr(prerequisites.get_app(), constants.CompileVars.APP) - load_events = app.get_load_events(self.router.page.path) + load_events = prerequisites.get_and_validate_app().app.get_load_events( + self.router.page.path + ) if not load_events: self.is_hydrated = True return # Fast path for navigation with no on_load events defined. self.is_hydrated = False return [ *fix_events( - load_events, + cast(list[Union[EventSpec, EventHandler]], load_events), self.router.session.client_token, router_data=self.router_data, ), @@ -2606,7 +2608,7 @@ class StateProxy(wrapt.ObjectProxy): """ super().__init__(state_instance) # compile is not relevant to backend logic - self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP) + self._self_app = prerequisites.get_and_validate_app().app self._self_substate_path = tuple(state_instance.get_full_name().split(".")) self._self_actx = None self._self_mutable = False @@ -3699,8 +3701,7 @@ def get_state_manager() -> StateManager: Returns: The state manager. """ - app = getattr(prerequisites.get_app(), constants.CompileVars.APP) - return app.state_manager + return prerequisites.get_and_validate_app().app.state_manager class MutableProxy(wrapt.ObjectProxy): diff --git a/reflex/testing.py b/reflex/testing.py index b3dedf398..a1083d250 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -282,6 +282,7 @@ class AppHarness: before_decorated_pages = reflex.app.DECORATED_PAGES[self.app_name].copy() # Ensure the AppHarness test does not skip State assignment due to running via pytest os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None) + os.environ[reflex.constants.APP_HARNESS_FLAG] = "true" self.app_module = reflex.utils.prerequisites.get_compiled_app( # Do not reload the module for pre-existing apps (only apps generated from source) reload=self.app_source is not None diff --git a/reflex/utils/build.py b/reflex/utils/build.py index e263374e1..9ea941792 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -13,13 +13,17 @@ from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from reflex import constants from reflex.config import get_config from reflex.utils import console, path_ops, prerequisites, processes +from reflex.utils.exec import is_in_app_harness def set_env_json(): """Write the upload url to a REFLEX_JSON.""" path_ops.update_json_file( str(prerequisites.get_web_dir() / constants.Dirs.ENV_JSON), - {endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint}, + { + **{endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint}, + "TEST_MODE": is_in_app_harness(), + }, ) diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index aa6c51430..d35643753 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -1,5 +1,7 @@ """Custom Exceptions.""" +from typing import Any + class ReflexError(Exception): """Base exception for all Reflex exceptions.""" @@ -29,6 +31,22 @@ class ComponentTypeError(ReflexError, TypeError): """Custom TypeError for component related errors.""" +class ChildrenTypeError(ComponentTypeError): + """Raised when the children prop of a component is not a valid type.""" + + def __init__(self, component: str, child: Any): + """Initialize the exception. + + Args: + component: The name of the component. + child: The child that caused the error. + """ + super().__init__( + f"Component {component} received child {child} of type {type(child)}. " + "Accepted types are other components, state vars, or primitive Python types (dict excluded)." + ) + + class EventHandlerTypeError(ReflexError, TypeError): """Custom TypeError for event handler related errors.""" @@ -209,6 +227,10 @@ class SystemPackageMissingError(ReflexError): ) +class EventDeserializationError(ReflexError, ValueError): + """Raised when an event cannot be deserialized.""" + + class InvalidLockWarningThresholdError(ReflexError): """Raised when an invalid lock warning threshold is provided.""" diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 0f6a80a0d..036f052f0 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -240,6 +240,28 @@ def run_backend( run_uvicorn_backend(host, port, loglevel) +def get_reload_dirs() -> list[str]: + """Get the reload directories for the backend. + + Returns: + The reload directories for the backend. + """ + config = get_config() + reload_dirs = [config.app_name] + if config.app_module is not None and config.app_module.__file__: + module_path = Path(config.app_module.__file__).resolve().parent + while module_path.parent.name: + for parent_file in module_path.parent.iterdir(): + if parent_file == "__init__.py": + # go up a level to find dir without `__init__.py` + module_path = module_path.parent + break + else: + break + reload_dirs.append(str(module_path)) + return reload_dirs + + def run_uvicorn_backend(host, port, loglevel: LogLevel): """Run the backend in development mode using Uvicorn. @@ -256,7 +278,7 @@ def run_uvicorn_backend(host, port, loglevel: LogLevel): port=port, log_level=loglevel.value, reload=True, - reload_dirs=[get_config().app_name], + reload_dirs=get_reload_dirs(), ) @@ -281,7 +303,7 @@ def run_granian_backend(host, port, loglevel: LogLevel): interface=Interfaces.ASGI, log_level=LogLevels(loglevel.value), reload=True, - reload_paths=[Path(get_config().app_name)], + reload_paths=get_reload_dirs(), reload_ignore_dirs=[".web"], ).serve() except ImportError: @@ -487,6 +509,15 @@ def is_testing_env() -> bool: return constants.PYTEST_CURRENT_TEST in os.environ +def is_in_app_harness() -> bool: + """Whether the app is running in the app harness. + + Returns: + True if the app is running in the app harness. + """ + return constants.APP_HARNESS_FLAG in os.environ + + def is_prod_mode() -> bool: """Check if the app is running in production mode. diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index f0a5ef575..7fd5b5075 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -17,11 +17,12 @@ import stat import sys import tempfile import time +import typing import zipfile from datetime import datetime from pathlib import Path from types import ModuleType -from typing import Callable, List, Optional +from typing import Callable, List, NamedTuple, Optional import httpx import typer @@ -42,9 +43,19 @@ from reflex.utils.exceptions import ( from reflex.utils.format import format_library_name from reflex.utils.registry import _get_npm_registry +if typing.TYPE_CHECKING: + from reflex.app import App + CURRENTLY_INSTALLING_NODE = False +class AppInfo(NamedTuple): + """A tuple containing the app instance and module.""" + + app: App + module: ModuleType + + @dataclasses.dataclass(frozen=True) class Template: """A template for a Reflex app.""" @@ -253,6 +264,22 @@ def windows_npm_escape_hatch() -> bool: return environment.REFLEX_USE_NPM.get() +def _check_app_name(config: Config): + """Check if the app name is set in the config. + + Args: + config: The config object. + + Raises: + RuntimeError: If the app name is not set in the config. + """ + if not config.app_name: + raise RuntimeError( + "Cannot get the app module because `app_name` is not set in rxconfig! " + "If this error occurs in a reflex test case, ensure that `get_app` is mocked." + ) + + def get_app(reload: bool = False) -> ModuleType: """Get the app module based on the default config. @@ -263,22 +290,23 @@ def get_app(reload: bool = False) -> ModuleType: The app based on the default config. Raises: - RuntimeError: If the app name is not set in the config. + Exception: If an error occurs while getting the app module. """ from reflex.utils import telemetry try: environment.RELOAD_CONFIG.set(reload) config = get_config() - if not config.app_name: - raise RuntimeError( - "Cannot get the app module because `app_name` is not set in rxconfig! " - "If this error occurs in a reflex test case, ensure that `get_app` is mocked." - ) + + _check_app_name(config) + module = config.module sys.path.insert(0, str(Path.cwd())) - app = __import__(module, fromlist=(constants.CompileVars.APP,)) - + app = ( + __import__(module, fromlist=(constants.CompileVars.APP,)) + if not config.app_module + else config.app_module + ) if reload: from reflex.state import reload_state_module @@ -287,11 +315,34 @@ def get_app(reload: bool = False) -> ModuleType: # Reload the app module. importlib.reload(app) - - return app except Exception as ex: telemetry.send_error(ex, context="frontend") raise + else: + return app + + +def get_and_validate_app(reload: bool = False) -> AppInfo: + """Get the app instance based on the default config and validate it. + + Args: + reload: Re-import the app module from disk + + Returns: + The app instance and the app module. + + Raises: + RuntimeError: If the app instance is not an instance of rx.App. + """ + from reflex.app import App + + app_module = get_app(reload=reload) + app = getattr(app_module, constants.CompileVars.APP) + if not isinstance(app, App): + raise RuntimeError( + "The app instance in the specified app_module_import in rxconfig must be an instance of rx.App." + ) + return AppInfo(app=app, module=app_module) def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: @@ -304,8 +355,7 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: Returns: The compiled app based on the default config. """ - app_module = get_app(reload=reload) - app = getattr(app_module, constants.CompileVars.APP) + app, app_module = get_and_validate_app(reload=reload) # For py3.9 compatibility when redis is used, we MUST add any decorator pages # before compiling the app in a thread to avoid event loop error (REF-2172). app._apply_decorated_pages() @@ -1143,11 +1193,12 @@ def ensure_reflex_installation_id() -> Optional[int]: if installation_id is None: installation_id = random.getrandbits(128) installation_id_file.write_text(str(installation_id)) - # If we get here, installation_id is definitely set - return installation_id except Exception as e: console.debug(f"Failed to ensure reflex installation id: {e}") return None + else: + # If we get here, installation_id is definitely set + return installation_id def initialize_reflex_user_directory(): @@ -1361,19 +1412,22 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str except OSError as ose: console.error(f"Failed to create temp directory for extracting zip: {ose}") raise typer.Exit(1) from ose + try: zipfile.ZipFile(zip_file_path).extractall(path=unzip_dir) # The zip file downloaded from github looks like: # repo-name-branch/**/*, so we need to remove the top level directory. - if len(subdirs := os.listdir(unzip_dir)) != 1: - console.error(f"Expected one directory in the zip, found {subdirs}") - raise typer.Exit(1) - template_dir = unzip_dir / subdirs[0] - console.debug(f"Template folder is located at {template_dir}") except Exception as uze: console.error(f"Failed to unzip the template: {uze}") raise typer.Exit(1) from uze + if len(subdirs := os.listdir(unzip_dir)) != 1: + console.error(f"Expected one directory in the zip, found {subdirs}") + raise typer.Exit(1) + + template_dir = unzip_dir / subdirs[0] + console.debug(f"Template folder is located at {template_dir}") + # Move the rxconfig file here first. path_ops.mv(str(template_dir / constants.Config.FILE), constants.Config.FILE) new_config = get_config(reload=True) diff --git a/reflex/utils/telemetry.py b/reflex/utils/telemetry.py index 99bb6dd37..67ec45b82 100644 --- a/reflex/utils/telemetry.py +++ b/reflex/utils/telemetry.py @@ -156,9 +156,10 @@ def _prepare_event(event: str, **kwargs) -> dict: def _send_event(event_data: dict) -> bool: try: httpx.post(POSTHOG_API_URL, json=event_data) - return True except Exception: return False + else: + return True def _send(event, telemetry_enabled, **kwargs): diff --git a/reflex/utils/types.py b/reflex/utils/types.py index b8bcbf2d6..ac30342e2 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -829,6 +829,22 @@ StateBases = get_base_class(StateVar) StateIterBases = get_base_class(StateIterVar) +def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]): + """Check if a class is a subclass of another class. Returns False if internal error occurs. + + Args: + cls: The class to check. + cls_check: The class to check against. + + Returns: + Whether the class is a subclass of the other class. + """ + try: + return issubclass(cls, cls_check) + except TypeError: + return False + + def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool: """Check if a type hint is a subclass of another type hint. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index fb0d2babd..1bb346c65 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -26,6 +26,7 @@ from typing import ( Iterable, List, Literal, + Mapping, NoReturn, Optional, Set, @@ -65,6 +66,7 @@ from reflex.utils.types import ( _isinstance, get_origin, has_args, + safe_issubclass, unionize, ) @@ -128,7 +130,7 @@ class VarData: state: str = "", field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, - hooks: dict[str, VarData | None] | None = None, + hooks: Mapping[str, VarData | None] | None = None, deps: list[Var] | None = None, position: Hooks.HookPosition | None = None, ): @@ -613,8 +615,8 @@ class Var(Generic[VAR_TYPE]): @overload def to( self, - output: type[dict], - ) -> ObjectVar[dict]: ... + output: type[Mapping], + ) -> ObjectVar[Mapping]: ... @overload def to( @@ -656,7 +658,9 @@ class Var(Generic[VAR_TYPE]): # If the first argument is a python type, we map it to the corresponding Var type. for var_subclass in _var_subclasses[::-1]: - if fixed_output_type in var_subclass.python_types: + if fixed_output_type in var_subclass.python_types or safe_issubclass( + fixed_output_type, var_subclass.python_types + ): return self.to(var_subclass.var_subclass, output) if fixed_output_type is None: @@ -790,7 +794,7 @@ class Var(Generic[VAR_TYPE]): return False if issubclass(type_, list): return [] - if issubclass(type_, dict): + if issubclass(type_, Mapping): return {} if issubclass(type_, tuple): return () @@ -996,7 +1000,7 @@ class Var(Generic[VAR_TYPE]): f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")] } ), - ).to(ObjectVar, Dict[str, str]) + ).to(ObjectVar, Mapping[str, str]) return refs[LiteralVar.create(str(self))] @deprecated("Use `.js_type()` instead.") @@ -1343,7 +1347,7 @@ class LiteralVar(Var): serialized_value = serializers.serialize(value) if serialized_value is not None: - if isinstance(serialized_value, dict): + if isinstance(serialized_value, Mapping): return LiteralObjectVar.create( serialized_value, _var_type=type(value), @@ -1468,7 +1472,7 @@ def var_operation( ) -> Callable[P, ArrayVar[LIST_T]]: ... -OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict) +OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping) @overload @@ -1543,8 +1547,8 @@ def figure_out_type(value: Any) -> types.GenericType: return Set[unionize(*(figure_out_type(v) for v in value))] if isinstance(value, tuple): return Tuple[unionize(*(figure_out_type(v) for v in value)), ...] - if isinstance(value, dict): - return Dict[ + if isinstance(value, Mapping): + return Mapping[ unionize(*(figure_out_type(k) for k in value)), unionize(*(figure_out_type(v) for v in value.values())), ] @@ -1968,10 +1972,10 @@ class ComputedVar(Var[RETURN_TYPE]): @overload def __get__( - self: ComputedVar[dict[DICT_KEY, DICT_VAL]], + self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]], instance: None, owner: Type, - ) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ... + ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ... @overload def __get__( @@ -2878,11 +2882,14 @@ V = TypeVar("V") BASE_TYPE = TypeVar("BASE_TYPE", bound=Base) +FIELD_TYPE = TypeVar("FIELD_TYPE") +MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping) -class Field(Generic[T]): + +class Field(Generic[FIELD_TYPE]): """Shadow class for Var to allow for type hinting in the IDE.""" - def __set__(self, instance, value: T): + def __set__(self, instance, value: FIELD_TYPE): """Set the Var. Args: @@ -2894,7 +2901,9 @@ class Field(Generic[T]): def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ... @overload - def __get__(self: Field[int], instance: None, owner) -> NumberVar: ... + def __get__( + self: Field[int] | Field[float] | Field[int | float], instance: None, owner + ) -> NumberVar: ... @overload def __get__(self: Field[str], instance: None, owner) -> StringVar: ... @@ -2911,8 +2920,8 @@ class Field(Generic[T]): @overload def __get__( - self: Field[Dict[str, V]], instance: None, owner - ) -> ObjectVar[Dict[str, V]]: ... + self: Field[MAPPING_TYPE], instance: None, owner + ) -> ObjectVar[MAPPING_TYPE]: ... @overload def __get__( @@ -2920,10 +2929,10 @@ class Field(Generic[T]): ) -> ObjectVar[BASE_TYPE]: ... @overload - def __get__(self, instance: None, owner) -> Var[T]: ... + def __get__(self, instance: None, owner) -> Var[FIELD_TYPE]: ... @overload - def __get__(self, instance, owner) -> T: ... + def __get__(self, instance, owner) -> FIELD_TYPE: ... def __get__(self, instance, owner): # type: ignore """Get the Var. @@ -2934,7 +2943,7 @@ class Field(Generic[T]): """ -def field(value: T) -> Field[T]: +def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]: """Create a Field with a value. Args: diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 5de431f5a..7b951c559 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -8,8 +8,8 @@ import typing from inspect import isclass from typing import ( Any, - Dict, List, + Mapping, NoReturn, Tuple, Type, @@ -19,6 +19,8 @@ from typing import ( overload, ) +from typing_extensions import is_typeddict + from reflex.utils import types from reflex.utils.exceptions import VarAttributeError from reflex.utils.types import GenericType, get_attribute_access_type, get_origin @@ -36,7 +38,7 @@ from .base import ( from .number import BooleanVar, NumberVar, raise_unsupported_operand_types from .sequence import ArrayVar, StringVar -OBJECT_TYPE = TypeVar("OBJECT_TYPE") +OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True) KEY_TYPE = TypeVar("KEY_TYPE") VALUE_TYPE = TypeVar("VALUE_TYPE") @@ -46,7 +48,7 @@ ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE") OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE") -class ObjectVar(Var[OBJECT_TYPE], python_types=dict): +class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): """Base class for immutable object vars.""" def _key_type(self) -> Type: @@ -59,7 +61,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def _value_type( - self: ObjectVar[Dict[Any, VALUE_TYPE]], + self: ObjectVar[Mapping[Any, VALUE_TYPE]], ) -> Type[VALUE_TYPE]: ... @overload @@ -74,7 +76,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): fixed_type = get_origin(self._var_type) or self._var_type if not isclass(fixed_type): return Any - args = get_args(self._var_type) if issubclass(fixed_type, dict) else () + args = get_args(self._var_type) if issubclass(fixed_type, Mapping) else () return args[1] if args else Any def keys(self) -> ArrayVar[List[str]]: @@ -87,7 +89,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def values( - self: ObjectVar[Dict[Any, VALUE_TYPE]], + self: ObjectVar[Mapping[Any, VALUE_TYPE]], ) -> ArrayVar[List[VALUE_TYPE]]: ... @overload @@ -103,7 +105,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def entries( - self: ObjectVar[Dict[Any, VALUE_TYPE]], + self: ObjectVar[Mapping[Any, VALUE_TYPE]], ) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ... @overload @@ -133,49 +135,55 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): # NoReturn is used here to catch when key value is Any @overload def __getitem__( - self: ObjectVar[Dict[Any, NoReturn]], + self: ObjectVar[Mapping[Any, NoReturn]], key: Var | Any, ) -> Var: ... + @overload + def __getitem__( + self: (ObjectVar[Mapping[Any, bool]]), + key: Var | Any, + ) -> BooleanVar: ... + @overload def __getitem__( self: ( - ObjectVar[Dict[Any, int]] - | ObjectVar[Dict[Any, float]] - | ObjectVar[Dict[Any, int | float]] + ObjectVar[Mapping[Any, int]] + | ObjectVar[Mapping[Any, float]] + | ObjectVar[Mapping[Any, int | float]] ), key: Var | Any, ) -> NumberVar: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, str]], + self: ObjectVar[Mapping[Any, str]], key: Var | Any, ) -> StringVar: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], + self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], key: Var | Any, ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]], key: Var | Any, - ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... + ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ... def __getitem__(self, key: Var | Any) -> Var: """Get an item from the object. @@ -195,49 +203,49 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): # NoReturn is used here to catch when key value is Any @overload def __getattr__( - self: ObjectVar[Dict[Any, NoReturn]], + self: ObjectVar[Mapping[Any, NoReturn]], name: str, ) -> Var: ... @overload def __getattr__( self: ( - ObjectVar[Dict[Any, int]] - | ObjectVar[Dict[Any, float]] - | ObjectVar[Dict[Any, int | float]] + ObjectVar[Mapping[Any, int]] + | ObjectVar[Mapping[Any, float]] + | ObjectVar[Mapping[Any, int | float]] ), name: str, ) -> NumberVar: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, str]], + self: ObjectVar[Mapping[Any, str]], name: str, ) -> StringVar: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]], name: str, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]], name: str, ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], + self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], name: str, ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]], name: str, - ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... + ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ... @overload def __getattr__( @@ -266,8 +274,11 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): var_type = get_args(var_type)[0] fixed_type = var_type if isclass(var_type) else get_origin(var_type) - if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or ( - fixed_type in types.UnionTypes + + if ( + (isclass(fixed_type) and not issubclass(fixed_type, Mapping)) + or (fixed_type in types.UnionTypes) + or is_typeddict(fixed_type) ): attribute_type = get_attribute_access_type(var_type, name) if attribute_type is None: @@ -299,7 +310,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): """Base class for immutable literal object vars.""" - _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( + _var_value: Mapping[Union[Var, Any], Union[Var, Any]] = dataclasses.field( default_factory=dict ) @@ -383,7 +394,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): @classmethod def create( cls, - _var_value: dict, + _var_value: Mapping, _var_type: Type[OBJECT_TYPE] | None = None, _var_data: VarData | None = None, ) -> LiteralObjectVar[OBJECT_TYPE]: @@ -466,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar): """ return var_operation_return( js_expression=f"({{...{lhs}, ...{rhs}}})", - var_type=Dict[ + var_type=Mapping[ Union[lhs._key_type(), rhs._key_type()], Union[lhs._value_type(), rhs._value_type()], ], diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 5864e70b9..1b11f50e6 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -987,7 +987,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): raise_unsupported_operand_types("[]", (type(self), type(i))) return array_item_operation(self, i) - def length(self) -> NumberVar: + def length(self) -> NumberVar[int]: """Get the length of the array. Returns: diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index 44187c8ba..18259fe3f 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -71,9 +71,10 @@ def has_error_modal(driver: WebDriver) -> bool: """ try: driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH) - return True except NoSuchElementException: return False + else: + return True @pytest.mark.asyncio diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py index f09e800e5..eaf98414d 100644 --- a/tests/units/components/core/test_match.py +++ b/tests/units/components/core/test_match.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import List, Mapping, Tuple import pytest @@ -67,7 +67,7 @@ def test_match_components(): assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}' assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })' - assert match_cases[4][0]._var_type == Dict[str, str] + assert match_cases[4][0]._var_type == Mapping[str, str] fifth_return_value_render = match_cases[4][1].render() assert fifth_return_value_render["name"] == "RadixThemesText" assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}' diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 01e199958..5bad24499 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -29,6 +29,7 @@ from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports from reflex.utils.exceptions import ( + ChildrenTypeError, EventFnArgMismatchError, EventHandlerArgTypeMismatchError, ) @@ -652,14 +653,17 @@ def test_create_filters_none_props(test_component): assert str(component.style["text-align"]) == '"center"' -@pytest.mark.parametrize("children", [((None,),), ("foo", ("bar", (None,)))]) +@pytest.mark.parametrize( + "children", + [ + ((None,),), + ("foo", ("bar", (None,))), + ({"foo": "bar"},), + ], +) def test_component_create_unallowed_types(children, test_component): - with pytest.raises(TypeError) as err: + with pytest.raises(ChildrenTypeError): test_component.create(*children) - assert ( - err.value.args[0] - == "Children of Reflex components must be other components, state vars, or primitive Python types. Got child None of type ." - ) @pytest.mark.parametrize( diff --git a/tests/units/test_style.py b/tests/units/test_style.py index e1d652798..bb585fd22 100644 --- a/tests/units/test_style.py +++ b/tests/units/test_style.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any, Mapping import pytest @@ -379,7 +379,7 @@ class StyleState(rx.State): { "css": Var( _js_expr=f'({{ ["color"] : ("dark"+{StyleState.color}) }})' - ).to(Dict[str, str]) + ).to(Mapping[str, str]) }, ), ( diff --git a/tests/units/test_var.py b/tests/units/test_var.py index a2700f0e7..d071d1dba 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -2,7 +2,7 @@ import json import math import sys import typing -from typing import Dict, List, Optional, Set, Tuple, Union, cast +from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast import pytest from pandas import DataFrame @@ -273,7 +273,7 @@ def test_get_setter(prop: Var, expected): ([1, 2, 3], Var(_js_expr="[1, 2, 3]", _var_type=List[int])), ( {"a": 1, "b": 2}, - Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Dict[str, int]), + Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Mapping[str, int]), ), ], ) diff --git a/tests/units/vars/test_base.py b/tests/units/vars/test_base.py index 68bc0c38e..e4ae7327a 100644 --- a/tests/units/vars/test_base.py +++ b/tests/units/vars/test_base.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union +from typing import List, Mapping, Union import pytest @@ -37,12 +37,12 @@ class ChildGenericDict(GenericDict): ("a", str), ([1, 2, 3], List[int]), ([1, 2.0, "a"], List[Union[int, float, str]]), - ({"a": 1, "b": 2}, Dict[str, int]), - ({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]), + ({"a": 1, "b": 2}, Mapping[str, int]), + ({"a": 1, 2: "b"}, Mapping[Union[int, str], Union[str, int]]), (CustomDict(), CustomDict), (ChildCustomDict(), ChildCustomDict), - (GenericDict({1: 1}), Dict[int, int]), - (ChildGenericDict({1: 1}), Dict[int, int]), + (GenericDict({1: 1}), Mapping[int, int]), + (ChildGenericDict({1: 1}), Mapping[int, int]), ], ) def test_figure_out_type(value, expected):