improve typing for serializer decorator (#4317)
* improve typing for serializer decorator * use wrapped logic * dang it darglint --------- Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
parent
0c482bda3c
commit
8fd5c9f200
@ -18,6 +18,7 @@ from typing import (
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_type_hints,
|
||||
overload,
|
||||
@ -32,17 +33,33 @@ from reflex.utils import types
|
||||
SerializedType = Union[str, bool, int, float, list, dict, None]
|
||||
|
||||
|
||||
Serializer = Callable[[Type], SerializedType]
|
||||
Serializer = Callable[[Any], SerializedType]
|
||||
|
||||
|
||||
SERIALIZERS: dict[Type, Serializer] = {}
|
||||
SERIALIZER_TYPES: dict[Type, Type] = {}
|
||||
|
||||
SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer)
|
||||
|
||||
|
||||
@overload
|
||||
def serializer(
|
||||
fn: None = None,
|
||||
to: Type[SerializedType] | None = None,
|
||||
) -> Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def serializer(
|
||||
fn: SERIALIZED_FUNCTION,
|
||||
to: Type[SerializedType] | None = None,
|
||||
) -> SERIALIZED_FUNCTION: ...
|
||||
|
||||
|
||||
def serializer(
|
||||
fn: Serializer | None = None,
|
||||
to: Type | None = None,
|
||||
) -> Serializer:
|
||||
fn: SERIALIZED_FUNCTION | None = None,
|
||||
to: Any = None,
|
||||
) -> SERIALIZED_FUNCTION | Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]:
|
||||
"""Decorator to add a serializer for a given type.
|
||||
|
||||
Args:
|
||||
@ -51,43 +68,44 @@ def serializer(
|
||||
|
||||
Returns:
|
||||
The decorated function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function does not take a single argument.
|
||||
"""
|
||||
if fn is None:
|
||||
# If the function is not provided, return a partial that acts as a decorator.
|
||||
return functools.partial(serializer, to=to) # type: ignore
|
||||
|
||||
# Check the type hints to get the type of the argument.
|
||||
type_hints = get_type_hints(fn)
|
||||
args = [arg for arg in type_hints if arg != "return"]
|
||||
def wrapper(fn: SERIALIZED_FUNCTION) -> SERIALIZED_FUNCTION:
|
||||
# Check the type hints to get the type of the argument.
|
||||
type_hints = get_type_hints(fn)
|
||||
args = [arg for arg in type_hints if arg != "return"]
|
||||
|
||||
# Make sure the function takes a single argument.
|
||||
if len(args) != 1:
|
||||
raise ValueError("Serializer must take a single argument.")
|
||||
# Make sure the function takes a single argument.
|
||||
if len(args) != 1:
|
||||
raise ValueError("Serializer must take a single argument.")
|
||||
|
||||
# Get the type of the argument.
|
||||
type_ = type_hints[args[0]]
|
||||
# Get the type of the argument.
|
||||
type_ = type_hints[args[0]]
|
||||
|
||||
# Make sure the type is not already registered.
|
||||
registered_fn = SERIALIZERS.get(type_)
|
||||
if registered_fn is not None and registered_fn != fn:
|
||||
raise ValueError(
|
||||
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
|
||||
)
|
||||
# Make sure the type is not already registered.
|
||||
registered_fn = SERIALIZERS.get(type_)
|
||||
if registered_fn is not None and registered_fn != fn:
|
||||
raise ValueError(
|
||||
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
|
||||
)
|
||||
|
||||
# Apply type transformation if requested
|
||||
if to is not None or ((to := type_hints.get("return")) is not None):
|
||||
SERIALIZER_TYPES[type_] = to
|
||||
get_serializer_type.cache_clear()
|
||||
to_type = to or type_hints.get("return")
|
||||
|
||||
# Register the serializer.
|
||||
SERIALIZERS[type_] = fn
|
||||
get_serializer.cache_clear()
|
||||
# Apply type transformation if requested
|
||||
if to_type:
|
||||
SERIALIZER_TYPES[type_] = to_type
|
||||
get_serializer_type.cache_clear()
|
||||
|
||||
# Return the function.
|
||||
return fn
|
||||
# Register the serializer.
|
||||
SERIALIZERS[type_] = fn
|
||||
get_serializer.cache_clear()
|
||||
|
||||
# Return the function.
|
||||
return fn
|
||||
|
||||
if fn is not None:
|
||||
return wrapper(fn)
|
||||
return wrapper
|
||||
|
||||
|
||||
@overload
|
||||
|
Loading…
Reference in New Issue
Block a user