Implement on_mount and on_unmount for all components. (#1636)

This commit is contained in:
Masen Furer 2023-08-30 09:50:39 -07:00 committed by GitHub
parent 161a77ca23
commit 2392c52928
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 231 additions and 17 deletions

View File

@ -155,8 +155,35 @@ def EventChain():
rx.input(value=State.token, readonly=True, id="token"),
)
def on_mount_return_chain():
return rx.fragment(
rx.text(
"return",
on_mount=State.on_load_return_chain,
on_unmount=lambda: State.event_arg("unmount"), # type: ignore
),
rx.input(value=State.token, readonly=True, id="token"),
rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"),
)
def on_mount_yield_chain():
return rx.fragment(
rx.text(
"yield",
on_mount=[
State.on_load_yield_chain,
lambda: State.event_arg("mount"), # type: ignore
],
on_unmount=State.event_no_args,
),
rx.input(value=State.token, readonly=True, id="token"),
rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"),
)
app.add_page(on_load_return_chain, on_load=State.on_load_return_chain) # type: ignore
app.add_page(on_load_yield_chain, on_load=State.on_load_yield_chain) # type: ignore
app.add_page(on_mount_return_chain)
app.add_page(on_mount_yield_chain)
app.compile()
@ -330,3 +357,69 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
time.sleep(0.5)
backend_state = event_chain.app_instance.state_manager.states[token]
assert backend_state.event_order == exp_event_order
@pytest.mark.parametrize(
("uri", "exp_event_order"),
[
(
"/on-mount-return-chain",
[
"on_load_return_chain",
"event_arg:unmount",
"on_load_return_chain",
"event_arg:1",
"event_arg:2",
"event_arg:3",
"event_arg:1",
"event_arg:2",
"event_arg:3",
"event_arg:unmount",
],
),
(
"/on-mount-yield-chain",
[
"on_load_yield_chain",
"event_arg:mount",
"event_no_args",
"on_load_yield_chain",
"event_arg:mount",
"event_arg:4",
"event_arg:5",
"event_arg:6",
"event_arg:4",
"event_arg:5",
"event_arg:6",
"event_no_args",
],
),
],
)
def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
"""Load the URI, assert that the events are handled in the correct order.
These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode
due to react StrictMode being used.
In prod mode, these events are only fired once.
Args:
event_chain: AppHarness for the event_chain app
driver: selenium WebDriver open to the app
uri: the page to load
exp_event_order: the expected events recorded in the State
"""
driver.get(event_chain.frontend_url + uri)
token_input = driver.find_element(By.ID, "token")
assert token_input
token = event_chain.poll_for_value(token_input)
unmount_button = driver.find_element(By.ID, "unmount")
assert unmount_button
unmount_button.click()
time.sleep(1)
backend_state = event_chain.app_instance.state_manager.states[token]
assert backend_state.event_order == exp_event_order

View File

@ -218,6 +218,11 @@ export const queueEvents = async (events, socket) => {
export const processEvent = async (
socket
) => {
// Only proceed if the socket is up, otherwise we throw the event into the void
if (!socket) {
return;
}
// Only proceed if we're not already processing an event.
if (event_queue.length === 0 || event_processing) {
return;

View File

@ -286,7 +286,11 @@ class Component(Base, ABC):
Returns:
The event triggers.
"""
return EVENT_TRIGGERS | set(self.get_controlled_triggers())
return (
EVENT_TRIGGERS
| set(self.get_controlled_triggers())
| set((constants.ON_MOUNT, constants.ON_UNMOUNT))
)
def get_controlled_triggers(self) -> Dict[str, Var]:
"""Get the event triggers that pass the component's value to the handler.
@ -525,16 +529,63 @@ class Component(Base, ABC):
self._get_imports(), *[child.get_imports() for child in self.children]
)
def _get_hooks(self) -> Optional[str]:
"""Get the React hooks for this component.
def _get_mount_lifecycle_hook(self) -> str | None:
"""Generate the component lifecycle hook.
Returns:
The hooks for just this component.
The useEffect hook for managing `on_mount` and `on_unmount` events.
"""
# pop on_mount and on_unmount from event_triggers since these are handled by
# hooks, not as actually props in the component
on_mount = self.event_triggers.pop(constants.ON_MOUNT, None)
on_unmount = self.event_triggers.pop(constants.ON_UNMOUNT, None)
if on_mount:
on_mount = format.format_event_chain(on_mount)
if on_unmount:
on_unmount = format.format_event_chain(on_unmount)
if on_mount or on_unmount:
return f"""
useEffect(() => {{
{on_mount or ""}
return () => {{
{on_unmount or ""}
}}
}}, []);"""
def _get_ref_hook(self) -> str | None:
"""Generate the ref hook for the component.
Returns:
The useRef hook for managing refs.
"""
ref = self.get_ref()
if ref is not None:
return f"const {ref} = useRef(null); refs['{ref}'] = {ref};"
return None
def _get_hooks_internal(self) -> Set[str]:
"""Get the React hooks for this component managed by the framework.
Downstream components should NOT override this method to avoid breaking
framework functionality.
Returns:
Set of internally managed hooks.
"""
return set(
hook
for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
if hook
)
def _get_hooks(self) -> Optional[str]:
"""Get the React hooks for this component.
Downstream components should override this method to add their own hooks.
Returns:
The hooks for just this component.
"""
return
def get_hooks(self) -> Set[str]:
"""Get the React hooks for this component and its children.
@ -543,7 +594,7 @@ class Component(Base, ABC):
The code that should appear just before returning the rendered component.
"""
# Store the code in a set to avoid duplicates.
code = set()
code = self._get_hooks_internal()
# Add the hook code for this component.
hooks = self._get_hooks()

View File

@ -76,8 +76,8 @@ class PinInput(ChakraComponent):
"""
return None
def _get_hooks(self) -> Optional[str]:
"""Override the base get_hooks to handle array refs.
def _get_ref_hook(self) -> Optional[str]:
"""Override the base _get_ref_hook to handle array refs.
Returns:
The overrided hooks.
@ -86,7 +86,7 @@ class PinInput(ChakraComponent):
ref = format.format_array_ref(self.id, None)
if ref:
return f"const {ref} = Array.from({{length:{self.length}}}, () => useRef(null));"
return super()._get_hooks()
return super()._get_ref_hook()
@classmethod
def create(cls, *children, **props) -> Component:
@ -130,7 +130,7 @@ class PinInputField(ChakraComponent):
# Default to None because it is assigned by PinInput when created.
index: Optional[Var[int]] = None
def _get_hooks(self) -> Optional[str]:
def _get_ref_hook(self) -> Optional[str]:
return None
def get_ref(self):

View File

@ -64,8 +64,8 @@ class RangeSlider(ChakraComponent):
"""
return None
def _get_hooks(self) -> Optional[str]:
"""Override the base get_hooks to handle array refs.
def _get_ref_hook(self) -> Optional[str]:
"""Override the base _get_ref_hook to handle array refs.
Returns:
The overrided hooks.
@ -74,7 +74,7 @@ class RangeSlider(ChakraComponent):
ref = format.format_array_ref(self.id, None)
if ref:
return f"const {ref} = Array.from({{length:2}}, () => useRef(null));"
return super()._get_hooks()
return super()._get_ref_hook()
@classmethod
def create(cls, *children, **props) -> Component:
@ -130,7 +130,7 @@ class RangeSliderThumb(ChakraComponent):
# The position of the thumb.
index: Var[int]
def _get_hooks(self) -> Optional[str]:
def _get_ref_hook(self) -> Optional[str]:
# hook is None because RangeSlider is handling it.
return None

View File

@ -359,5 +359,9 @@ PING_TIMEOUT = 120
# Alembic migrations
ALEMBIC_CONFIG = os.environ.get("ALEMBIC_CONFIG", "alembic.ini")
# Names of event handlers on all components mapped to useEffect
ON_MOUNT = "on_mount"
ON_UNMOUNT = "on_unmount"
# If this env var is set to "yes", App.compile will be a no-op
SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"

View File

@ -16,10 +16,11 @@ from plotly.io import to_json
from reflex import constants
from reflex.utils import types
from reflex.vars import Var
if TYPE_CHECKING:
from reflex.components.component import ComponentStyle
from reflex.event import EventHandler, EventSpec
from reflex.event import EventChain, EventHandler, EventSpec
WRAP_MAP = {
"{": "}",
@ -182,6 +183,24 @@ def format_string(string: str) -> str:
return string
def format_var(var: Var) -> str:
"""Format the given Var as a javascript value.
Args:
var: The Var to format.
Returns:
The formatted Var.
"""
if not var.is_local or var.is_string:
return str(var)
if types._issubclass(var.type_, str):
return format_string(var.full_name)
if is_wrapped(var.full_name, "{"):
return var.full_name
return json_dumps(var.full_name)
def format_route(route: str) -> str:
"""Format the given route.
@ -311,6 +330,46 @@ def format_event(event_spec: EventSpec) -> str:
return f"E({', '.join(event_args)})"
def format_event_chain(
event_chain: EventChain | Var[EventChain],
event_arg: Var | None = None,
) -> str:
"""Format an event chain as a javascript invocation.
Args:
event_chain: The event chain to queue on the frontend.
event_arg: The browser-native event (only used to preventDefault).
Returns:
Compiled javascript code to queue the given event chain on the frontend.
Raises:
ValueError: When the given event chain is not a valid event chain.
"""
if isinstance(event_chain, Var):
from reflex.event import EventChain
if event_chain.type_ is not EventChain:
raise ValueError(f"Invalid event chain: {event_chain}")
return "".join(
[
"(() => {",
format_var(event_chain),
f"; preventDefault({format_var(event_arg)})" if event_arg else "",
"})()",
]
)
chain = ",".join([format_event(event) for event in event_chain.events])
return "".join(
[
f"Event([{chain}]",
f", {format_var(event_arg)}" if event_arg else "",
")",
]
)
def format_query_params(router_data: Dict[str, Any]) -> Dict[str, str]:
"""Convert back query params name to python-friendly case.

View File

@ -5,6 +5,7 @@ import pytest
import reflex as rx
from reflex.components.component import Component, CustomComponent, custom_component
from reflex.components.layout.box import Box
from reflex.constants import ON_MOUNT, ON_UNMOUNT
from reflex.event import EVENT_ARG, EVENT_TRIGGERS, EventHandler
from reflex.state import State
from reflex.style import Style
@ -377,8 +378,9 @@ def test_get_triggers(component1, component2):
component1: A test component.
component2: A test component.
"""
assert component1().get_triggers() == EVENT_TRIGGERS
assert component2().get_triggers() == {"on_open", "on_close"} | EVENT_TRIGGERS
default_triggers = {ON_MOUNT, ON_UNMOUNT} | EVENT_TRIGGERS
assert component1().get_triggers() == default_triggers
assert component2().get_triggers() == {"on_open", "on_close"} | default_triggers
def test_create_custom_component(my_component):