diff --git a/integration/test_component_state.py b/integration/test_component_state.py new file mode 100644 index 000000000..d2c10c766 --- /dev/null +++ b/integration/test_component_state.py @@ -0,0 +1,107 @@ +"""Test that per-component state scaffold works and operates independently.""" +from typing import Generator + +import pytest +from selenium.webdriver.common.by import By + +from reflex.testing import AppHarness + +from . import utils + + +def ComponentStateApp(): + """App using per component state.""" + import reflex as rx + + class MultiCounter(rx.ComponentState): + count: int = 0 + + def increment(self): + self.count += 1 + + @classmethod + def get_component(cls, *children, **props): + return rx.vstack( + *children, + rx.heading(cls.count, id=f"count-{props.get('id', 'default')}"), + rx.button( + "Increment", + on_click=cls.increment, + id=f"button-{props.get('id', 'default')}", + ), + **props, + ) + + app = rx.App(state=rx.State) # noqa + + @rx.page() + def index(): + mc_a = MultiCounter.create(id="a") + mc_b = MultiCounter.create(id="b") + assert mc_a.State != mc_b.State + return rx.vstack( + mc_a, + mc_b, + rx.button( + "Inc A", + on_click=mc_a.State.increment, # type: ignore + id="inc-a", + ), + ) + + +@pytest.fixture() +def component_state_app(tmp_path) -> Generator[AppHarness, None, None]: + """Start ComponentStateApp app at tmp_path via AppHarness. + + Args: + tmp_path: pytest tmp_path fixture + + Yields: + running AppHarness instance + """ + with AppHarness.create( + root=tmp_path, + app_source=ComponentStateApp, # type: ignore + ) as harness: + yield harness + + +@pytest.mark.asyncio +async def test_component_state_app(component_state_app: AppHarness): + """Increment counters independently. + + Args: + component_state_app: harness for ComponentStateApp app + """ + assert component_state_app.app_instance is not None, "app is not running" + driver = component_state_app.frontend() + + ss = utils.SessionStorage(driver) + token = AppHarness._poll_for(lambda: ss.get("token") is not None) + assert token is not None + + count_a = driver.find_element(By.ID, "count-a") + count_b = driver.find_element(By.ID, "count-b") + button_a = driver.find_element(By.ID, "button-a") + button_b = driver.find_element(By.ID, "button-b") + button_inc_a = driver.find_element(By.ID, "inc-a") + + assert count_a.text == "0" + + button_a.click() + assert component_state_app.poll_for_content(count_a, exp_not_equal="0") == "1" + + button_a.click() + assert component_state_app.poll_for_content(count_a, exp_not_equal="1") == "2" + + button_inc_a.click() + assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3" + + assert count_b.text == "0" + + button_b.click() + assert component_state_app.poll_for_content(count_b, exp_not_equal="0") == "1" + + button_b.click() + assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2" diff --git a/integration/utils.py b/integration/utils.py index 273094c84..bcbd6c497 100644 --- a/integration/utils.py +++ b/integration/utils.py @@ -38,6 +38,8 @@ class LocalStorage: https://stackoverflow.com/a/46361900 """ + storage_key = "localStorage" + def __init__(self, driver: WebDriver): """Initialize the class. @@ -171,3 +173,12 @@ class LocalStorage: An iterator over the items in local storage. """ return iter(self.keys()) + + +class SessionStorage(LocalStorage): + """Class to access session storage. + + https://stackoverflow.com/a/46361900 + """ + + storage_key = "sessionStorage" diff --git a/reflex/__init__.py b/reflex/__init__.py index 9c5faea0f..8b6a656a3 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -153,7 +153,14 @@ _MAPPING = { "reflex.model": ["model", "session", "Model"], "reflex.page": ["page"], "reflex.route": ["route"], - "reflex.state": ["state", "var", "Cookie", "LocalStorage", "State"], + "reflex.state": [ + "state", + "var", + "Cookie", + "LocalStorage", + "ComponentState", + "State", + ], "reflex.style": ["style", "toggle_color_mode"], "reflex.testing": ["testing"], "reflex.utils": ["utils"], diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index 51c291343..5b46dd4f4 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -141,6 +141,7 @@ from reflex import state as state from reflex.state import var as var from reflex.state import Cookie as Cookie from reflex.state import LocalStorage as LocalStorage +from reflex.state import ComponentState as ComponentState from reflex.state import State as State from reflex import style as style from reflex.style import toggle_color_mode as toggle_color_mode diff --git a/reflex/components/component.py b/reflex/components/component.py index 4e7654a76..6cb30186d 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -21,6 +21,7 @@ from typing import ( Union, ) +import reflex.state from reflex.base import Base from reflex.compiler.templates import STATEFUL_COMPONENT from reflex.components.tags import Tag @@ -214,6 +215,9 @@ class Component(BaseComponent, ABC): # When to memoize this component and its children. _memoization_mode: MemoizationMode = MemoizationMode() + # State class associated with this component instance + State: Optional[Type[reflex.state.State]] = None + @classmethod def __init_subclass__(cls, **kwargs): """Set default properties. diff --git a/reflex/state.py b/reflex/state.py index 847d65433..42cfc9866 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -15,6 +15,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from types import FunctionType, MethodType from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Callable, @@ -47,6 +48,10 @@ from reflex.utils.exec import is_testing_env from reflex.utils.serializers import SerializedType, serialize, serializer from reflex.vars import BaseVar, ComputedVar, Var, computed_var +if TYPE_CHECKING: + from reflex.components.component import Component + + Delta = Dict[str, Any] var = computed_var @@ -1835,6 +1840,45 @@ class OnLoadInternalState(State): ] +class ComponentState(Base): + """The base class for a State that is copied for each Component associated with it.""" + + _per_component_state_instance_count: ClassVar[int] = 0 + + @classmethod + def get_component(cls, *children, **props) -> "Component": + """Get the component instance. + + Args: + children: The children of the component. + props: The props of the component. + + Raises: + NotImplementedError: if the subclass does not override this method. + """ + raise NotImplementedError( + f"{cls.__name__} must implement get_component to return the component instance." + ) + + @classmethod + def create(cls, *children, **props) -> "Component": + """Create a new instance of the Component. + + Args: + children: The children of the component. + props: The props of the component. + + Returns: + A new instance of the Component with an independent copy of the State. + """ + cls._per_component_state_instance_count += 1 + state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}" + component_state = type(state_cls_name, (cls, State), {}) + component = component_state.get_component(*children, **props) + component.State = component_state + return component + + class StateProxy(wrapt.ObjectProxy): """Proxy of a state instance to control mutability of vars for a background task. diff --git a/scripts/pyi_generator.py b/scripts/pyi_generator.py index 2e8936d13..35ce6f50a 100644 --- a/scripts/pyi_generator.py +++ b/scripts/pyi_generator.py @@ -57,6 +57,7 @@ EXCLUDED_PROPS = [ "_rename_props", "_valid_children", "_valid_parents", + "State", ] DEFAULT_TYPING_IMPORTS = { diff --git a/tests/components/test_component_state.py b/tests/components/test_component_state.py new file mode 100644 index 000000000..0dc0956e2 --- /dev/null +++ b/tests/components/test_component_state.py @@ -0,0 +1,42 @@ +"""Ensure that Components returned by ComponentState.create have independent State classes.""" + +import reflex as rx +from reflex.components.base.bare import Bare + + +def test_component_state(): + """Create two components with independent state classes.""" + + class CS(rx.ComponentState): + count: int = 0 + + def increment(self): + self.count += 1 + + @classmethod + def get_component(cls, *children, **props): + return rx.el.div( + *children, + **props, + ) + + cs1, cs2 = CS.create("a", id="a"), CS.create("b", id="b") + assert isinstance(cs1, rx.Component) + assert isinstance(cs2, rx.Component) + assert cs1.State is not None + assert cs2.State is not None + assert cs1.State != cs2.State + assert issubclass(cs1.State, CS) + assert issubclass(cs1.State, rx.State) + assert issubclass(cs2.State, CS) + assert issubclass(cs2.State, rx.State) + assert CS._per_component_state_instance_count == 2 + assert isinstance(cs1.State.increment, rx.event.EventHandler) + assert cs1.State.increment != cs2.State.increment + + assert len(cs1.children) == 1 + assert cs1.children[0].render() == Bare.create("{`a`}").render() + assert cs1.id == "a" + assert len(cs2.children) == 1 + assert cs2.children[0].render() == Bare.create("{`b`}").render() + assert cs2.id == "b"