From dba051e8ca4ba12cadad53c633c6a284cca46227 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Mon, 16 Sep 2024 16:36:01 -0700 Subject: [PATCH] use serializer for state update and rework serializers (#3934) * use serializer for state update and rework serializers * format --- reflex/state.py | 24 +------------- reflex/utils/format.py | 24 ++------------ reflex/utils/serializers.py | 29 +++++++--------- reflex/vars/base.py | 28 ++++++++-------- tests/utils/test_format.py | 42 +++++++++++------------ tests/utils/test_serializers.py | 59 +++++++++++++++++++-------------- 6 files changed, 84 insertions(+), 122 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 4581bbae1..fbef6dd4a 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -8,7 +8,6 @@ import copy import dataclasses import functools import inspect -import json import os import uuid from abc import ABC, abstractmethod @@ -206,27 +205,6 @@ class RouterData: object.__setattr__(self, "headers", HeaderData(router_data)) object.__setattr__(self, "page", PageData(router_data)) - def toJson(self) -> str: - """Convert the object to a JSON string. - - Returns: - The JSON string. - """ - return json.dumps(dataclasses.asdict(self)) - - -@serializer -def serialize_routerdata(value: RouterData) -> str: - """Serialize a RouterData instance. - - Args: - value: The RouterData to serialize. - - Returns: - The serialized RouterData. - """ - return value.toJson() - def _no_chain_background_task( state_cls: Type["BaseState"], name: str, fn: Callable @@ -2415,7 +2393,7 @@ class StateUpdate: Returns: The state update as a JSON string. """ - return json.dumps(dataclasses.asdict(self)) + return format.json_dumps(dataclasses.asdict(self)) class StateManager(Base, ABC): diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 39f886f93..e8b040230 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -2,7 +2,6 @@ from __future__ import annotations -import dataclasses import inspect import json import os @@ -410,22 +409,11 @@ def format_props(*single_props, **key_value_props) -> list[str]: return [ ( - f"{name}={format_prop(prop)}" - if isinstance(prop, Var) and not isinstance(prop, Var) - else ( - f"{name}={{{format_prop(prop if isinstance(prop, Var) else LiteralVar.create(prop))}}}" - ) + f"{name}={{{format_prop(prop if isinstance(prop, Var) else LiteralVar.create(prop))}}}" ) for name, prop in sorted(key_value_props.items()) if prop is not None - ] + [ - ( - str(prop) - if isinstance(prop, Var) and not isinstance(prop, Var) - else f"{str(LiteralVar.create(prop))}" - ) - for prop in single_props - ] + ] + [(f"{str(LiteralVar.create(prop))}") for prop in single_props] def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]: @@ -623,14 +611,6 @@ def format_state(value: Any, key: Optional[str] = None) -> Any: if isinstance(value, dict): return {k: format_state(v, k) for k, v in value.items()} - # Hand dataclasses. - if dataclasses.is_dataclass(value): - if isinstance(value, type): - raise TypeError( - f"Cannot format state of type {type(value)}. Please provide an instance of the dataclass." - ) - return {k: format_state(v, k) for k, v in dataclasses.asdict(value).items()} - # Handle lists, sets, typles. if isinstance(value, types.StateIterBases): return [format_state(v) for v in value] diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index c5cded3b6..42fb82916 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -2,6 +2,7 @@ from __future__ import annotations +import dataclasses import functools import json import warnings @@ -29,7 +30,7 @@ from reflex.utils import types # Mapping from type to a serializer. # The serializer should convert the type to a JSON object. -SerializedType = Union[str, bool, int, float, list, dict] +SerializedType = Union[str, bool, int, float, list, dict, None] Serializer = Callable[[Type], SerializedType] @@ -124,6 +125,8 @@ def serialize( # If there is no serializer, return None. if serializer is None: + if dataclasses.is_dataclass(value) and not isinstance(value, type): + return serialize(dataclasses.asdict(value)) if get_type: return None, None return None @@ -225,7 +228,7 @@ def serialize_str(value: str) -> str: @serializer -def serialize_primitive(value: Union[bool, int, float, None]) -> str: +def serialize_primitive(value: Union[bool, int, float, None]): """Serialize a primitive type. Args: @@ -234,13 +237,11 @@ def serialize_primitive(value: Union[bool, int, float, None]) -> str: Returns: The serialized number/bool/None. """ - from reflex.utils import format - - return format.json_dumps(value) + return value @serializer -def serialize_base(value: Base) -> str: +def serialize_base(value: Base) -> dict: """Serialize a Base instance. Args: @@ -249,13 +250,11 @@ def serialize_base(value: Base) -> str: Returns: The serialized Base. """ - from reflex.vars import LiteralVar - - return str(LiteralVar.create(value)) + return {k: serialize(v) for k, v in value.dict().items() if not callable(v)} @serializer -def serialize_list(value: Union[List, Tuple, Set]) -> str: +def serialize_list(value: Union[List, Tuple, Set]) -> list: """Serialize a list to a JSON string. Args: @@ -264,13 +263,11 @@ def serialize_list(value: Union[List, Tuple, Set]) -> str: Returns: The serialized list. """ - from reflex.vars import LiteralArrayVar - - return str(LiteralArrayVar.create(value)) + return [serialize(item) for item in value] @serializer -def serialize_dict(prop: Dict[str, Any]) -> str: +def serialize_dict(prop: Dict[str, Any]) -> dict: """Serialize a dictionary to a JSON string. Args: @@ -279,9 +276,7 @@ def serialize_dict(prop: Dict[str, Any]) -> str: Returns: The serialized dictionary. """ - from reflex.vars import LiteralObjectVar - - return str(LiteralObjectVar.create(prop)) + return {k: serialize(v) for k, v in prop.items()} @serializer(to=str) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index d0d14a825..4a7584258 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -936,21 +936,6 @@ class Var(Generic[VAR_TYPE]): OUTPUT = TypeVar("OUTPUT", bound=Var) -def _encode_var(value: Var) -> str: - """Encode the state name into a formatted var. - - Args: - value: The value to encode the state name into. - - Returns: - The encoded var. - """ - return f"{value}" - - -serializers.serializer(_encode_var) - - class LiteralVar(Var): """Base class for immutable literal vars.""" @@ -1101,6 +1086,19 @@ class LiteralVar(Var): ) +@serializers.serializer +def serialize_literal(value: LiteralVar): + """Serialize a Literal type. + + Args: + value: The Literal to serialize. + + Returns: + The serialized Literal. + """ + return serializers.serialize(value._var_value) + + P = ParamSpec("P") T = TypeVar("T") diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index ec5eabb73..b108f9dc2 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -352,28 +352,28 @@ def test_format_match( "prop,formatted", [ ("string", '"string"'), - ("{wrapped_string}", "{wrapped_string}"), - (True, "{true}"), - (False, "{false}"), - (123, "{123}"), - (3.14, "{3.14}"), - ([1, 2, 3], "{[1, 2, 3]}"), - (["a", "b", "c"], '{["a", "b", "c"]}'), - ({"a": 1, "b": 2, "c": 3}, '{({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })}'), - ({"a": 'foo "bar" baz'}, r'{({ ["a"] : "foo \"bar\" baz" })}'), + ("{wrapped_string}", '"{wrapped_string}"'), + (True, "true"), + (False, "false"), + (123, "123"), + (3.14, "3.14"), + ([1, 2, 3], "[1, 2, 3]"), + (["a", "b", "c"], '["a", "b", "c"]'), + ({"a": 1, "b": 2, "c": 3}, '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })'), + ({"a": 'foo "bar" baz'}, r'({ ["a"] : "foo \"bar\" baz" })'), ( { "a": 'foo "{ "bar" }" baz', "b": Var(_js_expr="val", _var_type=str).guess_type(), }, - r'{({ ["a"] : "foo \"{ \"bar\" }\" baz", ["b"] : val })}', + r'({ ["a"] : "foo \"{ \"bar\" }\" baz", ["b"] : val })', ), ( EventChain( events=[EventSpec(handler=EventHandler(fn=mock_event))], args_spec=lambda: [], ), - '{(...args) => addEvents([Event("mock_event", {})], args, {})}', + '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ })))))', ), ( EventChain( @@ -382,7 +382,7 @@ def test_format_match( handler=EventHandler(fn=mock_event), args=( ( - LiteralVar.create("arg"), + Var(_js_expr="arg"), Var( _js_expr="_e", ) @@ -394,7 +394,7 @@ def test_format_match( ], args_spec=lambda e: [e.target.value], ), - '{(_e) => addEvents([Event("mock_event", {"arg":_e["target"]["value"]})], [_e], {})}', + '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] })))], [_e], ({ })))))', ), ( EventChain( @@ -402,7 +402,7 @@ def test_format_match( args_spec=lambda: [], event_actions={"stopPropagation": True}, ), - '{(...args) => addEvents([Event("mock_event", {})], args, {"stopPropagation": true})}', + '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["stopPropagation"] : true })))))', ), ( EventChain( @@ -410,9 +410,9 @@ def test_format_match( args_spec=lambda: [], event_actions={"preventDefault": True}, ), - '{(...args) => addEvents([Event("mock_event", {})], args, {"preventDefault": true})}', + '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["preventDefault"] : true })))))', ), - ({"a": "red", "b": "blue"}, '{({ ["a"] : "red", ["b"] : "blue" })}'), + ({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'), (Var(_js_expr="var", _var_type=int).guess_type(), "var"), ( Var( @@ -427,15 +427,15 @@ def test_format_match( ), ( {"a": Var(_js_expr="val", _var_type=str).guess_type()}, - '{({ ["a"] : val })}', + '({ ["a"] : val })', ), ( {"a": Var(_js_expr='"val"', _var_type=str).guess_type()}, - '{({ ["a"] : "val" })}', + '({ ["a"] : "val" })', ), ( {"a": Var(_js_expr='state.colors["val"]', _var_type=str).guess_type()}, - '{({ ["a"] : state.colors["val"] })}', + '({ ["a"] : state.colors["val"] })', ), # tricky real-world case from markdown component ( @@ -444,7 +444,7 @@ def test_format_match( _js_expr=f"(({{node, ...props}}) => )" ), }, - '{({ ["h1"] : (({node, ...props}) => ) })}', + '({ ["h1"] : (({node, ...props}) => ) })', ), ], ) @@ -455,7 +455,7 @@ def test_format_prop(prop: Var, formatted: str): prop: The prop to test. formatted: The expected formatted value. """ - assert format.format_prop(prop) == formatted + assert format.format_prop(LiteralVar.create(prop)) == formatted @pytest.mark.parametrize( diff --git a/tests/utils/test_serializers.py b/tests/utils/test_serializers.py index 2605edba9..97da98792 100644 --- a/tests/utils/test_serializers.py +++ b/tests/utils/test_serializers.py @@ -8,7 +8,7 @@ import pytest from reflex.base import Base from reflex.components.core.colors import Color from reflex.utils import serializers -from reflex.vars.base import LiteralVar, Var +from reflex.vars.base import LiteralVar @pytest.mark.parametrize( @@ -123,48 +123,59 @@ class BaseSubclass(Base): "value,expected", [ ("test", "test"), - (1, "1"), - (1.0, "1.0"), - (True, "true"), - (False, "false"), - (None, "null"), - ([1, 2, 3], "[1, 2, 3]"), - ([1, "2", 3.0], '[1, "2", 3.0]'), - ([{"key": 1}, {"key": 2}], '[({ ["key"] : 1 }), ({ ["key"] : 2 })]'), + (1, 1), + (1.0, 1.0), + (True, True), + (False, False), + (None, None), + ([1, 2, 3], [1, 2, 3]), + ([1, "2", 3.0], [1, "2", 3.0]), + ([{"key": 1}, {"key": 2}], [{"key": 1}, {"key": 2}]), (StrEnum.FOO, "foo"), - ([StrEnum.FOO, StrEnum.BAR], '["foo", "bar"]'), + ([StrEnum.FOO, StrEnum.BAR], ["foo", "bar"]), ( {"key1": [1, 2, 3], "key2": [StrEnum.FOO, StrEnum.BAR]}, - '({ ["key1"] : [1, 2, 3], ["key2"] : ["foo", "bar"] })', + { + "key1": [1, 2, 3], + "key2": ["foo", "bar"], + }, ), (EnumWithPrefix.FOO, "prefix_foo"), - ([EnumWithPrefix.FOO, EnumWithPrefix.BAR], '["prefix_foo", "prefix_bar"]'), + ([EnumWithPrefix.FOO, EnumWithPrefix.BAR], ["prefix_foo", "prefix_bar"]), ( {"key1": EnumWithPrefix.FOO, "key2": EnumWithPrefix.BAR}, - '({ ["key1"] : "prefix_foo", ["key2"] : "prefix_bar" })', + { + "key1": "prefix_foo", + "key2": "prefix_bar", + }, ), (TestEnum.FOO, "foo"), - ([TestEnum.FOO, TestEnum.BAR], '["foo", "bar"]'), + ([TestEnum.FOO, TestEnum.BAR], ["foo", "bar"]), ( {"key1": TestEnum.FOO, "key2": TestEnum.BAR}, - '({ ["key1"] : "foo", ["key2"] : "bar" })', + { + "key1": "foo", + "key2": "bar", + }, ), ( BaseSubclass(ts=datetime.timedelta(1, 1, 1)), - '({ ["ts"] : "1 day, 0:00:01.000001" })', + { + "ts": "1 day, 0:00:01.000001", + }, ), ( - [1, LiteralVar.create("hi"), Var(_js_expr="bye")], - '[1, "hi", bye]', + [1, LiteralVar.create("hi")], + [1, "hi"], ), ( - (1, LiteralVar.create("hi"), Var(_js_expr="bye")), - '[1, "hi", bye]', + (1, LiteralVar.create("hi")), + [1, "hi"], ), - ({1: 2, 3: 4}, "({ [1] : 2, [3] : 4 })"), + ({1: 2, 3: 4}, {1: 2, 3: 4}), ( - {1: LiteralVar.create("hi"), 3: Var(_js_expr="bye")}, - '({ [1] : "hi", [3] : bye })', + {1: LiteralVar.create("hi")}, + {1: "hi"}, ), (datetime.datetime(2021, 1, 1, 1, 1, 1, 1), "2021-01-01 01:01:01.000001"), (datetime.date(2021, 1, 1), "2021-01-01"), @@ -172,7 +183,7 @@ class BaseSubclass(Base): (datetime.timedelta(1, 1, 1), "1 day, 0:00:01.000001"), ( [datetime.timedelta(1, 1, 1), datetime.timedelta(1, 1, 2)], - '["1 day, 0:00:01.000001", "1 day, 0:00:01.000002"]', + ["1 day, 0:00:01.000001", "1 day, 0:00:01.000002"], ), (Color(color="slate", shade=1), "var(--slate-1)"), (Color(color="orange", shade=1, alpha=True), "var(--orange-a1)"),