diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml
index 017336ba5..2ca9aed23 100644
--- a/.github/workflows/integration_tests.yml
+++ b/.github/workflows/integration_tests.yml
@@ -33,7 +33,7 @@ env:
PR_TITLE: ${{ github.event.pull_request.title }}
jobs:
- example-counter:
+ example-counter-and-nba-proxy:
env:
OUTPUT_FILE: import_benchmark.json
timeout-minutes: 30
@@ -119,6 +119,26 @@ jobs:
--benchmark-json "./reflex-examples/counter/${{ env.OUTPUT_FILE }}"
--branch-name "${{ github.head_ref || github.ref_name }}" --pr-id "${{ github.event.pull_request.id }}"
--app-name "counter"
+ - name: Install requirements for nba proxy example
+ working-directory: ./reflex-examples/nba-proxy
+ run: |
+ poetry run uv pip install -r requirements.txt
+ - name: Install additional dependencies for DB access
+ run: poetry run uv pip install psycopg
+ - name: Check export --backend-only before init for nba-proxy example
+ working-directory: ./reflex-examples/nba-proxy
+ run: |
+ poetry run reflex export --backend-only
+ - name: Init Website for nba-proxy example
+ working-directory: ./reflex-examples/nba-proxy
+ run: |
+ poetry run reflex init --loglevel debug
+ - name: Run Website and Check for errors
+ run: |
+ # Check that npm is home
+ npm -v
+ poetry run bash scripts/integration.sh ./reflex-examples/nba-proxy dev
+
reflex-web:
strategy:
diff --git a/pyproject.toml b/pyproject.toml
index eccf21230..d1ae1dcf0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -86,8 +86,8 @@ build-backend = "poetry.core.masonry.api"
target-version = "py39"
output-format = "concise"
lint.isort.split-on-trailing-comma = false
-lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "T", "W"]
-lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012"]
+lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "T", "TRY", "W"]
+lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012", "TRY0"]
lint.pydocstyle.convention = "google"
[tool.ruff.lint.per-file-ignores]
diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js
index 41dbee446..5b8046347 100644
--- a/reflex/.templates/web/utils/state.js
+++ b/reflex/.templates/web/utils/state.js
@@ -3,6 +3,7 @@ import axios from "axios";
import io from "socket.io-client";
import JSON5 from "json5";
import env from "$/env.json";
+import reflexEnvironment from "$/reflex.json";
import Cookies from "universal-cookie";
import { useEffect, useRef, useState } from "react";
import Router, { useRouter } from "next/router";
@@ -407,10 +408,18 @@ export const connect = async (
socket.current = io(endpoint.href, {
path: endpoint["pathname"],
transports: transports,
+ protocols: env.TEST_MODE ? undefined : [reflexEnvironment.version],
autoUnref: false,
});
// Ensure undefined fields in events are sent as null instead of removed
- socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v)
+ socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v);
+ socket.current.io.decoder.tryParse = (str) => {
+ try {
+ return JSON5.parse(str);
+ } catch (e) {
+ return false;
+ }
+ };
function checkVisibility() {
if (document.visibilityState === "visible") {
diff --git a/reflex/app.py b/reflex/app.py
index 712dcee9f..e729eeefd 100644
--- a/reflex/app.py
+++ b/reflex/app.py
@@ -495,14 +495,8 @@ class App(MiddlewareMixin, LifespanMixin):
Returns:
The generated component.
-
- Raises:
- exceptions.MatchTypeError: If the return types of match cases in rx.match are different.
"""
- try:
- return component if isinstance(component, Component) else component()
- except exceptions.MatchTypeError:
- raise
+ return component if isinstance(component, Component) else component()
def add_page(
self,
@@ -596,11 +590,12 @@ class App(MiddlewareMixin, LifespanMixin):
meta=meta,
)
- def _compile_page(self, route: str):
+ def _compile_page(self, route: str, save_page: bool = True):
"""Compile a page.
Args:
route: The route of the page to compile.
+ save_page: If True, the compiled page is saved to self.pages.
"""
component, enable_state = compiler.compile_unevaluated_page(
route, self._unevaluated_pages[route], self._state, self.style, self.theme
@@ -611,7 +606,8 @@ class App(MiddlewareMixin, LifespanMixin):
# Add the page.
self._check_routes_conflict(route)
- self._pages[route] = component
+ if save_page:
+ self.pages[route] = component
def get_load_events(self, route: str) -> list[IndividualEventType[[], Any]]:
"""Get the load events for a route.
@@ -914,14 +910,16 @@ class App(MiddlewareMixin, LifespanMixin):
# If a theme component was provided, wrap the app with it
app_wrappers[(20, "Theme")] = self.theme
- for route in self._unevaluated_pages:
+ should_compile = self._should_compile()
+
+ for route in self.unevaluated_pages:
console.debug(f"Evaluating page: {route}")
- self._compile_page(route)
+ self._compile_page(route, save_page=should_compile)
# Add the optional endpoints (_upload)
self._add_optional_endpoints()
- if not self._should_compile():
+ if not should_compile:
return
self._validate_var_dependencies()
@@ -1565,7 +1563,11 @@ class EventNamespace(AsyncNamespace):
sid: The Socket.IO session id.
environ: The request information, including HTTP headers.
"""
- pass
+ subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL", None)
+ if subprotocol and subprotocol != constants.Reflex.VERSION:
+ console.warn(
+ f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}."
+ )
def on_disconnect(self, sid):
"""Event for when the websocket disconnects.
@@ -1598,10 +1600,36 @@ class EventNamespace(AsyncNamespace):
Args:
sid: The Socket.IO session id.
data: The event data.
+
+ Raises:
+ EventDeserializationError: If the event data is not a dictionary.
"""
fields = data
- # Get the event.
- event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS})
+
+ if isinstance(fields, str):
+ console.warn(
+ "Received event data as a string. This generally should not happen and may indicate a bug."
+ f" Event data: {fields}"
+ )
+ try:
+ fields = json.loads(fields)
+ except json.JSONDecodeError as ex:
+ raise exceptions.EventDeserializationError(
+ f"Failed to deserialize event data: {fields}."
+ ) from ex
+
+ if not isinstance(fields, dict):
+ raise exceptions.EventDeserializationError(
+ f"Event data must be a dictionary, but received {fields} of type {type(fields)}."
+ )
+
+ try:
+ # Get the event.
+ event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS})
+ except (TypeError, ValueError) as ex:
+ raise exceptions.EventDeserializationError(
+ f"Failed to deserialize event data: {fields}."
+ ) from ex
self.token_to_sid[event.token] = sid
self.sid_to_token[sid] = event.token
diff --git a/reflex/app_module_for_backend.py b/reflex/app_module_for_backend.py
index 8109fc3d6..b0ae0a29f 100644
--- a/reflex/app_module_for_backend.py
+++ b/reflex/app_module_for_backend.py
@@ -7,14 +7,13 @@ from concurrent.futures import ThreadPoolExecutor
from reflex import constants
from reflex.utils import telemetry
from reflex.utils.exec import is_prod_mode
-from reflex.utils.prerequisites import get_app
+from reflex.utils.prerequisites import get_and_validate_app
if constants.CompileVars.APP != "app":
raise AssertionError("unexpected variable name for 'app'")
telemetry.send("compile")
-app_module = get_app(reload=False)
-app = getattr(app_module, constants.CompileVars.APP)
+app, app_module = get_and_validate_app(reload=False)
# For py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172).
app._apply_decorated_pages()
@@ -30,7 +29,7 @@ if is_prod_mode():
# ensure only "app" is exposed.
del app_module
del compile_future
-del get_app
+del get_and_validate_app
del is_prod_mode
del telemetry
del constants
diff --git a/reflex/components/component.py b/reflex/components/component.py
index 8649b593d..ed90a0f24 100644
--- a/reflex/components/component.py
+++ b/reflex/components/component.py
@@ -429,20 +429,22 @@ class Component(BaseComponent, ABC):
else:
continue
+ def determine_key(value):
+ # Try to create a var from the value
+ key = value if isinstance(value, Var) else LiteralVar.create(value)
+
+ # Check that the var type is not None.
+ if key is None:
+ raise TypeError
+
+ return key
+
# Check whether the key is a component prop.
if types._issubclass(field_type, Var):
# Used to store the passed types if var type is a union.
passed_types = None
try:
- # Try to create a var from the value.
- if isinstance(value, Var):
- kwargs[key] = value
- else:
- kwargs[key] = LiteralVar.create(value)
-
- # Check that the var type is not None.
- if kwargs[key] is None:
- raise TypeError
+ kwargs[key] = determine_key(value)
expected_type = fields[key].outer_type_.__args__[0]
# validate literal fields.
@@ -740,22 +742,21 @@ class Component(BaseComponent, ABC):
# Import here to avoid circular imports.
from reflex.components.base.bare import Bare
from reflex.components.base.fragment import Fragment
- from reflex.utils.exceptions import ComponentTypeError
+ from reflex.utils.exceptions import ChildrenTypeError
# Filter out None props
props = {key: value for key, value in props.items() if value is not None}
def validate_children(children):
for child in children:
- if isinstance(child, tuple):
+ if isinstance(child, (tuple, list)):
validate_children(child)
+
# Make sure the child is a valid type.
- if not types._isinstance(child, ComponentChild):
- raise ComponentTypeError(
- "Children of Reflex components must be other components, "
- "state vars, or primitive Python types. "
- f"Got child {child} of type {type(child)}.",
- )
+ if isinstance(child, dict) or not types._isinstance(
+ child, ComponentChild
+ ):
+ raise ChildrenTypeError(component=cls.__name__, child=child)
# Validate all the children.
validate_children(children)
diff --git a/reflex/components/dynamic.py b/reflex/components/dynamic.py
index fbfc55f97..806d610df 100644
--- a/reflex/components/dynamic.py
+++ b/reflex/components/dynamic.py
@@ -136,6 +136,23 @@ def load_dynamic_serializer():
module_code_lines.insert(0, "const React = window.__reflex.react;")
+ function_line = next(
+ index
+ for index, line in enumerate(module_code_lines)
+ if line.startswith("export default function")
+ )
+
+ module_code_lines = [
+ line
+ for _, line in sorted(
+ enumerate(module_code_lines),
+ key=lambda x: (
+ not (x[1].startswith("import ") and x[0] < function_line),
+ x[0],
+ ),
+ )
+ ]
+
return "\n".join(
[
"//__reflex_evaluate",
diff --git a/reflex/components/lucide/icon.py b/reflex/components/lucide/icon.py
index 04410ac56..6c7cbede7 100644
--- a/reflex/components/lucide/icon.py
+++ b/reflex/components/lucide/icon.py
@@ -2,13 +2,15 @@
from reflex.components.component import Component
from reflex.utils import format
-from reflex.vars.base import Var
+from reflex.utils.imports import ImportVar
+from reflex.vars.base import LiteralVar, Var
+from reflex.vars.sequence import LiteralStringVar
class LucideIconComponent(Component):
"""Lucide Icon Component."""
- library = "lucide-react@0.469.0"
+ library = "lucide-react@0.471.1"
class Icon(LucideIconComponent):
@@ -32,6 +34,7 @@ class Icon(LucideIconComponent):
Raises:
AttributeError: The errors tied to bad usage of the Icon component.
ValueError: If the icon tag is invalid.
+ TypeError: If the icon name is not a string.
Returns:
The created component.
@@ -39,7 +42,6 @@ class Icon(LucideIconComponent):
if children:
if len(children) == 1 and isinstance(children[0], str):
props["tag"] = children[0]
- children = []
else:
raise AttributeError(
f"Passing multiple children to Icon component is not allowed: remove positional arguments {children[1:]} to fix"
@@ -47,24 +49,46 @@ class Icon(LucideIconComponent):
if "tag" not in props:
raise AttributeError("Missing 'tag' keyword-argument for Icon")
+ tag: str | Var | LiteralVar = props.pop("tag")
+ if isinstance(tag, LiteralVar):
+ if isinstance(tag, LiteralStringVar):
+ tag = tag._var_value
+ else:
+ raise TypeError(f"Icon name must be a string, got {type(tag)}")
+ elif isinstance(tag, Var):
+ return DynamicIcon.create(name=tag, **props)
+
if (
- not isinstance(props["tag"], str)
- or format.to_snake_case(props["tag"]) not in LUCIDE_ICON_LIST
+ not isinstance(tag, str)
+ or format.to_snake_case(tag) not in LUCIDE_ICON_LIST
):
raise ValueError(
- f"Invalid icon tag: {props['tag']}. Please use one of the following: {', '.join(LUCIDE_ICON_LIST[0:25])}, ..."
+ f"Invalid icon tag: {tag}. Please use one of the following: {', '.join(LUCIDE_ICON_LIST[0:25])}, ..."
"\nSee full list at https://lucide.dev/icons."
)
- if props["tag"] in LUCIDE_ICON_MAPPING_OVERRIDE:
- props["tag"] = LUCIDE_ICON_MAPPING_OVERRIDE[props["tag"]]
+ if tag in LUCIDE_ICON_MAPPING_OVERRIDE:
+ props["tag"] = LUCIDE_ICON_MAPPING_OVERRIDE[tag]
else:
- props["tag"] = (
- format.to_title_case(format.to_snake_case(props["tag"])) + "Icon"
- )
+ props["tag"] = format.to_title_case(format.to_snake_case(tag)) + "Icon"
props["alias"] = f"Lucide{props['tag']}"
props.setdefault("color", "var(--current-color)")
- return super().create(*children, **props)
+ return super().create(**props)
+
+
+class DynamicIcon(LucideIconComponent):
+ """A DynamicIcon component."""
+
+ tag = "DynamicIcon"
+
+ name: Var[str]
+
+ def _get_imports(self):
+ _imports = super()._get_imports()
+ if self.library:
+ _imports.pop(self.library)
+ _imports["lucide-react/dynamic"] = [ImportVar("DynamicIcon", install=False)]
+ return _imports
LUCIDE_ICON_LIST = [
@@ -846,6 +870,7 @@ LUCIDE_ICON_LIST = [
"house",
"house_plug",
"house_plus",
+ "house_wifi",
"ice_cream_bowl",
"ice_cream_cone",
"id_card",
@@ -1534,6 +1559,7 @@ LUCIDE_ICON_LIST = [
"trending_up_down",
"triangle",
"triangle_alert",
+ "triangle_dashed",
"triangle_right",
"trophy",
"truck",
diff --git a/reflex/components/lucide/icon.pyi b/reflex/components/lucide/icon.pyi
index 39a1da0e6..6094cfd87 100644
--- a/reflex/components/lucide/icon.pyi
+++ b/reflex/components/lucide/icon.pyi
@@ -104,12 +104,60 @@ class Icon(LucideIconComponent):
Raises:
AttributeError: The errors tied to bad usage of the Icon component.
ValueError: If the icon tag is invalid.
+ TypeError: If the icon name is not a string.
Returns:
The created component.
"""
...
+class DynamicIcon(LucideIconComponent):
+ @overload
+ @classmethod
+ def create( # type: ignore
+ cls,
+ *children,
+ name: Optional[Union[Var[str], str]] = None,
+ style: Optional[Style] = None,
+ key: Optional[Any] = None,
+ id: Optional[Any] = None,
+ class_name: Optional[Any] = None,
+ autofocus: Optional[bool] = None,
+ custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
+ on_blur: Optional[EventType[[], BASE_STATE]] = None,
+ on_click: Optional[EventType[[], BASE_STATE]] = None,
+ on_context_menu: Optional[EventType[[], BASE_STATE]] = None,
+ on_double_click: Optional[EventType[[], BASE_STATE]] = None,
+ on_focus: Optional[EventType[[], BASE_STATE]] = None,
+ on_mount: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_down: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_move: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_out: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_over: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_up: Optional[EventType[[], BASE_STATE]] = None,
+ on_scroll: Optional[EventType[[], BASE_STATE]] = None,
+ on_unmount: Optional[EventType[[], BASE_STATE]] = None,
+ **props,
+ ) -> "DynamicIcon":
+ """Create the component.
+
+ Args:
+ *children: The children of the component.
+ style: The style of the component.
+ key: A unique key for the component.
+ id: The id for the component.
+ class_name: The class name for the component.
+ autofocus: Whether the component should take the focus once the page is loaded
+ custom_attrs: custom attribute
+ **props: The props of the component.
+
+ Returns:
+ The component.
+ """
+ ...
+
LUCIDE_ICON_LIST = [
"a_arrow_down",
"a_arrow_up",
@@ -889,6 +937,7 @@ LUCIDE_ICON_LIST = [
"house",
"house_plug",
"house_plus",
+ "house_wifi",
"ice_cream_bowl",
"ice_cream_cone",
"id_card",
@@ -1577,6 +1626,7 @@ LUCIDE_ICON_LIST = [
"trending_up_down",
"triangle",
"triangle_alert",
+ "triangle_dashed",
"triangle_right",
"trophy",
"truck",
diff --git a/reflex/config.py b/reflex/config.py
index 0579b019f..8511694fb 100644
--- a/reflex/config.py
+++ b/reflex/config.py
@@ -12,6 +12,7 @@ import threading
import urllib.parse
from importlib.util import find_spec
from pathlib import Path
+from types import ModuleType
from typing import (
TYPE_CHECKING,
Any,
@@ -567,6 +568,9 @@ class EnvironmentVariables:
# The maximum size of the reflex state in kilobytes.
REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000)
+ # Whether to use the turbopack bundler.
+ REFLEX_USE_TURBOPACK: EnvVar[bool] = env_var(True)
+
environment = EnvironmentVariables()
@@ -604,6 +608,9 @@ class Config(Base):
# The name of the app (should match the name of the app directory).
app_name: str
+ # The path to the app module.
+ app_module_import: Optional[str] = None
+
# The log level to use.
loglevel: constants.LogLevel = constants.LogLevel.DEFAULT
@@ -726,6 +733,19 @@ class Config(Base):
"REDIS_URL is required when using the redis state manager."
)
+ @property
+ def app_module(self) -> ModuleType | None:
+ """Return the app module if `app_module_import` is set.
+
+ Returns:
+ The app module.
+ """
+ return (
+ importlib.import_module(self.app_module_import)
+ if self.app_module_import
+ else None
+ )
+
@property
def module(self) -> str:
"""Get the module name of the app.
@@ -733,6 +753,8 @@ class Config(Base):
Returns:
The module name.
"""
+ if self.app_module is not None:
+ return self.app_module.__name__
return ".".join([self.app_name, self.app_name])
def update_from_env(self) -> dict[str, Any]:
@@ -871,7 +893,7 @@ def get_config(reload: bool = False) -> Config:
return cached_rxconfig.config
with _config_lock:
- sys_path = sys.path.copy()
+ orig_sys_path = sys.path.copy()
sys.path.clear()
sys.path.append(str(Path.cwd()))
try:
@@ -879,9 +901,14 @@ def get_config(reload: bool = False) -> Config:
return _get_config()
except Exception:
# If the module import fails, try to import with the original sys.path.
- sys.path.extend(sys_path)
+ sys.path.extend(orig_sys_path)
return _get_config()
finally:
+ # Find any entries added to sys.path by rxconfig.py itself.
+ extra_paths = [
+ p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd())
+ ]
# Restore the original sys.path.
sys.path.clear()
- sys.path.extend(sys_path)
+ sys.path.extend(extra_paths)
+ sys.path.extend(orig_sys_path)
diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py
index e816da0f7..f5946bf5e 100644
--- a/reflex/constants/__init__.py
+++ b/reflex/constants/__init__.py
@@ -1,6 +1,7 @@
"""The constants package."""
from .base import (
+ APP_HARNESS_FLAG,
COOKIES,
IS_LINUX,
IS_MACOS,
diff --git a/reflex/constants/base.py b/reflex/constants/base.py
index af96583ad..f737858c0 100644
--- a/reflex/constants/base.py
+++ b/reflex/constants/base.py
@@ -257,6 +257,7 @@ SESSION_STORAGE = "session_storage"
# Testing variables.
# Testing os env set by pytest when running a test case.
PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
+APP_HARNESS_FLAG = "APP_HARNESS_FLAG"
REFLEX_VAR_OPENING_TAG = ""
REFLEX_VAR_CLOSING_TAG = ""
diff --git a/reflex/constants/installer.py b/reflex/constants/installer.py
index 0b45586dd..f9dd26b5a 100644
--- a/reflex/constants/installer.py
+++ b/reflex/constants/installer.py
@@ -182,7 +182,7 @@ class PackageJson(SimpleNamespace):
"@emotion/react": "11.13.3",
"axios": "1.7.7",
"json5": "2.2.3",
- "next": "14.2.16",
+ "next": "15.1.4",
"next-sitemap": "4.2.3",
"next-themes": "0.4.3",
"react": "18.3.1",
diff --git a/reflex/custom_components/custom_components.py b/reflex/custom_components/custom_components.py
index 4a169802f..8000e7f4c 100644
--- a/reflex/custom_components/custom_components.py
+++ b/reflex/custom_components/custom_components.py
@@ -421,12 +421,13 @@ def _run_commands_in_subprocess(cmds: list[str]) -> bool:
console.debug(f"Running command: {' '.join(cmds)}")
try:
result = subprocess.run(cmds, capture_output=True, text=True, check=True)
- console.debug(result.stdout)
- return True
except subprocess.CalledProcessError as cpe:
console.error(cpe.stdout)
console.error(cpe.stderr)
return False
+ else:
+ console.debug(result.stdout)
+ return True
def _make_pyi_files():
@@ -931,10 +932,11 @@ def _get_file_from_prompt_in_loop() -> Tuple[bytes, str] | None:
file_extension = image_filepath.suffix
try:
image_file = image_filepath.read_bytes()
- return image_file, file_extension
except OSError as ose:
console.error(f"Unable to read the {file_extension} file due to {ose}")
raise typer.Exit(code=1) from ose
+ else:
+ return image_file, file_extension
console.debug(f"File extension detected: {file_extension}")
return None
diff --git a/reflex/event.py b/reflex/event.py
index 28852fde5..886a306c1 100644
--- a/reflex/event.py
+++ b/reflex/event.py
@@ -1591,7 +1591,7 @@ def get_handler_args(
def fix_events(
- events: list[EventHandler | EventSpec] | None,
+ events: list[EventSpec | EventHandler] | None,
token: str,
router_data: dict[str, Any] | None = None,
) -> list[Event]:
diff --git a/reflex/reflex.py b/reflex/reflex.py
index 22fcb9fb8..2d6ebc30c 100644
--- a/reflex/reflex.py
+++ b/reflex/reflex.py
@@ -440,7 +440,11 @@ def deploy(
config.app_name,
"--app-name",
help="The name of the App to deploy under.",
- hidden=True,
+ ),
+ app_id: str = typer.Option(
+ None,
+ "--app-id",
+ help="The ID of the App to deploy over.",
),
regions: List[str] = typer.Option(
[],
@@ -480,6 +484,11 @@ def deploy(
"--project",
help="project id to deploy to",
),
+ project_name: Optional[str] = typer.Option(
+ None,
+ "--project-name",
+ help="The name of the project to deploy to.",
+ ),
token: Optional[str] = typer.Option(
None,
"--token",
@@ -503,13 +512,6 @@ def deploy(
# Set the log level.
console.set_log_level(loglevel)
- if not token:
- # make sure user is logged in.
- if interactive:
- hosting_cli.login()
- else:
- raise SystemExit("Token is required for non-interactive mode.")
-
# Only check requirements if interactive.
# There is user interaction for requirements update.
if interactive:
@@ -519,9 +521,12 @@ def deploy(
if prerequisites.needs_reinit(frontend=True):
_init(name=config.app_name, loglevel=loglevel)
prerequisites.check_latest_package_version(constants.ReflexHostingCLI.MODULE_NAME)
-
+ extra: dict[str, str] = (
+ {"config_path": config_path} if config_path is not None else {}
+ )
hosting_cli.deploy(
app_name=app_name,
+ app_id=app_id,
export_fn=lambda zip_dest_dir,
api_url,
deploy_url,
@@ -546,6 +551,8 @@ def deploy(
token=token,
project=project,
config_path=config_path,
+ project_name=project_name,
+ **extra,
)
diff --git a/reflex/state.py b/reflex/state.py
index a31aae032..66098d232 100644
--- a/reflex/state.py
+++ b/reflex/state.py
@@ -104,6 +104,7 @@ from reflex.utils.exceptions import (
LockExpiredError,
ReflexRuntimeError,
SetUndefinedStateVarError,
+ StateMismatchError,
StateSchemaMismatchError,
StateSerializationError,
StateTooLargeError,
@@ -1199,7 +1200,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
fget=func,
auto_deps=False,
deps=["router"],
- cache=True,
_js_expr=param,
_var_data=VarData.from_state(cls),
)
@@ -1543,7 +1543,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Return the direct parent of target_state_cls for subsequent linking.
return parent_state
- def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
+ def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from the cache.
Args:
@@ -1551,11 +1551,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Returns:
The instance of state_cls associated with this state's client_token.
+
+ Raises:
+ StateMismatchError: If the state instance is not of the expected type.
"""
root_state = self._get_root_state()
- return root_state.get_substate(state_cls.get_full_name().split("."))
+ substate = root_state.get_substate(state_cls.get_full_name().split("."))
+ if not isinstance(substate, state_cls):
+ raise StateMismatchError(
+ f"Searched for state {state_cls.get_full_name()} but found {substate}."
+ )
+ return substate
- async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
+ async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from redis.
Args:
@@ -1566,6 +1574,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Raises:
RuntimeError: If redis is not used in this backend process.
+ StateMismatchError: If the state instance is not of the expected type.
"""
# Fetch all missing parent states from redis.
parent_state_of_state_cls = await self._populate_parent_states(state_cls)
@@ -1577,14 +1586,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
"(All states should already be available -- this is likely a bug).",
)
- return await state_manager.get_state(
+
+ state_in_redis = await state_manager.get_state(
token=_substate_key(self.router.session.client_token, state_cls),
top_level=False,
get_substates=True,
parent_state=parent_state_of_state_cls,
)
- async def get_state(self, state_cls: Type[BaseState]) -> BaseState:
+ if not isinstance(state_in_redis, state_cls):
+ raise StateMismatchError(
+ f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
+ )
+
+ return state_in_redis
+
+ async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get an instance of the state associated with this token.
Allows for arbitrary access to sibling states from within an event handler.
@@ -1759,9 +1776,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
except Exception as ex:
state._clean()
- app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
-
- event_specs = app_instance.backend_exception_handler(ex)
+ event_specs = (
+ prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
+ )
if event_specs is None:
return StateUpdate()
@@ -1871,9 +1888,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
except Exception as ex:
telemetry.send_error(ex, context="backend")
- app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
-
- event_specs = app_instance.backend_exception_handler(ex)
+ event_specs = (
+ prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
+ )
yield state._as_state_update(
handler,
@@ -2316,6 +2333,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return state
+T_STATE = TypeVar("T_STATE", bound=BaseState)
+
+
class State(BaseState):
"""The app Base State."""
@@ -2383,8 +2403,9 @@ class FrontendEventExceptionState(State):
component_stack: The stack trace of the component where the exception occurred.
"""
- app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
- app_instance.frontend_exception_handler(Exception(stack))
+ prerequisites.get_and_validate_app().app.frontend_exception_handler(
+ Exception(stack)
+ )
class UpdateVarsInternalState(State):
@@ -2422,15 +2443,16 @@ class OnLoadInternalState(State):
The list of events to queue for on load handling.
"""
# Do not app._compile()! It should be already compiled by now.
- app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
- load_events = app.get_load_events(self.router.page.path)
+ load_events = prerequisites.get_and_validate_app().app.get_load_events(
+ self.router.page.path
+ )
if not load_events:
self.is_hydrated = True
return # Fast path for navigation with no on_load events defined.
self.is_hydrated = False
return [
*fix_events(
- load_events,
+ cast(list[Union[EventSpec, EventHandler]], load_events),
self.router.session.client_token,
router_data=self.router_data,
),
@@ -2589,7 +2611,7 @@ class StateProxy(wrapt.ObjectProxy):
"""
super().__init__(state_instance)
# compile is not relevant to backend logic
- self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
+ self._self_app = prerequisites.get_and_validate_app().app
self._self_substate_path = tuple(state_instance.get_full_name().split("."))
self._self_actx = None
self._self_mutable = False
@@ -3682,8 +3704,7 @@ def get_state_manager() -> StateManager:
Returns:
The state manager.
"""
- app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
- return app.state_manager
+ return prerequisites.get_and_validate_app().app.state_manager
class MutableProxy(wrapt.ObjectProxy):
diff --git a/reflex/testing.py b/reflex/testing.py
index f9bef2c09..027648f97 100644
--- a/reflex/testing.py
+++ b/reflex/testing.py
@@ -282,6 +282,7 @@ class AppHarness:
before_decorated_pages = reflex.app.DECORATED_PAGES[self.app_name].copy()
# Ensure the AppHarness test does not skip State assignment due to running via pytest
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
+ os.environ[reflex.constants.APP_HARNESS_FLAG] = "true"
self.app_module = reflex.utils.prerequisites.get_compiled_app(
# Do not reload the module for pre-existing apps (only apps generated from source)
reload=self.app_source is not None
diff --git a/reflex/utils/build.py b/reflex/utils/build.py
index e263374e1..9ea941792 100644
--- a/reflex/utils/build.py
+++ b/reflex/utils/build.py
@@ -13,13 +13,17 @@ from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from reflex import constants
from reflex.config import get_config
from reflex.utils import console, path_ops, prerequisites, processes
+from reflex.utils.exec import is_in_app_harness
def set_env_json():
"""Write the upload url to a REFLEX_JSON."""
path_ops.update_json_file(
str(prerequisites.get_web_dir() / constants.Dirs.ENV_JSON),
- {endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint},
+ {
+ **{endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint},
+ "TEST_MODE": is_in_app_harness(),
+ },
)
diff --git a/reflex/utils/console.py b/reflex/utils/console.py
index be545140a..8929b63b6 100644
--- a/reflex/utils/console.py
+++ b/reflex/utils/console.py
@@ -2,6 +2,11 @@
from __future__ import annotations
+import inspect
+import shutil
+from pathlib import Path
+from types import FrameType
+
from rich.console import Console
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from rich.prompt import Prompt
@@ -188,6 +193,33 @@ def warn(msg: str, dedupe: bool = False, **kwargs):
print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
+def _get_first_non_framework_frame() -> FrameType | None:
+ import click
+ import typer
+ import typing_extensions
+
+ import reflex as rx
+
+ # Exclude utility modules that should never be the source of deprecated reflex usage.
+ exclude_modules = [click, rx, typer, typing_extensions]
+ exclude_roots = [
+ p.parent.resolve()
+ if (p := Path(m.__file__)).name == "__init__.py"
+ else p.resolve()
+ for m in exclude_modules
+ ]
+ # Specifically exclude the reflex cli module.
+ if reflex_bin := shutil.which(b"reflex"):
+ exclude_roots.append(Path(reflex_bin.decode()))
+
+ frame = inspect.currentframe()
+ while frame := frame and frame.f_back:
+ frame_path = Path(inspect.getfile(frame)).resolve()
+ if not any(frame_path.is_relative_to(root) for root in exclude_roots):
+ break
+ return frame
+
+
def deprecate(
feature_name: str,
reason: str,
@@ -206,15 +238,27 @@ def deprecate(
dedupe: If True, suppress multiple console logs of deprecation message.
kwargs: Keyword arguments to pass to the print function.
"""
- if feature_name not in _EMITTED_DEPRECATION_WARNINGS:
+ dedupe_key = feature_name
+ loc = ""
+
+ # See if we can find where the deprecation exists in "user code"
+ origin_frame = _get_first_non_framework_frame()
+ if origin_frame is not None:
+ filename = Path(origin_frame.f_code.co_filename)
+ if filename.is_relative_to(Path.cwd()):
+ filename = filename.relative_to(Path.cwd())
+ loc = f"{filename}:{origin_frame.f_lineno}"
+ dedupe_key = f"{dedupe_key} {loc}"
+
+ if dedupe_key not in _EMITTED_DEPRECATION_WARNINGS:
msg = (
f"{feature_name} has been deprecated in version {deprecation_version} {reason.rstrip('.')}. It will be completely "
- f"removed in {removal_version}"
+ f"removed in {removal_version}. ({loc})"
)
if _LOG_LEVEL <= LogLevel.WARNING:
print(f"[yellow]DeprecationWarning: {msg}[/yellow]", **kwargs)
if dedupe:
- _EMITTED_DEPRECATION_WARNINGS.add(feature_name)
+ _EMITTED_DEPRECATION_WARNINGS.add(dedupe_key)
def error(msg: str, dedupe: bool = False, **kwargs):
diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py
index bceadc977..37a68e420 100644
--- a/reflex/utils/exceptions.py
+++ b/reflex/utils/exceptions.py
@@ -1,6 +1,6 @@
"""Custom Exceptions."""
-from typing import NoReturn
+from typing import Any, NoReturn
class ReflexError(Exception):
@@ -31,6 +31,22 @@ class ComponentTypeError(ReflexError, TypeError):
"""Custom TypeError for component related errors."""
+class ChildrenTypeError(ComponentTypeError):
+ """Raised when the children prop of a component is not a valid type."""
+
+ def __init__(self, component: str, child: Any):
+ """Initialize the exception.
+
+ Args:
+ component: The name of the component.
+ child: The child that caused the error.
+ """
+ super().__init__(
+ f"Component {component} received child {child} of type {type(child)}. "
+ "Accepted types are other components, state vars, or primitive Python types (dict excluded)."
+ )
+
+
class EventHandlerTypeError(ReflexError, TypeError):
"""Custom TypeError for event handler related errors."""
@@ -163,10 +179,18 @@ class StateSerializationError(ReflexError):
"""Raised when the state cannot be serialized."""
+class StateMismatchError(ReflexError, ValueError):
+ """Raised when the state retrieved does not match the expected state."""
+
+
class SystemPackageMissingError(ReflexError):
"""Raised when a system package is missing."""
+class EventDeserializationError(ReflexError, ValueError):
+ """Raised when an event cannot be deserialized."""
+
+
def raise_system_package_missing_error(package: str) -> NoReturn:
"""Raise a SystemPackageMissingError.
diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py
index 621c4a608..c10b6b856 100644
--- a/reflex/utils/exec.py
+++ b/reflex/utils/exec.py
@@ -240,6 +240,28 @@ def run_backend(
run_uvicorn_backend(host, port, loglevel)
+def get_reload_dirs() -> list[str]:
+ """Get the reload directories for the backend.
+
+ Returns:
+ The reload directories for the backend.
+ """
+ config = get_config()
+ reload_dirs = [config.app_name]
+ if config.app_module is not None and config.app_module.__file__:
+ module_path = Path(config.app_module.__file__).resolve().parent
+ while module_path.parent.name:
+ for parent_file in module_path.parent.iterdir():
+ if parent_file == "__init__.py":
+ # go up a level to find dir without `__init__.py`
+ module_path = module_path.parent
+ break
+ else:
+ break
+ reload_dirs.append(str(module_path))
+ return reload_dirs
+
+
def run_uvicorn_backend(host, port, loglevel: LogLevel):
"""Run the backend in development mode using Uvicorn.
@@ -256,7 +278,7 @@ def run_uvicorn_backend(host, port, loglevel: LogLevel):
port=port,
log_level=loglevel.value,
reload=True,
- reload_dirs=[get_config().app_name],
+ reload_dirs=get_reload_dirs(),
)
@@ -281,7 +303,7 @@ def run_granian_backend(host, port, loglevel: LogLevel):
interface=Interfaces.ASGI,
log_level=LogLevels(loglevel.value),
reload=True,
- reload_paths=[Path(get_config().app_name)],
+ reload_paths=get_reload_dirs(),
reload_ignore_dirs=[".web"],
).serve()
except ImportError:
@@ -487,6 +509,15 @@ def is_testing_env() -> bool:
return constants.PYTEST_CURRENT_TEST in os.environ
+def is_in_app_harness() -> bool:
+ """Whether the app is running in the app harness.
+
+ Returns:
+ True if the app is running in the app harness.
+ """
+ return constants.APP_HARNESS_FLAG in os.environ
+
+
def is_prod_mode() -> bool:
"""Check if the app is running in production mode.
diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py
index d838c0eea..4f9cc0c14 100644
--- a/reflex/utils/prerequisites.py
+++ b/reflex/utils/prerequisites.py
@@ -17,11 +17,12 @@ import stat
import sys
import tempfile
import time
+import typing
import zipfile
from datetime import datetime
from pathlib import Path
from types import ModuleType
-from typing import Callable, List, Optional
+from typing import Callable, List, NamedTuple, Optional
import httpx
import typer
@@ -42,9 +43,19 @@ from reflex.utils.exceptions import (
from reflex.utils.format import format_library_name
from reflex.utils.registry import _get_npm_registry
+if typing.TYPE_CHECKING:
+ from reflex.app import App
+
CURRENTLY_INSTALLING_NODE = False
+class AppInfo(NamedTuple):
+ """A tuple containing the app instance and module."""
+
+ app: App
+ module: ModuleType
+
+
@dataclasses.dataclass(frozen=True)
class Template:
"""A template for a Reflex app."""
@@ -267,6 +278,22 @@ def windows_npm_escape_hatch() -> bool:
return environment.REFLEX_USE_NPM.get()
+def _check_app_name(config: Config):
+ """Check if the app name is set in the config.
+
+ Args:
+ config: The config object.
+
+ Raises:
+ RuntimeError: If the app name is not set in the config.
+ """
+ if not config.app_name:
+ raise RuntimeError(
+ "Cannot get the app module because `app_name` is not set in rxconfig! "
+ "If this error occurs in a reflex test case, ensure that `get_app` is mocked."
+ )
+
+
def get_app(reload: bool = False) -> ModuleType:
"""Get the app module based on the default config.
@@ -277,22 +304,23 @@ def get_app(reload: bool = False) -> ModuleType:
The app based on the default config.
Raises:
- RuntimeError: If the app name is not set in the config.
+ Exception: If an error occurs while getting the app module.
"""
from reflex.utils import telemetry
try:
environment.RELOAD_CONFIG.set(reload)
config = get_config()
- if not config.app_name:
- raise RuntimeError(
- "Cannot get the app module because `app_name` is not set in rxconfig! "
- "If this error occurs in a reflex test case, ensure that `get_app` is mocked."
- )
+
+ _check_app_name(config)
+
module = config.module
sys.path.insert(0, str(Path.cwd()))
- app = __import__(module, fromlist=(constants.CompileVars.APP,))
-
+ app = (
+ __import__(module, fromlist=(constants.CompileVars.APP,))
+ if not config.app_module
+ else config.app_module
+ )
if reload:
from reflex.state import reload_state_module
@@ -301,11 +329,34 @@ def get_app(reload: bool = False) -> ModuleType:
# Reload the app module.
importlib.reload(app)
-
- return app
except Exception as ex:
telemetry.send_error(ex, context="frontend")
raise
+ else:
+ return app
+
+
+def get_and_validate_app(reload: bool = False) -> AppInfo:
+ """Get the app instance based on the default config and validate it.
+
+ Args:
+ reload: Re-import the app module from disk
+
+ Returns:
+ The app instance and the app module.
+
+ Raises:
+ RuntimeError: If the app instance is not an instance of rx.App.
+ """
+ from reflex.app import App
+
+ app_module = get_app(reload=reload)
+ app = getattr(app_module, constants.CompileVars.APP)
+ if not isinstance(app, App):
+ raise RuntimeError(
+ "The app instance in the specified app_module_import in rxconfig must be an instance of rx.App."
+ )
+ return AppInfo(app=app, module=app_module)
def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
@@ -318,8 +369,7 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
Returns:
The compiled app based on the default config.
"""
- app_module = get_app(reload=reload)
- app = getattr(app_module, constants.CompileVars.APP)
+ app, app_module = get_and_validate_app(reload=reload)
# For py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172).
app._apply_decorated_pages()
@@ -610,10 +660,14 @@ def initialize_web_directory():
init_reflex_json(project_hash=project_hash)
+def _turbopack_flag() -> str:
+ return " --turbopack" if environment.REFLEX_USE_TURBOPACK.get() else ""
+
+
def _compile_package_json():
return templates.PACKAGE_JSON.render(
scripts={
- "dev": constants.PackageJson.Commands.DEV,
+ "dev": constants.PackageJson.Commands.DEV + _turbopack_flag(),
"export": constants.PackageJson.Commands.EXPORT,
"export_sitemap": constants.PackageJson.Commands.EXPORT_SITEMAP,
"prod": constants.PackageJson.Commands.PROD,
@@ -1149,11 +1203,12 @@ def ensure_reflex_installation_id() -> Optional[int]:
if installation_id is None:
installation_id = random.getrandbits(128)
installation_id_file.write_text(str(installation_id))
- # If we get here, installation_id is definitely set
- return installation_id
except Exception as e:
console.debug(f"Failed to ensure reflex installation id: {e}")
return None
+ else:
+ # If we get here, installation_id is definitely set
+ return installation_id
def initialize_reflex_user_directory():
@@ -1367,19 +1422,22 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str
except OSError as ose:
console.error(f"Failed to create temp directory for extracting zip: {ose}")
raise typer.Exit(1) from ose
+
try:
zipfile.ZipFile(zip_file_path).extractall(path=unzip_dir)
# The zip file downloaded from github looks like:
# repo-name-branch/**/*, so we need to remove the top level directory.
- if len(subdirs := os.listdir(unzip_dir)) != 1:
- console.error(f"Expected one directory in the zip, found {subdirs}")
- raise typer.Exit(1)
- template_dir = unzip_dir / subdirs[0]
- console.debug(f"Template folder is located at {template_dir}")
except Exception as uze:
console.error(f"Failed to unzip the template: {uze}")
raise typer.Exit(1) from uze
+ if len(subdirs := os.listdir(unzip_dir)) != 1:
+ console.error(f"Expected one directory in the zip, found {subdirs}")
+ raise typer.Exit(1)
+
+ template_dir = unzip_dir / subdirs[0]
+ console.debug(f"Template folder is located at {template_dir}")
+
# Move the rxconfig file here first.
path_ops.mv(str(template_dir / constants.Config.FILE), constants.Config.FILE)
new_config = get_config(reload=True)
diff --git a/reflex/utils/processes.py b/reflex/utils/processes.py
index 871b5f323..3673b36b2 100644
--- a/reflex/utils/processes.py
+++ b/reflex/utils/processes.py
@@ -17,6 +17,7 @@ import typer
from redis.exceptions import RedisError
from reflex import constants
+from reflex.config import environment
from reflex.utils import console, path_ops, prerequisites
@@ -156,24 +157,30 @@ def new_process(args, run: bool = False, show_logs: bool = False, **kwargs):
Raises:
Exit: When attempting to run a command with a None value.
"""
- node_bin_path = str(path_ops.get_node_bin_path())
- if not node_bin_path and not prerequisites.CURRENTLY_INSTALLING_NODE:
- console.warn(
- "The path to the Node binary could not be found. Please ensure that Node is properly "
- "installed and added to your system's PATH environment variable or try running "
- "`reflex init` again."
- )
+ # Check for invalid command first.
if None in args:
console.error(f"Invalid command: {args}")
raise typer.Exit(1)
- # Add the node bin path to the PATH environment variable.
+
+ path_env: str = os.environ.get("PATH", "")
+
+ # Add node_bin_path to the PATH environment variable.
+ if not environment.REFLEX_BACKEND_ONLY.get():
+ node_bin_path = str(path_ops.get_node_bin_path())
+ if not node_bin_path and not prerequisites.CURRENTLY_INSTALLING_NODE:
+ console.warn(
+ "The path to the Node binary could not be found. Please ensure that Node is properly "
+ "installed and added to your system's PATH environment variable or try running "
+ "`reflex init` again."
+ )
+ path_env = os.pathsep.join([node_bin_path, path_env])
+
env: dict[str, str] = {
**os.environ,
- "PATH": os.pathsep.join(
- [node_bin_path if node_bin_path else "", os.environ["PATH"]]
- ), # type: ignore
+ "PATH": path_env,
**kwargs.pop("env", {}),
}
+
kwargs = {
"env": env,
"stderr": None if show_logs else subprocess.STDOUT,
diff --git a/reflex/utils/telemetry.py b/reflex/utils/telemetry.py
index fc90932a6..8e9130b09 100644
--- a/reflex/utils/telemetry.py
+++ b/reflex/utils/telemetry.py
@@ -156,9 +156,10 @@ def _prepare_event(event: str, **kwargs) -> dict:
def _send_event(event_data: dict) -> bool:
try:
httpx.post(POSTHOG_API_URL, json=event_data)
- return True
except Exception:
return False
+ else:
+ return True
def _send(event, telemetry_enabled, **kwargs):
diff --git a/reflex/utils/types.py b/reflex/utils/types.py
index b8bcbf2d6..ac30342e2 100644
--- a/reflex/utils/types.py
+++ b/reflex/utils/types.py
@@ -829,6 +829,22 @@ StateBases = get_base_class(StateVar)
StateIterBases = get_base_class(StateIterVar)
+def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]):
+ """Check if a class is a subclass of another class. Returns False if internal error occurs.
+
+ Args:
+ cls: The class to check.
+ cls_check: The class to check against.
+
+ Returns:
+ Whether the class is a subclass of the other class.
+ """
+ try:
+ return issubclass(cls, cls_check)
+ except TypeError:
+ return False
+
+
def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
"""Check if a type hint is a subclass of another type hint.
diff --git a/reflex/vars/base.py b/reflex/vars/base.py
index 0a93901cd..122545187 100644
--- a/reflex/vars/base.py
+++ b/reflex/vars/base.py
@@ -26,6 +26,7 @@ from typing import (
Iterable,
List,
Literal,
+ Mapping,
NoReturn,
Optional,
Set,
@@ -64,6 +65,7 @@ from reflex.utils.types import (
_isinstance,
get_origin,
has_args,
+ safe_issubclass,
unionize,
)
@@ -127,7 +129,7 @@ class VarData:
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
- hooks: dict[str, VarData | None] | None = None,
+ hooks: Mapping[str, VarData | None] | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
):
@@ -561,7 +563,7 @@ class Var(Generic[VAR_TYPE]):
if _var_is_local is not None:
console.deprecate(
feature_name="_var_is_local",
- reason="The _var_is_local argument is not supported for Var."
+ reason="The _var_is_local argument is not supported for Var. "
"If you want to create a Var from a raw Javascript expression, use the constructor directly",
deprecation_version="0.6.0",
removal_version="0.7.0",
@@ -569,7 +571,7 @@ class Var(Generic[VAR_TYPE]):
if _var_is_string is not None:
console.deprecate(
feature_name="_var_is_string",
- reason="The _var_is_string argument is not supported for Var."
+ reason="The _var_is_string argument is not supported for Var. "
"If you want to create a Var from a raw Javascript expression, use the constructor directly",
deprecation_version="0.6.0",
removal_version="0.7.0",
@@ -643,8 +645,8 @@ class Var(Generic[VAR_TYPE]):
@overload
def to(
self,
- output: type[dict],
- ) -> ObjectVar[dict]: ...
+ output: type[Mapping],
+ ) -> ObjectVar[Mapping]: ...
@overload
def to(
@@ -686,7 +688,9 @@ class Var(Generic[VAR_TYPE]):
# If the first argument is a python type, we map it to the corresponding Var type.
for var_subclass in _var_subclasses[::-1]:
- if fixed_output_type in var_subclass.python_types:
+ if fixed_output_type in var_subclass.python_types or safe_issubclass(
+ fixed_output_type, var_subclass.python_types
+ ):
return self.to(var_subclass.var_subclass, output)
if fixed_output_type is None:
@@ -820,7 +824,7 @@ class Var(Generic[VAR_TYPE]):
return False
if issubclass(type_, list):
return []
- if issubclass(type_, dict):
+ if issubclass(type_, Mapping):
return {}
if issubclass(type_, tuple):
return ()
@@ -1026,7 +1030,7 @@ class Var(Generic[VAR_TYPE]):
f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")]
}
),
- ).to(ObjectVar, Dict[str, str])
+ ).to(ObjectVar, Mapping[str, str])
return refs[LiteralVar.create(str(self))]
@deprecated("Use `.js_type()` instead.")
@@ -1373,7 +1377,7 @@ class LiteralVar(Var):
serialized_value = serializers.serialize(value)
if serialized_value is not None:
- if isinstance(serialized_value, dict):
+ if isinstance(serialized_value, Mapping):
return LiteralObjectVar.create(
serialized_value,
_var_type=type(value),
@@ -1498,7 +1502,7 @@ def var_operation(
) -> Callable[P, ArrayVar[LIST_T]]: ...
-OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
+OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping)
@overload
@@ -1573,8 +1577,8 @@ def figure_out_type(value: Any) -> types.GenericType:
return Set[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, tuple):
return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
- if isinstance(value, dict):
- return Dict[
+ if isinstance(value, Mapping):
+ return Mapping[
unionize(*(figure_out_type(k) for k in value)),
unionize(*(figure_out_type(v) for v in value.values())),
]
@@ -1838,7 +1842,7 @@ class ComputedVar(Var[RETURN_TYPE]):
self,
fget: Callable[[BASE_STATE], RETURN_TYPE],
initial_value: RETURN_TYPE | types.Unset = types.Unset(),
- cache: bool = False,
+ cache: bool = True,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[int, datetime.timedelta]] = None,
@@ -2002,10 +2006,10 @@ class ComputedVar(Var[RETURN_TYPE]):
@overload
def __get__(
- self: ComputedVar[dict[DICT_KEY, DICT_VAL]],
+ self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]],
instance: None,
owner: Type,
- ) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ...
+ ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
@overload
def __get__(
@@ -2253,7 +2257,7 @@ if TYPE_CHECKING:
def computed_var(
fget: None = None,
initial_value: Any | types.Unset = types.Unset(),
- cache: bool = False,
+ cache: bool = True,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
@@ -2266,7 +2270,7 @@ def computed_var(
def computed_var(
fget: Callable[[BASE_STATE], RETURN_TYPE],
initial_value: RETURN_TYPE | types.Unset = types.Unset(),
- cache: bool = False,
+ cache: bool = True,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
@@ -2278,7 +2282,7 @@ def computed_var(
def computed_var(
fget: Callable[[BASE_STATE], Any] | None = None,
initial_value: Any | types.Unset = types.Unset(),
- cache: Optional[bool] = None,
+ cache: bool = True,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
@@ -2304,15 +2308,6 @@ def computed_var(
ValueError: If caching is disabled and an update interval is set.
VarDependencyError: If user supplies dependencies without caching.
"""
- if cache is None:
- cache = False
- console.deprecate(
- "Default non-cached rx.var",
- "the default value will be `@rx.var(cache=True)` in a future release. "
- "To retain uncached var, explicitly pass `@rx.var(cache=False)`",
- deprecation_version="0.6.8",
- removal_version="0.7.0",
- )
if cache is False and interval is not None:
raise ValueError("Cannot set update interval without caching.")
@@ -2924,11 +2919,14 @@ V = TypeVar("V")
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
+FIELD_TYPE = TypeVar("FIELD_TYPE")
+MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
-class Field(Generic[T]):
+
+class Field(Generic[FIELD_TYPE]):
"""Shadow class for Var to allow for type hinting in the IDE."""
- def __set__(self, instance, value: T):
+ def __set__(self, instance, value: FIELD_TYPE):
"""Set the Var.
Args:
@@ -2940,7 +2938,9 @@ class Field(Generic[T]):
def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
@overload
- def __get__(self: Field[int], instance: None, owner) -> NumberVar: ...
+ def __get__(
+ self: Field[int] | Field[float] | Field[int | float], instance: None, owner
+ ) -> NumberVar: ...
@overload
def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
@@ -2957,8 +2957,8 @@ class Field(Generic[T]):
@overload
def __get__(
- self: Field[Dict[str, V]], instance: None, owner
- ) -> ObjectVar[Dict[str, V]]: ...
+ self: Field[MAPPING_TYPE], instance: None, owner
+ ) -> ObjectVar[MAPPING_TYPE]: ...
@overload
def __get__(
@@ -2966,10 +2966,10 @@ class Field(Generic[T]):
) -> ObjectVar[BASE_TYPE]: ...
@overload
- def __get__(self, instance: None, owner) -> Var[T]: ...
+ def __get__(self, instance: None, owner) -> Var[FIELD_TYPE]: ...
@overload
- def __get__(self, instance, owner) -> T: ...
+ def __get__(self, instance, owner) -> FIELD_TYPE: ...
def __get__(self, instance, owner): # type: ignore
"""Get the Var.
@@ -2980,7 +2980,7 @@ class Field(Generic[T]):
"""
-def field(value: T) -> Field[T]:
+def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]:
"""Create a Field with a value.
Args:
diff --git a/reflex/vars/function.py b/reflex/vars/function.py
index 2a7d50e1b..131f15b9f 100644
--- a/reflex/vars/function.py
+++ b/reflex/vars/function.py
@@ -390,6 +390,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
Returns:
The function var.
"""
+ return_expr = Var.create(return_expr)
return cls(
_js_expr="",
_var_type=_var_type,
@@ -445,6 +446,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
Returns:
The function var.
"""
+ return_expr = Var.create(return_expr)
return cls(
_js_expr="",
_var_type=_var_type,
diff --git a/reflex/vars/number.py b/reflex/vars/number.py
index d04aded35..a2a0293d5 100644
--- a/reflex/vars/number.py
+++ b/reflex/vars/number.py
@@ -20,7 +20,6 @@ from typing import (
from reflex.constants.base import Dirs
from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError
from reflex.utils.imports import ImportDict, ImportVar
-from reflex.utils.types import is_optional
from .base import (
CustomVarOperationReturn,
@@ -431,7 +430,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("<", (type(self), type(other)))
- return less_than_operation(self, +other)
+ return less_than_operation(+self, +other)
@overload
def __le__(self, other: number_types) -> BooleanVar: ...
@@ -450,7 +449,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("<=", (type(self), type(other)))
- return less_than_or_equal_operation(self, +other)
+ return less_than_or_equal_operation(+self, +other)
def __eq__(self, other: Any):
"""Equal comparison.
@@ -462,7 +461,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
The result of the comparison.
"""
if isinstance(other, NUMBER_TYPES):
- return equal_operation(self, +other)
+ return equal_operation(+self, +other)
return equal_operation(self, other)
def __ne__(self, other: Any):
@@ -475,7 +474,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
The result of the comparison.
"""
if isinstance(other, NUMBER_TYPES):
- return not_equal_operation(self, +other)
+ return not_equal_operation(+self, +other)
return not_equal_operation(self, other)
@overload
@@ -495,7 +494,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types(">", (type(self), type(other)))
- return greater_than_operation(self, +other)
+ return greater_than_operation(+self, +other)
@overload
def __ge__(self, other: number_types) -> BooleanVar: ...
@@ -514,17 +513,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types(">=", (type(self), type(other)))
- return greater_than_or_equal_operation(self, +other)
-
- def bool(self):
- """Boolean conversion.
-
- Returns:
- The boolean value of the number.
- """
- if is_optional(self._var_type):
- return boolify((self != None) & (self != 0)) # noqa: E711
- return self != 0
+ return greater_than_or_equal_operation(+self, +other)
def _is_strict_float(self) -> bool:
"""Check if the number is a float.
diff --git a/reflex/vars/object.py b/reflex/vars/object.py
index 5de431f5a..7b951c559 100644
--- a/reflex/vars/object.py
+++ b/reflex/vars/object.py
@@ -8,8 +8,8 @@ import typing
from inspect import isclass
from typing import (
Any,
- Dict,
List,
+ Mapping,
NoReturn,
Tuple,
Type,
@@ -19,6 +19,8 @@ from typing import (
overload,
)
+from typing_extensions import is_typeddict
+
from reflex.utils import types
from reflex.utils.exceptions import VarAttributeError
from reflex.utils.types import GenericType, get_attribute_access_type, get_origin
@@ -36,7 +38,7 @@ from .base import (
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
from .sequence import ArrayVar, StringVar
-OBJECT_TYPE = TypeVar("OBJECT_TYPE")
+OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)
KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE")
@@ -46,7 +48,7 @@ ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE")
OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE")
-class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
+class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
"""Base class for immutable object vars."""
def _key_type(self) -> Type:
@@ -59,7 +61,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def _value_type(
- self: ObjectVar[Dict[Any, VALUE_TYPE]],
+ self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> Type[VALUE_TYPE]: ...
@overload
@@ -74,7 +76,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
fixed_type = get_origin(self._var_type) or self._var_type
if not isclass(fixed_type):
return Any
- args = get_args(self._var_type) if issubclass(fixed_type, dict) else ()
+ args = get_args(self._var_type) if issubclass(fixed_type, Mapping) else ()
return args[1] if args else Any
def keys(self) -> ArrayVar[List[str]]:
@@ -87,7 +89,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def values(
- self: ObjectVar[Dict[Any, VALUE_TYPE]],
+ self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[VALUE_TYPE]]: ...
@overload
@@ -103,7 +105,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def entries(
- self: ObjectVar[Dict[Any, VALUE_TYPE]],
+ self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ...
@overload
@@ -133,49 +135,55 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any
@overload
def __getitem__(
- self: ObjectVar[Dict[Any, NoReturn]],
+ self: ObjectVar[Mapping[Any, NoReturn]],
key: Var | Any,
) -> Var: ...
+ @overload
+ def __getitem__(
+ self: (ObjectVar[Mapping[Any, bool]]),
+ key: Var | Any,
+ ) -> BooleanVar: ...
+
@overload
def __getitem__(
self: (
- ObjectVar[Dict[Any, int]]
- | ObjectVar[Dict[Any, float]]
- | ObjectVar[Dict[Any, int | float]]
+ ObjectVar[Mapping[Any, int]]
+ | ObjectVar[Mapping[Any, float]]
+ | ObjectVar[Mapping[Any, int | float]]
),
key: Var | Any,
) -> NumberVar: ...
@overload
def __getitem__(
- self: ObjectVar[Dict[Any, str]],
+ self: ObjectVar[Mapping[Any, str]],
key: Var | Any,
) -> StringVar: ...
@overload
def __getitem__(
- self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
+ self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
- self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
+ self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
- self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
+ self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
key: Var | Any,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload
def __getitem__(
- self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
+ self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
key: Var | Any,
- ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
+ ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
def __getitem__(self, key: Var | Any) -> Var:
"""Get an item from the object.
@@ -195,49 +203,49 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any
@overload
def __getattr__(
- self: ObjectVar[Dict[Any, NoReturn]],
+ self: ObjectVar[Mapping[Any, NoReturn]],
name: str,
) -> Var: ...
@overload
def __getattr__(
self: (
- ObjectVar[Dict[Any, int]]
- | ObjectVar[Dict[Any, float]]
- | ObjectVar[Dict[Any, int | float]]
+ ObjectVar[Mapping[Any, int]]
+ | ObjectVar[Mapping[Any, float]]
+ | ObjectVar[Mapping[Any, int | float]]
),
name: str,
) -> NumberVar: ...
@overload
def __getattr__(
- self: ObjectVar[Dict[Any, str]],
+ self: ObjectVar[Mapping[Any, str]],
name: str,
) -> StringVar: ...
@overload
def __getattr__(
- self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
+ self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
- self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
+ self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
- self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
+ self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
name: str,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload
def __getattr__(
- self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
+ self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
name: str,
- ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
+ ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
@overload
def __getattr__(
@@ -266,8 +274,11 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
var_type = get_args(var_type)[0]
fixed_type = var_type if isclass(var_type) else get_origin(var_type)
- if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or (
- fixed_type in types.UnionTypes
+
+ if (
+ (isclass(fixed_type) and not issubclass(fixed_type, Mapping))
+ or (fixed_type in types.UnionTypes)
+ or is_typeddict(fixed_type)
):
attribute_type = get_attribute_access_type(var_type, name)
if attribute_type is None:
@@ -299,7 +310,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
"""Base class for immutable literal object vars."""
- _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
+ _var_value: Mapping[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
default_factory=dict
)
@@ -383,7 +394,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
@classmethod
def create(
cls,
- _var_value: dict,
+ _var_value: Mapping,
_var_type: Type[OBJECT_TYPE] | None = None,
_var_data: VarData | None = None,
) -> LiteralObjectVar[OBJECT_TYPE]:
@@ -466,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
"""
return var_operation_return(
js_expression=f"({{...{lhs}, ...{rhs}}})",
- var_type=Dict[
+ var_type=Mapping[
Union[lhs._key_type(), rhs._key_type()],
Union[lhs._value_type(), rhs._value_type()],
],
diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py
index 5864e70b9..1b11f50e6 100644
--- a/reflex/vars/sequence.py
+++ b/reflex/vars/sequence.py
@@ -987,7 +987,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
raise_unsupported_operand_types("[]", (type(self), type(i)))
return array_item_operation(self, i)
- def length(self) -> NumberVar:
+ def length(self) -> NumberVar[int]:
"""Get the length of the array.
Returns:
diff --git a/tests/integration/test_computed_vars.py b/tests/integration/test_computed_vars.py
index 03aaf18b4..f56001ea8 100644
--- a/tests/integration/test_computed_vars.py
+++ b/tests/integration/test_computed_vars.py
@@ -22,22 +22,22 @@ def ComputedVars():
count: int = 0
# cached var with dep on count
- @rx.var(cache=True, interval=15)
+ @rx.var(interval=15)
def count1(self) -> int:
return self.count
# cached backend var with dep on count
- @rx.var(cache=True, interval=15, backend=True)
+ @rx.var(interval=15, backend=True)
def count1_backend(self) -> int:
return self.count
# same as above but implicit backend with `_` prefix
- @rx.var(cache=True, interval=15)
+ @rx.var(interval=15)
def _count1_backend(self) -> int:
return self.count
# explicit disabled auto_deps
- @rx.var(interval=15, cache=True, auto_deps=False)
+ @rx.var(interval=15, auto_deps=False)
def count3(self) -> int:
# this will not add deps, because auto_deps is False
print(self.count1)
@@ -45,19 +45,27 @@ def ComputedVars():
return self.count
# explicit dependency on count var
- @rx.var(cache=True, deps=["count"], auto_deps=False)
+ @rx.var(deps=["count"], auto_deps=False)
def depends_on_count(self) -> int:
return self.count
# explicit dependency on count1 var
- @rx.var(cache=True, deps=[count1], auto_deps=False)
+ @rx.var(deps=[count1], auto_deps=False)
def depends_on_count1(self) -> int:
return self.count
- @rx.var(deps=[count3], auto_deps=False, cache=True)
+ @rx.var(
+ deps=[count3],
+ auto_deps=False,
+ )
def depends_on_count3(self) -> int:
return self.count
+ # special floats should be properly decoded on the frontend
+ @rx.var(cache=True, initial_value=[])
+ def special_floats(self) -> list[float]:
+ return [42.9, float("nan"), float("inf"), float("-inf")]
+
@rx.event
def increment(self):
self.count += 1
@@ -103,6 +111,11 @@ def ComputedVars():
State.depends_on_count3,
id="depends_on_count3",
),
+ rx.text("special_floats:"),
+ rx.text(
+ State.special_floats.join(", "),
+ id="special_floats",
+ ),
),
)
@@ -224,6 +237,10 @@ async def test_computed_vars(
assert depends_on_count3
assert depends_on_count3.text == "0"
+ special_floats = driver.find_element(By.ID, "special_floats")
+ assert special_floats
+ assert special_floats.text == "42.9, NaN, Infinity, -Infinity"
+
increment = driver.find_element(By.ID, "increment")
assert increment.is_enabled()
diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py
index 735dbd243..c2a912af6 100644
--- a/tests/integration/test_connection_banner.py
+++ b/tests/integration/test_connection_banner.py
@@ -71,9 +71,10 @@ def has_error_modal(driver: WebDriver) -> bool:
"""
try:
driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH)
- return True
except NoSuchElementException:
return False
+ else:
+ return True
@pytest.mark.asyncio
diff --git a/tests/integration/test_dynamic_routes.py b/tests/integration/test_dynamic_routes.py
index 50d9f23b1..9032fd84c 100644
--- a/tests/integration/test_dynamic_routes.py
+++ b/tests/integration/test_dynamic_routes.py
@@ -74,16 +74,16 @@ def DynamicRoute():
class ArgState(rx.State):
"""The app state."""
- @rx.var
+ @rx.var(cache=False)
def arg(self) -> int:
return int(self.arg_str or 0)
class ArgSubState(ArgState):
- @rx.var(cache=True)
+ @rx.var
def cached_arg(self) -> int:
return self.arg
- @rx.var(cache=True)
+ @rx.var
def cached_arg_str(self) -> str:
return self.arg_str
diff --git a/tests/integration/test_lifespan.py b/tests/integration/test_lifespan.py
index 0fa4a7e92..d79273fbc 100644
--- a/tests/integration/test_lifespan.py
+++ b/tests/integration/test_lifespan.py
@@ -36,7 +36,7 @@ def LifespanApp():
print("Lifespan global started.")
try:
while True:
- lifespan_task_global += inc # pyright: ignore[reportUnboundVariable]
+ lifespan_task_global += inc # pyright: ignore[reportUnboundVariable, reportPossiblyUnboundVariable]
await asyncio.sleep(0.1)
except asyncio.CancelledError as ce:
print(f"Lifespan global cancelled: {ce}.")
@@ -45,11 +45,11 @@ def LifespanApp():
class LifespanState(rx.State):
interval: int = 100
- @rx.var
+ @rx.var(cache=False)
def task_global(self) -> int:
return lifespan_task_global
- @rx.var
+ @rx.var(cache=False)
def context_global(self) -> int:
return lifespan_context_global
diff --git a/tests/integration/test_media.py b/tests/integration/test_media.py
index 10af26591..649038a7e 100644
--- a/tests/integration/test_media.py
+++ b/tests/integration/test_media.py
@@ -22,31 +22,31 @@ def MediaApp():
img.format = format # type: ignore
return img
- @rx.var(cache=True)
+ @rx.var
def img_default(self) -> Image.Image:
return self._blue()
- @rx.var(cache=True)
+ @rx.var
def img_bmp(self) -> Image.Image:
return self._blue(format="BMP")
- @rx.var(cache=True)
+ @rx.var
def img_jpg(self) -> Image.Image:
return self._blue(format="JPEG")
- @rx.var(cache=True)
+ @rx.var
def img_png(self) -> Image.Image:
return self._blue(format="PNG")
- @rx.var(cache=True)
+ @rx.var
def img_gif(self) -> Image.Image:
return self._blue(format="GIF")
- @rx.var(cache=True)
+ @rx.var
def img_webp(self) -> Image.Image:
return self._blue(format="WEBP")
- @rx.var(cache=True)
+ @rx.var
def img_from_url(self) -> Image.Image:
img_url = "https://picsum.photos/id/1/200/300"
img_resp = httpx.get(img_url, follow_redirects=True)
diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py
index f09e800e5..eaf98414d 100644
--- a/tests/units/components/core/test_match.py
+++ b/tests/units/components/core/test_match.py
@@ -1,4 +1,4 @@
-from typing import Dict, List, Tuple
+from typing import List, Mapping, Tuple
import pytest
@@ -67,7 +67,7 @@ def test_match_components():
assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}'
assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })'
- assert match_cases[4][0]._var_type == Dict[str, str]
+ assert match_cases[4][0]._var_type == Mapping[str, str]
fifth_return_value_render = match_cases[4][1].render()
assert fifth_return_value_render["name"] == "RadixThemesText"
assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}'
diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py
index 674873b69..6396e4322 100644
--- a/tests/units/components/test_component.py
+++ b/tests/units/components/test_component.py
@@ -27,7 +27,7 @@ from reflex.event import (
from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
-from reflex.utils.exceptions import EventFnArgMismatch
+from reflex.utils.exceptions import ChildrenTypeError, EventFnArgMismatch
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
@@ -645,14 +645,17 @@ def test_create_filters_none_props(test_component):
assert str(component.style["text-align"]) == '"center"'
-@pytest.mark.parametrize("children", [((None,),), ("foo", ("bar", (None,)))])
+@pytest.mark.parametrize(
+ "children",
+ [
+ ((None,),),
+ ("foo", ("bar", (None,))),
+ ({"foo": "bar"},),
+ ],
+)
def test_component_create_unallowed_types(children, test_component):
- with pytest.raises(TypeError) as err:
+ with pytest.raises(ChildrenTypeError):
test_component.create(*children)
- assert (
- err.value.args[0]
- == "Children of Reflex components must be other components, state vars, or primitive Python types. Got child None of type ."
- )
@pytest.mark.parametrize(
diff --git a/tests/units/test_app.py b/tests/units/test_app.py
index a09fde972..80e0be5fd 100644
--- a/tests/units/test_app.py
+++ b/tests/units/test_app.py
@@ -908,7 +908,7 @@ class DynamicState(BaseState):
"""Increment the counter var."""
self.counter = self.counter + 1
- @computed_var(cache=True)
+ @computed_var
def comp_dynamic(self) -> str:
"""A computed var that depends on the dynamic var.
@@ -1549,11 +1549,11 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
base: int = 0
_backend: int = 0
- @computed_var(cache=True)
+ @computed_var()
def foo(self) -> str:
return "foo"
- @computed_var(deps=["_backend", "base", foo], cache=True)
+ @computed_var(deps=["_backend", "base", foo])
def bar(self) -> str:
return "bar"
@@ -1565,7 +1565,7 @@ def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]):
app, _ = compilable_app
class InvalidDepState(BaseState):
- @computed_var(deps=["foolksjdf"], cache=True)
+ @computed_var(deps=["foolksjdf"])
def bar(self) -> str:
return "bar"
diff --git a/tests/units/test_state.py b/tests/units/test_state.py
index cf3363770..05633b75f 100644
--- a/tests/units/test_state.py
+++ b/tests/units/test_state.py
@@ -202,7 +202,7 @@ class GrandchildState(ChildState):
class GrandchildState2(ChildState2):
"""A grandchild state fixture."""
- @rx.var(cache=True)
+ @rx.var
def cached(self) -> str:
"""A cached var.
@@ -215,7 +215,7 @@ class GrandchildState2(ChildState2):
class GrandchildState3(ChildState3):
"""A great grandchild state fixture."""
- @rx.var
+ @rx.var(cache=False)
def computed(self) -> str:
"""A computed var.
@@ -796,7 +796,7 @@ async def test_process_event_simple(test_state):
# The delta should contain the changes, including computed vars.
assert update.delta == {
- TestState.get_full_name(): {"num1": 69, "sum": 72.14, "upper": ""},
+ TestState.get_full_name(): {"num1": 69, "sum": 72.14},
GrandchildState3.get_full_name(): {"computed": ""},
}
assert update.events == []
@@ -823,7 +823,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
assert child_state.value == "HI"
assert child_state.count == 24
assert update.delta == {
- TestState.get_full_name(): {"sum": 3.14, "upper": ""},
+ # TestState.get_full_name(): {"sum": 3.14, "upper": ""},
ChildState.get_full_name(): {"value": "HI", "count": 24},
GrandchildState3.get_full_name(): {"computed": ""},
}
@@ -839,7 +839,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
update = await test_state._process(event).__anext__()
assert grandchild_state.value2 == "new"
assert update.delta == {
- TestState.get_full_name(): {"sum": 3.14, "upper": ""},
+ # TestState.get_full_name(): {"sum": 3.14, "upper": ""},
GrandchildState.get_full_name(): {"value2": "new"},
GrandchildState3.get_full_name(): {"computed": ""},
}
@@ -989,7 +989,7 @@ class InterdependentState(BaseState):
v1: int = 0
_v2: int = 1
- @rx.var(cache=True)
+ @rx.var
def v1x2(self) -> int:
"""Depends on var v1.
@@ -998,7 +998,7 @@ class InterdependentState(BaseState):
"""
return self.v1 * 2
- @rx.var(cache=True)
+ @rx.var
def v2x2(self) -> int:
"""Depends on backend var _v2.
@@ -1007,7 +1007,7 @@ class InterdependentState(BaseState):
"""
return self._v2 * 2
- @rx.var(cache=True, backend=True)
+ @rx.var(backend=True)
def v2x2_backend(self) -> int:
"""Depends on backend var _v2.
@@ -1016,7 +1016,7 @@ class InterdependentState(BaseState):
"""
return self._v2 * 2
- @rx.var(cache=True)
+ @rx.var
def v1x2x2(self) -> int:
"""Depends on ComputedVar v1x2.
@@ -1025,7 +1025,7 @@ class InterdependentState(BaseState):
"""
return self.v1x2 * 2 # type: ignore
- @rx.var(cache=True)
+ @rx.var
def _v3(self) -> int:
"""Depends on backend var _v2.
@@ -1034,7 +1034,7 @@ class InterdependentState(BaseState):
"""
return self._v2
- @rx.var(cache=True)
+ @rx.var
def v3x2(self) -> int:
"""Depends on ComputedVar _v3.
@@ -1239,7 +1239,7 @@ def test_computed_var_cached():
class ComputedState(BaseState):
v: int = 0
- @rx.var(cache=True)
+ @rx.var
def comp_v(self) -> int:
nonlocal comp_v_calls
comp_v_calls += 1
@@ -1264,15 +1264,15 @@ def test_computed_var_cached_depends_on_non_cached():
class ComputedState(BaseState):
v: int = 0
- @rx.var
+ @rx.var(cache=False)
def no_cache_v(self) -> int:
return self.v
- @rx.var(cache=True)
+ @rx.var
def dep_v(self) -> int:
return self.no_cache_v # type: ignore
- @rx.var(cache=True)
+ @rx.var
def comp_v(self) -> int:
return self.v
@@ -1304,14 +1304,14 @@ def test_computed_var_depends_on_parent_non_cached():
counter = 0
class ParentState(BaseState):
- @rx.var
+ @rx.var(cache=False)
def no_cache_v(self) -> int:
nonlocal counter
counter += 1
return counter
class ChildState(ParentState):
- @rx.var(cache=True)
+ @rx.var
def dep_v(self) -> int:
return self.no_cache_v # type: ignore
@@ -1357,7 +1357,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
def handler(self):
self.x = self.x + 1
- @rx.var(cache=True)
+ @rx.var
def cached_x_side_effect(self) -> int:
self.handler()
nonlocal counter
@@ -1393,7 +1393,7 @@ def test_computed_var_dependencies():
def testprop(self) -> int:
return self.v
- @rx.var(cache=True)
+ @rx.var
def comp_v(self) -> int:
"""Direct access.
@@ -1402,7 +1402,7 @@ def test_computed_var_dependencies():
"""
return self.v
- @rx.var(cache=True, backend=True)
+ @rx.var(backend=True)
def comp_v_backend(self) -> int:
"""Direct access backend var.
@@ -1411,7 +1411,7 @@ def test_computed_var_dependencies():
"""
return self.v
- @rx.var(cache=True)
+ @rx.var
def comp_v_via_property(self) -> int:
"""Access v via property.
@@ -1420,7 +1420,7 @@ def test_computed_var_dependencies():
"""
return self.testprop
- @rx.var(cache=True)
+ @rx.var
def comp_w(self):
"""Nested lambda.
@@ -1429,7 +1429,7 @@ def test_computed_var_dependencies():
"""
return lambda: self.w
- @rx.var(cache=True)
+ @rx.var
def comp_x(self):
"""Nested function.
@@ -1442,7 +1442,7 @@ def test_computed_var_dependencies():
return _
- @rx.var(cache=True)
+ @rx.var
def comp_y(self) -> List[int]:
"""Comprehension iterating over attribute.
@@ -1451,7 +1451,7 @@ def test_computed_var_dependencies():
"""
return [round(y) for y in self.y]
- @rx.var(cache=True)
+ @rx.var
def comp_z(self) -> List[bool]:
"""Comprehension accesses attribute.
@@ -2027,10 +2027,6 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
assert mcall.args[0] == str(SocketEvent.EVENT)
assert mcall.args[1] == StateUpdate(
delta={
- parent_state.get_full_name(): {
- "upper": "",
- "sum": 3.14,
- },
grandchild_state.get_full_name(): {
"value2": "42",
},
@@ -2053,7 +2049,7 @@ class BackgroundTaskState(BaseState):
super().__init__(**kwargs)
self.router_data = {"simulate": "hydrate"}
- @rx.var
+ @rx.var(cache=False)
def computed_order(self) -> List[str]:
"""Get the order as a computed var.
@@ -3040,10 +3036,6 @@ async def test_get_state(mock_app: rx.App, token: str):
grandchild_state.value2 = "set_value"
assert test_state.get_delta() == {
- TestState.get_full_name(): {
- "sum": 3.14,
- "upper": "",
- },
GrandchildState.get_full_name(): {
"value2": "set_value",
},
@@ -3081,10 +3073,6 @@ async def test_get_state(mock_app: rx.App, token: str):
child_state2.value = "set_c2_value"
assert new_test_state.get_delta() == {
- TestState.get_full_name(): {
- "sum": 3.14,
- "upper": "",
- },
ChildState2.get_full_name(): {
"value": "set_c2_value",
},
@@ -3139,7 +3127,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
child3_var: int = 0
- @rx.var
+ @rx.var(cache=False)
def v(self):
pass
@@ -3210,8 +3198,8 @@ def test_potentially_dirty_substates():
def bar(self) -> str:
return ""
- assert RxState._potentially_dirty_substates() == {State}
- assert State._potentially_dirty_substates() == {C1}
+ assert RxState._potentially_dirty_substates() == set()
+ assert State._potentially_dirty_substates() == set()
assert C1._potentially_dirty_substates() == set()
@@ -3226,7 +3214,7 @@ def test_router_var_dep() -> None:
class RouterVarDepState(RouterVarParentState):
"""A state with a router var dependency."""
- @rx.var(cache=True)
+ @rx.var
def foo(self) -> str:
return self.router.page.params.get("foo", "")
@@ -3421,7 +3409,7 @@ class MixinState(State, mixin=True):
_backend: int = 0
_backend_no_default: dict
- @rx.var(cache=True)
+ @rx.var
def computed(self) -> str:
"""A computed var on mixin state.
diff --git a/tests/units/test_state_tree.py b/tests/units/test_state_tree.py
index 44ff58818..70ef71cb8 100644
--- a/tests/units/test_state_tree.py
+++ b/tests/units/test_state_tree.py
@@ -42,7 +42,7 @@ class SubA_A_A_A(SubA_A_A):
class SubA_A_A_B(SubA_A_A):
"""SubA_A_A_B is a child of SubA_A_A."""
- @rx.var(cache=True)
+ @rx.var
def sub_a_a_a_cached(self) -> int:
"""A cached var.
@@ -117,7 +117,7 @@ class TreeD(Root):
d: int
- @rx.var
+ @rx.var(cache=False)
def d_var(self) -> int:
"""A computed var.
@@ -156,7 +156,7 @@ class SubE_A_A_A_A(SubE_A_A_A):
sub_e_a_a_a_a: int
- @rx.var
+ @rx.var(cache=False)
def sub_e_a_a_a_a_var(self) -> int:
"""A computed var.
@@ -183,7 +183,7 @@ class SubE_A_A_A_D(SubE_A_A_A):
sub_e_a_a_a_d: int
- @rx.var(cache=True)
+ @rx.var
def sub_e_a_a_a_d_var(self) -> int:
"""A computed var.
diff --git a/tests/units/test_style.py b/tests/units/test_style.py
index e1d652798..bb585fd22 100644
--- a/tests/units/test_style.py
+++ b/tests/units/test_style.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from typing import Any, Dict
+from typing import Any, Mapping
import pytest
@@ -379,7 +379,7 @@ class StyleState(rx.State):
{
"css": Var(
_js_expr=f'({{ ["color"] : ("dark"+{StyleState.color}) }})'
- ).to(Dict[str, str])
+ ).to(Mapping[str, str])
},
),
(
diff --git a/tests/units/test_var.py b/tests/units/test_var.py
index bfa8aa35a..a8e9cd88c 100644
--- a/tests/units/test_var.py
+++ b/tests/units/test_var.py
@@ -2,7 +2,7 @@ import json
import math
import sys
import typing
-from typing import Dict, List, Optional, Set, Tuple, Union, cast
+from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast
import pytest
from pandas import DataFrame
@@ -270,7 +270,7 @@ def test_get_setter(prop: Var, expected):
([1, 2, 3], Var(_js_expr="[1, 2, 3]", _var_type=List[int])),
(
{"a": 1, "b": 2},
- Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Dict[str, int]),
+ Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Mapping[str, int]),
),
],
)
@@ -1004,7 +1004,7 @@ def test_all_number_operations():
assert (
str(even_more_complicated_number)
- == "!(((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)))) !== 0))"
+ == "!(isTrue((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))))))"
)
assert str(LiteralNumberVar.create(5) > False) == "(5 > 0)"
@@ -1814,10 +1814,7 @@ def cv_fget(state: BaseState) -> int:
],
)
def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
- @computed_var(
- deps=deps,
- cache=True,
- )
+ @computed_var(deps=deps)
def test_var(state) -> int:
return 1
@@ -1835,10 +1832,7 @@ def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
def test_invalid_computed_var_deps(deps: List):
with pytest.raises(TypeError):
- @computed_var(
- deps=deps,
- cache=True,
- )
+ @computed_var(deps=deps)
def test_var(state) -> int:
return 1
diff --git a/tests/units/vars/test_base.py b/tests/units/vars/test_base.py
index 68bc0c38e..e4ae7327a 100644
--- a/tests/units/vars/test_base.py
+++ b/tests/units/vars/test_base.py
@@ -1,4 +1,4 @@
-from typing import Dict, List, Union
+from typing import List, Mapping, Union
import pytest
@@ -37,12 +37,12 @@ class ChildGenericDict(GenericDict):
("a", str),
([1, 2, 3], List[int]),
([1, 2.0, "a"], List[Union[int, float, str]]),
- ({"a": 1, "b": 2}, Dict[str, int]),
- ({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]),
+ ({"a": 1, "b": 2}, Mapping[str, int]),
+ ({"a": 1, 2: "b"}, Mapping[Union[int, str], Union[str, int]]),
(CustomDict(), CustomDict),
(ChildCustomDict(), ChildCustomDict),
- (GenericDict({1: 1}), Dict[int, int]),
- (ChildGenericDict({1: 1}), Dict[int, int]),
+ (GenericDict({1: 1}), Mapping[int, int]),
+ (ChildGenericDict({1: 1}), Mapping[int, int]),
],
)
def test_figure_out_type(value, expected):