BaseState.get_var_value helper to get a value from a Var

When given a state Var or a LiteralVar, retrieve the actual value associated
with the Var.

For state Vars, the returned value is directly tied to the associated state and
can be modified.

Modifying LiteralVar values or ComputedVar values will have no useful effect.
This commit is contained in:
Masen Furer 2024-12-17 17:57:25 -08:00
parent d8e988105f
commit 592db8cdca
No known key found for this signature in database
GPG Key ID: 2AE2BD5531FF94F4
3 changed files with 65 additions and 0 deletions

View File

@ -107,6 +107,7 @@ from reflex.utils.exceptions import (
StateSchemaMismatchError, StateSchemaMismatchError,
StateSerializationError, StateSerializationError,
StateTooLargeError, StateTooLargeError,
UnretrievableVarValueError,
) )
from reflex.utils.exec import is_testing_env from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import serializer from reflex.utils.serializers import serializer
@ -1596,6 +1597,36 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Slow case - fetch missing parent states from redis. # Slow case - fetch missing parent states from redis.
return await self._get_state_from_redis(state_cls) return await self._get_state_from_redis(state_cls)
async def get_var_value(self, var: Var) -> Any:
"""Get the value of an rx.Var from another state.
Args:
var: The var to get the value for.
Returns:
The value of the var.
Raises:
UnretrievableVarValueError: If the var does not have a literal value
or associated state.
"""
# Fast case: this is a literal var and the value is known.
if hasattr(var, "_var_value"):
return var._var_value
var_data = var._get_all_var_data()
if var_data is None or not var_data.state:
raise UnretrievableVarValueError(
f"Unable to retrieve value for {var._js_expr}: not associated with any state."
)
# Fastish case: this var belongs to this state
if var_data.state == self.get_full_name():
return getattr(self, var_data.field_name)
# Slow case: this var belongs to another state
other_state = await self.get_state(
self._get_root_state().get_class_substate(var_data.state)
)
return getattr(other_state, var_data.field_name)
def _get_event_handler( def _get_event_handler(
self, event: Event self, event: Event
) -> tuple[BaseState | StateProxy, EventHandler]: ) -> tuple[BaseState | StateProxy, EventHandler]:

View File

@ -187,3 +187,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn:
class InvalidLockWarningThresholdError(ReflexError): class InvalidLockWarningThresholdError(ReflexError):
"""Raised when an invalid lock warning threshold is provided.""" """Raised when an invalid lock warning threshold is provided."""
class UnretrievableVarValueError(ReflexError):
"""Raised when the value of a var is not retrievable."""

View File

@ -60,6 +60,7 @@ from reflex.utils.exceptions import (
ReflexRuntimeError, ReflexRuntimeError,
SetUndefinedStateVarError, SetUndefinedStateVarError,
StateSerializationError, StateSerializationError,
UnretrievableVarValueError,
) )
from reflex.utils.format import json_dumps from reflex.utils.format import json_dumps
from reflex.vars.base import Var, computed_var from reflex.vars.base import Var, computed_var
@ -3764,3 +3765,32 @@ async def test_upcast_event_handler_arg(handler, payload):
state = UpcastState() state = UpcastState()
async for update in state._process_event(handler, state, payload): async for update in state._process_event(handler, state, payload):
assert update.delta == {UpcastState.get_full_name(): {"passed": True}} assert update.delta == {UpcastState.get_full_name(): {"passed": True}}
@pytest.mark.asyncio
async def test_get_var_value(state_manager, token):
"""Test that get_var_value works correctly.
Args:
state_manager: The state manager to use.
token: A token.
"""
state = await state_manager.get_state(_substate_key(token, TestState))
# State Var from same state
assert await state.get_var_value(TestState.num1) == 0
state.num1 = 42
assert await state.get_var_value(TestState.num1) == 42
# State Var from another state
child_state = await state.get_state(ChildState)
assert await state.get_var_value(ChildState.count) == 23
child_state.count = 66
assert await state.get_var_value(ChildState.count) == 66
# LiteralVar with known value
assert await state.get_var_value(rx.Var.create([1, 2, 3])) == [1, 2, 3]
# Generic Var with no state
with pytest.raises(UnretrievableVarValueError):
await state.get_var_value(rx.Var("undefined"))