[ENG-4098] Deconfuse key/value of State.get_value / dict / get_delta (#4371)

Because of some dodgy logic in Base.get_value and State.dict / State.get_delta
when the value of some state var X happened to be the name of another var in
the state Y, then the value for X would be returned as the value of Y.

wat.

Fixes #4369
This commit is contained in:
Masen Furer 2024-11-12 13:24:06 -08:00 committed by GitHub
parent 2b7ef0dccc
commit 5d88263cd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 40 additions and 10 deletions

View File

@ -130,8 +130,8 @@ class Base(BaseModel): # pyright: ignore [reportUnboundVariable]
Returns: Returns:
The value of the field. The value of the field.
""" """
if isinstance(key, str) and key in self.__fields__: if isinstance(key, str):
# Seems like this function signature was wrong all along? # Seems like this function signature was wrong all along?
# If the user wants a field that we know of, get it and pass it off to _get_value # If the user wants a field that we know of, get it and pass it off to _get_value
key = getattr(self, key) return getattr(self, key, key)
return key return key

View File

@ -1890,7 +1890,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
) )
subdelta: Dict[str, Any] = { subdelta: Dict[str, Any] = {
prop: self.get_value(getattr(self, prop)) prop: self.get_value(prop)
for prop in delta_vars for prop in delta_vars
if not types.is_backend_base_variable(prop, type(self)) if not types.is_backend_base_variable(prop, type(self))
} }
@ -1982,9 +1982,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Returns: Returns:
The value of the field. The value of the field.
""" """
if isinstance(key, MutableProxy): value = super().get_value(key)
return super().get_value(key.__wrapped__) if isinstance(value, MutableProxy):
return super().get_value(key) return value.__wrapped__
return value
def dict( def dict(
self, include_computed: bool = True, initial: bool = False, **kwargs self, include_computed: bool = True, initial: bool = False, **kwargs
@ -2006,8 +2007,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
self._mark_dirty() self._mark_dirty()
base_vars = { base_vars = {
prop_name: self.get_value(getattr(self, prop_name)) prop_name: self.get_value(prop_name) for prop_name in self.base_vars
for prop_name in self.base_vars
} }
if initial and include_computed: if initial and include_computed:
computed_vars = { computed_vars = {
@ -2016,7 +2016,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cv._initial_value cv._initial_value
if is_computed_var(cv) if is_computed_var(cv)
and not isinstance(cv._initial_value, types.Unset) and not isinstance(cv._initial_value, types.Unset)
else self.get_value(getattr(self, prop_name)) else self.get_value(prop_name)
) )
for prop_name, cv in self.computed_vars.items() for prop_name, cv in self.computed_vars.items()
if not cv._backend if not cv._backend
@ -2024,7 +2024,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
elif include_computed: elif include_computed:
computed_vars = { computed_vars = {
# Include the computed vars. # Include the computed vars.
prop_name: self.get_value(getattr(self, prop_name)) prop_name: self.get_value(prop_name)
for prop_name, cv in self.computed_vars.items() for prop_name, cv in self.computed_vars.items()
if not cv._backend if not cv._backend
} }

View File

@ -3411,3 +3411,33 @@ def test_typed_state() -> None:
field: rx.Field[str] = rx.field("") field: rx.Field[str] = rx.field("")
_ = TypedState(field="str") _ = TypedState(field="str")
def test_get_value():
class GetValueState(rx.State):
foo: str = "FOO"
bar: str = "BAR"
state = GetValueState()
assert state.dict() == {
state.get_full_name(): {
"foo": "FOO",
"bar": "BAR",
}
}
assert state.get_delta() == {}
state.bar = "foo"
assert state.dict() == {
state.get_full_name(): {
"foo": "FOO",
"bar": "foo",
}
}
assert state.get_delta() == {
state.get_full_name(): {
"bar": "foo",
}
}