From 2adb221a6c66c256cdea19114cb39f008f6124db Mon Sep 17 00:00:00 2001 From: Lendemor Date: Thu, 19 Sep 2024 15:34:52 -0700 Subject: [PATCH] expose rx.get_state() to get instance of state from anywhere --- reflex/__init__.py | 1 + reflex/__init__.pyi | 1 + reflex/istate/__init__.py | 3 +++ reflex/istate/wrappers.py | 25 +++++++++++++++++++++++++ 4 files changed, 30 insertions(+) create mode 100644 reflex/istate/__init__.py create mode 100644 reflex/istate/wrappers.py diff --git a/reflex/__init__.py b/reflex/__init__.py index 63de1f386..5036bfc0b 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -328,6 +328,7 @@ _MAPPING: dict = { "ComponentState", "State", ], + "istate": ["get_state"], "style": ["Style", "toggle_color_mode"], "utils.imports": ["ImportVar"], "utils.serializers": ["serializer"], diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index ef5bcfd8f..d764e5c39 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -174,6 +174,7 @@ from .event import stop_propagation as stop_propagation from .event import upload_files as upload_files from .event import window_alert as window_alert from .experimental import _x as _x +from .istate import get_state as get_state from .middleware import Middleware as Middleware from .middleware import middleware as middleware from .model import Model as Model diff --git a/reflex/istate/__init__.py b/reflex/istate/__init__.py new file mode 100644 index 000000000..d10d841c7 --- /dev/null +++ b/reflex/istate/__init__.py @@ -0,0 +1,3 @@ +"""This module will provide interfaces for the state.""" + +from .wrappers import get_state diff --git a/reflex/istate/wrappers.py b/reflex/istate/wrappers.py new file mode 100644 index 000000000..bbc8382b0 --- /dev/null +++ b/reflex/istate/wrappers.py @@ -0,0 +1,25 @@ +"""Wrappers for the state manager.""" + +from typing import Any + +from reflex.state import _split_substate_key, _substate_key, get_state_manager + + +async def get_state(token, state_cls: Any | None = None): + """Get the instance of a state for a token. + + Args: + token: The token for the state. + state_cls: The class of the state. + + Returns: + The state instance. + """ + mng = get_state_manager() + if state_cls is not None: + root_state = await mng.get_state(_substate_key(token, state_cls)) + else: + root_state = await mng.get_state(token) + _, state_path = _split_substate_key(token) + state_cls = root_state.get_class_substate(tuple(state_path.split("."))) + return await root_state.get_state(state_cls)