Feat: Add Session storage to store data on client storage (#3420)

This commit is contained in:
Kelechi Ebiri 2024-06-17 22:31:36 +01:00 committed by GitHub
parent b78fa6f210
commit 2b2cdf9847
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 221 additions and 14 deletions

View File

@ -31,4 +31,4 @@ repos:
always_run: true always_run: true
language: system language: system
description: 'Update pyi files as needed' description: 'Update pyi files as needed'
entry: python scripts/make_pyi.py entry: python3 scripts/make_pyi.py

View File

@ -46,6 +46,11 @@ def ClientSide():
l5: str = rx.LocalStorage(sync=True) l5: str = rx.LocalStorage(sync=True)
l6: str = rx.LocalStorage(sync=True, name="l6") l6: str = rx.LocalStorage(sync=True, name="l6")
# Session storage
s1: str = rx.SessionStorage()
s2: rx.SessionStorage = "s2 default" # type: ignore
s3: str = rx.SessionStorage(name="s3")
def set_l6(self, my_param: str): def set_l6(self, my_param: str):
self.l6 = my_param self.l6 = my_param
@ -56,6 +61,7 @@ def ClientSide():
class ClientSideSubSubState(ClientSideSubState): class ClientSideSubSubState(ClientSideSubState):
c1s: str = rx.Cookie() c1s: str = rx.Cookie()
l1s: str = rx.LocalStorage() l1s: str = rx.LocalStorage()
s1s: str = rx.SessionStorage()
def set_var(self): def set_var(self):
setattr(self, self.state_var, self.input_value) setattr(self, self.state_var, self.input_value)
@ -103,8 +109,12 @@ def ClientSide():
rx.box(ClientSideSubState.l4, id="l4"), rx.box(ClientSideSubState.l4, id="l4"),
rx.box(ClientSideSubState.l5, id="l5"), rx.box(ClientSideSubState.l5, id="l5"),
rx.box(ClientSideSubState.l6, id="l6"), rx.box(ClientSideSubState.l6, id="l6"),
rx.box(ClientSideSubState.s1, id="s1"),
rx.box(ClientSideSubState.s2, id="s2"),
rx.box(ClientSideSubState.s3, id="s3"),
rx.box(ClientSideSubSubState.c1s, id="c1s"), rx.box(ClientSideSubSubState.c1s, id="c1s"),
rx.box(ClientSideSubSubState.l1s, id="l1s"), rx.box(ClientSideSubSubState.l1s, id="l1s"),
rx.box(ClientSideSubSubState.s1s, id="s1s"),
) )
app = rx.App(state=rx.State) app = rx.App(state=rx.State)
@ -162,6 +172,21 @@ def local_storage(driver: WebDriver) -> Generator[utils.LocalStorage, None, None
ls.clear() ls.clear()
@pytest.fixture()
def session_storage(driver: WebDriver) -> Generator[utils.SessionStorage, None, None]:
"""Get an instance of the session storage helper.
Args:
driver: WebDriver instance.
Yields:
Session storage helper.
"""
ss = utils.SessionStorage(driver)
yield ss
ss.clear()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]: def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]:
"""Delete all cookies after each test. """Delete all cookies after each test.
@ -190,7 +215,10 @@ def cookie_info_map(driver: WebDriver) -> dict[str, dict[str, str]]:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_client_side_state( async def test_client_side_state(
client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage client_side: AppHarness,
driver: WebDriver,
local_storage: utils.LocalStorage,
session_storage: utils.SessionStorage,
): ):
"""Test client side state. """Test client side state.
@ -198,6 +226,7 @@ async def test_client_side_state(
client_side: harness for ClientSide app. client_side: harness for ClientSide app.
driver: WebDriver instance. driver: WebDriver instance.
local_storage: Local storage helper. local_storage: Local storage helper.
session_storage: Session storage helper.
""" """
assert client_side.app_instance is not None assert client_side.app_instance is not None
assert client_side.frontend_url is not None assert client_side.frontend_url is not None
@ -251,8 +280,12 @@ async def test_client_side_state(
l2 = driver.find_element(By.ID, "l2") l2 = driver.find_element(By.ID, "l2")
l3 = driver.find_element(By.ID, "l3") l3 = driver.find_element(By.ID, "l3")
l4 = driver.find_element(By.ID, "l4") l4 = driver.find_element(By.ID, "l4")
s1 = driver.find_element(By.ID, "s1")
s2 = driver.find_element(By.ID, "s2")
s3 = driver.find_element(By.ID, "s3")
c1s = driver.find_element(By.ID, "c1s") c1s = driver.find_element(By.ID, "c1s")
l1s = driver.find_element(By.ID, "l1s") l1s = driver.find_element(By.ID, "l1s")
s1s = driver.find_element(By.ID, "s1s")
# assert on defaults where present # assert on defaults where present
assert c1.text == "" assert c1.text == ""
@ -266,8 +299,12 @@ async def test_client_side_state(
assert l2.text == "l2 default" assert l2.text == "l2 default"
assert l3.text == "" assert l3.text == ""
assert l4.text == "l4 default" assert l4.text == "l4 default"
assert s1.text == ""
assert s2.text == "s2 default"
assert s3.text == ""
assert c1s.text == "" assert c1s.text == ""
assert l1s.text == "" assert l1s.text == ""
assert s1s.text == ""
# no cookies should be set yet! # no cookies should be set yet!
assert not driver.get_cookies() assert not driver.get_cookies()
@ -287,8 +324,12 @@ async def test_client_side_state(
set_sub("l2", "l2 value") set_sub("l2", "l2 value")
set_sub("l3", "l3 value") set_sub("l3", "l3 value")
set_sub("l4", "l4 value") set_sub("l4", "l4 value")
set_sub("s1", "s1 value")
set_sub("s2", "s2 value")
set_sub("s3", "s3 value")
set_sub_sub("c1s", "c1s value") set_sub_sub("c1s", "c1s value")
set_sub_sub("l1s", "l1s value") set_sub_sub("l1s", "l1s value")
set_sub_sub("s1s", "s1s value")
exp_cookies = { exp_cookies = {
"state.client_side_state.client_side_sub_state.c1": { "state.client_side_state.client_side_sub_state.c1": {
@ -405,6 +446,25 @@ async def test_client_side_state(
) )
assert not local_storage_items assert not local_storage_items
session_storage_items = session_storage.items()
session_storage_items.pop("token", None)
assert (
session_storage_items.pop("state.client_side_state.client_side_sub_state.s1")
== "s1 value"
)
assert (
session_storage_items.pop("state.client_side_state.client_side_sub_state.s2")
== "s2 value"
)
assert session_storage_items.pop("s3") == "s3 value"
assert (
session_storage_items.pop(
"state.client_side_state.client_side_sub_state.client_side_sub_sub_state.s1s"
)
== "s1s value"
)
assert not session_storage_items
assert c1.text == "c1 value" assert c1.text == "c1 value"
assert c2.text == "c2 value" assert c2.text == "c2 value"
assert c3.text == "c3 value" assert c3.text == "c3 value"
@ -416,8 +476,12 @@ async def test_client_side_state(
assert l2.text == "l2 value" assert l2.text == "l2 value"
assert l3.text == "l3 value" assert l3.text == "l3 value"
assert l4.text == "l4 value" assert l4.text == "l4 value"
assert s1.text == "s1 value"
assert s2.text == "s2 value"
assert s3.text == "s3 value"
assert c1s.text == "c1s value" assert c1s.text == "c1s value"
assert l1s.text == "l1s value" assert l1s.text == "l1s value"
assert s1s.text == "s1s value"
# navigate to the /foo route # navigate to the /foo route
with utils.poll_for_navigation(driver): with utils.poll_for_navigation(driver):
@ -435,8 +499,12 @@ async def test_client_side_state(
l2 = driver.find_element(By.ID, "l2") l2 = driver.find_element(By.ID, "l2")
l3 = driver.find_element(By.ID, "l3") l3 = driver.find_element(By.ID, "l3")
l4 = driver.find_element(By.ID, "l4") l4 = driver.find_element(By.ID, "l4")
s1 = driver.find_element(By.ID, "s1")
s2 = driver.find_element(By.ID, "s2")
s3 = driver.find_element(By.ID, "s3")
c1s = driver.find_element(By.ID, "c1s") c1s = driver.find_element(By.ID, "c1s")
l1s = driver.find_element(By.ID, "l1s") l1s = driver.find_element(By.ID, "l1s")
s1s = driver.find_element(By.ID, "s1s")
assert c1.text == "c1 value" assert c1.text == "c1 value"
assert c2.text == "c2 value" assert c2.text == "c2 value"
@ -449,8 +517,12 @@ async def test_client_side_state(
assert l2.text == "l2 value" assert l2.text == "l2 value"
assert l3.text == "l3 value" assert l3.text == "l3 value"
assert l4.text == "l4 value" assert l4.text == "l4 value"
assert s1.text == "s1 value"
assert s2.text == "s2 value"
assert s3.text == "s3 value"
assert c1s.text == "c1s value" assert c1s.text == "c1s value"
assert l1s.text == "l1s value" assert l1s.text == "l1s value"
assert s1s.text == "s1s value"
# reset the backend state to force refresh from client storage # reset the backend state to force refresh from client storage
async with client_side.modify_state(f"{token}_state.client_side_state") as state: async with client_side.modify_state(f"{token}_state.client_side_state") as state:
@ -475,8 +547,12 @@ async def test_client_side_state(
l2 = driver.find_element(By.ID, "l2") l2 = driver.find_element(By.ID, "l2")
l3 = driver.find_element(By.ID, "l3") l3 = driver.find_element(By.ID, "l3")
l4 = driver.find_element(By.ID, "l4") l4 = driver.find_element(By.ID, "l4")
s1 = driver.find_element(By.ID, "s1")
s2 = driver.find_element(By.ID, "s2")
s3 = driver.find_element(By.ID, "s3")
c1s = driver.find_element(By.ID, "c1s") c1s = driver.find_element(By.ID, "c1s")
l1s = driver.find_element(By.ID, "l1s") l1s = driver.find_element(By.ID, "l1s")
s1s = driver.find_element(By.ID, "s1s")
assert c1.text == "c1 value" assert c1.text == "c1 value"
assert c2.text == "c2 value" assert c2.text == "c2 value"
@ -489,8 +565,12 @@ async def test_client_side_state(
assert l2.text == "l2 value" assert l2.text == "l2 value"
assert l3.text == "l3 value" assert l3.text == "l3 value"
assert l4.text == "l4 value" assert l4.text == "l4 value"
assert s1.text == "s1 value"
assert s2.text == "s2 value"
assert s3.text == "s3 value"
assert c1s.text == "c1s value" assert c1s.text == "c1s value"
assert l1s.text == "l1s value" assert l1s.text == "l1s value"
assert s1s.text == "s1s value"
# make sure c5 cookie shows up on the `/foo` route # make sure c5 cookie shows up on the `/foo` route
AppHarness._poll_for( AppHarness._poll_for(
@ -525,6 +605,15 @@ async def test_client_side_state(
assert AppHarness._poll_for(lambda: l6.text == "l6 value") assert AppHarness._poll_for(lambda: l6.text == "l6 value")
assert l5.text == "l5 value" assert l5.text == "l5 value"
# Set session storage values in the new tab
set_sub("s1", "other tab s1")
s1 = driver.find_element(By.ID, "s1")
s2 = driver.find_element(By.ID, "s2")
s3 = driver.find_element(By.ID, "s3")
assert AppHarness._poll_for(lambda: s1.text == "other tab s1")
assert s2.text == "s2 default"
assert s3.text == ""
# Switch back to main window. # Switch back to main window.
driver.switch_to.window(main_tab) driver.switch_to.window(main_tab)
@ -534,6 +623,13 @@ async def test_client_side_state(
assert AppHarness._poll_for(lambda: l6.text == "l6 value") assert AppHarness._poll_for(lambda: l6.text == "l6 value")
assert l5.text == "l5 value" assert l5.text == "l5 value"
s1 = driver.find_element(By.ID, "s1")
s2 = driver.find_element(By.ID, "s2")
s3 = driver.find_element(By.ID, "s3")
assert AppHarness._poll_for(lambda: s1.text == "s1 value")
assert s2.text == "s2 value"
assert s3.text == "s3 value"
# clear the cookie jar and local storage, ensure state reset to default # clear the cookie jar and local storage, ensure state reset to default
driver.delete_all_cookies() driver.delete_all_cookies()
local_storage.clear() local_storage.clear()

View File

@ -185,6 +185,18 @@ export const applyEvent = async (event, socket) => {
return false; return false;
} }
if (event.name == "_clear_session_storage") {
sessionStorage.clear();
queueEvents(initialEvents(), socket);
return false;
}
if (event.name == "_remove_session_storage") {
sessionStorage.removeItem(event.payload.key);
queueEvents(initialEvents(), socket);
return false;
}
if (event.name == "_set_clipboard") { if (event.name == "_set_clipboard") {
const content = event.payload.content; const content = event.payload.content;
navigator.clipboard.writeText(content); navigator.clipboard.writeText(content);
@ -538,7 +550,18 @@ export const hydrateClientStorage = (client_storage) => {
} }
} }
} }
if (client_storage.cookies || client_storage.local_storage) { if (client_storage.session_storage && typeof window != "undefined") {
for (const state_key in client_storage.session_storage) {
const session_options = client_storage.session_storage[state_key];
const session_storage_value = sessionStorage.getItem(
session_options.name || state_key
);
if (session_storage_value != null) {
client_storage_values[state_key] = session_storage_value;
}
}
}
if (client_storage.cookies || client_storage.local_storage || client_storage.session_storage) {
return client_storage_values; return client_storage_values;
} }
return {}; return {};
@ -578,7 +601,15 @@ const applyClientStorageDelta = (client_storage, delta) => {
) { ) {
const options = client_storage.local_storage[state_key]; const options = client_storage.local_storage[state_key];
localStorage.setItem(options.name || state_key, delta[substate][key]); localStorage.setItem(options.name || state_key, delta[substate][key]);
} else if(
client_storage.session_storage &&
state_key in client_storage.session_storage &&
typeof window !== "undefined"
) {
const session_options = client_storage.session_storage[state_key];
sessionStorage.setItem(session_options.name || state_key, delta[substate][key]);
} }
} }
} }
}; };

View File

@ -287,12 +287,14 @@ _MAPPING: dict = {
"background", "background",
"call_script", "call_script",
"clear_local_storage", "clear_local_storage",
"clear_session_storage",
"console_log", "console_log",
"download", "download",
"prevent_default", "prevent_default",
"redirect", "redirect",
"remove_cookie", "remove_cookie",
"remove_local_storage", "remove_local_storage",
"remove_session_storage",
"set_clipboard", "set_clipboard",
"set_focus", "set_focus",
"scroll_to", "scroll_to",
@ -307,6 +309,7 @@ _MAPPING: dict = {
"var", "var",
"Cookie", "Cookie",
"LocalStorage", "LocalStorage",
"SessionStorage",
"ComponentState", "ComponentState",
"State", "State",
], ],

View File

@ -157,12 +157,14 @@ from .event import EventHandler as EventHandler
from .event import background as background from .event import background as background
from .event import call_script as call_script from .event import call_script as call_script
from .event import clear_local_storage as clear_local_storage from .event import clear_local_storage as clear_local_storage
from .event import clear_session_storage as clear_session_storage
from .event import console_log as console_log from .event import console_log as console_log
from .event import download as download from .event import download as download
from .event import prevent_default as prevent_default from .event import prevent_default as prevent_default
from .event import redirect as redirect from .event import redirect as redirect
from .event import remove_cookie as remove_cookie from .event import remove_cookie as remove_cookie
from .event import remove_local_storage as remove_local_storage from .event import remove_local_storage as remove_local_storage
from .event import remove_session_storage as remove_session_storage
from .event import set_clipboard as set_clipboard from .event import set_clipboard as set_clipboard
from .event import set_focus as set_focus from .event import set_focus as set_focus
from .event import scroll_to as scroll_to from .event import scroll_to as scroll_to
@ -177,6 +179,7 @@ from .model import Model as Model
from .state import var as var from .state import var as var
from .state import Cookie as Cookie from .state import Cookie as Cookie
from .state import LocalStorage as LocalStorage from .state import LocalStorage as LocalStorage
from .state import SessionStorage as SessionStorage
from .state import ComponentState as ComponentState from .state import ComponentState as ComponentState
from .state import State as State from .state import State as State
from .style import Style as Style from .style import Style as Style

View File

@ -28,7 +28,7 @@ from reflex.components.base import (
Title, Title,
) )
from reflex.components.component import Component, ComponentStyle, CustomComponent from reflex.components.component import Component, ComponentStyle, CustomComponent
from reflex.state import BaseState, Cookie, LocalStorage from reflex.state import BaseState, Cookie, LocalStorage, SessionStorage
from reflex.style import Style from reflex.style import Style
from reflex.utils import console, format, imports, path_ops from reflex.utils import console, format, imports, path_ops
from reflex.utils.imports import ImportVar, ParsedImportDict from reflex.utils.imports import ImportVar, ParsedImportDict
@ -158,8 +158,11 @@ def compile_state(state: Type[BaseState]) -> dict:
def _compile_client_storage_field( def _compile_client_storage_field(
field: ModelField, field: ModelField,
) -> tuple[Type[Cookie] | Type[LocalStorage] | None, dict[str, Any] | None]: ) -> tuple[
"""Compile the given cookie or local_storage field. Type[Cookie] | Type[LocalStorage] | Type[SessionStorage] | None,
dict[str, Any] | None,
]:
"""Compile the given cookie, local_storage or session_storage field.
Args: Args:
field: The possible cookie field to compile. field: The possible cookie field to compile.
@ -167,7 +170,7 @@ def _compile_client_storage_field(
Returns: Returns:
A dictionary of the compiled cookie or None if the field is not cookie-like. A dictionary of the compiled cookie or None if the field is not cookie-like.
""" """
for field_type in (Cookie, LocalStorage): for field_type in (Cookie, LocalStorage, SessionStorage):
if isinstance(field.default, field_type): if isinstance(field.default, field_type):
cs_obj = field.default cs_obj = field.default
elif isinstance(field.type_, type) and issubclass(field.type_, field_type): elif isinstance(field.type_, type) and issubclass(field.type_, field_type):
@ -180,7 +183,7 @@ def _compile_client_storage_field(
def _compile_client_storage_recursive( def _compile_client_storage_recursive(
state: Type[BaseState], state: Type[BaseState],
) -> tuple[dict[str, dict], dict[str, dict[str, str]]]: ) -> tuple[dict[str, dict], dict[str, dict], dict[str, dict]]:
"""Compile the client-side storage for the given state recursively. """Compile the client-side storage for the given state recursively.
Args: Args:
@ -191,10 +194,12 @@ def _compile_client_storage_recursive(
( (
cookies: dict[str, dict], cookies: dict[str, dict],
local_storage: dict[str, dict[str, str]] local_storage: dict[str, dict[str, str]]
) session_storage: dict[str, dict[str, str]]
).
""" """
cookies = {} cookies = {}
local_storage = {} local_storage = {}
session_storage = {}
state_name = state.get_full_name() state_name = state.get_full_name()
for name, field in state.__fields__.items(): for name, field in state.__fields__.items():
if name in state.inherited_vars: if name in state.inherited_vars:
@ -206,15 +211,20 @@ def _compile_client_storage_recursive(
cookies[state_key] = options cookies[state_key] = options
elif field_type is LocalStorage: elif field_type is LocalStorage:
local_storage[state_key] = options local_storage[state_key] = options
elif field_type is SessionStorage:
session_storage[state_key] = options
else: else:
continue continue
for substate in state.get_substates(): for substate in state.get_substates():
substate_cookies, substate_local_storage = _compile_client_storage_recursive( (
substate substate_cookies,
) substate_local_storage,
substate_session_storage,
) = _compile_client_storage_recursive(substate)
cookies.update(substate_cookies) cookies.update(substate_cookies)
local_storage.update(substate_local_storage) local_storage.update(substate_local_storage)
return cookies, local_storage session_storage.update(substate_session_storage)
return cookies, local_storage, session_storage
def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]: def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
@ -226,10 +236,11 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
Returns: Returns:
A dictionary of the compiled client-side storage info. A dictionary of the compiled client-side storage info.
""" """
cookies, local_storage = _compile_client_storage_recursive(state) cookies, local_storage, session_storage = _compile_client_storage_recursive(state)
return { return {
constants.COOKIES: cookies, constants.COOKIES: cookies,
constants.LOCAL_STORAGE: local_storage, constants.LOCAL_STORAGE: local_storage,
constants.SESSION_STORAGE: session_storage,
} }

View File

@ -10,6 +10,7 @@ from .base import (
REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_CLOSING_TAG,
REFLEX_VAR_OPENING_TAG, REFLEX_VAR_OPENING_TAG,
RELOAD_CONFIG, RELOAD_CONFIG,
SESSION_STORAGE,
SKIP_COMPILE_ENV_VAR, SKIP_COMPILE_ENV_VAR,
ColorMode, ColorMode,
Dirs, Dirs,
@ -88,6 +89,7 @@ __ALL__ = [
Imports, Imports,
IS_WINDOWS, IS_WINDOWS,
LOCAL_STORAGE, LOCAL_STORAGE,
SESSION_STORAGE,
LogLevel, LogLevel,
MemoizationDisposition, MemoizationDisposition,
MemoizationMode, MemoizationMode,

View File

@ -178,6 +178,7 @@ class Ping(SimpleNamespace):
# Keys in the client_side_storage dict # Keys in the client_side_storage dict
COOKIES = "cookies" COOKIES = "cookies"
LOCAL_STORAGE = "local_storage" LOCAL_STORAGE = "local_storage"
SESSION_STORAGE = "session_storage"
# If this env var is set to "yes", App.compile will be a no-op # If this env var is set to "yes", App.compile will be a no-op
SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE" SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"

View File

@ -617,6 +617,34 @@ def remove_local_storage(key: str) -> EventSpec:
) )
def clear_session_storage() -> EventSpec:
"""Set a value in the session storage on the frontend.
Returns:
EventSpec: An event to clear the session storage.
"""
return server_side(
"_clear_session_storage",
get_fn_signature(clear_session_storage),
)
def remove_session_storage(key: str) -> EventSpec:
"""Set a value in the session storage on the frontend.
Args:
key: The key identifying the variable in the session storage to remove.
Returns:
EventSpec: An event to remove an item based on the provided key in session storage.
"""
return server_side(
"_remove_session_storage",
get_fn_signature(remove_session_storage),
key=key,
)
def set_clipboard(content: str) -> EventSpec: def set_clipboard(content: str) -> EventSpec:
"""Set the text in content in the clipboard. """Set the text in content in the clipboard.

View File

@ -2835,6 +2835,38 @@ class LocalStorage(ClientStorageBase, str):
return inst return inst
class SessionStorage(ClientStorageBase, str):
"""Represents a state Var that is stored in sessionStorage in the browser."""
name: str | None
def __new__(
cls,
object: Any = "",
encoding: str | None = None,
errors: str | None = None,
/,
name: str | None = None,
) -> "SessionStorage":
"""Create a client-side sessionStorage (str).
Args:
object: The initial object.
encoding: The encoding to use.
errors: The error handling scheme to use
name: The name of the storage on the client side
Returns:
The client-side sessionStorage object.
"""
if encoding or errors:
inst = super().__new__(cls, object, encoding or "utf-8", errors or "strict")
else:
inst = super().__new__(cls, object)
inst.name = name
return inst
class MutableProxy(wrapt.ObjectProxy): class MutableProxy(wrapt.ObjectProxy):
"""A proxy for a mutable object that tracks changes.""" """A proxy for a mutable object that tracks changes."""