unionize base var fields types (#4153)

* unionize base var fields types

* add tests

* fix union types for vars (#4152)

* remove 3.11 special casing

* special case on version

* fix old versions of python

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
Khaleel Al-Adhami 2024-10-11 17:27:15 -07:00 committed by GitHub
parent 0889276e24
commit b1d449897a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 66 additions and 27 deletions

View File

@ -182,6 +182,26 @@ def is_generic_alias(cls: GenericType) -> bool:
return isinstance(cls, GenericAliasTypes) return isinstance(cls, GenericAliasTypes)
def unionize(*args: GenericType) -> Type:
"""Unionize the types.
Args:
args: The types to unionize.
Returns:
The unionized types.
"""
if not args:
return Any
if len(args) == 1:
return args[0]
# We are bisecting the args list here to avoid hitting the recursion limit
# In Python versions >= 3.11, we can simply do `return Union[*args]`
midpoint = len(args) // 2
first_half, second_half = args[:midpoint], args[midpoint:]
return Union[unionize(*first_half), unionize(*second_half)]
def is_none(cls: GenericType) -> bool: def is_none(cls: GenericType) -> bool:
"""Check if a class is None. """Check if a class is None.
@ -358,11 +378,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
return type_ return type_
elif is_union(cls): elif is_union(cls):
# Check in each arg of the annotation. # Check in each arg of the annotation.
for arg in get_args(cls): return unionize(
type_ = get_attribute_access_type(arg, name) *(get_attribute_access_type(arg, name) for arg in get_args(cls))
if type_ is not None: )
# Return the first attribute type that is accessible.
return type_
elif isinstance(cls, type): elif isinstance(cls, type):
# Bare class # Bare class
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):

View File

@ -56,7 +56,7 @@ from reflex.utils.imports import (
ParsedImportDict, ParsedImportDict,
parse_imports, parse_imports,
) )
from reflex.utils.types import GenericType, Self, get_origin, has_args from reflex.utils.types import GenericType, Self, get_origin, has_args, unionize
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.state import BaseState from reflex.state import BaseState
@ -1237,26 +1237,6 @@ def var_operation(
return wrapper return wrapper
def unionize(*args: Type) -> Type:
"""Unionize the types.
Args:
args: The types to unionize.
Returns:
The unionized types.
"""
if not args:
return Any
if len(args) == 1:
return args[0]
# We are bisecting the args list here to avoid hitting the recursion limit
# In Python versions >= 3.11, we can simply do `return Union[*args]`
midpoint = len(args) // 2
first_half, second_half = args[:midpoint], args[midpoint:]
return Union[unionize(*first_half), unionize(*second_half)]
def figure_out_type(value: Any) -> types.GenericType: def figure_out_type(value: Any) -> types.GenericType:
"""Figure out the type of the value. """Figure out the type of the value.

View File

@ -262,7 +262,9 @@ class ObjectVar(Var[OBJECT_TYPE]):
var_type = get_args(var_type)[0] var_type = get_args(var_type)[0]
fixed_type = var_type if isclass(var_type) else get_origin(var_type) fixed_type = var_type if isclass(var_type) else get_origin(var_type)
if isclass(fixed_type) and not issubclass(fixed_type, dict): if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or (
fixed_type in types.UnionTypes
):
attribute_type = get_attribute_access_type(var_type, name) attribute_type = get_attribute_access_type(var_type, name)
if attribute_type is None: if attribute_type is None:
raise VarAttributeError( raise VarAttributeError(

View File

@ -1,5 +1,6 @@
import json import json
import math import math
import sys
import typing import typing
from typing import Dict, List, Optional, Set, Tuple, Union, cast from typing import Dict, List, Optional, Set, Tuple, Union, cast
@ -398,6 +399,44 @@ def test_list_tuple_contains(var, expected):
assert str(var.contains(other_var)) == f"{expected}.includes(other)" assert str(var.contains(other_var)) == f"{expected}.includes(other)"
class Foo(rx.Base):
"""Foo class."""
bar: int
baz: str
class Bar(rx.Base):
"""Bar class."""
bar: str
baz: str
foo: int
@pytest.mark.parametrize(
("var", "var_type"),
(
[
(Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar),
(Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]),
]
if sys.version_info >= (3, 10)
else []
)
+ [
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]),
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str),
(
Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo,
Union[int, None],
),
],
)
def test_var_types(var, var_type):
assert var._var_type == var_type
@pytest.mark.parametrize( @pytest.mark.parametrize(
"var, expected", "var, expected",
[ [