From f66c6c3361f81eb20e18e3eacc24c20b4714cce6 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 26 Oct 2023 17:54:48 -0700 Subject: [PATCH] Support callback for rx.call_script (#2045) --- integration/test_call_script.py | 341 +++++++++++++++++++++++++++ reflex/.templates/web/utils/state.js | 12 +- reflex/event.py | 35 ++- reflex/utils/format.py | 2 +- 4 files changed, 385 insertions(+), 5 deletions(-) create mode 100644 integration/test_call_script.py diff --git a/integration/test_call_script.py b/integration/test_call_script.py new file mode 100644 index 000000000..d2162d34c --- /dev/null +++ b/integration/test_call_script.py @@ -0,0 +1,341 @@ +"""Integration tests for client side storage.""" +from __future__ import annotations + +from typing import Generator + +import pytest +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.webdriver import WebDriver + +from reflex.testing import AppHarness + + +def CallScript(): + """A test app for browser javascript integration.""" + import reflex as rx + + inline_scripts = """ + let inline_counter = 0 + function inline1() { + inline_counter += 1 + return "inline1" + } + function inline2() { + inline_counter += 1 + console.log("inline2") + } + function inline3() { + inline_counter += 1 + return {inline3: 42, a: [1, 2, 3], s: 'js', o: {a: 1, b: 2}} + } + """ + + external_scripts = inline_scripts.replace("inline", "external") + + class CallScriptState(rx.State): + results: list[str | dict | list | None] = [] + inline_counter: int = 0 + external_counter: int = 0 + + def call_script_callback(self, result): + self.results.append(result) + + def call_script_callback_other_arg(self, result, other_arg): + self.results.append([other_arg, result]) + + def call_scripts_inline_yield(self): + yield rx.call_script("inline1()") + yield rx.call_script("inline2()") + yield rx.call_script("inline3()") + + def call_script_inline_return(self): + return rx.call_script("inline2()") + + def call_scripts_inline_yield_callback(self): + yield rx.call_script( + "inline1()", callback=CallScriptState.call_script_callback + ) + yield rx.call_script( + "inline2()", callback=CallScriptState.call_script_callback + ) + yield rx.call_script( + "inline3()", callback=CallScriptState.call_script_callback + ) + + def call_script_inline_return_callback(self): + return rx.call_script( + "inline3()", callback=CallScriptState.call_script_callback + ) + + def call_script_inline_return_lambda(self): + return rx.call_script( + "inline2()", + callback=lambda result: CallScriptState.call_script_callback_other_arg( # type: ignore + result, "lambda" + ), + ) + + def get_inline_counter(self): + return rx.call_script( + "inline_counter", + callback=CallScriptState.set_inline_counter, # type: ignore + ) + + def call_scripts_external_yield(self): + yield rx.call_script("external1()") + yield rx.call_script("external2()") + yield rx.call_script("external3()") + + def call_script_external_return(self): + return rx.call_script("external2()") + + def call_scripts_external_yield_callback(self): + yield rx.call_script( + "external1()", callback=CallScriptState.call_script_callback + ) + yield rx.call_script( + "external2()", callback=CallScriptState.call_script_callback + ) + yield rx.call_script( + "external3()", callback=CallScriptState.call_script_callback + ) + + def call_script_external_return_callback(self): + return rx.call_script( + "external3()", callback=CallScriptState.call_script_callback + ) + + def call_script_external_return_lambda(self): + return rx.call_script( + "external2()", + callback=lambda result: CallScriptState.call_script_callback_other_arg( # type: ignore + result, "lambda" + ), + ) + + def get_external_counter(self): + return rx.call_script( + "external_counter", + callback=CallScriptState.set_external_counter, # type: ignore + ) + + def reset_(self): + yield rx.call_script("inline_counter = 0; external_counter = 0") + self.reset() + + app = rx.App(state=CallScriptState) + with open("assets/external.js", "w") as f: + f.write(external_scripts) + + @app.add_page + def index(): + return rx.vstack( + rx.input( + value=CallScriptState.router.session.client_token, + is_read_only=True, + id="token", + ), + rx.input( + value=CallScriptState.inline_counter.to(str), # type: ignore + id="inline_counter", + is_read_only=True, + ), + rx.input( + value=CallScriptState.external_counter.to(str), # type: ignore + id="external_counter", + is_read_only=True, + ), + rx.text_area( + value=CallScriptState.results.to_string(), # type: ignore + id="results", + is_read_only=True, + ), + rx.script(inline_scripts), + rx.script(src="/external.js"), + rx.button( + "call_scripts_inline_yield", + on_click=CallScriptState.call_scripts_inline_yield, + id="inline_yield", + ), + rx.button( + "call_script_inline_return", + on_click=CallScriptState.call_script_inline_return, + id="inline_return", + ), + rx.button( + "call_scripts_inline_yield_callback", + on_click=CallScriptState.call_scripts_inline_yield_callback, + id="inline_yield_callback", + ), + rx.button( + "call_script_inline_return_callback", + on_click=CallScriptState.call_script_inline_return_callback, + id="inline_return_callback", + ), + rx.button( + "call_script_inline_return_lambda", + on_click=CallScriptState.call_script_inline_return_lambda, + id="inline_return_lambda", + ), + rx.button( + "call_scripts_external_yield", + on_click=CallScriptState.call_scripts_external_yield, + id="external_yield", + ), + rx.button( + "call_script_external_return", + on_click=CallScriptState.call_script_external_return, + id="external_return", + ), + rx.button( + "call_scripts_external_yield_callback", + on_click=CallScriptState.call_scripts_external_yield_callback, + id="external_yield_callback", + ), + rx.button( + "call_script_external_return_callback", + on_click=CallScriptState.call_script_external_return_callback, + id="external_return_callback", + ), + rx.button( + "call_script_external_return_lambda", + on_click=CallScriptState.call_script_external_return_lambda, + id="external_return_lambda", + ), + rx.button( + "Update Inline Counter", + on_click=CallScriptState.get_inline_counter, + id="update_inline_counter", + ), + rx.button( + "Update External Counter", + on_click=CallScriptState.get_external_counter, + id="update_external_counter", + ), + rx.button("Reset", id="reset", on_click=CallScriptState.reset_), + ) + + app.compile() + + +@pytest.fixture(scope="session") +def call_script(tmp_path_factory) -> Generator[AppHarness, None, None]: + """Start CallScript app at tmp_path via AppHarness. + + Args: + tmp_path_factory: pytest tmp_path_factory fixture + + Yields: + running AppHarness instance + """ + with AppHarness.create( + root=tmp_path_factory.mktemp("call_script"), + app_source=CallScript, # type: ignore + ) as harness: + yield harness + + +@pytest.fixture +def driver(call_script: AppHarness) -> Generator[WebDriver, None, None]: + """Get an instance of the browser open to the call_script app. + + Args: + call_script: harness for CallScript app + + Yields: + WebDriver instance. + """ + assert call_script.app_instance is not None, "app is not running" + driver = call_script.frontend() + try: + yield driver + finally: + driver.quit() + + +def assert_token(call_script: AppHarness, driver: WebDriver) -> str: + """Get the token associated with backend state. + + Args: + call_script: harness for CallScript app. + driver: WebDriver instance. + + Returns: + The token visible in the driver browser. + """ + assert call_script.app_instance is not None + token_input = driver.find_element(By.ID, "token") + assert token_input + + # wait for the backend connection to send the token + token = call_script.poll_for_value(token_input) + assert token is not None + + return token + + +@pytest.mark.parametrize("script", ["inline", "external"]) +def test_call_script( + call_script: AppHarness, + driver: WebDriver, + script: str, +): + """Test calling javascript functions from python. + + Args: + call_script: harness for CallScript app. + driver: WebDriver instance. + script: The type of script to test. + """ + assert_token(call_script, driver) + reset_button = driver.find_element(By.ID, "reset") + update_counter_button = driver.find_element(By.ID, f"update_{script}_counter") + counter = driver.find_element(By.ID, f"{script}_counter") + results = driver.find_element(By.ID, "results") + yield_button = driver.find_element(By.ID, f"{script}_yield") + return_button = driver.find_element(By.ID, f"{script}_return") + yield_callback_button = driver.find_element(By.ID, f"{script}_yield_callback") + return_callback_button = driver.find_element(By.ID, f"{script}_return_callback") + return_lambda_button = driver.find_element(By.ID, f"{script}_return_lambda") + + yield_button.click() + update_counter_button.click() + assert call_script.poll_for_value(counter, exp_not_equal="0") == "3" + reset_button.click() + assert call_script.poll_for_value(counter, exp_not_equal="3") == "0" + return_button.click() + update_counter_button.click() + assert call_script.poll_for_value(counter, exp_not_equal="0") == "1" + reset_button.click() + assert call_script.poll_for_value(counter, exp_not_equal="1") == "0" + + yield_callback_button.click() + update_counter_button.click() + assert call_script.poll_for_value(counter, exp_not_equal="0") == "3" + assert call_script.poll_for_value( + results, exp_not_equal="[]" + ) == '["%s1",null,{"%s3":42,"a":[1,2,3],"s":"js","o":{"a":1,"b":2}}]' % ( + script, + script, + ) + reset_button.click() + assert call_script.poll_for_value(counter, exp_not_equal="3") == "0" + + return_callback_button.click() + update_counter_button.click() + assert call_script.poll_for_value(counter, exp_not_equal="0") == "1" + assert ( + call_script.poll_for_value(results, exp_not_equal="[]") + == '[{"%s3":42,"a":[1,2,3],"s":"js","o":{"a":1,"b":2}}]' % script + ) + reset_button.click() + assert call_script.poll_for_value(counter, exp_not_equal="1") == "0" + + return_lambda_button.click() + update_counter_button.click() + assert call_script.poll_for_value(counter, exp_not_equal="0") == "1" + assert ( + call_script.poll_for_value(results, exp_not_equal="[]") == '[["lambda",null]]' + ) + reset_button.click() + assert call_script.poll_for_value(counter, exp_not_equal="1") == "0" diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 7edc35eff..1c0c63774 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -198,7 +198,10 @@ export const applyEvent = async (event, socket) => { if (event.name == "_call_script") { try { - eval(event.payload.javascript_code); + const eval_result = eval(event.payload.javascript_code); + if (event.payload.callback) { + eval(event.payload.callback)(eval_result) + } } catch (e) { console.log("_call_script", e); } @@ -213,7 +216,7 @@ export const applyEvent = async (event, socket) => { // Send the event to the server. if (socket) { - socket.emit("event", JSON.stringify(event)); + socket.emit("event", JSON.stringify(event, (k, v) => v === undefined ? null : v)); return true; } @@ -407,7 +410,10 @@ export const hydrateClientStorage = (client_storage) => { for (const state_key in client_storage.cookies) { const cookie_options = client_storage.cookies[state_key] const cookie_name = cookie_options.name || state_key - client_storage_values.cookies[state_key] = cookies.get(cookie_name) + const cookie_value = cookies.get(cookie_name) + if (cookie_value !== undefined) { + client_storage_values.cookies[state_key] = cookies.get(cookie_name) + } } } if (client_storage.local_storage && (typeof window !== 'undefined')) { diff --git a/reflex/event.py b/reflex/event.py index 683919a00..422183235 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -2,6 +2,7 @@ from __future__ import annotations import inspect +from types import FunctionType from typing import ( TYPE_CHECKING, Any, @@ -408,19 +409,51 @@ def download(url: str, filename: Optional[str] = None) -> EventSpec: ) -def call_script(javascript_code: str) -> EventSpec: +def _callback_arg_spec(eval_result): + """ArgSpec for call_script callback function. + + Args: + eval_result: The result of the javascript execution. + + Returns: + Args for the callback function + """ + return [eval_result] + + +def call_script( + javascript_code: str, + callback: EventHandler | Callable | None = None, +) -> EventSpec: """Create an event handler that executes arbitrary javascript code. Args: javascript_code: The code to execute. + callback: EventHandler that will receive the result of evaluating the javascript code. Returns: EventSpec: An event that will execute the client side javascript. + + Raises: + ValueError: If the callback is not a valid event handler. """ + callback_kwargs = {} + if callback is not None: + arg_name = parse_args_spec(_callback_arg_spec)[0]._var_name + if isinstance(callback, EventHandler): + event_spec = call_event_handler(callback, _callback_arg_spec) + elif isinstance(callback, FunctionType): + event_spec = call_event_fn(callback, _callback_arg_spec)[0] + else: + raise ValueError("Cannot use {callback!r} as a call_script callback.") + callback_kwargs = { + "callback": f"({arg_name}) => queueEvents([{format.format_event(event_spec)}], {constants.CompileVars.SOCKET})" + } return server_side( "_call_script", get_fn_signature(call_script), javascript_code=javascript_code, + **callback_kwargs, ) diff --git a/reflex/utils/format.py b/reflex/utils/format.py index df3c354df..1ec24dd37 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -417,7 +417,7 @@ def format_event(event_spec: EventSpec) -> str: ":".join( ( name._var_name, - wrap(json.dumps(val._var_name).strip('"'), "`") + wrap(json.dumps(val._var_name).strip('"').replace("`", "\\`"), "`") if val._var_is_string else val._var_full_name, )