diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 67ed43f20..98d8db273 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -67,6 +67,24 @@ export const getToken = () => { return token; }; +/** + * Get the URL for the websocket connection + * @returns The websocket URL object. + */ +export const getEventURL = () => { + // Get backend URL object from the endpoint. + const endpoint = new URL(EVENTURL); + if (endpoint.hostname === "localhost") { + // If the backend URL references localhost, and the frontend is not on localhost, + // then use the frontend host. + const frontend_hostname = window.location.hostname; + if (frontend_hostname !== "localhost") { + endpoint.hostname = frontend_hostname; + } + } + return endpoint +} + /** * Apply a delta to the state. * @param state The state to apply the delta to. @@ -289,9 +307,10 @@ export const connect = async ( client_storage = {}, ) => { // Get backend URL object from the endpoint. - const endpoint = new URL(EVENTURL); + const endpoint = getEventURL() + // Create the socket. - socket.current = io(EVENTURL, { + socket.current = io(endpoint.href, { path: endpoint["pathname"], transports: transports, autoUnref: false, diff --git a/reflex/components/overlay/banner.py b/reflex/components/overlay/banner.py index fd874f94b..4a37346c9 100644 --- a/reflex/components/overlay/banner.py +++ b/reflex/components/overlay/banner.py @@ -3,11 +3,13 @@ from __future__ import annotations from typing import Optional +from reflex.components.base.bare import Bare from reflex.components.component import Component from reflex.components.layout import Box, Cond from reflex.components.overlay.modal import Modal from reflex.components.typography import Text -from reflex.vars import Var +from reflex.utils import imports +from reflex.vars import ImportVar, Var connection_error: Var = Var.create_safe( value="(connectError !== null) ? connectError.message : ''", @@ -21,19 +23,35 @@ has_connection_error: Var = Var.create_safe( has_connection_error.type_ = bool -def default_connection_error() -> list[str | Var]: +class WebsocketTargetURL(Bare): + """A component that renders the websocket target URL.""" + + def _get_imports(self) -> imports.ImportDict: + return { + "/utils/state.js": {ImportVar(tag="getEventURL")}, + } + + @classmethod + def create(cls) -> Component: + """Create a websocket target URL component. + + Returns: + The websocket target URL component. + """ + return super().create(contents="{getEventURL().href}") + + +def default_connection_error() -> list[str | Var | Component]: """Get the default connection error message. Returns: The default connection error message. """ - from reflex.config import get_config - return [ "Cannot connect to server: ", connection_error, ". Check if server is reachable at ", - get_config().api_url or "", + WebsocketTargetURL.create(), ] diff --git a/reflex/config.py b/reflex/config.py index 8caa29741..625b4add3 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -6,7 +6,9 @@ import importlib import os import sys import urllib.parse -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set + +import pydantic from reflex import constants from reflex.base import Base @@ -191,6 +193,9 @@ class Config(Base): # The username. username: Optional[str] = None + # Attributes that were explicitly set by the user. + _non_default_attributes: Set[str] = pydantic.PrivateAttr(set()) + def __init__(self, *args, **kwargs): """Initialize the config values. @@ -204,7 +209,14 @@ class Config(Base): self.check_deprecated_values(**kwargs) # Update the config from environment variables. - self.update_from_env() + env_kwargs = self.update_from_env() + for key, env_value in env_kwargs.items(): + setattr(self, key, env_value) + + # Update default URLs if ports were set + kwargs.update(env_kwargs) + self._non_default_attributes.update(kwargs) + self._replace_defaults(**kwargs) @staticmethod def check_deprecated_values(**kwargs): @@ -227,13 +239,16 @@ class Config(Base): "env_path is deprecated - use environment variables instead" ) - def update_from_env(self): + def update_from_env(self) -> dict[str, Any]: """Update the config from environment variables. + Returns: + The updated config values. Raises: ValueError: If an environment variable is set to an invalid type. """ + updated_values = {} # Iterate over the fields. for key, field in self.__fields__.items(): # The env var name is the key in uppercase. @@ -260,7 +275,9 @@ class Config(Base): raise # Set the value. - setattr(self, key, env_var) + updated_values[key] = env_var + + return updated_values def get_event_namespace(self) -> str | None: """Get the websocket event namespace. @@ -274,6 +291,34 @@ class Config(Base): event_url = constants.Endpoint.EVENT.get_url() return urllib.parse.urlsplit(event_url).path + def _replace_defaults(self, **kwargs): + """Replace formatted defaults when the caller provides updates. + + Args: + **kwargs: The kwargs passed to the config or from the env. + """ + if "api_url" not in self._non_default_attributes and "backend_port" in kwargs: + self.api_url = f"http://localhost:{kwargs['backend_port']}" + + if ( + "deploy_url" not in self._non_default_attributes + and "frontend_port" in kwargs + ): + self.deploy_url = f"http://localhost:{kwargs['frontend_port']}" + + def _set_persistent(self, **kwargs): + """Set values in this config and in the environment so they persist into subprocess. + + Args: + **kwargs: The kwargs passed to the config. + """ + for key, value in kwargs.items(): + if value is not None: + os.environ[key.upper()] = str(value) + setattr(self, key, value) + self._non_default_attributes.update(kwargs) + self._replace_defaults(**kwargs) + def get_config(reload: bool = False) -> Config: """Get the app config. diff --git a/reflex/config.pyi b/reflex/config.pyi index 70390dca9..73c9f6766 100644 --- a/reflex/config.pyi +++ b/reflex/config.pyi @@ -97,5 +97,6 @@ class Config(Base): def check_deprecated_values(**kwargs) -> None: ... def update_from_env(self) -> None: ... def get_event_namespace(self) -> str | None: ... + def _set_persistent(self, **kwargs) -> None: ... def get_config(reload: bool = ...) -> Config: ... diff --git a/reflex/reflex.py b/reflex/reflex.py index 0fb65f233..1c4ce739d 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -140,6 +140,12 @@ def run( if backend and processes.is_process_on_port(backend_port): backend_port = processes.change_or_terminate_port(backend_port, "backend") + # Apply the new ports to the config. + if frontend_port != str(config.frontend_port): + config._set_persistent(frontend_port=frontend_port) + if backend_port != str(config.backend_port): + config._set_persistent(backend_port=backend_port) + console.rule("[bold]Starting Reflex App") if frontend: diff --git a/tests/test_config.py b/tests/test_config.py index ee62bc1c3..5f17a1135 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -108,3 +108,95 @@ def test_event_namespace(mocker, kwargs, expected): config = reflex.config.get_config() assert conf == config assert config.get_event_namespace() == expected + + +DEFAULT_CONFIG = rx.Config(app_name="a") + + +@pytest.mark.parametrize( + ("config_kwargs", "env_vars", "set_persistent_vars", "exp_config_values"), + [ + ( + {}, + {}, + {}, + { + "api_url": DEFAULT_CONFIG.api_url, + "backend_port": DEFAULT_CONFIG.backend_port, + "deploy_url": DEFAULT_CONFIG.deploy_url, + "frontend_port": DEFAULT_CONFIG.frontend_port, + }, + ), + # Ports set in config kwargs + ( + {"backend_port": 8001, "frontend_port": 3001}, + {}, + {}, + { + "api_url": "http://localhost:8001", + "backend_port": 8001, + "deploy_url": "http://localhost:3001", + "frontend_port": 3001, + }, + ), + # Ports set in environment take precendence + ( + {"backend_port": 8001, "frontend_port": 3001}, + {"BACKEND_PORT": 8002}, + {}, + { + "api_url": "http://localhost:8002", + "backend_port": 8002, + "deploy_url": "http://localhost:3001", + "frontend_port": 3001, + }, + ), + # Ports set on the command line take precendence + ( + {"backend_port": 8001, "frontend_port": 3001}, + {"BACKEND_PORT": 8002}, + {"frontend_port": "3005"}, + { + "api_url": "http://localhost:8002", + "backend_port": 8002, + "deploy_url": "http://localhost:3005", + "frontend_port": 3005, + }, + ), + # api_url / deploy_url already set should not be overridden + ( + {"api_url": "http://foo.bar:8900", "deploy_url": "http://foo.bar:3001"}, + {"BACKEND_PORT": 8002}, + {"frontend_port": "3005"}, + { + "api_url": "http://foo.bar:8900", + "backend_port": 8002, + "deploy_url": "http://foo.bar:3001", + "frontend_port": 3005, + }, + ), + ], +) +def test_replace_defaults( + monkeypatch, + config_kwargs, + env_vars, + set_persistent_vars, + exp_config_values, +): + """Test that the config replaces defaults with values from the environment. + + Args: + monkeypatch: The pytest monkeypatch object. + config_kwargs: The config kwargs. + env_vars: The environment variables. + set_persistent_vars: The values passed to config._set_persistent variables. + exp_config_values: The expected config values. + """ + mock_os_env = os.environ.copy() + monkeypatch.setattr(reflex.config.os, "environ", mock_os_env) # type: ignore + mock_os_env.update({k: str(v) for k, v in env_vars.items()}) + c = rx.Config(app_name="a", **config_kwargs) + c._set_persistent(**set_persistent_vars) + for key, value in exp_config_values.items(): + assert getattr(c, key) == value