diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 6bedf5b61..dec0f4eaf 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -182,6 +182,26 @@ def is_generic_alias(cls: GenericType) -> bool: return isinstance(cls, GenericAliasTypes) +def unionize(*args: GenericType) -> Type: + """Unionize the types. + + Args: + args: The types to unionize. + + Returns: + The unionized types. + """ + if not args: + return Any + if len(args) == 1: + return args[0] + # We are bisecting the args list here to avoid hitting the recursion limit + # In Python versions >= 3.11, we can simply do `return Union[*args]` + midpoint = len(args) // 2 + first_half, second_half = args[:midpoint], args[midpoint:] + return Union[unionize(*first_half), unionize(*second_half)] + + def is_none(cls: GenericType) -> bool: """Check if a class is None. @@ -337,11 +357,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None return type_ elif is_union(cls): # Check in each arg of the annotation. - for arg in get_args(cls): - type_ = get_attribute_access_type(arg, name) - if type_ is not None: - # Return the first attribute type that is accessible. - return type_ + return unionize( + *(get_attribute_access_type(arg, name) for arg in get_args(cls)) + ) elif isinstance(cls, type): # Bare class if sys.version_info >= (3, 10): diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 58a73e025..9508d87e5 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -56,7 +56,7 @@ from reflex.utils.imports import ( ParsedImportDict, parse_imports, ) -from reflex.utils.types import GenericType, Self, get_origin +from reflex.utils.types import GenericType, Self, get_origin, unionize if TYPE_CHECKING: from reflex.state import BaseState @@ -1237,26 +1237,6 @@ def var_operation( return wrapper -def unionize(*args: Type) -> Type: - """Unionize the types. - - Args: - args: The types to unionize. - - Returns: - The unionized types. - """ - if not args: - return Any - if len(args) == 1: - return args[0] - # We are bisecting the args list here to avoid hitting the recursion limit - # In Python versions >= 3.11, we can simply do `return Union[*args]` - midpoint = len(args) // 2 - first_half, second_half = args[:midpoint], args[midpoint:] - return Union[unionize(*first_half), unionize(*second_half)] - - def figure_out_type(value: Any) -> types.GenericType: """Figure out the type of the value.