Remove Pydantic from some classes (#3907)

* half of the way there

* add dataclass support

* Forbid Computed var shadowing (#3843)

* get it right pyright

* fix unit tests

* rip out more pydantic

* fix weird issues with merge_imports

* add missing docstring

* make special props a list instead of a set

* fix moment pyi

* actually ignore the runtime error

* it's ruff out there

---------

Co-authored-by: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com>
This commit is contained in:
Khaleel Al-Adhami 2024-09-13 12:53:30 -07:00 committed by GitHub
parent 7c25358607
commit 8f937f0417
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 394 additions and 248 deletions

View File

@ -9,6 +9,7 @@ import copy
import functools import functools
import inspect import inspect
import io import io
import json
import multiprocessing import multiprocessing
import os import os
import platform import platform
@ -1096,6 +1097,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
if delta: if delta:
# When the state is modified reset dirty status and emit the delta to the frontend. # When the state is modified reset dirty status and emit the delta to the frontend.
state._clean() state._clean()
print(dir(state.router))
await self.event_namespace.emit_update( await self.event_namespace.emit_update(
update=StateUpdate(delta=delta), update=StateUpdate(delta=delta),
sid=state.router.session.session_id, sid=state.router.session.session_id,
@ -1531,8 +1533,9 @@ class EventNamespace(AsyncNamespace):
sid: The Socket.IO session id. sid: The Socket.IO session id.
data: The event data. data: The event data.
""" """
fields = json.loads(data)
# Get the event. # Get the event.
event = Event.parse_raw(data) event = Event(**{k: v for k, v in fields.items() if k != "handler"})
self.token_to_sid[event.token] = sid self.token_to_sid[event.token] = sid
self.sid_to_token[sid] = event.token self.sid_to_token[sid] = event.token

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import copy import copy
import typing import typing
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import lru_cache, wraps from functools import lru_cache, wraps
from hashlib import md5 from hashlib import md5
@ -169,6 +170,8 @@ ComponentStyle = Dict[
] ]
ComponentChild = Union[types.PrimitiveType, Var, BaseComponent] ComponentChild = Union[types.PrimitiveType, Var, BaseComponent]
warnings.filterwarnings("ignore", message="fields may not start with an underscore")
class Component(BaseComponent, ABC): class Component(BaseComponent, ABC):
"""A component with style, event trigger and other props.""" """A component with style, event trigger and other props."""
@ -195,7 +198,7 @@ class Component(BaseComponent, ABC):
class_name: Any = None class_name: Any = None
# Special component props. # Special component props.
special_props: Set[ImmutableVar] = set() special_props: List[ImmutableVar] = []
# Whether the component should take the focus once the page is loaded # Whether the component should take the focus once the page is loaded
autofocus: bool = False autofocus: bool = False
@ -655,7 +658,7 @@ class Component(BaseComponent, ABC):
""" """
# Create the base tag. # Create the base tag.
tag = Tag( tag = Tag(
name=self.tag if not self.alias else self.alias, name=(self.tag if not self.alias else self.alias) or "",
special_props=self.special_props, special_props=self.special_props,
) )
@ -2244,7 +2247,7 @@ class StatefulComponent(BaseComponent):
Returns: Returns:
The tag to render. The tag to render.
""" """
return dict(Tag(name=self.tag)) return dict(Tag(name=self.tag or ""))
def __str__(self) -> str: def __str__(self) -> str:
"""Represent the component in React. """Represent the component in React.

View File

@ -247,9 +247,9 @@ class Upload(MemoizationLeaf):
} }
# The file input to use. # The file input to use.
upload = Input.create(type="file") upload = Input.create(type="file")
upload.special_props = { upload.special_props = [
ImmutableVar(_var_name="{...getInputProps()}", _var_type=None) ImmutableVar(_var_name="{...getInputProps()}", _var_type=None)
} ]
# The dropzone to use. # The dropzone to use.
zone = Box.create( zone = Box.create(
@ -257,9 +257,9 @@ class Upload(MemoizationLeaf):
*children, *children,
**{k: v for k, v in props.items() if k not in supported_props}, **{k: v for k, v in props.items() if k not in supported_props},
) )
zone.special_props = { zone.special_props = [
ImmutableVar(_var_name="{...getRootProps()}", _var_type=None) ImmutableVar(_var_name="{...getRootProps()}", _var_type=None)
} ]
# Create the component. # Create the component.
upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID) upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID)

View File

@ -1,6 +1,6 @@
"""Element classes. This is an auto-generated file. Do not edit. See ../generate.py.""" """Element classes. This is an auto-generated file. Do not edit. See ../generate.py."""
from typing import Set, Union from typing import List, Union
from reflex.components.el.element import Element from reflex.components.el.element import Element
from reflex.ivars.base import ImmutableVar from reflex.ivars.base import ImmutableVar
@ -90,9 +90,9 @@ class StyleEl(Element): # noqa: E742
media: Var[Union[str, int, bool]] media: Var[Union[str, int, bool]]
special_props: Set[ImmutableVar] = { special_props: List[ImmutableVar] = [
ImmutableVar.create_safe("suppressHydrationWarning") ImmutableVar.create_safe("suppressHydrationWarning")
} ]
base = Base.create base = Base.create

View File

@ -195,17 +195,17 @@ class Markdown(Component):
if tag not in self.component_map: if tag not in self.component_map:
raise ValueError(f"No markdown component found for tag: {tag}.") raise ValueError(f"No markdown component found for tag: {tag}.")
special_props = {_PROPS_IN_TAG} special_props = [_PROPS_IN_TAG]
children = [_CHILDREN] children = [_CHILDREN]
# For certain tags, the props from the markdown renderer are not actually valid for the component. # For certain tags, the props from the markdown renderer are not actually valid for the component.
if tag in NO_PROPS_TAGS: if tag in NO_PROPS_TAGS:
special_props = set() special_props = []
# If the children are set as a prop, don't pass them as children. # If the children are set as a prop, don't pass them as children.
children_prop = props.pop("children", None) children_prop = props.pop("children", None)
if children_prop is not None: if children_prop is not None:
special_props.add( special_props.append(
ImmutableVar.create_safe(f"children={{{str(children_prop)}}}") ImmutableVar.create_safe(f"children={{{str(children_prop)}}}")
) )
children = [] children = []

View File

@ -1,26 +1,27 @@
"""Moment component for humanized date rendering.""" """Moment component for humanized date rendering."""
import dataclasses
from typing import List, Optional from typing import List, Optional
from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent from reflex.components.component import Component, NoSSRComponent
from reflex.event import EventHandler from reflex.event import EventHandler
from reflex.utils.imports import ImportDict from reflex.utils.imports import ImportDict
from reflex.vars import Var from reflex.vars import Var
class MomentDelta(Base): @dataclasses.dataclass(frozen=True)
class MomentDelta:
"""A delta used for add/subtract prop in Moment.""" """A delta used for add/subtract prop in Moment."""
years: Optional[int] years: Optional[int] = dataclasses.field(default=None)
quarters: Optional[int] quarters: Optional[int] = dataclasses.field(default=None)
months: Optional[int] months: Optional[int] = dataclasses.field(default=None)
weeks: Optional[int] weeks: Optional[int] = dataclasses.field(default=None)
days: Optional[int] days: Optional[int] = dataclasses.field(default=None)
hours: Optional[int] hours: Optional[int] = dataclasses.field(default=None)
minutess: Optional[int] minutess: Optional[int] = dataclasses.field(default=None)
seconds: Optional[int] seconds: Optional[int] = dataclasses.field(default=None)
milliseconds: Optional[int] milliseconds: Optional[int] = dataclasses.field(default=None)
class Moment(NoSSRComponent): class Moment(NoSSRComponent):

View File

@ -3,9 +3,9 @@
# ------------------- DO NOT EDIT ---------------------- # ------------------- DO NOT EDIT ----------------------
# This file was generated by `reflex/utils/pyi_generator.py`! # This file was generated by `reflex/utils/pyi_generator.py`!
# ------------------------------------------------------ # ------------------------------------------------------
import dataclasses
from typing import Any, Callable, Dict, Optional, Union, overload from typing import Any, Callable, Dict, Optional, Union, overload
from reflex.base import Base
from reflex.components.component import NoSSRComponent from reflex.components.component import NoSSRComponent
from reflex.event import EventHandler, EventSpec from reflex.event import EventHandler, EventSpec
from reflex.ivars.base import ImmutableVar from reflex.ivars.base import ImmutableVar
@ -13,7 +13,8 @@ from reflex.style import Style
from reflex.utils.imports import ImportDict from reflex.utils.imports import ImportDict
from reflex.vars import Var from reflex.vars import Var
class MomentDelta(Base): @dataclasses.dataclass(frozen=True)
class MomentDelta:
years: Optional[int] years: Optional[int]
quarters: Optional[int] quarters: Optional[int]
months: Optional[int] months: Optional[int]

View File

@ -267,7 +267,7 @@ const extractPoints = (points) => {
template_dict = LiteralVar.create({"layout": {"template": self.template}}) template_dict = LiteralVar.create({"layout": {"template": self.template}})
merge_dicts.append(template_dict.without_data()) merge_dicts.append(template_dict.without_data())
if merge_dicts: if merge_dicts:
tag.special_props.add( tag.special_props.append(
# Merge all dictionaries and spread the result over props. # Merge all dictionaries and spread the result over props.
ImmutableVar.create_safe( ImmutableVar.create_safe(
f"{{...mergician({str(figure)}," f"{{...mergician({str(figure)},"
@ -276,5 +276,5 @@ const extractPoints = (points) => {
) )
else: else:
# Spread the figure dict over props, nothing to merge. # Spread the figure dict over props, nothing to merge.
tag.special_props.add(ImmutableVar.create_safe(f"{{...{str(figure)}}}")) tag.special_props.append(ImmutableVar.create_safe(f"{{...{str(figure)}}}"))
return tag return tag

View File

@ -1,19 +1,22 @@
"""Tag to conditionally render components.""" """Tag to conditionally render components."""
import dataclasses
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.ivars.base import LiteralVar
from reflex.vars import Var from reflex.vars import Var
@dataclasses.dataclass()
class CondTag(Tag): class CondTag(Tag):
"""A conditional tag.""" """A conditional tag."""
# The condition to determine which component to render. # The condition to determine which component to render.
cond: Var[Any] cond: Var[Any] = dataclasses.field(default_factory=lambda: LiteralVar.create(True))
# The code to render if the condition is true. # The code to render if the condition is true.
true_value: Dict true_value: Dict = dataclasses.field(default_factory=dict)
# The code to render if the condition is false. # The code to render if the condition is false.
false_value: Optional[Dict] false_value: Optional[Dict] = None

View File

@ -2,31 +2,36 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
import inspect import inspect
from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Type, Union, get_args from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Type, Union, get_args
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.ivars.base import ImmutableVar from reflex.ivars.base import ImmutableVar
from reflex.vars import Var from reflex.ivars.sequence import LiteralArrayVar
from reflex.vars import Var, get_unique_variable_name
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.components.component import Component from reflex.components.component import Component
@dataclasses.dataclass()
class IterTag(Tag): class IterTag(Tag):
"""An iterator tag.""" """An iterator tag."""
# The var to iterate over. # The var to iterate over.
iterable: Var[List] iterable: Var[List] = dataclasses.field(
default_factory=lambda: LiteralArrayVar.create([])
)
# The component render function for each item in the iterable. # The component render function for each item in the iterable.
render_fn: Callable render_fn: Callable = dataclasses.field(default_factory=lambda: lambda x: x)
# The name of the arg var. # The name of the arg var.
arg_var_name: str arg_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
# The name of the index var. # The name of the index var.
index_var_name: str index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
def get_iterable_var_type(self) -> Type: def get_iterable_var_type(self) -> Type:
"""Get the type of the iterable var. """Get the type of the iterable var.

View File

@ -1,19 +1,22 @@
"""Tag to conditionally match cases.""" """Tag to conditionally match cases."""
import dataclasses
from typing import Any, List from typing import Any, List
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.ivars.base import LiteralVar
from reflex.vars import Var from reflex.vars import Var
@dataclasses.dataclass()
class MatchTag(Tag): class MatchTag(Tag):
"""A match tag.""" """A match tag."""
# The condition to determine which case to match. # The condition to determine which case to match.
cond: Var[Any] cond: Var[Any] = dataclasses.field(default_factory=lambda: LiteralVar.create(True))
# The list of match cases to be matched. # The list of match cases to be matched.
match_cases: List[Any] match_cases: List[Any] = dataclasses.field(default_factory=list)
# The catchall case to match. # The catchall case to match.
default: Any default: Any = dataclasses.field(default=LiteralVar.create(None))

View File

@ -2,22 +2,23 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Set, Tuple, Union import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Union
from reflex.base import Base
from reflex.event import EventChain from reflex.event import EventChain
from reflex.ivars.base import ImmutableVar, LiteralVar from reflex.ivars.base import ImmutableVar, LiteralVar
from reflex.utils import format, types from reflex.utils import format, types
class Tag(Base): @dataclasses.dataclass()
class Tag:
"""A React tag.""" """A React tag."""
# The name of the tag. # The name of the tag.
name: str = "" name: str = ""
# The props of the tag. # The props of the tag.
props: Dict[str, Any] = {} props: Dict[str, Any] = dataclasses.field(default_factory=dict)
# The inner contents of the tag. # The inner contents of the tag.
contents: str = "" contents: str = ""
@ -26,25 +27,18 @@ class Tag(Base):
args: Optional[Tuple[str, ...]] = None args: Optional[Tuple[str, ...]] = None
# Special props that aren't key value pairs. # Special props that aren't key value pairs.
special_props: Set[ImmutableVar] = set() special_props: List[ImmutableVar] = dataclasses.field(default_factory=list)
# The children components. # The children components.
children: List[Any] = [] children: List[Any] = dataclasses.field(default_factory=list)
def __init__(self, *args, **kwargs): def __post_init__(self):
"""Initialize the tag. """Post initialize the tag."""
object.__setattr__(
Args: self,
*args: Args to initialize the tag. "props",
**kwargs: Kwargs to initialize the tag. {name: LiteralVar.create(value) for name, value in self.props.items()},
""" )
# Convert any props to vars.
if "props" in kwargs:
kwargs["props"] = {
name: LiteralVar.create(value)
for name, value in kwargs["props"].items()
}
super().__init__(*args, **kwargs)
def format_props(self) -> List: def format_props(self) -> List:
"""Format the tag's props. """Format the tag's props.
@ -54,6 +48,29 @@ class Tag(Base):
""" """
return format.format_props(*self.special_props, **self.props) return format.format_props(*self.special_props, **self.props)
def set(self, **kwargs: Any):
"""Set the tag's fields.
Args:
kwargs: The fields to set.
Returns:
The tag with the fields
"""
for name, value in kwargs.items():
setattr(self, name, value)
return self
def __iter__(self):
"""Iterate over the tag's fields.
Yields:
Tuple[str, Any]: The field name and value.
"""
for field in dataclasses.fields(self):
yield field.name, getattr(self, field.name)
def add_props(self, **kwargs: Optional[Any]) -> Tag: def add_props(self, **kwargs: Optional[Any]) -> Tag:
"""Add props to the tag. """Add props to the tag.

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
import inspect import inspect
import types import types
import urllib.parse import urllib.parse
@ -18,7 +19,6 @@ from typing import (
) )
from reflex import constants from reflex import constants
from reflex.base import Base
from reflex.ivars.base import ImmutableVar, LiteralVar from reflex.ivars.base import ImmutableVar, LiteralVar
from reflex.ivars.function import FunctionStringVar, FunctionVar from reflex.ivars.function import FunctionStringVar, FunctionVar
from reflex.ivars.object import ObjectVar from reflex.ivars.object import ObjectVar
@ -33,7 +33,11 @@ except ImportError:
from typing_extensions import Annotated from typing_extensions import Annotated
class Event(Base): @dataclasses.dataclass(
init=True,
frozen=True,
)
class Event:
"""An event that describes any state change in the app.""" """An event that describes any state change in the app."""
# The token to specify the client that the event is for. # The token to specify the client that the event is for.
@ -43,10 +47,10 @@ class Event(Base):
name: str name: str
# The routing data where event occurred # The routing data where event occurred
router_data: Dict[str, Any] = {} router_data: Dict[str, Any] = dataclasses.field(default_factory=dict)
# The event payload. # The event payload.
payload: Dict[str, Any] = {} payload: Dict[str, Any] = dataclasses.field(default_factory=dict)
@property @property
def substate_token(self) -> str: def substate_token(self) -> str:
@ -81,11 +85,15 @@ def background(fn):
return fn return fn
class EventActionsMixin(Base): @dataclasses.dataclass(
init=True,
frozen=True,
)
class EventActionsMixin:
"""Mixin for DOM event actions.""" """Mixin for DOM event actions."""
# Whether to `preventDefault` or `stopPropagation` on the event. # Whether to `preventDefault` or `stopPropagation` on the event.
event_actions: Dict[str, Union[bool, int]] = {} event_actions: Dict[str, Union[bool, int]] = dataclasses.field(default_factory=dict)
@property @property
def stop_propagation(self): def stop_propagation(self):
@ -94,8 +102,9 @@ class EventActionsMixin(Base):
Returns: Returns:
New EventHandler-like with stopPropagation set to True. New EventHandler-like with stopPropagation set to True.
""" """
return self.copy( return dataclasses.replace(
update={"event_actions": {"stopPropagation": True, **self.event_actions}}, self,
event_actions={"stopPropagation": True, **self.event_actions},
) )
@property @property
@ -105,8 +114,9 @@ class EventActionsMixin(Base):
Returns: Returns:
New EventHandler-like with preventDefault set to True. New EventHandler-like with preventDefault set to True.
""" """
return self.copy( return dataclasses.replace(
update={"event_actions": {"preventDefault": True, **self.event_actions}}, self,
event_actions={"preventDefault": True, **self.event_actions},
) )
def throttle(self, limit_ms: int): def throttle(self, limit_ms: int):
@ -118,8 +128,9 @@ class EventActionsMixin(Base):
Returns: Returns:
New EventHandler-like with throttle set to limit_ms. New EventHandler-like with throttle set to limit_ms.
""" """
return self.copy( return dataclasses.replace(
update={"event_actions": {"throttle": limit_ms, **self.event_actions}}, self,
event_actions={"throttle": limit_ms, **self.event_actions},
) )
def debounce(self, delay_ms: int): def debounce(self, delay_ms: int):
@ -131,26 +142,25 @@ class EventActionsMixin(Base):
Returns: Returns:
New EventHandler-like with debounce set to delay_ms. New EventHandler-like with debounce set to delay_ms.
""" """
return self.copy( return dataclasses.replace(
update={"event_actions": {"debounce": delay_ms, **self.event_actions}}, self,
event_actions={"debounce": delay_ms, **self.event_actions},
) )
@dataclasses.dataclass(
init=True,
frozen=True,
)
class EventHandler(EventActionsMixin): class EventHandler(EventActionsMixin):
"""An event handler responds to an event to update the state.""" """An event handler responds to an event to update the state."""
# The function to call in response to the event. # The function to call in response to the event.
fn: Any fn: Any = dataclasses.field(default=None)
# The full name of the state class this event handler is attached to. # The full name of the state class this event handler is attached to.
# Empty string means this event handler is a server side event. # Empty string means this event handler is a server side event.
state_full_name: str = "" state_full_name: str = dataclasses.field(default="")
class Config:
"""The Pydantic config."""
# Needed to allow serialization of Callable.
frozen = True
@classmethod @classmethod
def __class_getitem__(cls, args_spec: str) -> Annotated: def __class_getitem__(cls, args_spec: str) -> Annotated:
@ -215,6 +225,10 @@ class EventHandler(EventActionsMixin):
) )
@dataclasses.dataclass(
init=True,
frozen=True,
)
class EventSpec(EventActionsMixin): class EventSpec(EventActionsMixin):
"""An event specification. """An event specification.
@ -223,19 +237,37 @@ class EventSpec(EventActionsMixin):
""" """
# The event handler. # The event handler.
handler: EventHandler handler: EventHandler = dataclasses.field(default=None) # type: ignore
# The handler on the client to process event. # The handler on the client to process event.
client_handler_name: str = "" client_handler_name: str = dataclasses.field(default="")
# The arguments to pass to the function. # The arguments to pass to the function.
args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] = () args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] = dataclasses.field(
default_factory=tuple
)
class Config: def __init__(
"""The Pydantic config.""" self,
handler: EventHandler,
event_actions: Dict[str, Union[bool, int]] | None = None,
client_handler_name: str = "",
args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] = tuple(),
):
"""Initialize an EventSpec.
# Required to allow tuple fields. Args:
frozen = True event_actions: The event actions.
handler: The event handler.
client_handler_name: The client handler name.
args: The arguments to pass to the function.
"""
if event_actions is None:
event_actions = {}
object.__setattr__(self, "event_actions", event_actions)
object.__setattr__(self, "handler", handler)
object.__setattr__(self, "client_handler_name", client_handler_name)
object.__setattr__(self, "args", args or tuple())
def with_args( def with_args(
self, args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...] self, args: Tuple[Tuple[ImmutableVar, ImmutableVar], ...]
@ -286,6 +318,9 @@ class EventSpec(EventActionsMixin):
return self.with_args(self.args + new_payload) return self.with_args(self.args + new_payload)
@dataclasses.dataclass(
frozen=True,
)
class CallableEventSpec(EventSpec): class CallableEventSpec(EventSpec):
"""Decorate an EventSpec-returning function to act as both a EventSpec and a function. """Decorate an EventSpec-returning function to act as both a EventSpec and a function.
@ -305,10 +340,13 @@ class CallableEventSpec(EventSpec):
if fn is not None: if fn is not None:
default_event_spec = fn() default_event_spec = fn()
super().__init__( super().__init__(
fn=fn, # type: ignore event_actions=default_event_spec.event_actions,
**default_event_spec.dict(), client_handler_name=default_event_spec.client_handler_name,
args=default_event_spec.args,
handler=default_event_spec.handler,
**kwargs, **kwargs,
) )
object.__setattr__(self, "fn", fn)
else: else:
super().__init__(**kwargs) super().__init__(**kwargs)
@ -332,12 +370,16 @@ class CallableEventSpec(EventSpec):
return self.fn(*args, **kwargs) return self.fn(*args, **kwargs)
@dataclasses.dataclass(
init=True,
frozen=True,
)
class EventChain(EventActionsMixin): class EventChain(EventActionsMixin):
"""Container for a chain of events that will be executed in order.""" """Container for a chain of events that will be executed in order."""
events: List[EventSpec] events: List[EventSpec] = dataclasses.field(default_factory=list)
args_spec: Optional[Callable] args_spec: Optional[Callable] = dataclasses.field(default=None)
# These chains can be used for their side effects when no other events are desired. # These chains can be used for their side effects when no other events are desired.
@ -345,14 +387,22 @@ stop_propagation = EventChain(events=[], args_spec=lambda: []).stop_propagation
prevent_default = EventChain(events=[], args_spec=lambda: []).prevent_default prevent_default = EventChain(events=[], args_spec=lambda: []).prevent_default
class Target(Base): @dataclasses.dataclass(
init=True,
frozen=True,
)
class Target:
"""A Javascript event target.""" """A Javascript event target."""
checked: bool = False checked: bool = False
value: Any = None value: Any = None
class FrontendEvent(Base): @dataclasses.dataclass(
init=True,
frozen=True,
)
class FrontendEvent:
"""A Javascript event.""" """A Javascript event."""
target: Target = Target() target: Target = Target()
@ -360,7 +410,11 @@ class FrontendEvent(Base):
value: Any = None value: Any = None
class FileUpload(Base): @dataclasses.dataclass(
init=True,
frozen=True,
)
class FileUpload:
"""Class to represent a file upload.""" """Class to represent a file upload."""
upload_id: Optional[str] = None upload_id: Optional[str] = None

View File

@ -421,6 +421,9 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
if issubclass(output, (ObjectVar, Base)): if issubclass(output, (ObjectVar, Base)):
return ToObjectOperation.create(self, var_type or dict) return ToObjectOperation.create(self, var_type or dict)
if dataclasses.is_dataclass(output):
return ToObjectOperation.create(self, var_type or dict)
if issubclass(output, FunctionVar): if issubclass(output, FunctionVar):
# if fixed_type is not None and not issubclass(fixed_type, Callable): # if fixed_type is not None and not issubclass(fixed_type, Callable):
# raise TypeError( # raise TypeError(
@ -479,7 +482,11 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
): ):
return self.to(NumberVar, self._var_type) return self.to(NumberVar, self._var_type)
if all(inspect.isclass(t) and issubclass(t, Base) for t in inner_types): if all(
inspect.isclass(t)
and (issubclass(t, Base) or dataclasses.is_dataclass(t))
for t in inner_types
):
return self.to(ObjectVar, self._var_type) return self.to(ObjectVar, self._var_type)
return self return self
@ -499,6 +506,8 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
return self.to(StringVar, self._var_type) return self.to(StringVar, self._var_type)
if issubclass(fixed_type, Base): if issubclass(fixed_type, Base):
return self.to(ObjectVar, self._var_type) return self.to(ObjectVar, self._var_type)
if dataclasses.is_dataclass(fixed_type):
return self.to(ObjectVar, self._var_type)
return self return self
def get_default_value(self) -> Any: def get_default_value(self) -> Any:
@ -985,6 +994,16 @@ class LiteralVar(ImmutableVar):
) )
return LiteralVar.create(serialized_value, _var_data=_var_data) return LiteralVar.create(serialized_value, _var_data=_var_data)
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return LiteralObjectVar.create(
{
k: (None if callable(v) else v)
for k, v in dataclasses.asdict(value).items()
},
_var_type=type(value),
_var_data=_var_data,
)
raise TypeError( raise TypeError(
f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}." f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}."
) )

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from reflex import constants from reflex import constants
@ -14,6 +15,7 @@ if TYPE_CHECKING:
from reflex.app import App from reflex.app import App
@dataclasses.dataclass(init=True)
class HydrateMiddleware(Middleware): class HydrateMiddleware(Middleware):
"""Middleware to handle initial app hydration.""" """Middleware to handle initial app hydration."""

View File

@ -2,10 +2,9 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from reflex.base import Base
from reflex.event import Event from reflex.event import Event
from reflex.state import BaseState, StateUpdate from reflex.state import BaseState, StateUpdate
@ -13,9 +12,10 @@ if TYPE_CHECKING:
from reflex.app import App from reflex.app import App
class Middleware(Base, ABC): class Middleware(ABC):
"""Middleware to preprocess and postprocess requests.""" """Middleware to preprocess and postprocess requests."""
@abstractmethod
async def preprocess( async def preprocess(
self, app: App, state: BaseState, event: Event self, app: App, state: BaseState, event: Event
) -> Optional[StateUpdate]: ) -> Optional[StateUpdate]:

View File

@ -5,8 +5,10 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import copy import copy
import dataclasses
import functools import functools
import inspect import inspect
import json
import os import os
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -83,13 +85,15 @@ var = immutable_computed_var
TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb
class HeaderData(Base): @dataclasses.dataclass(frozen=True)
class HeaderData:
"""An object containing headers data.""" """An object containing headers data."""
host: str = "" host: str = ""
origin: str = "" origin: str = ""
upgrade: str = "" upgrade: str = ""
connection: str = "" connection: str = ""
cookie: str = ""
pragma: str = "" pragma: str = ""
cache_control: str = "" cache_control: str = ""
user_agent: str = "" user_agent: str = ""
@ -105,13 +109,16 @@ class HeaderData(Base):
Args: Args:
router_data: the router_data dict. router_data: the router_data dict.
""" """
super().__init__()
if router_data: if router_data:
for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items(): for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items():
setattr(self, format.to_snake_case(k), v) object.__setattr__(self, format.to_snake_case(k), v)
else:
for k in dataclasses.fields(self):
object.__setattr__(self, k.name, "")
class PageData(Base): @dataclasses.dataclass(frozen=True)
class PageData:
"""An object containing page data.""" """An object containing page data."""
host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?)
@ -119,7 +126,7 @@ class PageData(Base):
raw_path: str = "" raw_path: str = ""
full_path: str = "" full_path: str = ""
full_raw_path: str = "" full_raw_path: str = ""
params: dict = {} params: dict = dataclasses.field(default_factory=dict)
def __init__(self, router_data: Optional[dict] = None): def __init__(self, router_data: Optional[dict] = None):
"""Initalize the PageData object based on router_data. """Initalize the PageData object based on router_data.
@ -127,17 +134,34 @@ class PageData(Base):
Args: Args:
router_data: the router_data dict. router_data: the router_data dict.
""" """
super().__init__()
if router_data: if router_data:
self.host = router_data.get(constants.RouteVar.HEADERS, {}).get("origin") object.__setattr__(
self.path = router_data.get(constants.RouteVar.PATH, "") self,
self.raw_path = router_data.get(constants.RouteVar.ORIGIN, "") "host",
self.full_path = f"{self.host}{self.path}" router_data.get(constants.RouteVar.HEADERS, {}).get("origin", ""),
self.full_raw_path = f"{self.host}{self.raw_path}" )
self.params = router_data.get(constants.RouteVar.QUERY, {}) object.__setattr__(
self, "path", router_data.get(constants.RouteVar.PATH, "")
)
object.__setattr__(
self, "raw_path", router_data.get(constants.RouteVar.ORIGIN, "")
)
object.__setattr__(self, "full_path", f"{self.host}{self.path}")
object.__setattr__(self, "full_raw_path", f"{self.host}{self.raw_path}")
object.__setattr__(
self, "params", router_data.get(constants.RouteVar.QUERY, {})
)
else:
object.__setattr__(self, "host", "")
object.__setattr__(self, "path", "")
object.__setattr__(self, "raw_path", "")
object.__setattr__(self, "full_path", "")
object.__setattr__(self, "full_raw_path", "")
object.__setattr__(self, "params", {})
class SessionData(Base): @dataclasses.dataclass(frozen=True, init=False)
class SessionData:
"""An object containing session data.""" """An object containing session data."""
client_token: str = "" client_token: str = ""
@ -150,19 +174,24 @@ class SessionData(Base):
Args: Args:
router_data: the router_data dict. router_data: the router_data dict.
""" """
super().__init__()
if router_data: if router_data:
self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "") client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "")
self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "") client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "")
self.session_id = router_data.get(constants.RouteVar.SESSION_ID, "") session_id = router_data.get(constants.RouteVar.SESSION_ID, "")
else:
client_token = client_ip = session_id = ""
object.__setattr__(self, "client_token", client_token)
object.__setattr__(self, "client_ip", client_ip)
object.__setattr__(self, "session_id", session_id)
class RouterData(Base): @dataclasses.dataclass(frozen=True, init=False)
class RouterData:
"""An object containing RouterData.""" """An object containing RouterData."""
session: SessionData = SessionData() session: SessionData = dataclasses.field(default_factory=SessionData)
headers: HeaderData = HeaderData() headers: HeaderData = dataclasses.field(default_factory=HeaderData)
page: PageData = PageData() page: PageData = dataclasses.field(default_factory=PageData)
def __init__(self, router_data: Optional[dict] = None): def __init__(self, router_data: Optional[dict] = None):
"""Initialize the RouterData object. """Initialize the RouterData object.
@ -170,10 +199,30 @@ class RouterData(Base):
Args: Args:
router_data: the router_data dict. router_data: the router_data dict.
""" """
super().__init__() object.__setattr__(self, "session", SessionData(router_data))
self.session = SessionData(router_data) object.__setattr__(self, "headers", HeaderData(router_data))
self.headers = HeaderData(router_data) object.__setattr__(self, "page", PageData(router_data))
self.page = PageData(router_data)
def toJson(self) -> str:
"""Convert the object to a JSON string.
Returns:
The JSON string.
"""
return json.dumps(dataclasses.asdict(self))
@serializer
def serialize_routerdata(value: RouterData) -> str:
"""Serialize a RouterData instance.
Args:
value: The RouterData to serialize.
Returns:
The serialized RouterData.
"""
return value.toJson()
def _no_chain_background_task( def _no_chain_background_task(
@ -249,10 +298,11 @@ def _split_substate_key(substate_key: str) -> tuple[str, str]:
return token, state_name return token, state_name
@dataclasses.dataclass(frozen=True, init=False)
class EventHandlerSetVar(EventHandler): class EventHandlerSetVar(EventHandler):
"""A special event handler to wrap setvar functionality.""" """A special event handler to wrap setvar functionality."""
state_cls: Type[BaseState] state_cls: Type[BaseState] = dataclasses.field(init=False)
def __init__(self, state_cls: Type[BaseState]): def __init__(self, state_cls: Type[BaseState]):
"""Initialize the EventHandlerSetVar. """Initialize the EventHandlerSetVar.
@ -263,8 +313,8 @@ class EventHandlerSetVar(EventHandler):
super().__init__( super().__init__(
fn=type(self).setvar, fn=type(self).setvar,
state_full_name=state_cls.get_full_name(), state_full_name=state_cls.get_full_name(),
state_cls=state_cls, # type: ignore
) )
object.__setattr__(self, "state_cls", state_cls)
def setvar(self, var_name: str, value: Any): def setvar(self, var_name: str, value: Any):
"""Set the state variable to the value of the event. """Set the state variable to the value of the event.
@ -1826,8 +1876,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
self.dirty_vars.update(self._always_dirty_computed_vars) self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty() self._mark_dirty()
def dictify(value: Any):
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return dataclasses.asdict(value)
return value
base_vars = { base_vars = {
prop_name: self.get_value(getattr(self, prop_name)) prop_name: dictify(self.get_value(getattr(self, prop_name)))
for prop_name in self.base_vars for prop_name in self.base_vars
} }
if initial and include_computed: if initial and include_computed:
@ -1907,9 +1962,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return state return state
EventHandlerSetVar.update_forward_refs()
class State(BaseState): class State(BaseState):
"""The app Base State.""" """The app Base State."""
@ -2341,18 +2393,29 @@ class StateProxy(wrapt.ObjectProxy):
self._self_mutable = original_mutable self._self_mutable = original_mutable
class StateUpdate(Base): @dataclasses.dataclass(
frozen=True,
)
class StateUpdate:
"""A state update sent to the frontend.""" """A state update sent to the frontend."""
# The state delta. # The state delta.
delta: Delta = {} delta: Delta = dataclasses.field(default_factory=dict)
# Events to be added to the event queue. # Events to be added to the event queue.
events: List[Event] = [] events: List[Event] = dataclasses.field(default_factory=list)
# Whether this is the final state update for the event. # Whether this is the final state update for the event.
final: bool = True final: bool = True
def json(self) -> str:
"""Convert the state update to a JSON string.
Returns:
The state update as a JSON string.
"""
return json.dumps(dataclasses.asdict(self))
class StateManager(Base, ABC): class StateManager(Base, ABC):
"""A class to manage many client states.""" """A class to manage many client states."""

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
import inspect import inspect
import json import json
import os import os
@ -623,6 +624,14 @@ def format_state(value: Any, key: Optional[str] = None) -> Any:
if isinstance(value, dict): if isinstance(value, dict):
return {k: format_state(v, k) for k, v in value.items()} return {k: format_state(v, k) for k, v in value.items()}
# Hand dataclasses.
if dataclasses.is_dataclass(value):
if isinstance(value, type):
raise TypeError(
f"Cannot format state of type {type(value)}. Please provide an instance of the dataclass."
)
return {k: format_state(v, k) for k, v in dataclasses.asdict(value).items()}
# Handle lists, sets, typles. # Handle lists, sets, typles.
if isinstance(value, types.StateIterBases): if isinstance(value, types.StateIterBases):
return [format_state(v) for v in value] return [format_state(v) for v in value]

View File

@ -2,10 +2,9 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union from typing import DefaultDict, Dict, List, Optional, Tuple, Union
from reflex.base import Base
def merge_imports( def merge_imports(
@ -19,12 +18,22 @@ def merge_imports(
Returns: Returns:
The merged import dicts. The merged import dicts.
""" """
all_imports = defaultdict(list) all_imports: DefaultDict[str, List[ImportVar]] = defaultdict(list)
for import_dict in imports: for import_dict in imports:
for lib, fields in ( for lib, fields in (
import_dict if isinstance(import_dict, tuple) else import_dict.items() import_dict if isinstance(import_dict, tuple) else import_dict.items()
): ):
all_imports[lib].extend(fields) if isinstance(fields, (list, tuple, set)):
all_imports[lib].extend(
(
ImportVar(field) if isinstance(field, str) else field
for field in fields
)
)
else:
all_imports[lib].append(
ImportVar(fields) if isinstance(fields, str) else fields
)
return all_imports return all_imports
@ -75,7 +84,8 @@ def collapse_imports(
} }
class ImportVar(Base): @dataclasses.dataclass(order=True, frozen=True)
class ImportVar:
"""An import var.""" """An import var."""
# The name of the import tag. # The name of the import tag.
@ -111,73 +121,6 @@ class ImportVar(Base):
else: else:
return self.tag or "" return self.tag or ""
def __lt__(self, other: ImportVar) -> bool:
"""Compare two ImportVar objects.
Args:
other: The other ImportVar object to compare.
Returns:
Whether this ImportVar object is less than the other.
"""
return (
self.tag,
self.is_default,
self.alias,
self.install,
self.render,
self.transpile,
) < (
other.tag,
other.is_default,
other.alias,
other.install,
other.render,
other.transpile,
)
def __eq__(self, other: ImportVar) -> bool:
"""Check if two ImportVar objects are equal.
Args:
other: The other ImportVar object to compare.
Returns:
Whether the two ImportVar objects are equal.
"""
return (
self.tag,
self.is_default,
self.alias,
self.install,
self.render,
self.transpile,
) == (
other.tag,
other.is_default,
other.alias,
other.install,
other.render,
other.transpile,
)
def __hash__(self) -> int:
"""Hash the ImportVar object.
Returns:
The hash of the ImportVar object.
"""
return hash(
(
self.tag,
self.is_default,
self.alias,
self.install,
self.render,
self.transpile,
)
)
ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]] ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]]
ImportDict = Dict[str, ImportTypes] ImportDict = Dict[str, ImportTypes]

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
import functools import functools
import glob import glob
import importlib import importlib
@ -32,7 +33,6 @@ from redis import exceptions
from redis.asyncio import Redis from redis.asyncio import Redis
from reflex import constants, model from reflex import constants, model
from reflex.base import Base
from reflex.compiler import templates from reflex.compiler import templates
from reflex.config import Config, get_config from reflex.config import Config, get_config
from reflex.utils import console, net, path_ops, processes from reflex.utils import console, net, path_ops, processes
@ -43,7 +43,8 @@ from reflex.utils.registry import _get_best_registry
CURRENTLY_INSTALLING_NODE = False CURRENTLY_INSTALLING_NODE = False
class Template(Base): @dataclasses.dataclass(frozen=True)
class Template:
"""A template for a Reflex app.""" """A template for a Reflex app."""
name: str name: str
@ -52,7 +53,8 @@ class Template(Base):
demo_url: str demo_url: str
class CpuInfo(Base): @dataclasses.dataclass(frozen=True)
class CpuInfo:
"""Model to save cpu info.""" """Model to save cpu info."""
manufacturer_id: Optional[str] manufacturer_id: Optional[str]
@ -1279,7 +1281,7 @@ def fetch_app_templates(version: str) -> dict[str, Template]:
None, None,
) )
return { return {
tp["name"]: Template.parse_obj(tp) tp["name"]: Template(**tp)
for tp in templates_data for tp in templates_data
if not tp["hidden"] and tp["code_url"] is not None if not tp["hidden"] and tp["code_url"] is not None
} }

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import dataclasses
import multiprocessing import multiprocessing
import platform import platform
import warnings import warnings
@ -144,7 +145,7 @@ def _prepare_event(event: str, **kwargs) -> dict:
"python_version": get_python_version(), "python_version": get_python_version(),
"cpu_count": get_cpu_count(), "cpu_count": get_cpu_count(),
"memory": get_memory(), "memory": get_memory(),
"cpu_info": dict(cpuinfo) if cpuinfo else {}, "cpu_info": dataclasses.asdict(cpuinfo) if cpuinfo else {},
**additional_fields, **additional_fields,
}, },
"timestamp": stamp, "timestamp": stamp,

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
import dataclasses
import inspect import inspect
import sys import sys
import types import types
@ -480,7 +481,11 @@ def is_valid_var_type(type_: Type) -> bool:
if is_union(type_): if is_union(type_):
return all((is_valid_var_type(arg) for arg in get_args(type_))) return all((is_valid_var_type(arg) for arg in get_args(type_)))
return _issubclass(type_, StateVar) or serializers.has_serializer(type_) return (
_issubclass(type_, StateVar)
or serializers.has_serializer(type_)
or dataclasses.is_dataclass(type_)
)
def is_backend_base_variable(name: str, cls: Type) -> bool: def is_backend_base_variable(name: str, cls: Type) -> bool:

View File

@ -637,21 +637,21 @@ def test_component_create_unallowed_types(children, test_component):
"props": [], "props": [],
"contents": "", "contents": "",
"args": None, "args": None,
"special_props": set(), "special_props": [],
"children": [ "children": [
{ {
"name": "RadixThemesText", "name": "RadixThemesText",
"props": ['as={"p"}'], "props": ['as={"p"}'],
"contents": "", "contents": "",
"args": None, "args": None,
"special_props": set(), "special_props": [],
"children": [ "children": [
{ {
"name": "", "name": "",
"props": [], "props": [],
"contents": '{"first_text"}', "contents": '{"first_text"}',
"args": None, "args": None,
"special_props": set(), "special_props": [],
"children": [], "children": [],
"autofocus": False, "autofocus": False,
} }
@ -679,13 +679,13 @@ def test_component_create_unallowed_types(children, test_component):
"contents": '{"first_text"}', "contents": '{"first_text"}',
"name": "", "name": "",
"props": [], "props": [],
"special_props": set(), "special_props": [],
} }
], ],
"contents": "", "contents": "",
"name": "RadixThemesText", "name": "RadixThemesText",
"props": ['as={"p"}'], "props": ['as={"p"}'],
"special_props": set(), "special_props": [],
}, },
{ {
"args": None, "args": None,
@ -698,19 +698,19 @@ def test_component_create_unallowed_types(children, test_component):
"contents": '{"second_text"}', "contents": '{"second_text"}',
"name": "", "name": "",
"props": [], "props": [],
"special_props": set(), "special_props": [],
} }
], ],
"contents": "", "contents": "",
"name": "RadixThemesText", "name": "RadixThemesText",
"props": ['as={"p"}'], "props": ['as={"p"}'],
"special_props": set(), "special_props": [],
}, },
], ],
"contents": "", "contents": "",
"name": "Fragment", "name": "Fragment",
"props": [], "props": [],
"special_props": set(), "special_props": [],
}, },
), ),
( (
@ -730,13 +730,13 @@ def test_component_create_unallowed_types(children, test_component):
"contents": '{"first_text"}', "contents": '{"first_text"}',
"name": "", "name": "",
"props": [], "props": [],
"special_props": set(), "special_props": [],
} }
], ],
"contents": "", "contents": "",
"name": "RadixThemesText", "name": "RadixThemesText",
"props": ['as={"p"}'], "props": ['as={"p"}'],
"special_props": set(), "special_props": [],
}, },
{ {
"args": None, "args": None,
@ -757,31 +757,31 @@ def test_component_create_unallowed_types(children, test_component):
"contents": '{"second_text"}', "contents": '{"second_text"}',
"name": "", "name": "",
"props": [], "props": [],
"special_props": set(), "special_props": [],
} }
], ],
"contents": "", "contents": "",
"name": "RadixThemesText", "name": "RadixThemesText",
"props": ['as={"p"}'], "props": ['as={"p"}'],
"special_props": set(), "special_props": [],
} }
], ],
"contents": "", "contents": "",
"name": "Fragment", "name": "Fragment",
"props": [], "props": [],
"special_props": set(), "special_props": [],
} }
], ],
"contents": "", "contents": "",
"name": "RadixThemesBox", "name": "RadixThemesBox",
"props": [], "props": [],
"special_props": set(), "special_props": [],
}, },
], ],
"contents": "", "contents": "",
"name": "Fragment", "name": "Fragment",
"props": [], "props": [],
"special_props": set(), "special_props": [],
}, },
), ),
], ],
@ -1289,12 +1289,12 @@ class EventState(rx.State):
id="fstring-class_name", id="fstring-class_name",
), ),
pytest.param( pytest.param(
rx.fragment(special_props={TEST_VAR}), rx.fragment(special_props=[TEST_VAR]),
[TEST_VAR], [TEST_VAR],
id="direct-special_props", id="direct-special_props",
), ),
pytest.param( pytest.param(
rx.fragment(special_props={LiteralVar.create(f"foo{TEST_VAR}bar")}), rx.fragment(special_props=[LiteralVar.create(f"foo{TEST_VAR}bar")]),
[FORMATTED_TEST_VAR], [FORMATTED_TEST_VAR],
id="fstring-special_props", id="fstring-special_props",
), ),

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
import functools import functools
import io import io
import json import json
@ -1052,7 +1053,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
f"comp_{arg_name}": exp_val, f"comp_{arg_name}": exp_val,
constants.CompileVars.IS_HYDRATED: False, constants.CompileVars.IS_HYDRATED: False,
# "side_effect_counter": exp_index, # "side_effect_counter": exp_index,
"router": exp_router, "router": dataclasses.asdict(exp_router),
} }
}, },
events=[ events=[

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import copy import copy
import dataclasses
import datetime import datetime
import functools import functools
import json import json
@ -58,6 +59,7 @@ formatted_router = {
"origin": "", "origin": "",
"upgrade": "", "upgrade": "",
"connection": "", "connection": "",
"cookie": "",
"pragma": "", "pragma": "",
"cache_control": "", "cache_control": "",
"user_agent": "", "user_agent": "",
@ -865,8 +867,10 @@ def test_get_headers(test_state, router_data, router_data_headers):
router_data: The router data fixture. router_data: The router data fixture.
router_data_headers: The expected headers. router_data_headers: The expected headers.
""" """
print(router_data_headers)
test_state.router = RouterData(router_data) test_state.router = RouterData(router_data)
assert test_state.router.headers.dict() == { print(test_state.router.headers)
assert dataclasses.asdict(test_state.router.headers) == {
format.to_snake_case(k): v for k, v in router_data_headers.items() format.to_snake_case(k): v for k, v in router_data_headers.items()
} }
@ -1908,19 +1912,21 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
mock_app.event_namespace.emit.assert_called_once() mock_app.event_namespace.emit.assert_called_once()
mcall = mock_app.event_namespace.emit.mock_calls[0] mcall = mock_app.event_namespace.emit.mock_calls[0]
assert mcall.args[0] == str(SocketEvent.EVENT) assert mcall.args[0] == str(SocketEvent.EVENT)
assert json.loads(mcall.args[1]) == StateUpdate( assert json.loads(mcall.args[1]) == dataclasses.asdict(
delta={ StateUpdate(
parent_state.get_full_name(): { delta={
"upper": "", parent_state.get_full_name(): {
"sum": 3.14, "upper": "",
}, "sum": 3.14,
grandchild_state.get_full_name(): { },
"value2": "42", grandchild_state.get_full_name(): {
}, "value2": "42",
GrandchildState3.get_full_name(): { },
"computed": "", GrandchildState3.get_full_name(): {
}, "computed": "",
} },
}
)
) )
assert mcall.kwargs["to"] == grandchild_state.router.session.session_id assert mcall.kwargs["to"] == grandchild_state.router.session.session_id

View File

@ -553,6 +553,7 @@ formatted_router = {
"origin": "", "origin": "",
"upgrade": "", "upgrade": "",
"connection": "", "connection": "",
"cookie": "",
"pragma": "", "pragma": "",
"cache_control": "", "cache_control": "",
"user_agent": "", "user_agent": "",

View File

@ -54,17 +54,21 @@ def test_import_var(import_var, expected_name):
( (
{"react": {"Component"}}, {"react": {"Component"}},
{"react": {"Component"}, "react-dom": {"render"}}, {"react": {"Component"}, "react-dom": {"render"}},
{"react": {"Component"}, "react-dom": {"render"}}, {"react": {ImportVar("Component")}, "react-dom": {ImportVar("render")}},
), ),
( (
{"react": {"Component"}, "next/image": {"Image"}}, {"react": {"Component"}, "next/image": {"Image"}},
{"react": {"Component"}, "react-dom": {"render"}}, {"react": {"Component"}, "react-dom": {"render"}},
{"react": {"Component"}, "react-dom": {"render"}, "next/image": {"Image"}}, {
"react": {ImportVar("Component")},
"react-dom": {ImportVar("render")},
"next/image": {ImportVar("Image")},
},
), ),
( (
{"react": {"Component"}}, {"react": {"Component"}},
{"": {"some/custom.css"}}, {"": {"some/custom.css"}},
{"react": {"Component"}, "": {"some/custom.css"}}, {"react": {ImportVar("Component")}, "": {ImportVar("some/custom.css")}},
), ),
], ],
) )