diff --git a/integration/test_event_actions.py b/integration/test_event_actions.py new file mode 100644 index 000000000..d446b6663 --- /dev/null +++ b/integration/test_event_actions.py @@ -0,0 +1,234 @@ +"""Ensure stopPropagation and preventDefault work as expected.""" + +from typing import Callable, Coroutine, Generator + +import pytest +from selenium.webdriver.common.by import By + +from reflex.testing import AppHarness, WebDriver + + +def TestEventAction(): + """App for testing event_actions.""" + import reflex as rx + + class EventActionState(rx.State): + order: list[str] + + def on_click(self, ev): + self.order.append(f"on_click:{ev}") + + def on_click2(self): + self.order.append("on_click2") + + @rx.var + def token(self) -> str: + return self.get_token() + + def index(): + return rx.vstack( + rx.input(value=EventActionState.token, is_read_only=True, id="token"), + rx.button("No events", id="btn-no-events"), + rx.button( + "Stop Prop Only", + id="btn-stop-prop-only", + on_click=rx.stop_propagation, # type: ignore + ), + rx.button( + "Click event", + on_click=EventActionState.on_click("no_event_actions"), # type: ignore + id="btn-click-event", + ), + rx.button( + "Click stop propagation", + on_click=EventActionState.on_click("stop_propagation").stop_propagation, # type: ignore + id="btn-click-stop-propagation", + ), + rx.button( + "Click stop propagation2", + on_click=EventActionState.on_click2.stop_propagation, + id="btn-click-stop-propagation2", + ), + rx.button( + "Click event 2", + on_click=EventActionState.on_click2, + id="btn-click-event2", + ), + rx.link( + "Link", + href="#", + on_click=EventActionState.on_click("link_no_event_actions"), # type: ignore + id="link", + ), + rx.link( + "Link Stop Propagation", + href="#", + on_click=EventActionState.on_click( # type: ignore + "link_stop_propagation" + ).stop_propagation, + id="link-stop-propagation", + ), + rx.link( + "Link Prevent Default Only", + href="/invalid", + on_click=rx.prevent_default, # type: ignore + id="link-prevent-default-only", + ), + rx.link( + "Link Prevent Default", + href="/invalid", + on_click=EventActionState.on_click( # type: ignore + "link_prevent_default" + ).prevent_default, + id="link-prevent-default", + ), + rx.link( + "Link Both", + href="/invalid", + on_click=EventActionState.on_click( # type: ignore + "link_both" + ).stop_propagation.prevent_default, + id="link-stop-propagation-prevent-default", + ), + rx.list( + rx.foreach( + EventActionState.order, # type: ignore + rx.list_item, + ), + ), + on_click=EventActionState.on_click("outer"), # type: ignore + ) + + app = rx.App(state=EventActionState) + app.add_page(index) + app.compile() + + +@pytest.fixture(scope="session") +def event_action(tmp_path_factory) -> Generator[AppHarness, None, None]: + """Start TestEventAction app at tmp_path via AppHarness. + + Args: + tmp_path_factory: pytest tmp_path_factory fixture + + Yields: + running AppHarness instance + """ + with AppHarness.create( + root=tmp_path_factory.mktemp(f"event_action"), + app_source=TestEventAction, # type: ignore + ) as harness: + yield harness + + +@pytest.fixture +def driver(event_action: AppHarness) -> Generator[WebDriver, None, None]: + """Get an instance of the browser open to the event_action app. + + Args: + event_action: harness for TestEventAction app + + Yields: + WebDriver instance. + """ + assert event_action.app_instance is not None, "app is not running" + driver = event_action.frontend() + try: + yield driver + finally: + driver.quit() + + +@pytest.fixture() +def token(event_action: AppHarness, driver: WebDriver) -> str: + """Get the token associated with backend state. + + Args: + event_action: harness for TestEventAction app. + driver: WebDriver instance. + + Returns: + The token visible in the driver browser. + """ + assert event_action.app_instance is not None + token_input = driver.find_element(By.ID, "token") + assert token_input + + # wait for the backend connection to send the token + token = event_action.poll_for_value(token_input) + assert token is not None + + return token + + +@pytest.fixture() +def poll_for_order( + event_action: AppHarness, token: str +) -> Callable[[list[str]], Coroutine[None, None, None]]: + """Poll for the order list to match the expected order. + + Args: + event_action: harness for TestEventAction app. + token: The token visible in the driver browser. + + Returns: + An async function that polls for the order list to match the expected order. + """ + + async def _poll_for_order(exp_order: list[str]): + async def _backend_state(): + return await event_action.get_state(token) + + async def _check(): + return (await _backend_state()).order == exp_order + + await AppHarness._poll_for_async(_check) + assert (await _backend_state()).order == exp_order + + return _poll_for_order + + +@pytest.mark.parametrize( + ("element_id", "exp_order"), + [ + ("btn-no-events", ["on_click:outer"]), + ("btn-stop-prop-only", []), + ("btn-click-event", ["on_click:no_event_actions", "on_click:outer"]), + ("btn-click-stop-propagation", ["on_click:stop_propagation"]), + ("btn-click-stop-propagation2", ["on_click2"]), + ("btn-click-event2", ["on_click2", "on_click:outer"]), + ("link", ["on_click:link_no_event_actions", "on_click:outer"]), + ("link-stop-propagation", ["on_click:link_stop_propagation"]), + ("link-prevent-default", ["on_click:link_prevent_default", "on_click:outer"]), + ("link-prevent-default-only", ["on_click:outer"]), + ("link-stop-propagation-prevent-default", ["on_click:link_both"]), + ], +) +@pytest.mark.usefixtures("token") +@pytest.mark.asyncio +async def test_event_actions( + driver: WebDriver, + poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], + element_id: str, + exp_order: list[str], +): + """Click links and buttons and assert on fired events. + + Args: + driver: WebDriver instance. + poll_for_order: function that polls for the order list to match the expected order. + element_id: The id of the element to click. + exp_order: The expected order of events. + """ + el = driver.find_element(By.ID, element_id) + assert el + + prev_url = driver.current_url + + el.click() + await poll_for_order(exp_order) + + if element_id.startswith("link") and "prevent-default" not in element_id: + assert driver.current_url != prev_url + else: + assert driver.current_url == prev_url diff --git a/integration/test_form_submit.py b/integration/test_form_submit.py index ebaec955b..25296b3d4 100644 --- a/integration/test_form_submit.py +++ b/integration/test_form_submit.py @@ -53,6 +53,7 @@ def FormSubmit(): rx.button("Submit", type_="submit"), ), on_submit=FormState.form_submit, + custom_attrs={"action": "/invalid"}, ), rx.spacer(), height="100vh", @@ -145,6 +146,8 @@ async def test_submit(driver, form_submit: AppHarness): time.sleep(1) + prev_url = driver.current_url + submit_input = driver.find_element(By.CLASS_NAME, "chakra-button") submit_input.click() @@ -166,3 +169,6 @@ async def test_submit(driver, form_submit: AppHarness): assert form_data["select_input"] == "option1" assert form_data["text_area_input"] == "Some\nText" assert form_data["debounce_input"] == "bar baz" + + # submitting the form should NOT change the url (preventDefault on_submit event) + assert driver.current_url == prev_url diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 1c0c63774..44ca56050 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -486,8 +486,13 @@ export const useEventLoop = ( const [connectError, setConnectError] = useState(null) // Function to add new events to the event queue. - const addEvents = (events, _e) => { - preventDefault(_e); + const addEvents = (events, _e, event_actions) => { + if (event_actions?.preventDefault && _e) { + _e.preventDefault(); + } + if (event_actions?.stopPropagation && _e) { + _e.stopPropagation(); + } queueEvents(events, socket) } @@ -532,16 +537,6 @@ export const isTrue = (val) => { return Array.isArray(val) ? val.length > 0 : !!val; }; -/** - * Prevent the default event for form submission. - * @param event - */ -export const preventDefault = (event) => { - if (event && event.type == "submit") { - event.preventDefault(); - } -}; - /** * Get the value from a ref. * @param ref The ref to get the value from. diff --git a/reflex/__init__.py b/reflex/__init__.py index 34f9e9b0f..017d7b7ec 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -24,12 +24,14 @@ from .event import call_script as call_script from .event import clear_local_storage as clear_local_storage from .event import console_log as console_log from .event import download as download +from .event import prevent_default as prevent_default from .event import redirect as redirect from .event import remove_cookie as remove_cookie from .event import remove_local_storage as remove_local_storage from .event import set_clipboard as set_clipboard from .event import set_focus as set_focus from .event import set_value as set_value +from .event import stop_propagation as stop_propagation from .event import window_alert as window_alert from .middleware import Middleware as Middleware from .model import Model as Model diff --git a/reflex/components/component.py b/reflex/components/component.py index ee9c99ad4..654263f9c 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -242,6 +242,9 @@ class Component(Base, ABC): if value._var_type is not EventChain: raise ValueError(f"Invalid event chain: {value}") return value + elif isinstance(value, EventChain): + # Trust that the caller knows what they're doing passing an EventChain directly + return value arg_spec = triggers.get(event_trigger, lambda: []) @@ -260,7 +263,7 @@ class Component(Base, ABC): deprecation_version="0.2.8", removal_version="0.3.0", ) - events = [] + events: list[EventSpec] = [] for v in value: if isinstance(v, EventHandler): # Call the event handler to get the event. @@ -291,20 +294,26 @@ class Component(Base, ABC): raise ValueError(f"Invalid event chain: {value}") # Add args to the event specs if necessary. - events = [ - EventSpec( - handler=e.handler, - args=get_handler_args(e), - client_handler_name=e.client_handler_name, - ) - for e in events - ] + events = [e.with_args(get_handler_args(e)) for e in events] + + # Collect event_actions from each spec + event_actions = {} + for e in events: + event_actions.update(e.event_actions) # Return the event chain. if isinstance(arg_spec, Var): - return EventChain(events=events, args_spec=None) + return EventChain( + events=events, + args_spec=None, + event_actions=event_actions, + ) else: - return EventChain(events=events, args_spec=arg_spec) # type: ignore + return EventChain( + events=events, + args_spec=arg_spec, # type: ignore + event_actions=event_actions, + ) def get_event_triggers(self) -> Dict[str, Any]: """Get the event triggers for the component. diff --git a/reflex/components/forms/form.py b/reflex/components/forms/form.py index 2572c5746..f17642ece 100644 --- a/reflex/components/forms/form.py +++ b/reflex/components/forms/form.py @@ -1,10 +1,12 @@ """Form components.""" +from __future__ import annotations -from typing import Any, Dict +from typing import Any, Callable, Dict, List from reflex.components.component import Component from reflex.components.libs.chakra import ChakraComponent from reflex.constants import EventTriggers +from reflex.event import EventChain, EventHandler, EventSpec from reflex.vars import Var @@ -16,6 +18,29 @@ class Form(ChakraComponent): # What the form renders to. as_: Var[str] = "form" # type: ignore + def _create_event_chain( + self, + event_trigger: str, + value: Var + | EventHandler + | EventSpec + | List[EventHandler | EventSpec] + | Callable[..., Any], + ) -> EventChain | Var: + """Override the event chain creation to preventDefault for on_submit. + + Args: + event_trigger: The event trigger. + value: The value of the event trigger. + + Returns: + The event chain. + """ + chain = super()._create_event_chain(event_trigger, value) + if event_trigger == EventTriggers.ON_SUBMIT and isinstance(chain, EventChain): + return chain.prevent_default + return chain + def get_event_triggers(self) -> Dict[str, Any]: """Get the event triggers that pass the component's value to the handler. diff --git a/reflex/event.py b/reflex/event.py index 422183235..6775ae90b 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -101,7 +101,36 @@ def _no_chain_background_task( raise TypeError(f"{fn} is marked as a background task, but is not async.") -class EventHandler(Base): +class EventActionsMixin(Base): + """Mixin for DOM event actions.""" + + # Whether to `preventDefault` or `stopPropagation` on the event. + event_actions: Dict[str, bool] = {} + + @property + def stop_propagation(self): + """Stop the event from bubbling up the DOM tree. + + Returns: + New EventHandler-like with stopPropagation set to True. + """ + return self.copy( + update={"event_actions": {"stopPropagation": True, **self.event_actions}}, + ) + + @property + def prevent_default(self): + """Prevent the default behavior of the event. + + Returns: + New EventHandler-like with preventDefault set to True. + """ + return self.copy( + update={"event_actions": {"preventDefault": True, **self.event_actions}}, + ) + + +class EventHandler(EventActionsMixin): """An event handler responds to an event to update the state.""" # The function to call in response to the event. @@ -150,6 +179,7 @@ class EventHandler(Base): client_handler_name="uploadFiles", # `files` is defined in the Upload component's _use_hooks args=((Var.create_safe("files"), Var.create_safe("files")),), + event_actions=self.event_actions.copy(), ) # Otherwise, convert to JSON. @@ -162,10 +192,12 @@ class EventHandler(Base): payload = tuple(zip(fn_args, values)) # Return the event spec. - return EventSpec(handler=self, args=payload) + return EventSpec( + handler=self, args=payload, event_actions=self.event_actions.copy() + ) -class EventSpec(Base): +class EventSpec(EventActionsMixin): """An event specification. Whereas an Event object is passed during runtime, a spec is used @@ -187,8 +219,24 @@ class EventSpec(Base): # Required to allow tuple fields. frozen = True + def with_args(self, args: Tuple[Tuple[Var, Var], ...]) -> EventSpec: + """Copy the event spec, with updated args. -class EventChain(Base): + Args: + args: The new args to pass to the function. + + Returns: + A copy of the event spec, with the new args. + """ + return type(self)( + handler=self.handler, + client_handler_name=self.client_handler_name, + args=args, + event_actions=self.event_actions.copy(), + ) + + +class EventChain(EventActionsMixin): """Container for a chain of events that will be executed in order.""" events: List[EventSpec] @@ -196,6 +244,11 @@ class EventChain(Base): args_spec: Optional[Callable] +# These chains can be used for their side effects when no other events are desired. +stop_propagation = EventChain(events=[], args_spec=lambda: []).stop_propagation +prevent_default = EventChain(events=[], args_spec=lambda: []).prevent_default + + class Target(Base): """A Javascript event target.""" diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 1ec24dd37..857a61f82 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -314,7 +314,7 @@ def format_prop( arg_def = "(_e)" chain = ",".join([format_event(event) for event in prop.events]) - event = f"addEvents([{chain}], {arg_def})" + event = f"addEvents([{chain}], {arg_def}, {json_dumps(prop.event_actions)})" prop = f"{arg_def} => {event}" # Handle other types. diff --git a/tests/components/base/test_script.py b/tests/components/base/test_script.py index ffee64b6e..cc16ab718 100644 --- a/tests/components/base/test_script.py +++ b/tests/components/base/test_script.py @@ -57,14 +57,14 @@ def test_script_event_handler(): ) render_dict = component.render() assert ( - 'onReady={(_e) => addEvents([Event("ev_state.on_ready", {})], (_e))}' + 'onReady={(_e) => addEvents([Event("ev_state.on_ready", {})], (_e), {})}' in render_dict["props"] ) assert ( - 'onLoad={(_e) => addEvents([Event("ev_state.on_load", {})], (_e))}' + 'onLoad={(_e) => addEvents([Event("ev_state.on_load", {})], (_e), {})}' in render_dict["props"] ) assert ( - 'onError={(_e) => addEvents([Event("ev_state.on_error", {})], (_e))}' + 'onError={(_e) => addEvents([Event("ev_state.on_error", {})], (_e), {})}' in render_dict["props"] ) diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 998f55c46..c8f83fc42 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -425,7 +425,7 @@ def test_component_event_trigger_arbitrary_args(): assert comp.render()["props"][0] == ( "onFoo={(__e,_alpha,_bravo,_charlie) => addEvents(" '[Event("c1_state.mock_handler", {_e:__e.target.value,_bravo:_bravo["nested"],_charlie:(_charlie.custom + 42)})], ' - "(__e,_alpha,_bravo,_charlie))}" + "(__e,_alpha,_bravo,_charlie), {})}" ) diff --git a/tests/test_event.py b/tests/test_event.py index 7dcb3ef18..69e69905f 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -4,6 +4,7 @@ import pytest from reflex import event from reflex.event import Event, EventHandler, EventSpec, fix_events +from reflex.state import State from reflex.utils import format from reflex.vars import Var @@ -261,3 +262,54 @@ def test_remove_local_storage(): assert ( format.format_event(spec) == 'Event("_remove_local_storage", {key:`testkey`})' ) + + +def test_event_actions(): + """Test DOM event actions, like stopPropagation and preventDefault.""" + # EventHandler + handler = EventHandler(fn=lambda: None) + assert not handler.event_actions + sp_handler = handler.stop_propagation + assert handler is not sp_handler + assert sp_handler.event_actions == {"stopPropagation": True} + pd_handler = handler.prevent_default + assert handler is not pd_handler + assert pd_handler.event_actions == {"preventDefault": True} + both_handler = sp_handler.prevent_default + assert both_handler is not sp_handler + assert both_handler.event_actions == { + "stopPropagation": True, + "preventDefault": True, + } + assert not handler.event_actions + + # Convert to EventSpec should carry event actions + sp_handler2 = handler.stop_propagation + spec = sp_handler2() + assert spec.event_actions == {"stopPropagation": True} + assert spec.event_actions == sp_handler2.event_actions + assert spec.event_actions is not sp_handler2.event_actions + # But it should be a copy! + assert spec.event_actions is not sp_handler2.event_actions + spec2 = spec.prevent_default + assert spec is not spec2 + assert spec2.event_actions == {"stopPropagation": True, "preventDefault": True} + assert spec2.event_actions != spec.event_actions + + # The original handler should still not be touched. + assert not handler.event_actions + + +def test_event_actions_on_state(): + class EventActionState(State): + def handler(self): + pass + + handler = EventActionState.handler + assert isinstance(handler, EventHandler) + assert not handler.event_actions + + sp_handler = EventActionState.handler.stop_propagation + assert sp_handler.event_actions == {"stopPropagation": True} + # should NOT affect other references to the handler + assert not handler.event_actions diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index a6809cf29..768cc353a 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -4,7 +4,7 @@ from typing import Any import pytest from reflex.components.tags.tag import Tag -from reflex.event import EventChain, EventHandler, EventSpec +from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent from reflex.style import Style from reflex.utils import format from reflex.vars import BaseVar, Var @@ -290,6 +290,49 @@ def test_format_cond(condition: str, true_value: str, false_value: str, expected }, r'{{"a": "foo \"{ \"bar\" }\" baz", "b": val}}', ), + ( + EventChain( + events=[EventSpec(handler=EventHandler(fn=mock_event))], + args_spec=lambda: [], + ), + '{(_e) => addEvents([Event("mock_event", {})], (_e), {})}', + ), + ( + EventChain( + events=[ + EventSpec( + handler=EventHandler(fn=mock_event), + args=( + ( + Var.create_safe("arg"), + BaseVar( + _var_name="_e", + _var_type=FrontendEvent, + ).target.value, + ), + ), + ) + ], + args_spec=lambda: [], + ), + '{(_e) => addEvents([Event("mock_event", {arg:_e.target.value})], (_e), {})}', + ), + ( + EventChain( + events=[EventSpec(handler=EventHandler(fn=mock_event))], + args_spec=lambda: [], + event_actions={"stopPropagation": True}, + ), + '{(_e) => addEvents([Event("mock_event", {})], (_e), {"stopPropagation": true})}', + ), + ( + EventChain( + events=[EventSpec(handler=EventHandler(fn=mock_event))], + args_spec=lambda: [], + event_actions={"preventDefault": True}, + ), + '{(_e) => addEvents([Event("mock_event", {})], (_e), {"preventDefault": true})}', + ), ({"a": "red", "b": "blue"}, '{{"a": "red", "b": "blue"}}'), (BaseVar(_var_name="var", _var_type="int"), "{var}"), (