diff --git a/reflex/components/component.py b/reflex/components/component.py index 2c60f09d4..754344cd2 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -7,7 +7,7 @@ import typing from abc import ABC, abstractmethod from functools import lru_cache, wraps from hashlib import md5 -from types import SimpleNamespace, LambdaType +from types import SimpleNamespace from typing import ( Any, Callable, @@ -1113,38 +1113,35 @@ class Component(BaseComponent, ABC): return vars - def _event_trigger_values_use_state(self) -> bool: + """Check if the values of a component's event trigger use state. + + Returns: + True if any of the component's event trigger values uses State. + """ for trigger in self.event_triggers.values(): if isinstance(trigger, EventChain): for event in trigger.events: - if event.handler.state_full_name or isinstance(event.handler.fn, LambdaType) and event.handler.fn.__name__== (lambda: None).__name__: + if event.handler.state_full_name: return True elif isinstance(trigger, Var) and trigger._var_state: return True return False def _has_stateful_event_triggers(self): + """Check if component or children have any event triggers that use state. + Returns: + True if the component or children have any event triggers that uses state. + """ if self.event_triggers and self._event_trigger_values_use_state(): return True else: for child in self.children: - if isinstance(child, Component) and child._has_stateful_event_triggers(): - return True - return False - - def _has_event_triggers(self) -> bool: - """Check if the component or children have any event triggers. - - Returns: - True if the component or children have any event triggers. - """ - if self.event_triggers: - return True - else: - for child in self.children: - if isinstance(child, Component) and child._has_event_triggers(): + if ( + isinstance(child, Component) + and child._has_stateful_event_triggers() + ): return True return False diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 9680af762..2a0d70154 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -2082,37 +2082,27 @@ class TriggerState(rx.State): """Sample event handler.""" pass + def do_another_thing(self, value): + """Sample event handler with arg.""" + pass + @pytest.mark.parametrize( - "component, exclude_event_trigger_values, output", + "component, output", [ - (rx.box(rx.text("random text")), None, False), - (rx.box(rx.text("random text", on_click=rx.console_log("log"))), None, True), + (rx.box(rx.text("random text")), False), ( rx.box(rx.text("random text", on_click=rx.console_log("log"))), - ["_console"], False, ), ( rx.box( - rx.text("random text", on_click=rx.console_log("log")), + rx.text("random text", on_click=TriggerState.do_something), rx.text( "random text", on_click=BaseVar(_var_name="toggleColorMode", _var_type=EventChain), ), ), - ["_console", "toggleColorMode"], - False, - ), - ( - rx.box( - rx.text("random text", on_click=rx.console_log("log")), - rx.text( - "random text", - on_click=BaseVar(_var_name="toggleColorMode", _var_type=EventChain), - ), - ), - ["_console"], True, ), ( @@ -2123,17 +2113,10 @@ class TriggerState(rx.State): on_click=BaseVar(_var_name="toggleColorMode", _var_type=EventChain), ), ), - ["toggleColorMode"], - True, - ), - ( - rx.box(rx.text("random text", on_click=TriggerState.do_something)), - ["do_something"], False, ), ( rx.box(rx.text("random text", on_click=TriggerState.do_something)), - ["non_existent"], True, ), ( @@ -2143,26 +2126,27 @@ class TriggerState(rx.State): on_click=[rx.console_log("log"), rx.window_alert("alert")], ), ), - ["_console", "_alert"], False, ), -( + ( rx.box( rx.text( "random text", - on_click=lambda x: x, + on_click=[rx.console_log("log"), TriggerState.do_something], ), ), - ["_console", "_alert"], - False, + True, + ), + ( + rx.box( + rx.text( + "random text", + on_click=lambda val: TriggerState.do_another_thing(val), # type: ignore + ), + ), + True, ), - ], ) -def test_has_event_triggers(component, exclude_event_trigger_values, output): - assert ( - component._has_event_triggers( - exclude_event_trigger_values=exclude_event_trigger_values - ) - == output - ) +def test_has_state_event_triggers(component, output): + assert component._has_stateful_event_triggers() == output