Event Handlers should not shadow base state methods (#1543)
This commit is contained in:
parent
cebc5982f3
commit
2fa087a0fa
@ -40,7 +40,7 @@ class HydrateMiddleware(Middleware):
|
||||
setattr(state, constants.IS_HYDRATED, False)
|
||||
delta = format.format_state({state.get_name(): state.dict()})
|
||||
# since a full dict was captured, clean any dirtiness
|
||||
state.clean()
|
||||
state._clean()
|
||||
|
||||
# Get the route for on_load events.
|
||||
route = event.router_data.get(constants.RouteVar.PATH, "")
|
||||
|
@ -105,7 +105,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Setup the substates.
|
||||
for substate in self.get_substates():
|
||||
self.substates[substate.get_name()] = substate(parent_state=self)
|
||||
|
||||
# Convert the event handlers to functions.
|
||||
for name, event_handler in self.event_handlers.items():
|
||||
fn = functools.partial(event_handler.fn, self)
|
||||
@ -154,7 +153,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if types._issubclass(field.type_, Union[List, Dict]):
|
||||
setattr(self, field.name, value_in_rx_data)
|
||||
|
||||
self.clean()
|
||||
self._clean()
|
||||
|
||||
def _reassign_field(self, field_name: str):
|
||||
"""Reassign the given field.
|
||||
@ -186,6 +185,8 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
**kwargs: The kwargs to pass to the pydantic init_subclass method.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Event handlers should not shadow builtin state methods.
|
||||
cls._check_overridden_methods()
|
||||
|
||||
# Get the parent vars.
|
||||
parent_state = cls.get_parent_state()
|
||||
@ -238,6 +239,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
cls.event_handlers[name] = handler
|
||||
setattr(cls, name, handler)
|
||||
|
||||
@classmethod
|
||||
def _check_overridden_methods(cls):
|
||||
"""Check for shadow methods and raise error if any.
|
||||
|
||||
Raises:
|
||||
NameError: When an event handler shadows an inbuilt state method.
|
||||
"""
|
||||
overridden_methods = set()
|
||||
state_base_functions = cls._get_base_functions()
|
||||
for name, method in inspect.getmembers(cls, inspect.isfunction):
|
||||
# Check if the method is overridden and not a dunder method
|
||||
if (
|
||||
not name.startswith("__")
|
||||
and method.__name__ in state_base_functions
|
||||
and state_base_functions[method.__name__] != method
|
||||
):
|
||||
overridden_methods.add(method.__name__)
|
||||
|
||||
for method_name in overridden_methods:
|
||||
raise NameError(
|
||||
f"The event handler name `{method_name}` shadows a builtin State method; use a different name instead"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_skip_vars(cls) -> Set[str]:
|
||||
"""Get the vars to skip when serializing.
|
||||
@ -444,6 +468,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
field.required = False
|
||||
field.default = default_value
|
||||
|
||||
@staticmethod
|
||||
def _get_base_functions() -> Dict[str, FunctionType]:
|
||||
"""Get all functions of the state class excluding dunder methods.
|
||||
|
||||
Returns:
|
||||
The functions of rx.State class as a dict.
|
||||
"""
|
||||
return {
|
||||
func[0]: func[1]
|
||||
for func in inspect.getmembers(State, predicate=inspect.isfunction)
|
||||
if not func[0].startswith("__")
|
||||
}
|
||||
|
||||
def get_token(self) -> str:
|
||||
"""Return the token of the client associated with this state.
|
||||
|
||||
@ -598,7 +635,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if types.is_backend_variable(name) and name != "_backend_vars":
|
||||
self._backend_vars.__setitem__(name, value)
|
||||
self.dirty_vars.add(name)
|
||||
self.mark_dirty()
|
||||
self._mark_dirty()
|
||||
return
|
||||
|
||||
# Make sure lists and dicts are converted to ReflexList, ReflexDict and ReflexSet.
|
||||
@ -611,12 +648,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Add the var to the dirty list.
|
||||
if name in self.vars or name in self.computed_var_dependencies:
|
||||
self.dirty_vars.add(name)
|
||||
self.mark_dirty()
|
||||
self._mark_dirty()
|
||||
|
||||
# For now, handle router_data updates as a special case
|
||||
if name == constants.ROUTER_DATA:
|
||||
self.dirty_vars.add(name)
|
||||
self.mark_dirty()
|
||||
self._mark_dirty()
|
||||
# propagate router_data updates down the state tree
|
||||
for substate in self.substates.values():
|
||||
setattr(substate, name, value)
|
||||
@ -685,7 +722,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
)
|
||||
|
||||
# Clean the state before processing the event.
|
||||
self.clean()
|
||||
self._clean()
|
||||
|
||||
# Run the event generator and return state updates.
|
||||
async for events, final in event_iter:
|
||||
@ -699,7 +736,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
yield StateUpdate(delta=delta, events=events, final=final)
|
||||
|
||||
# Clean the state to prepare for the next event.
|
||||
self.clean()
|
||||
self._clean()
|
||||
|
||||
async def _process_event(
|
||||
self, handler: EventHandler, state: State, payload: Dict
|
||||
@ -806,7 +843,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
# Apply dirty variables down into substates
|
||||
self.dirty_vars.update(self._always_dirty_computed_vars())
|
||||
self.mark_dirty()
|
||||
self._mark_dirty()
|
||||
|
||||
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
||||
# and always dirty computed vars (cache=False)
|
||||
@ -835,7 +872,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Return the delta.
|
||||
return delta
|
||||
|
||||
def mark_dirty(self):
|
||||
def _mark_dirty(self):
|
||||
"""Mark the substate and all parent states as dirty."""
|
||||
state_name = self.get_name()
|
||||
if (
|
||||
@ -843,7 +880,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
and state_name not in self.parent_state.dirty_substates
|
||||
):
|
||||
self.parent_state.dirty_substates.add(self.get_name())
|
||||
self.parent_state.mark_dirty()
|
||||
self.parent_state._mark_dirty()
|
||||
|
||||
# have to mark computed vars dirty to allow access to newly computed
|
||||
# values within the same ComputedVar function
|
||||
@ -856,13 +893,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
self.dirty_substates.add(substate_name)
|
||||
substate = substates[substate_name]
|
||||
substate.dirty_vars.add(var)
|
||||
substate.mark_dirty()
|
||||
substate._mark_dirty()
|
||||
|
||||
def clean(self):
|
||||
def _clean(self):
|
||||
"""Reset the dirty vars."""
|
||||
# Recursively clean the substates.
|
||||
for substate in self.dirty_substates:
|
||||
self.substates[substate].clean()
|
||||
self.substates[substate]._clean()
|
||||
|
||||
# Clean this state.
|
||||
self.dirty_vars = set()
|
||||
@ -882,7 +919,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Apply dirty variables down into substates to allow never-cached ComputedVar to
|
||||
# trigger recalculation of dependent vars
|
||||
self.dirty_vars.update(self._always_dirty_computed_vars())
|
||||
self.mark_dirty()
|
||||
self._mark_dirty()
|
||||
|
||||
base_vars = {
|
||||
prop_name: self.get_value(getattr(self, prop_name))
|
||||
|
@ -365,7 +365,7 @@ class AppHarness:
|
||||
delta = state.get_delta()
|
||||
if delta:
|
||||
update = reflex.state.StateUpdate(delta=delta, events=[], final=True)
|
||||
state.clean()
|
||||
state._clean()
|
||||
# Emit the event.
|
||||
pending.append(
|
||||
event_ns.emit(
|
||||
|
@ -498,7 +498,7 @@ def test_set_dirty_var(test_state):
|
||||
assert test_state.dirty_vars == {"num1", "num2", "sum"}
|
||||
|
||||
# Cleaning the state should remove all dirty vars.
|
||||
test_state.clean()
|
||||
test_state._clean()
|
||||
assert test_state.dirty_vars == set()
|
||||
|
||||
|
||||
@ -524,7 +524,7 @@ def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_st
|
||||
assert child_state.dirty_substates == set()
|
||||
|
||||
# Cleaning the parent state should remove the dirty substate.
|
||||
test_state.clean()
|
||||
test_state._clean()
|
||||
assert test_state.dirty_substates == set()
|
||||
assert child_state.dirty_vars == set()
|
||||
|
||||
@ -534,7 +534,7 @@ def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_st
|
||||
assert test_state.dirty_substates == {"child_state"}
|
||||
|
||||
# Cleaning the middle state should keep the parent state dirty.
|
||||
child_state.clean()
|
||||
child_state._clean()
|
||||
assert test_state.dirty_substates == {"child_state"}
|
||||
assert child_state.dirty_substates == set()
|
||||
assert grandchild_state.dirty_vars == set()
|
||||
@ -626,7 +626,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
||||
"test_state": {"sum": 3.14, "upper": ""},
|
||||
"test_state.child_state": {"value": "HI", "count": 24},
|
||||
}
|
||||
test_state.clean()
|
||||
test_state._clean()
|
||||
|
||||
# Test with the granchild state.
|
||||
assert grandchild_state.value2 == ""
|
||||
@ -1044,23 +1044,23 @@ def test_computed_var_cached_depends_on_non_cached():
|
||||
cs = ComputedState()
|
||||
assert cs.dirty_vars == set()
|
||||
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
|
||||
cs.clean()
|
||||
cs._clean()
|
||||
assert cs.dirty_vars == set()
|
||||
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
|
||||
cs.clean()
|
||||
cs._clean()
|
||||
assert cs.dirty_vars == set()
|
||||
cs.v = 1
|
||||
assert cs.dirty_vars == {"v", "comp_v", "dep_v", "no_cache_v"}
|
||||
assert cs.get_delta() == {
|
||||
cs.get_name(): {"v": 1, "no_cache_v": 1, "dep_v": 1, "comp_v": 1}
|
||||
}
|
||||
cs.clean()
|
||||
cs._clean()
|
||||
assert cs.dirty_vars == set()
|
||||
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
|
||||
cs.clean()
|
||||
cs._clean()
|
||||
assert cs.dirty_vars == set()
|
||||
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
|
||||
cs.clean()
|
||||
cs._clean()
|
||||
assert cs.dirty_vars == set()
|
||||
|
||||
|
||||
@ -1191,3 +1191,17 @@ def test_setattr_of_mutable_types(mutable_state):
|
||||
assert isinstance(hashmap["mod_third_key"], ReflexDict)
|
||||
|
||||
assert isinstance(test_set, ReflexSet)
|
||||
|
||||
|
||||
def test_error_on_state_method_shadow():
|
||||
"""Test that an error is thrown when an event handler shadows a state method."""
|
||||
with pytest.raises(NameError) as err:
|
||||
|
||||
class InvalidTest(rx.State):
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
assert (
|
||||
err.value.args[0]
|
||||
== f"The event handler name `reset` shadows a builtin State method; use a different name instead"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user