[REF-2265] ComponentState: scaffold for copying State per Component instance (#2923)
* [REF-2265] ComponentState: scaffold for copying State per Component instance Define a base ComponentState which can be used to easily create copies of the given State definition (Vars and EventHandlers) that are tied to a particular instance of a Component (returned by get_component) * Define `State` field on `Component` for typing compatibility. This is an Optional field of Type[State] and is populated by ComponentState. * Add integration/test_component_state.py Create two independent counters and increment them separately * Add unit test for ComponentState
This commit is contained in:
parent
f372402ee4
commit
5510eaf820
107
integration/test_component_state.py
Normal file
107
integration/test_component_state.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
"""Test that per-component state scaffold works and operates independently."""
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from selenium.webdriver.common.by import By
|
||||||
|
|
||||||
|
from reflex.testing import AppHarness
|
||||||
|
|
||||||
|
from . import utils
|
||||||
|
|
||||||
|
|
||||||
|
def ComponentStateApp():
|
||||||
|
"""App using per component state."""
|
||||||
|
import reflex as rx
|
||||||
|
|
||||||
|
class MultiCounter(rx.ComponentState):
|
||||||
|
count: int = 0
|
||||||
|
|
||||||
|
def increment(self):
|
||||||
|
self.count += 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_component(cls, *children, **props):
|
||||||
|
return rx.vstack(
|
||||||
|
*children,
|
||||||
|
rx.heading(cls.count, id=f"count-{props.get('id', 'default')}"),
|
||||||
|
rx.button(
|
||||||
|
"Increment",
|
||||||
|
on_click=cls.increment,
|
||||||
|
id=f"button-{props.get('id', 'default')}",
|
||||||
|
),
|
||||||
|
**props,
|
||||||
|
)
|
||||||
|
|
||||||
|
app = rx.App(state=rx.State) # noqa
|
||||||
|
|
||||||
|
@rx.page()
|
||||||
|
def index():
|
||||||
|
mc_a = MultiCounter.create(id="a")
|
||||||
|
mc_b = MultiCounter.create(id="b")
|
||||||
|
assert mc_a.State != mc_b.State
|
||||||
|
return rx.vstack(
|
||||||
|
mc_a,
|
||||||
|
mc_b,
|
||||||
|
rx.button(
|
||||||
|
"Inc A",
|
||||||
|
on_click=mc_a.State.increment, # type: ignore
|
||||||
|
id="inc-a",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def component_state_app(tmp_path) -> Generator[AppHarness, None, None]:
|
||||||
|
"""Start ComponentStateApp app at tmp_path via AppHarness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmp_path: pytest tmp_path fixture
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
running AppHarness instance
|
||||||
|
"""
|
||||||
|
with AppHarness.create(
|
||||||
|
root=tmp_path,
|
||||||
|
app_source=ComponentStateApp, # type: ignore
|
||||||
|
) as harness:
|
||||||
|
yield harness
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_component_state_app(component_state_app: AppHarness):
|
||||||
|
"""Increment counters independently.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_state_app: harness for ComponentStateApp app
|
||||||
|
"""
|
||||||
|
assert component_state_app.app_instance is not None, "app is not running"
|
||||||
|
driver = component_state_app.frontend()
|
||||||
|
|
||||||
|
ss = utils.SessionStorage(driver)
|
||||||
|
token = AppHarness._poll_for(lambda: ss.get("token") is not None)
|
||||||
|
assert token is not None
|
||||||
|
|
||||||
|
count_a = driver.find_element(By.ID, "count-a")
|
||||||
|
count_b = driver.find_element(By.ID, "count-b")
|
||||||
|
button_a = driver.find_element(By.ID, "button-a")
|
||||||
|
button_b = driver.find_element(By.ID, "button-b")
|
||||||
|
button_inc_a = driver.find_element(By.ID, "inc-a")
|
||||||
|
|
||||||
|
assert count_a.text == "0"
|
||||||
|
|
||||||
|
button_a.click()
|
||||||
|
assert component_state_app.poll_for_content(count_a, exp_not_equal="0") == "1"
|
||||||
|
|
||||||
|
button_a.click()
|
||||||
|
assert component_state_app.poll_for_content(count_a, exp_not_equal="1") == "2"
|
||||||
|
|
||||||
|
button_inc_a.click()
|
||||||
|
assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3"
|
||||||
|
|
||||||
|
assert count_b.text == "0"
|
||||||
|
|
||||||
|
button_b.click()
|
||||||
|
assert component_state_app.poll_for_content(count_b, exp_not_equal="0") == "1"
|
||||||
|
|
||||||
|
button_b.click()
|
||||||
|
assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2"
|
@ -38,6 +38,8 @@ class LocalStorage:
|
|||||||
https://stackoverflow.com/a/46361900
|
https://stackoverflow.com/a/46361900
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
storage_key = "localStorage"
|
||||||
|
|
||||||
def __init__(self, driver: WebDriver):
|
def __init__(self, driver: WebDriver):
|
||||||
"""Initialize the class.
|
"""Initialize the class.
|
||||||
|
|
||||||
@ -171,3 +173,12 @@ class LocalStorage:
|
|||||||
An iterator over the items in local storage.
|
An iterator over the items in local storage.
|
||||||
"""
|
"""
|
||||||
return iter(self.keys())
|
return iter(self.keys())
|
||||||
|
|
||||||
|
|
||||||
|
class SessionStorage(LocalStorage):
|
||||||
|
"""Class to access session storage.
|
||||||
|
|
||||||
|
https://stackoverflow.com/a/46361900
|
||||||
|
"""
|
||||||
|
|
||||||
|
storage_key = "sessionStorage"
|
||||||
|
@ -153,7 +153,14 @@ _MAPPING = {
|
|||||||
"reflex.model": ["model", "session", "Model"],
|
"reflex.model": ["model", "session", "Model"],
|
||||||
"reflex.page": ["page"],
|
"reflex.page": ["page"],
|
||||||
"reflex.route": ["route"],
|
"reflex.route": ["route"],
|
||||||
"reflex.state": ["state", "var", "Cookie", "LocalStorage", "State"],
|
"reflex.state": [
|
||||||
|
"state",
|
||||||
|
"var",
|
||||||
|
"Cookie",
|
||||||
|
"LocalStorage",
|
||||||
|
"ComponentState",
|
||||||
|
"State",
|
||||||
|
],
|
||||||
"reflex.style": ["style", "toggle_color_mode"],
|
"reflex.style": ["style", "toggle_color_mode"],
|
||||||
"reflex.testing": ["testing"],
|
"reflex.testing": ["testing"],
|
||||||
"reflex.utils": ["utils"],
|
"reflex.utils": ["utils"],
|
||||||
|
@ -141,6 +141,7 @@ from reflex import state as state
|
|||||||
from reflex.state import var as var
|
from reflex.state import var as var
|
||||||
from reflex.state import Cookie as Cookie
|
from reflex.state import Cookie as Cookie
|
||||||
from reflex.state import LocalStorage as LocalStorage
|
from reflex.state import LocalStorage as LocalStorage
|
||||||
|
from reflex.state import ComponentState as ComponentState
|
||||||
from reflex.state import State as State
|
from reflex.state import State as State
|
||||||
from reflex import style as style
|
from reflex import style as style
|
||||||
from reflex.style import toggle_color_mode as toggle_color_mode
|
from reflex.style import toggle_color_mode as toggle_color_mode
|
||||||
|
@ -21,6 +21,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import reflex.state
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.compiler.templates import STATEFUL_COMPONENT
|
from reflex.compiler.templates import STATEFUL_COMPONENT
|
||||||
from reflex.components.tags import Tag
|
from reflex.components.tags import Tag
|
||||||
@ -214,6 +215,9 @@ class Component(BaseComponent, ABC):
|
|||||||
# When to memoize this component and its children.
|
# When to memoize this component and its children.
|
||||||
_memoization_mode: MemoizationMode = MemoizationMode()
|
_memoization_mode: MemoizationMode = MemoizationMode()
|
||||||
|
|
||||||
|
# State class associated with this component instance
|
||||||
|
State: Optional[Type[reflex.state.State]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
"""Set default properties.
|
"""Set default properties.
|
||||||
|
@ -15,6 +15,7 @@ from abc import ABC, abstractmethod
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from types import FunctionType, MethodType
|
from types import FunctionType, MethodType
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
@ -47,6 +48,10 @@ from reflex.utils.exec import is_testing_env
|
|||||||
from reflex.utils.serializers import SerializedType, serialize, serializer
|
from reflex.utils.serializers import SerializedType, serialize, serializer
|
||||||
from reflex.vars import BaseVar, ComputedVar, Var, computed_var
|
from reflex.vars import BaseVar, ComputedVar, Var, computed_var
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from reflex.components.component import Component
|
||||||
|
|
||||||
|
|
||||||
Delta = Dict[str, Any]
|
Delta = Dict[str, Any]
|
||||||
var = computed_var
|
var = computed_var
|
||||||
|
|
||||||
@ -1835,6 +1840,45 @@ class OnLoadInternalState(State):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentState(Base):
|
||||||
|
"""The base class for a State that is copied for each Component associated with it."""
|
||||||
|
|
||||||
|
_per_component_state_instance_count: ClassVar[int] = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_component(cls, *children, **props) -> "Component":
|
||||||
|
"""Get the component instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
children: The children of the component.
|
||||||
|
props: The props of the component.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: if the subclass does not override this method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{cls.__name__} must implement get_component to return the component instance."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, *children, **props) -> "Component":
|
||||||
|
"""Create a new instance of the Component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
children: The children of the component.
|
||||||
|
props: The props of the component.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of the Component with an independent copy of the State.
|
||||||
|
"""
|
||||||
|
cls._per_component_state_instance_count += 1
|
||||||
|
state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
|
||||||
|
component_state = type(state_cls_name, (cls, State), {})
|
||||||
|
component = component_state.get_component(*children, **props)
|
||||||
|
component.State = component_state
|
||||||
|
return component
|
||||||
|
|
||||||
|
|
||||||
class StateProxy(wrapt.ObjectProxy):
|
class StateProxy(wrapt.ObjectProxy):
|
||||||
"""Proxy of a state instance to control mutability of vars for a background task.
|
"""Proxy of a state instance to control mutability of vars for a background task.
|
||||||
|
|
||||||
|
@ -57,6 +57,7 @@ EXCLUDED_PROPS = [
|
|||||||
"_rename_props",
|
"_rename_props",
|
||||||
"_valid_children",
|
"_valid_children",
|
||||||
"_valid_parents",
|
"_valid_parents",
|
||||||
|
"State",
|
||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_TYPING_IMPORTS = {
|
DEFAULT_TYPING_IMPORTS = {
|
||||||
|
42
tests/components/test_component_state.py
Normal file
42
tests/components/test_component_state.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
"""Ensure that Components returned by ComponentState.create have independent State classes."""
|
||||||
|
|
||||||
|
import reflex as rx
|
||||||
|
from reflex.components.base.bare import Bare
|
||||||
|
|
||||||
|
|
||||||
|
def test_component_state():
|
||||||
|
"""Create two components with independent state classes."""
|
||||||
|
|
||||||
|
class CS(rx.ComponentState):
|
||||||
|
count: int = 0
|
||||||
|
|
||||||
|
def increment(self):
|
||||||
|
self.count += 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_component(cls, *children, **props):
|
||||||
|
return rx.el.div(
|
||||||
|
*children,
|
||||||
|
**props,
|
||||||
|
)
|
||||||
|
|
||||||
|
cs1, cs2 = CS.create("a", id="a"), CS.create("b", id="b")
|
||||||
|
assert isinstance(cs1, rx.Component)
|
||||||
|
assert isinstance(cs2, rx.Component)
|
||||||
|
assert cs1.State is not None
|
||||||
|
assert cs2.State is not None
|
||||||
|
assert cs1.State != cs2.State
|
||||||
|
assert issubclass(cs1.State, CS)
|
||||||
|
assert issubclass(cs1.State, rx.State)
|
||||||
|
assert issubclass(cs2.State, CS)
|
||||||
|
assert issubclass(cs2.State, rx.State)
|
||||||
|
assert CS._per_component_state_instance_count == 2
|
||||||
|
assert isinstance(cs1.State.increment, rx.event.EventHandler)
|
||||||
|
assert cs1.State.increment != cs2.State.increment
|
||||||
|
|
||||||
|
assert len(cs1.children) == 1
|
||||||
|
assert cs1.children[0].render() == Bare.create("{`a`}").render()
|
||||||
|
assert cs1.id == "a"
|
||||||
|
assert len(cs2.children) == 1
|
||||||
|
assert cs2.children[0].render() == Bare.create("{`b`}").render()
|
||||||
|
assert cs2.id == "b"
|
Loading…
Reference in New Issue
Block a user