Merge remote-tracking branch 'origin/main' into masenf/fix-default-color-mode

This commit is contained in:
Masen Furer 2025-01-21 13:16:03 -08:00
commit d7dec3935b
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
46 changed files with 734 additions and 305 deletions

View File

@ -33,7 +33,7 @@ env:
PR_TITLE: ${{ github.event.pull_request.title }}
jobs:
example-counter:
example-counter-and-nba-proxy:
env:
OUTPUT_FILE: import_benchmark.json
timeout-minutes: 30
@ -119,6 +119,26 @@ jobs:
--benchmark-json "./reflex-examples/counter/${{ env.OUTPUT_FILE }}"
--branch-name "${{ github.head_ref || github.ref_name }}" --pr-id "${{ github.event.pull_request.id }}"
--app-name "counter"
- name: Install requirements for nba proxy example
working-directory: ./reflex-examples/nba-proxy
run: |
poetry run uv pip install -r requirements.txt
- name: Install additional dependencies for DB access
run: poetry run uv pip install psycopg
- name: Check export --backend-only before init for nba-proxy example
working-directory: ./reflex-examples/nba-proxy
run: |
poetry run reflex export --backend-only
- name: Init Website for nba-proxy example
working-directory: ./reflex-examples/nba-proxy
run: |
poetry run reflex init --loglevel debug
- name: Run Website and Check for errors
run: |
# Check that npm is home
npm -v
poetry run bash scripts/integration.sh ./reflex-examples/nba-proxy dev
reflex-web:
strategy:

View File

@ -86,8 +86,8 @@ build-backend = "poetry.core.masonry.api"
target-version = "py39"
output-format = "concise"
lint.isort.split-on-trailing-comma = false
lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "T", "W"]
lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012"]
lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "T", "TRY", "W"]
lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012", "TRY0"]
lint.pydocstyle.convention = "google"
[tool.ruff.lint.per-file-ignores]

View File

@ -3,6 +3,7 @@ import axios from "axios";
import io from "socket.io-client";
import JSON5 from "json5";
import env from "$/env.json";
import reflexEnvironment from "$/reflex.json";
import Cookies from "universal-cookie";
import { useEffect, useRef, useState } from "react";
import Router, { useRouter } from "next/router";
@ -407,10 +408,18 @@ export const connect = async (
socket.current = io(endpoint.href, {
path: endpoint["pathname"],
transports: transports,
protocols: env.TEST_MODE ? undefined : [reflexEnvironment.version],
autoUnref: false,
});
// Ensure undefined fields in events are sent as null instead of removed
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v)
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v);
socket.current.io.decoder.tryParse = (str) => {
try {
return JSON5.parse(str);
} catch (e) {
return false;
}
};
function checkVisibility() {
if (document.visibilityState === "visible") {

View File

@ -463,14 +463,8 @@ class App(MiddlewareMixin, LifespanMixin):
Returns:
The generated component.
Raises:
exceptions.MatchTypeError: If the return types of match cases in rx.match are different.
"""
try:
return component if isinstance(component, Component) else component()
except exceptions.MatchTypeError:
raise
return component if isinstance(component, Component) else component()
def add_page(
self,
@ -564,11 +558,12 @@ class App(MiddlewareMixin, LifespanMixin):
meta=meta,
)
def _compile_page(self, route: str):
def _compile_page(self, route: str, save_page: bool = True):
"""Compile a page.
Args:
route: The route of the page to compile.
save_page: If True, the compiled page is saved to self.pages.
"""
component, enable_state = compiler.compile_unevaluated_page(
route, self.unevaluated_pages[route], self.state, self.style, self.theme
@ -579,7 +574,8 @@ class App(MiddlewareMixin, LifespanMixin):
# Add the page.
self._check_routes_conflict(route)
self.pages[route] = component
if save_page:
self.pages[route] = component
def get_load_events(self, route: str) -> list[IndividualEventType[[], Any]]:
"""Get the load events for a route.
@ -879,14 +875,16 @@ class App(MiddlewareMixin, LifespanMixin):
# If a theme component was provided, wrap the app with it
app_wrappers[(20, "Theme")] = self.theme
should_compile = self._should_compile()
for route in self.unevaluated_pages:
console.debug(f"Evaluating page: {route}")
self._compile_page(route)
self._compile_page(route, save_page=should_compile)
# Add the optional endpoints (_upload)
self._add_optional_endpoints()
if not self._should_compile():
if not should_compile:
return
self._validate_var_dependencies()
@ -1530,7 +1528,11 @@ class EventNamespace(AsyncNamespace):
sid: The Socket.IO session id.
environ: The request information, including HTTP headers.
"""
pass
subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL", None)
if subprotocol and subprotocol != constants.Reflex.VERSION:
console.warn(
f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}."
)
def on_disconnect(self, sid):
"""Event for when the websocket disconnects.
@ -1563,10 +1565,36 @@ class EventNamespace(AsyncNamespace):
Args:
sid: The Socket.IO session id.
data: The event data.
Raises:
EventDeserializationError: If the event data is not a dictionary.
"""
fields = data
# Get the event.
event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS})
if isinstance(fields, str):
console.warn(
"Received event data as a string. This generally should not happen and may indicate a bug."
f" Event data: {fields}"
)
try:
fields = json.loads(fields)
except json.JSONDecodeError as ex:
raise exceptions.EventDeserializationError(
f"Failed to deserialize event data: {fields}."
) from ex
if not isinstance(fields, dict):
raise exceptions.EventDeserializationError(
f"Event data must be a dictionary, but received {fields} of type {type(fields)}."
)
try:
# Get the event.
event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS})
except (TypeError, ValueError) as ex:
raise exceptions.EventDeserializationError(
f"Failed to deserialize event data: {fields}."
) from ex
self.token_to_sid[event.token] = sid
self.sid_to_token[sid] = event.token

View File

@ -7,14 +7,13 @@ from concurrent.futures import ThreadPoolExecutor
from reflex import constants
from reflex.utils import telemetry
from reflex.utils.exec import is_prod_mode
from reflex.utils.prerequisites import get_app
from reflex.utils.prerequisites import get_and_validate_app
if constants.CompileVars.APP != "app":
raise AssertionError("unexpected variable name for 'app'")
telemetry.send("compile")
app_module = get_app(reload=False)
app = getattr(app_module, constants.CompileVars.APP)
app, app_module = get_and_validate_app(reload=False)
# For py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172).
app._apply_decorated_pages()
@ -30,7 +29,7 @@ if is_prod_mode():
# ensure only "app" is exposed.
del app_module
del compile_future
del get_app
del get_and_validate_app
del is_prod_mode
del telemetry
del constants

View File

@ -429,20 +429,22 @@ class Component(BaseComponent, ABC):
else:
continue
def determine_key(value):
# Try to create a var from the value
key = value if isinstance(value, Var) else LiteralVar.create(value)
# Check that the var type is not None.
if key is None:
raise TypeError
return key
# Check whether the key is a component prop.
if types._issubclass(field_type, Var):
# Used to store the passed types if var type is a union.
passed_types = None
try:
# Try to create a var from the value.
if isinstance(value, Var):
kwargs[key] = value
else:
kwargs[key] = LiteralVar.create(value)
# Check that the var type is not None.
if kwargs[key] is None:
raise TypeError
kwargs[key] = determine_key(value)
expected_type = fields[key].outer_type_.__args__[0]
# validate literal fields.
@ -740,22 +742,21 @@ class Component(BaseComponent, ABC):
# Import here to avoid circular imports.
from reflex.components.base.bare import Bare
from reflex.components.base.fragment import Fragment
from reflex.utils.exceptions import ComponentTypeError
from reflex.utils.exceptions import ChildrenTypeError
# Filter out None props
props = {key: value for key, value in props.items() if value is not None}
def validate_children(children):
for child in children:
if isinstance(child, tuple):
if isinstance(child, (tuple, list)):
validate_children(child)
# Make sure the child is a valid type.
if not types._isinstance(child, ComponentChild):
raise ComponentTypeError(
"Children of Reflex components must be other components, "
"state vars, or primitive Python types. "
f"Got child {child} of type {type(child)}.",
)
if isinstance(child, dict) or not types._isinstance(
child, ComponentChild
):
raise ChildrenTypeError(component=cls.__name__, child=child)
# Validate all the children.
validate_children(children)

View File

@ -136,6 +136,23 @@ def load_dynamic_serializer():
module_code_lines.insert(0, "const React = window.__reflex.react;")
function_line = next(
index
for index, line in enumerate(module_code_lines)
if line.startswith("export default function")
)
module_code_lines = [
line
for _, line in sorted(
enumerate(module_code_lines),
key=lambda x: (
not (x[1].startswith("import ") and x[0] < function_line),
x[0],
),
)
]
return "\n".join(
[
"//__reflex_evaluate",

View File

@ -2,13 +2,15 @@
from reflex.components.component import Component
from reflex.utils import format
from reflex.vars.base import Var
from reflex.utils.imports import ImportVar
from reflex.vars.base import LiteralVar, Var
from reflex.vars.sequence import LiteralStringVar
class LucideIconComponent(Component):
"""Lucide Icon Component."""
library = "lucide-react@0.469.0"
library = "lucide-react@0.471.1"
class Icon(LucideIconComponent):
@ -32,6 +34,7 @@ class Icon(LucideIconComponent):
Raises:
AttributeError: The errors tied to bad usage of the Icon component.
ValueError: If the icon tag is invalid.
TypeError: If the icon name is not a string.
Returns:
The created component.
@ -39,7 +42,6 @@ class Icon(LucideIconComponent):
if children:
if len(children) == 1 and isinstance(children[0], str):
props["tag"] = children[0]
children = []
else:
raise AttributeError(
f"Passing multiple children to Icon component is not allowed: remove positional arguments {children[1:]} to fix"
@ -47,19 +49,46 @@ class Icon(LucideIconComponent):
if "tag" not in props:
raise AttributeError("Missing 'tag' keyword-argument for Icon")
tag: str | Var | LiteralVar = props.pop("tag")
if isinstance(tag, LiteralVar):
if isinstance(tag, LiteralStringVar):
tag = tag._var_value
else:
raise TypeError(f"Icon name must be a string, got {type(tag)}")
elif isinstance(tag, Var):
return DynamicIcon.create(name=tag, **props)
if (
not isinstance(props["tag"], str)
or format.to_snake_case(props["tag"]) not in LUCIDE_ICON_LIST
not isinstance(tag, str)
or format.to_snake_case(tag) not in LUCIDE_ICON_LIST
):
raise ValueError(
f"Invalid icon tag: {props['tag']}. Please use one of the following: {', '.join(LUCIDE_ICON_LIST[0:25])}, ..."
f"Invalid icon tag: {tag}. Please use one of the following: {', '.join(LUCIDE_ICON_LIST[0:25])}, ..."
"\nSee full list at https://lucide.dev/icons."
)
props["tag"] = format.to_title_case(format.to_snake_case(props["tag"])) + "Icon"
if tag in LUCIDE_ICON_MAPPING_OVERRIDE:
props["tag"] = LUCIDE_ICON_MAPPING_OVERRIDE[tag]
else:
props["tag"] = format.to_title_case(format.to_snake_case(tag)) + "Icon"
props["alias"] = f"Lucide{props['tag']}"
props.setdefault("color", "var(--current-color)")
return super().create(*children, **props)
return super().create(**props)
class DynamicIcon(LucideIconComponent):
"""A DynamicIcon component."""
tag = "DynamicIcon"
name: Var[str]
def _get_imports(self):
_imports = super()._get_imports()
if self.library:
_imports.pop(self.library)
_imports["lucide-react/dynamic"] = [ImportVar("DynamicIcon", install=False)]
return _imports
LUCIDE_ICON_LIST = [
@ -841,6 +870,7 @@ LUCIDE_ICON_LIST = [
"house",
"house_plug",
"house_plus",
"house_wifi",
"ice_cream_bowl",
"ice_cream_cone",
"id_card",
@ -1529,6 +1559,7 @@ LUCIDE_ICON_LIST = [
"trending_up_down",
"triangle",
"triangle_alert",
"triangle_dashed",
"triangle_right",
"trophy",
"truck",
@ -1634,3 +1665,10 @@ LUCIDE_ICON_LIST = [
"zoom_in",
"zoom_out",
]
# The default transformation of some icon names doesn't match how the
# icons are exported from Lucide. Manual overrides can go here.
LUCIDE_ICON_MAPPING_OVERRIDE = {
"grid_2x_2_check": "Grid2x2Check",
"grid_2x_2_x": "Grid2x2X",
}

View File

@ -104,12 +104,60 @@ class Icon(LucideIconComponent):
Raises:
AttributeError: The errors tied to bad usage of the Icon component.
ValueError: If the icon tag is invalid.
TypeError: If the icon name is not a string.
Returns:
The created component.
"""
...
class DynamicIcon(LucideIconComponent):
@overload
@classmethod
def create( # type: ignore
cls,
*children,
name: Optional[Union[Var[str], str]] = None,
style: Optional[Style] = None,
key: Optional[Any] = None,
id: Optional[Any] = None,
class_name: Optional[Any] = None,
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
on_blur: Optional[EventType[[], BASE_STATE]] = None,
on_click: Optional[EventType[[], BASE_STATE]] = None,
on_context_menu: Optional[EventType[[], BASE_STATE]] = None,
on_double_click: Optional[EventType[[], BASE_STATE]] = None,
on_focus: Optional[EventType[[], BASE_STATE]] = None,
on_mount: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_down: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_move: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_out: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_over: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_up: Optional[EventType[[], BASE_STATE]] = None,
on_scroll: Optional[EventType[[], BASE_STATE]] = None,
on_unmount: Optional[EventType[[], BASE_STATE]] = None,
**props,
) -> "DynamicIcon":
"""Create the component.
Args:
*children: The children of the component.
style: The style of the component.
key: A unique key for the component.
id: The id for the component.
class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute
**props: The props of the component.
Returns:
The component.
"""
...
LUCIDE_ICON_LIST = [
"a_arrow_down",
"a_arrow_up",
@ -889,6 +937,7 @@ LUCIDE_ICON_LIST = [
"house",
"house_plug",
"house_plus",
"house_wifi",
"ice_cream_bowl",
"ice_cream_cone",
"id_card",
@ -1577,6 +1626,7 @@ LUCIDE_ICON_LIST = [
"trending_up_down",
"triangle",
"triangle_alert",
"triangle_dashed",
"triangle_right",
"trophy",
"truck",
@ -1682,3 +1732,7 @@ LUCIDE_ICON_LIST = [
"zoom_in",
"zoom_out",
]
LUCIDE_ICON_MAPPING_OVERRIDE = {
"grid_2x_2_check": "Grid2x2Check",
"grid_2x_2_x": "Grid2x2X",
}

View File

@ -151,8 +151,8 @@ class ColorModeIconButton(IconButton):
dropdown_menu.trigger(
super().create(
ColorModeIcon.create(),
**props,
)
),
**props,
),
dropdown_menu.content(
color_mode_item("light"),

View File

@ -12,6 +12,7 @@ import threading
import urllib.parse
from importlib.util import find_spec
from pathlib import Path
from types import ModuleType
from typing import (
TYPE_CHECKING,
Any,
@ -567,6 +568,9 @@ class EnvironmentVariables:
# The maximum size of the reflex state in kilobytes.
REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000)
# Whether to use the turbopack bundler.
REFLEX_USE_TURBOPACK: EnvVar[bool] = env_var(True)
environment = EnvironmentVariables()
@ -604,6 +608,9 @@ class Config(Base):
# The name of the app (should match the name of the app directory).
app_name: str
# The path to the app module.
app_module_import: Optional[str] = None
# The log level to use.
loglevel: constants.LogLevel = constants.LogLevel.DEFAULT
@ -726,6 +733,19 @@ class Config(Base):
"REDIS_URL is required when using the redis state manager."
)
@property
def app_module(self) -> ModuleType | None:
"""Return the app module if `app_module_import` is set.
Returns:
The app module.
"""
return (
importlib.import_module(self.app_module_import)
if self.app_module_import
else None
)
@property
def module(self) -> str:
"""Get the module name of the app.
@ -733,6 +753,8 @@ class Config(Base):
Returns:
The module name.
"""
if self.app_module is not None:
return self.app_module.__name__
return ".".join([self.app_name, self.app_name])
def update_from_env(self) -> dict[str, Any]:
@ -871,7 +893,7 @@ def get_config(reload: bool = False) -> Config:
return cached_rxconfig.config
with _config_lock:
sys_path = sys.path.copy()
orig_sys_path = sys.path.copy()
sys.path.clear()
sys.path.append(str(Path.cwd()))
try:
@ -879,9 +901,14 @@ def get_config(reload: bool = False) -> Config:
return _get_config()
except Exception:
# If the module import fails, try to import with the original sys.path.
sys.path.extend(sys_path)
sys.path.extend(orig_sys_path)
return _get_config()
finally:
# Find any entries added to sys.path by rxconfig.py itself.
extra_paths = [
p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd())
]
# Restore the original sys.path.
sys.path.clear()
sys.path.extend(sys_path)
sys.path.extend(extra_paths)
sys.path.extend(orig_sys_path)

View File

@ -1,6 +1,7 @@
"""The constants package."""
from .base import (
APP_HARNESS_FLAG,
COOKIES,
IS_LINUX,
IS_MACOS,

View File

@ -257,6 +257,7 @@ SESSION_STORAGE = "session_storage"
# Testing variables.
# Testing os env set by pytest when running a test case.
PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
APP_HARNESS_FLAG = "APP_HARNESS_FLAG"
REFLEX_VAR_OPENING_TAG = "<reflex.Var>"
REFLEX_VAR_CLOSING_TAG = "</reflex.Var>"

View File

@ -182,7 +182,7 @@ class PackageJson(SimpleNamespace):
"@emotion/react": "11.13.3",
"axios": "1.7.7",
"json5": "2.2.3",
"next": "14.2.16",
"next": "15.1.4",
"next-sitemap": "4.2.3",
"next-themes": "0.4.3",
"react": "18.3.1",

View File

@ -421,12 +421,13 @@ def _run_commands_in_subprocess(cmds: list[str]) -> bool:
console.debug(f"Running command: {' '.join(cmds)}")
try:
result = subprocess.run(cmds, capture_output=True, text=True, check=True)
console.debug(result.stdout)
return True
except subprocess.CalledProcessError as cpe:
console.error(cpe.stdout)
console.error(cpe.stderr)
return False
else:
console.debug(result.stdout)
return True
def _make_pyi_files():
@ -931,10 +932,11 @@ def _get_file_from_prompt_in_loop() -> Tuple[bytes, str] | None:
file_extension = image_filepath.suffix
try:
image_file = image_filepath.read_bytes()
return image_file, file_extension
except OSError as ose:
console.error(f"Unable to read the {file_extension} file due to {ose}")
raise typer.Exit(code=1) from ose
else:
return image_file, file_extension
console.debug(f"File extension detected: {file_extension}")
return None

View File

@ -1591,7 +1591,7 @@ def get_handler_args(
def fix_events(
events: list[EventHandler | EventSpec] | None,
events: list[EventSpec | EventHandler] | None,
token: str,
router_data: dict[str, Any] | None = None,
) -> list[Event]:

View File

@ -440,7 +440,11 @@ def deploy(
config.app_name,
"--app-name",
help="The name of the App to deploy under.",
hidden=True,
),
app_id: str = typer.Option(
None,
"--app-id",
help="The ID of the App to deploy over.",
),
regions: List[str] = typer.Option(
[],
@ -480,6 +484,11 @@ def deploy(
"--project",
help="project id to deploy to",
),
project_name: Optional[str] = typer.Option(
None,
"--project-name",
help="The name of the project to deploy to.",
),
token: Optional[str] = typer.Option(
None,
"--token",
@ -519,9 +528,12 @@ def deploy(
if prerequisites.needs_reinit(frontend=True):
_init(name=config.app_name, loglevel=loglevel)
prerequisites.check_latest_package_version(constants.ReflexHostingCLI.MODULE_NAME)
extra: dict[str, str] = (
{"config_path": config_path} if config_path is not None else {}
)
hosting_cli.deploy(
app_name=app_name,
app_id=app_id,
export_fn=lambda zip_dest_dir,
api_url,
deploy_url,
@ -546,6 +558,8 @@ def deploy(
token=token,
project=project,
config_path=config_path,
project_name=project_name,
**extra,
)

View File

@ -104,6 +104,7 @@ from reflex.utils.exceptions import (
LockExpiredError,
ReflexRuntimeError,
SetUndefinedStateVarError,
StateMismatchError,
StateSchemaMismatchError,
StateSerializationError,
StateTooLargeError,
@ -1199,7 +1200,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
fget=func,
auto_deps=False,
deps=["router"],
cache=True,
_js_expr=param,
_var_data=VarData.from_state(cls),
)
@ -1543,7 +1543,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Return the direct parent of target_state_cls for subsequent linking.
return parent_state
def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from the cache.
Args:
@ -1551,11 +1551,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Returns:
The instance of state_cls associated with this state's client_token.
Raises:
StateMismatchError: If the state instance is not of the expected type.
"""
root_state = self._get_root_state()
return root_state.get_substate(state_cls.get_full_name().split("."))
substate = root_state.get_substate(state_cls.get_full_name().split("."))
if not isinstance(substate, state_cls):
raise StateMismatchError(
f"Searched for state {state_cls.get_full_name()} but found {substate}."
)
return substate
async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from redis.
Args:
@ -1566,6 +1574,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Raises:
RuntimeError: If redis is not used in this backend process.
StateMismatchError: If the state instance is not of the expected type.
"""
# Fetch all missing parent states from redis.
parent_state_of_state_cls = await self._populate_parent_states(state_cls)
@ -1577,14 +1586,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
"(All states should already be available -- this is likely a bug).",
)
return await state_manager.get_state(
state_in_redis = await state_manager.get_state(
token=_substate_key(self.router.session.client_token, state_cls),
top_level=False,
get_substates=True,
parent_state=parent_state_of_state_cls,
)
async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
if not isinstance(state_in_redis, state_cls):
raise StateMismatchError(
f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
)
return state_in_redis
async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get an instance of the state associated with this token.
Allows for arbitrary access to sibling states from within an event handler.
@ -1759,9 +1776,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
except Exception as ex:
state._clean()
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
event_specs = app_instance.backend_exception_handler(ex)
event_specs = (
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
)
if event_specs is None:
return StateUpdate()
@ -1871,9 +1888,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
except Exception as ex:
telemetry.send_error(ex, context="backend")
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
event_specs = app_instance.backend_exception_handler(ex)
event_specs = (
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
)
yield state._as_state_update(
handler,
@ -2316,6 +2333,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return state
T_STATE = TypeVar("T_STATE", bound=BaseState)
class State(BaseState):
"""The app Base State."""
@ -2383,8 +2403,9 @@ class FrontendEventExceptionState(State):
component_stack: The stack trace of the component where the exception occurred.
"""
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance.frontend_exception_handler(Exception(stack))
prerequisites.get_and_validate_app().app.frontend_exception_handler(
Exception(stack)
)
class UpdateVarsInternalState(State):
@ -2422,15 +2443,16 @@ class OnLoadInternalState(State):
The list of events to queue for on load handling.
"""
# Do not app._compile()! It should be already compiled by now.
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
load_events = app.get_load_events(self.router.page.path)
load_events = prerequisites.get_and_validate_app().app.get_load_events(
self.router.page.path
)
if not load_events:
self.is_hydrated = True
return # Fast path for navigation with no on_load events defined.
self.is_hydrated = False
return [
*fix_events(
load_events,
cast(list[Union[EventSpec, EventHandler]], load_events),
self.router.session.client_token,
router_data=self.router_data,
),
@ -2589,7 +2611,7 @@ class StateProxy(wrapt.ObjectProxy):
"""
super().__init__(state_instance)
# compile is not relevant to backend logic
self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
self._self_app = prerequisites.get_and_validate_app().app
self._self_substate_path = tuple(state_instance.get_full_name().split("."))
self._self_actx = None
self._self_mutable = False
@ -3682,8 +3704,7 @@ def get_state_manager() -> StateManager:
Returns:
The state manager.
"""
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
return app.state_manager
return prerequisites.get_and_validate_app().app.state_manager
class MutableProxy(wrapt.ObjectProxy):

View File

@ -282,6 +282,7 @@ class AppHarness:
before_decorated_pages = reflex.app.DECORATED_PAGES[self.app_name].copy()
# Ensure the AppHarness test does not skip State assignment due to running via pytest
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
os.environ[reflex.constants.APP_HARNESS_FLAG] = "true"
self.app_module = reflex.utils.prerequisites.get_compiled_app(
# Do not reload the module for pre-existing apps (only apps generated from source)
reload=self.app_source is not None

View File

@ -13,13 +13,17 @@ from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from reflex import constants
from reflex.config import get_config
from reflex.utils import console, path_ops, prerequisites, processes
from reflex.utils.exec import is_in_app_harness
def set_env_json():
"""Write the upload url to a REFLEX_JSON."""
path_ops.update_json_file(
str(prerequisites.get_web_dir() / constants.Dirs.ENV_JSON),
{endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint},
{
**{endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint},
"TEST_MODE": is_in_app_harness(),
},
)

View File

@ -2,6 +2,11 @@
from __future__ import annotations
import inspect
import shutil
from pathlib import Path
from types import FrameType
from rich.console import Console
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from rich.prompt import Prompt
@ -188,6 +193,33 @@ def warn(msg: str, dedupe: bool = False, **kwargs):
print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
def _get_first_non_framework_frame() -> FrameType | None:
import click
import typer
import typing_extensions
import reflex as rx
# Exclude utility modules that should never be the source of deprecated reflex usage.
exclude_modules = [click, rx, typer, typing_extensions]
exclude_roots = [
p.parent.resolve()
if (p := Path(m.__file__)).name == "__init__.py"
else p.resolve()
for m in exclude_modules
]
# Specifically exclude the reflex cli module.
if reflex_bin := shutil.which(b"reflex"):
exclude_roots.append(Path(reflex_bin.decode()))
frame = inspect.currentframe()
while frame := frame and frame.f_back:
frame_path = Path(inspect.getfile(frame)).resolve()
if not any(frame_path.is_relative_to(root) for root in exclude_roots):
break
return frame
def deprecate(
feature_name: str,
reason: str,
@ -206,15 +238,27 @@ def deprecate(
dedupe: If True, suppress multiple console logs of deprecation message.
kwargs: Keyword arguments to pass to the print function.
"""
if feature_name not in _EMITTED_DEPRECATION_WARNINGS:
dedupe_key = feature_name
loc = ""
# See if we can find where the deprecation exists in "user code"
origin_frame = _get_first_non_framework_frame()
if origin_frame is not None:
filename = Path(origin_frame.f_code.co_filename)
if filename.is_relative_to(Path.cwd()):
filename = filename.relative_to(Path.cwd())
loc = f"{filename}:{origin_frame.f_lineno}"
dedupe_key = f"{dedupe_key} {loc}"
if dedupe_key not in _EMITTED_DEPRECATION_WARNINGS:
msg = (
f"{feature_name} has been deprecated in version {deprecation_version} {reason.rstrip('.')}. It will be completely "
f"removed in {removal_version}"
f"removed in {removal_version}. ({loc})"
)
if _LOG_LEVEL <= LogLevel.WARNING:
print(f"[yellow]DeprecationWarning: {msg}[/yellow]", **kwargs)
if dedupe:
_EMITTED_DEPRECATION_WARNINGS.add(feature_name)
_EMITTED_DEPRECATION_WARNINGS.add(dedupe_key)
def error(msg: str, dedupe: bool = False, **kwargs):

View File

@ -1,6 +1,6 @@
"""Custom Exceptions."""
from typing import NoReturn
from typing import Any, NoReturn
class ReflexError(Exception):
@ -31,6 +31,22 @@ class ComponentTypeError(ReflexError, TypeError):
"""Custom TypeError for component related errors."""
class ChildrenTypeError(ComponentTypeError):
"""Raised when the children prop of a component is not a valid type."""
def __init__(self, component: str, child: Any):
"""Initialize the exception.
Args:
component: The name of the component.
child: The child that caused the error.
"""
super().__init__(
f"Component {component} received child {child} of type {type(child)}. "
"Accepted types are other components, state vars, or primitive Python types (dict excluded)."
)
class EventHandlerTypeError(ReflexError, TypeError):
"""Custom TypeError for event handler related errors."""
@ -163,10 +179,18 @@ class StateSerializationError(ReflexError):
"""Raised when the state cannot be serialized."""
class StateMismatchError(ReflexError, ValueError):
"""Raised when the state retrieved does not match the expected state."""
class SystemPackageMissingError(ReflexError):
"""Raised when a system package is missing."""
class EventDeserializationError(ReflexError, ValueError):
"""Raised when an event cannot be deserialized."""
def raise_system_package_missing_error(package: str) -> NoReturn:
"""Raise a SystemPackageMissingError.

View File

@ -240,6 +240,28 @@ def run_backend(
run_uvicorn_backend(host, port, loglevel)
def get_reload_dirs() -> list[str]:
"""Get the reload directories for the backend.
Returns:
The reload directories for the backend.
"""
config = get_config()
reload_dirs = [config.app_name]
if config.app_module is not None and config.app_module.__file__:
module_path = Path(config.app_module.__file__).resolve().parent
while module_path.parent.name:
for parent_file in module_path.parent.iterdir():
if parent_file == "__init__.py":
# go up a level to find dir without `__init__.py`
module_path = module_path.parent
break
else:
break
reload_dirs.append(str(module_path))
return reload_dirs
def run_uvicorn_backend(host, port, loglevel: LogLevel):
"""Run the backend in development mode using Uvicorn.
@ -256,7 +278,7 @@ def run_uvicorn_backend(host, port, loglevel: LogLevel):
port=port,
log_level=loglevel.value,
reload=True,
reload_dirs=[get_config().app_name],
reload_dirs=get_reload_dirs(),
)
@ -281,7 +303,7 @@ def run_granian_backend(host, port, loglevel: LogLevel):
interface=Interfaces.ASGI,
log_level=LogLevels(loglevel.value),
reload=True,
reload_paths=[Path(get_config().app_name)],
reload_paths=get_reload_dirs(),
reload_ignore_dirs=[".web"],
).serve()
except ImportError:
@ -487,6 +509,15 @@ def is_testing_env() -> bool:
return constants.PYTEST_CURRENT_TEST in os.environ
def is_in_app_harness() -> bool:
"""Whether the app is running in the app harness.
Returns:
True if the app is running in the app harness.
"""
return constants.APP_HARNESS_FLAG in os.environ
def is_prod_mode() -> bool:
"""Check if the app is running in production mode.

View File

@ -17,11 +17,12 @@ import stat
import sys
import tempfile
import time
import typing
import zipfile
from datetime import datetime
from pathlib import Path
from types import ModuleType
from typing import Callable, List, Optional
from typing import Callable, List, NamedTuple, Optional
import httpx
import typer
@ -42,9 +43,19 @@ from reflex.utils.exceptions import (
from reflex.utils.format import format_library_name
from reflex.utils.registry import _get_npm_registry
if typing.TYPE_CHECKING:
from reflex.app import App
CURRENTLY_INSTALLING_NODE = False
class AppInfo(NamedTuple):
"""A tuple containing the app instance and module."""
app: App
module: ModuleType
@dataclasses.dataclass(frozen=True)
class Template:
"""A template for a Reflex app."""
@ -267,6 +278,22 @@ def windows_npm_escape_hatch() -> bool:
return environment.REFLEX_USE_NPM.get()
def _check_app_name(config: Config):
"""Check if the app name is set in the config.
Args:
config: The config object.
Raises:
RuntimeError: If the app name is not set in the config.
"""
if not config.app_name:
raise RuntimeError(
"Cannot get the app module because `app_name` is not set in rxconfig! "
"If this error occurs in a reflex test case, ensure that `get_app` is mocked."
)
def get_app(reload: bool = False) -> ModuleType:
"""Get the app module based on the default config.
@ -277,22 +304,23 @@ def get_app(reload: bool = False) -> ModuleType:
The app based on the default config.
Raises:
RuntimeError: If the app name is not set in the config.
Exception: If an error occurs while getting the app module.
"""
from reflex.utils import telemetry
try:
environment.RELOAD_CONFIG.set(reload)
config = get_config()
if not config.app_name:
raise RuntimeError(
"Cannot get the app module because `app_name` is not set in rxconfig! "
"If this error occurs in a reflex test case, ensure that `get_app` is mocked."
)
_check_app_name(config)
module = config.module
sys.path.insert(0, str(Path.cwd()))
app = __import__(module, fromlist=(constants.CompileVars.APP,))
app = (
__import__(module, fromlist=(constants.CompileVars.APP,))
if not config.app_module
else config.app_module
)
if reload:
from reflex.state import reload_state_module
@ -301,11 +329,34 @@ def get_app(reload: bool = False) -> ModuleType:
# Reload the app module.
importlib.reload(app)
return app
except Exception as ex:
telemetry.send_error(ex, context="frontend")
raise
else:
return app
def get_and_validate_app(reload: bool = False) -> AppInfo:
"""Get the app instance based on the default config and validate it.
Args:
reload: Re-import the app module from disk
Returns:
The app instance and the app module.
Raises:
RuntimeError: If the app instance is not an instance of rx.App.
"""
from reflex.app import App
app_module = get_app(reload=reload)
app = getattr(app_module, constants.CompileVars.APP)
if not isinstance(app, App):
raise RuntimeError(
"The app instance in the specified app_module_import in rxconfig must be an instance of rx.App."
)
return AppInfo(app=app, module=app_module)
def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
@ -318,8 +369,7 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
Returns:
The compiled app based on the default config.
"""
app_module = get_app(reload=reload)
app = getattr(app_module, constants.CompileVars.APP)
app, app_module = get_and_validate_app(reload=reload)
# For py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172).
app._apply_decorated_pages()
@ -610,10 +660,14 @@ def initialize_web_directory():
init_reflex_json(project_hash=project_hash)
def _turbopack_flag() -> str:
return " --turbopack" if environment.REFLEX_USE_TURBOPACK.get() else ""
def _compile_package_json():
return templates.PACKAGE_JSON.render(
scripts={
"dev": constants.PackageJson.Commands.DEV,
"dev": constants.PackageJson.Commands.DEV + _turbopack_flag(),
"export": constants.PackageJson.Commands.EXPORT,
"export_sitemap": constants.PackageJson.Commands.EXPORT_SITEMAP,
"prod": constants.PackageJson.Commands.PROD,
@ -1149,11 +1203,12 @@ def ensure_reflex_installation_id() -> Optional[int]:
if installation_id is None:
installation_id = random.getrandbits(128)
installation_id_file.write_text(str(installation_id))
# If we get here, installation_id is definitely set
return installation_id
except Exception as e:
console.debug(f"Failed to ensure reflex installation id: {e}")
return None
else:
# If we get here, installation_id is definitely set
return installation_id
def initialize_reflex_user_directory():
@ -1367,19 +1422,22 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str
except OSError as ose:
console.error(f"Failed to create temp directory for extracting zip: {ose}")
raise typer.Exit(1) from ose
try:
zipfile.ZipFile(zip_file_path).extractall(path=unzip_dir)
# The zip file downloaded from github looks like:
# repo-name-branch/**/*, so we need to remove the top level directory.
if len(subdirs := os.listdir(unzip_dir)) != 1:
console.error(f"Expected one directory in the zip, found {subdirs}")
raise typer.Exit(1)
template_dir = unzip_dir / subdirs[0]
console.debug(f"Template folder is located at {template_dir}")
except Exception as uze:
console.error(f"Failed to unzip the template: {uze}")
raise typer.Exit(1) from uze
if len(subdirs := os.listdir(unzip_dir)) != 1:
console.error(f"Expected one directory in the zip, found {subdirs}")
raise typer.Exit(1)
template_dir = unzip_dir / subdirs[0]
console.debug(f"Template folder is located at {template_dir}")
# Move the rxconfig file here first.
path_ops.mv(str(template_dir / constants.Config.FILE), constants.Config.FILE)
new_config = get_config(reload=True)

View File

@ -17,6 +17,7 @@ import typer
from redis.exceptions import RedisError
from reflex import constants
from reflex.config import environment
from reflex.utils import console, path_ops, prerequisites
@ -156,24 +157,30 @@ def new_process(args, run: bool = False, show_logs: bool = False, **kwargs):
Raises:
Exit: When attempting to run a command with a None value.
"""
node_bin_path = str(path_ops.get_node_bin_path())
if not node_bin_path and not prerequisites.CURRENTLY_INSTALLING_NODE:
console.warn(
"The path to the Node binary could not be found. Please ensure that Node is properly "
"installed and added to your system's PATH environment variable or try running "
"`reflex init` again."
)
# Check for invalid command first.
if None in args:
console.error(f"Invalid command: {args}")
raise typer.Exit(1)
# Add the node bin path to the PATH environment variable.
path_env: str = os.environ.get("PATH", "")
# Add node_bin_path to the PATH environment variable.
if not environment.REFLEX_BACKEND_ONLY.get():
node_bin_path = str(path_ops.get_node_bin_path())
if not node_bin_path and not prerequisites.CURRENTLY_INSTALLING_NODE:
console.warn(
"The path to the Node binary could not be found. Please ensure that Node is properly "
"installed and added to your system's PATH environment variable or try running "
"`reflex init` again."
)
path_env = os.pathsep.join([node_bin_path, path_env])
env: dict[str, str] = {
**os.environ,
"PATH": os.pathsep.join(
[node_bin_path if node_bin_path else "", os.environ["PATH"]]
), # type: ignore
"PATH": path_env,
**kwargs.pop("env", {}),
}
kwargs = {
"env": env,
"stderr": None if show_logs else subprocess.STDOUT,

View File

@ -156,9 +156,10 @@ def _prepare_event(event: str, **kwargs) -> dict:
def _send_event(event_data: dict) -> bool:
try:
httpx.post(POSTHOG_API_URL, json=event_data)
return True
except Exception:
return False
else:
return True
def _send(event, telemetry_enabled, **kwargs):

View File

@ -829,6 +829,22 @@ StateBases = get_base_class(StateVar)
StateIterBases = get_base_class(StateIterVar)
def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]):
"""Check if a class is a subclass of another class. Returns False if internal error occurs.
Args:
cls: The class to check.
cls_check: The class to check against.
Returns:
Whether the class is a subclass of the other class.
"""
try:
return issubclass(cls, cls_check)
except TypeError:
return False
def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
"""Check if a type hint is a subclass of another type hint.

View File

@ -26,6 +26,7 @@ from typing import (
Iterable,
List,
Literal,
Mapping,
NoReturn,
Optional,
Set,
@ -64,6 +65,7 @@ from reflex.utils.types import (
_isinstance,
get_origin,
has_args,
safe_issubclass,
unionize,
)
@ -127,7 +129,7 @@ class VarData:
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, VarData | None] | None = None,
hooks: Mapping[str, VarData | None] | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
):
@ -561,7 +563,7 @@ class Var(Generic[VAR_TYPE]):
if _var_is_local is not None:
console.deprecate(
feature_name="_var_is_local",
reason="The _var_is_local argument is not supported for Var."
reason="The _var_is_local argument is not supported for Var. "
"If you want to create a Var from a raw Javascript expression, use the constructor directly",
deprecation_version="0.6.0",
removal_version="0.7.0",
@ -569,7 +571,7 @@ class Var(Generic[VAR_TYPE]):
if _var_is_string is not None:
console.deprecate(
feature_name="_var_is_string",
reason="The _var_is_string argument is not supported for Var."
reason="The _var_is_string argument is not supported for Var. "
"If you want to create a Var from a raw Javascript expression, use the constructor directly",
deprecation_version="0.6.0",
removal_version="0.7.0",
@ -643,8 +645,8 @@ class Var(Generic[VAR_TYPE]):
@overload
def to(
self,
output: type[dict],
) -> ObjectVar[dict]: ...
output: type[Mapping],
) -> ObjectVar[Mapping]: ...
@overload
def to(
@ -686,7 +688,9 @@ class Var(Generic[VAR_TYPE]):
# If the first argument is a python type, we map it to the corresponding Var type.
for var_subclass in _var_subclasses[::-1]:
if fixed_output_type in var_subclass.python_types:
if fixed_output_type in var_subclass.python_types or safe_issubclass(
fixed_output_type, var_subclass.python_types
):
return self.to(var_subclass.var_subclass, output)
if fixed_output_type is None:
@ -820,7 +824,7 @@ class Var(Generic[VAR_TYPE]):
return False
if issubclass(type_, list):
return []
if issubclass(type_, dict):
if issubclass(type_, Mapping):
return {}
if issubclass(type_, tuple):
return ()
@ -1026,7 +1030,7 @@ class Var(Generic[VAR_TYPE]):
f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")]
}
),
).to(ObjectVar, Dict[str, str])
).to(ObjectVar, Mapping[str, str])
return refs[LiteralVar.create(str(self))]
@deprecated("Use `.js_type()` instead.")
@ -1373,7 +1377,7 @@ class LiteralVar(Var):
serialized_value = serializers.serialize(value)
if serialized_value is not None:
if isinstance(serialized_value, dict):
if isinstance(serialized_value, Mapping):
return LiteralObjectVar.create(
serialized_value,
_var_type=type(value),
@ -1498,7 +1502,7 @@ def var_operation(
) -> Callable[P, ArrayVar[LIST_T]]: ...
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping)
@overload
@ -1573,8 +1577,8 @@ def figure_out_type(value: Any) -> types.GenericType:
return Set[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, tuple):
return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
if isinstance(value, dict):
return Dict[
if isinstance(value, Mapping):
return Mapping[
unionize(*(figure_out_type(k) for k in value)),
unionize(*(figure_out_type(v) for v in value.values())),
]
@ -1838,7 +1842,7 @@ class ComputedVar(Var[RETURN_TYPE]):
self,
fget: Callable[[BASE_STATE], RETURN_TYPE],
initial_value: RETURN_TYPE | types.Unset = types.Unset(),
cache: bool = False,
cache: bool = True,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[int, datetime.timedelta]] = None,
@ -2002,10 +2006,10 @@ class ComputedVar(Var[RETURN_TYPE]):
@overload
def __get__(
self: ComputedVar[dict[DICT_KEY, DICT_VAL]],
self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]],
instance: None,
owner: Type,
) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ...
) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
@overload
def __get__(
@ -2253,7 +2257,7 @@ if TYPE_CHECKING:
def computed_var(
fget: None = None,
initial_value: Any | types.Unset = types.Unset(),
cache: bool = False,
cache: bool = True,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
@ -2266,7 +2270,7 @@ def computed_var(
def computed_var(
fget: Callable[[BASE_STATE], RETURN_TYPE],
initial_value: RETURN_TYPE | types.Unset = types.Unset(),
cache: bool = False,
cache: bool = True,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
@ -2278,7 +2282,7 @@ def computed_var(
def computed_var(
fget: Callable[[BASE_STATE], Any] | None = None,
initial_value: Any | types.Unset = types.Unset(),
cache: Optional[bool] = None,
cache: bool = True,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
@ -2304,15 +2308,6 @@ def computed_var(
ValueError: If caching is disabled and an update interval is set.
VarDependencyError: If user supplies dependencies without caching.
"""
if cache is None:
cache = False
console.deprecate(
"Default non-cached rx.var",
"the default value will be `@rx.var(cache=True)` in a future release. "
"To retain uncached var, explicitly pass `@rx.var(cache=False)`",
deprecation_version="0.6.8",
removal_version="0.7.0",
)
if cache is False and interval is not None:
raise ValueError("Cannot set update interval without caching.")
@ -2924,11 +2919,14 @@ V = TypeVar("V")
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
FIELD_TYPE = TypeVar("FIELD_TYPE")
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
class Field(Generic[T]):
class Field(Generic[FIELD_TYPE]):
"""Shadow class for Var to allow for type hinting in the IDE."""
def __set__(self, instance, value: T):
def __set__(self, instance, value: FIELD_TYPE):
"""Set the Var.
Args:
@ -2940,7 +2938,9 @@ class Field(Generic[T]):
def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
@overload
def __get__(self: Field[int], instance: None, owner) -> NumberVar: ...
def __get__(
self: Field[int] | Field[float] | Field[int | float], instance: None, owner
) -> NumberVar: ...
@overload
def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
@ -2957,8 +2957,8 @@ class Field(Generic[T]):
@overload
def __get__(
self: Field[Dict[str, V]], instance: None, owner
) -> ObjectVar[Dict[str, V]]: ...
self: Field[MAPPING_TYPE], instance: None, owner
) -> ObjectVar[MAPPING_TYPE]: ...
@overload
def __get__(
@ -2966,10 +2966,10 @@ class Field(Generic[T]):
) -> ObjectVar[BASE_TYPE]: ...
@overload
def __get__(self, instance: None, owner) -> Var[T]: ...
def __get__(self, instance: None, owner) -> Var[FIELD_TYPE]: ...
@overload
def __get__(self, instance, owner) -> T: ...
def __get__(self, instance, owner) -> FIELD_TYPE: ...
def __get__(self, instance, owner): # type: ignore
"""Get the Var.
@ -2980,7 +2980,7 @@ class Field(Generic[T]):
"""
def field(value: T) -> Field[T]:
def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]:
"""Create a Field with a value.
Args:

View File

@ -390,6 +390,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
Returns:
The function var.
"""
return_expr = Var.create(return_expr)
return cls(
_js_expr="",
_var_type=_var_type,
@ -445,6 +446,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
Returns:
The function var.
"""
return_expr = Var.create(return_expr)
return cls(
_js_expr="",
_var_type=_var_type,

View File

@ -20,7 +20,6 @@ from typing import (
from reflex.constants.base import Dirs
from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError
from reflex.utils.imports import ImportDict, ImportVar
from reflex.utils.types import is_optional
from .base import (
CustomVarOperationReturn,
@ -431,7 +430,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("<", (type(self), type(other)))
return less_than_operation(self, +other)
return less_than_operation(+self, +other)
@overload
def __le__(self, other: number_types) -> BooleanVar: ...
@ -450,7 +449,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("<=", (type(self), type(other)))
return less_than_or_equal_operation(self, +other)
return less_than_or_equal_operation(+self, +other)
def __eq__(self, other: Any):
"""Equal comparison.
@ -462,7 +461,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
The result of the comparison.
"""
if isinstance(other, NUMBER_TYPES):
return equal_operation(self, +other)
return equal_operation(+self, +other)
return equal_operation(self, other)
def __ne__(self, other: Any):
@ -475,7 +474,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
The result of the comparison.
"""
if isinstance(other, NUMBER_TYPES):
return not_equal_operation(self, +other)
return not_equal_operation(+self, +other)
return not_equal_operation(self, other)
@overload
@ -495,7 +494,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types(">", (type(self), type(other)))
return greater_than_operation(self, +other)
return greater_than_operation(+self, +other)
@overload
def __ge__(self, other: number_types) -> BooleanVar: ...
@ -514,17 +513,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types(">=", (type(self), type(other)))
return greater_than_or_equal_operation(self, +other)
def bool(self):
"""Boolean conversion.
Returns:
The boolean value of the number.
"""
if is_optional(self._var_type):
return boolify((self != None) & (self != 0)) # noqa: E711
return self != 0
return greater_than_or_equal_operation(+self, +other)
def _is_strict_float(self) -> bool:
"""Check if the number is a float.

View File

@ -8,8 +8,8 @@ import typing
from inspect import isclass
from typing import (
Any,
Dict,
List,
Mapping,
NoReturn,
Tuple,
Type,
@ -19,6 +19,8 @@ from typing import (
overload,
)
from typing_extensions import is_typeddict
from reflex.utils import types
from reflex.utils.exceptions import VarAttributeError
from reflex.utils.types import GenericType, get_attribute_access_type, get_origin
@ -36,7 +38,7 @@ from .base import (
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
from .sequence import ArrayVar, StringVar
OBJECT_TYPE = TypeVar("OBJECT_TYPE")
OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)
KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE")
@ -46,7 +48,7 @@ ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE")
OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE")
class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
"""Base class for immutable object vars."""
def _key_type(self) -> Type:
@ -59,7 +61,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def _value_type(
self: ObjectVar[Dict[Any, VALUE_TYPE]],
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> Type[VALUE_TYPE]: ...
@overload
@ -74,7 +76,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
fixed_type = get_origin(self._var_type) or self._var_type
if not isclass(fixed_type):
return Any
args = get_args(self._var_type) if issubclass(fixed_type, dict) else ()
args = get_args(self._var_type) if issubclass(fixed_type, Mapping) else ()
return args[1] if args else Any
def keys(self) -> ArrayVar[List[str]]:
@ -87,7 +89,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def values(
self: ObjectVar[Dict[Any, VALUE_TYPE]],
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[VALUE_TYPE]]: ...
@overload
@ -103,7 +105,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def entries(
self: ObjectVar[Dict[Any, VALUE_TYPE]],
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ...
@overload
@ -133,49 +135,55 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any
@overload
def __getitem__(
self: ObjectVar[Dict[Any, NoReturn]],
self: ObjectVar[Mapping[Any, NoReturn]],
key: Var | Any,
) -> Var: ...
@overload
def __getitem__(
self: (ObjectVar[Mapping[Any, bool]]),
key: Var | Any,
) -> BooleanVar: ...
@overload
def __getitem__(
self: (
ObjectVar[Dict[Any, int]]
| ObjectVar[Dict[Any, float]]
| ObjectVar[Dict[Any, int | float]]
ObjectVar[Mapping[Any, int]]
| ObjectVar[Mapping[Any, float]]
| ObjectVar[Mapping[Any, int | float]]
),
key: Var | Any,
) -> NumberVar: ...
@overload
def __getitem__(
self: ObjectVar[Dict[Any, str]],
self: ObjectVar[Mapping[Any, str]],
key: Var | Any,
) -> StringVar: ...
@overload
def __getitem__(
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
key: Var | Any,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
key: Var | Any,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
def __getitem__(self, key: Var | Any) -> Var:
"""Get an item from the object.
@ -195,49 +203,49 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any
@overload
def __getattr__(
self: ObjectVar[Dict[Any, NoReturn]],
self: ObjectVar[Mapping[Any, NoReturn]],
name: str,
) -> Var: ...
@overload
def __getattr__(
self: (
ObjectVar[Dict[Any, int]]
| ObjectVar[Dict[Any, float]]
| ObjectVar[Dict[Any, int | float]]
ObjectVar[Mapping[Any, int]]
| ObjectVar[Mapping[Any, float]]
| ObjectVar[Mapping[Any, int | float]]
),
name: str,
) -> NumberVar: ...
@overload
def __getattr__(
self: ObjectVar[Dict[Any, str]],
self: ObjectVar[Mapping[Any, str]],
name: str,
) -> StringVar: ...
@overload
def __getattr__(
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
name: str,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
name: str,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
@overload
def __getattr__(
@ -266,8 +274,11 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
var_type = get_args(var_type)[0]
fixed_type = var_type if isclass(var_type) else get_origin(var_type)
if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or (
fixed_type in types.UnionTypes
if (
(isclass(fixed_type) and not issubclass(fixed_type, Mapping))
or (fixed_type in types.UnionTypes)
or is_typeddict(fixed_type)
):
attribute_type = get_attribute_access_type(var_type, name)
if attribute_type is None:
@ -299,7 +310,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
"""Base class for immutable literal object vars."""
_var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
_var_value: Mapping[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
default_factory=dict
)
@ -383,7 +394,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
@classmethod
def create(
cls,
_var_value: dict,
_var_value: Mapping,
_var_type: Type[OBJECT_TYPE] | None = None,
_var_data: VarData | None = None,
) -> LiteralObjectVar[OBJECT_TYPE]:
@ -466,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
"""
return var_operation_return(
js_expression=f"({{...{lhs}, ...{rhs}}})",
var_type=Dict[
var_type=Mapping[
Union[lhs._key_type(), rhs._key_type()],
Union[lhs._value_type(), rhs._value_type()],
],

View File

@ -987,7 +987,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
raise_unsupported_operand_types("[]", (type(self), type(i)))
return array_item_operation(self, i)
def length(self) -> NumberVar:
def length(self) -> NumberVar[int]:
"""Get the length of the array.
Returns:

View File

@ -22,22 +22,22 @@ def ComputedVars():
count: int = 0
# cached var with dep on count
@rx.var(cache=True, interval=15)
@rx.var(interval=15)
def count1(self) -> int:
return self.count
# cached backend var with dep on count
@rx.var(cache=True, interval=15, backend=True)
@rx.var(interval=15, backend=True)
def count1_backend(self) -> int:
return self.count
# same as above but implicit backend with `_` prefix
@rx.var(cache=True, interval=15)
@rx.var(interval=15)
def _count1_backend(self) -> int:
return self.count
# explicit disabled auto_deps
@rx.var(interval=15, cache=True, auto_deps=False)
@rx.var(interval=15, auto_deps=False)
def count3(self) -> int:
# this will not add deps, because auto_deps is False
print(self.count1)
@ -45,19 +45,27 @@ def ComputedVars():
return self.count
# explicit dependency on count var
@rx.var(cache=True, deps=["count"], auto_deps=False)
@rx.var(deps=["count"], auto_deps=False)
def depends_on_count(self) -> int:
return self.count
# explicit dependency on count1 var
@rx.var(cache=True, deps=[count1], auto_deps=False)
@rx.var(deps=[count1], auto_deps=False)
def depends_on_count1(self) -> int:
return self.count
@rx.var(deps=[count3], auto_deps=False, cache=True)
@rx.var(
deps=[count3],
auto_deps=False,
)
def depends_on_count3(self) -> int:
return self.count
# special floats should be properly decoded on the frontend
@rx.var(cache=True, initial_value=[])
def special_floats(self) -> list[float]:
return [42.9, float("nan"), float("inf"), float("-inf")]
@rx.event
def increment(self):
self.count += 1
@ -103,6 +111,11 @@ def ComputedVars():
State.depends_on_count3,
id="depends_on_count3",
),
rx.text("special_floats:"),
rx.text(
State.special_floats.join(", "),
id="special_floats",
),
),
)
@ -224,6 +237,10 @@ async def test_computed_vars(
assert depends_on_count3
assert depends_on_count3.text == "0"
special_floats = driver.find_element(By.ID, "special_floats")
assert special_floats
assert special_floats.text == "42.9, NaN, Infinity, -Infinity"
increment = driver.find_element(By.ID, "increment")
assert increment.is_enabled()

View File

@ -71,9 +71,10 @@ def has_error_modal(driver: WebDriver) -> bool:
"""
try:
driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH)
return True
except NoSuchElementException:
return False
else:
return True
@pytest.mark.asyncio

View File

@ -74,16 +74,16 @@ def DynamicRoute():
class ArgState(rx.State):
"""The app state."""
@rx.var
@rx.var(cache=False)
def arg(self) -> int:
return int(self.arg_str or 0)
class ArgSubState(ArgState):
@rx.var(cache=True)
@rx.var
def cached_arg(self) -> int:
return self.arg
@rx.var(cache=True)
@rx.var
def cached_arg_str(self) -> str:
return self.arg_str

View File

@ -36,7 +36,7 @@ def LifespanApp():
print("Lifespan global started.")
try:
while True:
lifespan_task_global += inc # pyright: ignore[reportUnboundVariable]
lifespan_task_global += inc # pyright: ignore[reportUnboundVariable, reportPossiblyUnboundVariable]
await asyncio.sleep(0.1)
except asyncio.CancelledError as ce:
print(f"Lifespan global cancelled: {ce}.")
@ -45,11 +45,11 @@ def LifespanApp():
class LifespanState(rx.State):
interval: int = 100
@rx.var
@rx.var(cache=False)
def task_global(self) -> int:
return lifespan_task_global
@rx.var
@rx.var(cache=False)
def context_global(self) -> int:
return lifespan_context_global

View File

@ -22,31 +22,31 @@ def MediaApp():
img.format = format # type: ignore
return img
@rx.var(cache=True)
@rx.var
def img_default(self) -> Image.Image:
return self._blue()
@rx.var(cache=True)
@rx.var
def img_bmp(self) -> Image.Image:
return self._blue(format="BMP")
@rx.var(cache=True)
@rx.var
def img_jpg(self) -> Image.Image:
return self._blue(format="JPEG")
@rx.var(cache=True)
@rx.var
def img_png(self) -> Image.Image:
return self._blue(format="PNG")
@rx.var(cache=True)
@rx.var
def img_gif(self) -> Image.Image:
return self._blue(format="GIF")
@rx.var(cache=True)
@rx.var
def img_webp(self) -> Image.Image:
return self._blue(format="WEBP")
@rx.var(cache=True)
@rx.var
def img_from_url(self) -> Image.Image:
img_url = "https://picsum.photos/id/1/200/300"
img_resp = httpx.get(img_url, follow_redirects=True)

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import List, Mapping, Tuple
import pytest
@ -67,7 +67,7 @@ def test_match_components():
assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}'
assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })'
assert match_cases[4][0]._var_type == Dict[str, str]
assert match_cases[4][0]._var_type == Mapping[str, str]
fifth_return_value_render = match_cases[4][1].render()
assert fifth_return_value_render["name"] == "RadixThemesText"
assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}'

View File

@ -1,13 +1,19 @@
import pytest
from reflex.components.lucide.icon import LUCIDE_ICON_LIST, Icon
from reflex.components.lucide.icon import (
LUCIDE_ICON_LIST,
LUCIDE_ICON_MAPPING_OVERRIDE,
Icon,
)
from reflex.utils import format
@pytest.mark.parametrize("tag", LUCIDE_ICON_LIST)
def test_icon(tag):
icon = Icon.create(tag)
assert icon.alias == f"Lucide{format.to_title_case(tag)}Icon"
assert icon.alias == "Lucide" + LUCIDE_ICON_MAPPING_OVERRIDE.get(
tag, f"{format.to_title_case(tag)}Icon"
)
def test_icon_missing_tag():

View File

@ -27,7 +27,7 @@ from reflex.event import (
from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
from reflex.utils.exceptions import EventFnArgMismatch
from reflex.utils.exceptions import ChildrenTypeError, EventFnArgMismatch
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
@ -645,14 +645,17 @@ def test_create_filters_none_props(test_component):
assert str(component.style["text-align"]) == '"center"'
@pytest.mark.parametrize("children", [((None,),), ("foo", ("bar", (None,)))])
@pytest.mark.parametrize(
"children",
[
((None,),),
("foo", ("bar", (None,))),
({"foo": "bar"},),
],
)
def test_component_create_unallowed_types(children, test_component):
with pytest.raises(TypeError) as err:
with pytest.raises(ChildrenTypeError):
test_component.create(*children)
assert (
err.value.args[0]
== "Children of Reflex components must be other components, state vars, or primitive Python types. Got child None of type <class 'NoneType'>."
)
@pytest.mark.parametrize(

View File

@ -908,7 +908,7 @@ class DynamicState(BaseState):
"""Increment the counter var."""
self.counter = self.counter + 1
@computed_var(cache=True)
@computed_var
def comp_dynamic(self) -> str:
"""A computed var that depends on the dynamic var.
@ -1549,11 +1549,11 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
base: int = 0
_backend: int = 0
@computed_var(cache=True)
@computed_var()
def foo(self) -> str:
return "foo"
@computed_var(deps=["_backend", "base", foo], cache=True)
@computed_var(deps=["_backend", "base", foo])
def bar(self) -> str:
return "bar"
@ -1565,7 +1565,7 @@ def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]):
app, _ = compilable_app
class InvalidDepState(BaseState):
@computed_var(deps=["foolksjdf"], cache=True)
@computed_var(deps=["foolksjdf"])
def bar(self) -> str:
return "bar"

View File

@ -202,7 +202,7 @@ class GrandchildState(ChildState):
class GrandchildState2(ChildState2):
"""A grandchild state fixture."""
@rx.var(cache=True)
@rx.var
def cached(self) -> str:
"""A cached var.
@ -215,7 +215,7 @@ class GrandchildState2(ChildState2):
class GrandchildState3(ChildState3):
"""A great grandchild state fixture."""
@rx.var
@rx.var(cache=False)
def computed(self) -> str:
"""A computed var.
@ -796,7 +796,7 @@ async def test_process_event_simple(test_state):
# The delta should contain the changes, including computed vars.
assert update.delta == {
TestState.get_full_name(): {"num1": 69, "sum": 72.14, "upper": ""},
TestState.get_full_name(): {"num1": 69, "sum": 72.14},
GrandchildState3.get_full_name(): {"computed": ""},
}
assert update.events == []
@ -823,7 +823,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
assert child_state.value == "HI"
assert child_state.count == 24
assert update.delta == {
TestState.get_full_name(): {"sum": 3.14, "upper": ""},
# TestState.get_full_name(): {"sum": 3.14, "upper": ""},
ChildState.get_full_name(): {"value": "HI", "count": 24},
GrandchildState3.get_full_name(): {"computed": ""},
}
@ -839,7 +839,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
update = await test_state._process(event).__anext__()
assert grandchild_state.value2 == "new"
assert update.delta == {
TestState.get_full_name(): {"sum": 3.14, "upper": ""},
# TestState.get_full_name(): {"sum": 3.14, "upper": ""},
GrandchildState.get_full_name(): {"value2": "new"},
GrandchildState3.get_full_name(): {"computed": ""},
}
@ -989,7 +989,7 @@ class InterdependentState(BaseState):
v1: int = 0
_v2: int = 1
@rx.var(cache=True)
@rx.var
def v1x2(self) -> int:
"""Depends on var v1.
@ -998,7 +998,7 @@ class InterdependentState(BaseState):
"""
return self.v1 * 2
@rx.var(cache=True)
@rx.var
def v2x2(self) -> int:
"""Depends on backend var _v2.
@ -1007,7 +1007,7 @@ class InterdependentState(BaseState):
"""
return self._v2 * 2
@rx.var(cache=True, backend=True)
@rx.var(backend=True)
def v2x2_backend(self) -> int:
"""Depends on backend var _v2.
@ -1016,7 +1016,7 @@ class InterdependentState(BaseState):
"""
return self._v2 * 2
@rx.var(cache=True)
@rx.var
def v1x2x2(self) -> int:
"""Depends on ComputedVar v1x2.
@ -1025,7 +1025,7 @@ class InterdependentState(BaseState):
"""
return self.v1x2 * 2 # type: ignore
@rx.var(cache=True)
@rx.var
def _v3(self) -> int:
"""Depends on backend var _v2.
@ -1034,7 +1034,7 @@ class InterdependentState(BaseState):
"""
return self._v2
@rx.var(cache=True)
@rx.var
def v3x2(self) -> int:
"""Depends on ComputedVar _v3.
@ -1239,7 +1239,7 @@ def test_computed_var_cached():
class ComputedState(BaseState):
v: int = 0
@rx.var(cache=True)
@rx.var
def comp_v(self) -> int:
nonlocal comp_v_calls
comp_v_calls += 1
@ -1264,15 +1264,15 @@ def test_computed_var_cached_depends_on_non_cached():
class ComputedState(BaseState):
v: int = 0
@rx.var
@rx.var(cache=False)
def no_cache_v(self) -> int:
return self.v
@rx.var(cache=True)
@rx.var
def dep_v(self) -> int:
return self.no_cache_v # type: ignore
@rx.var(cache=True)
@rx.var
def comp_v(self) -> int:
return self.v
@ -1304,14 +1304,14 @@ def test_computed_var_depends_on_parent_non_cached():
counter = 0
class ParentState(BaseState):
@rx.var
@rx.var(cache=False)
def no_cache_v(self) -> int:
nonlocal counter
counter += 1
return counter
class ChildState(ParentState):
@rx.var(cache=True)
@rx.var
def dep_v(self) -> int:
return self.no_cache_v # type: ignore
@ -1357,7 +1357,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
def handler(self):
self.x = self.x + 1
@rx.var(cache=True)
@rx.var
def cached_x_side_effect(self) -> int:
self.handler()
nonlocal counter
@ -1393,7 +1393,7 @@ def test_computed_var_dependencies():
def testprop(self) -> int:
return self.v
@rx.var(cache=True)
@rx.var
def comp_v(self) -> int:
"""Direct access.
@ -1402,7 +1402,7 @@ def test_computed_var_dependencies():
"""
return self.v
@rx.var(cache=True, backend=True)
@rx.var(backend=True)
def comp_v_backend(self) -> int:
"""Direct access backend var.
@ -1411,7 +1411,7 @@ def test_computed_var_dependencies():
"""
return self.v
@rx.var(cache=True)
@rx.var
def comp_v_via_property(self) -> int:
"""Access v via property.
@ -1420,7 +1420,7 @@ def test_computed_var_dependencies():
"""
return self.testprop
@rx.var(cache=True)
@rx.var
def comp_w(self):
"""Nested lambda.
@ -1429,7 +1429,7 @@ def test_computed_var_dependencies():
"""
return lambda: self.w
@rx.var(cache=True)
@rx.var
def comp_x(self):
"""Nested function.
@ -1442,7 +1442,7 @@ def test_computed_var_dependencies():
return _
@rx.var(cache=True)
@rx.var
def comp_y(self) -> List[int]:
"""Comprehension iterating over attribute.
@ -1451,7 +1451,7 @@ def test_computed_var_dependencies():
"""
return [round(y) for y in self.y]
@rx.var(cache=True)
@rx.var
def comp_z(self) -> List[bool]:
"""Comprehension accesses attribute.
@ -2027,10 +2027,6 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
assert mcall.args[0] == str(SocketEvent.EVENT)
assert mcall.args[1] == StateUpdate(
delta={
parent_state.get_full_name(): {
"upper": "",
"sum": 3.14,
},
grandchild_state.get_full_name(): {
"value2": "42",
},
@ -2053,7 +2049,7 @@ class BackgroundTaskState(BaseState):
super().__init__(**kwargs)
self.router_data = {"simulate": "hydrate"}
@rx.var
@rx.var(cache=False)
def computed_order(self) -> List[str]:
"""Get the order as a computed var.
@ -3040,10 +3036,6 @@ async def test_get_state(mock_app: rx.App, token: str):
grandchild_state.value2 = "set_value"
assert test_state.get_delta() == {
TestState.get_full_name(): {
"sum": 3.14,
"upper": "",
},
GrandchildState.get_full_name(): {
"value2": "set_value",
},
@ -3081,10 +3073,6 @@ async def test_get_state(mock_app: rx.App, token: str):
child_state2.value = "set_c2_value"
assert new_test_state.get_delta() == {
TestState.get_full_name(): {
"sum": 3.14,
"upper": "",
},
ChildState2.get_full_name(): {
"value": "set_c2_value",
},
@ -3139,7 +3127,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
child3_var: int = 0
@rx.var
@rx.var(cache=False)
def v(self):
pass
@ -3210,8 +3198,8 @@ def test_potentially_dirty_substates():
def bar(self) -> str:
return ""
assert RxState._potentially_dirty_substates() == {State}
assert State._potentially_dirty_substates() == {C1}
assert RxState._potentially_dirty_substates() == set()
assert State._potentially_dirty_substates() == set()
assert C1._potentially_dirty_substates() == set()
@ -3226,7 +3214,7 @@ def test_router_var_dep() -> None:
class RouterVarDepState(RouterVarParentState):
"""A state with a router var dependency."""
@rx.var(cache=True)
@rx.var
def foo(self) -> str:
return self.router.page.params.get("foo", "")
@ -3421,7 +3409,7 @@ class MixinState(State, mixin=True):
_backend: int = 0
_backend_no_default: dict
@rx.var(cache=True)
@rx.var
def computed(self) -> str:
"""A computed var on mixin state.

View File

@ -42,7 +42,7 @@ class SubA_A_A_A(SubA_A_A):
class SubA_A_A_B(SubA_A_A):
"""SubA_A_A_B is a child of SubA_A_A."""
@rx.var(cache=True)
@rx.var
def sub_a_a_a_cached(self) -> int:
"""A cached var.
@ -117,7 +117,7 @@ class TreeD(Root):
d: int
@rx.var
@rx.var(cache=False)
def d_var(self) -> int:
"""A computed var.
@ -156,7 +156,7 @@ class SubE_A_A_A_A(SubE_A_A_A):
sub_e_a_a_a_a: int
@rx.var
@rx.var(cache=False)
def sub_e_a_a_a_a_var(self) -> int:
"""A computed var.
@ -183,7 +183,7 @@ class SubE_A_A_A_D(SubE_A_A_A):
sub_e_a_a_a_d: int
@rx.var(cache=True)
@rx.var
def sub_e_a_a_a_d_var(self) -> int:
"""A computed var.

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict
from typing import Any, Mapping
import pytest
@ -379,7 +379,7 @@ class StyleState(rx.State):
{
"css": Var(
_js_expr=f'({{ ["color"] : ("dark"+{StyleState.color}) }})'
).to(Dict[str, str])
).to(Mapping[str, str])
},
),
(

View File

@ -2,7 +2,7 @@ import json
import math
import sys
import typing
from typing import Dict, List, Optional, Set, Tuple, Union, cast
from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast
import pytest
from pandas import DataFrame
@ -270,7 +270,7 @@ def test_get_setter(prop: Var, expected):
([1, 2, 3], Var(_js_expr="[1, 2, 3]", _var_type=List[int])),
(
{"a": 1, "b": 2},
Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Dict[str, int]),
Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Mapping[str, int]),
),
],
)
@ -1004,7 +1004,7 @@ def test_all_number_operations():
assert (
str(even_more_complicated_number)
== "!(((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)))) !== 0))"
== "!(isTrue((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))))))"
)
assert str(LiteralNumberVar.create(5) > False) == "(5 > 0)"
@ -1814,10 +1814,7 @@ def cv_fget(state: BaseState) -> int:
],
)
def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
@computed_var(
deps=deps,
cache=True,
)
@computed_var(deps=deps)
def test_var(state) -> int:
return 1
@ -1835,10 +1832,7 @@ def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
def test_invalid_computed_var_deps(deps: List):
with pytest.raises(TypeError):
@computed_var(
deps=deps,
cache=True,
)
@computed_var(deps=deps)
def test_var(state) -> int:
return 1

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Union
from typing import List, Mapping, Union
import pytest
@ -37,12 +37,12 @@ class ChildGenericDict(GenericDict):
("a", str),
([1, 2, 3], List[int]),
([1, 2.0, "a"], List[Union[int, float, str]]),
({"a": 1, "b": 2}, Dict[str, int]),
({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]),
({"a": 1, "b": 2}, Mapping[str, int]),
({"a": 1, 2: "b"}, Mapping[Union[int, str], Union[str, int]]),
(CustomDict(), CustomDict),
(ChildCustomDict(), ChildCustomDict),
(GenericDict({1: 1}), Dict[int, int]),
(ChildGenericDict({1: 1}), Dict[int, int]),
(GenericDict({1: 1}), Mapping[int, int]),
(ChildGenericDict({1: 1}), Mapping[int, int]),
],
)
def test_figure_out_type(value, expected):