diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py index 9f883c2e7..c381d3a3e 100644 --- a/integration/test_client_storage.py +++ b/integration/test_client_storage.py @@ -41,6 +41,13 @@ def ClientSide(): l3: str = rx.LocalStorage(name="l3") l4: str = rx.LocalStorage("l4 default") + # Sync'd local storage + l5: str = rx.LocalStorage(sync=True) + l6: str = rx.LocalStorage(sync=True, name="l6") + + def set_l6(self, my_param: str): + self.l6 = my_param + def set_var(self): setattr(self, self.state_var, self.input_value) self.state_var = self.input_value = "" @@ -93,6 +100,8 @@ def ClientSide(): rx.box(ClientSideSubState.l2, id="l2"), rx.box(ClientSideSubState.l3, id="l3"), rx.box(ClientSideSubState.l4, id="l4"), + rx.box(ClientSideSubState.l5, id="l5"), + rx.box(ClientSideSubState.l6, id="l6"), rx.box(ClientSideSubSubState.c1s, id="c1s"), rx.box(ClientSideSubSubState.l1s, id="l1s"), ) @@ -191,33 +200,44 @@ async def test_client_side_state( """ assert client_side.app_instance is not None assert client_side.frontend_url is not None - token_input = driver.find_element(By.ID, "token") - assert token_input - # wait for the backend connection to send the token - token = client_side.poll_for_value(token_input) - assert token is not None + def poll_for_token(): + token_input = driver.find_element(By.ID, "token") + assert token_input - # get a reference to the cookie manipulation form - state_var_input = driver.find_element(By.ID, "state_var") - input_value_input = driver.find_element(By.ID, "input_value") - set_sub_state_button = driver.find_element(By.ID, "set_sub_state") - set_sub_sub_state_button = driver.find_element(By.ID, "set_sub_sub_state") + # wait for the backend connection to send the token + token = client_side.poll_for_value(token_input) + assert token is not None + return token def set_sub(var: str, value: str): + # Get a reference to the cookie manipulation form. + state_var_input = driver.find_element(By.ID, "state_var") + input_value_input = driver.find_element(By.ID, "input_value") + set_sub_state_button = driver.find_element(By.ID, "set_sub_state") AppHarness._poll_for(lambda: state_var_input.get_attribute("value") == "") AppHarness._poll_for(lambda: input_value_input.get_attribute("value") == "") + + # Set the values. state_var_input.send_keys(var) input_value_input.send_keys(value) set_sub_state_button.click() def set_sub_sub(var: str, value: str): + # Get a reference to the cookie manipulation form. + state_var_input = driver.find_element(By.ID, "state_var") + input_value_input = driver.find_element(By.ID, "input_value") + set_sub_sub_state_button = driver.find_element(By.ID, "set_sub_sub_state") AppHarness._poll_for(lambda: state_var_input.get_attribute("value") == "") AppHarness._poll_for(lambda: input_value_input.get_attribute("value") == "") + + # Set the values. state_var_input.send_keys(var) input_value_input.send_keys(value) set_sub_sub_state_button.click() + token = poll_for_token() + # get a reference to all cookie and local storage elements c1 = driver.find_element(By.ID, "c1") c2 = driver.find_element(By.ID, "c2") @@ -485,6 +505,31 @@ async def test_client_side_state( "value": "c5%20value", } + # Open a new tab to check that sync'd local storage is working + main_tab = driver.window_handles[0] + driver.switch_to.new_window("window") + driver.get(client_side.frontend_url) + + # New tab should have a different state token. + assert poll_for_token() != token + + # Set values and check them in the new tab. + set_sub("l5", "l5 value") + set_sub("l6", "l6 value") + l5 = driver.find_element(By.ID, "l5") + l6 = driver.find_element(By.ID, "l6") + assert l5.text == "l5 value" + assert l6.text == "l6 value" + + # Switch back to main window. + driver.switch_to.window(main_tab) + + # The values should have updated automatically. + l5 = driver.find_element(By.ID, "l5") + l6 = driver.find_element(By.ID, "l6") + assert l5.text == "l5 value" + assert l6.text == "l6 value" + # clear the cookie jar and local storage, ensure state reset to default driver.delete_all_cookies() local_storage.clear() diff --git a/reflex/.templates/jinja/web/utils/context.js.jinja2 b/reflex/.templates/jinja/web/utils/context.js.jinja2 index 5ff4a44bd..cc48c4a56 100644 --- a/reflex/.templates/jinja/web/utils/context.js.jinja2 +++ b/reflex/.templates/jinja/web/utils/context.js.jinja2 @@ -23,13 +23,19 @@ export const clientStorage = {} {% endif %} {% if state_name %} -export const onLoadInternalEvent = () => [Event('{{state_name}}.{{const.on_load_internal}}')] +export const state_name = "{{state_name}}" +export const onLoadInternalEvent = () => [ + Event('{{state_name}}.{{const.update_vars_internal}}', {vars: hydrateClientStorage(clientStorage)}), + Event('{{state_name}}.{{const.on_load_internal}}') +] export const initialEvents = () => [ - Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)), + Event('{{state_name}}.{{const.hydrate}}'), ...onLoadInternalEvent() ] {% else %} +export const state_name = undefined + export const onLoadInternalEvent = () => [] export const initialEvents = () => [] diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index dbd914cf0..088f25626 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -6,7 +6,7 @@ import env from "/env.json"; import Cookies from "universal-cookie"; import { useEffect, useReducer, useRef, useState } from "react"; import Router, { useRouter } from "next/router"; -import { initialEvents, initialState, onLoadInternalEvent } from "utils/context.js" +import { initialEvents, initialState, onLoadInternalEvent, state_name } from "utils/context.js" // Endpoint URLs. const EVENTURL = env.EVENT @@ -441,17 +441,14 @@ export const Event = (name, payload = {}, handler = null) => { * @returns payload dict of client storage values */ export const hydrateClientStorage = (client_storage) => { - const client_storage_values = { - "cookies": {}, - "local_storage": {} - } + const client_storage_values = {} if (client_storage.cookies) { for (const state_key in client_storage.cookies) { const cookie_options = client_storage.cookies[state_key] const cookie_name = cookie_options.name || state_key const cookie_value = cookies.get(cookie_name) if (cookie_value !== undefined) { - client_storage_values.cookies[state_key] = cookies.get(cookie_name) + client_storage_values[state_key] = cookies.get(cookie_name) } } } @@ -460,7 +457,7 @@ export const hydrateClientStorage = (client_storage) => { const options = client_storage.local_storage[state_key] const local_storage_value = localStorage.getItem(options.name || state_key) if (local_storage_value !== null) { - client_storage_values.local_storage[state_key] = local_storage_value + client_storage_values[state_key] = local_storage_value } } } @@ -568,6 +565,36 @@ export const useEventLoop = ( } }) + + // localStorage event handling + useEffect(() => { + const storage_to_state_map = {}; + + if (client_storage.local_storage && typeof window !== "undefined") { + for (const state_key in client_storage.local_storage) { + const options = client_storage.local_storage[state_key]; + if (options.sync) { + const local_storage_value_key = options.name || state_key; + storage_to_state_map[local_storage_value_key] = state_key; + } + } + } + + // e is StorageEvent + const handleStorage = (e) => { + if (storage_to_state_map[e.key]) { + const vars = {} + vars[storage_to_state_map[e.key]] = e.newValue + const event = Event(`${state_name}.update_vars_internal`, {vars: vars}) + addEvents([event], e); + } + }; + + window.addEventListener("storage", handleStorage); + return () => window.removeEventListener("storage", handleStorage); + }); + + // Route after the initial page hydration. useEffect(() => { const change_complete = () => addEvents(onLoadInternalEvent()) diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index 472d1fbdc..2b71230c8 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -41,6 +41,7 @@ class ReflexJinjaEnvironment(Environment): "use_color_mode": constants.ColorMode.USE, "hydrate": constants.CompileVars.HYDRATE, "on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL, + "update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL, } diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 6a5465448..4efb68e22 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -60,6 +60,8 @@ class CompileVars(SimpleNamespace): TO_EVENT = "Event" # The name of the internal on_load event. ON_LOAD_INTERNAL = "on_load_internal" + # The name of the internal event to update generic state vars. + UPDATE_VARS_INTERNAL = "update_vars_internal" class PageNames(SimpleNamespace): diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index dc80971c7..b1617fca7 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -39,14 +39,6 @@ class HydrateMiddleware(Middleware): # Mark state as not hydrated (until on_loads are complete) setattr(state, constants.CompileVars.IS_HYDRATED, False) - # Apply client side storage values to state - for storage_type in (constants.COOKIES, constants.LOCAL_STORAGE): - if storage_type in event.payload: - for key, value in event.payload[storage_type].items(): - state_name, _, var_name = key.rpartition(".") - var_state = state.get_substate(state_name.split(".")) - setattr(var_state, var_name, value) - # Get the initial state. delta = format.format_state(state.dict()) # since a full dict was captured, clean any dirtiness diff --git a/reflex/state.py b/reflex/state.py index 33ff12c17..ccdb6649e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1405,6 +1405,23 @@ class State(BaseState): type(self).set_is_hydrated(True), # type: ignore ] + def update_vars_internal(self, vars: dict[str, Any]) -> None: + """Apply updates to fully qualified state vars. + + The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`, + and each value will be set on the appropriate substate instance. + + This function is primarily used to apply cookie and local storage + updates from the frontend to the appropriate substate. + + Args: + vars: The fully qualified vars and values to update. + """ + for var, value in vars.items(): + state_name, _, var_name = var.rpartition(".") + var_state = self.get_substate(state_name.split(".")) + setattr(var_state, var_name, value) + class StateProxy(wrapt.ObjectProxy): """Proxy of a state instance to control mutability of vars for a background task. @@ -1949,6 +1966,7 @@ class LocalStorage(ClientStorageBase, str): """Represents a state Var that is stored in localStorage in the browser.""" name: str | None + sync: bool = False def __new__( cls, @@ -1957,6 +1975,7 @@ class LocalStorage(ClientStorageBase, str): errors: str | None = None, /, name: str | None = None, + sync: bool = False, ) -> "LocalStorage": """Create a client-side localStorage (str). @@ -1965,6 +1984,7 @@ class LocalStorage(ClientStorageBase, str): encoding: The encoding to use. errors: The error handling scheme to use. name: The name of the storage key on the client side. + sync: Whether changes should be propagated to other tabs. Returns: The client-side localStorage object. @@ -1974,6 +1994,7 @@ class LocalStorage(ClientStorageBase, str): else: inst = super().__new__(cls, object) inst.name = name + inst.sync = sync return inst