[REF-2265] ComponentState: scaffold for copying State per Component instance (#2923)
* [REF-2265] ComponentState: scaffold for copying State per Component instance Define a base ComponentState which can be used to easily create copies of the given State definition (Vars and EventHandlers) that are tied to a particular instance of a Component (returned by get_component) * Define `State` field on `Component` for typing compatibility. This is an Optional field of Type[State] and is populated by ComponentState. * Add integration/test_component_state.py Create two independent counters and increment them separately * Add unit test for ComponentState
This commit is contained in:
parent
f372402ee4
commit
5510eaf820
107
integration/test_component_state.py
Normal file
107
integration/test_component_state.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""Test that per-component state scaffold works and operates independently."""
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from selenium.webdriver.common.by import By
|
||||
|
||||
from reflex.testing import AppHarness
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
def ComponentStateApp():
|
||||
"""App using per component state."""
|
||||
import reflex as rx
|
||||
|
||||
class MultiCounter(rx.ComponentState):
|
||||
count: int = 0
|
||||
|
||||
def increment(self):
|
||||
self.count += 1
|
||||
|
||||
@classmethod
|
||||
def get_component(cls, *children, **props):
|
||||
return rx.vstack(
|
||||
*children,
|
||||
rx.heading(cls.count, id=f"count-{props.get('id', 'default')}"),
|
||||
rx.button(
|
||||
"Increment",
|
||||
on_click=cls.increment,
|
||||
id=f"button-{props.get('id', 'default')}",
|
||||
),
|
||||
**props,
|
||||
)
|
||||
|
||||
app = rx.App(state=rx.State) # noqa
|
||||
|
||||
@rx.page()
|
||||
def index():
|
||||
mc_a = MultiCounter.create(id="a")
|
||||
mc_b = MultiCounter.create(id="b")
|
||||
assert mc_a.State != mc_b.State
|
||||
return rx.vstack(
|
||||
mc_a,
|
||||
mc_b,
|
||||
rx.button(
|
||||
"Inc A",
|
||||
on_click=mc_a.State.increment, # type: ignore
|
||||
id="inc-a",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def component_state_app(tmp_path) -> Generator[AppHarness, None, None]:
|
||||
"""Start ComponentStateApp app at tmp_path via AppHarness.
|
||||
|
||||
Args:
|
||||
tmp_path: pytest tmp_path fixture
|
||||
|
||||
Yields:
|
||||
running AppHarness instance
|
||||
"""
|
||||
with AppHarness.create(
|
||||
root=tmp_path,
|
||||
app_source=ComponentStateApp, # type: ignore
|
||||
) as harness:
|
||||
yield harness
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_component_state_app(component_state_app: AppHarness):
|
||||
"""Increment counters independently.
|
||||
|
||||
Args:
|
||||
component_state_app: harness for ComponentStateApp app
|
||||
"""
|
||||
assert component_state_app.app_instance is not None, "app is not running"
|
||||
driver = component_state_app.frontend()
|
||||
|
||||
ss = utils.SessionStorage(driver)
|
||||
token = AppHarness._poll_for(lambda: ss.get("token") is not None)
|
||||
assert token is not None
|
||||
|
||||
count_a = driver.find_element(By.ID, "count-a")
|
||||
count_b = driver.find_element(By.ID, "count-b")
|
||||
button_a = driver.find_element(By.ID, "button-a")
|
||||
button_b = driver.find_element(By.ID, "button-b")
|
||||
button_inc_a = driver.find_element(By.ID, "inc-a")
|
||||
|
||||
assert count_a.text == "0"
|
||||
|
||||
button_a.click()
|
||||
assert component_state_app.poll_for_content(count_a, exp_not_equal="0") == "1"
|
||||
|
||||
button_a.click()
|
||||
assert component_state_app.poll_for_content(count_a, exp_not_equal="1") == "2"
|
||||
|
||||
button_inc_a.click()
|
||||
assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3"
|
||||
|
||||
assert count_b.text == "0"
|
||||
|
||||
button_b.click()
|
||||
assert component_state_app.poll_for_content(count_b, exp_not_equal="0") == "1"
|
||||
|
||||
button_b.click()
|
||||
assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2"
|
@ -38,6 +38,8 @@ class LocalStorage:
|
||||
https://stackoverflow.com/a/46361900
|
||||
"""
|
||||
|
||||
storage_key = "localStorage"
|
||||
|
||||
def __init__(self, driver: WebDriver):
|
||||
"""Initialize the class.
|
||||
|
||||
@ -171,3 +173,12 @@ class LocalStorage:
|
||||
An iterator over the items in local storage.
|
||||
"""
|
||||
return iter(self.keys())
|
||||
|
||||
|
||||
class SessionStorage(LocalStorage):
|
||||
"""Class to access session storage.
|
||||
|
||||
https://stackoverflow.com/a/46361900
|
||||
"""
|
||||
|
||||
storage_key = "sessionStorage"
|
||||
|
@ -153,7 +153,14 @@ _MAPPING = {
|
||||
"reflex.model": ["model", "session", "Model"],
|
||||
"reflex.page": ["page"],
|
||||
"reflex.route": ["route"],
|
||||
"reflex.state": ["state", "var", "Cookie", "LocalStorage", "State"],
|
||||
"reflex.state": [
|
||||
"state",
|
||||
"var",
|
||||
"Cookie",
|
||||
"LocalStorage",
|
||||
"ComponentState",
|
||||
"State",
|
||||
],
|
||||
"reflex.style": ["style", "toggle_color_mode"],
|
||||
"reflex.testing": ["testing"],
|
||||
"reflex.utils": ["utils"],
|
||||
|
@ -141,6 +141,7 @@ from reflex import state as state
|
||||
from reflex.state import var as var
|
||||
from reflex.state import Cookie as Cookie
|
||||
from reflex.state import LocalStorage as LocalStorage
|
||||
from reflex.state import ComponentState as ComponentState
|
||||
from reflex.state import State as State
|
||||
from reflex import style as style
|
||||
from reflex.style import toggle_color_mode as toggle_color_mode
|
||||
|
@ -21,6 +21,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import reflex.state
|
||||
from reflex.base import Base
|
||||
from reflex.compiler.templates import STATEFUL_COMPONENT
|
||||
from reflex.components.tags import Tag
|
||||
@ -214,6 +215,9 @@ class Component(BaseComponent, ABC):
|
||||
# When to memoize this component and its children.
|
||||
_memoization_mode: MemoizationMode = MemoizationMode()
|
||||
|
||||
# State class associated with this component instance
|
||||
State: Optional[Type[reflex.state.State]] = None
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""Set default properties.
|
||||
|
@ -15,6 +15,7 @@ from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from types import FunctionType, MethodType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
@ -47,6 +48,10 @@ from reflex.utils.exec import is_testing_env
|
||||
from reflex.utils.serializers import SerializedType, serialize, serializer
|
||||
from reflex.vars import BaseVar, ComputedVar, Var, computed_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflex.components.component import Component
|
||||
|
||||
|
||||
Delta = Dict[str, Any]
|
||||
var = computed_var
|
||||
|
||||
@ -1835,6 +1840,45 @@ class OnLoadInternalState(State):
|
||||
]
|
||||
|
||||
|
||||
class ComponentState(Base):
|
||||
"""The base class for a State that is copied for each Component associated with it."""
|
||||
|
||||
_per_component_state_instance_count: ClassVar[int] = 0
|
||||
|
||||
@classmethod
|
||||
def get_component(cls, *children, **props) -> "Component":
|
||||
"""Get the component instance.
|
||||
|
||||
Args:
|
||||
children: The children of the component.
|
||||
props: The props of the component.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if the subclass does not override this method.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{cls.__name__} must implement get_component to return the component instance."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(cls, *children, **props) -> "Component":
|
||||
"""Create a new instance of the Component.
|
||||
|
||||
Args:
|
||||
children: The children of the component.
|
||||
props: The props of the component.
|
||||
|
||||
Returns:
|
||||
A new instance of the Component with an independent copy of the State.
|
||||
"""
|
||||
cls._per_component_state_instance_count += 1
|
||||
state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
|
||||
component_state = type(state_cls_name, (cls, State), {})
|
||||
component = component_state.get_component(*children, **props)
|
||||
component.State = component_state
|
||||
return component
|
||||
|
||||
|
||||
class StateProxy(wrapt.ObjectProxy):
|
||||
"""Proxy of a state instance to control mutability of vars for a background task.
|
||||
|
||||
|
@ -57,6 +57,7 @@ EXCLUDED_PROPS = [
|
||||
"_rename_props",
|
||||
"_valid_children",
|
||||
"_valid_parents",
|
||||
"State",
|
||||
]
|
||||
|
||||
DEFAULT_TYPING_IMPORTS = {
|
||||
|
42
tests/components/test_component_state.py
Normal file
42
tests/components/test_component_state.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""Ensure that Components returned by ComponentState.create have independent State classes."""
|
||||
|
||||
import reflex as rx
|
||||
from reflex.components.base.bare import Bare
|
||||
|
||||
|
||||
def test_component_state():
|
||||
"""Create two components with independent state classes."""
|
||||
|
||||
class CS(rx.ComponentState):
|
||||
count: int = 0
|
||||
|
||||
def increment(self):
|
||||
self.count += 1
|
||||
|
||||
@classmethod
|
||||
def get_component(cls, *children, **props):
|
||||
return rx.el.div(
|
||||
*children,
|
||||
**props,
|
||||
)
|
||||
|
||||
cs1, cs2 = CS.create("a", id="a"), CS.create("b", id="b")
|
||||
assert isinstance(cs1, rx.Component)
|
||||
assert isinstance(cs2, rx.Component)
|
||||
assert cs1.State is not None
|
||||
assert cs2.State is not None
|
||||
assert cs1.State != cs2.State
|
||||
assert issubclass(cs1.State, CS)
|
||||
assert issubclass(cs1.State, rx.State)
|
||||
assert issubclass(cs2.State, CS)
|
||||
assert issubclass(cs2.State, rx.State)
|
||||
assert CS._per_component_state_instance_count == 2
|
||||
assert isinstance(cs1.State.increment, rx.event.EventHandler)
|
||||
assert cs1.State.increment != cs2.State.increment
|
||||
|
||||
assert len(cs1.children) == 1
|
||||
assert cs1.children[0].render() == Bare.create("{`a`}").render()
|
||||
assert cs1.id == "a"
|
||||
assert len(cs2.children) == 1
|
||||
assert cs2.children[0].render() == Bare.create("{`b`}").render()
|
||||
assert cs2.id == "b"
|
Loading…
Reference in New Issue
Block a user