minor State cleanup (#3768)
This commit is contained in:
parent
9c70971dc6
commit
2d9380a6fd
@ -55,6 +55,7 @@ from reflex.utils import console, format, prerequisites, types
|
|||||||
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
|
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
|
||||||
from reflex.utils.exec import is_testing_env
|
from reflex.utils.exec import is_testing_env
|
||||||
from reflex.utils.serializers import SerializedType, serialize, serializer
|
from reflex.utils.serializers import SerializedType, serialize, serializer
|
||||||
|
from reflex.utils.types import override
|
||||||
from reflex.vars import BaseVar, ComputedVar, Var, computed_var
|
from reflex.vars import BaseVar, ComputedVar, Var, computed_var
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -1232,6 +1233,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
parent_states_with_name.append((parent_state.get_full_name(), parent_state))
|
parent_states_with_name.append((parent_state.get_full_name(), parent_state))
|
||||||
return parent_states_with_name
|
return parent_states_with_name
|
||||||
|
|
||||||
|
def _get_root_state(self) -> BaseState:
|
||||||
|
"""Get the root state of the state tree.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The root state of the state tree.
|
||||||
|
"""
|
||||||
|
parent_state = self
|
||||||
|
while parent_state.parent_state is not None:
|
||||||
|
parent_state = parent_state.parent_state
|
||||||
|
return parent_state
|
||||||
|
|
||||||
async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
|
async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
|
||||||
"""Populate substates in the tree between the target_state_cls and common ancestor of this state.
|
"""Populate substates in the tree between the target_state_cls and common ancestor of this state.
|
||||||
|
|
||||||
@ -1291,10 +1303,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
Returns:
|
Returns:
|
||||||
The instance of state_cls associated with this state's client_token.
|
The instance of state_cls associated with this state's client_token.
|
||||||
"""
|
"""
|
||||||
if self.parent_state is None:
|
root_state = self._get_root_state()
|
||||||
root_state = self
|
|
||||||
else:
|
|
||||||
root_state = self._get_parent_states()[-1][1]
|
|
||||||
return root_state.get_substate(state_cls.get_full_name().split("."))
|
return root_state.get_substate(state_cls.get_full_name().split("."))
|
||||||
|
|
||||||
async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
|
async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
|
||||||
@ -1445,9 +1454,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
The valid StateUpdate containing the events and final flag.
|
The valid StateUpdate containing the events and final flag.
|
||||||
"""
|
"""
|
||||||
# get the delta from the root of the state tree
|
# get the delta from the root of the state tree
|
||||||
state = self
|
state = self._get_root_state()
|
||||||
while state.parent_state is not None:
|
|
||||||
state = state.parent_state
|
|
||||||
|
|
||||||
token = self.router.session.client_token
|
token = self.router.session.client_token
|
||||||
|
|
||||||
@ -2368,6 +2375,7 @@ class StateManagerMemory(StateManager):
|
|||||||
"_states_locks": {"exclude": True},
|
"_states_locks": {"exclude": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@override
|
||||||
async def get_state(self, token: str) -> BaseState:
|
async def get_state(self, token: str) -> BaseState:
|
||||||
"""Get the state for a token.
|
"""Get the state for a token.
|
||||||
|
|
||||||
@ -2383,6 +2391,7 @@ class StateManagerMemory(StateManager):
|
|||||||
self.states[token] = self.state(_reflex_internal_init=True)
|
self.states[token] = self.state(_reflex_internal_init=True)
|
||||||
return self.states[token]
|
return self.states[token]
|
||||||
|
|
||||||
|
@override
|
||||||
async def set_state(self, token: str, state: BaseState):
|
async def set_state(self, token: str, state: BaseState):
|
||||||
"""Set the state for a token.
|
"""Set the state for a token.
|
||||||
|
|
||||||
@ -2392,6 +2401,7 @@ class StateManagerMemory(StateManager):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@override
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
||||||
"""Modify the state for a token while holding exclusive lock.
|
"""Modify the state for a token while holding exclusive lock.
|
||||||
@ -2483,19 +2493,6 @@ class StateManagerRedis(StateManager):
|
|||||||
# Only warn about each state class size once.
|
# Only warn about each state class size once.
|
||||||
_warned_about_state_size: ClassVar[Set[str]] = set()
|
_warned_about_state_size: ClassVar[Set[str]] = set()
|
||||||
|
|
||||||
def _get_root_state(self, state: BaseState) -> BaseState:
|
|
||||||
"""Chase parent_state pointers to find an instance of the top-level state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: The state to start from.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An instance of the top-level state (self.state).
|
|
||||||
"""
|
|
||||||
while type(state) != self.state and state.parent_state is not None:
|
|
||||||
state = state.parent_state
|
|
||||||
return state
|
|
||||||
|
|
||||||
async def _get_parent_state(self, token: str) -> BaseState | None:
|
async def _get_parent_state(self, token: str) -> BaseState | None:
|
||||||
"""Get the parent state for the state requested in the token.
|
"""Get the parent state for the state requested in the token.
|
||||||
|
|
||||||
@ -2558,6 +2555,7 @@ class StateManagerRedis(StateManager):
|
|||||||
for substate_name, substate_task in tasks.items():
|
for substate_name, substate_task in tasks.items():
|
||||||
state.substates[substate_name] = await substate_task
|
state.substates[substate_name] = await substate_task
|
||||||
|
|
||||||
|
@override
|
||||||
async def get_state(
|
async def get_state(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
@ -2609,7 +2607,7 @@ class StateManagerRedis(StateManager):
|
|||||||
# To retain compatibility with previous implementation, by default, we return
|
# To retain compatibility with previous implementation, by default, we return
|
||||||
# the top-level state by chasing `parent_state` pointers up the tree.
|
# the top-level state by chasing `parent_state` pointers up the tree.
|
||||||
if top_level:
|
if top_level:
|
||||||
return self._get_root_state(state)
|
return state._get_root_state()
|
||||||
return state
|
return state
|
||||||
|
|
||||||
# TODO: dedupe the following logic with the above block
|
# TODO: dedupe the following logic with the above block
|
||||||
@ -2631,7 +2629,7 @@ class StateManagerRedis(StateManager):
|
|||||||
# To retain compatibility with previous implementation, by default, we return
|
# To retain compatibility with previous implementation, by default, we return
|
||||||
# the top-level state by chasing `parent_state` pointers up the tree.
|
# the top-level state by chasing `parent_state` pointers up the tree.
|
||||||
if top_level:
|
if top_level:
|
||||||
return self._get_root_state(state)
|
return state._get_root_state()
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def _warn_if_too_large(
|
def _warn_if_too_large(
|
||||||
@ -2657,6 +2655,7 @@ class StateManagerRedis(StateManager):
|
|||||||
)
|
)
|
||||||
self._warned_about_state_size.add(state_full_name)
|
self._warned_about_state_size.add(state_full_name)
|
||||||
|
|
||||||
|
@override
|
||||||
async def set_state(
|
async def set_state(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
@ -2717,6 +2716,7 @@ class StateManagerRedis(StateManager):
|
|||||||
for t in tasks:
|
for t in tasks:
|
||||||
await t
|
await t
|
||||||
|
|
||||||
|
@override
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
||||||
"""Modify the state for a token while holding exclusive lock.
|
"""Modify the state for a token while holding exclusive lock.
|
||||||
|
@ -49,7 +49,7 @@ from reflex.base import Base
|
|||||||
from reflex.utils import console
|
from reflex.utils import console
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import override
|
from typing import override as override
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def override(func: Callable) -> Callable:
|
def override(func: Callable) -> Callable:
|
||||||
|
Loading…
Reference in New Issue
Block a user