Remove Pydantic from some classes (#3907)

* half of the way there

* add dataclass support

* Forbid Computed var shadowing (#3843)

* get it right pyright

* fix unit tests

* rip out more pydantic

* fix weird issues with merge_imports

* add missing docstring

* make special props a list instead of a set

* fix moment pyi

* actually ignore the runtime error

* it's ruff out there

---------

Co-authored-by: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com>
This commit is contained in:
Khaleel Al-Adhami 2024-09-13 12:53:30 -07:00 committed by GitHub
parent 7c25358607
commit 8f937f0417
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 394 additions and 248 deletions

View File

@ -9,6 +9,7 @@ import copy
import functools
import inspect
import io
import json
import multiprocessing
import os
import platform
@ -1096,6 +1097,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
if delta:
# When the state is modified reset dirty status and emit the delta to the frontend.
state._clean()
print(dir(state.router))
await self.event_namespace.emit_update(
update=StateUpdate(delta=delta),
sid=state.router.session.session_id,
@ -1531,8 +1533,9 @@ class EventNamespace(AsyncNamespace):
sid: The Socket.IO session id.
data: The event data.
"""
fields = json.loads(data)
# Get the event.
event = Event.parse_raw(data)
event = Event(**{k: v for k, v in fields.items() if k != "handler"})
self.token_to_sid[event.token] = sid
self.sid_to_token[sid] = event.token

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import copy
import typing
import warnings
from abc import ABC, abstractmethod
from functools import lru_cache, wraps
from hashlib import md5
@ -169,6 +170,8 @@ ComponentStyle = Dict[
]
ComponentChild = Union[types.PrimitiveType, Var, BaseComponent]
warnings.filterwarnings("ignore", message="fields may not start with an underscore")
class Component(BaseComponent, ABC):
"""A component with style, event trigger and other props."""
@ -195,7 +198,7 @@ class Component(BaseComponent, ABC):
class_name: Any = None
# Special component props.
special_props: Set[ImmutableVar] = set()
special_props: List[ImmutableVar] = []
# Whether the component should take the focus once the page is loaded
autofocus: bool = False
@ -655,7 +658,7 @@ class Component(BaseComponent, ABC):
"""
# Create the base tag.
tag = Tag(
name=self.tag if not self.alias else self.alias,
name=(self.tag if not self.alias else self.alias) or "",
special_props=self.special_props,
)
@ -2244,7 +2247,7 @@ class StatefulComponent(BaseComponent):
Returns:
The tag to render.
"""
return dict(Tag(name=self.tag))
return dict(Tag(name=self.tag or ""))
def __str__(self) -> str:
"""Represent the component in React.

View File

@ -247,9 +247,9 @@ class Upload(MemoizationLeaf):
}
# The file input to use.
upload = Input.create(type="file")
upload.special_props = {
upload.special_props = [
ImmutableVar(_var_name="{...getInputProps()}", _var_type=None)
}
]
# The dropzone to use.
zone = Box.create(
@ -257,9 +257,9 @@ class Upload(MemoizationLeaf):
*children,
**{k: v for k, v in props.items() if k not in supported_props},
)
zone.special_props = {
zone.special_props = [
ImmutableVar(_var_name="{...getRootProps()}", _var_type=None)
}
]
# Create the component.
upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID)

View File

@ -1,6 +1,6 @@
"""Element classes. This is an auto-generated file. Do not edit. See ../generate.py."""
from typing import Set, Union
from typing import List, Union
from reflex.components.el.element import Element
from reflex.ivars.base import ImmutableVar
@ -90,9 +90,9 @@ class StyleEl(Element): # noqa: E742
media: Var[Union[str, int, bool]]
special_props: Set[ImmutableVar] = {
special_props: List[ImmutableVar] = [
ImmutableVar.create_safe("suppressHydrationWarning")
}
]
base = Base.create

View File

@ -195,17 +195,17 @@ class Markdown(Component):
if tag not in self.component_map:
raise ValueError(f"No markdown component found for tag: {tag}.")
special_props = {_PROPS_IN_TAG}
special_props = [_PROPS_IN_TAG]
children = [_CHILDREN]
# For certain tags, the props from the markdown renderer are not actually valid for the component.
if tag in NO_PROPS_TAGS:
special_props = set()
special_props = []
# If the children are set as a prop, don't pass them as children.
children_prop = props.pop("children", None)
if children_prop is not None:
special_props.add(
special_props.append(
ImmutableVar.create_safe(f"children={{{str(children_prop)}}}")
)
children = []

View File

@ -1,26 +1,27 @@
"""Moment component for humanized date rendering."""
import dataclasses
from typing import List, Optional
from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.event import EventHandler
from reflex.utils.imports import ImportDict
from reflex.vars import Var
class MomentDelta(Base):
@dataclasses.dataclass(frozen=True)
class MomentDelta:
"""A delta used for add/subtract prop in Moment."""
years: Optional[int]
quarters: Optional[int]
months: Optional[int]
weeks: Optional[int]
days: Optional[int]
hours: Optional[int]
minutess: Optional[int]
seconds: Optional[int]
milliseconds: Optional[int]
years: Optional[int] = dataclasses.field(default=None)
quarters: Optional[int] = dataclasses.field(default=None)
months: Optional[int] = dataclasses.field(default=None)
weeks: Optional[int] = dataclasses.field(default=None)
days: Optional[int] = dataclasses.field(default=None)
hours: Optional[int] = dataclasses.field(default=None)
minutess: Optional[int] = dataclasses.field(default=None)
seconds: Optional[int] = dataclasses.field(default=None)
milliseconds: Optional[int] = dataclasses.field(default=None)
class Moment(NoSSRComponent):

View File

@ -3,9 +3,9 @@
# ------------------- DO NOT EDIT ----------------------
# This file was generated by `reflex/utils/pyi_generator.py`!
# ------------------------------------------------------
import dataclasses
from typing import Any, Callable, Dict, Optional, Union, overload
from reflex.base import Base
from reflex.components.component import NoSSRComponent
from reflex.event import EventHandler, EventSpec
from reflex.ivars.base import ImmutableVar
@ -13,7 +13,8 @@ from reflex.style import Style
from reflex.utils.imports import ImportDict
from reflex.vars import Var
class MomentDelta(Base):
@dataclasses.dataclass(frozen=True)
class MomentDelta:
years: Optional[int]
quarters: Optional[int]
months: Optional[int]

View File

@ -267,7 +267,7 @@ const extractPoints = (points) => {
template_dict = LiteralVar.create({"layout": {"template": self.template}})
merge_dicts.append(template_dict.without_data())
if merge_dicts:
tag.special_props.add(
tag.special_props.append(
# Merge all dictionaries and spread the result over props.
ImmutableVar.create_safe(
f"{{...mergician({str(figure)},"
@ -276,5 +276,5 @@ const extractPoints = (points) => {
)
else:
# Spread the figure dict over props, nothing to merge.
tag.special_props.add(ImmutableVar.create_safe(f"{{...{str(figure)}}}"))
tag.special_props.append(ImmutableVar.create_safe(f"{{...{str(figure)}}}"))
return tag

View File

@ -1,19 +1,22 @@
"""Tag to conditionally render components."""
import dataclasses
from typing import Any, Dict, Optional
from reflex.components.tags.tag import Tag
from reflex.ivars.base import LiteralVar
from reflex.vars import Var
@dataclasses.dataclass()
class CondTag(Tag):
"""A conditional tag."""
# The condition to determine which component to render.
cond: Var[Any]
cond: Var[Any] = dataclasses.field(default_factory=lambda: LiteralVar.create(True))
# The code to render if the condition is true.
true_value: Dict
true_value: Dict = dataclasses.field(default_factory=dict)
# The code to render if the condition is false.
false_value: Optional[Dict]
false_value: Optional[Dict] = None

View File

@ -2,31 +2,36 @@
from __future__ import annotations
import dataclasses
import inspect
from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Type, Union, get_args
from reflex.components.tags.tag import Tag
from reflex.ivars.base import ImmutableVar
from reflex.vars import Var
from reflex.ivars.sequence import LiteralArrayVar
from reflex.vars import Var, get_unique_variable_name
if TYPE_CHECKING:
from reflex.components.component import Component
@dataclasses.dataclass()
class IterTag(Tag):
"""An iterator tag."""
# The var to iterate over.
iterable: Var[List]
iterable: Var[List] = dataclasses.field(
default_factory=lambda: LiteralArrayVar.create([])
)
# The component render function for each item in the iterable.
render_fn: Callable
render_fn: Callable = dataclasses.field(default_factory=lambda: lambda x: x)
# The name of the arg var.
arg_var_name: str
arg_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
# The name of the index var.
index_var_name: str
index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
def get_iterable_var_type(self) -> Type:
"""Get the type of the iterable var.

View File

@ -1,19 +1,22 @@
"""Tag to conditionally match cases."""
import dataclasses
from typing import Any, List
from reflex.components.tags.tag import Tag
from reflex.ivars.base import LiteralVar
from reflex.vars import Var
@dataclasses.dataclass()
class MatchTag(Tag):
"""A match tag."""
# The condition to determine which case to match.
cond: Var[Any]
cond: Var[Any] = dataclasses.field(default_factory=lambda: LiteralVar.create(True))
# The list of match cases to be matched.
match_cases: List[Any]
match_cases: List[Any] = dataclasses.field(default_factory=list)
# The catchall case to match.
default: Any
default: Any = dataclasses.field(default=LiteralVar.create(None))

View File

@ -2,22 +2,23 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Union
from reflex.base import Base
from reflex.event import EventChain
from reflex.ivars.base import ImmutableVar, LiteralVar
from reflex.utils import format, types
class Tag(Base):
@dataclasses.dataclass()
class Tag:
"""A React tag."""
# The name of the tag.
name: str = ""
# The props of the tag.
props: Dict[str, Any] = {}
props: Dict[str, Any] = dataclasses.field(default_factory=dict)
# The inner contents of the tag.
contents: str = ""
@ -26,25 +27,18 @@ class Tag(Base):
args: Optional[Tuple[str, ...]] = None
# Special props that aren't key value pairs.
special_props: Set[ImmutableVar] = set()
special_props: List[ImmutableVar] = dataclasses.field(default_factory=list)
# The children components.
children: List[Any] = []
children: List[Any] = dataclasses.field(default_factory=list)
def __init__(self, *args, **kwargs):
"""Initialize the tag.
Args:
*args: Args to initialize the tag.
**kwargs: Kwargs to initialize the tag.
"""
# Convert any props to vars.
if "props" in kwargs:
kwargs["props"] = {
name: LiteralVar.create(value)
for name, value in kwargs["props"].items()
}
super().__init__(*args, **kwargs)
def __post_init__(self):
"""Post initialize the tag."""
object.__setattr__(
self,
"props",
{name: LiteralVar.create(value) for name, value in self.props.items()},
)
def format_props(self) -> List:
"""Format the tag's props.
@ -54,6 +48,29 @@ class Tag(Base):
"""
return format.format_props(*self.special_props, **self.props)
def set(self, **kwargs: Any):
"""Set the tag's fields.
Args:
kwargs: The fields to set.
Returns:
The tag with the fields
"""
for name, value in kwargs.items():
setattr(self, name, value)
return self
def __iter__(self):
"""Iterate over the tag's fields.
Yields:
Tuple[str, Any]: The field name and value.
"""
for field in dataclasses.fields(self):
yield field.name, getattr(self, field.name)
def add_props(self, **kwargs: Optional[Any]) -> Tag:
"""Add props to the tag.

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import dataclasses
import inspect
import types
import urllib.parse
@ -18,7 +19,6 @@ from typing import (
)
from reflex import constants
from reflex.base import Base
from reflex.ivars.base import ImmutableVar, LiteralVar
from reflex.ivars.function import FunctionStringVar, FunctionVar
from reflex.ivars.object import ObjectVar
@ -33,7 +33,11 @@ except ImportError:
from typing_extensions import Annotated
class Event(Base):
@dataclasses.dataclass(
init=True,
frozen=True,
)
class Event:
"""An event that describes any state change in the app."""
# The token to specify the client that the event is for.
@ -43,10 +47,10 @@ class Event(Base):
name: str
# The routing data where event occurred
router_data: Dict[str, Any] = {}
router_data: Dict[str, Any] = dataclasses.field(default_factory=dict)
# The event payload.
payload: Dict[str, Any] = {}
payload: Dict[str, Any] = dataclasses.field(default_factory=dict)
@property
def substate_token(self) -> str:
@ -81,11 +85,15 @@ def background(fn):
return fn
class EventActionsMixin(Base):
@dataclasses.dataclass(
init=True,
frozen=True,
)
class EventActionsMixin:
"""Mixin for DOM event actions."""
# Whether to `preventDefault` or `stopPropagation` on the event.
event_actions: Dict[str, Union[bool, int]] = {}
event_actions: Dict[str, Union[bool, int]] = dataclasses.field(default_factory=dict)
@property
def stop_propagation(self):
@ -94,8 +102,9 @@ class EventActionsMixin(Base):
Returns:
New EventHandler-like with stopPropagation set to True.
"""
return self.copy(
update={"event_actions": {"stopPropagation": True, **self.event_actions}},
return dataclasses.replace(
self,
event_actions={"stopPropagation": True, **self.event_actions},
)
@property
@ -105,8 +114,9 @@ class EventActionsMixin(Base):
Returns:
New EventHandler-like with preventDefault set to True.
"""
return self.copy(
update={"event_actions": {"preventDefault": True, **self.event_actions}},
return dataclasses.replace(
self,
event_actions={"preventDefault": True, **self.event_actions},
)
def throttle(self, limit_ms: int):
@ -118,8 +128,9 @@ class EventActionsMixin(Base):
Returns:
New EventHandler-like with throttle set to limit_ms.
"""
return self.copy(
update={"event_actions": {"throttle": limit_ms, **self.event_actions}},
return dataclasses.replace(
self,
event_actions={"throttle": limit_ms, **self.event_actions},
)
def debounce(self, delay_ms: int):
@ -131,26 +142,25 @@ class EventActionsMixin(Base):
Returns:
New EventHandler-like with debounce set to delay_ms.
"""
return self.copy(
update={"event_actions": {"debounce": delay_ms, **self.event_actions}},
return dataclasses.replace(
self,
event_actions={"debounce": delay_ms, **self.event_actions},
)
@dataclasses.dataclass(
init=True,
frozen=True,
)
class EventHandler(EventActionsMixin):
"""An event handler responds to an event to update the state."""
# The function to call in response to the event.
fn: Any
fn: Any = dataclasses.field(default=None)
# The full name of the state class this event handler is attached to.
# Empty string means this event handler is a server side event.
state_full_name: str = ""
class Config:
"""The Pydantic config."""
# Needed to allow serialization of Callable.
frozen = True
state_full_name: str = dataclasses.field(default="")
@classmethod
def __class_getitem__(cls, args_spec: str) -> Annotated:
@ -215,6 +225,10 @@ class EventHandler(EventActionsMixin):
)
@dataclasses.dataclass(
init=True,
frozen=True,
)
class EventSpec(EventActionsMixin):
"""An event specification.
@ -223,19 +237,37 @@ class EventSpec(EventActionsMixin):
"""
# The event handler.
handler: EventHandler
handler: EventHandler = dataclasses.field(default=None) # type: ignore
# The handler on the client to process event.
client_handler_name: str = ""
client_handler_name: str = dataclasses.field(default="")
# The arguments to pass to the function.
args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] = ()
args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] = dataclasses.field(
default_factory=tuple
)
class Config:
"""The Pydantic config."""
def __init__(
self,
handler: EventHandler,
event_actions: Dict[str, Union[bool, int]] | None = None,
client_handler_name: str = "",
args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] = tuple(),
):
"""Initialize an EventSpec.
# Required to allow tuple fields.
frozen = True
Args:
event_actions: The event actions.
handler: The event handler.
client_handler_name: The client handler name.
args: The arguments to pass to the function.
"""
if event_actions is None:
event_actions = {}
object.__setattr__(self, "event_actions", event_actions)
object.__setattr__(self, "handler", handler)
object.__setattr__(self, "client_handler_name", client_handler_name)
object.__setattr__(self, "args", args or tuple())
def with_args(
self, args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...]
@ -286,6 +318,9 @@ class EventSpec(EventActionsMixin):
return self.with_args(self.args + new_payload)
@dataclasses.dataclass(
frozen=True,
)
class CallableEventSpec(EventSpec):
"""Decorate an EventSpec-returning function to act as both a EventSpec and a function.
@ -305,10 +340,13 @@ class CallableEventSpec(EventSpec):
if fn is not None:
default_event_spec = fn()
super().__init__(
fn=fn, # type: ignore
**default_event_spec.dict(),
event_actions=default_event_spec.event_actions,
client_handler_name=default_event_spec.client_handler_name,
args=default_event_spec.args,
handler=default_event_spec.handler,
**kwargs,
)
object.__setattr__(self, "fn", fn)
else:
super().__init__(**kwargs)
@ -332,12 +370,16 @@ class CallableEventSpec(EventSpec):
return self.fn(*args, **kwargs)
@dataclasses.dataclass(
init=True,
frozen=True,
)
class EventChain(EventActionsMixin):
"""Container for a chain of events that will be executed in order."""
events: List[EventSpec]
events: List[EventSpec] = dataclasses.field(default_factory=list)
args_spec: Optional[Callable]
args_spec: Optional[Callable] = dataclasses.field(default=None)
# These chains can be used for their side effects when no other events are desired.
@ -345,14 +387,22 @@ stop_propagation = EventChain(events=[], args_spec=lambda: []).stop_propagation
prevent_default = EventChain(events=[], args_spec=lambda: []).prevent_default
class Target(Base):
@dataclasses.dataclass(
init=True,
frozen=True,
)
class Target:
"""A Javascript event target."""
checked: bool = False
value: Any = None
class FrontendEvent(Base):
@dataclasses.dataclass(
init=True,
frozen=True,
)
class FrontendEvent:
"""A Javascript event."""
target: Target = Target()
@ -360,7 +410,11 @@ class FrontendEvent(Base):
value: Any = None
class FileUpload(Base):
@dataclasses.dataclass(
init=True,
frozen=True,
)
class FileUpload:
"""Class to represent a file upload."""
upload_id: Optional[str] = None

View File

@ -421,6 +421,9 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
if issubclass(output, (ObjectVar, Base)):
return ToObjectOperation.create(self, var_type or dict)
if dataclasses.is_dataclass(output):
return ToObjectOperation.create(self, var_type or dict)
if issubclass(output, FunctionVar):
# if fixed_type is not None and not issubclass(fixed_type, Callable):
# raise TypeError(
@ -479,7 +482,11 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
):
return self.to(NumberVar, self._var_type)
if all(inspect.isclass(t) and issubclass(t, Base) for t in inner_types):
if all(
inspect.isclass(t)
and (issubclass(t, Base) or dataclasses.is_dataclass(t))
for t in inner_types
):
return self.to(ObjectVar, self._var_type)
return self
@ -499,6 +506,8 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
return self.to(StringVar, self._var_type)
if issubclass(fixed_type, Base):
return self.to(ObjectVar, self._var_type)
if dataclasses.is_dataclass(fixed_type):
return self.to(ObjectVar, self._var_type)
return self
def get_default_value(self) -> Any:
@ -985,6 +994,16 @@ class LiteralVar(ImmutableVar):
)
return LiteralVar.create(serialized_value, _var_data=_var_data)
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return LiteralObjectVar.create(
{
k: (None if callable(v) else v)
for k, v in dataclasses.asdict(value).items()
},
_var_type=type(value),
_var_data=_var_data,
)
raise TypeError(
f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}."
)

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING, Optional
from reflex import constants
@ -14,6 +15,7 @@ if TYPE_CHECKING:
from reflex.app import App
@dataclasses.dataclass(init=True)
class HydrateMiddleware(Middleware):
"""Middleware to handle initial app hydration."""

View File

@ -2,10 +2,9 @@
from __future__ import annotations
from abc import ABC
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from reflex.base import Base
from reflex.event import Event
from reflex.state import BaseState, StateUpdate
@ -13,9 +12,10 @@ if TYPE_CHECKING:
from reflex.app import App
class Middleware(Base, ABC):
class Middleware(ABC):
"""Middleware to preprocess and postprocess requests."""
@abstractmethod
async def preprocess(
self, app: App, state: BaseState, event: Event
) -> Optional[StateUpdate]:

View File

@ -5,8 +5,10 @@ from __future__ import annotations
import asyncio
import contextlib
import copy
import dataclasses
import functools
import inspect
import json
import os
import uuid
from abc import ABC, abstractmethod
@ -83,13 +85,15 @@ var = immutable_computed_var
TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb
class HeaderData(Base):
@dataclasses.dataclass(frozen=True)
class HeaderData:
"""An object containing headers data."""
host: str = ""
origin: str = ""
upgrade: str = ""
connection: str = ""
cookie: str = ""
pragma: str = ""
cache_control: str = ""
user_agent: str = ""
@ -105,13 +109,16 @@ class HeaderData(Base):
Args:
router_data: the router_data dict.
"""
super().__init__()
if router_data:
for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items():
setattr(self, format.to_snake_case(k), v)
object.__setattr__(self, format.to_snake_case(k), v)
else:
for k in dataclasses.fields(self):
object.__setattr__(self, k.name, "")
class PageData(Base):
@dataclasses.dataclass(frozen=True)
class PageData:
"""An object containing page data."""
host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?)
@ -119,7 +126,7 @@ class PageData(Base):
raw_path: str = ""
full_path: str = ""
full_raw_path: str = ""
params: dict = {}
params: dict = dataclasses.field(default_factory=dict)
def __init__(self, router_data: Optional[dict] = None):
"""Initalize the PageData object based on router_data.
@ -127,17 +134,34 @@ class PageData(Base):
Args:
router_data: the router_data dict.
"""
super().__init__()
if router_data:
self.host = router_data.get(constants.RouteVar.HEADERS, {}).get("origin")
self.path = router_data.get(constants.RouteVar.PATH, "")
self.raw_path = router_data.get(constants.RouteVar.ORIGIN, "")
self.full_path = f"{self.host}{self.path}"
self.full_raw_path = f"{self.host}{self.raw_path}"
self.params = router_data.get(constants.RouteVar.QUERY, {})
object.__setattr__(
self,
"host",
router_data.get(constants.RouteVar.HEADERS, {}).get("origin", ""),
)
object.__setattr__(
self, "path", router_data.get(constants.RouteVar.PATH, "")
)
object.__setattr__(
self, "raw_path", router_data.get(constants.RouteVar.ORIGIN, "")
)
object.__setattr__(self, "full_path", f"{self.host}{self.path}")
object.__setattr__(self, "full_raw_path", f"{self.host}{self.raw_path}")
object.__setattr__(
self, "params", router_data.get(constants.RouteVar.QUERY, {})
)
else:
object.__setattr__(self, "host", "")
object.__setattr__(self, "path", "")
object.__setattr__(self, "raw_path", "")
object.__setattr__(self, "full_path", "")
object.__setattr__(self, "full_raw_path", "")
object.__setattr__(self, "params", {})
class SessionData(Base):
@dataclasses.dataclass(frozen=True, init=False)
class SessionData:
"""An object containing session data."""
client_token: str = ""
@ -150,19 +174,24 @@ class SessionData(Base):
Args:
router_data: the router_data dict.
"""
super().__init__()
if router_data:
self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
self.session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
else:
client_token = client_ip = session_id = ""
object.__setattr__(self, "client_token", client_token)
object.__setattr__(self, "client_ip", client_ip)
object.__setattr__(self, "session_id", session_id)
class RouterData(Base):
@dataclasses.dataclass(frozen=True, init=False)
class RouterData:
"""An object containing RouterData."""
session: SessionData = SessionData()
headers: HeaderData = HeaderData()
page: PageData = PageData()
session: SessionData = dataclasses.field(default_factory=SessionData)
headers: HeaderData = dataclasses.field(default_factory=HeaderData)
page: PageData = dataclasses.field(default_factory=PageData)
def __init__(self, router_data: Optional[dict] = None):
"""Initialize the RouterData object.
@ -170,10 +199,30 @@ class RouterData(Base):
Args:
router_data: the router_data dict.
"""
super().__init__()
self.session = SessionData(router_data)
self.headers = HeaderData(router_data)
self.page = PageData(router_data)
object.__setattr__(self, "session", SessionData(router_data))
object.__setattr__(self, "headers", HeaderData(router_data))
object.__setattr__(self, "page", PageData(router_data))
def toJson(self) -> str:
"""Convert the object to a JSON string.
Returns:
The JSON string.
"""
return json.dumps(dataclasses.asdict(self))
@serializer
def serialize_routerdata(value: RouterData) -> str:
"""Serialize a RouterData instance.
Args:
value: The RouterData to serialize.
Returns:
The serialized RouterData.
"""
return value.toJson()
def _no_chain_background_task(
@ -249,10 +298,11 @@ def _split_substate_key(substate_key: str) -> tuple[str, str]:
return token, state_name
@dataclasses.dataclass(frozen=True, init=False)
class EventHandlerSetVar(EventHandler):
"""A special event handler to wrap setvar functionality."""
state_cls: Type[BaseState]
state_cls: Type[BaseState] = dataclasses.field(init=False)
def __init__(self, state_cls: Type[BaseState]):
"""Initialize the EventHandlerSetVar.
@ -263,8 +313,8 @@ class EventHandlerSetVar(EventHandler):
super().__init__(
fn=type(self).setvar,
state_full_name=state_cls.get_full_name(),
state_cls=state_cls, # type: ignore
)
object.__setattr__(self, "state_cls", state_cls)
def setvar(self, var_name: str, value: Any):
"""Set the state variable to the value of the event.
@ -1826,8 +1876,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty()
def dictify(value: Any):
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return dataclasses.asdict(value)
return value
base_vars = {
prop_name: self.get_value(getattr(self, prop_name))
prop_name: dictify(self.get_value(getattr(self, prop_name)))
for prop_name in self.base_vars
}
if initial and include_computed:
@ -1907,9 +1962,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return state
EventHandlerSetVar.update_forward_refs()
class State(BaseState):
"""The app Base State."""
@ -2341,18 +2393,29 @@ class StateProxy(wrapt.ObjectProxy):
self._self_mutable = original_mutable
class StateUpdate(Base):
@dataclasses.dataclass(
frozen=True,
)
class StateUpdate:
"""A state update sent to the frontend."""
# The state delta.
delta: Delta = {}
delta: Delta = dataclasses.field(default_factory=dict)
# Events to be added to the event queue.
events: List[Event] = []
events: List[Event] = dataclasses.field(default_factory=list)
# Whether this is the final state update for the event.
final: bool = True
def json(self) -> str:
"""Convert the state update to a JSON string.
Returns:
The state update as a JSON string.
"""
return json.dumps(dataclasses.asdict(self))
class StateManager(Base, ABC):
"""A class to manage many client states."""

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import dataclasses
import inspect
import json
import os
@ -623,6 +624,14 @@ def format_state(value: Any, key: Optional[str] = None) -> Any:
if isinstance(value, dict):
return {k: format_state(v, k) for k, v in value.items()}
# Hand dataclasses.
if dataclasses.is_dataclass(value):
if isinstance(value, type):
raise TypeError(
f"Cannot format state of type {type(value)}. Please provide an instance of the dataclass."
)
return {k: format_state(v, k) for k, v in dataclasses.asdict(value).items()}
# Handle lists, sets, typles.
if isinstance(value, types.StateIterBases):
return [format_state(v) for v in value]

View File

@ -2,10 +2,9 @@
from __future__ import annotations
import dataclasses
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
from reflex.base import Base
from typing import DefaultDict, Dict, List, Optional, Tuple, Union
def merge_imports(
@ -19,12 +18,22 @@ def merge_imports(
Returns:
The merged import dicts.
"""
all_imports = defaultdict(list)
all_imports: DefaultDict[str, List[ImportVar]] = defaultdict(list)
for import_dict in imports:
for lib, fields in (
import_dict if isinstance(import_dict, tuple) else import_dict.items()
):
all_imports[lib].extend(fields)
if isinstance(fields, (list, tuple, set)):
all_imports[lib].extend(
(
ImportVar(field) if isinstance(field, str) else field
for field in fields
)
)
else:
all_imports[lib].append(
ImportVar(fields) if isinstance(fields, str) else fields
)
return all_imports
@ -75,7 +84,8 @@ def collapse_imports(
}
class ImportVar(Base):
@dataclasses.dataclass(order=True, frozen=True)
class ImportVar:
"""An import var."""
# The name of the import tag.
@ -111,73 +121,6 @@ class ImportVar(Base):
else:
return self.tag or ""
def __lt__(self, other: ImportVar) -> bool:
"""Compare two ImportVar objects.
Args:
other: The other ImportVar object to compare.
Returns:
Whether this ImportVar object is less than the other.
"""
return (
self.tag,
self.is_default,
self.alias,
self.install,
self.render,
self.transpile,
) < (
other.tag,
other.is_default,
other.alias,
other.install,
other.render,
other.transpile,
)
def __eq__(self, other: ImportVar) -> bool:
"""Check if two ImportVar objects are equal.
Args:
other: The other ImportVar object to compare.
Returns:
Whether the two ImportVar objects are equal.
"""
return (
self.tag,
self.is_default,
self.alias,
self.install,
self.render,
self.transpile,
) == (
other.tag,
other.is_default,
other.alias,
other.install,
other.render,
other.transpile,
)
def __hash__(self) -> int:
"""Hash the ImportVar object.
Returns:
The hash of the ImportVar object.
"""
return hash(
(
self.tag,
self.is_default,
self.alias,
self.install,
self.render,
self.transpile,
)
)
ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]]
ImportDict = Dict[str, ImportTypes]

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import dataclasses
import functools
import glob
import importlib
@ -32,7 +33,6 @@ from redis import exceptions
from redis.asyncio import Redis
from reflex import constants, model
from reflex.base import Base
from reflex.compiler import templates
from reflex.config import Config, get_config
from reflex.utils import console, net, path_ops, processes
@ -43,7 +43,8 @@ from reflex.utils.registry import _get_best_registry
CURRENTLY_INSTALLING_NODE = False
class Template(Base):
@dataclasses.dataclass(frozen=True)
class Template:
"""A template for a Reflex app."""
name: str
@ -52,7 +53,8 @@ class Template(Base):
demo_url: str
class CpuInfo(Base):
@dataclasses.dataclass(frozen=True)
class CpuInfo:
"""Model to save cpu info."""
manufacturer_id: Optional[str]
@ -1279,7 +1281,7 @@ def fetch_app_templates(version: str) -> dict[str, Template]:
None,
)
return {
tp["name"]: Template.parse_obj(tp)
tp["name"]: Template(**tp)
for tp in templates_data
if not tp["hidden"] and tp["code_url"] is not None
}

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import dataclasses
import multiprocessing
import platform
import warnings
@ -144,7 +145,7 @@ def _prepare_event(event: str, **kwargs) -> dict:
"python_version": get_python_version(),
"cpu_count": get_cpu_count(),
"memory": get_memory(),
"cpu_info": dict(cpuinfo) if cpuinfo else {},
"cpu_info": dataclasses.asdict(cpuinfo) if cpuinfo else {},
**additional_fields,
},
"timestamp": stamp,

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import contextlib
import dataclasses
import inspect
import sys
import types
@ -480,7 +481,11 @@ def is_valid_var_type(type_: Type) -> bool:
if is_union(type_):
return all((is_valid_var_type(arg) for arg in get_args(type_)))
return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
return (
_issubclass(type_, StateVar)
or serializers.has_serializer(type_)
or dataclasses.is_dataclass(type_)
)
def is_backend_base_variable(name: str, cls: Type) -> bool:

View File

@ -637,21 +637,21 @@ def test_component_create_unallowed_types(children, test_component):
"props": [],
"contents": "",
"args": None,
"special_props": set(),
"special_props": [],
"children": [
{
"name": "RadixThemesText",
"props": ['as={"p"}'],
"contents": "",
"args": None,
"special_props": set(),
"special_props": [],
"children": [
{
"name": "",
"props": [],
"contents": '{"first_text"}',
"args": None,
"special_props": set(),
"special_props": [],
"children": [],
"autofocus": False,
}
@ -679,13 +679,13 @@ def test_component_create_unallowed_types(children, test_component):
"contents": '{"first_text"}',
"name": "",
"props": [],
"special_props": set(),
"special_props": [],
}
],
"contents": "",
"name": "RadixThemesText",
"props": ['as={"p"}'],
"special_props": set(),
"special_props": [],
},
{
"args": None,
@ -698,19 +698,19 @@ def test_component_create_unallowed_types(children, test_component):
"contents": '{"second_text"}',
"name": "",
"props": [],
"special_props": set(),
"special_props": [],
}
],
"contents": "",
"name": "RadixThemesText",
"props": ['as={"p"}'],
"special_props": set(),
"special_props": [],
},
],
"contents": "",
"name": "Fragment",
"props": [],
"special_props": set(),
"special_props": [],
},
),
(
@ -730,13 +730,13 @@ def test_component_create_unallowed_types(children, test_component):
"contents": '{"first_text"}',
"name": "",
"props": [],
"special_props": set(),
"special_props": [],
}
],
"contents": "",
"name": "RadixThemesText",
"props": ['as={"p"}'],
"special_props": set(),
"special_props": [],
},
{
"args": None,
@ -757,31 +757,31 @@ def test_component_create_unallowed_types(children, test_component):
"contents": '{"second_text"}',
"name": "",
"props": [],
"special_props": set(),
"special_props": [],
}
],
"contents": "",
"name": "RadixThemesText",
"props": ['as={"p"}'],
"special_props": set(),
"special_props": [],
}
],
"contents": "",
"name": "Fragment",
"props": [],
"special_props": set(),
"special_props": [],
}
],
"contents": "",
"name": "RadixThemesBox",
"props": [],
"special_props": set(),
"special_props": [],
},
],
"contents": "",
"name": "Fragment",
"props": [],
"special_props": set(),
"special_props": [],
},
),
],
@ -1289,12 +1289,12 @@ class EventState(rx.State):
id="fstring-class_name",
),
pytest.param(
rx.fragment(special_props={TEST_VAR}),
rx.fragment(special_props=[TEST_VAR]),
[TEST_VAR],
id="direct-special_props",
),
pytest.param(
rx.fragment(special_props={LiteralVar.create(f"foo{TEST_VAR}bar")}),
rx.fragment(special_props=[LiteralVar.create(f"foo{TEST_VAR}bar")]),
[FORMATTED_TEST_VAR],
id="fstring-special_props",
),

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import dataclasses
import functools
import io
import json
@ -1052,7 +1053,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
f"comp_{arg_name}": exp_val,
constants.CompileVars.IS_HYDRATED: False,
# "side_effect_counter": exp_index,
"router": exp_router,
"router": dataclasses.asdict(exp_router),
}
},
events=[

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import copy
import dataclasses
import datetime
import functools
import json
@ -58,6 +59,7 @@ formatted_router = {
"origin": "",
"upgrade": "",
"connection": "",
"cookie": "",
"pragma": "",
"cache_control": "",
"user_agent": "",
@ -865,8 +867,10 @@ def test_get_headers(test_state, router_data, router_data_headers):
router_data: The router data fixture.
router_data_headers: The expected headers.
"""
print(router_data_headers)
test_state.router = RouterData(router_data)
assert test_state.router.headers.dict() == {
print(test_state.router.headers)
assert dataclasses.asdict(test_state.router.headers) == {
format.to_snake_case(k): v for k, v in router_data_headers.items()
}
@ -1908,19 +1912,21 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
mock_app.event_namespace.emit.assert_called_once()
mcall = mock_app.event_namespace.emit.mock_calls[0]
assert mcall.args[0] == str(SocketEvent.EVENT)
assert json.loads(mcall.args[1]) == StateUpdate(
delta={
parent_state.get_full_name(): {
"upper": "",
"sum": 3.14,
},
grandchild_state.get_full_name(): {
"value2": "42",
},
GrandchildState3.get_full_name(): {
"computed": "",
},
}
assert json.loads(mcall.args[1]) == dataclasses.asdict(
StateUpdate(
delta={
parent_state.get_full_name(): {
"upper": "",
"sum": 3.14,
},
grandchild_state.get_full_name(): {
"value2": "42",
},
GrandchildState3.get_full_name(): {
"computed": "",
},
}
)
)
assert mcall.kwargs["to"] == grandchild_state.router.session.session_id

View File

@ -553,6 +553,7 @@ formatted_router = {
"origin": "",
"upgrade": "",
"connection": "",
"cookie": "",
"pragma": "",
"cache_control": "",
"user_agent": "",

View File

@ -54,17 +54,21 @@ def test_import_var(import_var, expected_name):
(
{"react": {"Component"}},
{"react": {"Component"}, "react-dom": {"render"}},
{"react": {"Component"}, "react-dom": {"render"}},
{"react": {ImportVar("Component")}, "react-dom": {ImportVar("render")}},
),
(
{"react": {"Component"}, "next/image": {"Image"}},
{"react": {"Component"}, "react-dom": {"render"}},
{"react": {"Component"}, "react-dom": {"render"}, "next/image": {"Image"}},
{
"react": {ImportVar("Component")},
"react-dom": {ImportVar("render")},
"next/image": {ImportVar("Image")},
},
),
(
{"react": {"Component"}},
{"": {"some/custom.css"}},
{"react": {"Component"}, "": {"some/custom.css"}},
{"react": {ImportVar("Component")}, "": {ImportVar("some/custom.css")}},
),
],
)