use wrapped logic
This commit is contained in:
parent
c684444561
commit
28b854321d
@ -18,6 +18,7 @@ from typing import (
|
|||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
overload,
|
overload,
|
||||||
@ -32,31 +33,33 @@ from reflex.utils import types
|
|||||||
SerializedType = Union[str, bool, int, float, list, dict, None]
|
SerializedType = Union[str, bool, int, float, list, dict, None]
|
||||||
|
|
||||||
|
|
||||||
Serializer = Callable[[Type], SerializedType]
|
Serializer = Callable[[Any], SerializedType]
|
||||||
|
|
||||||
|
|
||||||
SERIALIZERS: dict[Type, Serializer] = {}
|
SERIALIZERS: dict[Type, Serializer] = {}
|
||||||
SERIALIZER_TYPES: dict[Type, Type] = {}
|
SERIALIZER_TYPES: dict[Type, Type] = {}
|
||||||
|
|
||||||
|
SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def serializer(
|
def serializer(
|
||||||
fn: None = None,
|
fn: None = None,
|
||||||
to: Any = None,
|
to: Type[SerializedType] | None = None,
|
||||||
) -> Serializer: ...
|
) -> Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def serializer(
|
def serializer(
|
||||||
fn: Serializer,
|
fn: SERIALIZED_FUNCTION,
|
||||||
to: None = None,
|
to: Type[SerializedType] | None = None,
|
||||||
) -> functools.partial[Serializer]: ...
|
) -> SERIALIZED_FUNCTION: ...
|
||||||
|
|
||||||
|
|
||||||
def serializer(
|
def serializer(
|
||||||
fn: Serializer | None = None,
|
fn: SERIALIZED_FUNCTION | None = None,
|
||||||
to: Any = None,
|
to: Any = None,
|
||||||
) -> Serializer | functools.partial[Serializer]:
|
) -> SERIALIZED_FUNCTION | Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]:
|
||||||
"""Decorator to add a serializer for a given type.
|
"""Decorator to add a serializer for a given type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -69,39 +72,43 @@ def serializer(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the function does not take a single argument.
|
ValueError: If the function does not take a single argument.
|
||||||
"""
|
"""
|
||||||
if fn is None:
|
|
||||||
# If the function is not provided, return a partial that acts as a decorator.
|
|
||||||
return functools.partial(serializer, to=to) # type: ignore
|
|
||||||
|
|
||||||
# Check the type hints to get the type of the argument.
|
def wrapper(fn: SERIALIZED_FUNCTION) -> SERIALIZED_FUNCTION:
|
||||||
type_hints = get_type_hints(fn)
|
# Check the type hints to get the type of the argument.
|
||||||
args = [arg for arg in type_hints if arg != "return"]
|
type_hints = get_type_hints(fn)
|
||||||
|
args = [arg for arg in type_hints if arg != "return"]
|
||||||
|
|
||||||
# Make sure the function takes a single argument.
|
# Make sure the function takes a single argument.
|
||||||
if len(args) != 1:
|
if len(args) != 1:
|
||||||
raise ValueError("Serializer must take a single argument.")
|
raise ValueError("Serializer must take a single argument.")
|
||||||
|
|
||||||
# Get the type of the argument.
|
# Get the type of the argument.
|
||||||
type_ = type_hints[args[0]]
|
type_ = type_hints[args[0]]
|
||||||
|
|
||||||
# Make sure the type is not already registered.
|
# Make sure the type is not already registered.
|
||||||
registered_fn = SERIALIZERS.get(type_)
|
registered_fn = SERIALIZERS.get(type_)
|
||||||
if registered_fn is not None and registered_fn != fn:
|
if registered_fn is not None and registered_fn != fn:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
|
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply type transformation if requested
|
to_type = to or type_hints.get("return")
|
||||||
if to is not None or ((to := type_hints.get("return")) is not None):
|
|
||||||
SERIALIZER_TYPES[type_] = to
|
|
||||||
get_serializer_type.cache_clear()
|
|
||||||
|
|
||||||
# Register the serializer.
|
# Apply type transformation if requested
|
||||||
SERIALIZERS[type_] = fn
|
if to_type:
|
||||||
get_serializer.cache_clear()
|
SERIALIZER_TYPES[type_] = to_type
|
||||||
|
get_serializer_type.cache_clear()
|
||||||
|
|
||||||
# Return the function.
|
# Register the serializer.
|
||||||
return fn
|
SERIALIZERS[type_] = fn
|
||||||
|
get_serializer.cache_clear()
|
||||||
|
|
||||||
|
# Return the function.
|
||||||
|
return fn
|
||||||
|
|
||||||
|
if fn is not None:
|
||||||
|
return wrapper(fn)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
Loading…
Reference in New Issue
Block a user