diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 3f3b05258..d2de4a156 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -182,6 +182,26 @@ def is_generic_alias(cls: GenericType) -> bool: return isinstance(cls, GenericAliasTypes) +def unionize(*args: GenericType) -> Type: + """Unionize the types. + + Args: + args: The types to unionize. + + Returns: + The unionized types. + """ + if not args: + return Any + if len(args) == 1: + return args[0] + # We are bisecting the args list here to avoid hitting the recursion limit + # In Python versions >= 3.11, we can simply do `return Union[*args]` + midpoint = len(args) // 2 + first_half, second_half = args[:midpoint], args[midpoint:] + return Union[unionize(*first_half), unionize(*second_half)] + + def is_none(cls: GenericType) -> bool: """Check if a class is None. @@ -358,11 +378,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None return type_ elif is_union(cls): # Check in each arg of the annotation. - for arg in get_args(cls): - type_ = get_attribute_access_type(arg, name) - if type_ is not None: - # Return the first attribute type that is accessible. - return type_ + return unionize( + *(get_attribute_access_type(arg, name) for arg in get_args(cls)) + ) elif isinstance(cls, type): # Bare class if sys.version_info >= (3, 10): diff --git a/reflex/vars/base.py b/reflex/vars/base.py index df9a0e122..f62f8513a 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -56,7 +56,7 @@ from reflex.utils.imports import ( ParsedImportDict, parse_imports, ) -from reflex.utils.types import GenericType, Self, get_origin, has_args +from reflex.utils.types import GenericType, Self, get_origin, has_args, unionize if TYPE_CHECKING: from reflex.state import BaseState @@ -1237,26 +1237,6 @@ def var_operation( return wrapper -def unionize(*args: Type) -> Type: - """Unionize the types. - - Args: - args: The types to unionize. - - Returns: - The unionized types. - """ - if not args: - return Any - if len(args) == 1: - return args[0] - # We are bisecting the args list here to avoid hitting the recursion limit - # In Python versions >= 3.11, we can simply do `return Union[*args]` - midpoint = len(args) // 2 - first_half, second_half = args[:midpoint], args[midpoint:] - return Union[unionize(*first_half), unionize(*second_half)] - - def figure_out_type(value: Any) -> types.GenericType: """Figure out the type of the value. diff --git a/reflex/vars/object.py b/reflex/vars/object.py index a9175a703..38add7779 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -262,7 +262,9 @@ class ObjectVar(Var[OBJECT_TYPE]): var_type = get_args(var_type)[0] fixed_type = var_type if isclass(var_type) else get_origin(var_type) - if isclass(fixed_type) and not issubclass(fixed_type, dict): + if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or ( + fixed_type in types.UnionTypes + ): attribute_type = get_attribute_access_type(var_type, name) if attribute_type is None: raise VarAttributeError( diff --git a/tests/units/test_var.py b/tests/units/test_var.py index c02acefe6..c04e554a9 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -1,5 +1,6 @@ import json import math +import sys import typing from typing import Dict, List, Optional, Set, Tuple, Union, cast @@ -398,6 +399,44 @@ def test_list_tuple_contains(var, expected): assert str(var.contains(other_var)) == f"{expected}.includes(other)" +class Foo(rx.Base): + """Foo class.""" + + bar: int + baz: str + + +class Bar(rx.Base): + """Bar class.""" + + bar: str + baz: str + foo: int + + +@pytest.mark.parametrize( + ("var", "var_type"), + ( + [ + (Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar), + (Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]), + ] + if sys.version_info >= (3, 10) + else [] + ) + + [ + (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]), + (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str), + ( + Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo, + Union[int, None], + ), + ], +) +def test_var_types(var, var_type): + assert var._var_type == var_type + + @pytest.mark.parametrize( "var, expected", [