diff --git a/reflex/istate/dynamic.py b/reflex/istate/dynamic.py new file mode 100644 index 000000000..e5da36c26 --- /dev/null +++ b/reflex/istate/dynamic.py @@ -0,0 +1,3 @@ +"""A container for dynamically generated states.""" + +# This page intentionally left blank. diff --git a/reflex/state.py b/reflex/state.py index 5d1d51b91..935535ace 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -60,6 +60,7 @@ import wrapt from redis.asyncio import Redis from redis.exceptions import ResponseError +import reflex.istate.dynamic from reflex import constants from reflex.base import Base from reflex.event import ( @@ -424,6 +425,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if mixin: return + # Handle locally-defined states for pickling. + if "" in cls.__qualname__: + cls._handle_local_def() + # Validate the module name. cls._validate_module_name() @@ -646,6 +651,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) ] + @classmethod + def _handle_local_def(cls): + """Handle locally-defined states for pickling.""" + known_names = dir(reflex.istate.dynamic) + proposed_name = cls.__name__ + for ix in range(len(known_names)): + if proposed_name not in known_names: + break + proposed_name = f"{cls.__name__}_{ix}" + setattr(reflex.istate.dynamic, proposed_name, cls) + cls.__name__ = cls.__qualname__ = proposed_name + cls.__module__ = reflex.istate.dynamic.__name__ + @classmethod def _init_var_dependency_dicts(cls): """Initialize the var dependency tracking dicts. @@ -2203,10 +2221,13 @@ class ComponentState(State, mixin=True): cls._per_component_state_instance_count += 1 state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}" component_state = type( - state_cls_name, (cls, State), {"__module__": __name__}, mixin=False + state_cls_name, + (cls, State), + {"__module__": reflex.istate.dynamic.__name__}, + mixin=False, ) # Save a reference to the dynamic state for pickle/unpickle. - globals()[state_cls_name] = component_state + setattr(reflex.istate.dynamic, state_cls_name, component_state) component = component_state.get_component(*children, **props) component.State = component_state return component diff --git a/tests/integration/test_component_state.py b/tests/integration/test_component_state.py index 7b35f8116..1555577d1 100644 --- a/tests/integration/test_component_state.py +++ b/tests/integration/test_component_state.py @@ -20,6 +20,8 @@ def ComponentStateApp(): E = TypeVar("E") class MultiCounter(rx.ComponentState, Generic[E]): + """ComponentState style.""" + count: int = 0 _be: E _be_int: int @@ -42,16 +44,47 @@ def ComponentStateApp(): **props, ) + def multi_counter_func(id: str = "default") -> rx.Component: + """Local-substate style. + + Args: + id: identifier for this instance + + Returns: + A new instance of the component with its own state. + """ + + class _Counter(rx.State): + count: int = 0 + + def increment(self): + self.count += 1 + + return rx.vstack( + rx.heading(_Counter.count, id=f"count-{id}"), + rx.button( + "Increment", + on_click=_Counter.increment, + id=f"button-{id}", + ), + State=_Counter, + ) + app = rx.App(state=rx.State) # noqa @rx.page() def index(): mc_a = MultiCounter.create(id="a") mc_b = MultiCounter.create(id="b") + mc_c = multi_counter_func(id="c") + mc_d = multi_counter_func(id="d") assert mc_a.State != mc_b.State + assert mc_c.State != mc_d.State return rx.vstack( mc_a, mc_b, + mc_c, + mc_d, rx.button( "Inc A", on_click=mc_a.State.increment, # type: ignore @@ -149,3 +182,21 @@ async def test_component_state_app(component_state_app: AppHarness): b_state = root_state.substates[b_state_name] assert b_state._backend_vars != b_state.backend_vars assert b_state._be == b_state._backend_vars["_be"] == 2 + + # Check locally-defined substate style + count_c = driver.find_element(By.ID, "count-c") + count_d = driver.find_element(By.ID, "count-d") + button_c = driver.find_element(By.ID, "button-c") + button_d = driver.find_element(By.ID, "button-d") + + assert component_state_app.poll_for_content(count_c, exp_not_equal="") == "0" + assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0" + button_c.click() + assert component_state_app.poll_for_content(count_c, exp_not_equal="0") == "1" + assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0" + button_c.click() + assert component_state_app.poll_for_content(count_c, exp_not_equal="1") == "2" + assert component_state_app.poll_for_content(count_d, exp_not_equal="") == "0" + button_d.click() + assert component_state_app.poll_for_content(count_c, exp_not_equal="1") == "2" + assert component_state_app.poll_for_content(count_d, exp_not_equal="0") == "1"