Event Handlers should not shadow base state methods (#1543)

This commit is contained in:
Elijah Ahianyo 2023-08-10 19:47:35 +00:00 committed by GitHub
parent cebc5982f3
commit 2fa087a0fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 25 deletions

View File

@ -40,7 +40,7 @@ class HydrateMiddleware(Middleware):
setattr(state, constants.IS_HYDRATED, False)
delta = format.format_state({state.get_name(): state.dict()})
# since a full dict was captured, clean any dirtiness
state.clean()
state._clean()
# Get the route for on_load events.
route = event.router_data.get(constants.RouteVar.PATH, "")

View File

@ -105,7 +105,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Setup the substates.
for substate in self.get_substates():
self.substates[substate.get_name()] = substate(parent_state=self)
# Convert the event handlers to functions.
for name, event_handler in self.event_handlers.items():
fn = functools.partial(event_handler.fn, self)
@ -154,7 +153,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
if types._issubclass(field.type_, Union[List, Dict]):
setattr(self, field.name, value_in_rx_data)
self.clean()
self._clean()
def _reassign_field(self, field_name: str):
"""Reassign the given field.
@ -186,6 +185,8 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
**kwargs: The kwargs to pass to the pydantic init_subclass method.
"""
super().__init_subclass__(**kwargs)
# Event handlers should not shadow builtin state methods.
cls._check_overridden_methods()
# Get the parent vars.
parent_state = cls.get_parent_state()
@ -238,6 +239,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
cls.event_handlers[name] = handler
setattr(cls, name, handler)
@classmethod
def _check_overridden_methods(cls):
"""Check for shadow methods and raise error if any.
Raises:
NameError: When an event handler shadows an inbuilt state method.
"""
overridden_methods = set()
state_base_functions = cls._get_base_functions()
for name, method in inspect.getmembers(cls, inspect.isfunction):
# Check if the method is overridden and not a dunder method
if (
not name.startswith("__")
and method.__name__ in state_base_functions
and state_base_functions[method.__name__] != method
):
overridden_methods.add(method.__name__)
for method_name in overridden_methods:
raise NameError(
f"The event handler name `{method_name}` shadows a builtin State method; use a different name instead"
)
@classmethod
def get_skip_vars(cls) -> Set[str]:
"""Get the vars to skip when serializing.
@ -444,6 +468,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
field.required = False
field.default = default_value
@staticmethod
def _get_base_functions() -> Dict[str, FunctionType]:
"""Get all functions of the state class excluding dunder methods.
Returns:
The functions of rx.State class as a dict.
"""
return {
func[0]: func[1]
for func in inspect.getmembers(State, predicate=inspect.isfunction)
if not func[0].startswith("__")
}
def get_token(self) -> str:
"""Return the token of the client associated with this state.
@ -598,7 +635,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
if types.is_backend_variable(name) and name != "_backend_vars":
self._backend_vars.__setitem__(name, value)
self.dirty_vars.add(name)
self.mark_dirty()
self._mark_dirty()
return
# Make sure lists and dicts are converted to ReflexList, ReflexDict and ReflexSet.
@ -611,12 +648,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Add the var to the dirty list.
if name in self.vars or name in self.computed_var_dependencies:
self.dirty_vars.add(name)
self.mark_dirty()
self._mark_dirty()
# For now, handle router_data updates as a special case
if name == constants.ROUTER_DATA:
self.dirty_vars.add(name)
self.mark_dirty()
self._mark_dirty()
# propagate router_data updates down the state tree
for substate in self.substates.values():
setattr(substate, name, value)
@ -685,7 +722,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
)
# Clean the state before processing the event.
self.clean()
self._clean()
# Run the event generator and return state updates.
async for events, final in event_iter:
@ -699,7 +736,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
yield StateUpdate(delta=delta, events=events, final=final)
# Clean the state to prepare for the next event.
self.clean()
self._clean()
async def _process_event(
self, handler: EventHandler, state: State, payload: Dict
@ -806,7 +843,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Apply dirty variables down into substates
self.dirty_vars.update(self._always_dirty_computed_vars())
self.mark_dirty()
self._mark_dirty()
# Return the dirty vars for this instance, any cached/dependent computed vars,
# and always dirty computed vars (cache=False)
@ -835,7 +872,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Return the delta.
return delta
def mark_dirty(self):
def _mark_dirty(self):
"""Mark the substate and all parent states as dirty."""
state_name = self.get_name()
if (
@ -843,7 +880,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
and state_name not in self.parent_state.dirty_substates
):
self.parent_state.dirty_substates.add(self.get_name())
self.parent_state.mark_dirty()
self.parent_state._mark_dirty()
# have to mark computed vars dirty to allow access to newly computed
# values within the same ComputedVar function
@ -856,13 +893,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
self.dirty_substates.add(substate_name)
substate = substates[substate_name]
substate.dirty_vars.add(var)
substate.mark_dirty()
substate._mark_dirty()
def clean(self):
def _clean(self):
"""Reset the dirty vars."""
# Recursively clean the substates.
for substate in self.dirty_substates:
self.substates[substate].clean()
self.substates[substate]._clean()
# Clean this state.
self.dirty_vars = set()
@ -882,7 +919,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Apply dirty variables down into substates to allow never-cached ComputedVar to
# trigger recalculation of dependent vars
self.dirty_vars.update(self._always_dirty_computed_vars())
self.mark_dirty()
self._mark_dirty()
base_vars = {
prop_name: self.get_value(getattr(self, prop_name))

View File

@ -365,7 +365,7 @@ class AppHarness:
delta = state.get_delta()
if delta:
update = reflex.state.StateUpdate(delta=delta, events=[], final=True)
state.clean()
state._clean()
# Emit the event.
pending.append(
event_ns.emit(

View File

@ -498,7 +498,7 @@ def test_set_dirty_var(test_state):
assert test_state.dirty_vars == {"num1", "num2", "sum"}
# Cleaning the state should remove all dirty vars.
test_state.clean()
test_state._clean()
assert test_state.dirty_vars == set()
@ -524,7 +524,7 @@ def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_st
assert child_state.dirty_substates == set()
# Cleaning the parent state should remove the dirty substate.
test_state.clean()
test_state._clean()
assert test_state.dirty_substates == set()
assert child_state.dirty_vars == set()
@ -534,7 +534,7 @@ def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_st
assert test_state.dirty_substates == {"child_state"}
# Cleaning the middle state should keep the parent state dirty.
child_state.clean()
child_state._clean()
assert test_state.dirty_substates == {"child_state"}
assert child_state.dirty_substates == set()
assert grandchild_state.dirty_vars == set()
@ -626,7 +626,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
"test_state": {"sum": 3.14, "upper": ""},
"test_state.child_state": {"value": "HI", "count": 24},
}
test_state.clean()
test_state._clean()
# Test with the granchild state.
assert grandchild_state.value2 == ""
@ -1044,23 +1044,23 @@ def test_computed_var_cached_depends_on_non_cached():
cs = ComputedState()
assert cs.dirty_vars == set()
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
cs.clean()
cs._clean()
assert cs.dirty_vars == set()
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
cs.clean()
cs._clean()
assert cs.dirty_vars == set()
cs.v = 1
assert cs.dirty_vars == {"v", "comp_v", "dep_v", "no_cache_v"}
assert cs.get_delta() == {
cs.get_name(): {"v": 1, "no_cache_v": 1, "dep_v": 1, "comp_v": 1}
}
cs.clean()
cs._clean()
assert cs.dirty_vars == set()
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
cs.clean()
cs._clean()
assert cs.dirty_vars == set()
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
cs.clean()
cs._clean()
assert cs.dirty_vars == set()
@ -1191,3 +1191,17 @@ def test_setattr_of_mutable_types(mutable_state):
assert isinstance(hashmap["mod_third_key"], ReflexDict)
assert isinstance(test_set, ReflexSet)
def test_error_on_state_method_shadow():
"""Test that an error is thrown when an event handler shadows a state method."""
with pytest.raises(NameError) as err:
class InvalidTest(rx.State):
def reset(self):
pass
assert (
err.value.args[0]
== f"The event handler name `reset` shadows a builtin State method; use a different name instead"
)