Merge branch 'main' into masenf/base-model

This commit is contained in:
Masen Furer 2024-11-13 12:16:33 -08:00 committed by GitHub
commit 3c195142af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 588 additions and 135 deletions

View File

@ -25,7 +25,7 @@
# Stage 1: init
FROM python:3.11 as init
ARG uv=/root/.cargo/bin/uv
ARG uv=/root/.local/bin/uv
# Install `uv` for faster package boostrapping
ADD --chmod=755 https://astral.sh/uv/install.sh /install.sh

View File

@ -4,7 +4,7 @@
# Stage 1: init
FROM python:3.11 as init
ARG uv=/root/.cargo/bin/uv
ARG uv=/root/.local/bin/uv
# Install `uv` for faster package boostrapping
ADD --chmod=755 https://astral.sh/uv/install.sh /install.sh

6
poetry.lock generated
View File

@ -1350,8 +1350,8 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
]
python-dateutil = ">=2.8.2"
@ -1669,8 +1669,8 @@ files = [
annotated-types = ">=0.6.0"
pydantic-core = "2.23.4"
typing-extensions = [
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
{version = ">=4.6.1", markers = "python_version < \"3.13\""},
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
]
[package.extras]
@ -3050,4 +3050,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "593a52e9f54e95b50074f1bc4b7cdbabe4fab325051c72b23219268c0c9aa3ba"
content-hash = "937f0cadb1a4566117dad8d0be6018ad1a8fe9aeb19c499d2a010d36ef391ee1"

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "reflex"
version = "0.6.5dev1"
version = "0.6.6dev1"
description = "Web apps in pure Python."
license = "Apache-2.0"
authors = [
@ -49,7 +49,7 @@ wrapt = [
{version = ">=1.11.0,<2.0", python = "<3.11"},
]
packaging = ">=23.1,<25.0"
reflex-hosting-cli = ">=0.1.5,<2.0"
reflex-hosting-cli = ">=0.1.15,<2.0"
charset-normalizer = ">=3.3.2,<4.0"
wheel = ">=0.42.0,<1.0"
build = ">=1.0.3,<2.0"

View File

@ -298,6 +298,7 @@ _MAPPING: dict = {
"components.moment": ["MomentDelta", "moment"],
"config": ["Config", "DBConfig"],
"constants": ["Env"],
"constants.colors": ["Color"],
"event": [
"EventChain",
"EventHandler",
@ -338,7 +339,7 @@ _MAPPING: dict = {
],
"istate.wrappers": ["get_state"],
"style": ["Style", "toggle_color_mode"],
"utils.imports": ["ImportVar"],
"utils.imports": ["ImportDict", "ImportVar"],
"utils.serializers": ["serializer"],
"vars": ["Var", "field", "Field"],
}

View File

@ -152,6 +152,7 @@ from .components.suneditor import editor as editor
from .config import Config as Config
from .config import DBConfig as DBConfig
from .constants import Env as Env
from .constants.colors import Color as Color
from .event import EventChain as EventChain
from .event import EventHandler as EventHandler
from .event import background as background
@ -192,6 +193,7 @@ from .state import dynamic as dynamic
from .state import var as var
from .style import Style as Style
from .style import toggle_color_mode as toggle_color_mode
from .utils.imports import ImportDict as ImportDict
from .utils.imports import ImportVar as ImportVar
from .utils.serializers import serializer as serializer
from .vars import Field as Field

View File

@ -130,8 +130,8 @@ class Base(BaseModel): # pyright: ignore [reportUnboundVariable]
Returns:
The value of the field.
"""
if isinstance(key, str) and key in self.__fields__:
if isinstance(key, str):
# Seems like this function signature was wrong all along?
# If the user wants a field that we know of, get it and pass it off to _get_value
key = getattr(self, key)
return getattr(self, key, key)
return key

View File

@ -2,14 +2,15 @@
from __future__ import annotations
from typing import Dict, List, Tuple
from typing import Dict, Tuple
from reflex.compiler.compiler import _compile_component
from reflex.components.component import Component
from reflex.components.el import div, p
from reflex.event import EventHandler
from reflex.components.datadisplay.logo import svg_logo
from reflex.components.el import a, button, details, div, h2, hr, p, pre, summary
from reflex.event import EventHandler, set_clipboard
from reflex.state import FrontendEventExceptionState
from reflex.vars.base import Var
from reflex.vars.function import ArgsFunctionOperation
def on_error_spec(
@ -40,38 +41,7 @@ class ErrorBoundary(Component):
on_error: EventHandler[on_error_spec]
# Rendered instead of the children when an error is caught.
Fallback_component: Var[Component] = Var(_js_expr="Fallback")._replace(
_var_type=Component
)
def add_custom_code(self) -> List[str]:
"""Add custom Javascript code into the page that contains this component.
Custom code is inserted at module level, after any imports.
Returns:
The custom code to add.
"""
fallback_container = div(
p("Ooops...Unknown Reflex error has occured:"),
p(
Var(_js_expr="error.message"),
color="red",
),
p("Please contact the support."),
)
compiled_fallback = _compile_component(fallback_container)
return [
f"""
function Fallback({{ error, resetErrorBoundary }}) {{
return (
{compiled_fallback}
);
}}
"""
]
fallback_render: Var[Component]
@classmethod
def create(cls, *children, **props):
@ -86,6 +56,99 @@ class ErrorBoundary(Component):
"""
if "on_error" not in props:
props["on_error"] = FrontendEventExceptionState.handle_frontend_exception
if "fallback_render" not in props:
props["fallback_render"] = ArgsFunctionOperation.create(
("event_args",),
Var.create(
div(
div(
div(
h2(
"An error occurred while rendering this page.",
font_size="1.25rem",
font_weight="bold",
),
p(
"This is an error with the application itself.",
opacity="0.75",
),
details(
summary("Error message", padding="0.5rem"),
div(
div(
pre(
Var(
_js_expr="event_args.error.stack",
),
),
padding="0.5rem",
width="fit-content",
),
width="100%",
max_height="50vh",
overflow="auto",
background="#000",
color="#fff",
border_radius="0.25rem",
),
button(
"Copy",
on_click=set_clipboard(
Var(_js_expr="event_args.error.stack"),
),
padding="0.35rem 0.75rem",
margin="0.5rem",
background="#fff",
color="#000",
border="1px solid #000",
border_radius="0.25rem",
font_weight="bold",
),
),
display="flex",
flex_direction="column",
gap="1rem",
max_width="50ch",
border="1px solid #888888",
border_radius="0.25rem",
padding="1rem",
),
hr(
border_color="currentColor",
opacity="0.25",
),
a(
div(
"Built with ",
svg_logo("currentColor"),
display="flex",
align_items="baseline",
justify_content="center",
font_family="monospace",
gap="0.5rem",
),
href="https://reflex.dev",
),
display="flex",
flex_direction="column",
gap="1rem",
),
height="100%",
width="100%",
position="absolute",
display="flex",
align_items="center",
justify_content="center",
)
),
_var_type=Component,
)
else:
props["fallback_render"] = ArgsFunctionOperation.create(
("event_args",),
props["fallback_render"],
_var_type=Component,
)
return super().create(*children, **props)

View File

@ -3,7 +3,7 @@
# ------------------- DO NOT EDIT ----------------------
# This file was generated by `reflex/utils/pyi_generator.py`!
# ------------------------------------------------------
from typing import Any, Dict, List, Optional, Tuple, Union, overload
from typing import Any, Dict, Optional, Tuple, Union, overload
from reflex.components.component import Component
from reflex.event import BASE_STATE, EventType
@ -15,13 +15,12 @@ def on_error_spec(
) -> Tuple[Var[str], Var[str]]: ...
class ErrorBoundary(Component):
def add_custom_code(self) -> List[str]: ...
@overload
@classmethod
def create( # type: ignore
cls,
*children,
Fallback_component: Optional[Union[Component, Var[Component]]] = None,
fallback_render: Optional[Union[Component, Var[Component]]] = None,
style: Optional[Style] = None,
key: Optional[Any] = None,
id: Optional[Any] = None,
@ -57,7 +56,7 @@ class ErrorBoundary(Component):
Args:
*children: The children of the component.
on_error: Fired when the boundary catches an error.
Fallback_component: Rendered instead of the children when an error is caught.
fallback_render: Rendered instead of the children when an error is caught.
style: The style of the component.
key: A unique key for the component.
id: The id for the component.

View File

@ -171,6 +171,14 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
)
@overload
def color_mode_cond(light: Component, dark: Component | None = None) -> Component: ... # type: ignore
@overload
def color_mode_cond(light: Any, dark: Any = None) -> Var: ...
def color_mode_cond(light: Any, dark: Any = None) -> Var | Component:
"""Create a component or Prop based on color_mode.

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple
from reflex.components.base.fragment import Fragment
from reflex.components.component import (
Component,
ComponentNamespace,
@ -181,6 +182,13 @@ class UploadFilesProvider(Component):
tag = "UploadFilesProvider"
class GhostUpload(Fragment):
"""A ghost upload component."""
# Fired when files are dropped.
on_drop: EventHandler[_on_drop_spec]
class Upload(MemoizationLeaf):
"""A file upload component."""
@ -276,8 +284,8 @@ class Upload(MemoizationLeaf):
root_props_unique_name = get_unique_variable_name()
event_var, callback_str = StatefulComponent._get_memoized_event_triggers(
Box.create(on_click=upload_props["on_drop"]) # type: ignore
)["on_click"]
GhostUpload.create(on_drop=upload_props["on_drop"])
)["on_drop"]
upload_props["on_drop"] = event_var

View File

@ -6,6 +6,7 @@
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, Union, overload
from reflex.components.base.fragment import Fragment
from reflex.components.component import Component, ComponentNamespace, MemoizationLeaf
from reflex.constants import Dirs
from reflex.event import BASE_STATE, CallableEventSpec, EventSpec, EventType
@ -84,6 +85,56 @@ class UploadFilesProvider(Component):
"""
...
class GhostUpload(Fragment):
@overload
@classmethod
def create( # type: ignore
cls,
*children,
style: Optional[Style] = None,
key: Optional[Any] = None,
id: Optional[Any] = None,
class_name: Optional[Any] = None,
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
on_blur: Optional[EventType[[], BASE_STATE]] = None,
on_click: Optional[EventType[[], BASE_STATE]] = None,
on_context_menu: Optional[EventType[[], BASE_STATE]] = None,
on_double_click: Optional[EventType[[], BASE_STATE]] = None,
on_drop: Optional[
Union[EventType[[], BASE_STATE], EventType[[Any], BASE_STATE]]
] = None,
on_focus: Optional[EventType[[], BASE_STATE]] = None,
on_mount: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_down: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_move: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_out: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_over: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_up: Optional[EventType[[], BASE_STATE]] = None,
on_scroll: Optional[EventType[[], BASE_STATE]] = None,
on_unmount: Optional[EventType[[], BASE_STATE]] = None,
**props,
) -> "GhostUpload":
"""Create the component.
Args:
*children: The children of the component.
on_drop: Fired when files are dropped.
style: The style of the component.
key: A unique key for the component.
id: The id for the component.
class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute
**props: The props of the component.
Returns:
The component.
"""
...
class Upload(MemoizationLeaf):
is_used: ClassVar[bool] = False

View File

@ -1,22 +1,23 @@
"""A Reflex logo component."""
from typing import Union
import reflex as rx
def logo(**props):
"""A Reflex logo.
def svg_logo(color: Union[str, rx.Var[str]] = rx.color_mode_cond("#110F1F", "white")):
"""A Reflex logo SVG.
Args:
**props: The props to pass to the component.
color: The color of the logo.
Returns:
The logo component.
The Reflex logo SVG.
"""
def logo_path(d):
return rx.el.svg.path(
d=d,
fill=rx.color_mode_cond("#110F1F", "white"),
)
paths = [
@ -28,18 +29,30 @@ def logo(**props):
"M47.04 4.8799V0.399902H49.28V4.8799H47.04ZM53.76 4.8799V0.399902H56V4.8799H53.76ZM49.28 7.1199V4.8799H53.76V7.1199H49.28ZM47.04 11.5999V7.1199H49.28V11.5999H47.04ZM53.76 11.5999V7.1199H56V11.5999H53.76Z",
]
return rx.el.svg(
*[logo_path(d) for d in paths],
width="56",
height="12",
viewBox="0 0 56 12",
fill=color,
xmlns="http://www.w3.org/2000/svg",
)
def logo(**props):
"""A Reflex logo.
Args:
**props: The props to pass to the component.
Returns:
The logo component.
"""
return rx.center(
rx.link(
rx.hstack(
"Built with ",
rx.el.svg(
*[logo_path(d) for d in paths],
width="56",
height="12",
viewBox="0 0 56 12",
fill="none",
xmlns="http://www.w3.org/2000/svg",
),
svg_logo(),
text_align="center",
align="center",
padding="1em",

View File

@ -112,6 +112,9 @@ class RadixThemesComponent(Component):
library = "@radix-ui/themes@^3.0.0"
# Temporary pin < 3.1.5 until radix-ui/themes#627 is resolved.
library = library + " && <3.1.5"
# "Fake" prop color_scheme is used to avoid shadowing CSS prop "color".
_rename_props: Dict[str, str] = {"colorScheme": "color"}

View File

@ -45,6 +45,8 @@ from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
from reflex.vars.function import (
ArgsFunctionOperation,
ArgsFunctionOperationBuilder,
BuilderFunctionVar,
FunctionArgs,
FunctionStringVar,
FunctionVar,
@ -797,8 +799,7 @@ def scroll_to(elem_id: str, align_to_top: bool | Var[bool] = True) -> EventSpec:
get_element_by_id = FunctionStringVar.create("document.getElementById")
return run_script(
get_element_by_id(elem_id)
.call(elem_id)
get_element_by_id.call(elem_id)
.to(ObjectVar)
.scrollIntoView.to(FunctionVar)
.call(align_to_top),
@ -899,7 +900,7 @@ def remove_session_storage(key: str) -> EventSpec:
)
def set_clipboard(content: str) -> EventSpec:
def set_clipboard(content: Union[str, Var[str]]) -> EventSpec:
"""Set the text in content in the clipboard.
Args:
@ -1580,7 +1581,7 @@ class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
)
class EventChainVar(FunctionVar, python_types=EventChain):
class EventChainVar(BuilderFunctionVar, python_types=EventChain):
"""Base class for event chain vars."""
@ -1592,7 +1593,7 @@ class EventChainVar(FunctionVar, python_types=EventChain):
# Note: LiteralVar is second in the inheritance list allowing it act like a
# CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
# _cached_var_name property.
class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainVar):
"""A literal event chain var."""
_var_value: EventChain = dataclasses.field(default=None) # type: ignore

View File

@ -106,7 +106,7 @@ def _init(
template = constants.Templates.DEFAULT
# Initialize the app.
prerequisites.initialize_app(app_name, template)
template = prerequisites.initialize_app(app_name, template)
# If a reflex.build generation hash is available, download the code and apply it to the main module.
if generation_hash:
@ -120,8 +120,9 @@ def _init(
# Initialize the requirements.txt.
prerequisites.initialize_requirements_txt()
template_msg = "" if template else f" using the {template} template"
# Finish initializing the app.
console.success(f"Initialized {app_name}")
console.success(f"Initialized {app_name}{template_msg}")
@cli.command()

View File

@ -46,6 +46,7 @@ from reflex import event
from reflex.config import get_config
from reflex.istate.data import RouterData
from reflex.istate.storage import ClientStorageBase
from reflex.model import Model
from reflex.vars.base import (
ComputedVar,
DynamicRouteVar,
@ -1740,15 +1741,20 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if value is None:
continue
hinted_args = value_inside_optional(hinted_args)
if (
isinstance(value, dict)
and inspect.isclass(hinted_args)
and (
dataclasses.is_dataclass(hinted_args)
or issubclass(hinted_args, Base)
)
):
payload[arg] = hinted_args(**value)
if isinstance(value, dict) and inspect.isclass(hinted_args):
if issubclass(hinted_args, Model):
# Remove non-fields from the payload
payload[arg] = hinted_args(
**{
key: value
for key, value in value.items()
if key in hinted_args.__fields__
}
)
elif dataclasses.is_dataclass(hinted_args) or issubclass(
hinted_args, Base
):
payload[arg] = hinted_args(**value)
if isinstance(value, list) and (hinted_args is set or hinted_args is Set):
payload[arg] = set(value)
if isinstance(value, list) and (
@ -1891,7 +1897,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
)
subdelta: Dict[str, Any] = {
prop: self.get_value(getattr(self, prop))
prop: self.get_value(prop)
for prop in delta_vars
if not types.is_backend_base_variable(prop, type(self))
}
@ -1983,9 +1989,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Returns:
The value of the field.
"""
if isinstance(key, MutableProxy):
return super().get_value(key.__wrapped__)
return super().get_value(key)
value = super().get_value(key)
if isinstance(value, MutableProxy):
return value.__wrapped__
return value
def dict(
self, include_computed: bool = True, initial: bool = False, **kwargs
@ -2007,8 +2014,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
self._mark_dirty()
base_vars = {
prop_name: self.get_value(getattr(self, prop_name))
for prop_name in self.base_vars
prop_name: self.get_value(prop_name) for prop_name in self.base_vars
}
if initial and include_computed:
computed_vars = {
@ -2017,7 +2023,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cv._initial_value
if is_computed_var(cv)
and not isinstance(cv._initial_value, types.Unset)
else self.get_value(getattr(self, prop_name))
else self.get_value(prop_name)
)
for prop_name, cv in self.computed_vars.items()
if not cv._backend
@ -2025,7 +2031,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
elif include_computed:
computed_vars = {
# Include the computed vars.
prop_name: self.get_value(getattr(self, prop_name))
prop_name: self.get_value(prop_name)
for prop_name, cv in self.computed_vars.items()
if not cv._backend
}

View File

@ -1378,7 +1378,7 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str
shutil.rmtree(unzip_dir)
def initialize_app(app_name: str, template: str | None = None):
def initialize_app(app_name: str, template: str | None = None) -> str | None:
"""Initialize the app either from a remote template or a blank app. If the config file exists, it is considered as reinit.
Args:
@ -1387,6 +1387,9 @@ def initialize_app(app_name: str, template: str | None = None):
Raises:
Exit: If template is directly provided in the command flag and is invalid.
Returns:
The name of the template.
"""
# Local imports to avoid circular imports.
from reflex.utils import telemetry
@ -1441,6 +1444,7 @@ def initialize_app(app_name: str, template: str | None = None):
)
telemetry.send("init", template=template)
return template
def initialize_main_module_index_from_generation(app_name: str, generation_hash: str):

View File

@ -51,7 +51,8 @@ def get_python_version() -> str:
Returns:
The Python version.
"""
return platform.python_version()
# Remove the "+" from the version string in case user is using a pre-release version.
return platform.python_version().rstrip("+")
def get_reflex_version() -> str:

View File

@ -361,21 +361,29 @@ class Var(Generic[VAR_TYPE]):
return False
def __init_subclass__(
cls, python_types: Tuple[GenericType, ...] | GenericType = types.Unset, **kwargs
cls,
python_types: Tuple[GenericType, ...] | GenericType = types.Unset(),
default_type: GenericType = types.Unset(),
**kwargs,
):
"""Initialize the subclass.
Args:
python_types: The python types that the var represents.
default_type: The default type of the var. Defaults to the first python type.
**kwargs: Additional keyword arguments.
"""
super().__init_subclass__(**kwargs)
if python_types is not types.Unset:
if python_types or default_type:
python_types = (
python_types if isinstance(python_types, tuple) else (python_types,)
(python_types if isinstance(python_types, tuple) else (python_types,))
if python_types
else ()
)
default_type = default_type or (python_types[0] if python_types else Any)
@dataclasses.dataclass(
eq=False,
frozen=True,
@ -388,7 +396,7 @@ class Var(Generic[VAR_TYPE]):
default=Var(_js_expr="null", _var_type=None),
)
_default_var_type: ClassVar[GenericType] = python_types[0]
_default_var_type: ClassVar[GenericType] = default_type
ToVarOperation.__name__ = f'To{cls.__name__.removesuffix("Var")}Operation'
@ -588,6 +596,12 @@ class Var(Generic[VAR_TYPE]):
output: type[list] | type[tuple] | type[set],
) -> ArrayVar: ...
@overload
def to(
self,
output: type[dict],
) -> ObjectVar[dict]: ...
@overload
def to(
self, output: Type[ObjectVar], var_type: Type[VAR_INSIDE]

View File

@ -4,32 +4,177 @@ from __future__ import annotations
import dataclasses
import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
from reflex.utils import format
from reflex.utils.types import GenericType
from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
P = ParamSpec("P")
V1 = TypeVar("V1")
V2 = TypeVar("V2")
V3 = TypeVar("V3")
V4 = TypeVar("V4")
V5 = TypeVar("V5")
V6 = TypeVar("V6")
R = TypeVar("R")
class FunctionVar(Var[Callable], python_types=Callable):
class ReflexCallable(Protocol[P, R]):
"""Protocol for a callable."""
__call__: Callable[P, R]
CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
OTHER_CALLABLE_TYPE = TypeVar(
"OTHER_CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True
)
class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
"""Base class for immutable function vars."""
def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
"""Call the function with the given arguments.
@overload
def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
arg1: Union[V1, Var[V1]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
arg6: Union[V6, Var[V6]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(self, *args: Var | Any) -> FunctionVar: ...
def partial(self, *args: Var | Any) -> FunctionVar: # type: ignore
"""Partially apply the function with the given arguments.
Args:
*args: The arguments to call the function with.
*args: The arguments to partially apply the function with.
Returns:
The function call operation.
The partially applied function.
"""
if not args:
return ArgsFunctionOperation.create((), self)
return ArgsFunctionOperation.create(
("...args",),
VarOperationCall.create(self, *args, Var(_js_expr="...args")),
)
def call(self, *args: Var | Any) -> VarOperationCall:
@overload
def call(
self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]]
) -> VarOperationCall[[V1], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
) -> VarOperationCall[[V1, V2], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
) -> VarOperationCall[[V1, V2, V3], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
arg6: Union[V6, Var[V6]],
) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
) -> VarOperationCall[P, R]: ...
@overload
def call(self, *args: Var | Any) -> Var: ...
def call(self, *args: Var | Any) -> Var: # type: ignore
"""Call the function with the given arguments.
Args:
@ -38,19 +183,29 @@ class FunctionVar(Var[Callable], python_types=Callable):
Returns:
The function call operation.
"""
return VarOperationCall.create(self, *args)
return VarOperationCall.create(self, *args).guess_type()
__call__ = call
class FunctionStringVar(FunctionVar):
class BuilderFunctionVar(
FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
):
"""Base class for immutable function vars with the builder pattern."""
__call__ = FunctionVar.partial
class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
"""Base class for immutable function vars from a string."""
@classmethod
def create(
cls,
func: str,
_var_type: Type[Callable] = Callable,
_var_type: Type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
_var_data: VarData | None = None,
) -> FunctionStringVar:
) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
"""Create a new function var from a string.
Args:
@ -60,7 +215,7 @@ class FunctionStringVar(FunctionVar):
Returns:
The function var.
"""
return cls(
return FunctionStringVar(
_js_expr=func,
_var_type=_var_type,
_var_data=_var_data,
@ -72,10 +227,10 @@ class FunctionStringVar(FunctionVar):
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class VarOperationCall(CachedVarOperation, Var):
class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
"""Base class for immutable vars that are the result of a function call."""
_func: Optional[FunctionVar] = dataclasses.field(default=None)
_func: Optional[FunctionVar[ReflexCallable[P, R]]] = dataclasses.field(default=None)
_args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
@cached_property_no_lock
@ -103,7 +258,7 @@ class VarOperationCall(CachedVarOperation, Var):
@classmethod
def create(
cls,
func: FunctionVar,
func: FunctionVar[ReflexCallable[P, R]],
*args: Var | Any,
_var_type: GenericType = Any,
_var_data: VarData | None = None,
@ -118,9 +273,15 @@ class VarOperationCall(CachedVarOperation, Var):
Returns:
The function call var.
"""
function_return_type = (
func._var_type.__args__[1]
if getattr(func._var_type, "__args__", None)
else Any
)
var_type = _var_type if _var_type is not Any else function_return_type
return cls(
_js_expr="",
_var_type=_var_type,
_var_type=var_type,
_var_data=_var_data,
_func=func,
_args=args,
@ -157,6 +318,33 @@ class FunctionArgs:
rest: Optional[str] = None
def format_args_function_operation(
args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
) -> str:
"""Format an args function operation.
Args:
args: The function arguments.
return_expr: The return expression.
explicit_return: Whether to use explicit return syntax.
Returns:
The formatted args function operation.
"""
arg_names_str = ", ".join(
[arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
) + (f", ...{args.rest}" if args.rest else "")
return_expr_str = str(LiteralVar.create(return_expr))
# Wrap return expression in curly braces if explicit return syntax is used.
return_expr_str_wrapped = (
format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
)
return f"(({arg_names_str}) => {return_expr_str_wrapped})"
@dataclasses.dataclass(
eq=False,
frozen=True,
@ -176,24 +364,10 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
Returns:
The name of the var.
"""
arg_names_str = ", ".join(
[
arg if isinstance(arg, str) else arg.to_javascript()
for arg in self._args.args
]
) + (f", ...{self._args.rest}" if self._args.rest else "")
return_expr_str = str(LiteralVar.create(self._return_expr))
# Wrap return expression in curly braces if explicit return syntax is used.
return_expr_str_wrapped = (
format.wrap(return_expr_str, "{", "}")
if self._explicit_return
else return_expr_str
return format_args_function_operation(
self._args, self._return_expr, self._explicit_return
)
return f"(({arg_names_str}) => {return_expr_str_wrapped})"
@classmethod
def create(
cls,
@ -203,7 +377,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
explicit_return: bool = False,
_var_type: GenericType = Callable,
_var_data: VarData | None = None,
) -> ArgsFunctionOperation:
):
"""Create a new function var.
Args:
@ -226,8 +400,80 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
)
JSON_STRINGIFY = FunctionStringVar.create("JSON.stringify")
ARRAY_ISARRAY = FunctionStringVar.create("Array.isArray")
PROTOTYPE_TO_STRING = FunctionStringVar.create(
"((__to_string) => __to_string.toString())"
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
"""Base class for immutable function defined via arguments and return expression with the builder pattern."""
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
_explicit_return: bool = dataclasses.field(default=False)
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return format_args_function_operation(
self._args, self._return_expr, self._explicit_return
)
@classmethod
def create(
cls,
args_names: Sequence[Union[str, DestructuredArg]],
return_expr: Var | Any,
rest: str | None = None,
explicit_return: bool = False,
_var_type: GenericType = Callable,
_var_data: VarData | None = None,
):
"""Create a new function var.
Args:
args_names: The names of the arguments.
return_expr: The return expression of the function.
rest: The name of the rest argument.
explicit_return: Whether to use explicit return syntax.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The function var.
"""
return cls(
_js_expr="",
_var_type=_var_type,
_var_data=_var_data,
_args=FunctionArgs(args=tuple(args_names), rest=rest),
_return_expr=return_expr,
_explicit_return=explicit_return,
)
if python_version := sys.version_info[:2] >= (3, 10):
JSON_STRINGIFY = FunctionStringVar.create(
"JSON.stringify", _var_type=ReflexCallable[[Any], str]
)
ARRAY_ISARRAY = FunctionStringVar.create(
"Array.isArray", _var_type=ReflexCallable[[Any], bool]
)
PROTOTYPE_TO_STRING = FunctionStringVar.create(
"((__to_string) => __to_string.toString())",
_var_type=ReflexCallable[[Any], str],
)
else:
JSON_STRINGIFY = FunctionStringVar.create(
"JSON.stringify", _var_type=ReflexCallable[Any, str]
)
ARRAY_ISARRAY = FunctionStringVar.create(
"Array.isArray", _var_type=ReflexCallable[Any, bool]
)
PROTOTYPE_TO_STRING = FunctionStringVar.create(
"((__to_string) => __to_string.toString())",
_var_type=ReflexCallable[Any, str],
)

View File

@ -4,6 +4,7 @@ from reflex.components.core.banner import (
ConnectionPulser,
WebsocketTargetURL,
)
from reflex.components.radix.themes.base import RadixThemesComponent
from reflex.components.radix.themes.typography.text import Text
@ -24,7 +25,7 @@ def test_connection_banner():
"react",
"$/utils/context",
"$/utils/state",
"@radix-ui/themes@^3.0.0",
RadixThemesComponent().library or "",
"$/env.json",
)
)
@ -42,7 +43,7 @@ def test_connection_modal():
"react",
"$/utils/context",
"$/utils/state",
"@radix-ui/themes@^3.0.0",
RadixThemesComponent().library or "",
"$/env.json",
)
)

View File

@ -3460,3 +3460,34 @@ def test_mutable_models():
# state.dc.foo = "baz"
# assert state.dirty_vars == {"dc"}
# state.dirty_vars.clear()
def test_get_value():
class GetValueState(rx.State):
foo: str = "FOO"
bar: str = "BAR"
state = GetValueState()
assert state.dict() == {
state.get_full_name(): {
"foo": "FOO",
"bar": "BAR",
}
}
assert state.get_delta() == {}
state.bar = "foo"
assert state.dict() == {
state.get_full_name(): {
"foo": "FOO",
"bar": "foo",
}
}
assert state.get_delta() == {
state.get_full_name(): {
"bar": "foo",
}
}

View File

@ -928,7 +928,7 @@ def test_function_var():
== '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))'
)
increment_func = addition_func(1)
increment_func = addition_func.partial(1)
assert (
str(increment_func.call(2))
== "(((...args) => (((a, b) => a + b)(1, ...args)))(2))"