unionize base var fields types

This commit is contained in:
Khaleel Al-Adhami 2024-10-10 16:21:49 -07:00
parent 3da1a8d082
commit 220a058273
2 changed files with 24 additions and 26 deletions

View File

@ -182,6 +182,26 @@ def is_generic_alias(cls: GenericType) -> bool:
return isinstance(cls, GenericAliasTypes)
def unionize(*args: GenericType) -> Type:
"""Unionize the types.
Args:
args: The types to unionize.
Returns:
The unionized types.
"""
if not args:
return Any
if len(args) == 1:
return args[0]
# We are bisecting the args list here to avoid hitting the recursion limit
# In Python versions >= 3.11, we can simply do `return Union[*args]`
midpoint = len(args) // 2
first_half, second_half = args[:midpoint], args[midpoint:]
return Union[unionize(*first_half), unionize(*second_half)]
def is_none(cls: GenericType) -> bool:
"""Check if a class is None.
@ -337,11 +357,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
return type_
elif is_union(cls):
# Check in each arg of the annotation.
for arg in get_args(cls):
type_ = get_attribute_access_type(arg, name)
if type_ is not None:
# Return the first attribute type that is accessible.
return type_
return unionize(
*(get_attribute_access_type(arg, name) for arg in get_args(cls))
)
elif isinstance(cls, type):
# Bare class
if sys.version_info >= (3, 10):

View File

@ -56,7 +56,7 @@ from reflex.utils.imports import (
ParsedImportDict,
parse_imports,
)
from reflex.utils.types import GenericType, Self, get_origin
from reflex.utils.types import GenericType, Self, get_origin, unionize
if TYPE_CHECKING:
from reflex.state import BaseState
@ -1237,26 +1237,6 @@ def var_operation(
return wrapper
def unionize(*args: Type) -> Type:
"""Unionize the types.
Args:
args: The types to unionize.
Returns:
The unionized types.
"""
if not args:
return Any
if len(args) == 1:
return args[0]
# We are bisecting the args list here to avoid hitting the recursion limit
# In Python versions >= 3.11, we can simply do `return Union[*args]`
midpoint = len(args) // 2
first_half, second_half = args[:midpoint], args[midpoint:]
return Union[unionize(*first_half), unionize(*second_half)]
def figure_out_type(value: Any) -> types.GenericType:
"""Figure out the type of the value.