use wrapped logic

This commit is contained in:
Khaleel Al-Adhami 2024-11-06 17:58:10 -08:00
parent c684444561
commit 28b854321d

View File

@ -18,6 +18,7 @@ from typing import (
Set,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
overload,
@ -32,31 +33,33 @@ from reflex.utils import types
SerializedType = Union[str, bool, int, float, list, dict, None]
Serializer = Callable[[Type], SerializedType]
Serializer = Callable[[Any], SerializedType]
SERIALIZERS: dict[Type, Serializer] = {}
SERIALIZER_TYPES: dict[Type, Type] = {}
SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer)
@overload
def serializer(
fn: None = None,
to: Any = None,
) -> Serializer: ...
to: Type[SerializedType] | None = None,
) -> Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]: ...
@overload
def serializer(
fn: Serializer,
to: None = None,
) -> functools.partial[Serializer]: ...
fn: SERIALIZED_FUNCTION,
to: Type[SerializedType] | None = None,
) -> SERIALIZED_FUNCTION: ...
def serializer(
fn: Serializer | None = None,
fn: SERIALIZED_FUNCTION | None = None,
to: Any = None,
) -> Serializer | functools.partial[Serializer]:
) -> SERIALIZED_FUNCTION | Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]:
"""Decorator to add a serializer for a given type.
Args:
@ -69,10 +72,8 @@ def serializer(
Raises:
ValueError: If the function does not take a single argument.
"""
if fn is None:
# If the function is not provided, return a partial that acts as a decorator.
return functools.partial(serializer, to=to) # type: ignore
def wrapper(fn: SERIALIZED_FUNCTION) -> SERIALIZED_FUNCTION:
# Check the type hints to get the type of the argument.
type_hints = get_type_hints(fn)
args = [arg for arg in type_hints if arg != "return"]
@ -91,9 +92,11 @@ def serializer(
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
)
to_type = to or type_hints.get("return")
# Apply type transformation if requested
if to is not None or ((to := type_hints.get("return")) is not None):
SERIALIZER_TYPES[type_] = to
if to_type:
SERIALIZER_TYPES[type_] = to_type
get_serializer_type.cache_clear()
# Register the serializer.
@ -103,6 +106,10 @@ def serializer(
# Return the function.
return fn
if fn is not None:
return wrapper(fn)
return wrapper
@overload
def serialize(