From 2b2cdf9847260726c4a7db428cbc5447c6db6666 Mon Sep 17 00:00:00 2001 From: Kelechi Ebiri <56020538+TG199@users.noreply.github.com> Date: Mon, 17 Jun 2024 22:31:36 +0100 Subject: [PATCH] Feat: Add Session storage to store data on client storage (#3420) --- .pre-commit-config.yaml | 2 +- integration/test_client_storage.py | 98 +++++++++++++++++++++++++++- reflex/.templates/web/utils/state.js | 33 +++++++++- reflex/__init__.py | 3 + reflex/__init__.pyi | 3 + reflex/compiler/utils.py | 33 ++++++---- reflex/constants/__init__.py | 2 + reflex/constants/base.py | 1 + reflex/event.py | 28 ++++++++ reflex/state.py | 32 +++++++++ 10 files changed, 221 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e36cbdfe6..e9f9473a3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,4 +31,4 @@ repos: always_run: true language: system description: 'Update pyi files as needed' - entry: python scripts/make_pyi.py + entry: python3 scripts/make_pyi.py diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py index 24c3c7be0..e019b1d63 100644 --- a/integration/test_client_storage.py +++ b/integration/test_client_storage.py @@ -46,6 +46,11 @@ def ClientSide(): l5: str = rx.LocalStorage(sync=True) l6: str = rx.LocalStorage(sync=True, name="l6") + # Session storage + s1: str = rx.SessionStorage() + s2: rx.SessionStorage = "s2 default" # type: ignore + s3: str = rx.SessionStorage(name="s3") + def set_l6(self, my_param: str): self.l6 = my_param @@ -56,6 +61,7 @@ def ClientSide(): class ClientSideSubSubState(ClientSideSubState): c1s: str = rx.Cookie() l1s: str = rx.LocalStorage() + s1s: str = rx.SessionStorage() def set_var(self): setattr(self, self.state_var, self.input_value) @@ -103,8 +109,12 @@ def ClientSide(): rx.box(ClientSideSubState.l4, id="l4"), rx.box(ClientSideSubState.l5, id="l5"), rx.box(ClientSideSubState.l6, id="l6"), + rx.box(ClientSideSubState.s1, id="s1"), + rx.box(ClientSideSubState.s2, id="s2"), + rx.box(ClientSideSubState.s3, id="s3"), rx.box(ClientSideSubSubState.c1s, id="c1s"), rx.box(ClientSideSubSubState.l1s, id="l1s"), + rx.box(ClientSideSubSubState.s1s, id="s1s"), ) app = rx.App(state=rx.State) @@ -162,6 +172,21 @@ def local_storage(driver: WebDriver) -> Generator[utils.LocalStorage, None, None ls.clear() +@pytest.fixture() +def session_storage(driver: WebDriver) -> Generator[utils.SessionStorage, None, None]: + """Get an instance of the session storage helper. + + Args: + driver: WebDriver instance. + + Yields: + Session storage helper. + """ + ss = utils.SessionStorage(driver) + yield ss + ss.clear() + + @pytest.fixture(autouse=True) def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]: """Delete all cookies after each test. @@ -190,7 +215,10 @@ def cookie_info_map(driver: WebDriver) -> dict[str, dict[str, str]]: @pytest.mark.asyncio async def test_client_side_state( - client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage + client_side: AppHarness, + driver: WebDriver, + local_storage: utils.LocalStorage, + session_storage: utils.SessionStorage, ): """Test client side state. @@ -198,6 +226,7 @@ async def test_client_side_state( client_side: harness for ClientSide app. driver: WebDriver instance. local_storage: Local storage helper. + session_storage: Session storage helper. """ assert client_side.app_instance is not None assert client_side.frontend_url is not None @@ -251,8 +280,12 @@ async def test_client_side_state( l2 = driver.find_element(By.ID, "l2") l3 = driver.find_element(By.ID, "l3") l4 = driver.find_element(By.ID, "l4") + s1 = driver.find_element(By.ID, "s1") + s2 = driver.find_element(By.ID, "s2") + s3 = driver.find_element(By.ID, "s3") c1s = driver.find_element(By.ID, "c1s") l1s = driver.find_element(By.ID, "l1s") + s1s = driver.find_element(By.ID, "s1s") # assert on defaults where present assert c1.text == "" @@ -266,8 +299,12 @@ async def test_client_side_state( assert l2.text == "l2 default" assert l3.text == "" assert l4.text == "l4 default" + assert s1.text == "" + assert s2.text == "s2 default" + assert s3.text == "" assert c1s.text == "" assert l1s.text == "" + assert s1s.text == "" # no cookies should be set yet! assert not driver.get_cookies() @@ -287,8 +324,12 @@ async def test_client_side_state( set_sub("l2", "l2 value") set_sub("l3", "l3 value") set_sub("l4", "l4 value") + set_sub("s1", "s1 value") + set_sub("s2", "s2 value") + set_sub("s3", "s3 value") set_sub_sub("c1s", "c1s value") set_sub_sub("l1s", "l1s value") + set_sub_sub("s1s", "s1s value") exp_cookies = { "state.client_side_state.client_side_sub_state.c1": { @@ -405,6 +446,25 @@ async def test_client_side_state( ) assert not local_storage_items + session_storage_items = session_storage.items() + session_storage_items.pop("token", None) + assert ( + session_storage_items.pop("state.client_side_state.client_side_sub_state.s1") + == "s1 value" + ) + assert ( + session_storage_items.pop("state.client_side_state.client_side_sub_state.s2") + == "s2 value" + ) + assert session_storage_items.pop("s3") == "s3 value" + assert ( + session_storage_items.pop( + "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.s1s" + ) + == "s1s value" + ) + assert not session_storage_items + assert c1.text == "c1 value" assert c2.text == "c2 value" assert c3.text == "c3 value" @@ -416,8 +476,12 @@ async def test_client_side_state( assert l2.text == "l2 value" assert l3.text == "l3 value" assert l4.text == "l4 value" + assert s1.text == "s1 value" + assert s2.text == "s2 value" + assert s3.text == "s3 value" assert c1s.text == "c1s value" assert l1s.text == "l1s value" + assert s1s.text == "s1s value" # navigate to the /foo route with utils.poll_for_navigation(driver): @@ -435,8 +499,12 @@ async def test_client_side_state( l2 = driver.find_element(By.ID, "l2") l3 = driver.find_element(By.ID, "l3") l4 = driver.find_element(By.ID, "l4") + s1 = driver.find_element(By.ID, "s1") + s2 = driver.find_element(By.ID, "s2") + s3 = driver.find_element(By.ID, "s3") c1s = driver.find_element(By.ID, "c1s") l1s = driver.find_element(By.ID, "l1s") + s1s = driver.find_element(By.ID, "s1s") assert c1.text == "c1 value" assert c2.text == "c2 value" @@ -449,8 +517,12 @@ async def test_client_side_state( assert l2.text == "l2 value" assert l3.text == "l3 value" assert l4.text == "l4 value" + assert s1.text == "s1 value" + assert s2.text == "s2 value" + assert s3.text == "s3 value" assert c1s.text == "c1s value" assert l1s.text == "l1s value" + assert s1s.text == "s1s value" # reset the backend state to force refresh from client storage async with client_side.modify_state(f"{token}_state.client_side_state") as state: @@ -475,8 +547,12 @@ async def test_client_side_state( l2 = driver.find_element(By.ID, "l2") l3 = driver.find_element(By.ID, "l3") l4 = driver.find_element(By.ID, "l4") + s1 = driver.find_element(By.ID, "s1") + s2 = driver.find_element(By.ID, "s2") + s3 = driver.find_element(By.ID, "s3") c1s = driver.find_element(By.ID, "c1s") l1s = driver.find_element(By.ID, "l1s") + s1s = driver.find_element(By.ID, "s1s") assert c1.text == "c1 value" assert c2.text == "c2 value" @@ -489,8 +565,12 @@ async def test_client_side_state( assert l2.text == "l2 value" assert l3.text == "l3 value" assert l4.text == "l4 value" + assert s1.text == "s1 value" + assert s2.text == "s2 value" + assert s3.text == "s3 value" assert c1s.text == "c1s value" assert l1s.text == "l1s value" + assert s1s.text == "s1s value" # make sure c5 cookie shows up on the `/foo` route AppHarness._poll_for( @@ -525,6 +605,15 @@ async def test_client_side_state( assert AppHarness._poll_for(lambda: l6.text == "l6 value") assert l5.text == "l5 value" + # Set session storage values in the new tab + set_sub("s1", "other tab s1") + s1 = driver.find_element(By.ID, "s1") + s2 = driver.find_element(By.ID, "s2") + s3 = driver.find_element(By.ID, "s3") + assert AppHarness._poll_for(lambda: s1.text == "other tab s1") + assert s2.text == "s2 default" + assert s3.text == "" + # Switch back to main window. driver.switch_to.window(main_tab) @@ -534,6 +623,13 @@ async def test_client_side_state( assert AppHarness._poll_for(lambda: l6.text == "l6 value") assert l5.text == "l5 value" + s1 = driver.find_element(By.ID, "s1") + s2 = driver.find_element(By.ID, "s2") + s3 = driver.find_element(By.ID, "s3") + assert AppHarness._poll_for(lambda: s1.text == "s1 value") + assert s2.text == "s2 value" + assert s3.text == "s3 value" + # clear the cookie jar and local storage, ensure state reset to default driver.delete_all_cookies() local_storage.clear() diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 5c9634d08..09177a7a7 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -185,6 +185,18 @@ export const applyEvent = async (event, socket) => { return false; } + if (event.name == "_clear_session_storage") { + sessionStorage.clear(); + queueEvents(initialEvents(), socket); + return false; + } + + if (event.name == "_remove_session_storage") { + sessionStorage.removeItem(event.payload.key); + queueEvents(initialEvents(), socket); + return false; + } + if (event.name == "_set_clipboard") { const content = event.payload.content; navigator.clipboard.writeText(content); @@ -538,7 +550,18 @@ export const hydrateClientStorage = (client_storage) => { } } } - if (client_storage.cookies || client_storage.local_storage) { + if (client_storage.session_storage && typeof window != "undefined") { + for (const state_key in client_storage.session_storage) { + const session_options = client_storage.session_storage[state_key]; + const session_storage_value = sessionStorage.getItem( + session_options.name || state_key + ); + if (session_storage_value != null) { + client_storage_values[state_key] = session_storage_value; + } + } + } + if (client_storage.cookies || client_storage.local_storage || client_storage.session_storage) { return client_storage_values; } return {}; @@ -578,7 +601,15 @@ const applyClientStorageDelta = (client_storage, delta) => { ) { const options = client_storage.local_storage[state_key]; localStorage.setItem(options.name || state_key, delta[substate][key]); + } else if( + client_storage.session_storage && + state_key in client_storage.session_storage && + typeof window !== "undefined" + ) { + const session_options = client_storage.session_storage[state_key]; + sessionStorage.setItem(session_options.name || state_key, delta[substate][key]); } + } } }; diff --git a/reflex/__init__.py b/reflex/__init__.py index dd3f141e8..a71a6cc46 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -287,12 +287,14 @@ _MAPPING: dict = { "background", "call_script", "clear_local_storage", + "clear_session_storage", "console_log", "download", "prevent_default", "redirect", "remove_cookie", "remove_local_storage", + "remove_session_storage", "set_clipboard", "set_focus", "scroll_to", @@ -307,6 +309,7 @@ _MAPPING: dict = { "var", "Cookie", "LocalStorage", + "SessionStorage", "ComponentState", "State", ], diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index da47e6a6f..5233e8b50 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -157,12 +157,14 @@ from .event import EventHandler as EventHandler from .event import background as background from .event import call_script as call_script from .event import clear_local_storage as clear_local_storage +from .event import clear_session_storage as clear_session_storage from .event import console_log as console_log from .event import download as download from .event import prevent_default as prevent_default from .event import redirect as redirect from .event import remove_cookie as remove_cookie from .event import remove_local_storage as remove_local_storage +from .event import remove_session_storage as remove_session_storage from .event import set_clipboard as set_clipboard from .event import set_focus as set_focus from .event import scroll_to as scroll_to @@ -177,6 +179,7 @@ from .model import Model as Model from .state import var as var from .state import Cookie as Cookie from .state import LocalStorage as LocalStorage +from .state import SessionStorage as SessionStorage from .state import ComponentState as ComponentState from .state import State as State from .style import Style as Style diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index fde499094..1b69539ac 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -28,7 +28,7 @@ from reflex.components.base import ( Title, ) from reflex.components.component import Component, ComponentStyle, CustomComponent -from reflex.state import BaseState, Cookie, LocalStorage +from reflex.state import BaseState, Cookie, LocalStorage, SessionStorage from reflex.style import Style from reflex.utils import console, format, imports, path_ops from reflex.utils.imports import ImportVar, ParsedImportDict @@ -158,8 +158,11 @@ def compile_state(state: Type[BaseState]) -> dict: def _compile_client_storage_field( field: ModelField, -) -> tuple[Type[Cookie] | Type[LocalStorage] | None, dict[str, Any] | None]: - """Compile the given cookie or local_storage field. +) -> tuple[ + Type[Cookie] | Type[LocalStorage] | Type[SessionStorage] | None, + dict[str, Any] | None, +]: + """Compile the given cookie, local_storage or session_storage field. Args: field: The possible cookie field to compile. @@ -167,7 +170,7 @@ def _compile_client_storage_field( Returns: A dictionary of the compiled cookie or None if the field is not cookie-like. """ - for field_type in (Cookie, LocalStorage): + for field_type in (Cookie, LocalStorage, SessionStorage): if isinstance(field.default, field_type): cs_obj = field.default elif isinstance(field.type_, type) and issubclass(field.type_, field_type): @@ -180,7 +183,7 @@ def _compile_client_storage_field( def _compile_client_storage_recursive( state: Type[BaseState], -) -> tuple[dict[str, dict], dict[str, dict[str, str]]]: +) -> tuple[dict[str, dict], dict[str, dict], dict[str, dict]]: """Compile the client-side storage for the given state recursively. Args: @@ -191,10 +194,12 @@ def _compile_client_storage_recursive( ( cookies: dict[str, dict], local_storage: dict[str, dict[str, str]] - ) + session_storage: dict[str, dict[str, str]] + ). """ cookies = {} local_storage = {} + session_storage = {} state_name = state.get_full_name() for name, field in state.__fields__.items(): if name in state.inherited_vars: @@ -206,15 +211,20 @@ def _compile_client_storage_recursive( cookies[state_key] = options elif field_type is LocalStorage: local_storage[state_key] = options + elif field_type is SessionStorage: + session_storage[state_key] = options else: continue for substate in state.get_substates(): - substate_cookies, substate_local_storage = _compile_client_storage_recursive( - substate - ) + ( + substate_cookies, + substate_local_storage, + substate_session_storage, + ) = _compile_client_storage_recursive(substate) cookies.update(substate_cookies) local_storage.update(substate_local_storage) - return cookies, local_storage + session_storage.update(substate_session_storage) + return cookies, local_storage, session_storage def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]: @@ -226,10 +236,11 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]: Returns: A dictionary of the compiled client-side storage info. """ - cookies, local_storage = _compile_client_storage_recursive(state) + cookies, local_storage, session_storage = _compile_client_storage_recursive(state) return { constants.COOKIES: cookies, constants.LOCAL_STORAGE: local_storage, + constants.SESSION_STORAGE: session_storage, } diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index 6389a8b0a..e974ab915 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -10,6 +10,7 @@ from .base import ( REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG, RELOAD_CONFIG, + SESSION_STORAGE, SKIP_COMPILE_ENV_VAR, ColorMode, Dirs, @@ -88,6 +89,7 @@ __ALL__ = [ Imports, IS_WINDOWS, LOCAL_STORAGE, + SESSION_STORAGE, LogLevel, MemoizationDisposition, MemoizationMode, diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 3fca45e2f..a4ab92a88 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -178,6 +178,7 @@ class Ping(SimpleNamespace): # Keys in the client_side_storage dict COOKIES = "cookies" LOCAL_STORAGE = "local_storage" +SESSION_STORAGE = "session_storage" # If this env var is set to "yes", App.compile will be a no-op SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE" diff --git a/reflex/event.py b/reflex/event.py index c9799a527..0fcb1e794 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -617,6 +617,34 @@ def remove_local_storage(key: str) -> EventSpec: ) +def clear_session_storage() -> EventSpec: + """Set a value in the session storage on the frontend. + + Returns: + EventSpec: An event to clear the session storage. + """ + return server_side( + "_clear_session_storage", + get_fn_signature(clear_session_storage), + ) + + +def remove_session_storage(key: str) -> EventSpec: + """Set a value in the session storage on the frontend. + + Args: + key: The key identifying the variable in the session storage to remove. + + Returns: + EventSpec: An event to remove an item based on the provided key in session storage. + """ + return server_side( + "_remove_session_storage", + get_fn_signature(remove_session_storage), + key=key, + ) + + def set_clipboard(content: str) -> EventSpec: """Set the text in content in the clipboard. diff --git a/reflex/state.py b/reflex/state.py index 56b28f9e8..156751ac2 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2835,6 +2835,38 @@ class LocalStorage(ClientStorageBase, str): return inst +class SessionStorage(ClientStorageBase, str): + """Represents a state Var that is stored in sessionStorage in the browser.""" + + name: str | None + + def __new__( + cls, + object: Any = "", + encoding: str | None = None, + errors: str | None = None, + /, + name: str | None = None, + ) -> "SessionStorage": + """Create a client-side sessionStorage (str). + + Args: + object: The initial object. + encoding: The encoding to use. + errors: The error handling scheme to use + name: The name of the storage on the client side + + Returns: + The client-side sessionStorage object. + """ + if encoding or errors: + inst = super().__new__(cls, object, encoding or "utf-8", errors or "strict") + else: + inst = super().__new__(cls, object) + inst.name = name + return inst + + class MutableProxy(wrapt.ObjectProxy): """A proxy for a mutable object that tracks changes."""