improve typing for serializer decorator (#4317)
* improve typing for serializer decorator * use wrapped logic * dang it darglint --------- Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
parent
0c482bda3c
commit
8fd5c9f200
@ -18,6 +18,7 @@ from typing import (
|
|||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
overload,
|
overload,
|
||||||
@ -32,17 +33,33 @@ from reflex.utils import types
|
|||||||
SerializedType = Union[str, bool, int, float, list, dict, None]
|
SerializedType = Union[str, bool, int, float, list, dict, None]
|
||||||
|
|
||||||
|
|
||||||
Serializer = Callable[[Type], SerializedType]
|
Serializer = Callable[[Any], SerializedType]
|
||||||
|
|
||||||
|
|
||||||
SERIALIZERS: dict[Type, Serializer] = {}
|
SERIALIZERS: dict[Type, Serializer] = {}
|
||||||
SERIALIZER_TYPES: dict[Type, Type] = {}
|
SERIALIZER_TYPES: dict[Type, Type] = {}
|
||||||
|
|
||||||
|
SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def serializer(
|
||||||
|
fn: None = None,
|
||||||
|
to: Type[SerializedType] | None = None,
|
||||||
|
) -> Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def serializer(
|
||||||
|
fn: SERIALIZED_FUNCTION,
|
||||||
|
to: Type[SerializedType] | None = None,
|
||||||
|
) -> SERIALIZED_FUNCTION: ...
|
||||||
|
|
||||||
|
|
||||||
def serializer(
|
def serializer(
|
||||||
fn: Serializer | None = None,
|
fn: SERIALIZED_FUNCTION | None = None,
|
||||||
to: Type | None = None,
|
to: Any = None,
|
||||||
) -> Serializer:
|
) -> SERIALIZED_FUNCTION | Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]:
|
||||||
"""Decorator to add a serializer for a given type.
|
"""Decorator to add a serializer for a given type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -51,43 +68,44 @@ def serializer(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The decorated function.
|
The decorated function.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the function does not take a single argument.
|
|
||||||
"""
|
"""
|
||||||
if fn is None:
|
|
||||||
# If the function is not provided, return a partial that acts as a decorator.
|
|
||||||
return functools.partial(serializer, to=to) # type: ignore
|
|
||||||
|
|
||||||
# Check the type hints to get the type of the argument.
|
def wrapper(fn: SERIALIZED_FUNCTION) -> SERIALIZED_FUNCTION:
|
||||||
type_hints = get_type_hints(fn)
|
# Check the type hints to get the type of the argument.
|
||||||
args = [arg for arg in type_hints if arg != "return"]
|
type_hints = get_type_hints(fn)
|
||||||
|
args = [arg for arg in type_hints if arg != "return"]
|
||||||
|
|
||||||
# Make sure the function takes a single argument.
|
# Make sure the function takes a single argument.
|
||||||
if len(args) != 1:
|
if len(args) != 1:
|
||||||
raise ValueError("Serializer must take a single argument.")
|
raise ValueError("Serializer must take a single argument.")
|
||||||
|
|
||||||
# Get the type of the argument.
|
# Get the type of the argument.
|
||||||
type_ = type_hints[args[0]]
|
type_ = type_hints[args[0]]
|
||||||
|
|
||||||
# Make sure the type is not already registered.
|
# Make sure the type is not already registered.
|
||||||
registered_fn = SERIALIZERS.get(type_)
|
registered_fn = SERIALIZERS.get(type_)
|
||||||
if registered_fn is not None and registered_fn != fn:
|
if registered_fn is not None and registered_fn != fn:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
|
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply type transformation if requested
|
to_type = to or type_hints.get("return")
|
||||||
if to is not None or ((to := type_hints.get("return")) is not None):
|
|
||||||
SERIALIZER_TYPES[type_] = to
|
|
||||||
get_serializer_type.cache_clear()
|
|
||||||
|
|
||||||
# Register the serializer.
|
# Apply type transformation if requested
|
||||||
SERIALIZERS[type_] = fn
|
if to_type:
|
||||||
get_serializer.cache_clear()
|
SERIALIZER_TYPES[type_] = to_type
|
||||||
|
get_serializer_type.cache_clear()
|
||||||
|
|
||||||
# Return the function.
|
# Register the serializer.
|
||||||
return fn
|
SERIALIZERS[type_] = fn
|
||||||
|
get_serializer.cache_clear()
|
||||||
|
|
||||||
|
# Return the function.
|
||||||
|
return fn
|
||||||
|
|
||||||
|
if fn is not None:
|
||||||
|
return wrapper(fn)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
Loading…
Reference in New Issue
Block a user