diff --git a/reflex/app.py b/reflex/app.py index bb74d862f..633d6957c 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -54,6 +54,7 @@ from reflex.route import ( ) from reflex.state import ( DefaultState, + RouterData, State, StateManager, StateManagerMemory, @@ -803,6 +804,7 @@ async def process( # assignment will recurse into substates and force recalculation of # dependent ComputedVar (dynamic route variables) state.router_data = router_data + state.router = RouterData(router_data) # Preprocess the event. update = await app.preprocess(state, event) diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index 07adb9a8d..b901d2a88 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -39,6 +39,7 @@ from .installer import ( ) from .route import ( ROUTE_NOT_FOUND, + ROUTER, ROUTER_DATA, DefaultPage, Page404, @@ -77,9 +78,10 @@ __ALL__ = [ PYTEST_CURRENT_TEST, PRODUCTION_BACKEND_URL, Reflex, - RouteVar, - RouteRegex, RouteArgType, + RouteRegex, + RouteVar, + ROUTER, ROUTER_DATA, ROUTE_NOT_FOUND, SETTER_PREFIX, diff --git a/reflex/constants/route.py b/reflex/constants/route.py index c01b2e1b1..fad285f2f 100644 --- a/reflex/constants/route.py +++ b/reflex/constants/route.py @@ -13,6 +13,7 @@ class RouteArgType(SimpleNamespace): # the name of the backend var containing path and client information +ROUTER = "router" ROUTER_DATA = "router_data" diff --git a/reflex/state.py b/reflex/state.py index f9ab8a8c3..fa356b0ab 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -48,6 +48,99 @@ from reflex.vars import BaseVar, ComputedVar, Var Delta = Dict[str, Any] +class HeaderData(Base): + """An object containing headers data.""" + + host: str = "" + origin: str = "" + upgrade: str = "" + connection: str = "" + pragma: str = "" + cache_control: str = "" + user_agent: str = "" + sec_websocket_version: str = "" + sec_websocket_key: str = "" + sec_websocket_extensions: str = "" + accept_encoding: str = "" + accept_language: str = "" + + def __init__(self, router_data: Optional[dict] = None): + """Initalize the HeaderData object based on router_data. + + Args: + router_data: the router_data dict. + """ + super().__init__() + if router_data: + for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items(): + setattr(self, format.to_snake_case(k), v) + + +class PageData(Base): + """An object containing page data.""" + + host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) + path: str = "" + raw_path: str = "" + full_path: str = "" + full_raw_path: str = "" + params: dict = {} + + def __init__(self, router_data: Optional[dict] = None): + """Initalize the PageData object based on router_data. + + Args: + router_data: the router_data dict. + """ + super().__init__() + if router_data: + self.host = router_data.get(constants.RouteVar.HEADERS, {}).get("origin") + self.path = router_data.get(constants.RouteVar.PATH, "") + self.raw_path = router_data.get(constants.RouteVar.ORIGIN, "") + self.full_path = f"{self.host}{self.path}" + self.full_raw_path = f"{self.host}{self.raw_path}" + self.params = router_data.get(constants.RouteVar.QUERY, {}) + + +class SessionData(Base): + """An object containing session data.""" + + client_token: str = "" + client_ip: str = "" + session_id: str = "" + + def __init__(self, router_data: Optional[dict] = None): + """Initalize the SessionData object based on router_data. + + Args: + router_data: the router_data dict. + """ + super().__init__() + if router_data: + self.client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "") + self.client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "") + self.session_id = router_data.get(constants.RouteVar.SESSION_ID, "") + + +class RouterData(Base): + """An object containing RouterData.""" + + session: SessionData = SessionData() + headers: HeaderData = HeaderData() + page: PageData = PageData() + + def __init__(self, router_data: Optional[dict] = None): + """Initialize the RouterData object. + + Args: + router_data: the router_data dict. + """ + super().__init__() + self.session = SessionData(router_data) + self.headers = HeaderData(router_data) + self.page = PageData(router_data) + + class State(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" @@ -96,6 +189,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Per-instance copy of backend variable values _backend_vars: Dict[str, Any] = {} + # The router data for the current page + router: RouterData = RouterData() + def __init__(self, *args, parent_state: State | None = None, **kwargs): """Initialize the state. @@ -494,6 +590,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Returns: The token of the client. """ + console.deprecate( + feature_name="get_token", + reason="replaced by `State.router.session.client_token`", + deprecation_version="0.3.0", + removal_version="0.3.1", + ) return self.router_data.get(constants.RouteVar.CLIENT_TOKEN, "") def get_sid(self) -> str: @@ -502,6 +604,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Returns: The session ID of the client. """ + console.deprecate( + feature_name="get_sid", + reason="replaced by `State.router.session.session_id`", + deprecation_version="0.3.0", + removal_version="0.3.1", + ) return self.router_data.get(constants.RouteVar.SESSION_ID, "") def get_headers(self) -> Dict: @@ -510,6 +618,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Returns: The headers of the client. """ + console.deprecate( + feature_name="get_headers", + reason="replaced by `State.router.headers`", + deprecation_version="0.3.0", + removal_version="0.3.1", + ) return self.router_data.get(constants.RouteVar.HEADERS, {}) def get_client_ip(self) -> str: @@ -518,6 +632,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Returns: The IP of the client. """ + console.deprecate( + feature_name="get_client_ip", + reason="replaced by `State.router.session.client_ip`", + deprecation_version="0.3.0", + removal_version="0.3.1", + ) return self.router_data.get(constants.RouteVar.CLIENT_IP, "") def get_current_page(self, origin=False) -> str: @@ -529,10 +649,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Returns: The current page. """ - if origin: - return self.router_data.get(constants.RouteVar.ORIGIN, "") - else: - return self.router_data.get(constants.RouteVar.PATH, "") + console.deprecate( + feature_name="get_current_page", + reason="replaced by State.router.page / self.router.page", + deprecation_version="0.3.0", + removal_version="0.3.1", + ) + + return self.router.page.raw_path if origin else self.router.page.path def get_query_params(self) -> dict[str, str]: """Obtain the query parameters for the queried page. @@ -542,6 +666,12 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Returns: The dict of query parameters. """ + console.deprecate( + feature_name="get_query_params", + reason="replaced by `State.router.page.params`", + deprecation_version="0.3.0", + removal_version="0.3.1", + ) return self.router_data.get(constants.RouteVar.QUERY, {}) def get_cookies(self) -> dict[str, str]: @@ -583,14 +713,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow): def argsingle_factory(param): @ComputedVar def inner_func(self) -> str: - return self.get_query_params().get(param, "") + return self.router.page.params.get(param, "") return inner_func def arglist_factory(param): @ComputedVar def inner_func(self) -> List: - return self.get_query_params().get(param, []) + return self.router.page.params.get(param, []) return inner_func diff --git a/tests/test_app.py b/tests/test_app.py index 5158b93e5..a67ad0180 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -34,7 +34,7 @@ from reflex.components import Box, Component, Cond, Fragment, Text from reflex.event import Event, get_hydrate_event from reflex.middleware import HydrateMiddleware from reflex.model import Model -from reflex.state import State, StateManagerRedis, StateUpdate +from reflex.state import RouterData, State, StateManagerRedis, StateUpdate from reflex.style import Style from reflex.utils import format from reflex.vars import ComputedVar @@ -255,9 +255,9 @@ def test_add_page_set_route_dynamic(app: App, index_page, windows_platform: bool assert set(app.pages.keys()) == {"test/[dynamic]"} assert "dynamic" in app.state.computed_vars assert app.state.computed_vars["dynamic"]._deps(objclass=DefaultState) == { - constants.ROUTER_DATA + constants.ROUTER } - assert constants.ROUTER_DATA in app.state().computed_var_dependencies + assert constants.ROUTER in app.state().computed_var_dependencies def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool): @@ -874,9 +874,9 @@ async def test_dynamic_route_var_route_change_completed_on_load( assert arg_name in app.state.vars assert arg_name in app.state.computed_vars assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == { - constants.ROUTER_DATA + constants.ROUTER } - assert constants.ROUTER_DATA in app.state().computed_var_dependencies + assert constants.ROUTER in app.state().computed_var_dependencies sid = "mock_sid" client_ip = "127.0.0.1" @@ -912,6 +912,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( "token": token, **hydrate_event.router_data, } + exp_router = RouterData(exp_router_data) process_coro = process( app, event=hydrate_event, @@ -920,7 +921,6 @@ async def test_dynamic_route_var_route_change_completed_on_load( client_ip=client_ip, ) update = await process_coro.__anext__() # type: ignore - # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)] assert update == StateUpdate( delta={ @@ -930,6 +930,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( constants.CompileVars.IS_HYDRATED: False, "loaded": exp_index, "counter": exp_index, + "router": exp_router, # "side_effect_counter": exp_index, } }, diff --git a/tests/test_state.py b/tests/test_state.py index 0995f3403..b2910b377 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -22,6 +22,7 @@ from reflex.state import ( ImmutableStateError, LockExpiredError, MutableProxy, + RouterData, State, StateManager, StateManagerMemory, @@ -40,6 +41,33 @@ LOCK_EXPIRATION = 2000 if CI else 100 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2 +formatted_router = { + "session": {"client_token": "", "client_ip": "", "session_id": ""}, + "headers": { + "host": "", + "origin": "", + "upgrade": "", + "connection": "", + "pragma": "", + "cache_control": "", + "user_agent": "", + "sec_websocket_version": "", + "sec_websocket_key": "", + "sec_websocket_extensions": "", + "accept_encoding": "", + "accept_language": "", + }, + "page": { + "host": "", + "path": "", + "raw_path": "", + "full_path": "", + "full_raw_path": "", + "params": {}, + }, +} + + class Object(Base): """A test object fixture.""" @@ -226,6 +254,7 @@ def test_class_vars(test_state): cls = type(test_state) assert set(cls.vars.keys()) == { CompileVars.IS_HYDRATED, # added by hydrate_middleware to all State + "router", "num1", "num2", "key", @@ -614,6 +643,7 @@ def test_reset(test_state, child_state): "map_key", "mapping", "dt", + "router", } # The dirty vars should be reset. @@ -787,7 +817,7 @@ def test_get_current_page(test_state): assert test_state.get_current_page() == "" route = "mypage/subpage" - test_state.router_data = {RouteVar.PATH: route} + test_state.router = RouterData({RouteVar.PATH: route}) assert test_state.get_current_page() == route @@ -1131,16 +1161,19 @@ def test_computed_var_depends_on_parent_non_cached(): cs.get_name(): {"dep_v": 2}, "no_cache_v": 1, CompileVars.IS_HYDRATED: False, + "router": formatted_router, } assert ps.dict() == { cs.get_name(): {"dep_v": 4}, "no_cache_v": 3, CompileVars.IS_HYDRATED: False, + "router": formatted_router, } assert ps.dict() == { cs.get_name(): {"dep_v": 6}, "no_cache_v": 5, CompileVars.IS_HYDRATED: False, + "router": formatted_router, } assert counter == 6 @@ -2114,7 +2147,12 @@ def test_json_dumps_with_mutables(): dict_val = MutableContainsBase().dict() assert isinstance(dict_val["items"][0], dict) val = json_dumps(dict_val) - assert val == '{"is_hydrated": false, "items": [{"tags": ["123", "456"]}]}' + f_items = '[{"tags": ["123", "456"]}]' + f_formatted_router = str(formatted_router).replace("'", '"') + assert ( + val + == f'{{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}' + ) def test_reset_with_mutables(): diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index d2a19b799..ab12e8a62 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -446,6 +446,33 @@ def test_format_query_params(input, output): assert format.format_query_params(input) == output +formatted_router = { + "session": {"client_token": "", "client_ip": "", "session_id": ""}, + "headers": { + "host": "", + "origin": "", + "upgrade": "", + "connection": "", + "pragma": "", + "cache_control": "", + "user_agent": "", + "sec_websocket_version": "", + "sec_websocket_key": "", + "sec_websocket_extensions": "", + "accept_encoding": "", + "accept_language": "", + }, + "page": { + "host": "", + "path": "", + "raw_path": "", + "full_path": "", + "full_raw_path": "", + "params": {}, + }, +} + + @pytest.mark.parametrize( "input, output", [ @@ -474,6 +501,7 @@ def test_format_query_params(input, output): "obj": {"prop1": 42, "prop2": "hello"}, "sum": 3.14, "upper": "", + "router": formatted_router, }, ), ( @@ -484,6 +512,7 @@ def test_format_query_params(input, output): "is_hydrated": False, "t": "18:53:00+01:00", "td": "11 days, 0:11:00", + "router": formatted_router, }, ), ],