From a52479290a29309fed7b91d90af302c54391f15e Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 18 Dec 2024 17:07:31 -0800 Subject: [PATCH] Use Var[VAR_TYPE] annotation to take advantage of generics This requires rx.Field to pass typing where used. --- reflex/state.py | 5 ++++- tests/units/test_state.py | 12 ++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index bd656b388..714674e65 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -144,6 +144,9 @@ HANDLED_PICKLE_ERRORS = ( ValueError, ) +# For BaseState.get_var_value +VAR_TYPE = TypeVar("VAR_TYPE", bound=Var) + def _no_chain_background_task( state_cls: Type["BaseState"], name: str, fn: Callable @@ -1597,7 +1600,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Slow case - fetch missing parent states from redis. return await self._get_state_from_redis(state_cls) - async def get_var_value(self, var: Var) -> Any: + async def get_var_value(self, var: Var[VAR_TYPE]) -> VAR_TYPE: """Get the value of an rx.Var from another state. Args: diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 2176e828f..c1780b4f0 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -116,7 +116,7 @@ class TestState(BaseState): # Set this class as not test one __test__ = False - num1: int + num1: rx.Field[int] num2: float = 3.14 key: str map_key: str = "a" @@ -164,7 +164,7 @@ class ChildState(TestState): """A child state fixture.""" value: str - count: int = 23 + count: rx.Field[int] = rx.field(23) def change_both(self, value: str, count: int): """Change both the value and count. @@ -1664,7 +1664,7 @@ async def state_manager(request) -> AsyncGenerator[StateManager, None]: @pytest.fixture() -def substate_token(state_manager, token): +def substate_token(state_manager, token) -> str: """A token + substate name for looking up in state manager. Args: @@ -3768,14 +3768,14 @@ async def test_upcast_event_handler_arg(handler, payload): @pytest.mark.asyncio -async def test_get_var_value(state_manager, token): +async def test_get_var_value(state_manager: StateManager, substate_token: str): """Test that get_var_value works correctly. Args: state_manager: The state manager to use. - token: A token. + substate_token: Token for the substate used by state_manager. """ - state = await state_manager.get_state(_substate_key(token, TestState)) + state = await state_manager.get_state(substate_token) # State Var from same state assert await state.get_var_value(TestState.num1) == 0