fix serialization as a whole for list/dict/Base containing custom items to serialize (#1984)

This commit is contained in:
Thomas Brandého 2023-10-17 19:44:54 +02:00 committed by GitHub
parent d1d5812602
commit df09c716c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 78 additions and 6 deletions

View File

@ -30,7 +30,9 @@ class Base(pydantic.BaseModel):
Returns:
The object as a json string.
"""
return self.__config__.json_dumps(self.dict(), default=list)
from reflex.utils.serializers import serialize
return self.__config__.json_dumps(self.dict(), default=serialize)
def set(self, **kwargs):
"""Set multiple fields and return the object.

View File

@ -597,7 +597,7 @@ def json_dumps(obj: Any) -> str:
Returns:
A string
"""
return json.dumps(obj, ensure_ascii=False, default=list)
return json.dumps(obj, ensure_ascii=False, default=serialize)
def unwrap_vars(value: str) -> str:

View File

@ -93,7 +93,7 @@ def get_serializer(type_: Type) -> Serializer | None:
return serializer
# If the type is not registered, check if it is a subclass of a registered type.
for registered_type, serializer in SERIALIZERS.items():
for registered_type, serializer in reversed(SERIALIZERS.items()):
if types._issubclass(type_, registered_type):
return serializer
@ -127,18 +127,31 @@ def serialize_str(value: str) -> str:
@serializer
def serialize_primitive(value: Union[bool, int, float, Base, None]) -> str:
def serialize_primitive(value: Union[bool, int, float, None]) -> str:
"""Serialize a primitive type.
Args:
value: The number to serialize.
value: The number/bool/None to serialize.
Returns:
The serialized number.
The serialized number/bool/None.
"""
return format.json_dumps(value)
@serializer
def serialize_base(value: Base) -> str:
"""Serialize a Base instance.
Args:
value : The Base to serialize.
Returns:
The serialized Base.
"""
return value.json()
@serializer
def serialize_list(value: Union[List, Tuple, Set]) -> str:
"""Serialize a list to a JSON string.

View File

@ -1,3 +1,4 @@
import datetime
from typing import Any
import pytest
@ -604,6 +605,14 @@ def test_format_library_name(input: str, output: str):
([1, 2, 3], "[1, 2, 3]"),
({}, "{}"),
({"k1": False, "k2": True}, '{"k1": false, "k2": true}'),
(
[datetime.timedelta(1, 1, 1), datetime.timedelta(1, 1, 2)],
'["1 day, 0:00:01.000001", "1 day, 0:00:01.000002"]',
),
(
{"key1": datetime.timedelta(1, 1, 1), "key2": datetime.timedelta(1, 1, 2)},
'{"key1": "1 day, 0:00:01.000001", "key2": "1 day, 0:00:01.000002"}',
),
],
)
def test_json_dumps(input, output):

View File

@ -1,8 +1,10 @@
import datetime
from enum import Enum
from typing import Any, Dict, List, Type
import pytest
from reflex.base import Base
from reflex.utils import serializers
from reflex.vars import Var
@ -93,6 +95,31 @@ def test_add_serializer():
assert not serializers.has_serializer(Foo)
class StrEnum(str, Enum):
"""An enum also inheriting from str."""
FOO = "foo"
BAR = "bar"
class EnumWithPrefix(Enum):
"""An enum with a serializer adding a prefix."""
FOO = "foo"
BAR = "bar"
@serializers.serializer
def serialize_EnumWithPrefix(enum: EnumWithPrefix) -> str:
return "prefix_" + enum.value
class BaseSubclass(Base):
"""A class inheriting from Base for testing."""
ts: datetime.timedelta = datetime.timedelta(1, 1, 1)
@pytest.mark.parametrize(
"value,expected",
[
@ -104,6 +131,23 @@ def test_add_serializer():
(None, "null"),
([1, 2, 3], "[1, 2, 3]"),
([1, "2", 3.0], '[1, "2", 3.0]'),
([{"key": 1}, {"key": 2}], '[{"key": 1}, {"key": 2}]'),
(StrEnum.FOO, "foo"),
([StrEnum.FOO, StrEnum.BAR], '["foo", "bar"]'),
(
{"key1": [1, 2, 3], "key2": [StrEnum.FOO, StrEnum.BAR]},
'{"key1": [1, 2, 3], "key2": ["foo", "bar"]}',
),
(EnumWithPrefix.FOO, "prefix_foo"),
([EnumWithPrefix.FOO, EnumWithPrefix.BAR], '["prefix_foo", "prefix_bar"]'),
(
{"key1": EnumWithPrefix.FOO, "key2": EnumWithPrefix.BAR},
'{"key1": "prefix_foo", "key2": "prefix_bar"}',
),
(
BaseSubclass(ts=datetime.timedelta(1, 1, 1)),
'{"ts": "1 day, 0:00:01.000001"}',
),
(
[1, Var.create_safe("hi"), Var.create_safe("bye", _var_is_local=False)],
'[1, "hi", bye]',
@ -121,6 +165,10 @@ def test_add_serializer():
(datetime.date(2021, 1, 1), "2021-01-01"),
(datetime.time(1, 1, 1, 1), "01:01:01.000001"),
(datetime.timedelta(1, 1, 1), "1 day, 0:00:01.000001"),
(
[datetime.timedelta(1, 1, 1), datetime.timedelta(1, 1, 2)],
'["1 day, 0:00:01.000001", "1 day, 0:00:01.000002"]',
),
],
)
def test_serialize(value: Any, expected: str):