Merge branch 'main' into maintenance/removing_deprecated_features

This commit is contained in:
Lendemor 2025-01-21 23:39:28 +01:00
commit 85b6b446e5
31 changed files with 488 additions and 174 deletions

View File

@ -33,7 +33,7 @@ env:
PR_TITLE: ${{ github.event.pull_request.title }} PR_TITLE: ${{ github.event.pull_request.title }}
jobs: jobs:
example-counter: example-counter-and-nba-proxy:
env: env:
OUTPUT_FILE: import_benchmark.json OUTPUT_FILE: import_benchmark.json
timeout-minutes: 30 timeout-minutes: 30
@ -114,6 +114,26 @@ jobs:
--benchmark-json "./reflex-examples/counter/${{ env.OUTPUT_FILE }}" --benchmark-json "./reflex-examples/counter/${{ env.OUTPUT_FILE }}"
--branch-name "${{ github.head_ref || github.ref_name }}" --pr-id "${{ github.event.pull_request.id }}" --branch-name "${{ github.head_ref || github.ref_name }}" --pr-id "${{ github.event.pull_request.id }}"
--app-name "counter" --app-name "counter"
- name: Install requirements for nba proxy example
working-directory: ./reflex-examples/nba-proxy
run: |
poetry run uv pip install -r requirements.txt
- name: Install additional dependencies for DB access
run: poetry run uv pip install psycopg
- name: Check export --backend-only before init for nba-proxy example
working-directory: ./reflex-examples/nba-proxy
run: |
poetry run reflex export --backend-only
- name: Init Website for nba-proxy example
working-directory: ./reflex-examples/nba-proxy
run: |
poetry run reflex init --loglevel debug
- name: Run Website and Check for errors
run: |
# Check that npm is home
npm -v
poetry run bash scripts/integration.sh ./reflex-examples/nba-proxy dev
reflex-web: reflex-web:
strategy: strategy:

View File

@ -85,8 +85,8 @@ build-backend = "poetry.core.masonry.api"
target-version = "py39" target-version = "py39"
output-format = "concise" output-format = "concise"
lint.isort.split-on-trailing-comma = false lint.isort.split-on-trailing-comma = false
lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "T", "W"] lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "T", "TRY", "W"]
lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012"] lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012", "TRY0"]
lint.pydocstyle.convention = "google" lint.pydocstyle.convention = "google"
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]

View File

@ -3,6 +3,7 @@ import axios from "axios";
import io from "socket.io-client"; import io from "socket.io-client";
import JSON5 from "json5"; import JSON5 from "json5";
import env from "$/env.json"; import env from "$/env.json";
import reflexEnvironment from "$/reflex.json";
import Cookies from "universal-cookie"; import Cookies from "universal-cookie";
import { useEffect, useRef, useState } from "react"; import { useEffect, useRef, useState } from "react";
import Router, { useRouter } from "next/router"; import Router, { useRouter } from "next/router";
@ -407,6 +408,7 @@ export const connect = async (
socket.current = io(endpoint.href, { socket.current = io(endpoint.href, {
path: endpoint["pathname"], path: endpoint["pathname"],
transports: transports, transports: transports,
protocols: env.TEST_MODE ? undefined : [reflexEnvironment.version],
autoUnref: false, autoUnref: false,
}); });
// Ensure undefined fields in events are sent as null instead of removed // Ensure undefined fields in events are sent as null instead of removed

View File

@ -463,14 +463,8 @@ class App(MiddlewareMixin, LifespanMixin):
Returns: Returns:
The generated component. The generated component.
Raises:
exceptions.MatchTypeError: If the return types of match cases in rx.match are different.
""" """
try: return component if isinstance(component, Component) else component()
return component if isinstance(component, Component) else component()
except exceptions.MatchTypeError:
raise
def add_page( def add_page(
self, self,
@ -564,11 +558,12 @@ class App(MiddlewareMixin, LifespanMixin):
meta=meta, meta=meta,
) )
def _compile_page(self, route: str): def _compile_page(self, route: str, save_page: bool = True):
"""Compile a page. """Compile a page.
Args: Args:
route: The route of the page to compile. route: The route of the page to compile.
save_page: If True, the compiled page is saved to self.pages.
""" """
component, enable_state = compiler.compile_unevaluated_page( component, enable_state = compiler.compile_unevaluated_page(
route, self.unevaluated_pages[route], self.state, self.style, self.theme route, self.unevaluated_pages[route], self.state, self.style, self.theme
@ -579,7 +574,8 @@ class App(MiddlewareMixin, LifespanMixin):
# Add the page. # Add the page.
self._check_routes_conflict(route) self._check_routes_conflict(route)
self.pages[route] = component if save_page:
self.pages[route] = component
def get_load_events(self, route: str) -> list[IndividualEventType[[], Any]]: def get_load_events(self, route: str) -> list[IndividualEventType[[], Any]]:
"""Get the load events for a route. """Get the load events for a route.
@ -879,14 +875,16 @@ class App(MiddlewareMixin, LifespanMixin):
# If a theme component was provided, wrap the app with it # If a theme component was provided, wrap the app with it
app_wrappers[(20, "Theme")] = self.theme app_wrappers[(20, "Theme")] = self.theme
should_compile = self._should_compile()
for route in self.unevaluated_pages: for route in self.unevaluated_pages:
console.debug(f"Evaluating page: {route}") console.debug(f"Evaluating page: {route}")
self._compile_page(route) self._compile_page(route, save_page=should_compile)
# Add the optional endpoints (_upload) # Add the optional endpoints (_upload)
self._add_optional_endpoints() self._add_optional_endpoints()
if not self._should_compile(): if not should_compile:
return return
self._validate_var_dependencies() self._validate_var_dependencies()
@ -1530,7 +1528,11 @@ class EventNamespace(AsyncNamespace):
sid: The Socket.IO session id. sid: The Socket.IO session id.
environ: The request information, including HTTP headers. environ: The request information, including HTTP headers.
""" """
pass subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL", None)
if subprotocol and subprotocol != constants.Reflex.VERSION:
console.warn(
f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}."
)
def on_disconnect(self, sid): def on_disconnect(self, sid):
"""Event for when the websocket disconnects. """Event for when the websocket disconnects.
@ -1563,10 +1565,36 @@ class EventNamespace(AsyncNamespace):
Args: Args:
sid: The Socket.IO session id. sid: The Socket.IO session id.
data: The event data. data: The event data.
Raises:
EventDeserializationError: If the event data is not a dictionary.
""" """
fields = data fields = data
# Get the event.
event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS}) if isinstance(fields, str):
console.warn(
"Received event data as a string. This generally should not happen and may indicate a bug."
f" Event data: {fields}"
)
try:
fields = json.loads(fields)
except json.JSONDecodeError as ex:
raise exceptions.EventDeserializationError(
f"Failed to deserialize event data: {fields}."
) from ex
if not isinstance(fields, dict):
raise exceptions.EventDeserializationError(
f"Event data must be a dictionary, but received {fields} of type {type(fields)}."
)
try:
# Get the event.
event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS})
except (TypeError, ValueError) as ex:
raise exceptions.EventDeserializationError(
f"Failed to deserialize event data: {fields}."
) from ex
self.token_to_sid[event.token] = sid self.token_to_sid[event.token] = sid
self.sid_to_token[sid] = event.token self.sid_to_token[sid] = event.token

View File

@ -7,14 +7,13 @@ from concurrent.futures import ThreadPoolExecutor
from reflex import constants from reflex import constants
from reflex.utils import telemetry from reflex.utils import telemetry
from reflex.utils.exec import is_prod_mode from reflex.utils.exec import is_prod_mode
from reflex.utils.prerequisites import get_app from reflex.utils.prerequisites import get_and_validate_app
if constants.CompileVars.APP != "app": if constants.CompileVars.APP != "app":
raise AssertionError("unexpected variable name for 'app'") raise AssertionError("unexpected variable name for 'app'")
telemetry.send("compile") telemetry.send("compile")
app_module = get_app(reload=False) app, app_module = get_and_validate_app(reload=False)
app = getattr(app_module, constants.CompileVars.APP)
# For py3.9 compatibility when redis is used, we MUST add any decorator pages # For py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172). # before compiling the app in a thread to avoid event loop error (REF-2172).
app._apply_decorated_pages() app._apply_decorated_pages()
@ -30,7 +29,7 @@ if is_prod_mode():
# ensure only "app" is exposed. # ensure only "app" is exposed.
del app_module del app_module
del compile_future del compile_future
del get_app del get_and_validate_app
del is_prod_mode del is_prod_mode
del telemetry del telemetry
del constants del constants

View File

@ -426,20 +426,22 @@ class Component(BaseComponent, ABC):
else: else:
continue continue
def determine_key(value):
# Try to create a var from the value
key = value if isinstance(value, Var) else LiteralVar.create(value)
# Check that the var type is not None.
if key is None:
raise TypeError
return key
# Check whether the key is a component prop. # Check whether the key is a component prop.
if types._issubclass(field_type, Var): if types._issubclass(field_type, Var):
# Used to store the passed types if var type is a union. # Used to store the passed types if var type is a union.
passed_types = None passed_types = None
try: try:
# Try to create a var from the value. kwargs[key] = determine_key(value)
if isinstance(value, Var):
kwargs[key] = value
else:
kwargs[key] = LiteralVar.create(value)
# Check that the var type is not None.
if kwargs[key] is None:
raise TypeError
expected_type = fields[key].outer_type_.__args__[0] expected_type = fields[key].outer_type_.__args__[0]
# validate literal fields. # validate literal fields.
@ -702,22 +704,21 @@ class Component(BaseComponent, ABC):
# Import here to avoid circular imports. # Import here to avoid circular imports.
from reflex.components.base.bare import Bare from reflex.components.base.bare import Bare
from reflex.components.base.fragment import Fragment from reflex.components.base.fragment import Fragment
from reflex.utils.exceptions import ComponentTypeError from reflex.utils.exceptions import ChildrenTypeError
# Filter out None props # Filter out None props
props = {key: value for key, value in props.items() if value is not None} props = {key: value for key, value in props.items() if value is not None}
def validate_children(children): def validate_children(children):
for child in children: for child in children:
if isinstance(child, tuple): if isinstance(child, (tuple, list)):
validate_children(child) validate_children(child)
# Make sure the child is a valid type. # Make sure the child is a valid type.
if not types._isinstance(child, ComponentChild): if isinstance(child, dict) or not types._isinstance(
raise ComponentTypeError( child, ComponentChild
"Children of Reflex components must be other components, " ):
"state vars, or primitive Python types. " raise ChildrenTypeError(component=cls.__name__, child=child)
f"Got child {child} of type {type(child)}.",
)
# Validate all the children. # Validate all the children.
validate_children(children) validate_children(children)

View File

@ -2,13 +2,15 @@
from reflex.components.component import Component from reflex.components.component import Component
from reflex.utils import format from reflex.utils import format
from reflex.vars.base import Var from reflex.utils.imports import ImportVar
from reflex.vars.base import LiteralVar, Var
from reflex.vars.sequence import LiteralStringVar
class LucideIconComponent(Component): class LucideIconComponent(Component):
"""Lucide Icon Component.""" """Lucide Icon Component."""
library = "lucide-react@0.469.0" library = "lucide-react@0.471.1"
class Icon(LucideIconComponent): class Icon(LucideIconComponent):
@ -32,6 +34,7 @@ class Icon(LucideIconComponent):
Raises: Raises:
AttributeError: The errors tied to bad usage of the Icon component. AttributeError: The errors tied to bad usage of the Icon component.
ValueError: If the icon tag is invalid. ValueError: If the icon tag is invalid.
TypeError: If the icon name is not a string.
Returns: Returns:
The created component. The created component.
@ -39,7 +42,6 @@ class Icon(LucideIconComponent):
if children: if children:
if len(children) == 1 and isinstance(children[0], str): if len(children) == 1 and isinstance(children[0], str):
props["tag"] = children[0] props["tag"] = children[0]
children = []
else: else:
raise AttributeError( raise AttributeError(
f"Passing multiple children to Icon component is not allowed: remove positional arguments {children[1:]} to fix" f"Passing multiple children to Icon component is not allowed: remove positional arguments {children[1:]} to fix"
@ -47,24 +49,46 @@ class Icon(LucideIconComponent):
if "tag" not in props: if "tag" not in props:
raise AttributeError("Missing 'tag' keyword-argument for Icon") raise AttributeError("Missing 'tag' keyword-argument for Icon")
tag: str | Var | LiteralVar = props.pop("tag")
if isinstance(tag, LiteralVar):
if isinstance(tag, LiteralStringVar):
tag = tag._var_value
else:
raise TypeError(f"Icon name must be a string, got {type(tag)}")
elif isinstance(tag, Var):
return DynamicIcon.create(name=tag, **props)
if ( if (
not isinstance(props["tag"], str) not isinstance(tag, str)
or format.to_snake_case(props["tag"]) not in LUCIDE_ICON_LIST or format.to_snake_case(tag) not in LUCIDE_ICON_LIST
): ):
raise ValueError( raise ValueError(
f"Invalid icon tag: {props['tag']}. Please use one of the following: {', '.join(LUCIDE_ICON_LIST[0:25])}, ..." f"Invalid icon tag: {tag}. Please use one of the following: {', '.join(LUCIDE_ICON_LIST[0:25])}, ..."
"\nSee full list at https://lucide.dev/icons." "\nSee full list at https://lucide.dev/icons."
) )
if props["tag"] in LUCIDE_ICON_MAPPING_OVERRIDE: if tag in LUCIDE_ICON_MAPPING_OVERRIDE:
props["tag"] = LUCIDE_ICON_MAPPING_OVERRIDE[props["tag"]] props["tag"] = LUCIDE_ICON_MAPPING_OVERRIDE[tag]
else: else:
props["tag"] = ( props["tag"] = format.to_title_case(format.to_snake_case(tag)) + "Icon"
format.to_title_case(format.to_snake_case(props["tag"])) + "Icon"
)
props["alias"] = f"Lucide{props['tag']}" props["alias"] = f"Lucide{props['tag']}"
props.setdefault("color", "var(--current-color)") props.setdefault("color", "var(--current-color)")
return super().create(*children, **props) return super().create(**props)
class DynamicIcon(LucideIconComponent):
"""A DynamicIcon component."""
tag = "DynamicIcon"
name: Var[str]
def _get_imports(self):
_imports = super()._get_imports()
if self.library:
_imports.pop(self.library)
_imports["lucide-react/dynamic"] = [ImportVar("DynamicIcon", install=False)]
return _imports
LUCIDE_ICON_LIST = [ LUCIDE_ICON_LIST = [
@ -846,6 +870,7 @@ LUCIDE_ICON_LIST = [
"house", "house",
"house_plug", "house_plug",
"house_plus", "house_plus",
"house_wifi",
"ice_cream_bowl", "ice_cream_bowl",
"ice_cream_cone", "ice_cream_cone",
"id_card", "id_card",
@ -1534,6 +1559,7 @@ LUCIDE_ICON_LIST = [
"trending_up_down", "trending_up_down",
"triangle", "triangle",
"triangle_alert", "triangle_alert",
"triangle_dashed",
"triangle_right", "triangle_right",
"trophy", "trophy",
"truck", "truck",

View File

@ -104,12 +104,60 @@ class Icon(LucideIconComponent):
Raises: Raises:
AttributeError: The errors tied to bad usage of the Icon component. AttributeError: The errors tied to bad usage of the Icon component.
ValueError: If the icon tag is invalid. ValueError: If the icon tag is invalid.
TypeError: If the icon name is not a string.
Returns: Returns:
The created component. The created component.
""" """
... ...
class DynamicIcon(LucideIconComponent):
@overload
@classmethod
def create( # type: ignore
cls,
*children,
name: Optional[Union[Var[str], str]] = None,
style: Optional[Style] = None,
key: Optional[Any] = None,
id: Optional[Any] = None,
class_name: Optional[Any] = None,
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
on_blur: Optional[EventType[[], BASE_STATE]] = None,
on_click: Optional[EventType[[], BASE_STATE]] = None,
on_context_menu: Optional[EventType[[], BASE_STATE]] = None,
on_double_click: Optional[EventType[[], BASE_STATE]] = None,
on_focus: Optional[EventType[[], BASE_STATE]] = None,
on_mount: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_down: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_move: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_out: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_over: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_up: Optional[EventType[[], BASE_STATE]] = None,
on_scroll: Optional[EventType[[], BASE_STATE]] = None,
on_unmount: Optional[EventType[[], BASE_STATE]] = None,
**props,
) -> "DynamicIcon":
"""Create the component.
Args:
*children: The children of the component.
style: The style of the component.
key: A unique key for the component.
id: The id for the component.
class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute
**props: The props of the component.
Returns:
The component.
"""
...
LUCIDE_ICON_LIST = [ LUCIDE_ICON_LIST = [
"a_arrow_down", "a_arrow_down",
"a_arrow_up", "a_arrow_up",
@ -889,6 +937,7 @@ LUCIDE_ICON_LIST = [
"house", "house",
"house_plug", "house_plug",
"house_plus", "house_plus",
"house_wifi",
"ice_cream_bowl", "ice_cream_bowl",
"ice_cream_cone", "ice_cream_cone",
"id_card", "id_card",
@ -1577,6 +1626,7 @@ LUCIDE_ICON_LIST = [
"trending_up_down", "trending_up_down",
"triangle", "triangle",
"triangle_alert", "triangle_alert",
"triangle_dashed",
"triangle_right", "triangle_right",
"trophy", "trophy",
"truck", "truck",

View File

@ -12,6 +12,7 @@ import threading
import urllib.parse import urllib.parse
from importlib.util import find_spec from importlib.util import find_spec
from pathlib import Path from pathlib import Path
from types import ModuleType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -607,6 +608,9 @@ class Config(Base):
# The name of the app (should match the name of the app directory). # The name of the app (should match the name of the app directory).
app_name: str app_name: str
# The path to the app module.
app_module_import: Optional[str] = None
# The log level to use. # The log level to use.
loglevel: constants.LogLevel = constants.LogLevel.DEFAULT loglevel: constants.LogLevel = constants.LogLevel.DEFAULT
@ -729,6 +733,19 @@ class Config(Base):
"REDIS_URL is required when using the redis state manager." "REDIS_URL is required when using the redis state manager."
) )
@property
def app_module(self) -> ModuleType | None:
"""Return the app module if `app_module_import` is set.
Returns:
The app module.
"""
return (
importlib.import_module(self.app_module_import)
if self.app_module_import
else None
)
@property @property
def module(self) -> str: def module(self) -> str:
"""Get the module name of the app. """Get the module name of the app.
@ -736,6 +753,8 @@ class Config(Base):
Returns: Returns:
The module name. The module name.
""" """
if self.app_module is not None:
return self.app_module.__name__
return ".".join([self.app_name, self.app_name]) return ".".join([self.app_name, self.app_name])
def update_from_env(self) -> dict[str, Any]: def update_from_env(self) -> dict[str, Any]:
@ -874,7 +893,7 @@ def get_config(reload: bool = False) -> Config:
return cached_rxconfig.config return cached_rxconfig.config
with _config_lock: with _config_lock:
sys_path = sys.path.copy() orig_sys_path = sys.path.copy()
sys.path.clear() sys.path.clear()
sys.path.append(str(Path.cwd())) sys.path.append(str(Path.cwd()))
try: try:
@ -882,9 +901,14 @@ def get_config(reload: bool = False) -> Config:
return _get_config() return _get_config()
except Exception: except Exception:
# If the module import fails, try to import with the original sys.path. # If the module import fails, try to import with the original sys.path.
sys.path.extend(sys_path) sys.path.extend(orig_sys_path)
return _get_config() return _get_config()
finally: finally:
# Find any entries added to sys.path by rxconfig.py itself.
extra_paths = [
p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd())
]
# Restore the original sys.path. # Restore the original sys.path.
sys.path.clear() sys.path.clear()
sys.path.extend(sys_path) sys.path.extend(extra_paths)
sys.path.extend(orig_sys_path)

View File

@ -1,6 +1,7 @@
"""The constants package.""" """The constants package."""
from .base import ( from .base import (
APP_HARNESS_FLAG,
COOKIES, COOKIES,
IS_LINUX, IS_LINUX,
IS_MACOS, IS_MACOS,

View File

@ -257,6 +257,7 @@ SESSION_STORAGE = "session_storage"
# Testing variables. # Testing variables.
# Testing os env set by pytest when running a test case. # Testing os env set by pytest when running a test case.
PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST" PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
APP_HARNESS_FLAG = "APP_HARNESS_FLAG"
REFLEX_VAR_OPENING_TAG = "<reflex.Var>" REFLEX_VAR_OPENING_TAG = "<reflex.Var>"
REFLEX_VAR_CLOSING_TAG = "</reflex.Var>" REFLEX_VAR_CLOSING_TAG = "</reflex.Var>"

View File

@ -421,12 +421,13 @@ def _run_commands_in_subprocess(cmds: list[str]) -> bool:
console.debug(f"Running command: {' '.join(cmds)}") console.debug(f"Running command: {' '.join(cmds)}")
try: try:
result = subprocess.run(cmds, capture_output=True, text=True, check=True) result = subprocess.run(cmds, capture_output=True, text=True, check=True)
console.debug(result.stdout)
return True
except subprocess.CalledProcessError as cpe: except subprocess.CalledProcessError as cpe:
console.error(cpe.stdout) console.error(cpe.stdout)
console.error(cpe.stderr) console.error(cpe.stderr)
return False return False
else:
console.debug(result.stdout)
return True
def _make_pyi_files(): def _make_pyi_files():
@ -931,10 +932,11 @@ def _get_file_from_prompt_in_loop() -> Tuple[bytes, str] | None:
file_extension = image_filepath.suffix file_extension = image_filepath.suffix
try: try:
image_file = image_filepath.read_bytes() image_file = image_filepath.read_bytes()
return image_file, file_extension
except OSError as ose: except OSError as ose:
console.error(f"Unable to read the {file_extension} file due to {ose}") console.error(f"Unable to read the {file_extension} file due to {ose}")
raise typer.Exit(code=1) from ose raise typer.Exit(code=1) from ose
else:
return image_file, file_extension
console.debug(f"File extension detected: {file_extension}") console.debug(f"File extension detected: {file_extension}")
return None return None

View File

@ -1534,7 +1534,7 @@ def get_handler_args(
def fix_events( def fix_events(
events: list[EventHandler | EventSpec] | None, events: list[EventSpec | EventHandler] | None,
token: str, token: str,
router_data: dict[str, Any] | None = None, router_data: dict[str, Any] | None = None,
) -> list[Event]: ) -> list[Event]:

View File

@ -440,7 +440,11 @@ def deploy(
config.app_name, config.app_name,
"--app-name", "--app-name",
help="The name of the App to deploy under.", help="The name of the App to deploy under.",
hidden=True, ),
app_id: str = typer.Option(
None,
"--app-id",
help="The ID of the App to deploy over.",
), ),
regions: List[str] = typer.Option( regions: List[str] = typer.Option(
[], [],
@ -480,6 +484,11 @@ def deploy(
"--project", "--project",
help="project id to deploy to", help="project id to deploy to",
), ),
project_name: Optional[str] = typer.Option(
None,
"--project-name",
help="The name of the project to deploy to.",
),
token: Optional[str] = typer.Option( token: Optional[str] = typer.Option(
None, None,
"--token", "--token",
@ -503,13 +512,6 @@ def deploy(
# Set the log level. # Set the log level.
console.set_log_level(loglevel) console.set_log_level(loglevel)
if not token:
# make sure user is logged in.
if interactive:
hosting_cli.login()
else:
raise SystemExit("Token is required for non-interactive mode.")
# Only check requirements if interactive. # Only check requirements if interactive.
# There is user interaction for requirements update. # There is user interaction for requirements update.
if interactive: if interactive:
@ -526,6 +528,7 @@ def deploy(
hosting_cli.deploy( hosting_cli.deploy(
app_name=app_name, app_name=app_name,
app_id=app_id,
export_fn=lambda zip_dest_dir, export_fn=lambda zip_dest_dir,
api_url, api_url,
deploy_url, deploy_url,
@ -549,6 +552,8 @@ def deploy(
loglevel=type(loglevel).INFO, # type: ignore loglevel=type(loglevel).INFO, # type: ignore
token=token, token=token,
project=project, project=project,
config_path=config_path,
project_name=project_name,
**extra, **extra,
) )

View File

@ -1773,9 +1773,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
except Exception as ex: except Exception as ex:
state._clean() state._clean()
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) event_specs = (
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
event_specs = app_instance.backend_exception_handler(ex) )
if event_specs is None: if event_specs is None:
return StateUpdate() return StateUpdate()
@ -1885,9 +1885,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
except Exception as ex: except Exception as ex:
telemetry.send_error(ex, context="backend") telemetry.send_error(ex, context="backend")
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) event_specs = (
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
event_specs = app_instance.backend_exception_handler(ex) )
yield state._as_state_update( yield state._as_state_update(
handler, handler,
@ -2400,8 +2400,9 @@ class FrontendEventExceptionState(State):
component_stack: The stack trace of the component where the exception occurred. component_stack: The stack trace of the component where the exception occurred.
""" """
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) prerequisites.get_and_validate_app().app.frontend_exception_handler(
app_instance.frontend_exception_handler(Exception(stack)) Exception(stack)
)
class UpdateVarsInternalState(State): class UpdateVarsInternalState(State):
@ -2439,15 +2440,16 @@ class OnLoadInternalState(State):
The list of events to queue for on load handling. The list of events to queue for on load handling.
""" """
# Do not app._compile()! It should be already compiled by now. # Do not app._compile()! It should be already compiled by now.
app = getattr(prerequisites.get_app(), constants.CompileVars.APP) load_events = prerequisites.get_and_validate_app().app.get_load_events(
load_events = app.get_load_events(self.router.page.path) self.router.page.path
)
if not load_events: if not load_events:
self.is_hydrated = True self.is_hydrated = True
return # Fast path for navigation with no on_load events defined. return # Fast path for navigation with no on_load events defined.
self.is_hydrated = False self.is_hydrated = False
return [ return [
*fix_events( *fix_events(
load_events, cast(list[Union[EventSpec, EventHandler]], load_events),
self.router.session.client_token, self.router.session.client_token,
router_data=self.router_data, router_data=self.router_data,
), ),
@ -2606,7 +2608,7 @@ class StateProxy(wrapt.ObjectProxy):
""" """
super().__init__(state_instance) super().__init__(state_instance)
# compile is not relevant to backend logic # compile is not relevant to backend logic
self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP) self._self_app = prerequisites.get_and_validate_app().app
self._self_substate_path = tuple(state_instance.get_full_name().split(".")) self._self_substate_path = tuple(state_instance.get_full_name().split("."))
self._self_actx = None self._self_actx = None
self._self_mutable = False self._self_mutable = False
@ -3699,8 +3701,7 @@ def get_state_manager() -> StateManager:
Returns: Returns:
The state manager. The state manager.
""" """
app = getattr(prerequisites.get_app(), constants.CompileVars.APP) return prerequisites.get_and_validate_app().app.state_manager
return app.state_manager
class MutableProxy(wrapt.ObjectProxy): class MutableProxy(wrapt.ObjectProxy):

View File

@ -282,6 +282,7 @@ class AppHarness:
before_decorated_pages = reflex.app.DECORATED_PAGES[self.app_name].copy() before_decorated_pages = reflex.app.DECORATED_PAGES[self.app_name].copy()
# Ensure the AppHarness test does not skip State assignment due to running via pytest # Ensure the AppHarness test does not skip State assignment due to running via pytest
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None) os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
os.environ[reflex.constants.APP_HARNESS_FLAG] = "true"
self.app_module = reflex.utils.prerequisites.get_compiled_app( self.app_module = reflex.utils.prerequisites.get_compiled_app(
# Do not reload the module for pre-existing apps (only apps generated from source) # Do not reload the module for pre-existing apps (only apps generated from source)
reload=self.app_source is not None reload=self.app_source is not None

View File

@ -13,13 +13,17 @@ from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from reflex import constants from reflex import constants
from reflex.config import get_config from reflex.config import get_config
from reflex.utils import console, path_ops, prerequisites, processes from reflex.utils import console, path_ops, prerequisites, processes
from reflex.utils.exec import is_in_app_harness
def set_env_json(): def set_env_json():
"""Write the upload url to a REFLEX_JSON.""" """Write the upload url to a REFLEX_JSON."""
path_ops.update_json_file( path_ops.update_json_file(
str(prerequisites.get_web_dir() / constants.Dirs.ENV_JSON), str(prerequisites.get_web_dir() / constants.Dirs.ENV_JSON),
{endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint}, {
**{endpoint.name: endpoint.get_url() for endpoint in constants.Endpoint},
"TEST_MODE": is_in_app_harness(),
},
) )

View File

@ -1,5 +1,7 @@
"""Custom Exceptions.""" """Custom Exceptions."""
from typing import Any
class ReflexError(Exception): class ReflexError(Exception):
"""Base exception for all Reflex exceptions.""" """Base exception for all Reflex exceptions."""
@ -29,6 +31,22 @@ class ComponentTypeError(ReflexError, TypeError):
"""Custom TypeError for component related errors.""" """Custom TypeError for component related errors."""
class ChildrenTypeError(ComponentTypeError):
"""Raised when the children prop of a component is not a valid type."""
def __init__(self, component: str, child: Any):
"""Initialize the exception.
Args:
component: The name of the component.
child: The child that caused the error.
"""
super().__init__(
f"Component {component} received child {child} of type {type(child)}. "
"Accepted types are other components, state vars, or primitive Python types (dict excluded)."
)
class EventHandlerTypeError(ReflexError, TypeError): class EventHandlerTypeError(ReflexError, TypeError):
"""Custom TypeError for event handler related errors.""" """Custom TypeError for event handler related errors."""
@ -209,6 +227,10 @@ class SystemPackageMissingError(ReflexError):
) )
class EventDeserializationError(ReflexError, ValueError):
"""Raised when an event cannot be deserialized."""
class InvalidLockWarningThresholdError(ReflexError): class InvalidLockWarningThresholdError(ReflexError):
"""Raised when an invalid lock warning threshold is provided.""" """Raised when an invalid lock warning threshold is provided."""

View File

@ -240,6 +240,28 @@ def run_backend(
run_uvicorn_backend(host, port, loglevel) run_uvicorn_backend(host, port, loglevel)
def get_reload_dirs() -> list[str]:
"""Get the reload directories for the backend.
Returns:
The reload directories for the backend.
"""
config = get_config()
reload_dirs = [config.app_name]
if config.app_module is not None and config.app_module.__file__:
module_path = Path(config.app_module.__file__).resolve().parent
while module_path.parent.name:
for parent_file in module_path.parent.iterdir():
if parent_file == "__init__.py":
# go up a level to find dir without `__init__.py`
module_path = module_path.parent
break
else:
break
reload_dirs.append(str(module_path))
return reload_dirs
def run_uvicorn_backend(host, port, loglevel: LogLevel): def run_uvicorn_backend(host, port, loglevel: LogLevel):
"""Run the backend in development mode using Uvicorn. """Run the backend in development mode using Uvicorn.
@ -256,7 +278,7 @@ def run_uvicorn_backend(host, port, loglevel: LogLevel):
port=port, port=port,
log_level=loglevel.value, log_level=loglevel.value,
reload=True, reload=True,
reload_dirs=[get_config().app_name], reload_dirs=get_reload_dirs(),
) )
@ -281,7 +303,7 @@ def run_granian_backend(host, port, loglevel: LogLevel):
interface=Interfaces.ASGI, interface=Interfaces.ASGI,
log_level=LogLevels(loglevel.value), log_level=LogLevels(loglevel.value),
reload=True, reload=True,
reload_paths=[Path(get_config().app_name)], reload_paths=get_reload_dirs(),
reload_ignore_dirs=[".web"], reload_ignore_dirs=[".web"],
).serve() ).serve()
except ImportError: except ImportError:
@ -487,6 +509,15 @@ def is_testing_env() -> bool:
return constants.PYTEST_CURRENT_TEST in os.environ return constants.PYTEST_CURRENT_TEST in os.environ
def is_in_app_harness() -> bool:
"""Whether the app is running in the app harness.
Returns:
True if the app is running in the app harness.
"""
return constants.APP_HARNESS_FLAG in os.environ
def is_prod_mode() -> bool: def is_prod_mode() -> bool:
"""Check if the app is running in production mode. """Check if the app is running in production mode.

View File

@ -17,11 +17,12 @@ import stat
import sys import sys
import tempfile import tempfile
import time import time
import typing
import zipfile import zipfile
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Callable, List, Optional from typing import Callable, List, NamedTuple, Optional
import httpx import httpx
import typer import typer
@ -42,9 +43,19 @@ from reflex.utils.exceptions import (
from reflex.utils.format import format_library_name from reflex.utils.format import format_library_name
from reflex.utils.registry import _get_npm_registry from reflex.utils.registry import _get_npm_registry
if typing.TYPE_CHECKING:
from reflex.app import App
CURRENTLY_INSTALLING_NODE = False CURRENTLY_INSTALLING_NODE = False
class AppInfo(NamedTuple):
"""A tuple containing the app instance and module."""
app: App
module: ModuleType
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Template: class Template:
"""A template for a Reflex app.""" """A template for a Reflex app."""
@ -253,6 +264,22 @@ def windows_npm_escape_hatch() -> bool:
return environment.REFLEX_USE_NPM.get() return environment.REFLEX_USE_NPM.get()
def _check_app_name(config: Config):
"""Check if the app name is set in the config.
Args:
config: The config object.
Raises:
RuntimeError: If the app name is not set in the config.
"""
if not config.app_name:
raise RuntimeError(
"Cannot get the app module because `app_name` is not set in rxconfig! "
"If this error occurs in a reflex test case, ensure that `get_app` is mocked."
)
def get_app(reload: bool = False) -> ModuleType: def get_app(reload: bool = False) -> ModuleType:
"""Get the app module based on the default config. """Get the app module based on the default config.
@ -263,22 +290,23 @@ def get_app(reload: bool = False) -> ModuleType:
The app based on the default config. The app based on the default config.
Raises: Raises:
RuntimeError: If the app name is not set in the config. Exception: If an error occurs while getting the app module.
""" """
from reflex.utils import telemetry from reflex.utils import telemetry
try: try:
environment.RELOAD_CONFIG.set(reload) environment.RELOAD_CONFIG.set(reload)
config = get_config() config = get_config()
if not config.app_name:
raise RuntimeError( _check_app_name(config)
"Cannot get the app module because `app_name` is not set in rxconfig! "
"If this error occurs in a reflex test case, ensure that `get_app` is mocked."
)
module = config.module module = config.module
sys.path.insert(0, str(Path.cwd())) sys.path.insert(0, str(Path.cwd()))
app = __import__(module, fromlist=(constants.CompileVars.APP,)) app = (
__import__(module, fromlist=(constants.CompileVars.APP,))
if not config.app_module
else config.app_module
)
if reload: if reload:
from reflex.state import reload_state_module from reflex.state import reload_state_module
@ -287,11 +315,34 @@ def get_app(reload: bool = False) -> ModuleType:
# Reload the app module. # Reload the app module.
importlib.reload(app) importlib.reload(app)
return app
except Exception as ex: except Exception as ex:
telemetry.send_error(ex, context="frontend") telemetry.send_error(ex, context="frontend")
raise raise
else:
return app
def get_and_validate_app(reload: bool = False) -> AppInfo:
"""Get the app instance based on the default config and validate it.
Args:
reload: Re-import the app module from disk
Returns:
The app instance and the app module.
Raises:
RuntimeError: If the app instance is not an instance of rx.App.
"""
from reflex.app import App
app_module = get_app(reload=reload)
app = getattr(app_module, constants.CompileVars.APP)
if not isinstance(app, App):
raise RuntimeError(
"The app instance in the specified app_module_import in rxconfig must be an instance of rx.App."
)
return AppInfo(app=app, module=app_module)
def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
@ -304,8 +355,7 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
Returns: Returns:
The compiled app based on the default config. The compiled app based on the default config.
""" """
app_module = get_app(reload=reload) app, app_module = get_and_validate_app(reload=reload)
app = getattr(app_module, constants.CompileVars.APP)
# For py3.9 compatibility when redis is used, we MUST add any decorator pages # For py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172). # before compiling the app in a thread to avoid event loop error (REF-2172).
app._apply_decorated_pages() app._apply_decorated_pages()
@ -1143,11 +1193,12 @@ def ensure_reflex_installation_id() -> Optional[int]:
if installation_id is None: if installation_id is None:
installation_id = random.getrandbits(128) installation_id = random.getrandbits(128)
installation_id_file.write_text(str(installation_id)) installation_id_file.write_text(str(installation_id))
# If we get here, installation_id is definitely set
return installation_id
except Exception as e: except Exception as e:
console.debug(f"Failed to ensure reflex installation id: {e}") console.debug(f"Failed to ensure reflex installation id: {e}")
return None return None
else:
# If we get here, installation_id is definitely set
return installation_id
def initialize_reflex_user_directory(): def initialize_reflex_user_directory():
@ -1361,19 +1412,22 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str
except OSError as ose: except OSError as ose:
console.error(f"Failed to create temp directory for extracting zip: {ose}") console.error(f"Failed to create temp directory for extracting zip: {ose}")
raise typer.Exit(1) from ose raise typer.Exit(1) from ose
try: try:
zipfile.ZipFile(zip_file_path).extractall(path=unzip_dir) zipfile.ZipFile(zip_file_path).extractall(path=unzip_dir)
# The zip file downloaded from github looks like: # The zip file downloaded from github looks like:
# repo-name-branch/**/*, so we need to remove the top level directory. # repo-name-branch/**/*, so we need to remove the top level directory.
if len(subdirs := os.listdir(unzip_dir)) != 1:
console.error(f"Expected one directory in the zip, found {subdirs}")
raise typer.Exit(1)
template_dir = unzip_dir / subdirs[0]
console.debug(f"Template folder is located at {template_dir}")
except Exception as uze: except Exception as uze:
console.error(f"Failed to unzip the template: {uze}") console.error(f"Failed to unzip the template: {uze}")
raise typer.Exit(1) from uze raise typer.Exit(1) from uze
if len(subdirs := os.listdir(unzip_dir)) != 1:
console.error(f"Expected one directory in the zip, found {subdirs}")
raise typer.Exit(1)
template_dir = unzip_dir / subdirs[0]
console.debug(f"Template folder is located at {template_dir}")
# Move the rxconfig file here first. # Move the rxconfig file here first.
path_ops.mv(str(template_dir / constants.Config.FILE), constants.Config.FILE) path_ops.mv(str(template_dir / constants.Config.FILE), constants.Config.FILE)
new_config = get_config(reload=True) new_config = get_config(reload=True)

View File

@ -156,9 +156,10 @@ def _prepare_event(event: str, **kwargs) -> dict:
def _send_event(event_data: dict) -> bool: def _send_event(event_data: dict) -> bool:
try: try:
httpx.post(POSTHOG_API_URL, json=event_data) httpx.post(POSTHOG_API_URL, json=event_data)
return True
except Exception: except Exception:
return False return False
else:
return True
def _send(event, telemetry_enabled, **kwargs): def _send(event, telemetry_enabled, **kwargs):

View File

@ -829,6 +829,22 @@ StateBases = get_base_class(StateVar)
StateIterBases = get_base_class(StateIterVar) StateIterBases = get_base_class(StateIterVar)
def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]):
"""Check if a class is a subclass of another class. Returns False if internal error occurs.
Args:
cls: The class to check.
cls_check: The class to check against.
Returns:
Whether the class is a subclass of the other class.
"""
try:
return issubclass(cls, cls_check)
except TypeError:
return False
def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool: def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
"""Check if a type hint is a subclass of another type hint. """Check if a type hint is a subclass of another type hint.

View File

@ -26,6 +26,7 @@ from typing import (
Iterable, Iterable,
List, List,
Literal, Literal,
Mapping,
NoReturn, NoReturn,
Optional, Optional,
Set, Set,
@ -65,6 +66,7 @@ from reflex.utils.types import (
_isinstance, _isinstance,
get_origin, get_origin,
has_args, has_args,
safe_issubclass,
unionize, unionize,
) )
@ -128,7 +130,7 @@ class VarData:
state: str = "", state: str = "",
field_name: str = "", field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None, imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, VarData | None] | None = None, hooks: Mapping[str, VarData | None] | None = None,
deps: list[Var] | None = None, deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None, position: Hooks.HookPosition | None = None,
): ):
@ -613,8 +615,8 @@ class Var(Generic[VAR_TYPE]):
@overload @overload
def to( def to(
self, self,
output: type[dict], output: type[Mapping],
) -> ObjectVar[dict]: ... ) -> ObjectVar[Mapping]: ...
@overload @overload
def to( def to(
@ -656,7 +658,9 @@ class Var(Generic[VAR_TYPE]):
# If the first argument is a python type, we map it to the corresponding Var type. # If the first argument is a python type, we map it to the corresponding Var type.
for var_subclass in _var_subclasses[::-1]: for var_subclass in _var_subclasses[::-1]:
if fixed_output_type in var_subclass.python_types: if fixed_output_type in var_subclass.python_types or safe_issubclass(
fixed_output_type, var_subclass.python_types
):
return self.to(var_subclass.var_subclass, output) return self.to(var_subclass.var_subclass, output)
if fixed_output_type is None: if fixed_output_type is None:
@ -790,7 +794,7 @@ class Var(Generic[VAR_TYPE]):
return False return False
if issubclass(type_, list): if issubclass(type_, list):
return [] return []
if issubclass(type_, dict): if issubclass(type_, Mapping):
return {} return {}
if issubclass(type_, tuple): if issubclass(type_, tuple):
return () return ()
@ -996,7 +1000,7 @@ class Var(Generic[VAR_TYPE]):
f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")] f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")]
} }
), ),
).to(ObjectVar, Dict[str, str]) ).to(ObjectVar, Mapping[str, str])
return refs[LiteralVar.create(str(self))] return refs[LiteralVar.create(str(self))]
@deprecated("Use `.js_type()` instead.") @deprecated("Use `.js_type()` instead.")
@ -1343,7 +1347,7 @@ class LiteralVar(Var):
serialized_value = serializers.serialize(value) serialized_value = serializers.serialize(value)
if serialized_value is not None: if serialized_value is not None:
if isinstance(serialized_value, dict): if isinstance(serialized_value, Mapping):
return LiteralObjectVar.create( return LiteralObjectVar.create(
serialized_value, serialized_value,
_var_type=type(value), _var_type=type(value),
@ -1468,7 +1472,7 @@ def var_operation(
) -> Callable[P, ArrayVar[LIST_T]]: ... ) -> Callable[P, ArrayVar[LIST_T]]: ...
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict) OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping)
@overload @overload
@ -1543,8 +1547,8 @@ def figure_out_type(value: Any) -> types.GenericType:
return Set[unionize(*(figure_out_type(v) for v in value))] return Set[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, tuple): if isinstance(value, tuple):
return Tuple[unionize(*(figure_out_type(v) for v in value)), ...] return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
if isinstance(value, dict): if isinstance(value, Mapping):
return Dict[ return Mapping[
unionize(*(figure_out_type(k) for k in value)), unionize(*(figure_out_type(k) for k in value)),
unionize(*(figure_out_type(v) for v in value.values())), unionize(*(figure_out_type(v) for v in value.values())),
] ]
@ -1968,10 +1972,10 @@ class ComputedVar(Var[RETURN_TYPE]):
@overload @overload
def __get__( def __get__(
self: ComputedVar[dict[DICT_KEY, DICT_VAL]], self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]],
instance: None, instance: None,
owner: Type, owner: Type,
) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ... ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
@overload @overload
def __get__( def __get__(
@ -2878,11 +2882,14 @@ V = TypeVar("V")
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base) BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
FIELD_TYPE = TypeVar("FIELD_TYPE")
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
class Field(Generic[T]):
class Field(Generic[FIELD_TYPE]):
"""Shadow class for Var to allow for type hinting in the IDE.""" """Shadow class for Var to allow for type hinting in the IDE."""
def __set__(self, instance, value: T): def __set__(self, instance, value: FIELD_TYPE):
"""Set the Var. """Set the Var.
Args: Args:
@ -2894,7 +2901,9 @@ class Field(Generic[T]):
def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ... def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
@overload @overload
def __get__(self: Field[int], instance: None, owner) -> NumberVar: ... def __get__(
self: Field[int] | Field[float] | Field[int | float], instance: None, owner
) -> NumberVar: ...
@overload @overload
def __get__(self: Field[str], instance: None, owner) -> StringVar: ... def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
@ -2911,8 +2920,8 @@ class Field(Generic[T]):
@overload @overload
def __get__( def __get__(
self: Field[Dict[str, V]], instance: None, owner self: Field[MAPPING_TYPE], instance: None, owner
) -> ObjectVar[Dict[str, V]]: ... ) -> ObjectVar[MAPPING_TYPE]: ...
@overload @overload
def __get__( def __get__(
@ -2920,10 +2929,10 @@ class Field(Generic[T]):
) -> ObjectVar[BASE_TYPE]: ... ) -> ObjectVar[BASE_TYPE]: ...
@overload @overload
def __get__(self, instance: None, owner) -> Var[T]: ... def __get__(self, instance: None, owner) -> Var[FIELD_TYPE]: ...
@overload @overload
def __get__(self, instance, owner) -> T: ... def __get__(self, instance, owner) -> FIELD_TYPE: ...
def __get__(self, instance, owner): # type: ignore def __get__(self, instance, owner): # type: ignore
"""Get the Var. """Get the Var.
@ -2934,7 +2943,7 @@ class Field(Generic[T]):
""" """
def field(value: T) -> Field[T]: def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]:
"""Create a Field with a value. """Create a Field with a value.
Args: Args:

View File

@ -8,8 +8,8 @@ import typing
from inspect import isclass from inspect import isclass
from typing import ( from typing import (
Any, Any,
Dict,
List, List,
Mapping,
NoReturn, NoReturn,
Tuple, Tuple,
Type, Type,
@ -19,6 +19,8 @@ from typing import (
overload, overload,
) )
from typing_extensions import is_typeddict
from reflex.utils import types from reflex.utils import types
from reflex.utils.exceptions import VarAttributeError from reflex.utils.exceptions import VarAttributeError
from reflex.utils.types import GenericType, get_attribute_access_type, get_origin from reflex.utils.types import GenericType, get_attribute_access_type, get_origin
@ -36,7 +38,7 @@ from .base import (
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
from .sequence import ArrayVar, StringVar from .sequence import ArrayVar, StringVar
OBJECT_TYPE = TypeVar("OBJECT_TYPE") OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)
KEY_TYPE = TypeVar("KEY_TYPE") KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE") VALUE_TYPE = TypeVar("VALUE_TYPE")
@ -46,7 +48,7 @@ ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE")
OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE") OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE")
class ObjectVar(Var[OBJECT_TYPE], python_types=dict): class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
"""Base class for immutable object vars.""" """Base class for immutable object vars."""
def _key_type(self) -> Type: def _key_type(self) -> Type:
@ -59,7 +61,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload @overload
def _value_type( def _value_type(
self: ObjectVar[Dict[Any, VALUE_TYPE]], self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> Type[VALUE_TYPE]: ... ) -> Type[VALUE_TYPE]: ...
@overload @overload
@ -74,7 +76,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
fixed_type = get_origin(self._var_type) or self._var_type fixed_type = get_origin(self._var_type) or self._var_type
if not isclass(fixed_type): if not isclass(fixed_type):
return Any return Any
args = get_args(self._var_type) if issubclass(fixed_type, dict) else () args = get_args(self._var_type) if issubclass(fixed_type, Mapping) else ()
return args[1] if args else Any return args[1] if args else Any
def keys(self) -> ArrayVar[List[str]]: def keys(self) -> ArrayVar[List[str]]:
@ -87,7 +89,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload @overload
def values( def values(
self: ObjectVar[Dict[Any, VALUE_TYPE]], self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[VALUE_TYPE]]: ... ) -> ArrayVar[List[VALUE_TYPE]]: ...
@overload @overload
@ -103,7 +105,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload @overload
def entries( def entries(
self: ObjectVar[Dict[Any, VALUE_TYPE]], self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ... ) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ...
@overload @overload
@ -133,49 +135,55 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any # NoReturn is used here to catch when key value is Any
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, NoReturn]], self: ObjectVar[Mapping[Any, NoReturn]],
key: Var | Any, key: Var | Any,
) -> Var: ... ) -> Var: ...
@overload
def __getitem__(
self: (ObjectVar[Mapping[Any, bool]]),
key: Var | Any,
) -> BooleanVar: ...
@overload @overload
def __getitem__( def __getitem__(
self: ( self: (
ObjectVar[Dict[Any, int]] ObjectVar[Mapping[Any, int]]
| ObjectVar[Dict[Any, float]] | ObjectVar[Mapping[Any, float]]
| ObjectVar[Dict[Any, int | float]] | ObjectVar[Mapping[Any, int | float]]
), ),
key: Var | Any, key: Var | Any,
) -> NumberVar: ... ) -> NumberVar: ...
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, str]], self: ObjectVar[Mapping[Any, str]],
key: Var | Any, key: Var | Any,
) -> StringVar: ... ) -> StringVar: ...
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
key: Var | Any, key: Var | Any,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
key: Var | Any, key: Var | Any,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
key: Var | Any, key: Var | Any,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
key: Var | Any, key: Var | Any,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
def __getitem__(self, key: Var | Any) -> Var: def __getitem__(self, key: Var | Any) -> Var:
"""Get an item from the object. """Get an item from the object.
@ -195,49 +203,49 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any # NoReturn is used here to catch when key value is Any
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Dict[Any, NoReturn]], self: ObjectVar[Mapping[Any, NoReturn]],
name: str, name: str,
) -> Var: ... ) -> Var: ...
@overload @overload
def __getattr__( def __getattr__(
self: ( self: (
ObjectVar[Dict[Any, int]] ObjectVar[Mapping[Any, int]]
| ObjectVar[Dict[Any, float]] | ObjectVar[Mapping[Any, float]]
| ObjectVar[Dict[Any, int | float]] | ObjectVar[Mapping[Any, int | float]]
), ),
name: str, name: str,
) -> NumberVar: ... ) -> NumberVar: ...
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Dict[Any, str]], self: ObjectVar[Mapping[Any, str]],
name: str, name: str,
) -> StringVar: ... ) -> StringVar: ...
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
name: str, name: str,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
name: str, name: str,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
name: str, name: str,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
name: str, name: str,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
@overload @overload
def __getattr__( def __getattr__(
@ -266,8 +274,11 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
var_type = get_args(var_type)[0] var_type = get_args(var_type)[0]
fixed_type = var_type if isclass(var_type) else get_origin(var_type) fixed_type = var_type if isclass(var_type) else get_origin(var_type)
if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or (
fixed_type in types.UnionTypes if (
(isclass(fixed_type) and not issubclass(fixed_type, Mapping))
or (fixed_type in types.UnionTypes)
or is_typeddict(fixed_type)
): ):
attribute_type = get_attribute_access_type(var_type, name) attribute_type = get_attribute_access_type(var_type, name)
if attribute_type is None: if attribute_type is None:
@ -299,7 +310,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
"""Base class for immutable literal object vars.""" """Base class for immutable literal object vars."""
_var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( _var_value: Mapping[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
default_factory=dict default_factory=dict
) )
@ -383,7 +394,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
@classmethod @classmethod
def create( def create(
cls, cls,
_var_value: dict, _var_value: Mapping,
_var_type: Type[OBJECT_TYPE] | None = None, _var_type: Type[OBJECT_TYPE] | None = None,
_var_data: VarData | None = None, _var_data: VarData | None = None,
) -> LiteralObjectVar[OBJECT_TYPE]: ) -> LiteralObjectVar[OBJECT_TYPE]:
@ -466,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
""" """
return var_operation_return( return var_operation_return(
js_expression=f"({{...{lhs}, ...{rhs}}})", js_expression=f"({{...{lhs}, ...{rhs}}})",
var_type=Dict[ var_type=Mapping[
Union[lhs._key_type(), rhs._key_type()], Union[lhs._key_type(), rhs._key_type()],
Union[lhs._value_type(), rhs._value_type()], Union[lhs._value_type(), rhs._value_type()],
], ],

View File

@ -987,7 +987,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
raise_unsupported_operand_types("[]", (type(self), type(i))) raise_unsupported_operand_types("[]", (type(self), type(i)))
return array_item_operation(self, i) return array_item_operation(self, i)
def length(self) -> NumberVar: def length(self) -> NumberVar[int]:
"""Get the length of the array. """Get the length of the array.
Returns: Returns:

View File

@ -71,9 +71,10 @@ def has_error_modal(driver: WebDriver) -> bool:
""" """
try: try:
driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH) driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH)
return True
except NoSuchElementException: except NoSuchElementException:
return False return False
else:
return True
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Tuple from typing import List, Mapping, Tuple
import pytest import pytest
@ -67,7 +67,7 @@ def test_match_components():
assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}' assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}'
assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })' assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })'
assert match_cases[4][0]._var_type == Dict[str, str] assert match_cases[4][0]._var_type == Mapping[str, str]
fifth_return_value_render = match_cases[4][1].render() fifth_return_value_render = match_cases[4][1].render()
assert fifth_return_value_render["name"] == "RadixThemesText" assert fifth_return_value_render["name"] == "RadixThemesText"
assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}' assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}'

View File

@ -29,6 +29,7 @@ from reflex.state import BaseState
from reflex.style import Style from reflex.style import Style
from reflex.utils import imports from reflex.utils import imports
from reflex.utils.exceptions import ( from reflex.utils.exceptions import (
ChildrenTypeError,
EventFnArgMismatchError, EventFnArgMismatchError,
EventHandlerArgTypeMismatchError, EventHandlerArgTypeMismatchError,
) )
@ -652,14 +653,17 @@ def test_create_filters_none_props(test_component):
assert str(component.style["text-align"]) == '"center"' assert str(component.style["text-align"]) == '"center"'
@pytest.mark.parametrize("children", [((None,),), ("foo", ("bar", (None,)))]) @pytest.mark.parametrize(
"children",
[
((None,),),
("foo", ("bar", (None,))),
({"foo": "bar"},),
],
)
def test_component_create_unallowed_types(children, test_component): def test_component_create_unallowed_types(children, test_component):
with pytest.raises(TypeError) as err: with pytest.raises(ChildrenTypeError):
test_component.create(*children) test_component.create(*children)
assert (
err.value.args[0]
== "Children of Reflex components must be other components, state vars, or primitive Python types. Got child None of type <class 'NoneType'>."
)
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict from typing import Any, Mapping
import pytest import pytest
@ -379,7 +379,7 @@ class StyleState(rx.State):
{ {
"css": Var( "css": Var(
_js_expr=f'({{ ["color"] : ("dark"+{StyleState.color}) }})' _js_expr=f'({{ ["color"] : ("dark"+{StyleState.color}) }})'
).to(Dict[str, str]) ).to(Mapping[str, str])
}, },
), ),
( (

View File

@ -2,7 +2,7 @@ import json
import math import math
import sys import sys
import typing import typing
from typing import Dict, List, Optional, Set, Tuple, Union, cast from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast
import pytest import pytest
from pandas import DataFrame from pandas import DataFrame
@ -273,7 +273,7 @@ def test_get_setter(prop: Var, expected):
([1, 2, 3], Var(_js_expr="[1, 2, 3]", _var_type=List[int])), ([1, 2, 3], Var(_js_expr="[1, 2, 3]", _var_type=List[int])),
( (
{"a": 1, "b": 2}, {"a": 1, "b": 2},
Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Dict[str, int]), Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Mapping[str, int]),
), ),
], ],
) )

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Union from typing import List, Mapping, Union
import pytest import pytest
@ -37,12 +37,12 @@ class ChildGenericDict(GenericDict):
("a", str), ("a", str),
([1, 2, 3], List[int]), ([1, 2, 3], List[int]),
([1, 2.0, "a"], List[Union[int, float, str]]), ([1, 2.0, "a"], List[Union[int, float, str]]),
({"a": 1, "b": 2}, Dict[str, int]), ({"a": 1, "b": 2}, Mapping[str, int]),
({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]), ({"a": 1, 2: "b"}, Mapping[Union[int, str], Union[str, int]]),
(CustomDict(), CustomDict), (CustomDict(), CustomDict),
(ChildCustomDict(), ChildCustomDict), (ChildCustomDict(), ChildCustomDict),
(GenericDict({1: 1}), Dict[int, int]), (GenericDict({1: 1}), Mapping[int, int]),
(ChildGenericDict({1: 1}), Dict[int, int]), (ChildGenericDict({1: 1}), Mapping[int, int]),
], ],
) )
def test_figure_out_type(value, expected): def test_figure_out_type(value, expected):