diff --git a/.github/workflows/integration_app_harness.yml b/.github/workflows/integration_app_harness.yml index 6148ecd1a..5e88fc412 100644 --- a/.github/workflows/integration_app_harness.yml +++ b/.github/workflows/integration_app_harness.yml @@ -47,14 +47,14 @@ jobs: python-version: ${{ matrix.python-version }} run-poetry-install: true create-venv-at-path: .venv - - run: poetry run uv pip install pyvirtualdisplay pillow pytest-split + - run: poetry run uv pip install pyvirtualdisplay pillow pytest-split pytest-retry - name: Run app harness tests env: SCREENSHOT_DIR: /tmp/screenshots/${{ matrix.state_manager }}/${{ matrix.python-version }}/${{ matrix.split_index }} REDIS_URL: ${{ matrix.state_manager == 'redis' && 'redis://localhost:6379' || '' }} run: | poetry run playwright install chromium - poetry run pytest tests/integration --splits 2 --group ${{matrix.split_index}} + poetry run pytest tests/integration --retries 3 --maxfail=5 --splits 2 --group ${{matrix.split_index}} - uses: actions/upload-artifact@v4 name: Upload failed test screenshots if: always() diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index ec603fd13..5b8046347 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -3,6 +3,7 @@ import axios from "axios"; import io from "socket.io-client"; import JSON5 from "json5"; import env from "$/env.json"; +import reflexEnvironment from "$/reflex.json"; import Cookies from "universal-cookie"; import { useEffect, useRef, useState } from "react"; import Router, { useRouter } from "next/router"; @@ -407,6 +408,7 @@ export const connect = async ( socket.current = io(endpoint.href, { path: endpoint["pathname"], transports: transports, + protocols: env.TEST_MODE ? undefined : [reflexEnvironment.version], autoUnref: false, }); // Ensure undefined fields in events are sent as null instead of removed diff --git a/reflex/app.py b/reflex/app.py index 0d672e4c0..7e868e730 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -558,11 +558,12 @@ class App(MiddlewareMixin, LifespanMixin): meta=meta, ) - def _compile_page(self, route: str): + def _compile_page(self, route: str, save_page: bool = True): """Compile a page. Args: route: The route of the page to compile. + save_page: If True, the compiled page is saved to self.pages. """ component, enable_state = compiler.compile_unevaluated_page( route, self.unevaluated_pages[route], self.state, self.style, self.theme @@ -573,7 +574,8 @@ class App(MiddlewareMixin, LifespanMixin): # Add the page. self._check_routes_conflict(route) - self.pages[route] = component + if save_page: + self.pages[route] = component def get_load_events(self, route: str) -> list[IndividualEventType[[], Any]]: """Get the load events for a route. @@ -873,14 +875,16 @@ class App(MiddlewareMixin, LifespanMixin): # If a theme component was provided, wrap the app with it app_wrappers[(20, "Theme")] = self.theme + should_compile = self._should_compile() + for route in self.unevaluated_pages: console.debug(f"Evaluating page: {route}") - self._compile_page(route) + self._compile_page(route, save_page=should_compile) # Add the optional endpoints (_upload) self._add_optional_endpoints() - if not self._should_compile(): + if not should_compile: return self._validate_var_dependencies() @@ -1524,7 +1528,11 @@ class EventNamespace(AsyncNamespace): sid: The Socket.IO session id. environ: The request information, including HTTP headers. """ - pass + subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL", None) + if subprotocol and subprotocol != constants.Reflex.VERSION: + console.warn( + f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}." + ) def on_disconnect(self, sid): """Event for when the websocket disconnects. diff --git a/reflex/components/recharts/__init__.py b/reflex/components/recharts/__init__.py index 5e9e6fc14..6495c6583 100644 --- a/reflex/components/recharts/__init__.py +++ b/reflex/components/recharts/__init__.py @@ -70,6 +70,8 @@ _SUBMOD_ATTRS: dict = { "Label", "label_list", "LabelList", + "cell", + "Cell", ], "polar": [ "pie", diff --git a/reflex/components/recharts/__init__.pyi b/reflex/components/recharts/__init__.pyi index 8f6c4673f..61fc9b1e7 100644 --- a/reflex/components/recharts/__init__.pyi +++ b/reflex/components/recharts/__init__.pyi @@ -53,11 +53,13 @@ from .charts import radar_chart as radar_chart from .charts import radial_bar_chart as radial_bar_chart from .charts import scatter_chart as scatter_chart from .charts import treemap as treemap +from .general import Cell as Cell from .general import GraphingTooltip as GraphingTooltip from .general import Label as Label from .general import LabelList as LabelList from .general import Legend as Legend from .general import ResponsiveContainer as ResponsiveContainer +from .general import cell as cell from .general import graphing_tooltip as graphing_tooltip from .general import label as label from .general import label_list as label_list diff --git a/reflex/components/recharts/general.py b/reflex/components/recharts/general.py index 1769ea125..123c7708a 100644 --- a/reflex/components/recharts/general.py +++ b/reflex/components/recharts/general.py @@ -242,8 +242,23 @@ class LabelList(Recharts): stroke: Var[Union[str, Color]] = LiteralVar.create("none") +class Cell(Recharts): + """A Cell component in Recharts.""" + + tag = "Cell" + + alias = "RechartsCell" + + # The presentation attribute of a rectangle in bar or a sector in pie. + fill: Var[str] + + # The presentation attribute of a rectangle in bar or a sector in pie. + stroke: Var[str] + + responsive_container = ResponsiveContainer.create legend = Legend.create graphing_tooltip = GraphingTooltip.create label = Label.create label_list = LabelList.create +cell = Cell.create diff --git a/reflex/components/recharts/general.pyi b/reflex/components/recharts/general.pyi index 823a50fce..74a65c277 100644 --- a/reflex/components/recharts/general.pyi +++ b/reflex/components/recharts/general.pyi @@ -482,8 +482,59 @@ class LabelList(Recharts): """ ... +class Cell(Recharts): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + fill: Optional[Union[Var[str], str]] = None, + stroke: Optional[Union[Var[str], str]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None, + on_blur: Optional[EventType[[], BASE_STATE]] = None, + on_click: Optional[EventType[[], BASE_STATE]] = None, + on_context_menu: Optional[EventType[[], BASE_STATE]] = None, + on_double_click: Optional[EventType[[], BASE_STATE]] = None, + on_focus: Optional[EventType[[], BASE_STATE]] = None, + on_mount: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_down: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_move: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_out: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_over: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_up: Optional[EventType[[], BASE_STATE]] = None, + on_scroll: Optional[EventType[[], BASE_STATE]] = None, + on_unmount: Optional[EventType[[], BASE_STATE]] = None, + **props, + ) -> "Cell": + """Create the component. + + Args: + *children: The children of the component. + fill: The presentation attribute of a rectangle in bar or a sector in pie. + stroke: The presentation attribute of a rectangle in bar or a sector in pie. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: The props of the component. + + Returns: + The component. + """ + ... + responsive_container = ResponsiveContainer.create legend = Legend.create graphing_tooltip = GraphingTooltip.create label = Label.create label_list = LabelList.create +cell = Cell.create diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index e816da0f7..f5946bf5e 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -1,6 +1,7 @@ """The constants package.""" from .base import ( + APP_HARNESS_FLAG, COOKIES, IS_LINUX, IS_MACOS, diff --git a/reflex/constants/base.py b/reflex/constants/base.py index af96583ad..f737858c0 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -257,6 +257,7 @@ SESSION_STORAGE = "session_storage" # Testing variables. # Testing os env set by pytest when running a test case. PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST" +APP_HARNESS_FLAG = "APP_HARNESS_FLAG" REFLEX_VAR_OPENING_TAG = "" REFLEX_VAR_CLOSING_TAG = "" diff --git a/reflex/reflex.py b/reflex/reflex.py index b0f4ccd91..2d6ebc30c 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -440,7 +440,11 @@ def deploy( config.app_name, "--app-name", help="The name of the App to deploy under.", - hidden=True, + ), + app_id: str = typer.Option( + None, + "--app-id", + help="The ID of the App to deploy over.", ), regions: List[str] = typer.Option( [], @@ -480,6 +484,11 @@ def deploy( "--project", help="project id to deploy to", ), + project_name: Optional[str] = typer.Option( + None, + "--project-name", + help="The name of the project to deploy to.", + ), token: Optional[str] = typer.Option( None, "--token", @@ -503,13 +512,6 @@ def deploy( # Set the log level. console.set_log_level(loglevel) - if not token: - # make sure user is logged in. - if interactive: - hosting_cli.login() - else: - raise SystemExit("Token is required for non-interactive mode.") - # Only check requirements if interactive. # There is user interaction for requirements update. if interactive: @@ -524,6 +526,7 @@ def deploy( ) hosting_cli.deploy( app_name=app_name, + app_id=app_id, export_fn=lambda zip_dest_dir, api_url, deploy_url, @@ -547,6 +550,8 @@ def deploy( loglevel=type(loglevel).INFO, # type: ignore token=token, project=project, + config_path=config_path, + project_name=project_name, **extra, ) diff --git a/reflex/testing.py b/reflex/testing.py index b3dedf398..a1083d250 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -282,6 +282,7 @@ class AppHarness: before_decorated_pages = reflex.app.DECORATED_PAGES[self.app_name].copy() # Ensure the AppHarness test does not skip State assignment due to running via pytest os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None) + os.environ[reflex.constants.APP_HARNESS_FLAG] = "true" self.app_module = reflex.utils.prerequisites.get_compiled_app( # Do not reload the module for pre-existing apps (only apps generated from source) reload=self.app_source is not None diff --git a/reflex/utils/build.py b/reflex/utils/build.py index e263374e1..9ea941792 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -13,13 +13,17 @@ from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from reflex import constants from reflex.config import get_config from reflex.utils import console, path_ops, prerequisites, processes +from reflex.utils.exec import is_in_app_harness def set_env_json(): """Write the upload url to a REFLEX_JSON.""" path_ops.update_json_file( str(prerequisites.get_web_dir() / constants.Dirs.ENV_JSON), - {endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint}, + { + **{endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint}, + "TEST_MODE": is_in_app_harness(), + }, ) diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 6087818d9..c10b6b856 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -509,6 +509,15 @@ def is_testing_env() -> bool: return constants.PYTEST_CURRENT_TEST in os.environ +def is_in_app_harness() -> bool: + """Whether the app is running in the app harness. + + Returns: + True if the app is running in the app harness. + """ + return constants.APP_HARNESS_FLAG in os.environ + + def is_prod_mode() -> bool: """Check if the app is running in production mode. diff --git a/reflex/utils/path_ops.py b/reflex/utils/path_ops.py index 38560977e..b447718d2 100644 --- a/reflex/utils/path_ops.py +++ b/reflex/utils/path_ops.py @@ -174,7 +174,7 @@ def get_node_path() -> str | None: return str(node_path) -def get_npm_path() -> str | None: +def get_npm_path() -> Path | None: """Get npm binary path. Returns: @@ -183,8 +183,8 @@ def get_npm_path() -> str | None: npm_path = Path(constants.Node.NPM_PATH) if use_system_node() or not npm_path.exists(): system_npm_path = which("npm") - return str(system_npm_path) if system_npm_path else None - return str(npm_path) + npm_path = Path(system_npm_path) if system_npm_path else None + return npm_path.absolute() if npm_path else None def update_json_file(file_path: str | Path, update_dict: dict[str, int | str]): diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 4f9cc0c14..415519c8f 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -254,7 +254,7 @@ def get_package_manager(on_failure_return_none: bool = False) -> str | None: """ npm_path = path_ops.get_npm_path() if npm_path is not None: - return str(Path(npm_path).resolve()) + return str(npm_path) if on_failure_return_none: return None raise FileNotFoundError("NPM not found. You may need to run `reflex init`.") diff --git a/reflex/utils/processes.py b/reflex/utils/processes.py index 3673b36b2..575688eda 100644 --- a/reflex/utils/processes.py +++ b/reflex/utils/processes.py @@ -9,7 +9,6 @@ import os import signal import subprocess from concurrent import futures -from pathlib import Path from typing import Callable, Generator, List, Optional, Tuple, Union import psutil @@ -368,7 +367,7 @@ def get_command_with_loglevel(command: list[str]) -> list[str]: The updated command list """ npm_path = path_ops.get_npm_path() - npm_path = str(Path(npm_path).resolve()) if npm_path else npm_path + npm_path = str(npm_path) if npm_path else None if command[0] == npm_path: return [*command, "--loglevel", "silly"] diff --git a/reflex/utils/types.py b/reflex/utils/types.py index b8bcbf2d6..ac30342e2 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -829,6 +829,22 @@ StateBases = get_base_class(StateVar) StateIterBases = get_base_class(StateIterVar) +def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]): + """Check if a class is a subclass of another class. Returns False if internal error occurs. + + Args: + cls: The class to check. + cls_check: The class to check against. + + Returns: + Whether the class is a subclass of the other class. + """ + try: + return issubclass(cls, cls_check) + except TypeError: + return False + + def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool: """Check if a type hint is a subclass of another type hint. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 2892d004d..122545187 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -26,6 +26,7 @@ from typing import ( Iterable, List, Literal, + Mapping, NoReturn, Optional, Set, @@ -64,6 +65,7 @@ from reflex.utils.types import ( _isinstance, get_origin, has_args, + safe_issubclass, unionize, ) @@ -127,7 +129,7 @@ class VarData: state: str = "", field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, - hooks: dict[str, VarData | None] | None = None, + hooks: Mapping[str, VarData | None] | None = None, deps: list[Var] | None = None, position: Hooks.HookPosition | None = None, ): @@ -643,8 +645,8 @@ class Var(Generic[VAR_TYPE]): @overload def to( self, - output: type[dict], - ) -> ObjectVar[dict]: ... + output: type[Mapping], + ) -> ObjectVar[Mapping]: ... @overload def to( @@ -686,7 +688,9 @@ class Var(Generic[VAR_TYPE]): # If the first argument is a python type, we map it to the corresponding Var type. for var_subclass in _var_subclasses[::-1]: - if fixed_output_type in var_subclass.python_types: + if fixed_output_type in var_subclass.python_types or safe_issubclass( + fixed_output_type, var_subclass.python_types + ): return self.to(var_subclass.var_subclass, output) if fixed_output_type is None: @@ -820,7 +824,7 @@ class Var(Generic[VAR_TYPE]): return False if issubclass(type_, list): return [] - if issubclass(type_, dict): + if issubclass(type_, Mapping): return {} if issubclass(type_, tuple): return () @@ -1026,7 +1030,7 @@ class Var(Generic[VAR_TYPE]): f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")] } ), - ).to(ObjectVar, Dict[str, str]) + ).to(ObjectVar, Mapping[str, str]) return refs[LiteralVar.create(str(self))] @deprecated("Use `.js_type()` instead.") @@ -1373,7 +1377,7 @@ class LiteralVar(Var): serialized_value = serializers.serialize(value) if serialized_value is not None: - if isinstance(serialized_value, dict): + if isinstance(serialized_value, Mapping): return LiteralObjectVar.create( serialized_value, _var_type=type(value), @@ -1498,7 +1502,7 @@ def var_operation( ) -> Callable[P, ArrayVar[LIST_T]]: ... -OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict) +OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping) @overload @@ -1573,8 +1577,8 @@ def figure_out_type(value: Any) -> types.GenericType: return Set[unionize(*(figure_out_type(v) for v in value))] if isinstance(value, tuple): return Tuple[unionize(*(figure_out_type(v) for v in value)), ...] - if isinstance(value, dict): - return Dict[ + if isinstance(value, Mapping): + return Mapping[ unionize(*(figure_out_type(k) for k in value)), unionize(*(figure_out_type(v) for v in value.values())), ] @@ -2002,10 +2006,10 @@ class ComputedVar(Var[RETURN_TYPE]): @overload def __get__( - self: ComputedVar[dict[DICT_KEY, DICT_VAL]], + self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]], instance: None, owner: Type, - ) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ... + ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ... @overload def __get__( @@ -2915,11 +2919,14 @@ V = TypeVar("V") BASE_TYPE = TypeVar("BASE_TYPE", bound=Base) +FIELD_TYPE = TypeVar("FIELD_TYPE") +MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping) -class Field(Generic[T]): + +class Field(Generic[FIELD_TYPE]): """Shadow class for Var to allow for type hinting in the IDE.""" - def __set__(self, instance, value: T): + def __set__(self, instance, value: FIELD_TYPE): """Set the Var. Args: @@ -2931,7 +2938,9 @@ class Field(Generic[T]): def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ... @overload - def __get__(self: Field[int], instance: None, owner) -> NumberVar: ... + def __get__( + self: Field[int] | Field[float] | Field[int | float], instance: None, owner + ) -> NumberVar: ... @overload def __get__(self: Field[str], instance: None, owner) -> StringVar: ... @@ -2948,8 +2957,8 @@ class Field(Generic[T]): @overload def __get__( - self: Field[Dict[str, V]], instance: None, owner - ) -> ObjectVar[Dict[str, V]]: ... + self: Field[MAPPING_TYPE], instance: None, owner + ) -> ObjectVar[MAPPING_TYPE]: ... @overload def __get__( @@ -2957,10 +2966,10 @@ class Field(Generic[T]): ) -> ObjectVar[BASE_TYPE]: ... @overload - def __get__(self, instance: None, owner) -> Var[T]: ... + def __get__(self, instance: None, owner) -> Var[FIELD_TYPE]: ... @overload - def __get__(self, instance, owner) -> T: ... + def __get__(self, instance, owner) -> FIELD_TYPE: ... def __get__(self, instance, owner): # type: ignore """Get the Var. @@ -2971,7 +2980,7 @@ class Field(Generic[T]): """ -def field(value: T) -> Field[T]: +def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]: """Create a Field with a value. Args: diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 5de431f5a..7b951c559 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -8,8 +8,8 @@ import typing from inspect import isclass from typing import ( Any, - Dict, List, + Mapping, NoReturn, Tuple, Type, @@ -19,6 +19,8 @@ from typing import ( overload, ) +from typing_extensions import is_typeddict + from reflex.utils import types from reflex.utils.exceptions import VarAttributeError from reflex.utils.types import GenericType, get_attribute_access_type, get_origin @@ -36,7 +38,7 @@ from .base import ( from .number import BooleanVar, NumberVar, raise_unsupported_operand_types from .sequence import ArrayVar, StringVar -OBJECT_TYPE = TypeVar("OBJECT_TYPE") +OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True) KEY_TYPE = TypeVar("KEY_TYPE") VALUE_TYPE = TypeVar("VALUE_TYPE") @@ -46,7 +48,7 @@ ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE") OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE") -class ObjectVar(Var[OBJECT_TYPE], python_types=dict): +class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): """Base class for immutable object vars.""" def _key_type(self) -> Type: @@ -59,7 +61,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def _value_type( - self: ObjectVar[Dict[Any, VALUE_TYPE]], + self: ObjectVar[Mapping[Any, VALUE_TYPE]], ) -> Type[VALUE_TYPE]: ... @overload @@ -74,7 +76,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): fixed_type = get_origin(self._var_type) or self._var_type if not isclass(fixed_type): return Any - args = get_args(self._var_type) if issubclass(fixed_type, dict) else () + args = get_args(self._var_type) if issubclass(fixed_type, Mapping) else () return args[1] if args else Any def keys(self) -> ArrayVar[List[str]]: @@ -87,7 +89,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def values( - self: ObjectVar[Dict[Any, VALUE_TYPE]], + self: ObjectVar[Mapping[Any, VALUE_TYPE]], ) -> ArrayVar[List[VALUE_TYPE]]: ... @overload @@ -103,7 +105,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def entries( - self: ObjectVar[Dict[Any, VALUE_TYPE]], + self: ObjectVar[Mapping[Any, VALUE_TYPE]], ) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ... @overload @@ -133,49 +135,55 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): # NoReturn is used here to catch when key value is Any @overload def __getitem__( - self: ObjectVar[Dict[Any, NoReturn]], + self: ObjectVar[Mapping[Any, NoReturn]], key: Var | Any, ) -> Var: ... + @overload + def __getitem__( + self: (ObjectVar[Mapping[Any, bool]]), + key: Var | Any, + ) -> BooleanVar: ... + @overload def __getitem__( self: ( - ObjectVar[Dict[Any, int]] - | ObjectVar[Dict[Any, float]] - | ObjectVar[Dict[Any, int | float]] + ObjectVar[Mapping[Any, int]] + | ObjectVar[Mapping[Any, float]] + | ObjectVar[Mapping[Any, int | float]] ), key: Var | Any, ) -> NumberVar: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, str]], + self: ObjectVar[Mapping[Any, str]], key: Var | Any, ) -> StringVar: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], + self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], key: Var | Any, ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... @overload def __getitem__( - self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]], key: Var | Any, - ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... + ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ... def __getitem__(self, key: Var | Any) -> Var: """Get an item from the object. @@ -195,49 +203,49 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): # NoReturn is used here to catch when key value is Any @overload def __getattr__( - self: ObjectVar[Dict[Any, NoReturn]], + self: ObjectVar[Mapping[Any, NoReturn]], name: str, ) -> Var: ... @overload def __getattr__( self: ( - ObjectVar[Dict[Any, int]] - | ObjectVar[Dict[Any, float]] - | ObjectVar[Dict[Any, int | float]] + ObjectVar[Mapping[Any, int]] + | ObjectVar[Mapping[Any, float]] + | ObjectVar[Mapping[Any, int | float]] ), name: str, ) -> NumberVar: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, str]], + self: ObjectVar[Mapping[Any, str]], name: str, ) -> StringVar: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]], name: str, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]], name: str, ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], + self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], name: str, ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... @overload def __getattr__( - self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]], name: str, - ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... + ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ... @overload def __getattr__( @@ -266,8 +274,11 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): var_type = get_args(var_type)[0] fixed_type = var_type if isclass(var_type) else get_origin(var_type) - if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or ( - fixed_type in types.UnionTypes + + if ( + (isclass(fixed_type) and not issubclass(fixed_type, Mapping)) + or (fixed_type in types.UnionTypes) + or is_typeddict(fixed_type) ): attribute_type = get_attribute_access_type(var_type, name) if attribute_type is None: @@ -299,7 +310,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): """Base class for immutable literal object vars.""" - _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( + _var_value: Mapping[Union[Var, Any], Union[Var, Any]] = dataclasses.field( default_factory=dict ) @@ -383,7 +394,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): @classmethod def create( cls, - _var_value: dict, + _var_value: Mapping, _var_type: Type[OBJECT_TYPE] | None = None, _var_data: VarData | None = None, ) -> LiteralObjectVar[OBJECT_TYPE]: @@ -466,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar): """ return var_operation_return( js_expression=f"({{...{lhs}, ...{rhs}}})", - var_type=Dict[ + var_type=Mapping[ Union[lhs._key_type(), rhs._key_type()], Union[lhs._value_type(), rhs._value_type()], ], diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 5864e70b9..1b11f50e6 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -987,7 +987,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): raise_unsupported_operand_types("[]", (type(self), type(i))) return array_item_operation(self, i) - def length(self) -> NumberVar: + def length(self) -> NumberVar[int]: """Get the length of the array. Returns: diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py index f09e800e5..eaf98414d 100644 --- a/tests/units/components/core/test_match.py +++ b/tests/units/components/core/test_match.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import List, Mapping, Tuple import pytest @@ -67,7 +67,7 @@ def test_match_components(): assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}' assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })' - assert match_cases[4][0]._var_type == Dict[str, str] + assert match_cases[4][0]._var_type == Mapping[str, str] fifth_return_value_render = match_cases[4][1].render() assert fifth_return_value_render["name"] == "RadixThemesText" assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}' diff --git a/tests/units/test_style.py b/tests/units/test_style.py index e1d652798..bb585fd22 100644 --- a/tests/units/test_style.py +++ b/tests/units/test_style.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any, Mapping import pytest @@ -379,7 +379,7 @@ class StyleState(rx.State): { "css": Var( _js_expr=f'({{ ["color"] : ("dark"+{StyleState.color}) }})' - ).to(Dict[str, str]) + ).to(Mapping[str, str]) }, ), ( diff --git a/tests/units/test_var.py b/tests/units/test_var.py index 6ad82a761..a8e9cd88c 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -2,7 +2,7 @@ import json import math import sys import typing -from typing import Dict, List, Optional, Set, Tuple, Union, cast +from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast import pytest from pandas import DataFrame @@ -270,7 +270,7 @@ def test_get_setter(prop: Var, expected): ([1, 2, 3], Var(_js_expr="[1, 2, 3]", _var_type=List[int])), ( {"a": 1, "b": 2}, - Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Dict[str, int]), + Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Mapping[str, int]), ), ], ) diff --git a/tests/units/vars/test_base.py b/tests/units/vars/test_base.py index 68bc0c38e..e4ae7327a 100644 --- a/tests/units/vars/test_base.py +++ b/tests/units/vars/test_base.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union +from typing import List, Mapping, Union import pytest @@ -37,12 +37,12 @@ class ChildGenericDict(GenericDict): ("a", str), ([1, 2, 3], List[int]), ([1, 2.0, "a"], List[Union[int, float, str]]), - ({"a": 1, "b": 2}, Dict[str, int]), - ({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]), + ({"a": 1, "b": 2}, Mapping[str, int]), + ({"a": 1, 2: "b"}, Mapping[Union[int, str], Union[str, int]]), (CustomDict(), CustomDict), (ChildCustomDict(), ChildCustomDict), - (GenericDict({1: 1}), Dict[int, int]), - (ChildGenericDict({1: 1}), Dict[int, int]), + (GenericDict({1: 1}), Mapping[int, int]), + (ChildGenericDict({1: 1}), Mapping[int, int]), ], ) def test_figure_out_type(value, expected):