diff --git a/integration/__init__.py b/integration/__init__.py new file mode 100644 index 000000000..58fbe1b85 --- /dev/null +++ b/integration/__init__.py @@ -0,0 +1 @@ +"""Package for integration tests.""" diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py new file mode 100644 index 000000000..fd0290f22 --- /dev/null +++ b/integration/test_client_storage.py @@ -0,0 +1,515 @@ +"""Integration tests for client side storage.""" +from __future__ import annotations + +import time +from typing import Generator + +import pytest +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.webdriver import WebDriver + +from reflex.testing import AppHarness + +from . import utils + + +def ClientSide(): + """App for testing client-side state.""" + import reflex as rx + + class ClientSideState(rx.State): + state_var: str = "" + input_value: str = "" + + @rx.var + def token(self) -> str: + return self.get_token() + + class ClientSideSubState(ClientSideState): + # cookies with default settings + c1: str = rx.Cookie() + c2: rx.Cookie = "c2 default" # type: ignore + + # cookies with custom settings + c3: str = rx.Cookie(max_age=2) # expires after 2 second + c4: rx.Cookie = rx.Cookie(same_site="strict") + c5: str = rx.Cookie(path="/foo/") # only accessible on `/foo/` + c6: str = rx.Cookie(name="c6") + c7: str = rx.Cookie("c7 default") + + # local storage with default settings + l1: str = rx.LocalStorage() + l2: rx.LocalStorage = "l2 default" # type: ignore + + # local storage with custom settings + l3: str = rx.LocalStorage(name="l3") + l4: str = rx.LocalStorage("l4 default") + + def set_var(self): + setattr(self, self.state_var, self.input_value) + self.state_var = self.input_value = "" + + class ClientSideSubSubState(ClientSideSubState): + c1s: str = rx.Cookie() + l1s: str = rx.LocalStorage() + + def set_var(self): + setattr(self, self.state_var, self.input_value) + self.state_var = self.input_value = "" + + def index(): + return rx.fragment( + rx.input(value=ClientSideState.token, is_read_only=True, id="token"), + rx.input( + placeholder="state var", + value=ClientSideState.state_var, + on_change=ClientSideState.set_state_var, # type: ignore + id="state_var", + ), + rx.input( + placeholder="input value", + value=ClientSideState.input_value, + on_change=ClientSideState.set_input_value, # type: ignore + id="input_value", + ), + rx.button( + "Set ClientSideSubState", + on_click=ClientSideSubState.set_var, + id="set_sub_state", + ), + rx.button( + "Set ClientSideSubSubState", + on_click=ClientSideSubSubState.set_var, + id="set_sub_sub_state", + ), + rx.box(ClientSideSubState.c1, id="c1"), + rx.box(ClientSideSubState.c2, id="c2"), + rx.box(ClientSideSubState.c3, id="c3"), + rx.box(ClientSideSubState.c4, id="c4"), + rx.box(ClientSideSubState.c5, id="c5"), + rx.box(ClientSideSubState.c6, id="c6"), + rx.box(ClientSideSubState.c7, id="c7"), + rx.box(ClientSideSubState.l1, id="l1"), + rx.box(ClientSideSubState.l2, id="l2"), + rx.box(ClientSideSubState.l3, id="l3"), + rx.box(ClientSideSubState.l4, id="l4"), + rx.box(ClientSideSubSubState.c1s, id="c1s"), + rx.box(ClientSideSubSubState.l1s, id="l1s"), + ) + + app = rx.App(state=ClientSideState) + app.add_page(index) + app.add_page(index, route="/foo") + app.compile() + + +@pytest.fixture(scope="session") +def client_side(tmp_path_factory) -> Generator[AppHarness, None, None]: + """Start ClientSide app at tmp_path via AppHarness. + + Args: + tmp_path_factory: pytest tmp_path_factory fixture + + Yields: + running AppHarness instance + """ + with AppHarness.create( + root=tmp_path_factory.mktemp("client_side"), + app_source=ClientSide, # type: ignore + ) as harness: + yield harness + + +@pytest.fixture +def driver(client_side: AppHarness) -> Generator[WebDriver, None, None]: + """Get an instance of the browser open to the client_side app. + + Args: + client_side: harness for ClientSide app + + Yields: + WebDriver instance. + """ + assert client_side.app_instance is not None, "app is not running" + driver = client_side.frontend() + try: + assert client_side.poll_for_clients() + yield driver + finally: + driver.quit() + + +@pytest.fixture() +def local_storage(driver: WebDriver) -> Generator[utils.LocalStorage, None, None]: + """Get an instance of the local storage helper. + + Args: + driver: WebDriver instance. + + Yields: + Local storage helper. + """ + ls = utils.LocalStorage(driver) + yield ls + ls.clear() + + +@pytest.fixture(autouse=True) +def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]: + """Delete all cookies after each test. + + Args: + driver: WebDriver instance. + + Yields: + None + """ + yield + driver.delete_all_cookies() + + +def test_client_side_state( + client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage +): + """Test client side state. + + Args: + client_side: harness for ClientSide app. + driver: WebDriver instance. + local_storage: Local storage helper. + """ + assert client_side.app_instance is not None + assert client_side.frontend_url is not None + token_input = driver.find_element(By.ID, "token") + assert token_input + + # wait for the backend connection to send the token + token = client_side.poll_for_value(token_input) + assert token is not None + + backend_state = client_side.app_instance.state_manager.states[token] + + # get a reference to the cookie manipulation form + state_var_input = driver.find_element(By.ID, "state_var") + input_value_input = driver.find_element(By.ID, "input_value") + set_sub_state_button = driver.find_element(By.ID, "set_sub_state") + set_sub_sub_state_button = driver.find_element(By.ID, "set_sub_sub_state") + + # get a reference to all cookie and local storage elements + c1 = driver.find_element(By.ID, "c1") + c2 = driver.find_element(By.ID, "c2") + c3 = driver.find_element(By.ID, "c3") + c4 = driver.find_element(By.ID, "c4") + c5 = driver.find_element(By.ID, "c5") + c6 = driver.find_element(By.ID, "c6") + c7 = driver.find_element(By.ID, "c7") + l1 = driver.find_element(By.ID, "l1") + l2 = driver.find_element(By.ID, "l2") + l3 = driver.find_element(By.ID, "l3") + l4 = driver.find_element(By.ID, "l4") + c1s = driver.find_element(By.ID, "c1s") + l1s = driver.find_element(By.ID, "l1s") + + # assert on defaults where present + assert c1.text == "" + assert c2.text == "c2 default" + assert c3.text == "" + assert c4.text == "" + assert c5.text == "" + assert c6.text == "" + assert c7.text == "c7 default" + assert l1.text == "" + assert l2.text == "l2 default" + assert l3.text == "" + assert l4.text == "l4 default" + assert c1s.text == "" + assert l1s.text == "" + + # no cookies should be set yet! + assert not driver.get_cookies() + local_storage_items = local_storage.items() + local_storage_items.pop("chakra-ui-color-mode", None) + assert not local_storage_items + + # set some cookies and local storage values + state_var_input.send_keys("c1") + input_value_input.send_keys("c1 value") + set_sub_state_button.click() + state_var_input.send_keys("c2") + input_value_input.send_keys("c2 value") + set_sub_state_button.click() + state_var_input.send_keys("c3") + input_value_input.send_keys("c3 value") + set_sub_state_button.click() + state_var_input.send_keys("c4") + input_value_input.send_keys("c4 value") + set_sub_state_button.click() + state_var_input.send_keys("c5") + input_value_input.send_keys("c5 value") + set_sub_state_button.click() + state_var_input.send_keys("c6") + input_value_input.send_keys("c6 value") + set_sub_state_button.click() + state_var_input.send_keys("c7") + input_value_input.send_keys("c7 value") + set_sub_state_button.click() + + state_var_input.send_keys("l1") + input_value_input.send_keys("l1 value") + set_sub_state_button.click() + state_var_input.send_keys("l2") + input_value_input.send_keys("l2 value") + set_sub_state_button.click() + state_var_input.send_keys("l3") + input_value_input.send_keys("l3 value") + set_sub_state_button.click() + state_var_input.send_keys("l4") + input_value_input.send_keys("l4 value") + set_sub_state_button.click() + + state_var_input.send_keys("c1s") + input_value_input.send_keys("c1s value") + set_sub_sub_state_button.click() + state_var_input.send_keys("l1s") + input_value_input.send_keys("l1s value") + set_sub_sub_state_button.click() + + cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()} + assert cookies.pop("client_side_state.client_side_sub_state.c1") == { + "domain": "localhost", + "httpOnly": False, + "name": "client_side_state.client_side_sub_state.c1", + "path": "/", + "sameSite": "Lax", + "secure": False, + "value": "c1%20value", + } + assert cookies.pop("client_side_state.client_side_sub_state.c2") == { + "domain": "localhost", + "httpOnly": False, + "name": "client_side_state.client_side_sub_state.c2", + "path": "/", + "sameSite": "Lax", + "secure": False, + "value": "c2%20value", + } + c3_cookie = cookies.pop("client_side_state.client_side_sub_state.c3") + assert c3_cookie.pop("expiry") is not None + assert c3_cookie == { + "domain": "localhost", + "httpOnly": False, + "name": "client_side_state.client_side_sub_state.c3", + "path": "/", + "sameSite": "Lax", + "secure": False, + "value": "c3%20value", + } + assert cookies.pop("client_side_state.client_side_sub_state.c4") == { + "domain": "localhost", + "httpOnly": False, + "name": "client_side_state.client_side_sub_state.c4", + "path": "/", + "sameSite": "Strict", + "secure": False, + "value": "c4%20value", + } + assert cookies.pop("c6") == { + "domain": "localhost", + "httpOnly": False, + "name": "c6", + "path": "/", + "sameSite": "Lax", + "secure": False, + "value": "c6%20value", + } + assert cookies.pop("client_side_state.client_side_sub_state.c7") == { + "domain": "localhost", + "httpOnly": False, + "name": "client_side_state.client_side_sub_state.c7", + "path": "/", + "sameSite": "Lax", + "secure": False, + "value": "c7%20value", + } + assert cookies.pop( + "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s" + ) == { + "domain": "localhost", + "httpOnly": False, + "name": "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s", + "path": "/", + "sameSite": "Lax", + "secure": False, + "value": "c1s%20value", + } + # assert all cookies have been popped for this page + assert not cookies + time.sleep(2) # wait for c3 to expire + assert "client_side_state.client_side_sub_state.c3" not in { + cookie_info["name"] for cookie_info in driver.get_cookies() + } + + local_storage_items = local_storage.items() + local_storage_items.pop("chakra-ui-color-mode", None) + assert ( + local_storage_items.pop("client_side_state.client_side_sub_state.l1") + == "l1 value" + ) + assert ( + local_storage_items.pop("client_side_state.client_side_sub_state.l2") + == "l2 value" + ) + assert local_storage_items.pop("l3") == "l3 value" + assert ( + local_storage_items.pop("client_side_state.client_side_sub_state.l4") + == "l4 value" + ) + assert ( + local_storage_items.pop( + "client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s" + ) + == "l1s value" + ) + assert not local_storage_items + + assert c1.text == "c1 value" + assert c2.text == "c2 value" + assert c3.text == "c3 value" + assert c4.text == "c4 value" + assert c5.text == "c5 value" + assert c6.text == "c6 value" + assert c7.text == "c7 value" + assert l1.text == "l1 value" + assert l2.text == "l2 value" + assert l3.text == "l3 value" + assert l4.text == "l4 value" + assert c1s.text == "c1s value" + assert l1s.text == "l1s value" + + # navigate to the /foo route + with utils.poll_for_navigation(driver): + driver.get(client_side.frontend_url + "/foo") + + # get new references to all cookie and local storage elements + c1 = driver.find_element(By.ID, "c1") + c2 = driver.find_element(By.ID, "c2") + c3 = driver.find_element(By.ID, "c3") + c4 = driver.find_element(By.ID, "c4") + c5 = driver.find_element(By.ID, "c5") + c6 = driver.find_element(By.ID, "c6") + c7 = driver.find_element(By.ID, "c7") + l1 = driver.find_element(By.ID, "l1") + l2 = driver.find_element(By.ID, "l2") + l3 = driver.find_element(By.ID, "l3") + l4 = driver.find_element(By.ID, "l4") + c1s = driver.find_element(By.ID, "c1s") + l1s = driver.find_element(By.ID, "l1s") + + assert c1.text == "c1 value" + assert c2.text == "c2 value" + assert c3.text == "" # cookie expired so value removed from state + assert c4.text == "c4 value" + assert c5.text == "c5 value" + assert c6.text == "c6 value" + assert c7.text == "c7 value" + assert l1.text == "l1 value" + assert l2.text == "l2 value" + assert l3.text == "l3 value" + assert l4.text == "l4 value" + assert c1s.text == "c1s value" + assert l1s.text == "l1s value" + + # reset the backend state to force refresh from client storage + backend_state.reset() + driver.refresh() + + # wait for the backend connection to send the token (again) + token_input = driver.find_element(By.ID, "token") + assert token_input + token = client_side.poll_for_value(token_input) + assert token is not None + + # get new references to all cookie and local storage elements (again) + c1 = driver.find_element(By.ID, "c1") + c2 = driver.find_element(By.ID, "c2") + c3 = driver.find_element(By.ID, "c3") + c4 = driver.find_element(By.ID, "c4") + c5 = driver.find_element(By.ID, "c5") + c6 = driver.find_element(By.ID, "c6") + c7 = driver.find_element(By.ID, "c7") + l1 = driver.find_element(By.ID, "l1") + l2 = driver.find_element(By.ID, "l2") + l3 = driver.find_element(By.ID, "l3") + l4 = driver.find_element(By.ID, "l4") + c1s = driver.find_element(By.ID, "c1s") + l1s = driver.find_element(By.ID, "l1s") + + assert c1.text == "c1 value" + assert c2.text == "c2 value" + assert c3.text == "" # temporary cookie expired after reset state! + assert c4.text == "c4 value" + assert c5.text == "c5 value" + assert c6.text == "c6 value" + assert c7.text == "c7 value" + assert l1.text == "l1 value" + assert l2.text == "l2 value" + assert l3.text == "l3 value" + assert l4.text == "l4 value" + assert c1s.text == "c1s value" + assert l1s.text == "l1s value" + + # make sure c5 cookie shows up on the `/foo` route + cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()} + + assert cookies["client_side_state.client_side_sub_state.c5"] == { + "domain": "localhost", + "httpOnly": False, + "name": "client_side_state.client_side_sub_state.c5", + "path": "/foo/", + "sameSite": "Lax", + "secure": False, + "value": "c5%20value", + } + + # clear the cookie jar and local storage, ensure state reset to default + driver.delete_all_cookies() + local_storage.clear() + + # refresh the page to trigger re-hydrate + driver.refresh() + + # wait for the backend connection to send the token (again) + token_input = driver.find_element(By.ID, "token") + assert token_input + token = client_side.poll_for_value(token_input) + assert token is not None + + # all values should be back to their defaults + c1 = driver.find_element(By.ID, "c1") + c2 = driver.find_element(By.ID, "c2") + c3 = driver.find_element(By.ID, "c3") + c4 = driver.find_element(By.ID, "c4") + c5 = driver.find_element(By.ID, "c5") + c6 = driver.find_element(By.ID, "c6") + c7 = driver.find_element(By.ID, "c7") + l1 = driver.find_element(By.ID, "l1") + l2 = driver.find_element(By.ID, "l2") + l3 = driver.find_element(By.ID, "l3") + l4 = driver.find_element(By.ID, "l4") + c1s = driver.find_element(By.ID, "c1s") + l1s = driver.find_element(By.ID, "l1s") + + # assert on defaults where present + assert c1.text == "" + assert c2.text == "c2 default" + assert c3.text == "" + assert c4.text == "" + assert c5.text == "" + assert c6.text == "" + assert c7.text == "c7 default" + assert l1.text == "" + assert l2.text == "l2 default" + assert l3.text == "" + assert l4.text == "l4 default" + assert c1s.text == "" + assert l1s.text == "" diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py index a42db3716..661e6ce79 100644 --- a/integration/test_dynamic_routes.py +++ b/integration/test_dynamic_routes.py @@ -1,6 +1,5 @@ """Integration tests for dynamic route page behavior.""" import time -from contextlib import contextmanager from typing import Generator from urllib.parse import urlsplit @@ -9,6 +8,8 @@ from selenium.webdriver.common.by import By from reflex.testing import AppHarness +from .utils import poll_for_navigation + def DynamicRoute(): """App for testing dynamic routes.""" @@ -89,27 +90,6 @@ def driver(dynamic_route: AppHarness): driver.quit() -@contextmanager -def poll_for_navigation(driver, timeout: int = 5) -> Generator[None, None, None]: - """Wait for driver url to change. - - Use as a contextmanager, and apply the navigation event inside the context - block, polling will occur after the context block exits. - - Args: - driver: WebDriver instance. - timeout: Time to wait for url to change. - - Yields: - None - """ - prev_url = driver.current_url - - yield - - AppHarness._poll_for(lambda: prev_url != driver.current_url, timeout=timeout) - - def test_on_load_navigate(dynamic_route: AppHarness, driver): """Click links to navigate between dynamic pages with on_load event. diff --git a/integration/utils.py b/integration/utils.py new file mode 100644 index 000000000..273094c84 --- /dev/null +++ b/integration/utils.py @@ -0,0 +1,173 @@ +"""Helper utilities for integration tests.""" +from __future__ import annotations + +from contextlib import contextmanager +from typing import Generator, Iterator + +from selenium.webdriver.remote.webdriver import WebDriver + +from reflex.testing import AppHarness + + +@contextmanager +def poll_for_navigation( + driver: WebDriver, timeout: int = 5 +) -> Generator[None, None, None]: + """Wait for driver url to change. + + Use as a contextmanager, and apply the navigation event inside the context + block, polling will occur after the context block exits. + + Args: + driver: WebDriver instance. + timeout: Time to wait for url to change. + + Yields: + None + """ + prev_url = driver.current_url + + yield + + AppHarness._poll_for(lambda: prev_url != driver.current_url, timeout=timeout) + + +class LocalStorage: + """Class to access local storage. + + https://stackoverflow.com/a/46361900 + """ + + def __init__(self, driver: WebDriver): + """Initialize the class. + + Args: + driver: WebDriver instance. + """ + self.driver = driver + + def __len__(self) -> int: + """Get the number of items in local storage. + + Returns: + The number of items in local storage. + """ + return int(self.driver.execute_script("return window.localStorage.length;")) + + def items(self) -> dict[str, str]: + """Get all items in local storage. + + Returns: + A dict mapping keys to values. + """ + return self.driver.execute_script( + "var ls = window.localStorage, items = {}; " + "for (var i = 0, k; i < ls.length; ++i) " + " items[k = ls.key(i)] = ls.getItem(k); " + "return items; " + ) + + def keys(self) -> list[str]: + """Get all keys in local storage. + + Returns: + A list of keys. + """ + return self.driver.execute_script( + "var ls = window.localStorage, keys = []; " + "for (var i = 0; i < ls.length; ++i) " + " keys[i] = ls.key(i); " + "return keys; " + ) + + def get(self, key) -> str: + """Get a key from local storage. + + Args: + key: The key to get. + + Returns: + The value of the key. + """ + return self.driver.execute_script( + "return window.localStorage.getItem(arguments[0]);", key + ) + + def set(self, key, value) -> None: + """Set a key in local storage. + + Args: + key: The key to set. + value: The value to set the key to. + """ + self.driver.execute_script( + "window.localStorage.setItem(arguments[0], arguments[1]);", key, value + ) + + def has(self, key) -> bool: + """Check if key is in local storage. + + Args: + key: The key to check. + + Returns: + True if key is in local storage, False otherwise. + """ + return key in self + + def remove(self, key) -> None: + """Remove a key from local storage. + + Args: + key: The key to remove. + """ + self.driver.execute_script("window.localStorage.removeItem(arguments[0]);", key) + + def clear(self) -> None: + """Clear all local storage.""" + self.driver.execute_script("window.localStorage.clear();") + + def __getitem__(self, key) -> str: + """Get a key from local storage. + + Args: + key: The key to get. + + Returns: + The value of the key. + + Raises: + KeyError: If key is not in local storage. + """ + value = self.get(key) + if value is None: + raise KeyError(key) + return value + + def __setitem__(self, key, value) -> None: + """Set a key in local storage. + + Args: + key: The key to set. + value: The value to set the key to. + """ + self.set(key, value) + + def __contains__(self, key) -> bool: + """Check if key is in local storage. + + Args: + key: The key to check. + + Returns: + True if key is in local storage, False otherwise. + """ + return self.has(key) + + def __iter__(self) -> Iterator[str]: + """Iterate over the keys in local storage. + + Returns: + An iterator over the items in local storage. + """ + return iter(self.keys()) diff --git a/poetry.lock b/poetry.lock index 514b8ffcc..dd6d2efaa 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1520,13 +1520,13 @@ files = [ [[package]] name = "selenium" -version = "4.10.0" +version = "4.11.2" description = "" optional = false python-versions = ">=3.7" files = [ - {file = "selenium-4.10.0-py3-none-any.whl", hash = "sha256:40241b9d872f58959e9b34e258488bf11844cd86142fd68182bd41db9991fc5c"}, - {file = "selenium-4.10.0.tar.gz", hash = "sha256:871bf800c4934f745b909c8dfc7d15c65cf45bd2e943abd54451c810ada395e3"}, + {file = "selenium-4.11.2-py3-none-any.whl", hash = "sha256:98e72117b194b3fa9c69b48998f44bf7dd4152c7bd98544911a1753b9f03cc7d"}, + {file = "selenium-4.11.2.tar.gz", hash = "sha256:9f9a5ed586280a3594f7461eb1d9dab3eac9d91e28572f365e9b98d9d03e02b5"}, ] [package.dependencies] @@ -2125,4 +2125,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "2b00be45f1c3b5118e2d54b315991c37f65e9da3fa081dab6adb4c7bb1205c74" +content-hash = "44cce3d4423be203bf6b1ddc046cbdd9061924523b86baea8a42cd954dc86b36" diff --git a/pyproject.toml b/pyproject.toml index 2929b33f9..99626ef13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ pandas = [ ] asynctest = "^0.13.0" pre-commit = {version = "^3.2.1", python = ">=3.8,<4.0"} -selenium = "^4.10.0" +selenium = "^4.11.0" [tool.poetry.scripts] reflex = "reflex.reflex:cli" diff --git a/reflex/.templates/jinja/web/utils/context.js.jinja2 b/reflex/.templates/jinja/web/utils/context.js.jinja2 index cd9f20ffd..657c64e27 100644 --- a/reflex/.templates/jinja/web/utils/context.js.jinja2 +++ b/reflex/.templates/jinja/web/utils/context.js.jinja2 @@ -1,7 +1,10 @@ import { createContext } from "react" -import { E } from "/utils/state.js" +import { E, hydrateClientStorage } from "/utils/state.js" export const initialState = {{ initial_state|json_dumps }} -export const initialEvents = [E('{{state_name}}.{{const.hydrate}}', {})] export const StateContext = createContext(null); -export const EventLoopContext = createContext(null); \ No newline at end of file +export const EventLoopContext = createContext(null); +export const clientStorage = {{ client_storage|json_dumps }} +export const initialEvents = [ + E('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)), +] \ No newline at end of file diff --git a/reflex/.templates/web/pages/_app.js b/reflex/.templates/web/pages/_app.js index 6411ee9b5..4039f5ae2 100644 --- a/reflex/.templates/web/pages/_app.js +++ b/reflex/.templates/web/pages/_app.js @@ -1,7 +1,7 @@ import { ChakraProvider, extendTheme } from "@chakra-ui/react"; import { Global, css } from "@emotion/react"; import theme from "/utils/theme"; -import { initialEvents, initialState, StateContext, EventLoopContext } from "/utils/context.js"; +import { clientStorage, initialEvents, initialState, StateContext, EventLoopContext } from "/utils/context.js"; import { useEventLoop } from "utils/state"; import '../styles/tailwind.css' @@ -18,6 +18,7 @@ function EventLoopProvider({ children }) { const [state, Event, connectError] = useEventLoop( initialState, initialEvents, + clientStorage, ) return ( diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 9e1848436..50cb8e442 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -125,12 +125,12 @@ export const applyEvent = async (event, socket) => { } if (event.name == "_set_cookie") { - cookies.set(event.payload.key, event.payload.value); + cookies.set(event.payload.key, event.payload.value, { path: "/" }); return false; } if (event.name == "_remove_cookie") { - cookies.remove(event.payload.key, event.payload.options) + cookies.remove(event.payload.key, { path: "/", ...event.payload.options }) return false; } @@ -257,6 +257,7 @@ export const processEvent = async ( * @param transports The transports to use. * @param setConnectError The function to update connection error value. * @param initial_events Array of events to seed the queue after connecting. + * @param client_storage The client storage object from context.js */ export const connect = async ( socket, @@ -264,6 +265,7 @@ export const connect = async ( transports, setConnectError, initial_events = [], + client_storage = {}, ) => { // Get backend URL object from the endpoint. const endpoint = new URL(EVENTURL); @@ -288,6 +290,7 @@ export const connect = async ( socket.current.on("event", message => { const update = JSON5.parse(message) dispatch(update.delta) + applyClientStorageDelta(client_storage, update.delta) event_processing = !update.final if (update.events) { queueEvents(update.events, socket) @@ -357,10 +360,77 @@ export const E = (name, payload = {}, handler = null) => { return { name, payload, handler }; }; +/** + * Package client-side storage values as payload to send to the + * backend with the hydrate event + * @param client_storage The client storage object from context.js + * @returns payload dict of client storage values + */ +export const hydrateClientStorage = (client_storage) => { + const client_storage_values = { + "cookies": {}, + "local_storage": {} + } + if (client_storage.cookies) { + for (const state_key in client_storage.cookies) { + const cookie_options = client_storage.cookies[state_key] + const cookie_name = cookie_options.name || state_key + client_storage_values.cookies[state_key] = cookies.get(cookie_name) + } + } + if (client_storage.local_storage && (typeof window !== 'undefined')) { + for (const state_key in client_storage.local_storage) { + const options = client_storage.local_storage[state_key] + const local_storage_value = localStorage.getItem(options.name || state_key) + if (local_storage_value !== null) { + client_storage_values.local_storage[state_key] = local_storage_value + } + } + } + if (client_storage.cookies || client_storage.local_storage) { + return client_storage_values + } + return {} +}; + +/** + * Update client storage values based on backend state delta. + * @param client_storage The client storage object from context.js + * @param delta The state update from the backend + */ +const applyClientStorageDelta = (client_storage, delta) => { + // find the main state and check for is_hydrated + const unqualified_states = Object.keys(delta).filter((key) => key.split(".").length === 1); + if (unqualified_states.length === 1) { + const main_state = delta[unqualified_states[0]] + if (main_state.is_hydrated !== undefined && !main_state.is_hydrated) { + // skip if the state is not hydrated yet, since all client storage + // values are sent in the hydrate event + return; + } + } + // Save known client storage values to cookies and localStorage. + for (const substate in delta) { + for (const key in delta[substate]) { + const state_key = `${substate}.${key}` + if (client_storage.cookies && state_key in client_storage.cookies) { + const cookie_options = client_storage.cookies[state_key] + const cookie_name = cookie_options.name || state_key + delete cookie_options.name // name is not a valid cookie option + cookies.set(cookie_name, delta[substate][key], cookie_options); + } else if (client_storage.local_storage && state_key in client_storage.local_storage && (typeof window !== 'undefined')) { + const options = client_storage.local_storage[state_key] + localStorage.setItem(options.name || state_key, delta[substate][key]); + } + } + } +} + /** * Establish websocket event loop for a NextJS page. * @param initial_state The initial page state. * @param initial_events Array of events to seed the queue after connecting. + * @param client_storage The client storage object from context.js * * @returns [state, Event, connectError] - * state is a reactive dict, @@ -370,6 +440,7 @@ export const E = (name, payload = {}, handler = null) => { export const useEventLoop = ( initial_state = {}, initial_events = [], + client_storage = {}, ) => { const socket = useRef(null) const router = useRouter() @@ -391,7 +462,7 @@ export const useEventLoop = ( // Initialize the websocket connection. if (!socket.current) { - connect(socket, dispatch, ['websocket', 'polling'], setConnectError, initial_events) + connect(socket, dispatch, ['websocket', 'polling'], setConnectError, initial_events, client_storage) } (async () => { // Process all outstanding events. diff --git a/reflex/__init__.py b/reflex/__init__.py index e38f96499..53d0e0a38 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -38,6 +38,8 @@ from .model import session as session from .page import page as page from .route import route as route from .state import ComputedVar as var +from .state import Cookie as Cookie +from .state import LocalStorage as LocalStorage from .state import State as State from .style import color_mode as color_mode from .style import toggle_color_mode as toggle_color_mode diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 0d5a325fb..de5fbad8b 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -83,6 +83,7 @@ def _compile_contexts(state: Type[State]) -> str: return templates.CONTEXT.render( initial_state=utils.compile_state(state), state_name=state.get_name(), + client_storage=utils.compile_client_storage(state), ) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 5a234668b..f9a43b708 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -1,7 +1,10 @@ """Common utility functions used in the compiler.""" +from __future__ import annotations import os -from typing import Dict, List, Optional, Set, Tuple, Type +from typing import Any, Dict, List, Optional, Set, Tuple, Type + +from pydantic.fields import ModelField from reflex import constants from reflex.components.base import ( @@ -19,7 +22,7 @@ from reflex.components.base import ( Title, ) from reflex.components.component import Component, ComponentStyle, CustomComponent -from reflex.state import State +from reflex.state import Cookie, LocalStorage, State from reflex.style import Style from reflex.utils import format, imports, path_ops from reflex.vars import ImportVar @@ -129,6 +132,83 @@ def compile_state(state: Type[State]) -> Dict: return format.format_state(initial_state) +def _compile_client_storage_field( + field: ModelField, +) -> tuple[Type[Cookie] | Type[LocalStorage] | None, dict[str, Any] | None]: + """Compile the given cookie or local_storage field. + + Args: + field: The possible cookie field to compile. + + Returns: + A dictionary of the compiled cookie or None if the field is not cookie-like. + """ + for field_type in (Cookie, LocalStorage): + if isinstance(field.default, field_type): + cs_obj = field.default + elif isinstance(field.type_, type) and issubclass(field.type_, field_type): + cs_obj = field.type_() + else: + continue + return field_type, cs_obj.options() + return None, None + + +def _compile_client_storage_recursive( + state: Type[State], +) -> tuple[dict[str, dict], dict[str, dict[str, str]]]: + """Compile the client-side storage for the given state recursively. + + Args: + state: The app state object. + + Returns: + A tuple of the compiled client-side storage info: + ( + cookies: dict[str, dict], + local_storage: dict[str, dict[str, str]] + ) + """ + cookies = {} + local_storage = {} + state_name = state.get_full_name() + for name, field in state.__fields__.items(): + if name in state.inherited_vars: + # only include vars defined in this state + continue + state_key = f"{state_name}.{name}" + field_type, options = _compile_client_storage_field(field) + if field_type is Cookie: + cookies[state_key] = options + elif field_type is LocalStorage: + local_storage[state_key] = options + else: + continue + for substate in state.get_substates(): + substate_cookies, substate_local_storage = _compile_client_storage_recursive( + substate + ) + cookies.update(substate_cookies) + local_storage.update(substate_local_storage) + return cookies, local_storage + + +def compile_client_storage(state: Type[State]) -> dict[str, dict]: + """Compile the client-side storage for the given state. + + Args: + state: The app state object. + + Returns: + A dictionary of the compiled client-side storage info. + """ + cookies, local_storage = _compile_client_storage_recursive(state) + return { + constants.COOKIES: cookies, + constants.LOCAL_STORAGE: local_storage, + } + + def compile_custom_component( component: CustomComponent, ) -> Tuple[dict, imports.ImportDict]: diff --git a/reflex/constants.py b/reflex/constants.py index 202ca275c..2f26e12a4 100644 --- a/reflex/constants.py +++ b/reflex/constants.py @@ -359,6 +359,10 @@ PING_TIMEOUT = 120 # Alembic migrations ALEMBIC_CONFIG = os.environ.get("ALEMBIC_CONFIG", "alembic.ini") +# Keys in the client_side_storage dict +COOKIES = "cookies" +LOCAL_STORAGE = "local_storage" + # Names of event handlers on all components mapped to useEffect ON_MOUNT = "on_mount" ON_UNMOUNT = "on_unmount" diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 588315ae9..7fd465f78 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -36,8 +36,21 @@ class HydrateMiddleware(Middleware): if event.name != get_hydrate_event(state): return None - # Get the initial state. + # Clear client storage, to respect clearing cookies + state._reset_client_storage() + + # Mark state as not hydrated (until on_loads are complete) setattr(state, constants.IS_HYDRATED, False) + + # Apply client side storage values to state + for storage_type in (constants.COOKIES, constants.LOCAL_STORAGE): + if storage_type in event.payload: + for key, value in event.payload[storage_type].items(): + state_name, _, var_name = key.rpartition(".") + var_state = state.get_substate(state_name.split(".")) + setattr(var_state, var_name, value) + + # Get the initial state. delta = format.format_state({state.get_name(): state.dict()}) # since a full dict was captured, clean any dirtiness state._clean() diff --git a/reflex/state.py b/reflex/state.py index 7cd789fa0..e2df18881 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -688,6 +688,23 @@ class State(Base, ABC, extra=pydantic.Extra.allow): for substate in self.substates.values(): substate.reset() + def _reset_client_storage(self): + """Reset client storage base vars to their default values.""" + # Client-side storage is reset during hydrate so that clearing cookies + # on the browser also resets the values on the backend. + fields = self.get_fields() + for prop_name in self.base_vars: + field = fields[prop_name] + if isinstance(field.default, ClientStorageBase) or ( + isinstance(field.type_, type) + and issubclass(field.type_, ClientStorageBase) + ): + setattr(self, prop_name, field.default) + + # Recursively reset the substates. + for substate in self.substates.values(): + substate.reset() + def get_substate(self, path: Sequence[str]) -> Optional[State]: """Get the substate. @@ -1110,3 +1127,104 @@ def _convert_mutable_datatypes( ) return field_value + + +class ClientStorageBase: + """Base class for client-side storage.""" + + def options(self) -> dict[str, Any]: + """Get the options for the storage. + + Returns: + All set options for the storage (not None). + """ + return { + format.to_camel_case(k): v for k, v in vars(self).items() if v is not None + } + + +class Cookie(ClientStorageBase, str): + """Represents a state Var that is stored as a cookie in the browser.""" + + name: str | None + path: str + max_age: int | None + domain: str | None + secure: bool | None + same_site: str + + def __new__( + cls, + object: Any = "", + encoding: str | None = None, + errors: str | None = None, + /, + name: str | None = None, + path: str = "/", + max_age: int | None = None, + domain: str | None = None, + secure: bool | None = None, + same_site: str = "lax", + ): + """Create a client-side Cookie (str). + + Args: + object: The initial object. + encoding: The encoding to use. + errors: The error handling scheme to use. + name: The name of the cookie on the client side. + path: Cookie path. Use / as the path if the cookie should be accessible on all pages. + max_age: Relative max age of the cookie in seconds from when the client receives it. + domain: Domain for the cookie (sub.domain.com or .allsubdomains.com). + secure: Is the cookie only accessible through HTTPS? + same_site: Whether the cookie is sent with third party requests. + One of (true|false|none|lax|strict) + + Returns: + The client-side Cookie object. + + Note: expires (absolute Date) is not supported at this time. + """ + if encoding or errors: + inst = super().__new__(cls, object, encoding or "utf-8", errors or "strict") + else: + inst = super().__new__(cls, object) + inst.name = name + inst.path = path + inst.max_age = max_age + inst.domain = domain + inst.secure = secure + inst.same_site = same_site + return inst + + +class LocalStorage(ClientStorageBase, str): + """Represents a state Var that is stored in localStorage in the browser.""" + + name: str | None + + def __new__( + cls, + object: Any = "", + encoding: str | None = None, + errors: str | None = None, + /, + name: str | None = None, + ) -> "LocalStorage": + """Create a client-side localStorage (str). + + Args: + object: The initial object. + encoding: The encoding to use. + errors: The error handling scheme to use. + name: The name of the storage key on the client side. + + Returns: + The client-side localStorage object. + """ + if encoding or errors: + inst = super().__new__(cls, object, encoding or "utf-8", errors or "strict") + else: + inst = super().__new__(cls, object) + inst.name = name + return inst