From b6ae225455d9d29851fa097e72292c615cbfe4fd Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Tue, 20 Jun 2023 21:57:33 +0000 Subject: [PATCH] Get Cookies (#1221) --- pynecone/constants.py | 1 + pynecone/state.py | 16 ++++++++++++ tests/conftest.py | 44 ++++++++++++++++++++++++++++++++ tests/test_state.py | 59 +++++++++++++++++++++++++++---------------- 4 files changed, 98 insertions(+), 22 deletions(-) diff --git a/pynecone/constants.py b/pynecone/constants.py index 33e4ee517..af5d87f96 100644 --- a/pynecone/constants.py +++ b/pynecone/constants.py @@ -316,6 +316,7 @@ class RouteVar(SimpleNamespace): PATH = "pathname" SESSION_ID = "sid" QUERY = "query" + COOKIE = "cookie" class RouteRegex(SimpleNamespace): diff --git a/pynecone/state.py b/pynecone/state.py index 2785e76b3..7029a2c83 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -490,6 +490,22 @@ class State(Base, ABC, extra=pydantic.Extra.allow): """ return self.router_data.get(constants.RouteVar.QUERY, {}) + def get_cookies(self) -> Dict[str, str]: + """Obtain the cookies of the client stored in the browser. + + Returns: + The dict of cookies. + """ + headers = self.get_headers().get(constants.RouteVar.COOKIE) + return ( + { + pair[0].strip(): pair[1].strip() + for pair in (item.split("=") for item in headers.split(";")) + } + if headers + else {} + ) + @classmethod def setup_dynamic_args(cls, args: dict[str, str]): """Set up args for easy access in renderer. diff --git a/tests/conftest.py b/tests/conftest.py index 6e13ed6b1..22b0b79b2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -432,3 +432,47 @@ def gen_state() -> GenState: A test state. """ return GenState # type: ignore + + +@pytest.fixture +def router_data_headers() -> Dict[str, str]: + """Router data headers. + + Returns: + client headers + """ + return { + "host": "localhost:8000", + "connection": "Upgrade", + "pragma": "no-cache", + "cache-control": "no-cache", + "user-agent": "Mock Agent", + "upgrade": "websocket", + "origin": "http://localhost:3000", + "sec-websocket-version": "13", + "accept-encoding": "gzip, deflate, br", + "accept-language": "en-US,en;q=0.9", + "cookie": "csrftoken=mocktoken; name=reflex", + "sec-websocket-key": "mock-websocket-key", + "sec-websocket-extensions": "permessage-deflate; client_max_window_bits", + } + + +@pytest.fixture +def router_data(router_data_headers) -> Dict[str, str]: + """Router data. + + Args: + router_data_headers: Headers fixture. + + Returns: + Dict of router data. + """ + return { # type: ignore + "pathname": "/", + "query": {}, + "token": "b181904c-3953-4a79-dc18-ae9518c22f05", + "sid": "9fpxSzPb9aFMb4wFAAAH", + "headers": router_data_headers, + "ip": "127.0.0.1", + } diff --git a/tests/test_state.py b/tests/test_state.py index d0f40a2d9..8cdb49144 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -674,55 +674,70 @@ def test_format_event_handler(): ) -def test_get_token(test_state): - assert test_state.get_token() == "" +def test_get_token(test_state, mocker, router_data): + """Test that the token obtained from the router_data is correct. - token = "b181904c-3953-4a79-dc18-ae9518c22f05" - test_state.router_data = {RouteVar.CLIENT_TOKEN: token} + Args: + test_state: The test state. + mocker: Pytest Mocker object. + router_data: The router data fixture. + """ + mocker.patch.object(test_state, "router_data", router_data) - assert test_state.get_token() == token + assert test_state.get_token() == "b181904c-3953-4a79-dc18-ae9518c22f05" -def test_get_sid(test_state): +def test_get_sid(test_state, mocker, router_data): """Test getting session id. Args: test_state: A state. + mocker: Pytest Mocker object. + router_data: The router data fixture. """ - assert test_state.get_sid() == "" + mocker.patch.object(test_state, "router_data", router_data) - sid = "9fpxSzPb9aFMb4wFAAAH" - test_state.router_data = {RouteVar.SESSION_ID: sid} - - assert test_state.get_sid() == sid + assert test_state.get_sid() == "9fpxSzPb9aFMb4wFAAAH" -def test_get_headers(test_state): +def test_get_headers(test_state, mocker, router_data, router_data_headers): """Test getting client headers. Args: test_state: A state. + mocker: Pytest Mocker object. + router_data: The router data fixture. + router_data_headers: The expected headers. """ - assert test_state.get_headers() == {} + mocker.patch.object(test_state, "router_data", router_data) - headers = {"host": "localhost:8000", "connection": "keep-alive"} - test_state.router_data = {RouteVar.HEADERS: headers} - - assert test_state.get_headers() == headers + assert test_state.get_headers() == router_data_headers -def test_get_client_ip(test_state): +def test_get_client_ip(test_state, mocker, router_data): """Test getting client IP. Args: test_state: A state. + mocker: Pytest Mocker object. + router_data: The router data fixture. """ - assert test_state.get_client_ip() == "" + mocker.patch.object(test_state, "router_data", router_data) - client_ip = "127.0.0.1" - test_state.router_data = {RouteVar.CLIENT_IP: client_ip} + assert test_state.get_client_ip() == "127.0.0.1" - assert test_state.get_client_ip() == client_ip + +def test_get_cookies(test_state, mocker, router_data): + """Test getting client cookies. + + Args: + test_state: A state. + mocker: Pytest Mocker object. + router_data: The router data fixture. + """ + mocker.patch.object(test_state, "router_data", router_data) + + assert test_state.get_cookies() == {"csrftoken": "mocktoken", "name": "reflex"} def test_get_current_page(test_state):