Merge branch 'main' into release/masenf/bump-deps

This commit is contained in:
Lendemor 2025-01-22 17:29:38 +01:00
commit 2152da1633
24 changed files with 222 additions and 86 deletions

View File

@ -47,14 +47,14 @@ jobs:
python-version: ${{ matrix.python-version }}
run-poetry-install: true
create-venv-at-path: .venv
- run: poetry run uv pip install pyvirtualdisplay pillow pytest-split
- run: poetry run uv pip install pyvirtualdisplay pillow pytest-split pytest-retry
- name: Run app harness tests
env:
SCREENSHOT_DIR: /tmp/screenshots/${{ matrix.state_manager }}/${{ matrix.python-version }}/${{ matrix.split_index }}
REDIS_URL: ${{ matrix.state_manager == 'redis' && 'redis://localhost:6379' || '' }}
run: |
poetry run playwright install chromium
poetry run pytest tests/integration --splits 2 --group ${{matrix.split_index}}
poetry run pytest tests/integration --retries 3 --maxfail=5 --splits 2 --group ${{matrix.split_index}}
- uses: actions/upload-artifact@v4
name: Upload failed test screenshots
if: always()

View File

@ -3,6 +3,7 @@ import axios from "axios";
import io from "socket.io-client";
import JSON5 from "json5";
import env from "$/env.json";
import reflexEnvironment from "$/reflex.json";
import Cookies from "universal-cookie";
import { useEffect, useRef, useState } from "react";
import Router, { useRouter } from "next/router";
@ -407,6 +408,7 @@ export const connect = async (
socket.current = io(endpoint.href, {
path: endpoint["pathname"],
transports: transports,
protocols: env.TEST_MODE ? undefined : [reflexEnvironment.version],
autoUnref: false,
});
// Ensure undefined fields in events are sent as null instead of removed

View File

@ -558,11 +558,12 @@ class App(MiddlewareMixin, LifespanMixin):
meta=meta,
)
def _compile_page(self, route: str):
def _compile_page(self, route: str, save_page: bool = True):
"""Compile a page.
Args:
route: The route of the page to compile.
save_page: If True, the compiled page is saved to self.pages.
"""
component, enable_state = compiler.compile_unevaluated_page(
route, self.unevaluated_pages[route], self.state, self.style, self.theme
@ -573,7 +574,8 @@ class App(MiddlewareMixin, LifespanMixin):
# Add the page.
self._check_routes_conflict(route)
self.pages[route] = component
if save_page:
self.pages[route] = component
def get_load_events(self, route: str) -> list[IndividualEventType[[], Any]]:
"""Get the load events for a route.
@ -873,14 +875,16 @@ class App(MiddlewareMixin, LifespanMixin):
# If a theme component was provided, wrap the app with it
app_wrappers[(20, "Theme")] = self.theme
should_compile = self._should_compile()
for route in self.unevaluated_pages:
console.debug(f"Evaluating page: {route}")
self._compile_page(route)
self._compile_page(route, save_page=should_compile)
# Add the optional endpoints (_upload)
self._add_optional_endpoints()
if not self._should_compile():
if not should_compile:
return
self._validate_var_dependencies()
@ -1524,7 +1528,11 @@ class EventNamespace(AsyncNamespace):
sid: The Socket.IO session id.
environ: The request information, including HTTP headers.
"""
pass
subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL", None)
if subprotocol and subprotocol != constants.Reflex.VERSION:
console.warn(
f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}."
)
def on_disconnect(self, sid):
"""Event for when the websocket disconnects.

View File

@ -70,6 +70,8 @@ _SUBMOD_ATTRS: dict = {
"Label",
"label_list",
"LabelList",
"cell",
"Cell",
],
"polar": [
"pie",

View File

@ -53,11 +53,13 @@ from .charts import radar_chart as radar_chart
from .charts import radial_bar_chart as radial_bar_chart
from .charts import scatter_chart as scatter_chart
from .charts import treemap as treemap
from .general import Cell as Cell
from .general import GraphingTooltip as GraphingTooltip
from .general import Label as Label
from .general import LabelList as LabelList
from .general import Legend as Legend
from .general import ResponsiveContainer as ResponsiveContainer
from .general import cell as cell
from .general import graphing_tooltip as graphing_tooltip
from .general import label as label
from .general import label_list as label_list

View File

@ -242,8 +242,23 @@ class LabelList(Recharts):
stroke: Var[Union[str, Color]] = LiteralVar.create("none")
class Cell(Recharts):
"""A Cell component in Recharts."""
tag = "Cell"
alias = "RechartsCell"
# The presentation attribute of a rectangle in bar or a sector in pie.
fill: Var[str]
# The presentation attribute of a rectangle in bar or a sector in pie.
stroke: Var[str]
responsive_container = ResponsiveContainer.create
legend = Legend.create
graphing_tooltip = GraphingTooltip.create
label = Label.create
label_list = LabelList.create
cell = Cell.create

View File

@ -482,8 +482,59 @@ class LabelList(Recharts):
"""
...
class Cell(Recharts):
@overload
@classmethod
def create( # type: ignore
cls,
*children,
fill: Optional[Union[Var[str], str]] = None,
stroke: Optional[Union[Var[str], str]] = None,
style: Optional[Style] = None,
key: Optional[Any] = None,
id: Optional[Any] = None,
class_name: Optional[Any] = None,
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
on_blur: Optional[EventType[[], BASE_STATE]] = None,
on_click: Optional[EventType[[], BASE_STATE]] = None,
on_context_menu: Optional[EventType[[], BASE_STATE]] = None,
on_double_click: Optional[EventType[[], BASE_STATE]] = None,
on_focus: Optional[EventType[[], BASE_STATE]] = None,
on_mount: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_down: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_move: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_out: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_over: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_up: Optional[EventType[[], BASE_STATE]] = None,
on_scroll: Optional[EventType[[], BASE_STATE]] = None,
on_unmount: Optional[EventType[[], BASE_STATE]] = None,
**props,
) -> "Cell":
"""Create the component.
Args:
*children: The children of the component.
fill: The presentation attribute of a rectangle in bar or a sector in pie.
stroke: The presentation attribute of a rectangle in bar or a sector in pie.
style: The style of the component.
key: A unique key for the component.
id: The id for the component.
class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute
**props: The props of the component.
Returns:
The component.
"""
...
responsive_container = ResponsiveContainer.create
legend = Legend.create
graphing_tooltip = GraphingTooltip.create
label = Label.create
label_list = LabelList.create
cell = Cell.create

View File

@ -1,6 +1,7 @@
"""The constants package."""
from .base import (
APP_HARNESS_FLAG,
COOKIES,
IS_LINUX,
IS_MACOS,

View File

@ -257,6 +257,7 @@ SESSION_STORAGE = "session_storage"
# Testing variables.
# Testing os env set by pytest when running a test case.
PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
APP_HARNESS_FLAG = "APP_HARNESS_FLAG"
REFLEX_VAR_OPENING_TAG = "<reflex.Var>"
REFLEX_VAR_CLOSING_TAG = "</reflex.Var>"

View File

@ -440,7 +440,11 @@ def deploy(
config.app_name,
"--app-name",
help="The name of the App to deploy under.",
hidden=True,
),
app_id: str = typer.Option(
None,
"--app-id",
help="The ID of the App to deploy over.",
),
regions: List[str] = typer.Option(
[],
@ -480,6 +484,11 @@ def deploy(
"--project",
help="project id to deploy to",
),
project_name: Optional[str] = typer.Option(
None,
"--project-name",
help="The name of the project to deploy to.",
),
token: Optional[str] = typer.Option(
None,
"--token",
@ -503,13 +512,6 @@ def deploy(
# Set the log level.
console.set_log_level(loglevel)
if not token:
# make sure user is logged in.
if interactive:
hosting_cli.login()
else:
raise SystemExit("Token is required for non-interactive mode.")
# Only check requirements if interactive.
# There is user interaction for requirements update.
if interactive:
@ -524,6 +526,7 @@ def deploy(
)
hosting_cli.deploy(
app_name=app_name,
app_id=app_id,
export_fn=lambda zip_dest_dir,
api_url,
deploy_url,
@ -547,6 +550,8 @@ def deploy(
loglevel=type(loglevel).INFO, # type: ignore
token=token,
project=project,
config_path=config_path,
project_name=project_name,
**extra,
)

View File

@ -282,6 +282,7 @@ class AppHarness:
before_decorated_pages = reflex.app.DECORATED_PAGES[self.app_name].copy()
# Ensure the AppHarness test does not skip State assignment due to running via pytest
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
os.environ[reflex.constants.APP_HARNESS_FLAG] = "true"
self.app_module = reflex.utils.prerequisites.get_compiled_app(
# Do not reload the module for pre-existing apps (only apps generated from source)
reload=self.app_source is not None

View File

@ -13,13 +13,17 @@ from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from reflex import constants
from reflex.config import get_config
from reflex.utils import console, path_ops, prerequisites, processes
from reflex.utils.exec import is_in_app_harness
def set_env_json():
"""Write the upload url to a REFLEX_JSON."""
path_ops.update_json_file(
str(prerequisites.get_web_dir() / constants.Dirs.ENV_JSON),
{endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint},
{
**{endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint},
"TEST_MODE": is_in_app_harness(),
},
)

View File

@ -509,6 +509,15 @@ def is_testing_env() -> bool:
return constants.PYTEST_CURRENT_TEST in os.environ
def is_in_app_harness() -> bool:
"""Whether the app is running in the app harness.
Returns:
True if the app is running in the app harness.
"""
return constants.APP_HARNESS_FLAG in os.environ
def is_prod_mode() -> bool:
"""Check if the app is running in production mode.

View File

@ -174,7 +174,7 @@ def get_node_path() -> str | None:
return str(node_path)
def get_npm_path() -> str | None:
def get_npm_path() -> Path | None:
"""Get npm binary path.
Returns:
@ -183,8 +183,8 @@ def get_npm_path() -> str | None:
npm_path = Path(constants.Node.NPM_PATH)
if use_system_node() or not npm_path.exists():
system_npm_path = which("npm")
return str(system_npm_path) if system_npm_path else None
return str(npm_path)
npm_path = Path(system_npm_path) if system_npm_path else None
return npm_path.absolute() if npm_path else None
def update_json_file(file_path: str | Path, update_dict: dict[str, int | str]):

View File

@ -254,7 +254,7 @@ def get_package_manager(on_failure_return_none: bool = False) -> str | None:
"""
npm_path = path_ops.get_npm_path()
if npm_path is not None:
return str(Path(npm_path).resolve())
return str(npm_path)
if on_failure_return_none:
return None
raise FileNotFoundError("NPM not found. You may need to run `reflex init`.")

View File

@ -9,7 +9,6 @@ import os
import signal
import subprocess
from concurrent import futures
from pathlib import Path
from typing import Callable, Generator, List, Optional, Tuple, Union
import psutil
@ -368,7 +367,7 @@ def get_command_with_loglevel(command: list[str]) -> list[str]:
The updated command list
"""
npm_path = path_ops.get_npm_path()
npm_path = str(Path(npm_path).resolve()) if npm_path else npm_path
npm_path = str(npm_path) if npm_path else None
if command[0] == npm_path:
return [*command, "--loglevel", "silly"]

View File

@ -829,6 +829,22 @@ StateBases = get_base_class(StateVar)
StateIterBases = get_base_class(StateIterVar)
def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]):
"""Check if a class is a subclass of another class. Returns False if internal error occurs.
Args:
cls: The class to check.
cls_check: The class to check against.
Returns:
Whether the class is a subclass of the other class.
"""
try:
return issubclass(cls, cls_check)
except TypeError:
return False
def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
"""Check if a type hint is a subclass of another type hint.

View File

@ -26,6 +26,7 @@ from typing import (
Iterable,
List,
Literal,
Mapping,
NoReturn,
Optional,
Set,
@ -64,6 +65,7 @@ from reflex.utils.types import (
_isinstance,
get_origin,
has_args,
safe_issubclass,
unionize,
)
@ -127,7 +129,7 @@ class VarData:
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, VarData | None] | None = None,
hooks: Mapping[str, VarData | None] | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
):
@ -643,8 +645,8 @@ class Var(Generic[VAR_TYPE]):
@overload
def to(
self,
output: type[dict],
) -> ObjectVar[dict]: ...
output: type[Mapping],
) -> ObjectVar[Mapping]: ...
@overload
def to(
@ -686,7 +688,9 @@ class Var(Generic[VAR_TYPE]):
# If the first argument is a python type, we map it to the corresponding Var type.
for var_subclass in _var_subclasses[::-1]:
if fixed_output_type in var_subclass.python_types:
if fixed_output_type in var_subclass.python_types or safe_issubclass(
fixed_output_type, var_subclass.python_types
):
return self.to(var_subclass.var_subclass, output)
if fixed_output_type is None:
@ -820,7 +824,7 @@ class Var(Generic[VAR_TYPE]):
return False
if issubclass(type_, list):
return []
if issubclass(type_, dict):
if issubclass(type_, Mapping):
return {}
if issubclass(type_, tuple):
return ()
@ -1026,7 +1030,7 @@ class Var(Generic[VAR_TYPE]):
f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")]
}
),
).to(ObjectVar, Dict[str, str])
).to(ObjectVar, Mapping[str, str])
return refs[LiteralVar.create(str(self))]
@deprecated("Use `.js_type()` instead.")
@ -1373,7 +1377,7 @@ class LiteralVar(Var):
serialized_value = serializers.serialize(value)
if serialized_value is not None:
if isinstance(serialized_value, dict):
if isinstance(serialized_value, Mapping):
return LiteralObjectVar.create(
serialized_value,
_var_type=type(value),
@ -1498,7 +1502,7 @@ def var_operation(
) -> Callable[P, ArrayVar[LIST_T]]: ...
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping)
@overload
@ -1573,8 +1577,8 @@ def figure_out_type(value: Any) -> types.GenericType:
return Set[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, tuple):
return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
if isinstance(value, dict):
return Dict[
if isinstance(value, Mapping):
return Mapping[
unionize(*(figure_out_type(k) for k in value)),
unionize(*(figure_out_type(v) for v in value.values())),
]
@ -2002,10 +2006,10 @@ class ComputedVar(Var[RETURN_TYPE]):
@overload
def __get__(
self: ComputedVar[dict[DICT_KEY, DICT_VAL]],
self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]],
instance: None,
owner: Type,
) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ...
) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
@overload
def __get__(
@ -2915,11 +2919,14 @@ V = TypeVar("V")
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
FIELD_TYPE = TypeVar("FIELD_TYPE")
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
class Field(Generic[T]):
class Field(Generic[FIELD_TYPE]):
"""Shadow class for Var to allow for type hinting in the IDE."""
def __set__(self, instance, value: T):
def __set__(self, instance, value: FIELD_TYPE):
"""Set the Var.
Args:
@ -2931,7 +2938,9 @@ class Field(Generic[T]):
def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
@overload
def __get__(self: Field[int], instance: None, owner) -> NumberVar: ...
def __get__(
self: Field[int] | Field[float] | Field[int | float], instance: None, owner
) -> NumberVar: ...
@overload
def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
@ -2948,8 +2957,8 @@ class Field(Generic[T]):
@overload
def __get__(
self: Field[Dict[str, V]], instance: None, owner
) -> ObjectVar[Dict[str, V]]: ...
self: Field[MAPPING_TYPE], instance: None, owner
) -> ObjectVar[MAPPING_TYPE]: ...
@overload
def __get__(
@ -2957,10 +2966,10 @@ class Field(Generic[T]):
) -> ObjectVar[BASE_TYPE]: ...
@overload
def __get__(self, instance: None, owner) -> Var[T]: ...
def __get__(self, instance: None, owner) -> Var[FIELD_TYPE]: ...
@overload
def __get__(self, instance, owner) -> T: ...
def __get__(self, instance, owner) -> FIELD_TYPE: ...
def __get__(self, instance, owner): # type: ignore
"""Get the Var.
@ -2971,7 +2980,7 @@ class Field(Generic[T]):
"""
def field(value: T) -> Field[T]:
def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]:
"""Create a Field with a value.
Args:

View File

@ -8,8 +8,8 @@ import typing
from inspect import isclass
from typing import (
Any,
Dict,
List,
Mapping,
NoReturn,
Tuple,
Type,
@ -19,6 +19,8 @@ from typing import (
overload,
)
from typing_extensions import is_typeddict
from reflex.utils import types
from reflex.utils.exceptions import VarAttributeError
from reflex.utils.types import GenericType, get_attribute_access_type, get_origin
@ -36,7 +38,7 @@ from .base import (
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
from .sequence import ArrayVar, StringVar
OBJECT_TYPE = TypeVar("OBJECT_TYPE")
OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)
KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE")
@ -46,7 +48,7 @@ ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE")
OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE")
class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
"""Base class for immutable object vars."""
def _key_type(self) -> Type:
@ -59,7 +61,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def _value_type(
self: ObjectVar[Dict[Any, VALUE_TYPE]],
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> Type[VALUE_TYPE]: ...
@overload
@ -74,7 +76,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
fixed_type = get_origin(self._var_type) or self._var_type
if not isclass(fixed_type):
return Any
args = get_args(self._var_type) if issubclass(fixed_type, dict) else ()
args = get_args(self._var_type) if issubclass(fixed_type, Mapping) else ()
return args[1] if args else Any
def keys(self) -> ArrayVar[List[str]]:
@ -87,7 +89,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def values(
self: ObjectVar[Dict[Any, VALUE_TYPE]],
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[VALUE_TYPE]]: ...
@overload
@ -103,7 +105,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def entries(
self: ObjectVar[Dict[Any, VALUE_TYPE]],
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ...
@overload
@ -133,49 +135,55 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any
@overload
def __getitem__(
self: ObjectVar[Dict[Any, NoReturn]],
self: ObjectVar[Mapping[Any, NoReturn]],
key: Var | Any,
) -> Var: ...
@overload
def __getitem__(
self: (ObjectVar[Mapping[Any, bool]]),
key: Var | Any,
) -> BooleanVar: ...
@overload
def __getitem__(
self: (
ObjectVar[Dict[Any, int]]
| ObjectVar[Dict[Any, float]]
| ObjectVar[Dict[Any, int | float]]
ObjectVar[Mapping[Any, int]]
| ObjectVar[Mapping[Any, float]]
| ObjectVar[Mapping[Any, int | float]]
),
key: Var | Any,
) -> NumberVar: ...
@overload
def __getitem__(
self: ObjectVar[Dict[Any, str]],
self: ObjectVar[Mapping[Any, str]],
key: Var | Any,
) -> StringVar: ...
@overload
def __getitem__(
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
key: Var | Any,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
key: Var | Any,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
def __getitem__(self, key: Var | Any) -> Var:
"""Get an item from the object.
@ -195,49 +203,49 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any
@overload
def __getattr__(
self: ObjectVar[Dict[Any, NoReturn]],
self: ObjectVar[Mapping[Any, NoReturn]],
name: str,
) -> Var: ...
@overload
def __getattr__(
self: (
ObjectVar[Dict[Any, int]]
| ObjectVar[Dict[Any, float]]
| ObjectVar[Dict[Any, int | float]]
ObjectVar[Mapping[Any, int]]
| ObjectVar[Mapping[Any, float]]
| ObjectVar[Mapping[Any, int | float]]
),
name: str,
) -> NumberVar: ...
@overload
def __getattr__(
self: ObjectVar[Dict[Any, str]],
self: ObjectVar[Mapping[Any, str]],
name: str,
) -> StringVar: ...
@overload
def __getattr__(
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
name: str,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
name: str,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
@overload
def __getattr__(
@ -266,8 +274,11 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
var_type = get_args(var_type)[0]
fixed_type = var_type if isclass(var_type) else get_origin(var_type)
if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or (
fixed_type in types.UnionTypes
if (
(isclass(fixed_type) and not issubclass(fixed_type, Mapping))
or (fixed_type in types.UnionTypes)
or is_typeddict(fixed_type)
):
attribute_type = get_attribute_access_type(var_type, name)
if attribute_type is None:
@ -299,7 +310,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
"""Base class for immutable literal object vars."""
_var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
_var_value: Mapping[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
default_factory=dict
)
@ -383,7 +394,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
@classmethod
def create(
cls,
_var_value: dict,
_var_value: Mapping,
_var_type: Type[OBJECT_TYPE] | None = None,
_var_data: VarData | None = None,
) -> LiteralObjectVar[OBJECT_TYPE]:
@ -466,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
"""
return var_operation_return(
js_expression=f"({{...{lhs}, ...{rhs}}})",
var_type=Dict[
var_type=Mapping[
Union[lhs._key_type(), rhs._key_type()],
Union[lhs._value_type(), rhs._value_type()],
],

View File

@ -987,7 +987,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
raise_unsupported_operand_types("[]", (type(self), type(i)))
return array_item_operation(self, i)
def length(self) -> NumberVar:
def length(self) -> NumberVar[int]:
"""Get the length of the array.
Returns:

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import List, Mapping, Tuple
import pytest
@ -67,7 +67,7 @@ def test_match_components():
assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}'
assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })'
assert match_cases[4][0]._var_type == Dict[str, str]
assert match_cases[4][0]._var_type == Mapping[str, str]
fifth_return_value_render = match_cases[4][1].render()
assert fifth_return_value_render["name"] == "RadixThemesText"
assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}'

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict
from typing import Any, Mapping
import pytest
@ -379,7 +379,7 @@ class StyleState(rx.State):
{
"css": Var(
_js_expr=f'({{ ["color"] : ("dark"+{StyleState.color}) }})'
).to(Dict[str, str])
).to(Mapping[str, str])
},
),
(

View File

@ -2,7 +2,7 @@ import json
import math
import sys
import typing
from typing import Dict, List, Optional, Set, Tuple, Union, cast
from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast
import pytest
from pandas import DataFrame
@ -270,7 +270,7 @@ def test_get_setter(prop: Var, expected):
([1, 2, 3], Var(_js_expr="[1, 2, 3]", _var_type=List[int])),
(
{"a": 1, "b": 2},
Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Dict[str, int]),
Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Mapping[str, int]),
),
],
)

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Union
from typing import List, Mapping, Union
import pytest
@ -37,12 +37,12 @@ class ChildGenericDict(GenericDict):
("a", str),
([1, 2, 3], List[int]),
([1, 2.0, "a"], List[Union[int, float, str]]),
({"a": 1, "b": 2}, Dict[str, int]),
({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]),
({"a": 1, "b": 2}, Mapping[str, int]),
({"a": 1, 2: "b"}, Mapping[Union[int, str], Union[str, int]]),
(CustomDict(), CustomDict),
(ChildCustomDict(), ChildCustomDict),
(GenericDict({1: 1}), Dict[int, int]),
(ChildGenericDict({1: 1}), Dict[int, int]),
(GenericDict({1: 1}), Mapping[int, int]),
(ChildGenericDict({1: 1}), Mapping[int, int]),
],
)
def test_figure_out_type(value, expected):