improve typing for serializer decorator (#4317)

* improve typing for serializer decorator

* use wrapped logic

* dang it darglint

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
benedikt-bartscher 2024-11-08 00:52:11 +01:00 committed by GitHub
parent 0c482bda3c
commit 8fd5c9f200
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,6 +18,7 @@ from typing import (
Set,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
overload,
@ -32,17 +33,33 @@ from reflex.utils import types
SerializedType = Union[str, bool, int, float, list, dict, None]
Serializer = Callable[[Type], SerializedType]
Serializer = Callable[[Any], SerializedType]
SERIALIZERS: dict[Type, Serializer] = {}
SERIALIZER_TYPES: dict[Type, Type] = {}
SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer)
@overload
def serializer(
fn: None = None,
to: Type[SerializedType] | None = None,
) -> Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]: ...
@overload
def serializer(
fn: SERIALIZED_FUNCTION,
to: Type[SerializedType] | None = None,
) -> SERIALIZED_FUNCTION: ...
def serializer(
fn: Serializer | None = None,
to: Type | None = None,
) -> Serializer:
fn: SERIALIZED_FUNCTION | None = None,
to: Any = None,
) -> SERIALIZED_FUNCTION | Callable[[SERIALIZED_FUNCTION], SERIALIZED_FUNCTION]:
"""Decorator to add a serializer for a given type.
Args:
@ -51,43 +68,44 @@ def serializer(
Returns:
The decorated function.
Raises:
ValueError: If the function does not take a single argument.
"""
if fn is None:
# If the function is not provided, return a partial that acts as a decorator.
return functools.partial(serializer, to=to) # type: ignore
# Check the type hints to get the type of the argument.
type_hints = get_type_hints(fn)
args = [arg for arg in type_hints if arg != "return"]
def wrapper(fn: SERIALIZED_FUNCTION) -> SERIALIZED_FUNCTION:
# Check the type hints to get the type of the argument.
type_hints = get_type_hints(fn)
args = [arg for arg in type_hints if arg != "return"]
# Make sure the function takes a single argument.
if len(args) != 1:
raise ValueError("Serializer must take a single argument.")
# Make sure the function takes a single argument.
if len(args) != 1:
raise ValueError("Serializer must take a single argument.")
# Get the type of the argument.
type_ = type_hints[args[0]]
# Get the type of the argument.
type_ = type_hints[args[0]]
# Make sure the type is not already registered.
registered_fn = SERIALIZERS.get(type_)
if registered_fn is not None and registered_fn != fn:
raise ValueError(
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
)
# Make sure the type is not already registered.
registered_fn = SERIALIZERS.get(type_)
if registered_fn is not None and registered_fn != fn:
raise ValueError(
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
)
# Apply type transformation if requested
if to is not None or ((to := type_hints.get("return")) is not None):
SERIALIZER_TYPES[type_] = to
get_serializer_type.cache_clear()
to_type = to or type_hints.get("return")
# Register the serializer.
SERIALIZERS[type_] = fn
get_serializer.cache_clear()
# Apply type transformation if requested
if to_type:
SERIALIZER_TYPES[type_] = to_type
get_serializer_type.cache_clear()
# Return the function.
return fn
# Register the serializer.
SERIALIZERS[type_] = fn
get_serializer.cache_clear()
# Return the function.
return fn
if fn is not None:
return wrapper(fn)
return wrapper
@overload