[REF-2265] ComponentState: scaffold for copying State per Component instance (#2923)

* [REF-2265] ComponentState: scaffold for copying State per Component instance

Define a base ComponentState which can be used to easily create copies of the
given State definition (Vars and EventHandlers) that are tied to a particular
instance of a Component (returned by get_component)

* Define `State` field on `Component` for typing compatibility.

This is an Optional field of Type[State] and is populated by ComponentState.

* Add integration/test_component_state.py

Create two independent counters and increment them separately

* Add unit test for ComponentState
This commit is contained in:
Masen Furer 2024-03-29 09:22:25 -07:00 committed by GitHub
parent f372402ee4
commit 5510eaf820
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 218 additions and 1 deletions

View File

@ -0,0 +1,107 @@
"""Test that per-component state scaffold works and operates independently."""
from typing import Generator
import pytest
from selenium.webdriver.common.by import By
from reflex.testing import AppHarness
from . import utils
def ComponentStateApp():
"""App using per component state."""
import reflex as rx
class MultiCounter(rx.ComponentState):
count: int = 0
def increment(self):
self.count += 1
@classmethod
def get_component(cls, *children, **props):
return rx.vstack(
*children,
rx.heading(cls.count, id=f"count-{props.get('id', 'default')}"),
rx.button(
"Increment",
on_click=cls.increment,
id=f"button-{props.get('id', 'default')}",
),
**props,
)
app = rx.App(state=rx.State) # noqa
@rx.page()
def index():
mc_a = MultiCounter.create(id="a")
mc_b = MultiCounter.create(id="b")
assert mc_a.State != mc_b.State
return rx.vstack(
mc_a,
mc_b,
rx.button(
"Inc A",
on_click=mc_a.State.increment, # type: ignore
id="inc-a",
),
)
@pytest.fixture()
def component_state_app(tmp_path) -> Generator[AppHarness, None, None]:
"""Start ComponentStateApp app at tmp_path via AppHarness.
Args:
tmp_path: pytest tmp_path fixture
Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path,
app_source=ComponentStateApp, # type: ignore
) as harness:
yield harness
@pytest.mark.asyncio
async def test_component_state_app(component_state_app: AppHarness):
"""Increment counters independently.
Args:
component_state_app: harness for ComponentStateApp app
"""
assert component_state_app.app_instance is not None, "app is not running"
driver = component_state_app.frontend()
ss = utils.SessionStorage(driver)
token = AppHarness._poll_for(lambda: ss.get("token") is not None)
assert token is not None
count_a = driver.find_element(By.ID, "count-a")
count_b = driver.find_element(By.ID, "count-b")
button_a = driver.find_element(By.ID, "button-a")
button_b = driver.find_element(By.ID, "button-b")
button_inc_a = driver.find_element(By.ID, "inc-a")
assert count_a.text == "0"
button_a.click()
assert component_state_app.poll_for_content(count_a, exp_not_equal="0") == "1"
button_a.click()
assert component_state_app.poll_for_content(count_a, exp_not_equal="1") == "2"
button_inc_a.click()
assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3"
assert count_b.text == "0"
button_b.click()
assert component_state_app.poll_for_content(count_b, exp_not_equal="0") == "1"
button_b.click()
assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2"

View File

@ -38,6 +38,8 @@ class LocalStorage:
https://stackoverflow.com/a/46361900
"""
storage_key = "localStorage"
def __init__(self, driver: WebDriver):
"""Initialize the class.
@ -171,3 +173,12 @@ class LocalStorage:
An iterator over the items in local storage.
"""
return iter(self.keys())
class SessionStorage(LocalStorage):
"""Class to access session storage.
https://stackoverflow.com/a/46361900
"""
storage_key = "sessionStorage"

View File

@ -153,7 +153,14 @@ _MAPPING = {
"reflex.model": ["model", "session", "Model"],
"reflex.page": ["page"],
"reflex.route": ["route"],
"reflex.state": ["state", "var", "Cookie", "LocalStorage", "State"],
"reflex.state": [
"state",
"var",
"Cookie",
"LocalStorage",
"ComponentState",
"State",
],
"reflex.style": ["style", "toggle_color_mode"],
"reflex.testing": ["testing"],
"reflex.utils": ["utils"],

View File

@ -141,6 +141,7 @@ from reflex import state as state
from reflex.state import var as var
from reflex.state import Cookie as Cookie
from reflex.state import LocalStorage as LocalStorage
from reflex.state import ComponentState as ComponentState
from reflex.state import State as State
from reflex import style as style
from reflex.style import toggle_color_mode as toggle_color_mode

View File

@ -21,6 +21,7 @@ from typing import (
Union,
)
import reflex.state
from reflex.base import Base
from reflex.compiler.templates import STATEFUL_COMPONENT
from reflex.components.tags import Tag
@ -214,6 +215,9 @@ class Component(BaseComponent, ABC):
# When to memoize this component and its children.
_memoization_mode: MemoizationMode = MemoizationMode()
# State class associated with this component instance
State: Optional[Type[reflex.state.State]] = None
@classmethod
def __init_subclass__(cls, **kwargs):
"""Set default properties.

View File

@ -15,6 +15,7 @@ from abc import ABC, abstractmethod
from collections import defaultdict
from types import FunctionType, MethodType
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
@ -47,6 +48,10 @@ from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import SerializedType, serialize, serializer
from reflex.vars import BaseVar, ComputedVar, Var, computed_var
if TYPE_CHECKING:
from reflex.components.component import Component
Delta = Dict[str, Any]
var = computed_var
@ -1835,6 +1840,45 @@ class OnLoadInternalState(State):
]
class ComponentState(Base):
"""The base class for a State that is copied for each Component associated with it."""
_per_component_state_instance_count: ClassVar[int] = 0
@classmethod
def get_component(cls, *children, **props) -> "Component":
"""Get the component instance.
Args:
children: The children of the component.
props: The props of the component.
Raises:
NotImplementedError: if the subclass does not override this method.
"""
raise NotImplementedError(
f"{cls.__name__} must implement get_component to return the component instance."
)
@classmethod
def create(cls, *children, **props) -> "Component":
"""Create a new instance of the Component.
Args:
children: The children of the component.
props: The props of the component.
Returns:
A new instance of the Component with an independent copy of the State.
"""
cls._per_component_state_instance_count += 1
state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
component_state = type(state_cls_name, (cls, State), {})
component = component_state.get_component(*children, **props)
component.State = component_state
return component
class StateProxy(wrapt.ObjectProxy):
"""Proxy of a state instance to control mutability of vars for a background task.

View File

@ -57,6 +57,7 @@ EXCLUDED_PROPS = [
"_rename_props",
"_valid_children",
"_valid_parents",
"State",
]
DEFAULT_TYPING_IMPORTS = {

View File

@ -0,0 +1,42 @@
"""Ensure that Components returned by ComponentState.create have independent State classes."""
import reflex as rx
from reflex.components.base.bare import Bare
def test_component_state():
"""Create two components with independent state classes."""
class CS(rx.ComponentState):
count: int = 0
def increment(self):
self.count += 1
@classmethod
def get_component(cls, *children, **props):
return rx.el.div(
*children,
**props,
)
cs1, cs2 = CS.create("a", id="a"), CS.create("b", id="b")
assert isinstance(cs1, rx.Component)
assert isinstance(cs2, rx.Component)
assert cs1.State is not None
assert cs2.State is not None
assert cs1.State != cs2.State
assert issubclass(cs1.State, CS)
assert issubclass(cs1.State, rx.State)
assert issubclass(cs2.State, CS)
assert issubclass(cs2.State, rx.State)
assert CS._per_component_state_instance_count == 2
assert isinstance(cs1.State.increment, rx.event.EventHandler)
assert cs1.State.increment != cs2.State.increment
assert len(cs1.children) == 1
assert cs1.children[0].render() == Bare.create("{`a`}").render()
assert cs1.id == "a"
assert len(cs2.children) == 1
assert cs2.children[0].render() == Bare.create("{`b`}").render()
assert cs2.id == "b"