Implement on_mount
and on_unmount
for all components. (#1636)
This commit is contained in:
parent
161a77ca23
commit
2392c52928
@ -155,8 +155,35 @@ def EventChain():
|
||||
rx.input(value=State.token, readonly=True, id="token"),
|
||||
)
|
||||
|
||||
def on_mount_return_chain():
|
||||
return rx.fragment(
|
||||
rx.text(
|
||||
"return",
|
||||
on_mount=State.on_load_return_chain,
|
||||
on_unmount=lambda: State.event_arg("unmount"), # type: ignore
|
||||
),
|
||||
rx.input(value=State.token, readonly=True, id="token"),
|
||||
rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"),
|
||||
)
|
||||
|
||||
def on_mount_yield_chain():
|
||||
return rx.fragment(
|
||||
rx.text(
|
||||
"yield",
|
||||
on_mount=[
|
||||
State.on_load_yield_chain,
|
||||
lambda: State.event_arg("mount"), # type: ignore
|
||||
],
|
||||
on_unmount=State.event_no_args,
|
||||
),
|
||||
rx.input(value=State.token, readonly=True, id="token"),
|
||||
rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"),
|
||||
)
|
||||
|
||||
app.add_page(on_load_return_chain, on_load=State.on_load_return_chain) # type: ignore
|
||||
app.add_page(on_load_yield_chain, on_load=State.on_load_yield_chain) # type: ignore
|
||||
app.add_page(on_mount_return_chain)
|
||||
app.add_page(on_mount_yield_chain)
|
||||
|
||||
app.compile()
|
||||
|
||||
@ -330,3 +357,69 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
|
||||
time.sleep(0.5)
|
||||
backend_state = event_chain.app_instance.state_manager.states[token]
|
||||
assert backend_state.event_order == exp_event_order
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("uri", "exp_event_order"),
|
||||
[
|
||||
(
|
||||
"/on-mount-return-chain",
|
||||
[
|
||||
"on_load_return_chain",
|
||||
"event_arg:unmount",
|
||||
"on_load_return_chain",
|
||||
"event_arg:1",
|
||||
"event_arg:2",
|
||||
"event_arg:3",
|
||||
"event_arg:1",
|
||||
"event_arg:2",
|
||||
"event_arg:3",
|
||||
"event_arg:unmount",
|
||||
],
|
||||
),
|
||||
(
|
||||
"/on-mount-yield-chain",
|
||||
[
|
||||
"on_load_yield_chain",
|
||||
"event_arg:mount",
|
||||
"event_no_args",
|
||||
"on_load_yield_chain",
|
||||
"event_arg:mount",
|
||||
"event_arg:4",
|
||||
"event_arg:5",
|
||||
"event_arg:6",
|
||||
"event_arg:4",
|
||||
"event_arg:5",
|
||||
"event_arg:6",
|
||||
"event_no_args",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
|
||||
"""Load the URI, assert that the events are handled in the correct order.
|
||||
|
||||
These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode
|
||||
due to react StrictMode being used.
|
||||
|
||||
In prod mode, these events are only fired once.
|
||||
|
||||
Args:
|
||||
event_chain: AppHarness for the event_chain app
|
||||
driver: selenium WebDriver open to the app
|
||||
uri: the page to load
|
||||
exp_event_order: the expected events recorded in the State
|
||||
"""
|
||||
driver.get(event_chain.frontend_url + uri)
|
||||
token_input = driver.find_element(By.ID, "token")
|
||||
assert token_input
|
||||
|
||||
token = event_chain.poll_for_value(token_input)
|
||||
|
||||
unmount_button = driver.find_element(By.ID, "unmount")
|
||||
assert unmount_button
|
||||
unmount_button.click()
|
||||
|
||||
time.sleep(1)
|
||||
backend_state = event_chain.app_instance.state_manager.states[token]
|
||||
assert backend_state.event_order == exp_event_order
|
||||
|
@ -218,6 +218,11 @@ export const queueEvents = async (events, socket) => {
|
||||
export const processEvent = async (
|
||||
socket
|
||||
) => {
|
||||
// Only proceed if the socket is up, otherwise we throw the event into the void
|
||||
if (!socket) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only proceed if we're not already processing an event.
|
||||
if (event_queue.length === 0 || event_processing) {
|
||||
return;
|
||||
|
@ -286,7 +286,11 @@ class Component(Base, ABC):
|
||||
Returns:
|
||||
The event triggers.
|
||||
"""
|
||||
return EVENT_TRIGGERS | set(self.get_controlled_triggers())
|
||||
return (
|
||||
EVENT_TRIGGERS
|
||||
| set(self.get_controlled_triggers())
|
||||
| set((constants.ON_MOUNT, constants.ON_UNMOUNT))
|
||||
)
|
||||
|
||||
def get_controlled_triggers(self) -> Dict[str, Var]:
|
||||
"""Get the event triggers that pass the component's value to the handler.
|
||||
@ -525,16 +529,63 @@ class Component(Base, ABC):
|
||||
self._get_imports(), *[child.get_imports() for child in self.children]
|
||||
)
|
||||
|
||||
def _get_hooks(self) -> Optional[str]:
|
||||
"""Get the React hooks for this component.
|
||||
def _get_mount_lifecycle_hook(self) -> str | None:
|
||||
"""Generate the component lifecycle hook.
|
||||
|
||||
Returns:
|
||||
The hooks for just this component.
|
||||
The useEffect hook for managing `on_mount` and `on_unmount` events.
|
||||
"""
|
||||
# pop on_mount and on_unmount from event_triggers since these are handled by
|
||||
# hooks, not as actually props in the component
|
||||
on_mount = self.event_triggers.pop(constants.ON_MOUNT, None)
|
||||
on_unmount = self.event_triggers.pop(constants.ON_UNMOUNT, None)
|
||||
if on_mount:
|
||||
on_mount = format.format_event_chain(on_mount)
|
||||
if on_unmount:
|
||||
on_unmount = format.format_event_chain(on_unmount)
|
||||
if on_mount or on_unmount:
|
||||
return f"""
|
||||
useEffect(() => {{
|
||||
{on_mount or ""}
|
||||
return () => {{
|
||||
{on_unmount or ""}
|
||||
}}
|
||||
}}, []);"""
|
||||
|
||||
def _get_ref_hook(self) -> str | None:
|
||||
"""Generate the ref hook for the component.
|
||||
|
||||
Returns:
|
||||
The useRef hook for managing refs.
|
||||
"""
|
||||
ref = self.get_ref()
|
||||
if ref is not None:
|
||||
return f"const {ref} = useRef(null); refs['{ref}'] = {ref};"
|
||||
return None
|
||||
|
||||
def _get_hooks_internal(self) -> Set[str]:
|
||||
"""Get the React hooks for this component managed by the framework.
|
||||
|
||||
Downstream components should NOT override this method to avoid breaking
|
||||
framework functionality.
|
||||
|
||||
Returns:
|
||||
Set of internally managed hooks.
|
||||
"""
|
||||
return set(
|
||||
hook
|
||||
for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
|
||||
if hook
|
||||
)
|
||||
|
||||
def _get_hooks(self) -> Optional[str]:
|
||||
"""Get the React hooks for this component.
|
||||
|
||||
Downstream components should override this method to add their own hooks.
|
||||
|
||||
Returns:
|
||||
The hooks for just this component.
|
||||
"""
|
||||
return
|
||||
|
||||
def get_hooks(self) -> Set[str]:
|
||||
"""Get the React hooks for this component and its children.
|
||||
@ -543,7 +594,7 @@ class Component(Base, ABC):
|
||||
The code that should appear just before returning the rendered component.
|
||||
"""
|
||||
# Store the code in a set to avoid duplicates.
|
||||
code = set()
|
||||
code = self._get_hooks_internal()
|
||||
|
||||
# Add the hook code for this component.
|
||||
hooks = self._get_hooks()
|
||||
|
@ -76,8 +76,8 @@ class PinInput(ChakraComponent):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_hooks(self) -> Optional[str]:
|
||||
"""Override the base get_hooks to handle array refs.
|
||||
def _get_ref_hook(self) -> Optional[str]:
|
||||
"""Override the base _get_ref_hook to handle array refs.
|
||||
|
||||
Returns:
|
||||
The overrided hooks.
|
||||
@ -86,7 +86,7 @@ class PinInput(ChakraComponent):
|
||||
ref = format.format_array_ref(self.id, None)
|
||||
if ref:
|
||||
return f"const {ref} = Array.from({{length:{self.length}}}, () => useRef(null));"
|
||||
return super()._get_hooks()
|
||||
return super()._get_ref_hook()
|
||||
|
||||
@classmethod
|
||||
def create(cls, *children, **props) -> Component:
|
||||
@ -130,7 +130,7 @@ class PinInputField(ChakraComponent):
|
||||
# Default to None because it is assigned by PinInput when created.
|
||||
index: Optional[Var[int]] = None
|
||||
|
||||
def _get_hooks(self) -> Optional[str]:
|
||||
def _get_ref_hook(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def get_ref(self):
|
||||
|
@ -64,8 +64,8 @@ class RangeSlider(ChakraComponent):
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_hooks(self) -> Optional[str]:
|
||||
"""Override the base get_hooks to handle array refs.
|
||||
def _get_ref_hook(self) -> Optional[str]:
|
||||
"""Override the base _get_ref_hook to handle array refs.
|
||||
|
||||
Returns:
|
||||
The overrided hooks.
|
||||
@ -74,7 +74,7 @@ class RangeSlider(ChakraComponent):
|
||||
ref = format.format_array_ref(self.id, None)
|
||||
if ref:
|
||||
return f"const {ref} = Array.from({{length:2}}, () => useRef(null));"
|
||||
return super()._get_hooks()
|
||||
return super()._get_ref_hook()
|
||||
|
||||
@classmethod
|
||||
def create(cls, *children, **props) -> Component:
|
||||
@ -130,7 +130,7 @@ class RangeSliderThumb(ChakraComponent):
|
||||
# The position of the thumb.
|
||||
index: Var[int]
|
||||
|
||||
def _get_hooks(self) -> Optional[str]:
|
||||
def _get_ref_hook(self) -> Optional[str]:
|
||||
# hook is None because RangeSlider is handling it.
|
||||
return None
|
||||
|
||||
|
@ -359,5 +359,9 @@ PING_TIMEOUT = 120
|
||||
# Alembic migrations
|
||||
ALEMBIC_CONFIG = os.environ.get("ALEMBIC_CONFIG", "alembic.ini")
|
||||
|
||||
# Names of event handlers on all components mapped to useEffect
|
||||
ON_MOUNT = "on_mount"
|
||||
ON_UNMOUNT = "on_unmount"
|
||||
|
||||
# If this env var is set to "yes", App.compile will be a no-op
|
||||
SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"
|
||||
|
@ -16,10 +16,11 @@ from plotly.io import to_json
|
||||
|
||||
from reflex import constants
|
||||
from reflex.utils import types
|
||||
from reflex.vars import Var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflex.components.component import ComponentStyle
|
||||
from reflex.event import EventHandler, EventSpec
|
||||
from reflex.event import EventChain, EventHandler, EventSpec
|
||||
|
||||
WRAP_MAP = {
|
||||
"{": "}",
|
||||
@ -182,6 +183,24 @@ def format_string(string: str) -> str:
|
||||
return string
|
||||
|
||||
|
||||
def format_var(var: Var) -> str:
|
||||
"""Format the given Var as a javascript value.
|
||||
|
||||
Args:
|
||||
var: The Var to format.
|
||||
|
||||
Returns:
|
||||
The formatted Var.
|
||||
"""
|
||||
if not var.is_local or var.is_string:
|
||||
return str(var)
|
||||
if types._issubclass(var.type_, str):
|
||||
return format_string(var.full_name)
|
||||
if is_wrapped(var.full_name, "{"):
|
||||
return var.full_name
|
||||
return json_dumps(var.full_name)
|
||||
|
||||
|
||||
def format_route(route: str) -> str:
|
||||
"""Format the given route.
|
||||
|
||||
@ -311,6 +330,46 @@ def format_event(event_spec: EventSpec) -> str:
|
||||
return f"E({', '.join(event_args)})"
|
||||
|
||||
|
||||
def format_event_chain(
|
||||
event_chain: EventChain | Var[EventChain],
|
||||
event_arg: Var | None = None,
|
||||
) -> str:
|
||||
"""Format an event chain as a javascript invocation.
|
||||
|
||||
Args:
|
||||
event_chain: The event chain to queue on the frontend.
|
||||
event_arg: The browser-native event (only used to preventDefault).
|
||||
|
||||
Returns:
|
||||
Compiled javascript code to queue the given event chain on the frontend.
|
||||
|
||||
Raises:
|
||||
ValueError: When the given event chain is not a valid event chain.
|
||||
"""
|
||||
if isinstance(event_chain, Var):
|
||||
from reflex.event import EventChain
|
||||
|
||||
if event_chain.type_ is not EventChain:
|
||||
raise ValueError(f"Invalid event chain: {event_chain}")
|
||||
return "".join(
|
||||
[
|
||||
"(() => {",
|
||||
format_var(event_chain),
|
||||
f"; preventDefault({format_var(event_arg)})" if event_arg else "",
|
||||
"})()",
|
||||
]
|
||||
)
|
||||
|
||||
chain = ",".join([format_event(event) for event in event_chain.events])
|
||||
return "".join(
|
||||
[
|
||||
f"Event([{chain}]",
|
||||
f", {format_var(event_arg)}" if event_arg else "",
|
||||
")",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def format_query_params(router_data: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Convert back query params name to python-friendly case.
|
||||
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
import reflex as rx
|
||||
from reflex.components.component import Component, CustomComponent, custom_component
|
||||
from reflex.components.layout.box import Box
|
||||
from reflex.constants import ON_MOUNT, ON_UNMOUNT
|
||||
from reflex.event import EVENT_ARG, EVENT_TRIGGERS, EventHandler
|
||||
from reflex.state import State
|
||||
from reflex.style import Style
|
||||
@ -377,8 +378,9 @@ def test_get_triggers(component1, component2):
|
||||
component1: A test component.
|
||||
component2: A test component.
|
||||
"""
|
||||
assert component1().get_triggers() == EVENT_TRIGGERS
|
||||
assert component2().get_triggers() == {"on_open", "on_close"} | EVENT_TRIGGERS
|
||||
default_triggers = {ON_MOUNT, ON_UNMOUNT} | EVENT_TRIGGERS
|
||||
assert component1().get_triggers() == default_triggers
|
||||
assert component2().get_triggers() == {"on_open", "on_close"} | default_triggers
|
||||
|
||||
|
||||
def test_create_custom_component(my_component):
|
||||
|
Loading…
Reference in New Issue
Block a user