diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index d3dbb1d4c..b87909aec 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -18,6 +18,7 @@ from typing import ( Set, Tuple, Type, + TypeVar, Union, get_type_hints, overload, @@ -32,17 +33,33 @@ from reflex.utils import types SerializedType = Union[str, bool, int, float, list, dict, None] -Serializer = Callable[[Type], SerializedType] +Serializer = Callable[[Any], SerializedType] SERIALIZERS: dict[Type, Serializer] = {} SERIALIZER_TYPES: dict[Type, Type] = {} +SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer) + + +@overload +def serializer( + fn: None = None, + to: Type[SerializedType] | None = None, +) -> Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]: ... + + +@overload +def serializer( + fn: SERIALIZED_FUNCTION, + to: Type[SerializedType] | None = None, +) -> SERIALIZED_FUNCTION: ... + def serializer( - fn: Serializer | None = None, - to: Type | None = None, -) -> Serializer: + fn: SERIALIZED_FUNCTION | None = None, + to: Any = None, +) -> SERIALIZED_FUNCTION | Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]: """Decorator to add a serializer for a given type. Args: @@ -51,43 +68,44 @@ def serializer( Returns: The decorated function. - - Raises: - ValueError: If the function does not take a single argument. """ - if fn is None: - # If the function is not provided, return a partial that acts as a decorator. - return functools.partial(serializer, to=to) # type: ignore - # Check the type hints to get the type of the argument. - type_hints = get_type_hints(fn) - args = [arg for arg in type_hints if arg != "return"] + def wrapper(fn: SERIALIZED_FUNCTION) -> SERIALIZED_FUNCTION: + # Check the type hints to get the type of the argument. + type_hints = get_type_hints(fn) + args = [arg for arg in type_hints if arg != "return"] - # Make sure the function takes a single argument. - if len(args) != 1: - raise ValueError("Serializer must take a single argument.") + # Make sure the function takes a single argument. + if len(args) != 1: + raise ValueError("Serializer must take a single argument.") - # Get the type of the argument. - type_ = type_hints[args[0]] + # Get the type of the argument. + type_ = type_hints[args[0]] - # Make sure the type is not already registered. - registered_fn = SERIALIZERS.get(type_) - if registered_fn is not None and registered_fn != fn: - raise ValueError( - f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}." - ) + # Make sure the type is not already registered. + registered_fn = SERIALIZERS.get(type_) + if registered_fn is not None and registered_fn != fn: + raise ValueError( + f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}." + ) - # Apply type transformation if requested - if to is not None or ((to := type_hints.get("return")) is not None): - SERIALIZER_TYPES[type_] = to - get_serializer_type.cache_clear() + to_type = to or type_hints.get("return") - # Register the serializer. - SERIALIZERS[type_] = fn - get_serializer.cache_clear() + # Apply type transformation if requested + if to_type: + SERIALIZER_TYPES[type_] = to_type + get_serializer_type.cache_clear() - # Return the function. - return fn + # Register the serializer. + SERIALIZERS[type_] = fn + get_serializer.cache_clear() + + # Return the function. + return fn + + if fn is not None: + return wrapper(fn) + return wrapper @overload