Handle rx.State subclasses defined in function
* create a new container module: `reflex.istate.dynamic` to save references to dynamically generated substates. * for substates with `<locals>` in the name, copy these to the container module and update the name to avoid duplication. * add test for "poor man" ComponentState Fix #4128
This commit is contained in:
parent
5e3cfecdea
commit
9b999e2511
3
reflex/istate/dynamic.py
Normal file
3
reflex/istate/dynamic.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""A container for dynamically generated states."""
|
||||
|
||||
# This page intentionally left blank.
|
@ -60,6 +60,7 @@ import wrapt
|
||||
from redis.asyncio import Redis
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
import reflex.istate.dynamic
|
||||
from reflex import constants
|
||||
from reflex.base import Base
|
||||
from reflex.event import (
|
||||
@ -424,6 +425,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if mixin:
|
||||
return
|
||||
|
||||
# Handle locally-defined states for pickling.
|
||||
if "<locals>" in cls.__qualname__:
|
||||
cls._handle_local_def()
|
||||
|
||||
# Validate the module name.
|
||||
cls._validate_module_name()
|
||||
|
||||
@ -646,6 +651,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _handle_local_def(cls):
|
||||
"""Handle locally-defined states for pickling."""
|
||||
known_names = dir(reflex.istate.dynamic)
|
||||
proposed_name = cls.__name__
|
||||
for ix in range(len(known_names)):
|
||||
if proposed_name not in known_names:
|
||||
break
|
||||
proposed_name = f"{cls.__name__}_{ix}"
|
||||
setattr(reflex.istate.dynamic, proposed_name, cls)
|
||||
cls.__name__ = cls.__qualname__ = proposed_name
|
||||
cls.__module__ = reflex.istate.dynamic.__name__
|
||||
|
||||
@classmethod
|
||||
def _init_var_dependency_dicts(cls):
|
||||
"""Initialize the var dependency tracking dicts.
|
||||
@ -2203,10 +2221,13 @@ class ComponentState(State, mixin=True):
|
||||
cls._per_component_state_instance_count += 1
|
||||
state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
|
||||
component_state = type(
|
||||
state_cls_name, (cls, State), {"__module__": __name__}, mixin=False
|
||||
state_cls_name,
|
||||
(cls, State),
|
||||
{"__module__": reflex.istate.dynamic.__name__},
|
||||
mixin=False,
|
||||
)
|
||||
# Save a reference to the dynamic state for pickle/unpickle.
|
||||
globals()[state_cls_name] = component_state
|
||||
setattr(reflex.istate.dynamic, state_cls_name, component_state)
|
||||
component = component_state.get_component(*children, **props)
|
||||
component.State = component_state
|
||||
return component
|
||||
|
@ -20,6 +20,8 @@ def ComponentStateApp():
|
||||
E = TypeVar("E")
|
||||
|
||||
class MultiCounter(rx.ComponentState, Generic[E]):
|
||||
"""ComponentState style."""
|
||||
|
||||
count: int = 0
|
||||
_be: E
|
||||
_be_int: int
|
||||
@ -42,16 +44,47 @@ def ComponentStateApp():
|
||||
**props,
|
||||
)
|
||||
|
||||
def multi_counter_func(id: str = "default") -> rx.Component:
|
||||
"""Local-substate style.
|
||||
|
||||
Args:
|
||||
id: identifier for this instance
|
||||
|
||||
Returns:
|
||||
A new instance of the component with its own state.
|
||||
"""
|
||||
|
||||
class _Counter(rx.State):
|
||||
count: int = 0
|
||||
|
||||
def increment(self):
|
||||
self.count += 1
|
||||
|
||||
return rx.vstack(
|
||||
rx.heading(_Counter.count, id=f"count-{id}"),
|
||||
rx.button(
|
||||
"Increment",
|
||||
on_click=_Counter.increment,
|
||||
id=f"button-{id}",
|
||||
),
|
||||
State=_Counter,
|
||||
)
|
||||
|
||||
app = rx.App(state=rx.State) # noqa
|
||||
|
||||
@rx.page()
|
||||
def index():
|
||||
mc_a = MultiCounter.create(id="a")
|
||||
mc_b = MultiCounter.create(id="b")
|
||||
mc_c = multi_counter_func(id="c")
|
||||
mc_d = multi_counter_func(id="d")
|
||||
assert mc_a.State != mc_b.State
|
||||
assert mc_c.State != mc_d.State
|
||||
return rx.vstack(
|
||||
mc_a,
|
||||
mc_b,
|
||||
mc_c,
|
||||
mc_d,
|
||||
rx.button(
|
||||
"Inc A",
|
||||
on_click=mc_a.State.increment, # type: ignore
|
||||
@ -149,3 +182,21 @@ async def test_component_state_app(component_state_app: AppHarness):
|
||||
b_state = root_state.substates[b_state_name]
|
||||
assert b_state._backend_vars != b_state.backend_vars
|
||||
assert b_state._be == b_state._backend_vars["_be"] == 2
|
||||
|
||||
# Check locally-defined substate style
|
||||
count_c = driver.find_element(By.ID, "count-c")
|
||||
count_d = driver.find_element(By.ID, "count-d")
|
||||
button_c = driver.find_element(By.ID, "button-c")
|
||||
button_d = driver.find_element(By.ID, "button-d")
|
||||
|
||||
assert component_state_app.poll_for_content(count_c, exp_not_equal="") == "0"
|
||||
assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
|
||||
button_c.click()
|
||||
assert component_state_app.poll_for_content(count_c, exp_not_equal="0") == "1"
|
||||
assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
|
||||
button_c.click()
|
||||
assert component_state_app.poll_for_content(count_c, exp_not_equal="1") == "2"
|
||||
assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
|
||||
button_d.click()
|
||||
assert component_state_app.poll_for_content(count_c, exp_not_equal="1") == "2"
|
||||
assert component_state_app.poll_for_content(count_d, exp_not_equal="0") == "1"
|
||||
|
Loading…
Reference in New Issue
Block a user