diff --git a/integration/test_event_chain.py b/integration/test_event_chain.py index 89a7b64f9..f762f16ca 100644 --- a/integration/test_event_chain.py +++ b/integration/test_event_chain.py @@ -155,8 +155,35 @@ def EventChain(): rx.input(value=State.token, readonly=True, id="token"), ) + def on_mount_return_chain(): + return rx.fragment( + rx.text( + "return", + on_mount=State.on_load_return_chain, + on_unmount=lambda: State.event_arg("unmount"), # type: ignore + ), + rx.input(value=State.token, readonly=True, id="token"), + rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"), + ) + + def on_mount_yield_chain(): + return rx.fragment( + rx.text( + "yield", + on_mount=[ + State.on_load_yield_chain, + lambda: State.event_arg("mount"), # type: ignore + ], + on_unmount=State.event_no_args, + ), + rx.input(value=State.token, readonly=True, id="token"), + rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"), + ) + app.add_page(on_load_return_chain, on_load=State.on_load_return_chain) # type: ignore app.add_page(on_load_yield_chain, on_load=State.on_load_yield_chain) # type: ignore + app.add_page(on_mount_return_chain) + app.add_page(on_mount_yield_chain) app.compile() @@ -330,3 +357,69 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order): time.sleep(0.5) backend_state = event_chain.app_instance.state_manager.states[token] assert backend_state.event_order == exp_event_order + + +@pytest.mark.parametrize( + ("uri", "exp_event_order"), + [ + ( + "/on-mount-return-chain", + [ + "on_load_return_chain", + "event_arg:unmount", + "on_load_return_chain", + "event_arg:1", + "event_arg:2", + "event_arg:3", + "event_arg:1", + "event_arg:2", + "event_arg:3", + "event_arg:unmount", + ], + ), + ( + "/on-mount-yield-chain", + [ + "on_load_yield_chain", + "event_arg:mount", + "event_no_args", + "on_load_yield_chain", + "event_arg:mount", + "event_arg:4", + "event_arg:5", + "event_arg:6", + "event_arg:4", + "event_arg:5", + "event_arg:6", + "event_no_args", + ], + ), + ], +) +def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order): + """Load the URI, assert that the events are handled in the correct order. + + These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode + due to react StrictMode being used. + + In prod mode, these events are only fired once. + + Args: + event_chain: AppHarness for the event_chain app + driver: selenium WebDriver open to the app + uri: the page to load + exp_event_order: the expected events recorded in the State + """ + driver.get(event_chain.frontend_url + uri) + token_input = driver.find_element(By.ID, "token") + assert token_input + + token = event_chain.poll_for_value(token_input) + + unmount_button = driver.find_element(By.ID, "unmount") + assert unmount_button + unmount_button.click() + + time.sleep(1) + backend_state = event_chain.app_instance.state_manager.states[token] + assert backend_state.event_order == exp_event_order diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index e39ac8cce..9e1848436 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -218,6 +218,11 @@ export const queueEvents = async (events, socket) => { export const processEvent = async ( socket ) => { + // Only proceed if the socket is up, otherwise we throw the event into the void + if (!socket) { + return; + } + // Only proceed if we're not already processing an event. if (event_queue.length === 0 || event_processing) { return; diff --git a/reflex/components/component.py b/reflex/components/component.py index 2b8f89bf3..415fd3358 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -286,7 +286,11 @@ class Component(Base, ABC): Returns: The event triggers. """ - return EVENT_TRIGGERS | set(self.get_controlled_triggers()) + return ( + EVENT_TRIGGERS + | set(self.get_controlled_triggers()) + | set((constants.ON_MOUNT, constants.ON_UNMOUNT)) + ) def get_controlled_triggers(self) -> Dict[str, Var]: """Get the event triggers that pass the component's value to the handler. @@ -525,16 +529,63 @@ class Component(Base, ABC): self._get_imports(), *[child.get_imports() for child in self.children] ) - def _get_hooks(self) -> Optional[str]: - """Get the React hooks for this component. + def _get_mount_lifecycle_hook(self) -> str | None: + """Generate the component lifecycle hook. Returns: - The hooks for just this component. + The useEffect hook for managing `on_mount` and `on_unmount` events. + """ + # pop on_mount and on_unmount from event_triggers since these are handled by + # hooks, not as actually props in the component + on_mount = self.event_triggers.pop(constants.ON_MOUNT, None) + on_unmount = self.event_triggers.pop(constants.ON_UNMOUNT, None) + if on_mount: + on_mount = format.format_event_chain(on_mount) + if on_unmount: + on_unmount = format.format_event_chain(on_unmount) + if on_mount or on_unmount: + return f""" + useEffect(() => {{ + {on_mount or ""} + return () => {{ + {on_unmount or ""} + }} + }}, []);""" + + def _get_ref_hook(self) -> str | None: + """Generate the ref hook for the component. + + Returns: + The useRef hook for managing refs. """ ref = self.get_ref() if ref is not None: return f"const {ref} = useRef(null); refs['{ref}'] = {ref};" - return None + + def _get_hooks_internal(self) -> Set[str]: + """Get the React hooks for this component managed by the framework. + + Downstream components should NOT override this method to avoid breaking + framework functionality. + + Returns: + Set of internally managed hooks. + """ + return set( + hook + for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()] + if hook + ) + + def _get_hooks(self) -> Optional[str]: + """Get the React hooks for this component. + + Downstream components should override this method to add their own hooks. + + Returns: + The hooks for just this component. + """ + return def get_hooks(self) -> Set[str]: """Get the React hooks for this component and its children. @@ -543,7 +594,7 @@ class Component(Base, ABC): The code that should appear just before returning the rendered component. """ # Store the code in a set to avoid duplicates. - code = set() + code = self._get_hooks_internal() # Add the hook code for this component. hooks = self._get_hooks() diff --git a/reflex/components/forms/pininput.py b/reflex/components/forms/pininput.py index 81ec13996..0e889cc23 100644 --- a/reflex/components/forms/pininput.py +++ b/reflex/components/forms/pininput.py @@ -76,8 +76,8 @@ class PinInput(ChakraComponent): """ return None - def _get_hooks(self) -> Optional[str]: - """Override the base get_hooks to handle array refs. + def _get_ref_hook(self) -> Optional[str]: + """Override the base _get_ref_hook to handle array refs. Returns: The overrided hooks. @@ -86,7 +86,7 @@ class PinInput(ChakraComponent): ref = format.format_array_ref(self.id, None) if ref: return f"const {ref} = Array.from({{length:{self.length}}}, () => useRef(null));" - return super()._get_hooks() + return super()._get_ref_hook() @classmethod def create(cls, *children, **props) -> Component: @@ -130,7 +130,7 @@ class PinInputField(ChakraComponent): # Default to None because it is assigned by PinInput when created. index: Optional[Var[int]] = None - def _get_hooks(self) -> Optional[str]: + def _get_ref_hook(self) -> Optional[str]: return None def get_ref(self): diff --git a/reflex/components/forms/rangeslider.py b/reflex/components/forms/rangeslider.py index b27e33205..c72c0f0ef 100644 --- a/reflex/components/forms/rangeslider.py +++ b/reflex/components/forms/rangeslider.py @@ -64,8 +64,8 @@ class RangeSlider(ChakraComponent): """ return None - def _get_hooks(self) -> Optional[str]: - """Override the base get_hooks to handle array refs. + def _get_ref_hook(self) -> Optional[str]: + """Override the base _get_ref_hook to handle array refs. Returns: The overrided hooks. @@ -74,7 +74,7 @@ class RangeSlider(ChakraComponent): ref = format.format_array_ref(self.id, None) if ref: return f"const {ref} = Array.from({{length:2}}, () => useRef(null));" - return super()._get_hooks() + return super()._get_ref_hook() @classmethod def create(cls, *children, **props) -> Component: @@ -130,7 +130,7 @@ class RangeSliderThumb(ChakraComponent): # The position of the thumb. index: Var[int] - def _get_hooks(self) -> Optional[str]: + def _get_ref_hook(self) -> Optional[str]: # hook is None because RangeSlider is handling it. return None diff --git a/reflex/constants.py b/reflex/constants.py index 8b87d717d..202ca275c 100644 --- a/reflex/constants.py +++ b/reflex/constants.py @@ -359,5 +359,9 @@ PING_TIMEOUT = 120 # Alembic migrations ALEMBIC_CONFIG = os.environ.get("ALEMBIC_CONFIG", "alembic.ini") +# Names of event handlers on all components mapped to useEffect +ON_MOUNT = "on_mount" +ON_UNMOUNT = "on_unmount" + # If this env var is set to "yes", App.compile will be a no-op SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE" diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 7eef1c194..26a41b2a0 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -16,10 +16,11 @@ from plotly.io import to_json from reflex import constants from reflex.utils import types +from reflex.vars import Var if TYPE_CHECKING: from reflex.components.component import ComponentStyle - from reflex.event import EventHandler, EventSpec + from reflex.event import EventChain, EventHandler, EventSpec WRAP_MAP = { "{": "}", @@ -182,6 +183,24 @@ def format_string(string: str) -> str: return string +def format_var(var: Var) -> str: + """Format the given Var as a javascript value. + + Args: + var: The Var to format. + + Returns: + The formatted Var. + """ + if not var.is_local or var.is_string: + return str(var) + if types._issubclass(var.type_, str): + return format_string(var.full_name) + if is_wrapped(var.full_name, "{"): + return var.full_name + return json_dumps(var.full_name) + + def format_route(route: str) -> str: """Format the given route. @@ -311,6 +330,46 @@ def format_event(event_spec: EventSpec) -> str: return f"E({', '.join(event_args)})" +def format_event_chain( + event_chain: EventChain | Var[EventChain], + event_arg: Var | None = None, +) -> str: + """Format an event chain as a javascript invocation. + + Args: + event_chain: The event chain to queue on the frontend. + event_arg: The browser-native event (only used to preventDefault). + + Returns: + Compiled javascript code to queue the given event chain on the frontend. + + Raises: + ValueError: When the given event chain is not a valid event chain. + """ + if isinstance(event_chain, Var): + from reflex.event import EventChain + + if event_chain.type_ is not EventChain: + raise ValueError(f"Invalid event chain: {event_chain}") + return "".join( + [ + "(() => {", + format_var(event_chain), + f"; preventDefault({format_var(event_arg)})" if event_arg else "", + "})()", + ] + ) + + chain = ",".join([format_event(event) for event in event_chain.events]) + return "".join( + [ + f"Event([{chain}]", + f", {format_var(event_arg)}" if event_arg else "", + ")", + ] + ) + + def format_query_params(router_data: Dict[str, Any]) -> Dict[str, str]: """Convert back query params name to python-friendly case. diff --git a/tests/components/test_component.py b/tests/components/test_component.py index c676e543a..8410e7bf4 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -5,6 +5,7 @@ import pytest import reflex as rx from reflex.components.component import Component, CustomComponent, custom_component from reflex.components.layout.box import Box +from reflex.constants import ON_MOUNT, ON_UNMOUNT from reflex.event import EVENT_ARG, EVENT_TRIGGERS, EventHandler from reflex.state import State from reflex.style import Style @@ -377,8 +378,9 @@ def test_get_triggers(component1, component2): component1: A test component. component2: A test component. """ - assert component1().get_triggers() == EVENT_TRIGGERS - assert component2().get_triggers() == {"on_open", "on_close"} | EVENT_TRIGGERS + default_triggers = {ON_MOUNT, ON_UNMOUNT} | EVENT_TRIGGERS + assert component1().get_triggers() == default_triggers + assert component2().get_triggers() == {"on_open", "on_close"} | default_triggers def test_create_custom_component(my_component):