expose rx.get_state() to get instance of state from anywhere

This commit is contained in:
Lendemor 2024-09-19 15:34:52 -07:00
parent c46d1d9c7e
commit 2adb221a6c
4 changed files with 30 additions and 0 deletions

View File

@ -328,6 +328,7 @@ _MAPPING: dict = {
"ComponentState", "ComponentState",
"State", "State",
], ],
"istate": ["get_state"],
"style": ["Style", "toggle_color_mode"], "style": ["Style", "toggle_color_mode"],
"utils.imports": ["ImportVar"], "utils.imports": ["ImportVar"],
"utils.serializers": ["serializer"], "utils.serializers": ["serializer"],

View File

@ -174,6 +174,7 @@ from .event import stop_propagation as stop_propagation
from .event import upload_files as upload_files from .event import upload_files as upload_files
from .event import window_alert as window_alert from .event import window_alert as window_alert
from .experimental import _x as _x from .experimental import _x as _x
from .istate import get_state as get_state
from .middleware import Middleware as Middleware from .middleware import Middleware as Middleware
from .middleware import middleware as middleware from .middleware import middleware as middleware
from .model import Model as Model from .model import Model as Model

View File

@ -0,0 +1,3 @@
"""This module will provide interfaces for the state."""
from .wrappers import get_state

25
reflex/istate/wrappers.py Normal file
View File

@ -0,0 +1,25 @@
"""Wrappers for the state manager."""
from typing import Any
from reflex.state import _split_substate_key, _substate_key, get_state_manager
async def get_state(token, state_cls: Any | None = None):
"""Get the instance of a state for a token.
Args:
token: The token for the state.
state_cls: The class of the state.
Returns:
The state instance.
"""
mng = get_state_manager()
if state_cls is not None:
root_state = await mng.get_state(_substate_key(token, state_cls))
else:
root_state = await mng.get_state(token)
_, state_path = _split_substate_key(token)
state_cls = root_state.get_class_substate(tuple(state_path.split(".")))
return await root_state.get_state(state_cls)