unionize base var fields types (#4153)
* unionize base var fields types * add tests * fix union types for vars (#4152) * remove 3.11 special casing * special case on version * fix old versions of python --------- Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
parent
0889276e24
commit
b1d449897a
@ -182,6 +182,26 @@ def is_generic_alias(cls: GenericType) -> bool:
|
||||
return isinstance(cls, GenericAliasTypes)
|
||||
|
||||
|
||||
def unionize(*args: GenericType) -> Type:
|
||||
"""Unionize the types.
|
||||
|
||||
Args:
|
||||
args: The types to unionize.
|
||||
|
||||
Returns:
|
||||
The unionized types.
|
||||
"""
|
||||
if not args:
|
||||
return Any
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
# We are bisecting the args list here to avoid hitting the recursion limit
|
||||
# In Python versions >= 3.11, we can simply do `return Union[*args]`
|
||||
midpoint = len(args) // 2
|
||||
first_half, second_half = args[:midpoint], args[midpoint:]
|
||||
return Union[unionize(*first_half), unionize(*second_half)]
|
||||
|
||||
|
||||
def is_none(cls: GenericType) -> bool:
|
||||
"""Check if a class is None.
|
||||
|
||||
@ -358,11 +378,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
||||
return type_
|
||||
elif is_union(cls):
|
||||
# Check in each arg of the annotation.
|
||||
for arg in get_args(cls):
|
||||
type_ = get_attribute_access_type(arg, name)
|
||||
if type_ is not None:
|
||||
# Return the first attribute type that is accessible.
|
||||
return type_
|
||||
return unionize(
|
||||
*(get_attribute_access_type(arg, name) for arg in get_args(cls))
|
||||
)
|
||||
elif isinstance(cls, type):
|
||||
# Bare class
|
||||
if sys.version_info >= (3, 10):
|
||||
|
@ -56,7 +56,7 @@ from reflex.utils.imports import (
|
||||
ParsedImportDict,
|
||||
parse_imports,
|
||||
)
|
||||
from reflex.utils.types import GenericType, Self, get_origin, has_args
|
||||
from reflex.utils.types import GenericType, Self, get_origin, has_args, unionize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflex.state import BaseState
|
||||
@ -1237,26 +1237,6 @@ def var_operation(
|
||||
return wrapper
|
||||
|
||||
|
||||
def unionize(*args: Type) -> Type:
|
||||
"""Unionize the types.
|
||||
|
||||
Args:
|
||||
args: The types to unionize.
|
||||
|
||||
Returns:
|
||||
The unionized types.
|
||||
"""
|
||||
if not args:
|
||||
return Any
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
# We are bisecting the args list here to avoid hitting the recursion limit
|
||||
# In Python versions >= 3.11, we can simply do `return Union[*args]`
|
||||
midpoint = len(args) // 2
|
||||
first_half, second_half = args[:midpoint], args[midpoint:]
|
||||
return Union[unionize(*first_half), unionize(*second_half)]
|
||||
|
||||
|
||||
def figure_out_type(value: Any) -> types.GenericType:
|
||||
"""Figure out the type of the value.
|
||||
|
||||
|
@ -262,7 +262,9 @@ class ObjectVar(Var[OBJECT_TYPE]):
|
||||
var_type = get_args(var_type)[0]
|
||||
|
||||
fixed_type = var_type if isclass(var_type) else get_origin(var_type)
|
||||
if isclass(fixed_type) and not issubclass(fixed_type, dict):
|
||||
if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or (
|
||||
fixed_type in types.UnionTypes
|
||||
):
|
||||
attribute_type = get_attribute_access_type(var_type, name)
|
||||
if attribute_type is None:
|
||||
raise VarAttributeError(
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
import typing
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union, cast
|
||||
|
||||
@ -398,6 +399,44 @@ def test_list_tuple_contains(var, expected):
|
||||
assert str(var.contains(other_var)) == f"{expected}.includes(other)"
|
||||
|
||||
|
||||
class Foo(rx.Base):
|
||||
"""Foo class."""
|
||||
|
||||
bar: int
|
||||
baz: str
|
||||
|
||||
|
||||
class Bar(rx.Base):
|
||||
"""Bar class."""
|
||||
|
||||
bar: str
|
||||
baz: str
|
||||
foo: int
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("var", "var_type"),
|
||||
(
|
||||
[
|
||||
(Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar),
|
||||
(Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]),
|
||||
]
|
||||
if sys.version_info >= (3, 10)
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]),
|
||||
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str),
|
||||
(
|
||||
Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo,
|
||||
Union[int, None],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_var_types(var, var_type):
|
||||
assert var._var_type == var_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"var, expected",
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user