improve event handler state references (#2818)
This commit is contained in:
parent
c809107d09
commit
8a3c9383fb
@ -147,6 +147,10 @@ class EventHandler(EventActionsMixin):
|
||||
# The function to call in response to the event.
|
||||
fn: Any
|
||||
|
||||
# The full name of the state class this event handler is attached to.
|
||||
# Emtpy string means this event handler is a server side event.
|
||||
state_full_name: str = ""
|
||||
|
||||
class Config:
|
||||
"""The Pydantic config."""
|
||||
|
||||
|
@ -472,7 +472,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
events[name] = value
|
||||
|
||||
for name, fn in events.items():
|
||||
handler = EventHandler(fn=fn)
|
||||
handler = cls._create_event_handler(fn)
|
||||
cls.event_handlers[name] = handler
|
||||
setattr(cls, name, handler)
|
||||
|
||||
@ -677,7 +677,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]:
|
||||
def get_class_substate(cls, path: Sequence[str] | str) -> Type[BaseState]:
|
||||
"""Get the class substate.
|
||||
|
||||
Args:
|
||||
@ -689,6 +689,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
Raises:
|
||||
ValueError: If the substate is not found.
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
path = tuple(path.split("."))
|
||||
|
||||
if len(path) == 0:
|
||||
return cls
|
||||
if path[0] == cls.get_name():
|
||||
@ -789,6 +792,18 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""
|
||||
setattr(cls, prop._var_name, prop)
|
||||
|
||||
@classmethod
|
||||
def _create_event_handler(cls, fn):
|
||||
"""Create an event handler for the given function.
|
||||
|
||||
Args:
|
||||
fn: The function to create an event handler for.
|
||||
|
||||
Returns:
|
||||
The event handler.
|
||||
"""
|
||||
return EventHandler(fn=fn, state_full_name=cls.get_full_name())
|
||||
|
||||
@classmethod
|
||||
def _create_setter(cls, prop: BaseVar):
|
||||
"""Create a setter for the var.
|
||||
@ -798,7 +813,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""
|
||||
setter_name = prop.get_setter_name(include_state=False)
|
||||
if setter_name not in cls.__dict__:
|
||||
event_handler = EventHandler(fn=prop.get_setter())
|
||||
event_handler = cls._create_event_handler(prop.get_setter())
|
||||
cls.event_handlers[setter_name] = event_handler
|
||||
setattr(cls, setter_name, event_handler)
|
||||
|
||||
@ -1752,7 +1767,7 @@ class UpdateVarsInternalState(State):
|
||||
"""
|
||||
for var, value in vars.items():
|
||||
state_name, _, var_name = var.rpartition(".")
|
||||
var_state_cls = State.get_class_substate(tuple(state_name.split(".")))
|
||||
var_state_cls = State.get_class_substate(state_name)
|
||||
var_state = await self.get_state(var_state_cls)
|
||||
setattr(var_state, var_name, value)
|
||||
|
||||
@ -2268,7 +2283,7 @@ class StateManagerRedis(StateManager):
|
||||
_, state_path = _split_substate_key(token)
|
||||
if state_path:
|
||||
# Get the State class associated with the given path.
|
||||
state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
|
||||
state_cls = self.state.get_class_substate(state_path)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
|
||||
|
@ -6,7 +6,6 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, List, Union
|
||||
|
||||
from reflex import constants
|
||||
@ -470,18 +469,18 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
|
||||
if len(parts) == 1:
|
||||
return ("", parts[-1])
|
||||
|
||||
# Get the state and the function name.
|
||||
state_name, name = parts[-2:]
|
||||
# Get the state full name
|
||||
state_full_name = handler.state_full_name
|
||||
|
||||
# Construct the full event handler name.
|
||||
try:
|
||||
# Try to get the state from the module.
|
||||
state = vars(sys.modules[handler.fn.__module__])[state_name]
|
||||
except Exception:
|
||||
# If the state isn't in the module, just return the function name.
|
||||
# Get the function name
|
||||
name = parts[-1]
|
||||
|
||||
from reflex.state import State
|
||||
|
||||
if state_full_name == "state" and name not in State.__dict__:
|
||||
return ("", to_snake_case(handler.fn.__qualname__))
|
||||
|
||||
return (state.get_full_name(), name)
|
||||
return (state_full_name, name)
|
||||
|
||||
|
||||
def format_event_handler(handler: EventHandler) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user