From df09c716c6c02e6cbf0a68ce216e2da9633bae16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Tue, 17 Oct 2023 19:44:54 +0200 Subject: [PATCH] fix serialization as a whole for list/dict/Base containing custom items to serialize (#1984) --- reflex/base.py | 4 ++- reflex/utils/format.py | 2 +- reflex/utils/serializers.py | 21 ++++++++++++--- tests/utils/test_format.py | 9 +++++++ tests/utils/test_serializers.py | 48 +++++++++++++++++++++++++++++++++ 5 files changed, 78 insertions(+), 6 deletions(-) diff --git a/reflex/base.py b/reflex/base.py index fef91189b..62bc70855 100644 --- a/reflex/base.py +++ b/reflex/base.py @@ -30,7 +30,9 @@ class Base(pydantic.BaseModel): Returns: The object as a json string. """ - return self.__config__.json_dumps(self.dict(), default=list) + from reflex.utils.serializers import serialize + + return self.__config__.json_dumps(self.dict(), default=serialize) def set(self, **kwargs): """Set multiple fields and return the object. diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 9514d13d5..437b4c465 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -597,7 +597,7 @@ def json_dumps(obj: Any) -> str: Returns: A string """ - return json.dumps(obj, ensure_ascii=False, default=list) + return json.dumps(obj, ensure_ascii=False, default=serialize) def unwrap_vars(value: str) -> str: diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index 41ffec0a4..b8646d615 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -93,7 +93,7 @@ def get_serializer(type_: Type) -> Serializer | None: return serializer # If the type is not registered, check if it is a subclass of a registered type. - for registered_type, serializer in SERIALIZERS.items(): + for registered_type, serializer in reversed(SERIALIZERS.items()): if types._issubclass(type_, registered_type): return serializer @@ -127,18 +127,31 @@ def serialize_str(value: str) -> str: @serializer -def serialize_primitive(value: Union[bool, int, float, Base, None]) -> str: +def serialize_primitive(value: Union[bool, int, float, None]) -> str: """Serialize a primitive type. Args: - value: The number to serialize. + value: The number/bool/None to serialize. Returns: - The serialized number. + The serialized number/bool/None. """ return format.json_dumps(value) +@serializer +def serialize_base(value: Base) -> str: + """Serialize a Base instance. + + Args: + value : The Base to serialize. + + Returns: + The serialized Base. + """ + return value.json() + + @serializer def serialize_list(value: Union[List, Tuple, Set]) -> str: """Serialize a list to a JSON string. diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index ab12e8a62..79264d0cd 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -1,3 +1,4 @@ +import datetime from typing import Any import pytest @@ -604,6 +605,14 @@ def test_format_library_name(input: str, output: str): ([1, 2, 3], "[1, 2, 3]"), ({}, "{}"), ({"k1": False, "k2": True}, '{"k1": false, "k2": true}'), + ( + [datetime.timedelta(1, 1, 1), datetime.timedelta(1, 1, 2)], + '["1 day, 0:00:01.000001", "1 day, 0:00:01.000002"]', + ), + ( + {"key1": datetime.timedelta(1, 1, 1), "key2": datetime.timedelta(1, 1, 2)}, + '{"key1": "1 day, 0:00:01.000001", "key2": "1 day, 0:00:01.000002"}', + ), ], ) def test_json_dumps(input, output): diff --git a/tests/utils/test_serializers.py b/tests/utils/test_serializers.py index 5abad9f15..4e8808840 100644 --- a/tests/utils/test_serializers.py +++ b/tests/utils/test_serializers.py @@ -1,8 +1,10 @@ import datetime +from enum import Enum from typing import Any, Dict, List, Type import pytest +from reflex.base import Base from reflex.utils import serializers from reflex.vars import Var @@ -93,6 +95,31 @@ def test_add_serializer(): assert not serializers.has_serializer(Foo) +class StrEnum(str, Enum): + """An enum also inheriting from str.""" + + FOO = "foo" + BAR = "bar" + + +class EnumWithPrefix(Enum): + """An enum with a serializer adding a prefix.""" + + FOO = "foo" + BAR = "bar" + + +@serializers.serializer +def serialize_EnumWithPrefix(enum: EnumWithPrefix) -> str: + return "prefix_" + enum.value + + +class BaseSubclass(Base): + """A class inheriting from Base for testing.""" + + ts: datetime.timedelta = datetime.timedelta(1, 1, 1) + + @pytest.mark.parametrize( "value,expected", [ @@ -104,6 +131,23 @@ def test_add_serializer(): (None, "null"), ([1, 2, 3], "[1, 2, 3]"), ([1, "2", 3.0], '[1, "2", 3.0]'), + ([{"key": 1}, {"key": 2}], '[{"key": 1}, {"key": 2}]'), + (StrEnum.FOO, "foo"), + ([StrEnum.FOO, StrEnum.BAR], '["foo", "bar"]'), + ( + {"key1": [1, 2, 3], "key2": [StrEnum.FOO, StrEnum.BAR]}, + '{"key1": [1, 2, 3], "key2": ["foo", "bar"]}', + ), + (EnumWithPrefix.FOO, "prefix_foo"), + ([EnumWithPrefix.FOO, EnumWithPrefix.BAR], '["prefix_foo", "prefix_bar"]'), + ( + {"key1": EnumWithPrefix.FOO, "key2": EnumWithPrefix.BAR}, + '{"key1": "prefix_foo", "key2": "prefix_bar"}', + ), + ( + BaseSubclass(ts=datetime.timedelta(1, 1, 1)), + '{"ts": "1 day, 0:00:01.000001"}', + ), ( [1, Var.create_safe("hi"), Var.create_safe("bye", _var_is_local=False)], '[1, "hi", bye]', @@ -121,6 +165,10 @@ def test_add_serializer(): (datetime.date(2021, 1, 1), "2021-01-01"), (datetime.time(1, 1, 1, 1), "01:01:01.000001"), (datetime.timedelta(1, 1, 1), "1 day, 0:00:01.000001"), + ( + [datetime.timedelta(1, 1, 1), datetime.timedelta(1, 1, 2)], + '["1 day, 0:00:01.000001", "1 day, 0:00:01.000002"]', + ), ], ) def test_serialize(value: Any, expected: str):