use wrapped logic

This commit is contained in:
Khaleel Al-Adhami 2024-11-06 17:58:10 -08:00
parent c684444561
commit 28b854321d

View File

@ -18,6 +18,7 @@ from typing import (
Set, Set,
Tuple, Tuple,
Type, Type,
TypeVar,
Union, Union,
get_type_hints, get_type_hints,
overload, overload,
@ -32,31 +33,33 @@ from reflex.utils import types
SerializedType = Union[str, bool, int, float, list, dict, None] SerializedType = Union[str, bool, int, float, list, dict, None]
Serializer = Callable[[Type], SerializedType] Serializer = Callable[[Any], SerializedType]
SERIALIZERS: dict[Type, Serializer] = {} SERIALIZERS: dict[Type, Serializer] = {}
SERIALIZER_TYPES: dict[Type, Type] = {} SERIALIZER_TYPES: dict[Type, Type] = {}
SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer)
@overload @overload
def serializer( def serializer(
fn: None = None, fn: None = None,
to: Any = None, to: Type[SerializedType] | None = None,
) -> Serializer: ... ) -> Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]: ...
@overload @overload
def serializer( def serializer(
fn: Serializer, fn: SERIALIZED_FUNCTION,
to: None = None, to: Type[SerializedType] | None = None,
) -> functools.partial[Serializer]: ... ) -> SERIALIZED_FUNCTION: ...
def serializer( def serializer(
fn: Serializer | None = None, fn: SERIALIZED_FUNCTION | None = None,
to: Any = None, to: Any = None,
) -> Serializer | functools.partial[Serializer]: ) -> SERIALIZED_FUNCTION | Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]:
"""Decorator to add a serializer for a given type. """Decorator to add a serializer for a given type.
Args: Args:
@ -69,10 +72,8 @@ def serializer(
Raises: Raises:
ValueError: If the function does not take a single argument. ValueError: If the function does not take a single argument.
""" """
if fn is None:
# If the function is not provided, return a partial that acts as a decorator.
return functools.partial(serializer, to=to) # type: ignore
def wrapper(fn: SERIALIZED_FUNCTION) -> SERIALIZED_FUNCTION:
# Check the type hints to get the type of the argument. # Check the type hints to get the type of the argument.
type_hints = get_type_hints(fn) type_hints = get_type_hints(fn)
args = [arg for arg in type_hints if arg != "return"] args = [arg for arg in type_hints if arg != "return"]
@ -91,9 +92,11 @@ def serializer(
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}." f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
) )
to_type = to or type_hints.get("return")
# Apply type transformation if requested # Apply type transformation if requested
if to is not None or ((to := type_hints.get("return")) is not None): if to_type:
SERIALIZER_TYPES[type_] = to SERIALIZER_TYPES[type_] = to_type
get_serializer_type.cache_clear() get_serializer_type.cache_clear()
# Register the serializer. # Register the serializer.
@ -103,6 +106,10 @@ def serializer(
# Return the function. # Return the function.
return fn return fn
if fn is not None:
return wrapper(fn)
return wrapper
@overload @overload
def serialize( def serialize(