use getattr when given str in getitem (#4761)

* use getattr when given str in getitem

* stronger checking and tests

* switch ordering

* use safe issubclass

* calculate origin differently
This commit is contained in:
Khaleel Al-Adhami 2025-02-06 10:09:05 -08:00 committed by Masen Furer
parent 88a44f45ec
commit 0875a3eac0
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
2 changed files with 32 additions and 4 deletions

View File

@ -22,7 +22,12 @@ from typing_extensions import is_typeddict
from reflex.utils import types from reflex.utils import types
from reflex.utils.exceptions import VarAttributeError from reflex.utils.exceptions import VarAttributeError
from reflex.utils.types import GenericType, get_attribute_access_type, get_origin from reflex.utils.types import (
GenericType,
get_attribute_access_type,
get_origin,
safe_issubclass,
)
from .base import ( from .base import (
CachedVarOperation, CachedVarOperation,
@ -187,10 +192,14 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
Returns: Returns:
The item from the object. The item from the object.
""" """
from .sequence import LiteralStringVar
if not isinstance(key, (StringVar, str, int, NumberVar)) or ( if not isinstance(key, (StringVar, str, int, NumberVar)) or (
isinstance(key, NumberVar) and key._is_strict_float() isinstance(key, NumberVar) and key._is_strict_float()
): ):
raise_unsupported_operand_types("[]", (type(self), type(key))) raise_unsupported_operand_types("[]", (type(self), type(key)))
if isinstance(key, str) and isinstance(Var.create(key), LiteralStringVar):
return self.__getattr__(key)
return ObjectItemOperation.create(self, key).guess_type() return ObjectItemOperation.create(self, key).guess_type()
# NoReturn is used here to catch when key value is Any # NoReturn is used here to catch when key value is Any
@ -260,12 +269,12 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
if types.is_optional(var_type): if types.is_optional(var_type):
var_type = get_args(var_type)[0] var_type = get_args(var_type)[0]
fixed_type = var_type if isclass(var_type) else get_origin(var_type) fixed_type = get_origin(var_type) or var_type
if ( if (
(isclass(fixed_type) and not issubclass(fixed_type, Mapping)) is_typeddict(fixed_type)
or (isclass(fixed_type) and not safe_issubclass(fixed_type, Mapping))
or (fixed_type in types.UnionTypes) or (fixed_type in types.UnionTypes)
or is_typeddict(fixed_type)
): ):
attribute_type = get_attribute_access_type(var_type, name) attribute_type = get_attribute_access_type(var_type, name)
if attribute_type is None: if attribute_type is None:

View File

@ -10,6 +10,8 @@ from reflex.testing import AppHarness
def VarOperations(): def VarOperations():
"""App with var operations.""" """App with var operations."""
from typing import TypedDict
import reflex as rx import reflex as rx
from reflex.vars.base import LiteralVar from reflex.vars.base import LiteralVar
from reflex.vars.sequence import ArrayVar from reflex.vars.sequence import ArrayVar
@ -17,6 +19,10 @@ def VarOperations():
class Object(rx.Base): class Object(rx.Base):
name: str = "hello" name: str = "hello"
class Person(TypedDict):
name: str
age: int
class VarOperationState(rx.State): class VarOperationState(rx.State):
int_var1: rx.Field[int] = rx.field(10) int_var1: rx.Field[int] = rx.field(10)
int_var2: rx.Field[int] = rx.field(5) int_var2: rx.Field[int] = rx.field(5)
@ -34,6 +40,9 @@ def VarOperations():
dict1: rx.Field[dict[int, int]] = rx.field({1: 2}) dict1: rx.Field[dict[int, int]] = rx.field({1: 2})
dict2: rx.Field[dict[int, int]] = rx.field({3: 4}) dict2: rx.Field[dict[int, int]] = rx.field({3: 4})
html_str: rx.Field[str] = rx.field("<div>hello</div>") html_str: rx.Field[str] = rx.field("<div>hello</div>")
people: rx.Field[list[Person]] = rx.field(
[{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]
)
app = rx.App(_state=rx.State) app = rx.App(_state=rx.State)
@ -619,6 +628,15 @@ def VarOperations():
), ),
id="dict_in_foreach3", id="dict_in_foreach3",
), ),
rx.box(
rx.foreach(
VarOperationState.people,
lambda person: rx.text.span(
"Hello " + person["name"], person["age"] + 3
),
),
id="typed_dict_in_foreach",
),
) )
@ -826,6 +844,7 @@ def test_var_operations(driver, var_operations: AppHarness):
("dict_in_foreach1", "a1b2"), ("dict_in_foreach1", "a1b2"),
("dict_in_foreach2", "12"), ("dict_in_foreach2", "12"),
("dict_in_foreach3", "1234"), ("dict_in_foreach3", "1234"),
("typed_dict_in_foreach", "Hello Alice33Hello Bob28"),
] ]
for tag, expected in tests: for tag, expected in tests: