unionize base var fields types (#4153)

* unionize base var fields types

* add tests

* fix union types for vars (#4152)

* remove 3.11 special casing

* special case on version

* fix old versions of python

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
Khaleel Al-Adhami 2024-10-11 17:27:15 -07:00 committed by GitHub
parent 0889276e24
commit b1d449897a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 66 additions and 27 deletions

View File

@ -182,6 +182,26 @@ def is_generic_alias(cls: GenericType) -> bool:
return isinstance(cls, GenericAliasTypes)
def unionize(*args: GenericType) -> Type:
"""Unionize the types.
Args:
args: The types to unionize.
Returns:
The unionized types.
"""
if not args:
return Any
if len(args) == 1:
return args[0]
# We are bisecting the args list here to avoid hitting the recursion limit
# In Python versions >= 3.11, we can simply do `return Union[*args]`
midpoint = len(args) // 2
first_half, second_half = args[:midpoint], args[midpoint:]
return Union[unionize(*first_half), unionize(*second_half)]
def is_none(cls: GenericType) -> bool:
"""Check if a class is None.
@ -358,11 +378,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
return type_
elif is_union(cls):
# Check in each arg of the annotation.
for arg in get_args(cls):
type_ = get_attribute_access_type(arg, name)
if type_ is not None:
# Return the first attribute type that is accessible.
return type_
return unionize(
*(get_attribute_access_type(arg, name) for arg in get_args(cls))
)
elif isinstance(cls, type):
# Bare class
if sys.version_info >= (3, 10):

View File

@ -56,7 +56,7 @@ from reflex.utils.imports import (
ParsedImportDict,
parse_imports,
)
from reflex.utils.types import GenericType, Self, get_origin, has_args
from reflex.utils.types import GenericType, Self, get_origin, has_args, unionize
if TYPE_CHECKING:
from reflex.state import BaseState
@ -1237,26 +1237,6 @@ def var_operation(
return wrapper
def unionize(*args: Type) -> Type:
"""Unionize the types.
Args:
args: The types to unionize.
Returns:
The unionized types.
"""
if not args:
return Any
if len(args) == 1:
return args[0]
# We are bisecting the args list here to avoid hitting the recursion limit
# In Python versions >= 3.11, we can simply do `return Union[*args]`
midpoint = len(args) // 2
first_half, second_half = args[:midpoint], args[midpoint:]
return Union[unionize(*first_half), unionize(*second_half)]
def figure_out_type(value: Any) -> types.GenericType:
"""Figure out the type of the value.

View File

@ -262,7 +262,9 @@ class ObjectVar(Var[OBJECT_TYPE]):
var_type = get_args(var_type)[0]
fixed_type = var_type if isclass(var_type) else get_origin(var_type)
if isclass(fixed_type) and not issubclass(fixed_type, dict):
if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or (
fixed_type in types.UnionTypes
):
attribute_type = get_attribute_access_type(var_type, name)
if attribute_type is None:
raise VarAttributeError(

View File

@ -1,5 +1,6 @@
import json
import math
import sys
import typing
from typing import Dict, List, Optional, Set, Tuple, Union, cast
@ -398,6 +399,44 @@ def test_list_tuple_contains(var, expected):
assert str(var.contains(other_var)) == f"{expected}.includes(other)"
class Foo(rx.Base):
"""Foo class."""
bar: int
baz: str
class Bar(rx.Base):
"""Bar class."""
bar: str
baz: str
foo: int
@pytest.mark.parametrize(
("var", "var_type"),
(
[
(Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar),
(Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]),
]
if sys.version_info >= (3, 10)
else []
)
+ [
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]),
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str),
(
Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo,
Union[int, None],
),
],
)
def test_var_types(var, var_type):
assert var._var_type == var_type
@pytest.mark.parametrize(
"var, expected",
[