remove format_state and override behavior for bare (#3979)

* remove format_state and override behavior for bare

* pass the test cases

* only do one level of dicting dataclasses

* remove dict and replace list with set

* delete unnecessary serialize calls

* remove serialize for mutable proxy

* dang it darglint
This commit is contained in:
Khaleel Al-Adhami 2024-09-26 16:00:28 -07:00 committed by GitHub
parent 70bd88c682
commit 0ab161c119
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 30 additions and 117 deletions

View File

@ -155,7 +155,7 @@ def compile_state(state: Type[BaseState]) -> dict:
initial_state = state(_reflex_internal_init=True).dict( initial_state = state(_reflex_internal_init=True).dict(
initial=True, include_computed=False initial=True, include_computed=False
) )
return format.format_state(initial_state) return initial_state
def _compile_client_storage_field( def _compile_client_storage_field(

View File

@ -7,7 +7,7 @@ from typing import Any, Iterator
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.tags import Tag from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless from reflex.components.tags.tagless import Tagless
from reflex.vars.base import Var from reflex.vars import ArrayVar, BooleanVar, ObjectVar, Var
class Bare(Component): class Bare(Component):
@ -33,6 +33,8 @@ class Bare(Component):
def _render(self) -> Tag: def _render(self) -> Tag:
if isinstance(self.contents, Var): if isinstance(self.contents, Var):
if isinstance(self.contents, (BooleanVar, ObjectVar, ArrayVar)):
return Tagless(contents=f"{{{str(self.contents.to_string())}}}")
return Tagless(contents=f"{{{str(self.contents)}}}") return Tagless(contents=f"{{{str(self.contents)}}}")
return Tagless(contents=str(self.contents)) return Tagless(contents=str(self.contents))

View File

@ -9,7 +9,6 @@ from reflex import constants
from reflex.event import Event, get_hydrate_event from reflex.event import Event, get_hydrate_event
from reflex.middleware.middleware import Middleware from reflex.middleware.middleware import Middleware
from reflex.state import BaseState, StateUpdate from reflex.state import BaseState, StateUpdate
from reflex.utils import format
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.app import App from reflex.app import App
@ -43,7 +42,7 @@ class HydrateMiddleware(Middleware):
setattr(state, constants.CompileVars.IS_HYDRATED, False) setattr(state, constants.CompileVars.IS_HYDRATED, False)
# Get the initial state. # Get the initial state.
delta = format.format_state(state.dict()) delta = state.dict()
# since a full dict was captured, clean any dirtiness # since a full dict was captured, clean any dirtiness
state._clean() state._clean()

View File

@ -73,7 +73,7 @@ from reflex.utils.exceptions import (
LockExpiredError, LockExpiredError,
) )
from reflex.utils.exec import is_testing_env from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import SerializedType, serialize, serializer from reflex.utils.serializers import serializer
from reflex.utils.types import override from reflex.utils.types import override
from reflex.vars import VarData from reflex.vars import VarData
@ -1790,9 +1790,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
for substate in self.dirty_substates.union(self._always_dirty_substates): for substate in self.dirty_substates.union(self._always_dirty_substates):
delta.update(substates[substate].get_delta()) delta.update(substates[substate].get_delta())
# Format the delta.
delta = format.format_state(delta)
# Return the delta. # Return the delta.
return delta return delta
@ -2433,7 +2430,7 @@ class StateUpdate:
Returns: Returns:
The state update as a JSON string. The state update as a JSON string.
""" """
return format.json_dumps(dataclasses.asdict(self)) return format.json_dumps(self)
class StateManager(Base, ABC): class StateManager(Base, ABC):
@ -3660,22 +3657,16 @@ class MutableProxy(wrapt.ObjectProxy):
@serializer @serializer
def serialize_mutable_proxy(mp: MutableProxy) -> SerializedType: def serialize_mutable_proxy(mp: MutableProxy):
"""Serialize the wrapped value of a MutableProxy. """Return the wrapped value of a MutableProxy.
Args: Args:
mp: The MutableProxy to serialize. mp: The MutableProxy to serialize.
Returns: Returns:
The serialized wrapped object. The wrapped object.
Raises:
ValueError: when the wrapped object is not serializable.
""" """
value = serialize(mp.__wrapped__) return mp.__wrapped__
if value is None:
raise ValueError(f"Cannot serialize {type(mp.__wrapped__)}")
return value
class ImmutableMutableProxy(MutableProxy): class ImmutableMutableProxy(MutableProxy):

View File

@ -9,7 +9,7 @@ import re
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
from reflex import constants from reflex import constants
from reflex.utils import exceptions, types from reflex.utils import exceptions
from reflex.utils.console import deprecate from reflex.utils.console import deprecate
if TYPE_CHECKING: if TYPE_CHECKING:
@ -624,48 +624,6 @@ def format_query_params(router_data: dict[str, Any]) -> dict[str, str]:
return {k.replace("-", "_"): v for k, v in params.items()} return {k.replace("-", "_"): v for k, v in params.items()}
def format_state(value: Any, key: Optional[str] = None) -> Any:
"""Recursively format values in the given state.
Args:
value: The state to format.
key: The key associated with the value (optional).
Returns:
The formatted state.
Raises:
TypeError: If the given value is not a valid state.
"""
from reflex.utils import serializers
# Handle dicts.
if isinstance(value, dict):
return {k: format_state(v, k) for k, v in value.items()}
# Handle lists, sets, typles.
if isinstance(value, types.StateIterBases):
return [format_state(v) for v in value]
# Return state vars as is.
if isinstance(value, types.StateBases):
return value
# Serialize the value.
serialized = serializers.serialize(value)
if serialized is not None:
return serialized
if key is None:
raise TypeError(
f"No JSON serializer found for var {value} of type {type(value)}."
)
else:
raise TypeError(
f"No JSON serializer found for State Var '{key}' of value {value} of type {type(value)}."
)
def format_state_name(state_name: str) -> str: def format_state_name(state_name: str) -> str:
"""Format a state name, replacing dots with double underscore. """Format a state name, replacing dots with double underscore.

View File

@ -12,7 +12,6 @@ from pathlib import Path
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict,
List, List,
Literal, Literal,
Optional, Optional,
@ -126,7 +125,8 @@ def serialize(
# If there is no serializer, return None. # If there is no serializer, return None.
if serializer is None: if serializer is None:
if dataclasses.is_dataclass(value) and not isinstance(value, type): if dataclasses.is_dataclass(value) and not isinstance(value, type):
return serialize(dataclasses.asdict(value)) return {k.name: getattr(value, k.name) for k in dataclasses.fields(value)}
if get_type: if get_type:
return None, None return None, None
return None return None
@ -214,32 +214,6 @@ def serialize_type(value: type) -> str:
return value.__name__ return value.__name__
@serializer
def serialize_str(value: str) -> str:
"""Serialize a string.
Args:
value: The string to serialize.
Returns:
The serialized string.
"""
return value
@serializer
def serialize_primitive(value: Union[bool, int, float, None]):
"""Serialize a primitive type.
Args:
value: The number/bool/None to serialize.
Returns:
The serialized number/bool/None.
"""
return value
@serializer @serializer
def serialize_base(value: Base) -> dict: def serialize_base(value: Base) -> dict:
"""Serialize a Base instance. """Serialize a Base instance.
@ -250,33 +224,20 @@ def serialize_base(value: Base) -> dict:
Returns: Returns:
The serialized Base. The serialized Base.
""" """
return {k: serialize(v) for k, v in value.dict().items() if not callable(v)} return {k: v for k, v in value.dict().items() if not callable(v)}
@serializer @serializer
def serialize_list(value: Union[List, Tuple, Set]) -> list: def serialize_set(value: Set) -> list:
"""Serialize a list to a JSON string. """Serialize a set to a JSON serializable list.
Args: Args:
value: The list to serialize. value: The set to serialize.
Returns: Returns:
The serialized list. The serialized list.
""" """
return [serialize(item) for item in value] return list(value)
@serializer
def serialize_dict(prop: Dict[str, Any]) -> dict:
"""Serialize a dictionary to a JSON string.
Args:
prop: The dictionary to serialize.
Returns:
The serialized dictionary.
"""
return {k: serialize(v) for k, v in prop.items()}
@serializer(to=str) @serializer(to=str)

View File

@ -1141,7 +1141,7 @@ def serialize_literal(value: LiteralVar):
Returns: Returns:
The serialized Literal. The serialized Literal.
""" """
return serializers.serialize(value._var_value) return value._var_value
P = ParamSpec("P") P = ParamSpec("P")

View File

@ -793,8 +793,8 @@ def test_var_operations(driver, var_operations: AppHarness):
("foreach_list_ix", "1\n2"), ("foreach_list_ix", "1\n2"),
("foreach_list_nested", "1\n1\n2"), ("foreach_list_nested", "1\n1\n2"),
# rx.memo component with state # rx.memo component with state
("memo_comp", "1210"), ("memo_comp", "[1,2]10"),
("memo_comp_nested", "345"), ("memo_comp_nested", "[3,4]5"),
# foreach in a match # foreach in a match
("foreach_in_match", "first\nsecond\nthird"), ("foreach_in_match", "first\nsecond\nthird"),
] ]

View File

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
import functools import functools
import io import io
import json import json
@ -1053,7 +1052,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
f"comp_{arg_name}": exp_val, f"comp_{arg_name}": exp_val,
constants.CompileVars.IS_HYDRATED: False, constants.CompileVars.IS_HYDRATED: False,
# "side_effect_counter": exp_index, # "side_effect_counter": exp_index,
"router": dataclasses.asdict(exp_router), "router": exp_router,
} }
}, },
events=[ events=[

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import datetime import datetime
import json
from typing import Any, List from typing import Any, List
import plotly.graph_objects as go import plotly.graph_objects as go
@ -621,7 +622,7 @@ def test_format_state(input, output):
input: The state to format. input: The state to format.
output: The expected formatted state. output: The expected formatted state.
""" """
assert format.format_state(input) == output assert json.loads(format.json_dumps(input)) == json.loads(format.json_dumps(output))
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -1,19 +1,21 @@
import datetime import datetime
import json
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Type from typing import Any, Type
import pytest import pytest
from reflex.base import Base from reflex.base import Base
from reflex.components.core.colors import Color from reflex.components.core.colors import Color
from reflex.utils import serializers from reflex.utils import serializers
from reflex.utils.format import json_dumps
from reflex.vars.base import LiteralVar from reflex.vars.base import LiteralVar
@pytest.mark.parametrize( @pytest.mark.parametrize(
"type_,expected", "type_,expected",
[(str, True), (dict, True), (Dict[int, int], True), (Enum, True)], [(Enum, True)],
) )
def test_has_serializer(type_: Type, expected: bool): def test_has_serializer(type_: Type, expected: bool):
"""Test that has_serializer returns the correct value. """Test that has_serializer returns the correct value.
@ -198,7 +200,7 @@ def test_serialize(value: Any, expected: str):
value: The value to serialize. value: The value to serialize.
expected: The expected result. expected: The expected result.
""" """
assert serializers.serialize(value) == expected assert json.loads(json_dumps(value)) == json.loads(json_dumps(expected))
@pytest.mark.parametrize( @pytest.mark.parametrize(