[REF-3009] type transforming serializers (#3227)

* wip type transforming serializers

* old python sucks

* typing fixups

* Expose the `to` parameter on `rx.serializer` for type conversion

Serializers can also return a tuple of `(serialized_value, type)`, if both ways
are specified, then the returned value MUST match the `to` parameter.

When initializing a new rx.Var, if `_var_is_string` is not specified and the serializer returns a `str` type, then mark `_var_is_string=True` to indicate that the Var should be treated like a string literal.

Include datetime, color, types, and paths as "serializing to str" type.

Avoid other changes at this point to reduce fallout from this change:

  Notably, the `serialize_str` function does NOT cast to `str`, which
  would cause existing code to treat all Var initialized with a str as a
  str literal even though this was NOT the default before.

Update test cases to accomodate these changes.

* Raise deprecation warning for rx.Var.create with string literal

In the future, we will treat strings as string literals in the JS code. To get
a Var that is not treated like a string, pass _var_is_string=False.

This will allow our serializers to automatically identify cast string literals
with less special cases (and the special cases need to be explicitly
identified).

* Add test case for mismatched serialized types

* fix old python

* Remove serializer returning a tuple feature

Simplify the logic; instead of making a wrapper function that returns
a tuple, just save the type conversions in a separate global.

* Reset the LRU cache when adding new serializers

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
benedikt-bartscher 2024-06-07 18:50:10 +02:00 committed by GitHub
parent 168501d58a
commit e42d4ed9ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 181 additions and 30 deletions

View File

@ -2,13 +2,27 @@
from __future__ import annotations
import functools
import json
import types as builtin_types
import warnings
from datetime import date, datetime, time, timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union, get_type_hints
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Set,
Tuple,
Type,
Union,
get_type_hints,
overload,
)
from reflex.base import Base
from reflex.constants.colors import Color, format_color
@ -17,15 +31,24 @@ from reflex.utils import exceptions, types
# Mapping from type to a serializer.
# The serializer should convert the type to a JSON object.
SerializedType = Union[str, bool, int, float, list, dict]
Serializer = Callable[[Type], SerializedType]
SERIALIZERS: dict[Type, Serializer] = {}
SERIALIZER_TYPES: dict[Type, Type] = {}
def serializer(fn: Serializer) -> Serializer:
def serializer(
fn: Serializer | None = None,
to: Type | None = None,
) -> Serializer:
"""Decorator to add a serializer for a given type.
Args:
fn: The function to decorate.
to: The type returned by the serializer. If this is `str`, then any Var created from this type will be treated as a string.
Returns:
The decorated function.
@ -33,8 +56,9 @@ def serializer(fn: Serializer) -> Serializer:
Raises:
ValueError: If the function does not take a single argument.
"""
# Get the global serializers.
global SERIALIZERS
if fn is None:
# If the function is not provided, return a partial that acts as a decorator.
return functools.partial(serializer, to=to) # type: ignore
# Check the type hints to get the type of the argument.
type_hints = get_type_hints(fn)
@ -54,18 +78,47 @@ def serializer(fn: Serializer) -> Serializer:
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
)
# Apply type transformation if requested
if to is not None:
SERIALIZER_TYPES[type_] = to
get_serializer_type.cache_clear()
# Register the serializer.
SERIALIZERS[type_] = fn
get_serializer.cache_clear()
# Return the function.
return fn
def serialize(value: Any) -> SerializedType | None:
@overload
def serialize(
value: Any, get_type: Literal[True]
) -> Tuple[Optional[SerializedType], Optional[types.GenericType]]:
...
@overload
def serialize(value: Any, get_type: Literal[False]) -> Optional[SerializedType]:
...
@overload
def serialize(value: Any) -> Optional[SerializedType]:
...
def serialize(
value: Any, get_type: bool = False
) -> Union[
Optional[SerializedType],
Tuple[Optional[SerializedType], Optional[types.GenericType]],
]:
"""Serialize the value to a JSON string.
Args:
value: The value to serialize.
get_type: Whether to return the type of the serialized value.
Returns:
The serialized value, or None if a serializer is not found.
@ -75,13 +128,22 @@ def serialize(value: Any) -> SerializedType | None:
# If there is no serializer, return None.
if serializer is None:
if get_type:
return None, None
return None
# Serialize the value.
return serializer(value)
serialized = serializer(value)
# Return the serialized value and the type.
if get_type:
return serialized, get_serializer_type(type(value))
else:
return serialized
def get_serializer(type_: Type) -> Serializer | None:
@functools.lru_cache
def get_serializer(type_: Type) -> Optional[Serializer]:
"""Get the serializer for the type.
Args:
@ -90,8 +152,6 @@ def get_serializer(type_: Type) -> Serializer | None:
Returns:
The serializer for the type, or None if there is no serializer.
"""
global SERIALIZERS
# First, check if the type is registered.
serializer = SERIALIZERS.get(type_)
if serializer is not None:
@ -106,6 +166,30 @@ def get_serializer(type_: Type) -> Serializer | None:
return None
@functools.lru_cache
def get_serializer_type(type_: Type) -> Optional[Type]:
"""Get the converted type for the type after serializing.
Args:
type_: The type to get the serializer type for.
Returns:
The serialized type for the type, or None if there is no type conversion registered.
"""
# First, check if the type is registered.
serializer = SERIALIZER_TYPES.get(type_)
if serializer is not None:
return serializer
# If the type is not registered, check if it is a subclass of a registered type.
for registered_type, serializer in reversed(SERIALIZER_TYPES.items()):
if types._issubclass(type_, registered_type):
return serializer
# If there is no serializer, return None.
return None
def has_serializer(type_: Type) -> bool:
"""Check if there is a serializer for the type.
@ -118,7 +202,7 @@ def has_serializer(type_: Type) -> bool:
return get_serializer(type_) is not None
@serializer
@serializer(to=str)
def serialize_type(value: type) -> str:
"""Serialize a python type.
@ -226,7 +310,7 @@ def serialize_dict(prop: Dict[str, Any]) -> str:
return format.unwrap_vars(fprop)
@serializer
@serializer(to=str)
def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str:
"""Serialize a datetime to a JSON string.
@ -239,8 +323,8 @@ def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str:
return str(dt)
@serializer
def serialize_path(path: Path):
@serializer(to=str)
def serialize_path(path: Path) -> str:
"""Serialize a pathlib.Path to a JSON string.
Args:
@ -265,7 +349,7 @@ def serialize_enum(en: Enum) -> str:
return en.value
@serializer
@serializer(to=str)
def serialize_color(color: Color) -> str:
"""Serialize a color.

View File

@ -347,7 +347,7 @@ class Var:
cls,
value: Any,
_var_is_local: bool = True,
_var_is_string: bool = False,
_var_is_string: bool | None = None,
_var_data: Optional[VarData] = None,
) -> Var | None:
"""Create a var from a value.
@ -380,18 +380,39 @@ class Var:
# Try to serialize the value.
type_ = type(value)
name = value if type_ in types.JSONType else serializers.serialize(value)
if type_ in types.JSONType:
name = value
else:
name, serialized_type = serializers.serialize(value, get_type=True)
if (
serialized_type is not None
and _var_is_string is None
and issubclass(serialized_type, str)
):
_var_is_string = True
if name is None:
raise VarTypeError(
f"No JSON serializer found for var {value} of type {type_}."
)
name = name if isinstance(name, str) else format.json_dumps(name)
if _var_is_string is None and type_ is str:
console.deprecate(
feature_name="Creating a Var from a string without specifying _var_is_string",
reason=(
"Specify _var_is_string=False to create a Var that is not a string literal. "
"In the future, creating a Var from a string will be treated as a string literal "
"by default."
),
deprecation_version="0.5.4",
removal_version="0.6.0",
)
return BaseVar(
_var_name=name,
_var_type=type_,
_var_is_local=_var_is_local,
_var_is_string=_var_is_string,
_var_is_string=_var_is_string if _var_is_string is not None else False,
_var_data=_var_data,
)
@ -400,7 +421,7 @@ class Var:
cls,
value: Any,
_var_is_local: bool = True,
_var_is_string: bool = False,
_var_is_string: bool | None = None,
_var_data: Optional[VarData] = None,
) -> Var:
"""Create a var from a value, asserting that it is not None.
@ -847,19 +868,19 @@ class Var:
if invoke_fn:
# invoke the function on left operand.
operation_name = (
f"{left_operand_full_name}.{fn}({right_operand_full_name})"
) # type: ignore
f"{left_operand_full_name}.{fn}({right_operand_full_name})" # type: ignore
)
else:
# pass the operands as arguments to the function.
operation_name = (
f"{left_operand_full_name} {op} {right_operand_full_name}"
) # type: ignore
f"{left_operand_full_name} {op} {right_operand_full_name}" # type: ignore
)
operation_name = f"{fn}({operation_name})"
else:
# apply operator to operands (left operand <operator> right_operand)
operation_name = (
f"{left_operand_full_name} {op} {right_operand_full_name}"
) # type: ignore
f"{left_operand_full_name} {op} {right_operand_full_name}" # type: ignore
)
operation_name = format.wrap(operation_name, "(")
else:
# apply operator to left operand (<operator> left_operand)

View File

@ -51,11 +51,11 @@ class Var:
_var_data: VarData | None = None
@classmethod
def create(
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool | None = None, _var_data: VarData | None = None,
) -> Optional[Var]: ...
@classmethod
def create_safe(
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool | None = None, _var_data: VarData | None = None,
) -> Var: ...
@classmethod
def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...

View File

@ -2,6 +2,7 @@ import pytest
import reflex as rx
from reflex.components.datadisplay.code import CodeBlock
from reflex.constants.colors import Color
from reflex.vars import Var
@ -50,7 +51,12 @@ def create_color_var(color):
],
)
def test_color(color, expected):
assert str(color) == expected
assert color._var_is_string or color._var_type is str
assert color._var_full_name == expected
if color._var_type == Color:
assert str(color) == f"{{`{expected}`}}"
else:
assert str(color) == expected
@pytest.mark.parametrize(
@ -96,9 +102,9 @@ def test_color_with_conditionals(cond_var, expected):
@pytest.mark.parametrize(
"color, expected",
[
(create_color_var(rx.color("red")), "var(--red-7)"),
(create_color_var(rx.color("green", shade=1)), "var(--green-1)"),
(create_color_var(rx.color("blue", alpha=True)), "var(--blue-a7)"),
(create_color_var(rx.color("red")), "{`var(--red-7)`}"),
(create_color_var(rx.color("green", shade=1)), "{`var(--green-1)`}"),
(create_color_var(rx.color("blue", alpha=True)), "{`var(--blue-a7)`}"),
("red", "red"),
("green", "green"),
("blue", "blue"),

View File

@ -1,5 +1,6 @@
import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Type
import pytest
@ -90,6 +91,9 @@ def test_add_serializer():
# Remove the serializer.
serializers.SERIALIZERS.pop(Foo)
# LRU cache will still have the serializer, so we need to clear it.
assert serializers.has_serializer(Foo)
serializers.get_serializer.cache_clear()
assert not serializers.has_serializer(Foo)
@ -194,3 +198,39 @@ def test_serialize(value: Any, expected: str):
expected: The expected result.
"""
assert serializers.serialize(value) == expected
@pytest.mark.parametrize(
"value,expected,exp_var_is_string",
[
("test", "test", False),
(1, "1", False),
(1.0, "1.0", False),
(True, "true", False),
(False, "false", False),
([1, 2, 3], "[1, 2, 3]", False),
([{"key": 1}, {"key": 2}], '[{"key": 1}, {"key": 2}]', False),
(StrEnum.FOO, "foo", False),
([StrEnum.FOO, StrEnum.BAR], '["foo", "bar"]', False),
(
BaseSubclass(ts=datetime.timedelta(1, 1, 1)),
'{"ts": "1 day, 0:00:01.000001"}',
False,
),
(datetime.datetime(2021, 1, 1, 1, 1, 1, 1), "2021-01-01 01:01:01.000001", True),
(Color(color="slate", shade=1), "var(--slate-1)", True),
(BaseSubclass, "BaseSubclass", True),
(Path("."), ".", True),
],
)
def test_serialize_var_to_str(value: Any, expected: str, exp_var_is_string: bool):
"""Test that serialize with `to=str` passed to a Var is marked with _var_is_string.
Args:
value: The value to serialize.
expected: The expected result.
exp_var_is_string: The expected value of _var_is_string.
"""
v = Var.create_safe(value)
assert v._var_full_name == expected
assert v._var_is_string == exp_var_is_string