[REF-843] Automatically update api_url and deploy_url (#1954)
This commit is contained in:
parent
d0cb5b07e7
commit
684912e33b
@ -67,6 +67,24 @@ export const getToken = () => {
|
|||||||
return token;
|
return token;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the URL for the websocket connection
|
||||||
|
* @returns The websocket URL object.
|
||||||
|
*/
|
||||||
|
export const getEventURL = () => {
|
||||||
|
// Get backend URL object from the endpoint.
|
||||||
|
const endpoint = new URL(EVENTURL);
|
||||||
|
if (endpoint.hostname === "localhost") {
|
||||||
|
// If the backend URL references localhost, and the frontend is not on localhost,
|
||||||
|
// then use the frontend host.
|
||||||
|
const frontend_hostname = window.location.hostname;
|
||||||
|
if (frontend_hostname !== "localhost") {
|
||||||
|
endpoint.hostname = frontend_hostname;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return endpoint
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Apply a delta to the state.
|
* Apply a delta to the state.
|
||||||
* @param state The state to apply the delta to.
|
* @param state The state to apply the delta to.
|
||||||
@ -289,9 +307,10 @@ export const connect = async (
|
|||||||
client_storage = {},
|
client_storage = {},
|
||||||
) => {
|
) => {
|
||||||
// Get backend URL object from the endpoint.
|
// Get backend URL object from the endpoint.
|
||||||
const endpoint = new URL(EVENTURL);
|
const endpoint = getEventURL()
|
||||||
|
|
||||||
// Create the socket.
|
// Create the socket.
|
||||||
socket.current = io(EVENTURL, {
|
socket.current = io(endpoint.href, {
|
||||||
path: endpoint["pathname"],
|
path: endpoint["pathname"],
|
||||||
transports: transports,
|
transports: transports,
|
||||||
autoUnref: false,
|
autoUnref: false,
|
||||||
|
@ -3,11 +3,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from reflex.components.base.bare import Bare
|
||||||
from reflex.components.component import Component
|
from reflex.components.component import Component
|
||||||
from reflex.components.layout import Box, Cond
|
from reflex.components.layout import Box, Cond
|
||||||
from reflex.components.overlay.modal import Modal
|
from reflex.components.overlay.modal import Modal
|
||||||
from reflex.components.typography import Text
|
from reflex.components.typography import Text
|
||||||
from reflex.vars import Var
|
from reflex.utils import imports
|
||||||
|
from reflex.vars import ImportVar, Var
|
||||||
|
|
||||||
connection_error: Var = Var.create_safe(
|
connection_error: Var = Var.create_safe(
|
||||||
value="(connectError !== null) ? connectError.message : ''",
|
value="(connectError !== null) ? connectError.message : ''",
|
||||||
@ -21,19 +23,35 @@ has_connection_error: Var = Var.create_safe(
|
|||||||
has_connection_error.type_ = bool
|
has_connection_error.type_ = bool
|
||||||
|
|
||||||
|
|
||||||
def default_connection_error() -> list[str | Var]:
|
class WebsocketTargetURL(Bare):
|
||||||
|
"""A component that renders the websocket target URL."""
|
||||||
|
|
||||||
|
def _get_imports(self) -> imports.ImportDict:
|
||||||
|
return {
|
||||||
|
"/utils/state.js": {ImportVar(tag="getEventURL")},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls) -> Component:
|
||||||
|
"""Create a websocket target URL component.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The websocket target URL component.
|
||||||
|
"""
|
||||||
|
return super().create(contents="{getEventURL().href}")
|
||||||
|
|
||||||
|
|
||||||
|
def default_connection_error() -> list[str | Var | Component]:
|
||||||
"""Get the default connection error message.
|
"""Get the default connection error message.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The default connection error message.
|
The default connection error message.
|
||||||
"""
|
"""
|
||||||
from reflex.config import get_config
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
"Cannot connect to server: ",
|
"Cannot connect to server: ",
|
||||||
connection_error,
|
connection_error,
|
||||||
". Check if server is reachable at ",
|
". Check if server is reachable at ",
|
||||||
get_config().api_url or "<API_URL not set>",
|
WebsocketTargetURL.create(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,9 @@ import importlib
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
@ -191,6 +193,9 @@ class Config(Base):
|
|||||||
# The username.
|
# The username.
|
||||||
username: Optional[str] = None
|
username: Optional[str] = None
|
||||||
|
|
||||||
|
# Attributes that were explicitly set by the user.
|
||||||
|
_non_default_attributes: Set[str] = pydantic.PrivateAttr(set())
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
"""Initialize the config values.
|
"""Initialize the config values.
|
||||||
|
|
||||||
@ -204,7 +209,14 @@ class Config(Base):
|
|||||||
self.check_deprecated_values(**kwargs)
|
self.check_deprecated_values(**kwargs)
|
||||||
|
|
||||||
# Update the config from environment variables.
|
# Update the config from environment variables.
|
||||||
self.update_from_env()
|
env_kwargs = self.update_from_env()
|
||||||
|
for key, env_value in env_kwargs.items():
|
||||||
|
setattr(self, key, env_value)
|
||||||
|
|
||||||
|
# Update default URLs if ports were set
|
||||||
|
kwargs.update(env_kwargs)
|
||||||
|
self._non_default_attributes.update(kwargs)
|
||||||
|
self._replace_defaults(**kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_deprecated_values(**kwargs):
|
def check_deprecated_values(**kwargs):
|
||||||
@ -227,13 +239,16 @@ class Config(Base):
|
|||||||
"env_path is deprecated - use environment variables instead"
|
"env_path is deprecated - use environment variables instead"
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_from_env(self):
|
def update_from_env(self) -> dict[str, Any]:
|
||||||
"""Update the config from environment variables.
|
"""Update the config from environment variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated config values.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If an environment variable is set to an invalid type.
|
ValueError: If an environment variable is set to an invalid type.
|
||||||
"""
|
"""
|
||||||
|
updated_values = {}
|
||||||
# Iterate over the fields.
|
# Iterate over the fields.
|
||||||
for key, field in self.__fields__.items():
|
for key, field in self.__fields__.items():
|
||||||
# The env var name is the key in uppercase.
|
# The env var name is the key in uppercase.
|
||||||
@ -260,7 +275,9 @@ class Config(Base):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# Set the value.
|
# Set the value.
|
||||||
setattr(self, key, env_var)
|
updated_values[key] = env_var
|
||||||
|
|
||||||
|
return updated_values
|
||||||
|
|
||||||
def get_event_namespace(self) -> str | None:
|
def get_event_namespace(self) -> str | None:
|
||||||
"""Get the websocket event namespace.
|
"""Get the websocket event namespace.
|
||||||
@ -274,6 +291,34 @@ class Config(Base):
|
|||||||
event_url = constants.Endpoint.EVENT.get_url()
|
event_url = constants.Endpoint.EVENT.get_url()
|
||||||
return urllib.parse.urlsplit(event_url).path
|
return urllib.parse.urlsplit(event_url).path
|
||||||
|
|
||||||
|
def _replace_defaults(self, **kwargs):
|
||||||
|
"""Replace formatted defaults when the caller provides updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: The kwargs passed to the config or from the env.
|
||||||
|
"""
|
||||||
|
if "api_url" not in self._non_default_attributes and "backend_port" in kwargs:
|
||||||
|
self.api_url = f"http://localhost:{kwargs['backend_port']}"
|
||||||
|
|
||||||
|
if (
|
||||||
|
"deploy_url" not in self._non_default_attributes
|
||||||
|
and "frontend_port" in kwargs
|
||||||
|
):
|
||||||
|
self.deploy_url = f"http://localhost:{kwargs['frontend_port']}"
|
||||||
|
|
||||||
|
def _set_persistent(self, **kwargs):
|
||||||
|
"""Set values in this config and in the environment so they persist into subprocess.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: The kwargs passed to the config.
|
||||||
|
"""
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[key.upper()] = str(value)
|
||||||
|
setattr(self, key, value)
|
||||||
|
self._non_default_attributes.update(kwargs)
|
||||||
|
self._replace_defaults(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_config(reload: bool = False) -> Config:
|
def get_config(reload: bool = False) -> Config:
|
||||||
"""Get the app config.
|
"""Get the app config.
|
||||||
|
@ -97,5 +97,6 @@ class Config(Base):
|
|||||||
def check_deprecated_values(**kwargs) -> None: ...
|
def check_deprecated_values(**kwargs) -> None: ...
|
||||||
def update_from_env(self) -> None: ...
|
def update_from_env(self) -> None: ...
|
||||||
def get_event_namespace(self) -> str | None: ...
|
def get_event_namespace(self) -> str | None: ...
|
||||||
|
def _set_persistent(self, **kwargs) -> None: ...
|
||||||
|
|
||||||
def get_config(reload: bool = ...) -> Config: ...
|
def get_config(reload: bool = ...) -> Config: ...
|
||||||
|
@ -140,6 +140,12 @@ def run(
|
|||||||
if backend and processes.is_process_on_port(backend_port):
|
if backend and processes.is_process_on_port(backend_port):
|
||||||
backend_port = processes.change_or_terminate_port(backend_port, "backend")
|
backend_port = processes.change_or_terminate_port(backend_port, "backend")
|
||||||
|
|
||||||
|
# Apply the new ports to the config.
|
||||||
|
if frontend_port != str(config.frontend_port):
|
||||||
|
config._set_persistent(frontend_port=frontend_port)
|
||||||
|
if backend_port != str(config.backend_port):
|
||||||
|
config._set_persistent(backend_port=backend_port)
|
||||||
|
|
||||||
console.rule("[bold]Starting Reflex App")
|
console.rule("[bold]Starting Reflex App")
|
||||||
|
|
||||||
if frontend:
|
if frontend:
|
||||||
|
@ -108,3 +108,95 @@ def test_event_namespace(mocker, kwargs, expected):
|
|||||||
config = reflex.config.get_config()
|
config = reflex.config.get_config()
|
||||||
assert conf == config
|
assert conf == config
|
||||||
assert config.get_event_namespace() == expected
|
assert config.get_event_namespace() == expected
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CONFIG = rx.Config(app_name="a")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("config_kwargs", "env_vars", "set_persistent_vars", "exp_config_values"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{
|
||||||
|
"api_url": DEFAULT_CONFIG.api_url,
|
||||||
|
"backend_port": DEFAULT_CONFIG.backend_port,
|
||||||
|
"deploy_url": DEFAULT_CONFIG.deploy_url,
|
||||||
|
"frontend_port": DEFAULT_CONFIG.frontend_port,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# Ports set in config kwargs
|
||||||
|
(
|
||||||
|
{"backend_port": 8001, "frontend_port": 3001},
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{
|
||||||
|
"api_url": "http://localhost:8001",
|
||||||
|
"backend_port": 8001,
|
||||||
|
"deploy_url": "http://localhost:3001",
|
||||||
|
"frontend_port": 3001,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# Ports set in environment take precendence
|
||||||
|
(
|
||||||
|
{"backend_port": 8001, "frontend_port": 3001},
|
||||||
|
{"BACKEND_PORT": 8002},
|
||||||
|
{},
|
||||||
|
{
|
||||||
|
"api_url": "http://localhost:8002",
|
||||||
|
"backend_port": 8002,
|
||||||
|
"deploy_url": "http://localhost:3001",
|
||||||
|
"frontend_port": 3001,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# Ports set on the command line take precendence
|
||||||
|
(
|
||||||
|
{"backend_port": 8001, "frontend_port": 3001},
|
||||||
|
{"BACKEND_PORT": 8002},
|
||||||
|
{"frontend_port": "3005"},
|
||||||
|
{
|
||||||
|
"api_url": "http://localhost:8002",
|
||||||
|
"backend_port": 8002,
|
||||||
|
"deploy_url": "http://localhost:3005",
|
||||||
|
"frontend_port": 3005,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# api_url / deploy_url already set should not be overridden
|
||||||
|
(
|
||||||
|
{"api_url": "http://foo.bar:8900", "deploy_url": "http://foo.bar:3001"},
|
||||||
|
{"BACKEND_PORT": 8002},
|
||||||
|
{"frontend_port": "3005"},
|
||||||
|
{
|
||||||
|
"api_url": "http://foo.bar:8900",
|
||||||
|
"backend_port": 8002,
|
||||||
|
"deploy_url": "http://foo.bar:3001",
|
||||||
|
"frontend_port": 3005,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_replace_defaults(
|
||||||
|
monkeypatch,
|
||||||
|
config_kwargs,
|
||||||
|
env_vars,
|
||||||
|
set_persistent_vars,
|
||||||
|
exp_config_values,
|
||||||
|
):
|
||||||
|
"""Test that the config replaces defaults with values from the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
monkeypatch: The pytest monkeypatch object.
|
||||||
|
config_kwargs: The config kwargs.
|
||||||
|
env_vars: The environment variables.
|
||||||
|
set_persistent_vars: The values passed to config._set_persistent variables.
|
||||||
|
exp_config_values: The expected config values.
|
||||||
|
"""
|
||||||
|
mock_os_env = os.environ.copy()
|
||||||
|
monkeypatch.setattr(reflex.config.os, "environ", mock_os_env) # type: ignore
|
||||||
|
mock_os_env.update({k: str(v) for k, v in env_vars.items()})
|
||||||
|
c = rx.Config(app_name="a", **config_kwargs)
|
||||||
|
c._set_persistent(**set_persistent_vars)
|
||||||
|
for key, value in exp_config_values.items():
|
||||||
|
assert getattr(c, key) == value
|
||||||
|
Loading…
Reference in New Issue
Block a user