Event Handlers should not shadow base state methods (#1543)
This commit is contained in:
parent
cebc5982f3
commit
2fa087a0fa
@ -40,7 +40,7 @@ class HydrateMiddleware(Middleware):
|
|||||||
setattr(state, constants.IS_HYDRATED, False)
|
setattr(state, constants.IS_HYDRATED, False)
|
||||||
delta = format.format_state({state.get_name(): state.dict()})
|
delta = format.format_state({state.get_name(): state.dict()})
|
||||||
# since a full dict was captured, clean any dirtiness
|
# since a full dict was captured, clean any dirtiness
|
||||||
state.clean()
|
state._clean()
|
||||||
|
|
||||||
# Get the route for on_load events.
|
# Get the route for on_load events.
|
||||||
route = event.router_data.get(constants.RouteVar.PATH, "")
|
route = event.router_data.get(constants.RouteVar.PATH, "")
|
||||||
|
@ -105,7 +105,6 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Setup the substates.
|
# Setup the substates.
|
||||||
for substate in self.get_substates():
|
for substate in self.get_substates():
|
||||||
self.substates[substate.get_name()] = substate(parent_state=self)
|
self.substates[substate.get_name()] = substate(parent_state=self)
|
||||||
|
|
||||||
# Convert the event handlers to functions.
|
# Convert the event handlers to functions.
|
||||||
for name, event_handler in self.event_handlers.items():
|
for name, event_handler in self.event_handlers.items():
|
||||||
fn = functools.partial(event_handler.fn, self)
|
fn = functools.partial(event_handler.fn, self)
|
||||||
@ -154,7 +153,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
if types._issubclass(field.type_, Union[List, Dict]):
|
if types._issubclass(field.type_, Union[List, Dict]):
|
||||||
setattr(self, field.name, value_in_rx_data)
|
setattr(self, field.name, value_in_rx_data)
|
||||||
|
|
||||||
self.clean()
|
self._clean()
|
||||||
|
|
||||||
def _reassign_field(self, field_name: str):
|
def _reassign_field(self, field_name: str):
|
||||||
"""Reassign the given field.
|
"""Reassign the given field.
|
||||||
@ -186,6 +185,8 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
**kwargs: The kwargs to pass to the pydantic init_subclass method.
|
**kwargs: The kwargs to pass to the pydantic init_subclass method.
|
||||||
"""
|
"""
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
|
# Event handlers should not shadow builtin state methods.
|
||||||
|
cls._check_overridden_methods()
|
||||||
|
|
||||||
# Get the parent vars.
|
# Get the parent vars.
|
||||||
parent_state = cls.get_parent_state()
|
parent_state = cls.get_parent_state()
|
||||||
@ -238,6 +239,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
cls.event_handlers[name] = handler
|
cls.event_handlers[name] = handler
|
||||||
setattr(cls, name, handler)
|
setattr(cls, name, handler)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _check_overridden_methods(cls):
|
||||||
|
"""Check for shadow methods and raise error if any.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NameError: When an event handler shadows an inbuilt state method.
|
||||||
|
"""
|
||||||
|
overridden_methods = set()
|
||||||
|
state_base_functions = cls._get_base_functions()
|
||||||
|
for name, method in inspect.getmembers(cls, inspect.isfunction):
|
||||||
|
# Check if the method is overridden and not a dunder method
|
||||||
|
if (
|
||||||
|
not name.startswith("__")
|
||||||
|
and method.__name__ in state_base_functions
|
||||||
|
and state_base_functions[method.__name__] != method
|
||||||
|
):
|
||||||
|
overridden_methods.add(method.__name__)
|
||||||
|
|
||||||
|
for method_name in overridden_methods:
|
||||||
|
raise NameError(
|
||||||
|
f"The event handler name `{method_name}` shadows a builtin State method; use a different name instead"
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_skip_vars(cls) -> Set[str]:
|
def get_skip_vars(cls) -> Set[str]:
|
||||||
"""Get the vars to skip when serializing.
|
"""Get the vars to skip when serializing.
|
||||||
@ -444,6 +468,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
field.required = False
|
field.required = False
|
||||||
field.default = default_value
|
field.default = default_value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_base_functions() -> Dict[str, FunctionType]:
|
||||||
|
"""Get all functions of the state class excluding dunder methods.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The functions of rx.State class as a dict.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
func[0]: func[1]
|
||||||
|
for func in inspect.getmembers(State, predicate=inspect.isfunction)
|
||||||
|
if not func[0].startswith("__")
|
||||||
|
}
|
||||||
|
|
||||||
def get_token(self) -> str:
|
def get_token(self) -> str:
|
||||||
"""Return the token of the client associated with this state.
|
"""Return the token of the client associated with this state.
|
||||||
|
|
||||||
@ -598,7 +635,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
if types.is_backend_variable(name) and name != "_backend_vars":
|
if types.is_backend_variable(name) and name != "_backend_vars":
|
||||||
self._backend_vars.__setitem__(name, value)
|
self._backend_vars.__setitem__(name, value)
|
||||||
self.dirty_vars.add(name)
|
self.dirty_vars.add(name)
|
||||||
self.mark_dirty()
|
self._mark_dirty()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Make sure lists and dicts are converted to ReflexList, ReflexDict and ReflexSet.
|
# Make sure lists and dicts are converted to ReflexList, ReflexDict and ReflexSet.
|
||||||
@ -611,12 +648,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Add the var to the dirty list.
|
# Add the var to the dirty list.
|
||||||
if name in self.vars or name in self.computed_var_dependencies:
|
if name in self.vars or name in self.computed_var_dependencies:
|
||||||
self.dirty_vars.add(name)
|
self.dirty_vars.add(name)
|
||||||
self.mark_dirty()
|
self._mark_dirty()
|
||||||
|
|
||||||
# For now, handle router_data updates as a special case
|
# For now, handle router_data updates as a special case
|
||||||
if name == constants.ROUTER_DATA:
|
if name == constants.ROUTER_DATA:
|
||||||
self.dirty_vars.add(name)
|
self.dirty_vars.add(name)
|
||||||
self.mark_dirty()
|
self._mark_dirty()
|
||||||
# propagate router_data updates down the state tree
|
# propagate router_data updates down the state tree
|
||||||
for substate in self.substates.values():
|
for substate in self.substates.values():
|
||||||
setattr(substate, name, value)
|
setattr(substate, name, value)
|
||||||
@ -685,7 +722,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Clean the state before processing the event.
|
# Clean the state before processing the event.
|
||||||
self.clean()
|
self._clean()
|
||||||
|
|
||||||
# Run the event generator and return state updates.
|
# Run the event generator and return state updates.
|
||||||
async for events, final in event_iter:
|
async for events, final in event_iter:
|
||||||
@ -699,7 +736,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
yield StateUpdate(delta=delta, events=events, final=final)
|
yield StateUpdate(delta=delta, events=events, final=final)
|
||||||
|
|
||||||
# Clean the state to prepare for the next event.
|
# Clean the state to prepare for the next event.
|
||||||
self.clean()
|
self._clean()
|
||||||
|
|
||||||
async def _process_event(
|
async def _process_event(
|
||||||
self, handler: EventHandler, state: State, payload: Dict
|
self, handler: EventHandler, state: State, payload: Dict
|
||||||
@ -806,7 +843,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
# Apply dirty variables down into substates
|
# Apply dirty variables down into substates
|
||||||
self.dirty_vars.update(self._always_dirty_computed_vars())
|
self.dirty_vars.update(self._always_dirty_computed_vars())
|
||||||
self.mark_dirty()
|
self._mark_dirty()
|
||||||
|
|
||||||
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
# Return the dirty vars for this instance, any cached/dependent computed vars,
|
||||||
# and always dirty computed vars (cache=False)
|
# and always dirty computed vars (cache=False)
|
||||||
@ -835,7 +872,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Return the delta.
|
# Return the delta.
|
||||||
return delta
|
return delta
|
||||||
|
|
||||||
def mark_dirty(self):
|
def _mark_dirty(self):
|
||||||
"""Mark the substate and all parent states as dirty."""
|
"""Mark the substate and all parent states as dirty."""
|
||||||
state_name = self.get_name()
|
state_name = self.get_name()
|
||||||
if (
|
if (
|
||||||
@ -843,7 +880,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
and state_name not in self.parent_state.dirty_substates
|
and state_name not in self.parent_state.dirty_substates
|
||||||
):
|
):
|
||||||
self.parent_state.dirty_substates.add(self.get_name())
|
self.parent_state.dirty_substates.add(self.get_name())
|
||||||
self.parent_state.mark_dirty()
|
self.parent_state._mark_dirty()
|
||||||
|
|
||||||
# have to mark computed vars dirty to allow access to newly computed
|
# have to mark computed vars dirty to allow access to newly computed
|
||||||
# values within the same ComputedVar function
|
# values within the same ComputedVar function
|
||||||
@ -856,13 +893,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
self.dirty_substates.add(substate_name)
|
self.dirty_substates.add(substate_name)
|
||||||
substate = substates[substate_name]
|
substate = substates[substate_name]
|
||||||
substate.dirty_vars.add(var)
|
substate.dirty_vars.add(var)
|
||||||
substate.mark_dirty()
|
substate._mark_dirty()
|
||||||
|
|
||||||
def clean(self):
|
def _clean(self):
|
||||||
"""Reset the dirty vars."""
|
"""Reset the dirty vars."""
|
||||||
# Recursively clean the substates.
|
# Recursively clean the substates.
|
||||||
for substate in self.dirty_substates:
|
for substate in self.dirty_substates:
|
||||||
self.substates[substate].clean()
|
self.substates[substate]._clean()
|
||||||
|
|
||||||
# Clean this state.
|
# Clean this state.
|
||||||
self.dirty_vars = set()
|
self.dirty_vars = set()
|
||||||
@ -882,7 +919,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Apply dirty variables down into substates to allow never-cached ComputedVar to
|
# Apply dirty variables down into substates to allow never-cached ComputedVar to
|
||||||
# trigger recalculation of dependent vars
|
# trigger recalculation of dependent vars
|
||||||
self.dirty_vars.update(self._always_dirty_computed_vars())
|
self.dirty_vars.update(self._always_dirty_computed_vars())
|
||||||
self.mark_dirty()
|
self._mark_dirty()
|
||||||
|
|
||||||
base_vars = {
|
base_vars = {
|
||||||
prop_name: self.get_value(getattr(self, prop_name))
|
prop_name: self.get_value(getattr(self, prop_name))
|
||||||
|
@ -365,7 +365,7 @@ class AppHarness:
|
|||||||
delta = state.get_delta()
|
delta = state.get_delta()
|
||||||
if delta:
|
if delta:
|
||||||
update = reflex.state.StateUpdate(delta=delta, events=[], final=True)
|
update = reflex.state.StateUpdate(delta=delta, events=[], final=True)
|
||||||
state.clean()
|
state._clean()
|
||||||
# Emit the event.
|
# Emit the event.
|
||||||
pending.append(
|
pending.append(
|
||||||
event_ns.emit(
|
event_ns.emit(
|
||||||
|
@ -498,7 +498,7 @@ def test_set_dirty_var(test_state):
|
|||||||
assert test_state.dirty_vars == {"num1", "num2", "sum"}
|
assert test_state.dirty_vars == {"num1", "num2", "sum"}
|
||||||
|
|
||||||
# Cleaning the state should remove all dirty vars.
|
# Cleaning the state should remove all dirty vars.
|
||||||
test_state.clean()
|
test_state._clean()
|
||||||
assert test_state.dirty_vars == set()
|
assert test_state.dirty_vars == set()
|
||||||
|
|
||||||
|
|
||||||
@ -524,7 +524,7 @@ def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_st
|
|||||||
assert child_state.dirty_substates == set()
|
assert child_state.dirty_substates == set()
|
||||||
|
|
||||||
# Cleaning the parent state should remove the dirty substate.
|
# Cleaning the parent state should remove the dirty substate.
|
||||||
test_state.clean()
|
test_state._clean()
|
||||||
assert test_state.dirty_substates == set()
|
assert test_state.dirty_substates == set()
|
||||||
assert child_state.dirty_vars == set()
|
assert child_state.dirty_vars == set()
|
||||||
|
|
||||||
@ -534,7 +534,7 @@ def test_set_dirty_substate(test_state, child_state, child_state2, grandchild_st
|
|||||||
assert test_state.dirty_substates == {"child_state"}
|
assert test_state.dirty_substates == {"child_state"}
|
||||||
|
|
||||||
# Cleaning the middle state should keep the parent state dirty.
|
# Cleaning the middle state should keep the parent state dirty.
|
||||||
child_state.clean()
|
child_state._clean()
|
||||||
assert test_state.dirty_substates == {"child_state"}
|
assert test_state.dirty_substates == {"child_state"}
|
||||||
assert child_state.dirty_substates == set()
|
assert child_state.dirty_substates == set()
|
||||||
assert grandchild_state.dirty_vars == set()
|
assert grandchild_state.dirty_vars == set()
|
||||||
@ -626,7 +626,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
|||||||
"test_state": {"sum": 3.14, "upper": ""},
|
"test_state": {"sum": 3.14, "upper": ""},
|
||||||
"test_state.child_state": {"value": "HI", "count": 24},
|
"test_state.child_state": {"value": "HI", "count": 24},
|
||||||
}
|
}
|
||||||
test_state.clean()
|
test_state._clean()
|
||||||
|
|
||||||
# Test with the granchild state.
|
# Test with the granchild state.
|
||||||
assert grandchild_state.value2 == ""
|
assert grandchild_state.value2 == ""
|
||||||
@ -1044,23 +1044,23 @@ def test_computed_var_cached_depends_on_non_cached():
|
|||||||
cs = ComputedState()
|
cs = ComputedState()
|
||||||
assert cs.dirty_vars == set()
|
assert cs.dirty_vars == set()
|
||||||
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
|
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
|
||||||
cs.clean()
|
cs._clean()
|
||||||
assert cs.dirty_vars == set()
|
assert cs.dirty_vars == set()
|
||||||
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
|
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 0, "dep_v": 0}}
|
||||||
cs.clean()
|
cs._clean()
|
||||||
assert cs.dirty_vars == set()
|
assert cs.dirty_vars == set()
|
||||||
cs.v = 1
|
cs.v = 1
|
||||||
assert cs.dirty_vars == {"v", "comp_v", "dep_v", "no_cache_v"}
|
assert cs.dirty_vars == {"v", "comp_v", "dep_v", "no_cache_v"}
|
||||||
assert cs.get_delta() == {
|
assert cs.get_delta() == {
|
||||||
cs.get_name(): {"v": 1, "no_cache_v": 1, "dep_v": 1, "comp_v": 1}
|
cs.get_name(): {"v": 1, "no_cache_v": 1, "dep_v": 1, "comp_v": 1}
|
||||||
}
|
}
|
||||||
cs.clean()
|
cs._clean()
|
||||||
assert cs.dirty_vars == set()
|
assert cs.dirty_vars == set()
|
||||||
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
|
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
|
||||||
cs.clean()
|
cs._clean()
|
||||||
assert cs.dirty_vars == set()
|
assert cs.dirty_vars == set()
|
||||||
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
|
assert cs.get_delta() == {cs.get_name(): {"no_cache_v": 1, "dep_v": 1}}
|
||||||
cs.clean()
|
cs._clean()
|
||||||
assert cs.dirty_vars == set()
|
assert cs.dirty_vars == set()
|
||||||
|
|
||||||
|
|
||||||
@ -1191,3 +1191,17 @@ def test_setattr_of_mutable_types(mutable_state):
|
|||||||
assert isinstance(hashmap["mod_third_key"], ReflexDict)
|
assert isinstance(hashmap["mod_third_key"], ReflexDict)
|
||||||
|
|
||||||
assert isinstance(test_set, ReflexSet)
|
assert isinstance(test_set, ReflexSet)
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_on_state_method_shadow():
|
||||||
|
"""Test that an error is thrown when an event handler shadows a state method."""
|
||||||
|
with pytest.raises(NameError) as err:
|
||||||
|
|
||||||
|
class InvalidTest(rx.State):
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert (
|
||||||
|
err.value.args[0]
|
||||||
|
== f"The event handler name `reset` shadows a builtin State method; use a different name instead"
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user