diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 5d1c419f8..588315ae9 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -40,7 +40,7 @@ class HydrateMiddleware(Middleware): setattr(state, constants.IS_HYDRATED, False) delta = format.format_state({state.get_name(): state.dict()}) # since a full dict was captured, clean any dirtiness - state.clean() + state._clean() # Get the route for on_load events. route = event.router_data.get(constants.RouteVar.PATH, "") diff --git a/reflex/state.py b/reflex/state.py index 571a6931d..f5906197e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -105,7 +105,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Setup the substates. for substate in self.get_substates(): self.substates[substate.get_name()] = substate(parent_state=self) - # Convert the event handlers to functions. for name, event_handler in self.event_handlers.items(): fn = functools.partial(event_handler.fn, self) @@ -154,7 +153,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): if types._issubclass(field.type_, Union[List, Dict]): setattr(self, field.name, value_in_rx_data) - self.clean() + self._clean() def _reassign_field(self, field_name: str): """Reassign the given field. @@ -186,6 +185,8 @@ class State(Base, ABC, extra=pydantic.Extra.allow): **kwargs: The kwargs to pass to the pydantic init_subclass method. """ super().__init_subclass__(**kwargs) + # Event handlers should not shadow builtin state methods. + cls._check_overridden_methods() # Get the parent vars. parent_state = cls.get_parent_state() @@ -238,6 +239,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow): cls.event_handlers[name] = handler setattr(cls, name, handler) + @classmethod + def _check_overridden_methods(cls): + """Check for shadow methods and raise error if any. + + Raises: + NameError: When an event handler shadows an inbuilt state method. + """ + overridden_methods = set() + state_base_functions = cls._get_base_functions() + for name, method in inspect.getmembers(cls, inspect.isfunction): + # Check if the method is overridden and not a dunder method + if ( + not name.startswith("__") + and method.__name__ in state_base_functions + and state_base_functions[method.__name__] != method + ): + overridden_methods.add(method.__name__) + + for method_name in overridden_methods: + raise NameError( + f"The event handler name `{method_name}` shadows a builtin State method; use a different name instead" + ) + @classmethod def get_skip_vars(cls) -> Set[str]: """Get the vars to skip when serializing. @@ -444,6 +468,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow): field.required = False field.default = default_value + @staticmethod + def _get_base_functions() -> Dict[str, FunctionType]: + """Get all functions of the state class excluding dunder methods. + + Returns: + The functions of rx.State class as a dict. + """ + return { + func[0]: func[1] + for func in inspect.getmembers(State, predicate=inspect.isfunction) + if not func[0].startswith("__") + } + def get_token(self) -> str: """Return the token of the client associated with this state. @@ -598,7 +635,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): if types.is_backend_variable(name) and name != "_backend_vars": self._backend_vars.__setitem__(name, value) self.dirty_vars.add(name) - self.mark_dirty() + self._mark_dirty() return # Make sure lists and dicts are converted to ReflexList, ReflexDict and ReflexSet. @@ -611,12 +648,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Add the var to the dirty list. if name in self.vars or name in self.computed_var_dependencies: self.dirty_vars.add(name) - self.mark_dirty() + self._mark_dirty() # For now, handle router_data updates as a special case if name == constants.ROUTER_DATA: self.dirty_vars.add(name) - self.mark_dirty() + self._mark_dirty() # propagate router_data updates down the state tree for substate in self.substates.values(): setattr(substate, name, value) @@ -685,7 +722,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): ) # Clean the state before processing the event. - self.clean() + self._clean() # Run the event generator and return state updates. async for events, final in event_iter: @@ -699,7 +736,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): yield StateUpdate(delta=delta, events=events, final=final) # Clean the state to prepare for the next event. - self.clean() + self._clean() async def _process_event( self, handler: EventHandler, state: State, payload: Dict @@ -806,7 +843,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Apply dirty variables down into substates self.dirty_vars.update(self._always_dirty_computed_vars()) - self.mark_dirty() + self._mark_dirty() # Return the dirty vars for this instance, any cached/dependent computed vars, # and always dirty computed vars (cache=False) @@ -835,7 +872,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Return the delta. return delta - def mark_dirty(self): + def _mark_dirty(self): """Mark the substate and all parent states as dirty.""" state_name = self.get_name() if ( @@ -843,7 +880,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): and state_name not in self.parent_state.dirty_substates ): self.parent_state.dirty_substates.add(self.get_name()) - self.parent_state.mark_dirty() + self.parent_state._mark_dirty() # have to mark computed vars dirty to allow access to newly computed # values within the same ComputedVar function @@ -856,13 +893,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow): self.dirty_substates.add(substate_name) substate = substates[substate_name] substate.dirty_vars.add(var) - substate.mark_dirty() + substate._mark_dirty() - def clean(self): + def _clean(self): """Reset the dirty vars.""" # Recursively clean the substates. for substate in self.dirty_substates: - self.substates[substate].clean() + self.substates[substate]._clean() # Clean this state. self.dirty_vars = set() @@ -882,7 +919,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Apply dirty variables down into substates to allow never-cached ComputedVar to # trigger recalculation of dependent vars self.dirty_vars.update(self._always_dirty_computed_vars()) - self.mark_dirty() + self._mark_dirty() base_vars = { prop_name: self.get_value(getattr(self, prop_name)) diff --git a/reflex/testing.py b/reflex/testing.py index d65e7bf58..fec968e30 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -365,7 +365,7 @@ class AppHarness: delta = state.get_delta() if delta: update = reflex.state.StateUpdate(delta=delta, events=[], final=True) - state.clean() + state._clean() # Emit the event. pending.append( event_ns.emit( diff --git a/tests/test_state.py b/tests/test_state.py index 352ba6156..9288fba2b 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -498,7 +498,7 @@ def test_set_dirty_var(test_state): assert test_state.dirty_vars == {"num1", "num2", "sum"} # Cleaning the state should remove all dirty vars. - test_state.clean() + test_state._clean() assert test_state.dirty_vars == set() @@ -524,7 +524,7 @@ def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_st assert child_state.dirty_substates == set() # Cleaning the parent state should remove the dirty substate. - test_state.clean() + test_state._clean() assert test_state.dirty_substates == set() assert child_state.dirty_vars == set() @@ -534,7 +534,7 @@ def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_st assert test_state.dirty_substates == {"child_state"} # Cleaning the middle state should keep the parent state dirty. - child_state.clean() + child_state._clean() assert test_state.dirty_substates == {"child_state"} assert child_state.dirty_substates == set() assert grandchild_state.dirty_vars == set() @@ -626,7 +626,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) "test_state": {"sum": 3.14, "upper": ""}, "test_state.child_state": {"value": "HI", "count": 24}, } - test_state.clean() + test_state._clean() # Test with the granchild state. assert grandchild_state.value2 == "" @@ -1044,23 +1044,23 @@ def test_computed_var_cached_depends_on_non_cached(): cs = ComputedState() assert cs.dirty_vars == set() assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}} - cs.clean() + cs._clean() assert cs.dirty_vars == set() assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}} - cs.clean() + cs._clean() assert cs.dirty_vars == set() cs.v = 1 assert cs.dirty_vars == {"v", "comp_v", "dep_v", "no_cache_v"} assert cs.get_delta() == { cs.get_name(): {"v": 1, "no_cache_v": 1, "dep_v": 1, "comp_v": 1} } - cs.clean() + cs._clean() assert cs.dirty_vars == set() assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}} - cs.clean() + cs._clean() assert cs.dirty_vars == set() assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}} - cs.clean() + cs._clean() assert cs.dirty_vars == set() @@ -1191,3 +1191,17 @@ def test_setattr_of_mutable_types(mutable_state): assert isinstance(hashmap["mod_third_key"], ReflexDict) assert isinstance(test_set, ReflexSet) + + +def test_error_on_state_method_shadow(): + """Test that an error is thrown when an event handler shadows a state method.""" + with pytest.raises(NameError) as err: + + class InvalidTest(rx.State): + def reset(self): + pass + + assert ( + err.value.args[0] + == f"The event handler name `reset` shadows a builtin State method; use a different name instead" + )