Disconnect old websockets and avoid duplicating ws during hot reload

This commit is contained in:
Masen Furer 2024-12-20 10:20:06 -08:00
parent 77fe285675
commit 973e1141de
No known key found for this signature in database
GPG Key ID: 2AE2BD5531FF94F4
4 changed files with 38 additions and 8 deletions

View File

@ -398,6 +398,11 @@ export const connect = async (
// Get backend URL object from the endpoint. // Get backend URL object from the endpoint.
const endpoint = getBackendURL(EVENTURL); const endpoint = getBackendURL(EVENTURL);
// Disconnect old socket
if (socket.current && socket.current.connected) {
socket.current.disconnect();
}
// Create the socket. // Create the socket.
socket.current = io(endpoint.href, { socket.current = io(endpoint.href, {
path: endpoint["pathname"], path: endpoint["pathname"],
@ -429,6 +434,7 @@ export const connect = async (
socket.current.on("connect", () => { socket.current.on("connect", () => {
setConnectErrors([]); setConnectErrors([]);
window.addEventListener("pagehide", pagehideHandler); window.addEventListener("pagehide", pagehideHandler);
document.addEventListener("visibilitychange", checkVisibility);
}); });
socket.current.on("connect_error", (error) => { socket.current.on("connect_error", (error) => {
@ -438,7 +444,10 @@ export const connect = async (
// When the socket disconnects reset the event_processing flag // When the socket disconnects reset the event_processing flag
socket.current.on("disconnect", () => { socket.current.on("disconnect", () => {
event_processing = false; event_processing = false;
socket.current.io.skipReconnect = true;
socket.current = null;
window.removeEventListener("pagehide", pagehideHandler); window.removeEventListener("pagehide", pagehideHandler);
document.removeEventListener("visibilitychange", checkVisibility);
}); });
// On each received message, queue the updates and events. // On each received message, queue the updates and events.
@ -457,7 +466,6 @@ export const connect = async (
queueEvents([...initialEvents(), event], socket); queueEvents([...initialEvents(), event], socket);
}); });
document.addEventListener("visibilitychange", checkVisibility);
}; };
/** /**

View File

@ -26,6 +26,7 @@ from typing import (
from typing_extensions import Annotated, get_type_hints from typing_extensions import Annotated, get_type_hints
from reflex.utils.console import set_log_level
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
from reflex.utils.types import GenericType, is_union, value_inside_optional from reflex.utils.types import GenericType, is_union, value_inside_optional
@ -599,6 +600,7 @@ class Config(Base):
class Config: class Config:
"""Pydantic config for the config.""" """Pydantic config for the config."""
use_enum_values = False
validate_assignment = True validate_assignment = True
# The name of the app (should match the name of the app directory). # The name of the app (should match the name of the app directory).
@ -718,6 +720,9 @@ class Config(Base):
self._non_default_attributes.update(kwargs) self._non_default_attributes.update(kwargs)
self._replace_defaults(**kwargs) self._replace_defaults(**kwargs)
# Set the log level for this process
set_log_level(self.loglevel)
if ( if (
self.state_manager_mode == constants.StateManagerMode.REDIS self.state_manager_mode == constants.StateManagerMode.REDIS
and not self.redis_url and not self.redis_url

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator from typing import Any, AsyncGenerator
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
@ -34,6 +34,7 @@ except ImportError:
""" """
yield yield
else: else:
MAX_PROXY_RETRY = 25
async def proxy_http_with_retry( async def proxy_http_with_retry(
*, *,
@ -41,25 +42,36 @@ else:
scope: Scope, scope: Scope,
receive: Receive, receive: Receive,
send: Send, send: Send,
) -> None: ) -> Any:
"""Proxy an HTTP request with retries. """Proxy an HTTP request with retries.
Args: Args:
context: The proxy context. context: The proxy context.
scope: The ASGI scope. scope: The request scope.
receive: The receive channel. receive: The receive channel.
send: The send channel. send: The send channel.
Returns:
The response from `proxy_http`.
""" """
for _attempt in range(100): for _attempt in range(MAX_PROXY_RETRY):
try: try:
return await proxy_http( return await proxy_http(
context=context, scope=scope, receive=receive, send=send context=context,
scope=scope,
receive=receive,
send=send,
) )
except aiohttp.client_exceptions.ClientError as err: # noqa: PERF203 except aiohttp.ClientError as err: # noqa: PERF203
console.debug( console.debug(
f"Retrying request {scope['path']} due to client error {err!r}." f"Retrying request {scope['path']} due to client error {err!r}."
) )
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
except Exception as ex:
console.debug(
f"Retrying request {scope['path']} due to unhandled exception {ex!r}."
)
await asyncio.sleep(0.3)
def _get_proxy_app_with_context(frontend_host: str) -> tuple[ProxyContext, ASGIApp]: def _get_proxy_app_with_context(frontend_host: str) -> tuple[ProxyContext, ASGIApp]:
"""Get the proxy app with the given frontend host. """Get the proxy app with the given frontend host.

View File

@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
import os
from rich.console import Console from rich.console import Console
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from rich.prompt import Prompt from rich.prompt import Prompt
@ -12,7 +14,7 @@ from reflex.constants import LogLevel
_console = Console() _console = Console()
# The current log level. # The current log level.
_LOG_LEVEL = LogLevel.INFO _LOG_LEVEL = LogLevel.DEFAULT
# Deprecated features who's warning has been printed. # Deprecated features who's warning has been printed.
_EMITTED_DEPRECATION_WARNINGS = set() _EMITTED_DEPRECATION_WARNINGS = set()
@ -61,6 +63,9 @@ def set_log_level(log_level: LogLevel):
raise ValueError(f"Invalid log level: {log_level}") from ae raise ValueError(f"Invalid log level: {log_level}") from ae
global _LOG_LEVEL global _LOG_LEVEL
if log_level != _LOG_LEVEL:
# Set the loglevel persistently for subprocesses
os.environ["LOGLEVEL"] = log_level.value
_LOG_LEVEL = log_level _LOG_LEVEL = log_level