Use Var[VAR_TYPE] annotation to take advantage of generics

This requires rx.Field to pass typing where used.
This commit is contained in:
Masen Furer 2024-12-18 17:07:31 -08:00
parent 592db8cdca
commit a52479290a
No known key found for this signature in database
GPG Key ID: 2AE2BD5531FF94F4
2 changed files with 10 additions and 7 deletions

View File

@ -144,6 +144,9 @@ HANDLED_PICKLE_ERRORS = (
ValueError,
)
# For BaseState.get_var_value
VAR_TYPE = TypeVar("VAR_TYPE", bound=Var)
def _no_chain_background_task(
state_cls: Type["BaseState"], name: str, fn: Callable
@ -1597,7 +1600,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Slow case - fetch missing parent states from redis.
return await self._get_state_from_redis(state_cls)
async def get_var_value(self, var: Var) -> Any:
async def get_var_value(self, var: Var[VAR_TYPE]) -> VAR_TYPE:
"""Get the value of an rx.Var from another state.
Args:

View File

@ -116,7 +116,7 @@ class TestState(BaseState):
# Set this class as not test one
__test__ = False
num1: int
num1: rx.Field[int]
num2: float = 3.14
key: str
map_key: str = "a"
@ -164,7 +164,7 @@ class ChildState(TestState):
"""A child state fixture."""
value: str
count: int = 23
count: rx.Field[int] = rx.field(23)
def change_both(self, value: str, count: int):
"""Change both the value and count.
@ -1664,7 +1664,7 @@ async def state_manager(request) -> AsyncGenerator[StateManager, None]:
@pytest.fixture()
def substate_token(state_manager, token):
def substate_token(state_manager, token) -> str:
"""A token + substate name for looking up in state manager.
Args:
@ -3768,14 +3768,14 @@ async def test_upcast_event_handler_arg(handler, payload):
@pytest.mark.asyncio
async def test_get_var_value(state_manager, token):
async def test_get_var_value(state_manager: StateManager, substate_token: str):
"""Test that get_var_value works correctly.
Args:
state_manager: The state manager to use.
token: A token.
substate_token: Token for the substate used by state_manager.
"""
state = await state_manager.get_state(_substate_key(token, TestState))
state = await state_manager.get_state(substate_token)
# State Var from same state
assert await state.get_var_value(TestState.num1) == 0