diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index c35be95d3..796d28be3 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -2,13 +2,27 @@ from __future__ import annotations +import functools import json import types as builtin_types import warnings from datetime import date, datetime, time, timedelta from enum import Enum from pathlib import Path -from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union, get_type_hints +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Set, + Tuple, + Type, + Union, + get_type_hints, + overload, +) from reflex.base import Base from reflex.constants.colors import Color, format_color @@ -17,15 +31,24 @@ from reflex.utils import exceptions, types # Mapping from type to a serializer. # The serializer should convert the type to a JSON object. SerializedType = Union[str, bool, int, float, list, dict] + + Serializer = Callable[[Type], SerializedType] + + SERIALIZERS: dict[Type, Serializer] = {} +SERIALIZER_TYPES: dict[Type, Type] = {} -def serializer(fn: Serializer) -> Serializer: +def serializer( + fn: Serializer | None = None, + to: Type | None = None, +) -> Serializer: """Decorator to add a serializer for a given type. Args: fn: The function to decorate. + to: The type returned by the serializer. If this is `str`, then any Var created from this type will be treated as a string. Returns: The decorated function. @@ -33,8 +56,9 @@ def serializer(fn: Serializer) -> Serializer: Raises: ValueError: If the function does not take a single argument. """ - # Get the global serializers. - global SERIALIZERS + if fn is None: + # If the function is not provided, return a partial that acts as a decorator. + return functools.partial(serializer, to=to) # type: ignore # Check the type hints to get the type of the argument. type_hints = get_type_hints(fn) @@ -54,18 +78,47 @@ def serializer(fn: Serializer) -> Serializer: f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}." ) + # Apply type transformation if requested + if to is not None: + SERIALIZER_TYPES[type_] = to + get_serializer_type.cache_clear() + # Register the serializer. SERIALIZERS[type_] = fn + get_serializer.cache_clear() # Return the function. return fn -def serialize(value: Any) -> SerializedType | None: +@overload +def serialize( + value: Any, get_type: Literal[True] +) -> Tuple[Optional[SerializedType], Optional[types.GenericType]]: + ... + + +@overload +def serialize(value: Any, get_type: Literal[False]) -> Optional[SerializedType]: + ... + + +@overload +def serialize(value: Any) -> Optional[SerializedType]: + ... + + +def serialize( + value: Any, get_type: bool = False +) -> Union[ + Optional[SerializedType], + Tuple[Optional[SerializedType], Optional[types.GenericType]], +]: """Serialize the value to a JSON string. Args: value: The value to serialize. + get_type: Whether to return the type of the serialized value. Returns: The serialized value, or None if a serializer is not found. @@ -75,13 +128,22 @@ def serialize(value: Any) -> SerializedType | None: # If there is no serializer, return None. if serializer is None: + if get_type: + return None, None return None # Serialize the value. - return serializer(value) + serialized = serializer(value) + + # Return the serialized value and the type. + if get_type: + return serialized, get_serializer_type(type(value)) + else: + return serialized -def get_serializer(type_: Type) -> Serializer | None: +@functools.lru_cache +def get_serializer(type_: Type) -> Optional[Serializer]: """Get the serializer for the type. Args: @@ -90,8 +152,6 @@ def get_serializer(type_: Type) -> Serializer | None: Returns: The serializer for the type, or None if there is no serializer. """ - global SERIALIZERS - # First, check if the type is registered. serializer = SERIALIZERS.get(type_) if serializer is not None: @@ -106,6 +166,30 @@ def get_serializer(type_: Type) -> Serializer | None: return None +@functools.lru_cache +def get_serializer_type(type_: Type) -> Optional[Type]: + """Get the converted type for the type after serializing. + + Args: + type_: The type to get the serializer type for. + + Returns: + The serialized type for the type, or None if there is no type conversion registered. + """ + # First, check if the type is registered. + serializer = SERIALIZER_TYPES.get(type_) + if serializer is not None: + return serializer + + # If the type is not registered, check if it is a subclass of a registered type. + for registered_type, serializer in reversed(SERIALIZER_TYPES.items()): + if types._issubclass(type_, registered_type): + return serializer + + # If there is no serializer, return None. + return None + + def has_serializer(type_: Type) -> bool: """Check if there is a serializer for the type. @@ -118,7 +202,7 @@ def has_serializer(type_: Type) -> bool: return get_serializer(type_) is not None -@serializer +@serializer(to=str) def serialize_type(value: type) -> str: """Serialize a python type. @@ -226,7 +310,7 @@ def serialize_dict(prop: Dict[str, Any]) -> str: return format.unwrap_vars(fprop) -@serializer +@serializer(to=str) def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str: """Serialize a datetime to a JSON string. @@ -239,8 +323,8 @@ def serialize_datetime(dt: Union[date, datetime, time, timedelta]) -> str: return str(dt) -@serializer -def serialize_path(path: Path): +@serializer(to=str) +def serialize_path(path: Path) -> str: """Serialize a pathlib.Path to a JSON string. Args: @@ -265,7 +349,7 @@ def serialize_enum(en: Enum) -> str: return en.value -@serializer +@serializer(to=str) def serialize_color(color: Color) -> str: """Serialize a color. diff --git a/reflex/vars.py b/reflex/vars.py index cad3d735b..ce51c0324 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -347,7 +347,7 @@ class Var: cls, value: Any, _var_is_local: bool = True, - _var_is_string: bool = False, + _var_is_string: bool | None = None, _var_data: Optional[VarData] = None, ) -> Var | None: """Create a var from a value. @@ -380,18 +380,39 @@ class Var: # Try to serialize the value. type_ = type(value) - name = value if type_ in types.JSONType else serializers.serialize(value) + if type_ in types.JSONType: + name = value + else: + name, serialized_type = serializers.serialize(value, get_type=True) + if ( + serialized_type is not None + and _var_is_string is None + and issubclass(serialized_type, str) + ): + _var_is_string = True if name is None: raise VarTypeError( f"No JSON serializer found for var {value} of type {type_}." ) name = name if isinstance(name, str) else format.json_dumps(name) + if _var_is_string is None and type_ is str: + console.deprecate( + feature_name="Creating a Var from a string without specifying _var_is_string", + reason=( + "Specify _var_is_string=False to create a Var that is not a string literal. " + "In the future, creating a Var from a string will be treated as a string literal " + "by default." + ), + deprecation_version="0.5.4", + removal_version="0.6.0", + ) + return BaseVar( _var_name=name, _var_type=type_, _var_is_local=_var_is_local, - _var_is_string=_var_is_string, + _var_is_string=_var_is_string if _var_is_string is not None else False, _var_data=_var_data, ) @@ -400,7 +421,7 @@ class Var: cls, value: Any, _var_is_local: bool = True, - _var_is_string: bool = False, + _var_is_string: bool | None = None, _var_data: Optional[VarData] = None, ) -> Var: """Create a var from a value, asserting that it is not None. @@ -847,19 +868,19 @@ class Var: if invoke_fn: # invoke the function on left operand. operation_name = ( - f"{left_operand_full_name}.{fn}({right_operand_full_name})" - ) # type: ignore + f"{left_operand_full_name}.{fn}({right_operand_full_name})" # type: ignore + ) else: # pass the operands as arguments to the function. operation_name = ( - f"{left_operand_full_name} {op} {right_operand_full_name}" - ) # type: ignore + f"{left_operand_full_name} {op} {right_operand_full_name}" # type: ignore + ) operation_name = f"{fn}({operation_name})" else: # apply operator to operands (left operand right_operand) operation_name = ( - f"{left_operand_full_name} {op} {right_operand_full_name}" - ) # type: ignore + f"{left_operand_full_name} {op} {right_operand_full_name}" # type: ignore + ) operation_name = format.wrap(operation_name, "(") else: # apply operator to left operand ( left_operand) diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 01b276342..e95561318 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -51,11 +51,11 @@ class Var: _var_data: VarData | None = None @classmethod def create( - cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None, + cls, value: Any, _var_is_local: bool = True, _var_is_string: bool | None = None, _var_data: VarData | None = None, ) -> Optional[Var]: ... @classmethod def create_safe( - cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None, + cls, value: Any, _var_is_local: bool = True, _var_is_string: bool | None = None, _var_data: VarData | None = None, ) -> Var: ... @classmethod def __class_getitem__(cls, type_: Type) -> _GenericAlias: ... diff --git a/tests/components/core/test_colors.py b/tests/components/core/test_colors.py index 53aa4d86c..28078059f 100644 --- a/tests/components/core/test_colors.py +++ b/tests/components/core/test_colors.py @@ -2,6 +2,7 @@ import pytest import reflex as rx from reflex.components.datadisplay.code import CodeBlock +from reflex.constants.colors import Color from reflex.vars import Var @@ -50,7 +51,12 @@ def create_color_var(color): ], ) def test_color(color, expected): - assert str(color) == expected + assert color._var_is_string or color._var_type is str + assert color._var_full_name == expected + if color._var_type == Color: + assert str(color) == f"{{`{expected}`}}" + else: + assert str(color) == expected @pytest.mark.parametrize( @@ -96,9 +102,9 @@ def test_color_with_conditionals(cond_var, expected): @pytest.mark.parametrize( "color, expected", [ - (create_color_var(rx.color("red")), "var(--red-7)"), - (create_color_var(rx.color("green", shade=1)), "var(--green-1)"), - (create_color_var(rx.color("blue", alpha=True)), "var(--blue-a7)"), + (create_color_var(rx.color("red")), "{`var(--red-7)`}"), + (create_color_var(rx.color("green", shade=1)), "{`var(--green-1)`}"), + (create_color_var(rx.color("blue", alpha=True)), "{`var(--blue-a7)`}"), ("red", "red"), ("green", "green"), ("blue", "blue"), diff --git a/tests/utils/test_serializers.py b/tests/utils/test_serializers.py index 62834d3cc..b516f5eb3 100644 --- a/tests/utils/test_serializers.py +++ b/tests/utils/test_serializers.py @@ -1,5 +1,6 @@ import datetime from enum import Enum +from pathlib import Path from typing import Any, Dict, List, Type import pytest @@ -90,6 +91,9 @@ def test_add_serializer(): # Remove the serializer. serializers.SERIALIZERS.pop(Foo) + # LRU cache will still have the serializer, so we need to clear it. + assert serializers.has_serializer(Foo) + serializers.get_serializer.cache_clear() assert not serializers.has_serializer(Foo) @@ -194,3 +198,39 @@ def test_serialize(value: Any, expected: str): expected: The expected result. """ assert serializers.serialize(value) == expected + + +@pytest.mark.parametrize( + "value,expected,exp_var_is_string", + [ + ("test", "test", False), + (1, "1", False), + (1.0, "1.0", False), + (True, "true", False), + (False, "false", False), + ([1, 2, 3], "[1, 2, 3]", False), + ([{"key": 1}, {"key": 2}], '[{"key": 1}, {"key": 2}]', False), + (StrEnum.FOO, "foo", False), + ([StrEnum.FOO, StrEnum.BAR], '["foo", "bar"]', False), + ( + BaseSubclass(ts=datetime.timedelta(1, 1, 1)), + '{"ts": "1 day, 0:00:01.000001"}', + False, + ), + (datetime.datetime(2021, 1, 1, 1, 1, 1, 1), "2021-01-01 01:01:01.000001", True), + (Color(color="slate", shade=1), "var(--slate-1)", True), + (BaseSubclass, "BaseSubclass", True), + (Path("."), ".", True), + ], +) +def test_serialize_var_to_str(value: Any, expected: str, exp_var_is_string: bool): + """Test that serialize with `to=str` passed to a Var is marked with _var_is_string. + + Args: + value: The value to serialize. + expected: The expected result. + exp_var_is_string: The expected value of _var_is_string. + """ + v = Var.create_safe(value) + assert v._var_full_name == expected + assert v._var_is_string == exp_var_is_string