reflex/reflex/utils/types.py
Malte Klemm 86526cba51
[REF-2127] Loosen requirements (#2796)
* Remove upper bounds of most dependencies.

Also adds a import try except block for pydantic.v1 and relocks.

Keep black and ruff to not mess to much with current formatting

Make pyright see the right import as long as constraint still lock pydantiv v1

Down pin pytest-asyncio again due to known issue

Fix upload handler with latest versions of fastapi

Change comment

* Add changed lockfile

* Set max versions for deps

* Revert app.pyi

---------

Co-authored-by: Malte Klemm <malte.klemm@blueyonder.com>
Co-authored-by: Nikhil Rao <nikhil@reflex.dev>
2024-03-29 09:26:53 -07:00

464 lines
13 KiB
Python

"""Contains custom types and methods to check types."""
from __future__ import annotations
import contextlib
import inspect
import types
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
List,
Literal,
Optional,
Type,
Union,
_GenericAlias, # type: ignore
get_args,
get_origin,
get_type_hints,
)
import sqlalchemy
try:
# TODO The type checking guard can be removed once
# reflex-hosting-cli tools are compatible with pydantic v2
if not TYPE_CHECKING:
import pydantic.v1.fields as ModelField
else:
raise ModuleNotFoundError
except ModuleNotFoundError:
from pydantic.fields import ModelField
from sqlalchemy.ext.associationproxy import AssociationProxyInstance
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
from reflex import constants
from reflex.base import Base
from reflex.utils import serializers
# Potential GenericAlias types for isinstance checks.
GenericAliasTypes = [_GenericAlias]
with contextlib.suppress(ImportError):
# For newer versions of Python.
from types import GenericAlias # type: ignore
GenericAliasTypes.append(GenericAlias)
with contextlib.suppress(ImportError):
# For older versions of Python.
from typing import _SpecialGenericAlias # type: ignore
GenericAliasTypes.append(_SpecialGenericAlias)
GenericAliasTypes = tuple(GenericAliasTypes)
# Potential Union types for isinstance checks (UnionType added in py3.10).
UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,)
# Union of generic types.
GenericType = Union[Type, _GenericAlias]
# Valid state var types.
JSONType = {str, int, float, bool}
PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
StateVar = Union[PrimitiveType, Base, None]
StateIterVar = Union[list, set, tuple]
# ArgsSpec = Callable[[Var], list[Var]]
ArgsSpec = Callable
class Unset:
"""A class to represent an unset value.
This is used to differentiate between a value that is not set and a value that is set to None.
"""
def __repr__(self) -> str:
"""Return the string representation of the class.
Returns:
The string representation of the class.
"""
return "Unset"
def __bool__(self) -> bool:
"""Return False when the class is used in a boolean context.
Returns:
False
"""
return False
def is_generic_alias(cls: GenericType) -> bool:
"""Check whether the class is a generic alias.
Args:
cls: The class to check.
Returns:
Whether the class is a generic alias.
"""
return isinstance(cls, GenericAliasTypes)
def is_union(cls: GenericType) -> bool:
"""Check if a class is a Union.
Args:
cls: The class to check.
Returns:
Whether the class is a Union.
"""
return get_origin(cls) in UnionTypes
def is_literal(cls: GenericType) -> bool:
"""Check if a class is a Literal.
Args:
cls: The class to check.
Returns:
Whether the class is a literal.
"""
return get_origin(cls) is Literal
def is_optional(cls: GenericType) -> bool:
"""Check if a class is an Optional.
Args:
cls: The class to check.
Returns:
Whether the class is an Optional.
"""
return is_union(cls) and type(None) in get_args(cls)
def get_property_hint(attr: Any | None) -> GenericType | None:
"""Check if an attribute is a property and return its type hint.
Args:
attr: The descriptor to check.
Returns:
The type hint of the property, if it is a property, else None.
"""
if not isinstance(attr, (property, hybrid_property)):
return None
hints = get_type_hints(attr.fget)
return hints.get("return", None)
def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
"""Check if an attribute can be accessed on the cls and return its type.
Supports pydantic models, unions, and annotated attributes on rx.Model.
Args:
cls: The class to check.
name: The name of the attribute to check.
Returns:
The type of the attribute, if accessible, or None
"""
from reflex.model import Model
attr = getattr(cls, name, None)
if hint := get_property_hint(attr):
return hint
if hasattr(cls, "__fields__") and name in cls.__fields__:
# pydantic models
field = cls.__fields__[name]
type_ = field.outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
if not field.required and field.default is None:
# Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_]
return type_
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
insp = sqlalchemy.inspect(cls)
if name in insp.columns:
return insp.columns[name].type.python_type
if name not in insp.all_orm_descriptors.keys():
return None
descriptor = insp.all_orm_descriptors[name]
if hint := get_property_hint(descriptor):
return hint
if isinstance(descriptor, QueryableAttribute):
prop = descriptor.property
if not isinstance(prop, Relationship):
return None
class_ = prop.mapper.class_
if prop.uselist:
return List[class_]
else:
return class_
if isinstance(attr, AssociationProxyInstance):
return List[
get_attribute_access_type(
attr.target_class,
attr.remote_attr.key, # type: ignore[attr-defined]
)
]
elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model):
# Check in the annotations directly (for sqlmodel.Relationship)
hints = get_type_hints(cls)
if name in hints:
type_ = hints[name]
type_origin = get_origin(type_)
if isinstance(type_origin, type) and issubclass(type_origin, Mapped):
return get_args(type_)[0] # SQLAlchemy v2
if isinstance(type_, ModelField):
return type_.type_ # SQLAlchemy v1.4
return type_
elif is_union(cls):
# Check in each arg of the annotation.
for arg in get_args(cls):
type_ = get_attribute_access_type(arg, name)
if type_ is not None:
# Return the first attribute type that is accessible.
return type_
return None # Attribute is not accessible.
def get_base_class(cls: GenericType) -> Type:
"""Get the base class of a class.
Args:
cls: The class.
Returns:
The base class of the class.
Raises:
TypeError: If a literal has multiple types.
"""
if is_literal(cls):
# only literals of the same type are supported.
arg_type = type(get_args(cls)[0])
if not all(type(arg) == arg_type for arg in get_args(cls)):
raise TypeError("only literals of the same type are supported")
return type(get_args(cls)[0])
if is_union(cls):
return tuple(get_base_class(arg) for arg in get_args(cls))
return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls
def _issubclass(cls: GenericType, cls_check: GenericType) -> bool:
"""Check if a class is a subclass of another class.
Args:
cls: The class to check.
cls_check: The class to check against.
Returns:
Whether the class is a subclass of the other class.
Raises:
TypeError: If the base class is not valid for issubclass.
"""
# Special check for Any.
if cls_check == Any:
return True
if cls in [Any, Callable, None]:
return False
# Get the base classes.
cls_base = get_base_class(cls)
cls_check_base = get_base_class(cls_check)
# The class we're checking should not be a union.
if isinstance(cls_base, tuple):
return False
# Check if the types match.
try:
return cls_check_base == Any or issubclass(cls_base, cls_check_base)
except TypeError as te:
# These errors typically arise from bad annotations and are hard to
# debug without knowing the type that we tried to compare.
raise TypeError(f"Invalid type for issubclass: {cls_base}") from te
def _isinstance(obj: Any, cls: GenericType) -> bool:
"""Check if an object is an instance of a class.
Args:
obj: The object to check.
cls: The class to check against.
Returns:
Whether the object is an instance of the class.
"""
return isinstance(obj, get_base_class(cls))
def is_dataframe(value: Type) -> bool:
"""Check if the given value is a dataframe.
Args:
value: The value to check.
Returns:
Whether the value is a dataframe.
"""
if is_generic_alias(value) or value == Any:
return False
return value.__name__ == "DataFrame"
def is_valid_var_type(type_: Type) -> bool:
"""Check if the given type is a valid prop type.
Args:
type_: The type to check.
Returns:
Whether the type is a valid prop type.
"""
if is_union(type_):
return all((is_valid_var_type(arg) for arg in get_args(type_)))
return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
def is_backend_variable(name: str, cls: Type | None = None) -> bool:
"""Check if this variable name correspond to a backend variable.
Args:
name: The name of the variable to check
cls: The class of the variable to check
Returns:
bool: The result of the check
"""
if cls is not None and name.startswith(f"_{cls.__name__}__"):
return False
return name.startswith("_") and not name.startswith("__")
def check_type_in_allowed_types(value_type: Type, allowed_types: Iterable) -> bool:
"""Check that a value type is found in a list of allowed types.
Args:
value_type: Type of value.
allowed_types: Iterable of allowed types.
Returns:
If the type is found in the allowed types.
"""
return get_base_class(value_type) in allowed_types
def check_prop_in_allowed_types(prop: Any, allowed_types: Iterable) -> bool:
"""Check that a prop value is in a list of allowed types.
Does the check in a way that works regardless if it's a raw value or a state Var.
Args:
prop: The prop to check.
allowed_types: The list of allowed types.
Returns:
If the prop type match one of the allowed_types.
"""
from reflex.vars import Var
type_ = prop._var_type if _isinstance(prop, Var) else type(prop)
return type_ in allowed_types
def is_encoded_fstring(value) -> bool:
"""Check if a value is an encoded Var f-string.
Args:
value: The value string to check.
Returns:
Whether the value is an f-string
"""
return isinstance(value, str) and constants.REFLEX_VAR_OPENING_TAG in value
def validate_literal(key: str, value: Any, expected_type: Type, comp_name: str):
"""Check that a value is a valid literal.
Args:
key: The prop name.
value: The prop value to validate.
expected_type: The expected type(literal type).
comp_name: Name of the component.
Raises:
ValueError: When the value is not a valid literal.
"""
from reflex.vars import Var
if (
is_literal(expected_type)
and not isinstance(value, Var) # validating vars is not supported yet.
and not is_encoded_fstring(value) # f-strings are not supported.
and value not in expected_type.__args__
):
allowed_values = expected_type.__args__
if value not in allowed_values:
value_str = ",".join(
[str(v) if not isinstance(v, str) else f"'{v}'" for v in allowed_values]
)
raise ValueError(
f"prop value for {str(key)} of the `{comp_name}` component should be one of the following: {value_str}. Got '{value}' instead"
)
def validate_parameter_literals(func):
"""Decorator to check that the arguments passed to a function
correspond to the correct function parameter if it (the parameter)
is a literal type.
Args:
func: The function to validate.
Returns:
The wrapper function.
"""
@wraps(func)
def wrapper(*args, **kwargs):
func_params = list(inspect.signature(func).parameters.items())
annotations = {param[0]: param[1].annotation for param in func_params}
# validate args
for param, arg in zip(annotations.keys(), args):
if annotations[param] is inspect.Parameter.empty:
continue
validate_literal(param, arg, annotations[param], func.__name__)
# validate kwargs.
for key, value in kwargs.items():
annotation = annotations.get(key)
if not annotation or annotation is inspect.Parameter.empty:
continue
validate_literal(key, value, annotation, func.__name__)
return func(*args, **kwargs)
return wrapper
# Store this here for performance.
StateBases = get_base_class(StateVar)
StateIterBases = get_base_class(StateIterVar)