Handle rx.State subclasses defined in function

* create a new container module: `reflex.istate.dynamic` to save references to
  dynamically generated substates.
* for substates with `<locals>` in the name, copy these to the container module
  and update the name to avoid duplication.
* add test for "poor man" ComponentState

Fix #4128
This commit is contained in:
Masen Furer 2024-10-08 11:21:42 -07:00
parent 5e3cfecdea
commit 9b999e2511
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
3 changed files with 77 additions and 2 deletions

3
reflex/istate/dynamic.py Normal file
View File

@ -0,0 +1,3 @@
"""A container for dynamically generated states."""
# This page intentionally left blank.

View File

@ -60,6 +60,7 @@ import wrapt
from redis.asyncio import Redis
from redis.exceptions import ResponseError
import reflex.istate.dynamic
from reflex import constants
from reflex.base import Base
from reflex.event import (
@ -424,6 +425,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if mixin:
return
# Handle locally-defined states for pickling.
if "<locals>" in cls.__qualname__:
cls._handle_local_def()
# Validate the module name.
cls._validate_module_name()
@ -646,6 +651,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
)
]
@classmethod
def _handle_local_def(cls):
"""Handle locally-defined states for pickling."""
known_names = dir(reflex.istate.dynamic)
proposed_name = cls.__name__
for ix in range(len(known_names)):
if proposed_name not in known_names:
break
proposed_name = f"{cls.__name__}_{ix}"
setattr(reflex.istate.dynamic, proposed_name, cls)
cls.__name__ = cls.__qualname__ = proposed_name
cls.__module__ = reflex.istate.dynamic.__name__
@classmethod
def _init_var_dependency_dicts(cls):
"""Initialize the var dependency tracking dicts.
@ -2203,10 +2221,13 @@ class ComponentState(State, mixin=True):
cls._per_component_state_instance_count += 1
state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
component_state = type(
state_cls_name, (cls, State), {"__module__": __name__}, mixin=False
state_cls_name,
(cls, State),
{"__module__": reflex.istate.dynamic.__name__},
mixin=False,
)
# Save a reference to the dynamic state for pickle/unpickle.
globals()[state_cls_name] = component_state
setattr(reflex.istate.dynamic, state_cls_name, component_state)
component = component_state.get_component(*children, **props)
component.State = component_state
return component

View File

@ -20,6 +20,8 @@ def ComponentStateApp():
E = TypeVar("E")
class MultiCounter(rx.ComponentState, Generic[E]):
"""ComponentState style."""
count: int = 0
_be: E
_be_int: int
@ -42,16 +44,47 @@ def ComponentStateApp():
**props,
)
def multi_counter_func(id: str = "default") -> rx.Component:
"""Local-substate style.
Args:
id: identifier for this instance
Returns:
A new instance of the component with its own state.
"""
class _Counter(rx.State):
count: int = 0
def increment(self):
self.count += 1
return rx.vstack(
rx.heading(_Counter.count, id=f"count-{id}"),
rx.button(
"Increment",
on_click=_Counter.increment,
id=f"button-{id}",
),
State=_Counter,
)
app = rx.App(state=rx.State) # noqa
@rx.page()
def index():
mc_a = MultiCounter.create(id="a")
mc_b = MultiCounter.create(id="b")
mc_c = multi_counter_func(id="c")
mc_d = multi_counter_func(id="d")
assert mc_a.State != mc_b.State
assert mc_c.State != mc_d.State
return rx.vstack(
mc_a,
mc_b,
mc_c,
mc_d,
rx.button(
"Inc A",
on_click=mc_a.State.increment, # type: ignore
@ -149,3 +182,21 @@ async def test_component_state_app(component_state_app: AppHarness):
b_state = root_state.substates[b_state_name]
assert b_state._backend_vars != b_state.backend_vars
assert b_state._be == b_state._backend_vars["_be"] == 2
# Check locally-defined substate style
count_c = driver.find_element(By.ID, "count-c")
count_d = driver.find_element(By.ID, "count-d")
button_c = driver.find_element(By.ID, "button-c")
button_d = driver.find_element(By.ID, "button-d")
assert component_state_app.poll_for_content(count_c, exp_not_equal="") == "0"
assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
button_c.click()
assert component_state_app.poll_for_content(count_c, exp_not_equal="0") == "1"
assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
button_c.click()
assert component_state_app.poll_for_content(count_c, exp_not_equal="1") == "2"
assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0"
button_d.click()
assert component_state_app.poll_for_content(count_c, exp_not_equal="1") == "2"
assert component_state_app.poll_for_content(count_d, exp_not_equal="0") == "1"