use wrapped logic
This commit is contained in:
parent
c684444561
commit
28b854321d
@ -18,6 +18,7 @@ from typing import (
|
|||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
overload,
|
overload,
|
||||||
@ -32,31 +33,33 @@ from reflex.utils import types
|
|||||||
SerializedType = Union[str, bool, int, float, list, dict, None]
|
SerializedType = Union[str, bool, int, float, list, dict, None]
|
||||||
|
|
||||||
|
|
||||||
Serializer = Callable[[Type], SerializedType]
|
Serializer = Callable[[Any], SerializedType]
|
||||||
|
|
||||||
|
|
||||||
SERIALIZERS: dict[Type, Serializer] = {}
|
SERIALIZERS: dict[Type, Serializer] = {}
|
||||||
SERIALIZER_TYPES: dict[Type, Type] = {}
|
SERIALIZER_TYPES: dict[Type, Type] = {}
|
||||||
|
|
||||||
|
SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def serializer(
|
def serializer(
|
||||||
fn: None = None,
|
fn: None = None,
|
||||||
to: Any = None,
|
to: Type[SerializedType] | None = None,
|
||||||
) -> Serializer: ...
|
) -> Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def serializer(
|
def serializer(
|
||||||
fn: Serializer,
|
fn: SERIALIZED_FUNCTION,
|
||||||
to: None = None,
|
to: Type[SerializedType] | None = None,
|
||||||
) -> functools.partial[Serializer]: ...
|
) -> SERIALIZED_FUNCTION: ...
|
||||||
|
|
||||||
|
|
||||||
def serializer(
|
def serializer(
|
||||||
fn: Serializer | None = None,
|
fn: SERIALIZED_FUNCTION | None = None,
|
||||||
to: Any = None,
|
to: Any = None,
|
||||||
) -> Serializer | functools.partial[Serializer]:
|
) -> SERIALIZED_FUNCTION | Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]:
|
||||||
"""Decorator to add a serializer for a given type.
|
"""Decorator to add a serializer for a given type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -69,10 +72,8 @@ def serializer(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the function does not take a single argument.
|
ValueError: If the function does not take a single argument.
|
||||||
"""
|
"""
|
||||||
if fn is None:
|
|
||||||
# If the function is not provided, return a partial that acts as a decorator.
|
|
||||||
return functools.partial(serializer, to=to) # type: ignore
|
|
||||||
|
|
||||||
|
def wrapper(fn: SERIALIZED_FUNCTION) -> SERIALIZED_FUNCTION:
|
||||||
# Check the type hints to get the type of the argument.
|
# Check the type hints to get the type of the argument.
|
||||||
type_hints = get_type_hints(fn)
|
type_hints = get_type_hints(fn)
|
||||||
args = [arg for arg in type_hints if arg != "return"]
|
args = [arg for arg in type_hints if arg != "return"]
|
||||||
@ -91,9 +92,11 @@ def serializer(
|
|||||||
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
|
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
to_type = to or type_hints.get("return")
|
||||||
|
|
||||||
# Apply type transformation if requested
|
# Apply type transformation if requested
|
||||||
if to is not None or ((to := type_hints.get("return")) is not None):
|
if to_type:
|
||||||
SERIALIZER_TYPES[type_] = to
|
SERIALIZER_TYPES[type_] = to_type
|
||||||
get_serializer_type.cache_clear()
|
get_serializer_type.cache_clear()
|
||||||
|
|
||||||
# Register the serializer.
|
# Register the serializer.
|
||||||
@ -103,6 +106,10 @@ def serializer(
|
|||||||
# Return the function.
|
# Return the function.
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
if fn is not None:
|
||||||
|
return wrapper(fn)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def serialize(
|
def serialize(
|
||||||
|
Loading…
Reference in New Issue
Block a user