diff --git a/integration/test_state_inheritance.py b/integration/test_state_inheritance.py index 9ceafa228..476efa466 100644 --- a/integration/test_state_inheritance.py +++ b/integration/test_state_inheritance.py @@ -1,5 +1,6 @@ """Test state inheritance.""" +import time from typing import Generator import pytest @@ -8,30 +9,57 @@ from selenium.webdriver.common.by import By from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver +def raises_alert(driver: WebDriver, element: str) -> None: + """Click an element and check that an alert is raised. + + Args: + driver: WebDriver instance. + element: The element to click. + """ + btn = driver.find_element(By.ID, element) + btn.click() + time.sleep(0.2) # wait for the alert to appear + alert = driver.switch_to.alert + assert alert.text == "clicked" + alert.accept() + + def StateInheritance(): """Test that state inheritance works as expected.""" import reflex as rx class ChildMixin: - child_mixin: str = "child_mixin" + # mixin basevars only work with pydantic/rx.Base models + # child_mixin: str = "child_mixin" @rx.var def computed_child_mixin(self) -> str: return "computed_child_mixin" class Mixin(ChildMixin): - mixin: str = "mixin" + # mixin basevars only work with pydantic/rx.Base models + # mixin: str = "mixin" @rx.var def computed_mixin(self) -> str: return "computed_mixin" + def on_click_mixin(self): + return rx.call_script("alert('clicked')") + class OtherMixin(rx.Base): other_mixin: str = "other_mixin" + other_mixin_clicks: int = 0 @rx.var def computed_other_mixin(self) -> str: - return "computed_other_mixin" + return self.other_mixin + + def on_click_other_mixin(self): + self.other_mixin_clicks += 1 + self.other_mixin = ( + f"{self.__class__.__name__}.clicked.{self.other_mixin_clicks}" + ) class Base1(rx.State, Mixin): base1: str = "base1" @@ -65,30 +93,49 @@ def StateInheritance(): rx.input( id="token", value=Base1.router.session.client_token, is_read_only=True ), + # Base 1 rx.heading(Base1.computed_mixin, id="base1-computed_mixin"), rx.heading(Base1.computed_basevar, id="base1-computed_basevar"), rx.heading(Base1.computed_child_mixin, id="base1-child-mixin"), rx.heading(Base1.base1, id="base1-base1"), - rx.heading(Base1.mixin, id="base1-mixin"), - rx.heading(Base1.child_mixin, id="base1-child_mixin"), + rx.button( + "Base1.on_click_mixin", + on_click=Base1.on_click_mixin, # type: ignore + id="base1-mixin-btn", + ), + # Base 2 rx.heading(Base2.computed_basevar, id="base2-computed_basevar"), rx.heading(Base2.base2, id="base2-base2"), + # Child 1 rx.heading(Child1.computed_basevar, id="child1-computed_basevar"), rx.heading(Child1.computed_mixin, id="child1-computed_mixin"), rx.heading(Child1.computed_other_mixin, id="child1-other-mixin"), rx.heading(Child1.computed_child_mixin, id="child1-child-mixin"), rx.heading(Child1.base1, id="child1-base1"), - rx.heading(Child1.mixin, id="child1-mixin"), rx.heading(Child1.other_mixin, id="child1-other_mixin"), - rx.heading(Child1.child_mixin, id="child1-child_mixin"), + rx.button( + "Child1.on_click_other_mixin", + on_click=Child1.on_click_other_mixin, # type: ignore + id="child1-other-mixin-btn", + ), + # Child 2 rx.heading(Child2.computed_basevar, id="child2-computed_basevar"), rx.heading(Child2.computed_mixin, id="child2-computed_mixin"), rx.heading(Child2.computed_other_mixin, id="child2-other-mixin"), rx.heading(Child2.computed_child_mixin, id="child2-child-mixin"), rx.heading(Child2.base2, id="child2-base2"), - rx.heading(Child2.mixin, id="child2-mixin"), rx.heading(Child2.other_mixin, id="child2-other_mixin"), - rx.heading(Child2.child_mixin, id="child2-child_mixin"), + rx.button( + "Child2.on_click_mixin", + on_click=Child2.on_click_mixin, # type: ignore + id="child2-mixin-btn", + ), + rx.button( + "Child2.on_click_other_mixin", + on_click=Child2.on_click_other_mixin, # type: ignore + id="child2-other-mixin-btn", + ), + # Child 3 rx.heading(Child3.computed_basevar, id="child3-computed_basevar"), rx.heading(Child3.computed_mixin, id="child3-computed_mixin"), rx.heading(Child3.computed_other_mixin, id="child3-other-mixin"), @@ -96,9 +143,17 @@ def StateInheritance(): rx.heading(Child3.computed_child_mixin, id="child3-child-mixin"), rx.heading(Child3.child3, id="child3-child3"), rx.heading(Child3.base2, id="child3-base2"), - rx.heading(Child3.mixin, id="child3-mixin"), rx.heading(Child3.other_mixin, id="child3-other_mixin"), - rx.heading(Child3.child_mixin, id="child3-child_mixin"), + rx.button( + "Child3.on_click_mixin", + on_click=Child3.on_click_mixin, # type: ignore + id="child3-mixin-btn", + ), + rx.button( + "Child3.on_click_other_mixin", + on_click=Child3.on_click_other_mixin, # type: ignore + id="child3-other-mixin-btn", + ), ) app = rx.App() @@ -178,6 +233,8 @@ def test_state_inheritance( """ assert state_inheritance.app_instance is not None + # Initial State values Test + # Base 1 base1_mixin = driver.find_element(By.ID, "base1-computed_mixin") assert base1_mixin.text == "computed_mixin" @@ -190,18 +247,14 @@ def test_state_inheritance( base1_base1 = driver.find_element(By.ID, "base1-base1") assert base1_base1.text == "base1" - base1_mixin = driver.find_element(By.ID, "base1-mixin") - assert base1_mixin.text == "mixin" - - base1_child_mixin = driver.find_element(By.ID, "base1-child_mixin") - assert base1_child_mixin.text == "child_mixin" - + # Base 2 base2_computed_basevar = driver.find_element(By.ID, "base2-computed_basevar") assert base2_computed_basevar.text == "computed_basevar2" base2_base2 = driver.find_element(By.ID, "base2-base2") assert base2_base2.text == "base2" + # Child 1 child1_computed_basevar = driver.find_element(By.ID, "child1-computed_basevar") assert child1_computed_basevar.text == "computed_basevar1" @@ -209,7 +262,7 @@ def test_state_inheritance( assert child1_mixin.text == "computed_mixin" child1_computed_other_mixin = driver.find_element(By.ID, "child1-other-mixin") - assert child1_computed_other_mixin.text == "computed_other_mixin" + assert child1_computed_other_mixin.text == "other_mixin" child1_computed_child_mixin = driver.find_element(By.ID, "child1-child-mixin") assert child1_computed_child_mixin.text == "computed_child_mixin" @@ -217,15 +270,10 @@ def test_state_inheritance( child1_base1 = driver.find_element(By.ID, "child1-base1") assert child1_base1.text == "base1" - child1_mixin = driver.find_element(By.ID, "child1-mixin") - assert child1_mixin.text == "mixin" - child1_other_mixin = driver.find_element(By.ID, "child1-other_mixin") assert child1_other_mixin.text == "other_mixin" - child1_child_mixin = driver.find_element(By.ID, "child1-child_mixin") - assert child1_child_mixin.text == "child_mixin" - + # Child 2 child2_computed_basevar = driver.find_element(By.ID, "child2-computed_basevar") assert child2_computed_basevar.text == "computed_basevar2" @@ -233,7 +281,7 @@ def test_state_inheritance( assert child2_mixin.text == "computed_mixin" child2_computed_other_mixin = driver.find_element(By.ID, "child2-other-mixin") - assert child2_computed_other_mixin.text == "computed_other_mixin" + assert child2_computed_other_mixin.text == "other_mixin" child2_computed_child_mixin = driver.find_element(By.ID, "child2-child-mixin") assert child2_computed_child_mixin.text == "computed_child_mixin" @@ -241,15 +289,10 @@ def test_state_inheritance( child2_base2 = driver.find_element(By.ID, "child2-base2") assert child2_base2.text == "base2" - child2_mixin = driver.find_element(By.ID, "child2-mixin") - assert child2_mixin.text == "mixin" - child2_other_mixin = driver.find_element(By.ID, "child2-other_mixin") assert child2_other_mixin.text == "other_mixin" - child2_child_mixin = driver.find_element(By.ID, "child2-child_mixin") - assert child2_child_mixin.text == "child_mixin" - + # Child 3 child3_computed_basevar = driver.find_element(By.ID, "child3-computed_basevar") assert child3_computed_basevar.text == "computed_basevar2" @@ -257,7 +300,7 @@ def test_state_inheritance( assert child3_mixin.text == "computed_mixin" child3_computed_other_mixin = driver.find_element(By.ID, "child3-other-mixin") - assert child3_computed_other_mixin.text == "computed_other_mixin" + assert child3_computed_other_mixin.text == "other_mixin" child3_computed_childvar = driver.find_element(By.ID, "child3-computed_childvar") assert child3_computed_childvar.text == "computed_childvar" @@ -271,11 +314,59 @@ def test_state_inheritance( child3_base2 = driver.find_element(By.ID, "child3-base2") assert child3_base2.text == "base2" - child3_mixin = driver.find_element(By.ID, "child3-mixin") - assert child3_mixin.text == "mixin" - child3_other_mixin = driver.find_element(By.ID, "child3-other_mixin") assert child3_other_mixin.text == "other_mixin" - child3_child_mixin = driver.find_element(By.ID, "child3-child_mixin") - assert child3_child_mixin.text == "child_mixin" + # Event Handler Tests + raises_alert(driver, "base1-mixin-btn") + raises_alert(driver, "child2-mixin-btn") + raises_alert(driver, "child3-mixin-btn") + + child1_other_mixin_btn = driver.find_element(By.ID, "child1-other-mixin-btn") + child1_other_mixin_btn.click() + child1_other_mixin_value = state_inheritance.poll_for_content( + child1_other_mixin, exp_not_equal="other_mixin" + ) + child1_computed_mixin_value = state_inheritance.poll_for_content( + child1_computed_other_mixin, exp_not_equal="other_mixin" + ) + assert child1_other_mixin_value == "Child1.clicked.1" + assert child1_computed_mixin_value == "Child1.clicked.1" + + child2_other_mixin_btn = driver.find_element(By.ID, "child2-other-mixin-btn") + child2_other_mixin_btn.click() + child2_other_mixin_value = state_inheritance.poll_for_content( + child2_other_mixin, exp_not_equal="other_mixin" + ) + child2_computed_mixin_value = state_inheritance.poll_for_content( + child2_computed_other_mixin, exp_not_equal="other_mixin" + ) + child3_other_mixin_value = state_inheritance.poll_for_content( + child3_other_mixin, exp_not_equal="other_mixin" + ) + child3_computed_mixin_value = state_inheritance.poll_for_content( + child3_computed_other_mixin, exp_not_equal="other_mixin" + ) + assert child2_other_mixin_value == "Child2.clicked.1" + assert child2_computed_mixin_value == "Child2.clicked.1" + assert child3_other_mixin_value == "Child2.clicked.1" + assert child3_computed_mixin_value == "Child2.clicked.1" + + child3_other_mixin_btn = driver.find_element(By.ID, "child3-other-mixin-btn") + child3_other_mixin_btn.click() + child2_other_mixin_value = state_inheritance.poll_for_content( + child2_other_mixin, exp_not_equal="other_mixin" + ) + child2_computed_mixin_value = state_inheritance.poll_for_content( + child2_computed_other_mixin, exp_not_equal="other_mixin" + ) + child3_other_mixin_value = state_inheritance.poll_for_content( + child3_other_mixin, exp_not_equal="other_mixin" + ) + child3_computed_mixin_value = state_inheritance.poll_for_content( + child3_computed_other_mixin, exp_not_equal="other_mixin" + ) + assert child2_other_mixin_value == "Child2.clicked.2" + assert child2_computed_mixin_value == "Child2.clicked.2" + assert child3_other_mixin.text == "Child2.clicked.2" + assert child3_computed_other_mixin.text == "Child2.clicked.2" diff --git a/reflex/state.py b/reflex/state.py index cc7b153d9..51b925fe0 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -332,9 +332,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): } cls.computed_vars = { v._var_name: v._var_set_state(cls) - for mixin in cls.__mro__ - if mixin is cls or not issubclass(mixin, (BaseState, ABC)) - for v in mixin.__dict__.values() + for v in cls.__dict__.values() if isinstance(v, ComputedVar) } cls.vars = { @@ -352,10 +350,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): events = { name: fn for name, fn in cls.__dict__.items() - if not name.startswith("_") - and isinstance(fn, Callable) - and not isinstance(fn, EventHandler) + if cls._item_is_event_handler(name, fn) } + + for mixin in cls._mixins(): + for name, value in mixin.__dict__.items(): + if isinstance(value, ComputedVar): + fget = cls._copy_fn(value.fget) + newcv = ComputedVar(fget=fget, _var_name=value._var_name) + newcv._var_set_state(cls) + setattr(cls, name, newcv) + cls.computed_vars[newcv._var_name] = newcv + cls.vars[newcv._var_name] = newcv + continue + if events.get(name) is not None: + continue + if not cls._item_is_event_handler(name, value): + continue + if parent_state is not None and parent_state.event_handlers.get(name): + continue + value = cls._copy_fn(value) + value.__qualname__ = f"{cls.__name__}.{name}" + events[name] = value + for name, fn in events.items(): handler = EventHandler(fn=fn) cls.event_handlers[name] = handler @@ -363,6 +380,58 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): cls._init_var_dependency_dicts() + @staticmethod + def _copy_fn(fn: Callable) -> Callable: + """Copy a function. Used to copy ComputedVars and EventHandlers from mixins. + + Args: + fn: The function to copy. + + Returns: + The copied function. + """ + newfn = FunctionType( + fn.__code__, + fn.__globals__, + name=fn.__name__, + argdefs=fn.__defaults__, + closure=fn.__closure__, + ) + newfn.__annotations__ = fn.__annotations__ + return newfn + + @staticmethod + def _item_is_event_handler(name: str, value: Any) -> bool: + """Check if the item is an event handler. + + Args: + name: The name of the item. + value: The value of the item. + + Returns: + Whether the item is an event handler. + """ + return ( + not name.startswith("_") + and isinstance(value, Callable) + and not isinstance(value, EventHandler) + and hasattr(value, "__code__") + ) + + @classmethod + def _mixins(cls) -> List[Type]: + """Get the mixin classes of the state. + + Returns: + The mixin classes of the state. + """ + return [ + mixin + for mixin in cls.__mro__ + if not issubclass(mixin, (BaseState, ABC)) + and mixin not in [pydantic.BaseModel, Base] + ] + @classmethod def _init_var_dependency_dicts(cls): """Initialize the var dependency tracking dicts. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 6ec1bd987..fc5d7e100 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -19,6 +19,7 @@ from typing import ( Set, Type, Union, + overload, _GenericAlias, # type: ignore ) @@ -136,6 +137,16 @@ class ComputedVar(Var): def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ... def mark_dirty(self, instance) -> None: ... def _determine_var_type(self) -> Type: ... + @overload + def __init__( + self, + fget: Callable[[BaseState], Any], + fset: Callable[[BaseState, Any], None] | None = None, + fdel: Callable[[BaseState], Any] | None = None, + doc: str | None = None, + **kwargs, + ) -> None: ... + @overload def __init__(self, func) -> None: ... def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...