Handle rx.State subclasses defined in function

* create a new container module: `reflex.istate.dynamic` to save references to
  dynamically generated substates.
* for substates with `<locals>` in the name, copy these to the container module
  and update the name to avoid duplication.
* add test for "poor man" ComponentState

Fix #4128
This commit is contained in:
Masen Furer 2024-10-08 11:21:42 -07:00
parent 5e3cfecdea
commit 9b999e2511
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
3 changed files with 77 additions and 2 deletions

3
reflex/istate/dynamic.py Normal file
View File

@ -0,0 +1,3 @@
"""A container for dynamically generated states."""
# This page intentionally left blank.

View File

@ -60,6 +60,7 @@ import wrapt
from redis.asyncio import Redis from redis.asyncio import Redis
from redis.exceptions import ResponseError from redis.exceptions import ResponseError
import reflex.istate.dynamic
from reflex import constants from reflex import constants
from reflex.base import Base from reflex.base import Base
from reflex.event import ( from reflex.event import (
@ -424,6 +425,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if mixin: if mixin:
return return
# Handle locally-defined states for pickling.
if "<locals>" in cls.__qualname__:
cls._handle_local_def()
# Validate the module name. # Validate the module name.
cls._validate_module_name() cls._validate_module_name()
@ -646,6 +651,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
) )
] ]
@classmethod
def _handle_local_def(cls):
"""Handle locally-defined states for pickling."""
known_names = dir(reflex.istate.dynamic)
proposed_name = cls.__name__
for ix in range(len(known_names)):
if proposed_name not in known_names:
break
proposed_name = f"{cls.__name__}_{ix}"
setattr(reflex.istate.dynamic, proposed_name, cls)
cls.__name__ = cls.__qualname__ = proposed_name
cls.__module__ = reflex.istate.dynamic.__name__
@classmethod @classmethod
def _init_var_dependency_dicts(cls): def _init_var_dependency_dicts(cls):
"""Initialize the var dependency tracking dicts. """Initialize the var dependency tracking dicts.
@ -2203,10 +2221,13 @@ class ComponentState(State, mixin=True):
cls._per_component_state_instance_count += 1 cls._per_component_state_instance_count += 1
state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}" state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
component_state = type( component_state = type(
state_cls_name, (cls, State), {"__module__": __name__}, mixin=False state_cls_name,
(cls, State),
{"__module__": reflex.istate.dynamic.__name__},
mixin=False,
) )
# Save a reference to the dynamic state for pickle/unpickle. # Save a reference to the dynamic state for pickle/unpickle.
globals()[state_cls_name] = component_state setattr(reflex.istate.dynamic, state_cls_name, component_state)
component = component_state.get_component(*children, **props) component = component_state.get_component(*children, **props)
component.State = component_state component.State = component_state
return component return component

View File

@ -20,6 +20,8 @@ def ComponentStateApp():
E = TypeVar("E") E = TypeVar("E")
class MultiCounter(rx.ComponentState, Generic[E]): class MultiCounter(rx.ComponentState, Generic[E]):
"""ComponentState style."""
count: int = 0 count: int = 0
_be: E _be: E
_be_int: int _be_int: int
@ -42,16 +44,47 @@ def ComponentStateApp():
**props, **props,
) )
def multi_counter_func(id: str = "default") -> rx.Component:
"""Local-substate style.
Args:
id: identifier for this instance
Returns:
A new instance of the component with its own state.
"""
class _Counter(rx.State):
count: int = 0
def increment(self):
self.count += 1
return rx.vstack(
rx.heading(_Counter.count, id=f"count-{id}"),
rx.button(
"Increment",
on_click=_Counter.increment,
id=f"button-{id}",
),
State=_Counter,
)
app = rx.App(state=rx.State) # noqa app = rx.App(state=rx.State) # noqa
@rx.page() @rx.page()
def index(): def index():
mc_a = MultiCounter.create(id="a") mc_a = MultiCounter.create(id="a")
mc_b = MultiCounter.create(id="b") mc_b = MultiCounter.create(id="b")
mc_c = multi_counter_func(id="c")
mc_d = multi_counter_func(id="d")
assert mc_a.State != mc_b.State assert mc_a.State != mc_b.State
assert mc_c.State != mc_d.State
return rx.vstack( return rx.vstack(
mc_a, mc_a,
mc_b, mc_b,
mc_c,
mc_d,
rx.button( rx.button(
"Inc A", "Inc A",
on_click=mc_a.State.increment, # type: ignore on_click=mc_a.State.increment, # type: ignore
@ -149,3 +182,21 @@ async def test_component_state_app(component_state_app: AppHarness):
b_state = root_state.substates[b_state_name] b_state = root_state.substates[b_state_name]
assert b_state._backend_vars != b_state.backend_vars assert b_state._backend_vars != b_state.backend_vars
assert b_state._be == b_state._backend_vars["_be"] == 2 assert b_state._be == b_state._backend_vars["_be"] == 2
# Check locally-defined substate style
count_c = driver.find_element(By.ID, "count-c")
count_d = driver.find_element(By.ID, "count-d")
button_c = driver.find_element(By.ID, "button-c")
button_d = driver.find_element(By.ID, "button-d")
assert component_state_app.poll_for_content(count_c, exp_not_equal="") == "0"
assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
button_c.click()
assert component_state_app.poll_for_content(count_c, exp_not_equal="0") == "1"
assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
button_c.click()
assert component_state_app.poll_for_content(count_c, exp_not_equal="1") == "2"
assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
button_d.click()
assert component_state_app.poll_for_content(count_c, exp_not_equal="1") == "2"
assert component_state_app.poll_for_content(count_d, exp_not_equal="0") == "1"