Get default for backend var defined in mixin (#4060)
* Get default for backend var defined in mixin If the backend var is defined in a mixin class, it won't appear in `cls.__dict__`, but the value is still retrievable via `getattr` on `cls`. Prefer to use the actual defined default before using `Var.get_default_value()`. If `Var.get_default_value()` fails, set the default to `None` such that the backend var still gets recognized as a backend var when it is used on `self`. ---- Update test_component_state to include backend vars Extra coverage for backend vars with and without defaults, defined in a ComponentState/mixin class. * fix integration test
This commit is contained in:
parent
aa69234b76
commit
5c0518053d
@ -461,10 +461,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
for name, value in cls.__dict__.items()
|
||||
if types.is_backend_base_variable(name, cls)
|
||||
}
|
||||
# Add annotated backend vars that do not have a default value.
|
||||
# Add annotated backend vars that may not have a default value.
|
||||
new_backend_vars.update(
|
||||
{
|
||||
name: Var("", _var_type=annotation_value).get_default_value()
|
||||
name: cls._get_var_default(name, annotation_value)
|
||||
for name, annotation_value in get_type_hints(cls).items()
|
||||
if name not in new_backend_vars
|
||||
and types.is_backend_base_variable(name, cls)
|
||||
@ -990,6 +990,26 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Ensure frontend uses null coalescing when accessing.
|
||||
object.__setattr__(prop, "_var_type", Optional[prop._var_type])
|
||||
|
||||
@classmethod
|
||||
def _get_var_default(cls, name: str, annotation_value: Any) -> Any:
|
||||
"""Get the default value of a (backend) var.
|
||||
|
||||
Args:
|
||||
name: The name of the var.
|
||||
annotation_value: The annotation value of the var.
|
||||
|
||||
Returns:
|
||||
The default value of the var or None.
|
||||
"""
|
||||
try:
|
||||
return getattr(cls, name)
|
||||
except AttributeError:
|
||||
try:
|
||||
return Var("", _var_type=annotation_value).get_default_value()
|
||||
except TypeError:
|
||||
pass
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_base_functions() -> dict[str, FunctionType]:
|
||||
"""Get all functions of the state class excluding dunder methods.
|
||||
|
@ -5,6 +5,7 @@ from typing import Generator
|
||||
import pytest
|
||||
from selenium.webdriver.common.by import By
|
||||
|
||||
from reflex.state import State, _substate_key
|
||||
from reflex.testing import AppHarness
|
||||
|
||||
from . import utils
|
||||
@ -12,13 +13,21 @@ from . import utils
|
||||
|
||||
def ComponentStateApp():
|
||||
"""App using per component state."""
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import reflex as rx
|
||||
|
||||
class MultiCounter(rx.ComponentState):
|
||||
E = TypeVar("E")
|
||||
|
||||
class MultiCounter(rx.ComponentState, Generic[E]):
|
||||
count: int = 0
|
||||
_be: E
|
||||
_be_int: int
|
||||
_be_str: str = "42"
|
||||
|
||||
def increment(self):
|
||||
self.count += 1
|
||||
self._be = self.count # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_component(cls, *children, **props):
|
||||
@ -48,6 +57,14 @@ def ComponentStateApp():
|
||||
on_click=mc_a.State.increment, # type: ignore
|
||||
id="inc-a",
|
||||
),
|
||||
rx.text(
|
||||
mc_a.State.get_name() if mc_a.State is not None else "",
|
||||
id="a_state_name",
|
||||
),
|
||||
rx.text(
|
||||
mc_b.State.get_name() if mc_b.State is not None else "",
|
||||
id="b_state_name",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -80,6 +97,7 @@ async def test_component_state_app(component_state_app: AppHarness):
|
||||
|
||||
ss = utils.SessionStorage(driver)
|
||||
assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
|
||||
root_state_token = _substate_key(ss.get("token"), State)
|
||||
|
||||
count_a = driver.find_element(By.ID, "count-a")
|
||||
count_b = driver.find_element(By.ID, "count-b")
|
||||
@ -87,6 +105,18 @@ async def test_component_state_app(component_state_app: AppHarness):
|
||||
button_b = driver.find_element(By.ID, "button-b")
|
||||
button_inc_a = driver.find_element(By.ID, "inc-a")
|
||||
|
||||
# Check that backend vars in mixins are okay
|
||||
a_state_name = driver.find_element(By.ID, "a_state_name").text
|
||||
b_state_name = driver.find_element(By.ID, "b_state_name").text
|
||||
root_state = await component_state_app.get_state(root_state_token)
|
||||
a_state = root_state.substates[a_state_name]
|
||||
b_state = root_state.substates[b_state_name]
|
||||
assert a_state._backend_vars == a_state.backend_vars
|
||||
assert a_state._backend_vars == b_state._backend_vars
|
||||
assert a_state._backend_vars["_be"] is None
|
||||
assert a_state._backend_vars["_be_int"] == 0
|
||||
assert a_state._backend_vars["_be_str"] == "42"
|
||||
|
||||
assert count_a.text == "0"
|
||||
|
||||
button_a.click()
|
||||
@ -98,6 +128,14 @@ async def test_component_state_app(component_state_app: AppHarness):
|
||||
button_inc_a.click()
|
||||
assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3"
|
||||
|
||||
root_state = await component_state_app.get_state(root_state_token)
|
||||
a_state = root_state.substates[a_state_name]
|
||||
b_state = root_state.substates[b_state_name]
|
||||
assert a_state._backend_vars != a_state.backend_vars
|
||||
assert a_state._be == a_state._backend_vars["_be"] == 3
|
||||
assert b_state._be is None
|
||||
assert b_state._backend_vars["_be"] is None
|
||||
|
||||
assert count_b.text == "0"
|
||||
|
||||
button_b.click()
|
||||
@ -105,3 +143,9 @@ async def test_component_state_app(component_state_app: AppHarness):
|
||||
|
||||
button_b.click()
|
||||
assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2"
|
||||
|
||||
root_state = await component_state_app.get_state(root_state_token)
|
||||
a_state = root_state.substates[a_state_name]
|
||||
b_state = root_state.substates[b_state_name]
|
||||
assert b_state._backend_vars != b_state.backend_vars
|
||||
assert b_state._be == b_state._backend_vars["_be"] == 2
|
||||
|
Loading…
Reference in New Issue
Block a user