[REF-3009] type transforming serializers (#3227)

* wip type transforming serializers

* old python sucks

* typing fixups

* Expose the `to` parameter on `rx.serializer` for type conversion

Serializers can also return a tuple of `(serialized_value, type)`, if both ways
are specified, then the returned value MUST match the `to` parameter.

When initializing a new rx.Var, if `_var_is_string` is not specified and the serializer returns a `str` type, then mark `_var_is_string=True` to indicate that the Var should be treated like a string literal.

Include datetime, color, types, and paths as "serializing to str" type.

Avoid other changes at this point to reduce fallout from this change:

  Notably, the `serialize_str` function does NOT cast to `str`, which
  would cause existing code to treat all Var initialized with a str as a
  str literal even though this was NOT the default before.

Update test cases to accomodate these changes.

* Raise deprecation warning for rx.Var.create with string literal

In the future, we will treat strings as string literals in the JS code. To get
a Var that is not treated like a string, pass _var_is_string=False.

This will allow our serializers to automatically identify cast string literals
with less special cases (and the special cases need to be explicitly
identified).

* Add test case for mismatched serialized types

* fix old python

* Remove serializer returning a tuple feature

Simplify the logic; instead of making a wrapper function that returns
a tuple, just save the type conversions in a separate global.

* Reset the LRU cache when adding new serializers

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
benedikt-bartscher 2024-06-07 18:50:10 +02:00 committed by GitHub
parent 168501d58a
commit e42d4ed9ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 181 additions and 30 deletions

View File

@ -2,13 +2,27 @@
from __future__ import annotations from __future__ import annotations
import functools
import json import json
import types as builtin_types import types as builtin_types
import warnings import warnings
from datetime import date, datetime, time, timedelta from datetime import date, datetime, time, timedelta
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union, get_type_hints from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Set,
Tuple,
Type,
Union,
get_type_hints,
overload,
)
from reflex.base import Base from reflex.base import Base
from reflex.constants.colors import Color, format_color from reflex.constants.colors import Color, format_color
@ -17,15 +31,24 @@ from reflex.utils import exceptions, types
# Mapping from type to a serializer. # Mapping from type to a serializer.
# The serializer should convert the type to a JSON object. # The serializer should convert the type to a JSON object.
SerializedType = Union[str, bool, int, float, list, dict] SerializedType = Union[str, bool, int, float, list, dict]
Serializer = Callable[[Type], SerializedType] Serializer = Callable[[Type], SerializedType]
SERIALIZERS: dict[Type, Serializer] = {} SERIALIZERS: dict[Type, Serializer] = {}
SERIALIZER_TYPES: dict[Type, Type] = {}
def serializer(fn: Serializer) -> Serializer: def serializer(
fn: Serializer | None = None,
to: Type | None = None,
) -> Serializer:
"""Decorator to add a serializer for a given type. """Decorator to add a serializer for a given type.
Args: Args:
fn: The function to decorate. fn: The function to decorate.
to: The type returned by the serializer. If this is `str`, then any Var created from this type will be treated as a string.
Returns: Returns:
The decorated function. The decorated function.
@ -33,8 +56,9 @@ def serializer(fn: Serializer) -> Serializer:
Raises: Raises:
ValueError: If the function does not take a single argument. ValueError: If the function does not take a single argument.
""" """
# Get the global serializers. if fn is None:
global SERIALIZERS # If the function is not provided, return a partial that acts as a decorator.
return functools.partial(serializer, to=to) # type: ignore
# Check the type hints to get the type of the argument. # Check the type hints to get the type of the argument.
type_hints = get_type_hints(fn) type_hints = get_type_hints(fn)
@ -54,18 +78,47 @@ def serializer(fn: Serializer) -> Serializer:
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}." f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
) )
# Apply type transformation if requested
if to is not None:
SERIALIZER_TYPES[type_] = to
get_serializer_type.cache_clear()
# Register the serializer. # Register the serializer.
SERIALIZERS[type_] = fn SERIALIZERS[type_] = fn
get_serializer.cache_clear()
# Return the function. # Return the function.
return fn return fn
def serialize(value: Any) -> SerializedType | None: @overload
def serialize(
value: Any, get_type: Literal[True]
) -> Tuple[Optional[SerializedType], Optional[types.GenericType]]:
...
@overload
def serialize(value: Any, get_type: Literal[False]) -> Optional[SerializedType]:
...
@overload
def serialize(value: Any) -> Optional[SerializedType]:
...
def serialize(
value: Any, get_type: bool = False
) -> Union[
Optional[SerializedType],
Tuple[Optional[SerializedType], Optional[types.GenericType]],
]:
"""Serialize the value to a JSON string. """Serialize the value to a JSON string.
Args: Args:
value: The value to serialize. value: The value to serialize.
get_type: Whether to return the type of the serialized value.
Returns: Returns:
The serialized value, or None if a serializer is not found. The serialized value, or None if a serializer is not found.
@ -75,13 +128,22 @@ def serialize(value: Any) -> SerializedType | None:
# If there is no serializer, return None. # If there is no serializer, return None.
if serializer is None: if serializer is None:
if get_type:
return None, None
return None return None
# Serialize the value. # Serialize the value.
return serializer(value) serialized = serializer(value)
# Return the serialized value and the type.
if get_type:
return serialized, get_serializer_type(type(value))
else:
return serialized
def get_serializer(type_: Type) -> Serializer | None: @functools.lru_cache
def get_serializer(type_: Type) -> Optional[Serializer]:
"""Get the serializer for the type. """Get the serializer for the type.
Args: Args:
@ -90,8 +152,6 @@ def get_serializer(type_: Type) -> Serializer | None:
Returns: Returns:
The serializer for the type, or None if there is no serializer. The serializer for the type, or None if there is no serializer.
""" """
global SERIALIZERS
# First, check if the type is registered. # First, check if the type is registered.
serializer = SERIALIZERS.get(type_) serializer = SERIALIZERS.get(type_)
if serializer is not None: if serializer is not None:
@ -106,6 +166,30 @@ def get_serializer(type_: Type) -> Serializer | None:
return None return None
@functools.lru_cache
def get_serializer_type(type_: Type) -> Optional[Type]:
"""Get the converted type for the type after serializing.
Args:
type_: The type to get the serializer type for.
Returns:
The serialized type for the type, or None if there is no type conversion registered.
"""
# First, check if the type is registered.
serializer = SERIALIZER_TYPES.get(type_)
if serializer is not None:
return serializer
# If the type is not registered, check if it is a subclass of a registered type.
for registered_type, serializer in reversed(SERIALIZER_TYPES.items()):
if types._issubclass(type_, registered_type):
return serializer
# If there is no serializer, return None.
return None
def has_serializer(type_: Type) -> bool: def has_serializer(type_: Type) -> bool:
"""Check if there is a serializer for the type. """Check if there is a serializer for the type.
@ -118,7 +202,7 @@ def has_serializer(type_: Type) -> bool:
return get_serializer(type_) is not None return get_serializer(type_) is not None
@serializer @serializer(to=str)
def serialize_type(value: type) -> str: def serialize_type(value: type) -> str:
"""Serialize a python type. """Serialize a python type.
@ -226,7 +310,7 @@ def serialize_dict(prop: Dict[str, Any]) -> str:
return format.unwrap_vars(fprop) return format.unwrap_vars(fprop)
@serializer @serializer(to=str)
def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str: def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str:
"""Serialize a datetime to a JSON string. """Serialize a datetime to a JSON string.
@ -239,8 +323,8 @@ def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str:
return str(dt) return str(dt)
@serializer @serializer(to=str)
def serialize_path(path: Path): def serialize_path(path: Path) -> str:
"""Serialize a pathlib.Path to a JSON string. """Serialize a pathlib.Path to a JSON string.
Args: Args:
@ -265,7 +349,7 @@ def serialize_enum(en: Enum) -> str:
return en.value return en.value
@serializer @serializer(to=str)
def serialize_color(color: Color) -> str: def serialize_color(color: Color) -> str:
"""Serialize a color. """Serialize a color.

View File

@ -347,7 +347,7 @@ class Var:
cls, cls,
value: Any, value: Any,
_var_is_local: bool = True, _var_is_local: bool = True,
_var_is_string: bool = False, _var_is_string: bool | None = None,
_var_data: Optional[VarData] = None, _var_data: Optional[VarData] = None,
) -> Var | None: ) -> Var | None:
"""Create a var from a value. """Create a var from a value.
@ -380,18 +380,39 @@ class Var:
# Try to serialize the value. # Try to serialize the value.
type_ = type(value) type_ = type(value)
name = value if type_ in types.JSONType else serializers.serialize(value) if type_ in types.JSONType:
name = value
else:
name, serialized_type = serializers.serialize(value, get_type=True)
if (
serialized_type is not None
and _var_is_string is None
and issubclass(serialized_type, str)
):
_var_is_string = True
if name is None: if name is None:
raise VarTypeError( raise VarTypeError(
f"No JSON serializer found for var {value} of type {type_}." f"No JSON serializer found for var {value} of type {type_}."
) )
name = name if isinstance(name, str) else format.json_dumps(name) name = name if isinstance(name, str) else format.json_dumps(name)
if _var_is_string is None and type_ is str:
console.deprecate(
feature_name="Creating a Var from a string without specifying _var_is_string",
reason=(
"Specify _var_is_string=False to create a Var that is not a string literal. "
"In the future, creating a Var from a string will be treated as a string literal "
"by default."
),
deprecation_version="0.5.4",
removal_version="0.6.0",
)
return BaseVar( return BaseVar(
_var_name=name, _var_name=name,
_var_type=type_, _var_type=type_,
_var_is_local=_var_is_local, _var_is_local=_var_is_local,
_var_is_string=_var_is_string, _var_is_string=_var_is_string if _var_is_string is not None else False,
_var_data=_var_data, _var_data=_var_data,
) )
@ -400,7 +421,7 @@ class Var:
cls, cls,
value: Any, value: Any,
_var_is_local: bool = True, _var_is_local: bool = True,
_var_is_string: bool = False, _var_is_string: bool | None = None,
_var_data: Optional[VarData] = None, _var_data: Optional[VarData] = None,
) -> Var: ) -> Var:
"""Create a var from a value, asserting that it is not None. """Create a var from a value, asserting that it is not None.
@ -847,19 +868,19 @@ class Var:
if invoke_fn: if invoke_fn:
# invoke the function on left operand. # invoke the function on left operand.
operation_name = ( operation_name = (
f"{left_operand_full_name}.{fn}({right_operand_full_name})" f"{left_operand_full_name}.{fn}({right_operand_full_name})" # type: ignore
) # type: ignore )
else: else:
# pass the operands as arguments to the function. # pass the operands as arguments to the function.
operation_name = ( operation_name = (
f"{left_operand_full_name} {op} {right_operand_full_name}" f"{left_operand_full_name} {op} {right_operand_full_name}" # type: ignore
) # type: ignore )
operation_name = f"{fn}({operation_name})" operation_name = f"{fn}({operation_name})"
else: else:
# apply operator to operands (left operand <operator> right_operand) # apply operator to operands (left operand <operator> right_operand)
operation_name = ( operation_name = (
f"{left_operand_full_name} {op} {right_operand_full_name}" f"{left_operand_full_name} {op} {right_operand_full_name}" # type: ignore
) # type: ignore )
operation_name = format.wrap(operation_name, "(") operation_name = format.wrap(operation_name, "(")
else: else:
# apply operator to left operand (<operator> left_operand) # apply operator to left operand (<operator> left_operand)

View File

@ -51,11 +51,11 @@ class Var:
_var_data: VarData | None = None _var_data: VarData | None = None
@classmethod @classmethod
def create( def create(
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None, cls, value: Any, _var_is_local: bool = True, _var_is_string: bool | None = None, _var_data: VarData | None = None,
) -> Optional[Var]: ... ) -> Optional[Var]: ...
@classmethod @classmethod
def create_safe( def create_safe(
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None, cls, value: Any, _var_is_local: bool = True, _var_is_string: bool | None = None, _var_data: VarData | None = None,
) -> Var: ... ) -> Var: ...
@classmethod @classmethod
def __class_getitem__(cls, type_: Type) -> _GenericAlias: ... def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...

View File

@ -2,6 +2,7 @@ import pytest
import reflex as rx import reflex as rx
from reflex.components.datadisplay.code import CodeBlock from reflex.components.datadisplay.code import CodeBlock
from reflex.constants.colors import Color
from reflex.vars import Var from reflex.vars import Var
@ -50,7 +51,12 @@ def create_color_var(color):
], ],
) )
def test_color(color, expected): def test_color(color, expected):
assert str(color) == expected assert color._var_is_string or color._var_type is str
assert color._var_full_name == expected
if color._var_type == Color:
assert str(color) == f"{{`{expected}`}}"
else:
assert str(color) == expected
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -96,9 +102,9 @@ def test_color_with_conditionals(cond_var, expected):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"color, expected", "color, expected",
[ [
(create_color_var(rx.color("red")), "var(--red-7)"), (create_color_var(rx.color("red")), "{`var(--red-7)`}"),
(create_color_var(rx.color("green", shade=1)), "var(--green-1)"), (create_color_var(rx.color("green", shade=1)), "{`var(--green-1)`}"),
(create_color_var(rx.color("blue", alpha=True)), "var(--blue-a7)"), (create_color_var(rx.color("blue", alpha=True)), "{`var(--blue-a7)`}"),
("red", "red"), ("red", "red"),
("green", "green"), ("green", "green"),
("blue", "blue"), ("blue", "blue"),

View File

@ -1,5 +1,6 @@
import datetime import datetime
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Type from typing import Any, Dict, List, Type
import pytest import pytest
@ -90,6 +91,9 @@ def test_add_serializer():
# Remove the serializer. # Remove the serializer.
serializers.SERIALIZERS.pop(Foo) serializers.SERIALIZERS.pop(Foo)
# LRU cache will still have the serializer, so we need to clear it.
assert serializers.has_serializer(Foo)
serializers.get_serializer.cache_clear()
assert not serializers.has_serializer(Foo) assert not serializers.has_serializer(Foo)
@ -194,3 +198,39 @@ def test_serialize(value: Any, expected: str):
expected: The expected result. expected: The expected result.
""" """
assert serializers.serialize(value) == expected assert serializers.serialize(value) == expected
@pytest.mark.parametrize(
"value,expected,exp_var_is_string",
[
("test", "test", False),
(1, "1", False),
(1.0, "1.0", False),
(True, "true", False),
(False, "false", False),
([1, 2, 3], "[1, 2, 3]", False),
([{"key": 1}, {"key": 2}], '[{"key": 1}, {"key": 2}]', False),
(StrEnum.FOO, "foo", False),
([StrEnum.FOO, StrEnum.BAR], '["foo", "bar"]', False),
(
BaseSubclass(ts=datetime.timedelta(1, 1, 1)),
'{"ts": "1 day, 0:00:01.000001"}',
False,
),
(datetime.datetime(2021, 1, 1, 1, 1, 1, 1), "2021-01-01 01:01:01.000001", True),
(Color(color="slate", shade=1), "var(--slate-1)", True),
(BaseSubclass, "BaseSubclass", True),
(Path("."), ".", True),
],
)
def test_serialize_var_to_str(value: Any, expected: str, exp_var_is_string: bool):
"""Test that serialize with `to=str` passed to a Var is marked with _var_is_string.
Args:
value: The value to serialize.
expected: The expected result.
exp_var_is_string: The expected value of _var_is_string.
"""
v = Var.create_safe(value)
assert v._var_full_name == expected
assert v._var_is_string == exp_var_is_string