remove format_state and override behavior for bare (#3979)

* remove format_state and override behavior for bare

* pass the test cases

* only do one level of dicting dataclasses

* remove dict and replace list with set

* delete unnecessary serialize calls

* remove serialize for mutable proxy

* dang it darglint
This commit is contained in:
Khaleel Al-Adhami 2024-09-26 16:00:28 -07:00 committed by GitHub
parent 70bd88c682
commit 0ab161c119
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 30 additions and 117 deletions

View File

@ -155,7 +155,7 @@ def compile_state(state: Type[BaseState]) -> dict:
initial_state = state(_reflex_internal_init=True).dict(
initial=True, include_computed=False
)
return format.format_state(initial_state)
return initial_state
def _compile_client_storage_field(

View File

@ -7,7 +7,7 @@ from typing import Any, Iterator
from reflex.components.component import Component
from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless
from reflex.vars.base import Var
from reflex.vars import ArrayVar, BooleanVar, ObjectVar, Var
class Bare(Component):
@ -33,6 +33,8 @@ class Bare(Component):
def _render(self) -> Tag:
if isinstance(self.contents, Var):
if isinstance(self.contents, (BooleanVar, ObjectVar, ArrayVar)):
return Tagless(contents=f"{{{str(self.contents.to_string())}}}")
return Tagless(contents=f"{{{str(self.contents)}}}")
return Tagless(contents=str(self.contents))

View File

@ -9,7 +9,6 @@ from reflex import constants
from reflex.event import Event, get_hydrate_event
from reflex.middleware.middleware import Middleware
from reflex.state import BaseState, StateUpdate
from reflex.utils import format
if TYPE_CHECKING:
from reflex.app import App
@ -43,7 +42,7 @@ class HydrateMiddleware(Middleware):
setattr(state, constants.CompileVars.IS_HYDRATED, False)
# Get the initial state.
delta = format.format_state(state.dict())
delta = state.dict()
# since a full dict was captured, clean any dirtiness
state._clean()

View File

@ -73,7 +73,7 @@ from reflex.utils.exceptions import (
LockExpiredError,
)
from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import SerializedType, serialize, serializer
from reflex.utils.serializers import serializer
from reflex.utils.types import override
from reflex.vars import VarData
@ -1790,9 +1790,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
for substate in self.dirty_substates.union(self._always_dirty_substates):
delta.update(substates[substate].get_delta())
# Format the delta.
delta = format.format_state(delta)
# Return the delta.
return delta
@ -2433,7 +2430,7 @@ class StateUpdate:
Returns:
The state update as a JSON string.
"""
return format.json_dumps(dataclasses.asdict(self))
return format.json_dumps(self)
class StateManager(Base, ABC):
@ -3660,22 +3657,16 @@ class MutableProxy(wrapt.ObjectProxy):
@serializer
def serialize_mutable_proxy(mp: MutableProxy) -> SerializedType:
"""Serialize the wrapped value of a MutableProxy.
def serialize_mutable_proxy(mp: MutableProxy):
"""Return the wrapped value of a MutableProxy.
Args:
mp: The MutableProxy to serialize.
Returns:
The serialized wrapped object.
Raises:
ValueError: when the wrapped object is not serializable.
The wrapped object.
"""
value = serialize(mp.__wrapped__)
if value is None:
raise ValueError(f"Cannot serialize {type(mp.__wrapped__)}")
return value
return mp.__wrapped__
class ImmutableMutableProxy(MutableProxy):

View File

@ -9,7 +9,7 @@ import re
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
from reflex import constants
from reflex.utils import exceptions, types
from reflex.utils import exceptions
from reflex.utils.console import deprecate
if TYPE_CHECKING:
@ -624,48 +624,6 @@ def format_query_params(router_data: dict[str, Any]) -> dict[str, str]:
return {k.replace("-", "_"): v for k, v in params.items()}
def format_state(value: Any, key: Optional[str] = None) -> Any:
"""Recursively format values in the given state.
Args:
value: The state to format.
key: The key associated with the value (optional).
Returns:
The formatted state.
Raises:
TypeError: If the given value is not a valid state.
"""
from reflex.utils import serializers
# Handle dicts.
if isinstance(value, dict):
return {k: format_state(v, k) for k, v in value.items()}
# Handle lists, sets, typles.
if isinstance(value, types.StateIterBases):
return [format_state(v) for v in value]
# Return state vars as is.
if isinstance(value, types.StateBases):
return value
# Serialize the value.
serialized = serializers.serialize(value)
if serialized is not None:
return serialized
if key is None:
raise TypeError(
f"No JSON serializer found for var {value} of type {type(value)}."
)
else:
raise TypeError(
f"No JSON serializer found for State Var '{key}' of value {value} of type {type(value)}."
)
def format_state_name(state_name: str) -> str:
"""Format a state name, replacing dots with double underscore.

View File

@ -12,7 +12,6 @@ from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
@ -126,7 +125,8 @@ def serialize(
# If there is no serializer, return None.
if serializer is None:
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return serialize(dataclasses.asdict(value))
return {k.name: getattr(value, k.name) for k in dataclasses.fields(value)}
if get_type:
return None, None
return None
@ -214,32 +214,6 @@ def serialize_type(value: type) -> str:
return value.__name__
@serializer
def serialize_str(value: str) -> str:
"""Serialize a string.
Args:
value: The string to serialize.
Returns:
The serialized string.
"""
return value
@serializer
def serialize_primitive(value: Union[bool, int, float, None]):
"""Serialize a primitive type.
Args:
value: The number/bool/None to serialize.
Returns:
The serialized number/bool/None.
"""
return value
@serializer
def serialize_base(value: Base) -> dict:
"""Serialize a Base instance.
@ -250,33 +224,20 @@ def serialize_base(value: Base) -> dict:
Returns:
The serialized Base.
"""
return {k: serialize(v) for k, v in value.dict().items() if not callable(v)}
return {k: v for k, v in value.dict().items() if not callable(v)}
@serializer
def serialize_list(value: Union[List, Tuple, Set]) -> list:
"""Serialize a list to a JSON string.
def serialize_set(value: Set) -> list:
"""Serialize a set to a JSON serializable list.
Args:
value: The list to serialize.
value: The set to serialize.
Returns:
The serialized list.
"""
return [serialize(item) for item in value]
@serializer
def serialize_dict(prop: Dict[str, Any]) -> dict:
"""Serialize a dictionary to a JSON string.
Args:
prop: The dictionary to serialize.
Returns:
The serialized dictionary.
"""
return {k: serialize(v) for k, v in prop.items()}
return list(value)
@serializer(to=str)

View File

@ -1141,7 +1141,7 @@ def serialize_literal(value: LiteralVar):
Returns:
The serialized Literal.
"""
return serializers.serialize(value._var_value)
return value._var_value
P = ParamSpec("P")

View File

@ -793,8 +793,8 @@ def test_var_operations(driver, var_operations: AppHarness):
("foreach_list_ix", "1\n2"),
("foreach_list_nested", "1\n1\n2"),
# rx.memo component with state
("memo_comp", "1210"),
("memo_comp_nested", "345"),
("memo_comp", "[1,2]10"),
("memo_comp_nested", "[3,4]5"),
# foreach in a match
("foreach_in_match", "first\nsecond\nthird"),
]

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import dataclasses
import functools
import io
import json
@ -1053,7 +1052,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
f"comp_{arg_name}": exp_val,
constants.CompileVars.IS_HYDRATED: False,
# "side_effect_counter": exp_index,
"router": dataclasses.asdict(exp_router),
"router": exp_router,
}
},
events=[

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import datetime
import json
from typing import Any, List
import plotly.graph_objects as go
@ -621,7 +622,7 @@ def test_format_state(input, output):
input: The state to format.
output: The expected formatted state.
"""
assert format.format_state(input) == output
assert json.loads(format.json_dumps(input)) == json.loads(format.json_dumps(output))
@pytest.mark.parametrize(

View File

@ -1,19 +1,21 @@
import datetime
import json
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Type
from typing import Any, Type
import pytest
from reflex.base import Base
from reflex.components.core.colors import Color
from reflex.utils import serializers
from reflex.utils.format import json_dumps
from reflex.vars.base import LiteralVar
@pytest.mark.parametrize(
"type_,expected",
[(str, True), (dict, True), (Dict[int, int], True), (Enum, True)],
[(Enum, True)],
)
def test_has_serializer(type_: Type, expected: bool):
"""Test that has_serializer returns the correct value.
@ -198,7 +200,7 @@ def test_serialize(value: Any, expected: str):
value: The value to serialize.
expected: The expected result.
"""
assert serializers.serialize(value) == expected
assert json.loads(json_dumps(value)) == json.loads(json_dumps(expected))
@pytest.mark.parametrize(