diff --git a/docker-example/Caddyfile b/docker-example/Caddyfile
index 3ab6b0815..275140b56 100644
--- a/docker-example/Caddyfile
+++ b/docker-example/Caddyfile
@@ -7,12 +7,8 @@ handle @backend_routes {
reverse_proxy app:8000
}
+root * /srv
route {
- try_files {path} {path}.html
- file_server {
- root /srv
- pass_thru
- }
- # proxy dynamic routes to nextjs server
- reverse_proxy app:3000
+ try_files {path} {path}/ /404.html
+ file_server
}
diff --git a/docker-example/Dockerfile b/docker-example/Dockerfile
index 11fde7b99..a0c68944c 100644
--- a/docker-example/Dockerfile
+++ b/docker-example/Dockerfile
@@ -8,9 +8,6 @@ ARG API_URL
WORKDIR /app
COPY . .
-# Reflex will install bun, nvm, and node to `$HOME/.reflex` (/app/.reflex)
-ENV HOME=/app
-
# Create virtualenv which will be copied into final container
ENV VIRTUAL_ENV=/app/.venv
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
@@ -22,9 +19,13 @@ RUN pip install -r requirements.txt
# Deploy templates and prepare app
RUN reflex init
-# Export static copy of frontend to /app/.web/_static (and pre-install frontend packages)
+# Export static copy of frontend to /app/.web/_static
RUN reflex export --frontend-only --no-zip
+# Copy static files out of /app to save space in backend image
+RUN mv .web/_static /tmp/_static
+RUN rm -rf .web && mkdir .web
+RUN mv /tmp/_static .web/_static
# Stage 2: copy artifacts into slim image
FROM python:3.11-slim
@@ -35,4 +36,4 @@ COPY --chown=reflex --from=init /app /app
USER reflex
ENV PATH="/app/.venv/bin:$PATH" API_URL=$API_URL
-CMD reflex db migrate && reflex run --env prod
+CMD reflex db migrate && reflex run --env prod --backend-only
diff --git a/integration/conftest.py b/integration/conftest.py
index 92e0613fb..2a13161a3 100644
--- a/integration/conftest.py
+++ b/integration/conftest.py
@@ -5,6 +5,8 @@ from pathlib import Path
import pytest
+from reflex.testing import AppHarness, AppHarnessProd
+
DISPLAY = None
XVFB_DIMENSIONS = (800, 600)
@@ -57,3 +59,18 @@ def pytest_exception_interact(node, call, report):
)
except Exception as e:
print(f"Failed to take screenshot for {node}: {e}")
+
+
+@pytest.fixture(
+ scope="session", params=[AppHarness, AppHarnessProd], ids=["dev", "prod"]
+)
+def app_harness_env(request):
+ """Parametrize the AppHarness class to use for the test, either dev or prod.
+
+ Args:
+ request: The pytest fixture request object.
+
+ Returns:
+ The AppHarness class to use for the test.
+ """
+ return request.param
diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py
index e654151dc..c76993a24 100644
--- a/integration/test_dynamic_routes.py
+++ b/integration/test_dynamic_routes.py
@@ -1,12 +1,12 @@
"""Integration tests for dynamic route page behavior."""
-import time
-from typing import Generator
+from typing import Callable, Generator, Type
from urllib.parse import urlsplit
import pytest
from selenium.webdriver.common.by import By
-from reflex.testing import AppHarness
+from reflex import State
+from reflex.testing import AppHarness, AppHarnessProd, WebDriver
from .utils import poll_for_navigation
@@ -20,7 +20,14 @@ def DynamicRoute():
page_id: str = ""
def on_load(self):
- self.order.append(self.page_id or "no page id")
+ self.order.append(
+ f"{self.get_current_page()}-{self.page_id or 'no page id'}"
+ )
+
+ def on_load_redir(self):
+ query_params = self.get_query_params()
+ self.order.append(f"on_load_redir-{query_params}")
+ return rx.redirect(f"/page/{query_params['page_id']}")
@rx.var
def next_page(self) -> str:
@@ -42,37 +49,46 @@ def DynamicRoute():
rx.link(
"next", href="/page/" + DynamicState.next_page, id="link_page_next" # type: ignore
),
+ rx.link("missing", href="/missing", id="link_missing"),
rx.list(
rx.foreach(DynamicState.order, lambda i: rx.list_item(rx.text(i))), # type: ignore
),
)
+ @rx.page(route="/redirect-page/[page_id]", on_load=DynamicState.on_load_redir) # type: ignore
+ def redirect_page():
+ return rx.fragment(rx.text("redirecting..."))
+
app = rx.App(state=DynamicState)
app.add_page(index)
app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load) # type: ignore
app.add_page(index, route="/static/x", on_load=DynamicState.on_load) # type: ignore
+ app.add_custom_404_page(on_load=DynamicState.on_load) # type: ignore
app.compile()
@pytest.fixture(scope="session")
-def dynamic_route(tmp_path_factory) -> Generator[AppHarness, None, None]:
+def dynamic_route(
+ app_harness_env: Type[AppHarness], tmp_path_factory
+) -> Generator[AppHarness, None, None]:
"""Start DynamicRoute app at tmp_path via AppHarness.
Args:
+ app_harness_env: either AppHarness (dev) or AppHarnessProd (prod)
tmp_path_factory: pytest tmp_path_factory fixture
Yields:
running AppHarness instance
"""
- with AppHarness.create(
- root=tmp_path_factory.mktemp("dynamic_route"),
+ with app_harness_env.create(
+ root=tmp_path_factory.mktemp(f"dynamic_route"),
app_source=DynamicRoute, # type: ignore
) as harness:
yield harness
@pytest.fixture
-def driver(dynamic_route: AppHarness):
+def driver(dynamic_route: AppHarness) -> Generator[WebDriver, None, None]:
"""Get an instance of the browser open to the dynamic_route app.
Args:
@@ -90,22 +106,70 @@ def driver(dynamic_route: AppHarness):
driver.quit()
-def test_on_load_navigate(dynamic_route: AppHarness, driver):
+@pytest.fixture()
+def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State:
+ """Get the backend state.
+
+ Args:
+ dynamic_route: harness for DynamicRoute app.
+ driver: WebDriver instance.
+
+ Returns:
+ The backend state associated with the token visible in the driver browser.
+ """
+ assert dynamic_route.app_instance is not None
+ token_input = driver.find_element(By.ID, "token")
+ assert token_input
+
+ # wait for the backend connection to send the token
+ token = dynamic_route.poll_for_value(token_input)
+ assert token is not None
+
+ # look up the backend state from the state manager
+ return dynamic_route.app_instance.state_manager.states[token]
+
+
+@pytest.fixture()
+def poll_for_order(
+ dynamic_route: AppHarness, backend_state: State
+) -> Callable[[list[str]], None]:
+ """Poll for the order list to match the expected order.
+
+ Args:
+ dynamic_route: harness for DynamicRoute app.
+ backend_state: The backend state associated with the token visible in the driver browser.
+
+ Returns:
+ A function that polls for the order list to match the expected order.
+ """
+
+ def _poll_for_order(exp_order: list[str]):
+ dynamic_route._poll_for(lambda: backend_state.order == exp_order)
+ assert backend_state.order == exp_order
+
+ return _poll_for_order
+
+
+def test_on_load_navigate(
+ dynamic_route: AppHarness,
+ driver: WebDriver,
+ backend_state: State,
+ poll_for_order: Callable[[list[str]], None],
+):
"""Click links to navigate between dynamic pages with on_load event.
Args:
dynamic_route: harness for DynamicRoute app.
driver: WebDriver instance.
+ backend_state: The backend state associated with the token visible in the driver browser.
+ poll_for_order: function that polls for the order list to match the expected order.
"""
assert dynamic_route.app_instance is not None
- token_input = driver.find_element(By.ID, "token")
+ is_prod = isinstance(dynamic_route, AppHarnessProd)
link = driver.find_element(By.ID, "link_page_next")
- assert token_input
assert link
- # wait for the backend connection to send the token
- token = dynamic_route.poll_for_value(token_input)
- assert token is not None
+ exp_order = [f"/page/[page-id]-{ix}" for ix in range(10)]
# click the link a few times
for ix in range(10):
@@ -121,40 +185,84 @@ def test_on_load_navigate(dynamic_route: AppHarness, driver):
assert page_id_input
assert dynamic_route.poll_for_value(page_id_input) == str(ix)
+ poll_for_order(exp_order)
- # look up the backend state and assert that `on_load` was called for all
- # navigation events
- backend_state = dynamic_route.app_instance.state_manager.states[token]
- time.sleep(0.2)
- assert backend_state.order == [str(ix) for ix in range(10)]
+ # manually load the next page to trigger client side routing in prod mode
+ if is_prod:
+ exp_order += ["/404-no page id"]
+ exp_order += ["/page/[page-id]-10"]
+ with poll_for_navigation(driver):
+ driver.get(f"{dynamic_route.frontend_url}/page/10/")
+ poll_for_order(exp_order)
+
+ # make sure internal nav still hydrates after redirect
+ exp_order += ["/page/[page-id]-11"]
+ link = driver.find_element(By.ID, "link_page_next")
+ with poll_for_navigation(driver):
+ link.click()
+ poll_for_order(exp_order)
+
+ # load same page with a query param and make sure it passes through
+ if is_prod:
+ exp_order += ["/404-no page id"]
+ exp_order += ["/page/[page-id]-11"]
+ with poll_for_navigation(driver):
+ driver.get(f"{driver.current_url}?foo=bar")
+ poll_for_order(exp_order)
+ assert backend_state.get_query_params()["foo"] == "bar"
+
+ # hit a 404 and ensure we still hydrate
+ exp_order += ["/404-no page id"]
+ with poll_for_navigation(driver):
+ driver.get(f"{dynamic_route.frontend_url}/missing")
+ poll_for_order(exp_order)
+
+ # browser nav should still trigger hydration
+ if is_prod:
+ exp_order += ["/404-no page id"]
+ exp_order += ["/page/[page-id]-11"]
+ with poll_for_navigation(driver):
+ driver.back()
+ poll_for_order(exp_order)
+
+ # next/link to a 404 and ensure we still hydrate
+ exp_order += ["/404-no page id"]
+ link = driver.find_element(By.ID, "link_missing")
+ with poll_for_navigation(driver):
+ link.click()
+ poll_for_order(exp_order)
+
+ # hit a page that redirects back to dynamic page
+ if is_prod:
+ exp_order += ["/404-no page id"]
+ exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page-id]-0"]
+ with poll_for_navigation(driver):
+ driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar")
+ poll_for_order(exp_order)
+ # should have redirected back to page 0
+ assert urlsplit(driver.current_url).path == "/page/0/"
-def test_on_load_navigate_non_dynamic(dynamic_route: AppHarness, driver):
+def test_on_load_navigate_non_dynamic(
+ dynamic_route: AppHarness,
+ driver: WebDriver,
+ poll_for_order: Callable[[list[str]], None],
+):
"""Click links to navigate between static pages with on_load event.
-
Args:
dynamic_route: harness for DynamicRoute app.
driver: WebDriver instance.
+ poll_for_order: function that polls for the order list to match the expected order.
"""
assert dynamic_route.app_instance is not None
- token_input = driver.find_element(By.ID, "token")
link = driver.find_element(By.ID, "link_page_x")
- assert token_input
assert link
- # wait for the backend connection to send the token
- token = dynamic_route.poll_for_value(token_input)
- assert token is not None
-
with poll_for_navigation(driver):
link.click()
assert urlsplit(driver.current_url).path == "/static/x/"
-
- # look up the backend state and assert that `on_load` was called once
- backend_state = dynamic_route.app_instance.state_manager.states[token]
- time.sleep(0.2)
- assert backend_state.order == ["no page id"]
+ poll_for_order(["/static/x-no page id"])
# go back to the index and navigate back to the static route
link = driver.find_element(By.ID, "link_index")
@@ -166,5 +274,4 @@ def test_on_load_navigate_non_dynamic(dynamic_route: AppHarness, driver):
with poll_for_navigation(driver):
link.click()
assert urlsplit(driver.current_url).path == "/static/x/"
- time.sleep(0.2)
- assert backend_state.order == ["no page id", "no page id"]
+ poll_for_order(["/static/x-no page id", "/static/x-no page id"])
diff --git a/reflex/.templates/web/pages/404.js b/reflex/.templates/web/pages/404.js
deleted file mode 100644
index dc03e1e89..000000000
--- a/reflex/.templates/web/pages/404.js
+++ /dev/null
@@ -1,19 +0,0 @@
-import Router from "next/router";
-import { useEffect, useState } from "react";
-
-export default function Custom404() {
- const [isNotFound, setIsNotFound] = useState(false);
-
- useEffect(() => {
- const pathNameArray = window.location.pathname.split("/");
- if (pathNameArray.length == 2 && pathNameArray[1] == "404") {
- setIsNotFound(true);
- } else {
- Router.replace(window.location.pathname);
- }
- }, []);
-
- if (isNotFound) return
404 - Page Not Found
;
-
- return null;
-}
diff --git a/reflex/.templates/web/utils/client_side_routing.js b/reflex/.templates/web/utils/client_side_routing.js
new file mode 100644
index 000000000..75fb581c8
--- /dev/null
+++ b/reflex/.templates/web/utils/client_side_routing.js
@@ -0,0 +1,36 @@
+import { useEffect, useRef, useState } from "react";
+import { useRouter } from "next/router";
+
+/**
+ * React hook for use in /404 page to enable client-side routing.
+ *
+ * Uses the next/router to redirect to the provided URL when loading
+ * the 404 page (for example as a fallback in static hosting situations).
+ *
+ * @returns {boolean} routeNotFound - true if the current route is an actual 404
+ */
+export const useClientSideRouting = () => {
+ const [routeNotFound, setRouteNotFound] = useState(false)
+ const didRedirect = useRef(false)
+ const router = useRouter()
+ useEffect(() => {
+ if (
+ router.isReady &&
+ !didRedirect.current // have not tried redirecting yet
+ ) {
+ didRedirect.current = true // never redirect twice to avoid "Hard Navigate" error
+ // attempt to redirect to the route in the browser address bar once
+ router.replace({
+ pathname: window.location.pathname,
+ query: window.location.search.slice(1),
+ })
+ .catch((e) => {
+ setRouteNotFound(true) // navigation failed, so this is a real 404
+ })
+ }
+ }, [router.isReady]);
+
+ // Return the reactive bool, to avoid flashing 404 page until we know for sure
+ // the route is not found.
+ return routeNotFound
+}
\ No newline at end of file
diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js
index beb8236bd..9b90145a0 100644
--- a/reflex/.templates/web/utils/state.js
+++ b/reflex/.templates/web/utils/state.js
@@ -173,10 +173,13 @@ export const applyEvent = async (event, socket) => {
return false;
}
- // Send the event to the server.
- event.token = getToken();
- event.router_data = (({ pathname, query, asPath }) => ({ pathname, query, asPath }))(Router);
+ // Update token and router data (if missing).
+ event.token = getToken()
+ if (event.router_data === undefined || Object.keys(event.router_data).length === 0) {
+ event.router_data = (({ pathname, query, asPath }) => ({ pathname, query, asPath }))(Router)
+ }
+ // Send the event to the server.
if (socket) {
socket.emit("event", JSON.stringify(event));
return true;
@@ -255,7 +258,6 @@ export const processEvent = async (
* @param dispatch The function to queue state update
* @param transports The transports to use.
* @param setConnectError The function to update connection error value.
- * @param initial_events Array of events to seed the queue after connecting.
* @param client_storage The client storage object from context.js
*/
export const connect = async (
@@ -263,7 +265,6 @@ export const connect = async (
dispatch,
transports,
setConnectError,
- initial_events = [],
client_storage = {},
) => {
// Get backend URL object from the endpoint.
@@ -277,7 +278,6 @@ export const connect = async (
// Once the socket is open, hydrate the page.
socket.current.on("connect", () => {
- queueEvents(initial_events, socket)
setConnectError(null)
});
@@ -427,8 +427,8 @@ const applyClientStorageDelta = (client_storage, delta) => {
/**
* Establish websocket event loop for a NextJS page.
- * @param initial_state The initial page state.
- * @param initial_events Array of events to seed the queue after connecting.
+ * @param initial_state The initial app state.
+ * @param initial_events The initial app events.
* @param client_storage The client storage object from context.js
*
* @returns [state, Event, connectError] -
@@ -452,6 +452,15 @@ export const useEventLoop = (
queueEvents(events, socket)
}
+ const sentHydrate = useRef(false); // Avoid double-hydrate due to React strict-mode
+ // initial state hydrate
+ useEffect(() => {
+ if (router.isReady && !sentHydrate.current) {
+ Event(initial_events.map((e) => ({...e})))
+ sentHydrate.current = true
+ }
+ }, [router.isReady])
+
// Main event loop.
useEffect(() => {
// Skip if the router is not ready.
@@ -461,7 +470,7 @@ export const useEventLoop = (
// Initialize the websocket connection.
if (!socket.current) {
- connect(socket, dispatch, ['websocket', 'polling'], setConnectError, initial_events, client_storage)
+ connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage)
}
(async () => {
// Process all outstanding events.
diff --git a/reflex/app.py b/reflex/app.py
index 84b0227be..a1ba828ec 100644
--- a/reflex/app.py
+++ b/reflex/app.py
@@ -32,6 +32,10 @@ from reflex.compiler import utils as compiler_utils
from reflex.components import connection_modal
from reflex.components.component import Component, ComponentStyle
from reflex.components.layout.fragment import Fragment
+from reflex.components.navigation.client_side_routing import (
+ Default404Page,
+ wait_for_client_redirect,
+)
from reflex.config import get_config
from reflex.event import Event, EventHandler, EventSpec
from reflex.middleware import HydrateMiddleware, Middleware
@@ -451,8 +455,10 @@ class App(Base):
on_load: The event handler(s) that will be called each time the page load.
meta: The metadata of the page.
"""
+ if component is None:
+ component = Default404Page.create()
self.add_page(
- component=component if component else Fragment.create(),
+ component=wait_for_client_redirect(self._generate_component(component)),
route=constants.SLUG_404,
title=title or constants.TITLE_404,
image=image or constants.FAVICON_404,
@@ -533,6 +539,10 @@ class App(Base):
for render, kwargs in DECORATED_PAGES:
self.add_page(render, **kwargs)
+ # Render a default 404 page if the user didn't supply one
+ if constants.SLUG_404 not in self.pages:
+ self.add_custom_404_page()
+
task = progress.add_task("Compiling: ", total=len(self.pages))
# TODO: include all work done in progress indicator, not just self.pages
diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py
index de5fbad8b..a177bdd71 100644
--- a/reflex/compiler/compiler.py
+++ b/reflex/compiler/compiler.py
@@ -276,5 +276,5 @@ def compile_tailwind(
def purge_web_pages_dir():
"""Empty out .web directory."""
- template_files = ["_app.js", "404.js"]
+ template_files = ["_app.js"]
utils.empty_dir(constants.WEB_PAGES_DIR, keep_files=template_files)
diff --git a/reflex/components/navigation/client_side_routing.py b/reflex/components/navigation/client_side_routing.py
new file mode 100644
index 000000000..b37f82e0f
--- /dev/null
+++ b/reflex/components/navigation/client_side_routing.py
@@ -0,0 +1,69 @@
+"""Handle dynamic routes in static exports via client-side routing.
+
+Works with /utils/client_side_routing.js to handle the redirect and state.
+
+When the user hits a 404 accessing a route, redirect them to the same page,
+setting a reactive state var "routeNotFound" to true if the redirect fails. The
+`wait_for_client_redirect` function will render the component only after
+routeNotFound becomes true.
+"""
+from __future__ import annotations
+
+from reflex import constants
+
+from ...vars import Var
+from ..component import Component
+from ..layout.cond import Cond
+
+route_not_found = Var.create_safe(constants.ROUTE_NOT_FOUND)
+
+
+class ClientSideRouting(Component):
+ """The client-side routing component."""
+
+ library = "/utils/client_side_routing"
+ tag = "useClientSideRouting"
+
+ def _get_hooks(self) -> str:
+ """Get the hooks to render.
+
+ Returns:
+ The useClientSideRouting hook.
+ """
+ return f"const {constants.ROUTE_NOT_FOUND} = {self.tag}()"
+
+ def render(self) -> str:
+ """Render the component.
+
+ Returns:
+ Empty string, because this component is only used for its hooks.
+ """
+ return ""
+
+
+def wait_for_client_redirect(component) -> Component:
+ """Wait for a redirect to occur before rendering a component.
+
+ This prevents the 404 page from flashing while the redirect is happening.
+
+ Args:
+ component: The component to render after the redirect.
+
+ Returns:
+ The conditionally rendered component.
+ """
+ return Cond.create(
+ cond=route_not_found,
+ comp1=component,
+ comp2=ClientSideRouting.create(),
+ )
+
+
+class Default404Page(Component):
+ """The NextJS default 404 page."""
+
+ library = "next/error"
+ tag = "Error"
+ is_default = True
+
+ status_code: Var[int] = 404 # type: ignore
diff --git a/reflex/constants.py b/reflex/constants.py
index 9d3d3ec05..ca4358f1b 100644
--- a/reflex/constants.py
+++ b/reflex/constants.py
@@ -344,6 +344,7 @@ SLUG_404 = "404"
TITLE_404 = "404 - Not Found"
FAVICON_404 = "favicon.ico"
DESCRIPTION_404 = "The page was not found"
+ROUTE_NOT_FOUND = "routeNotFound"
# Color mode variables
USE_COLOR_MODE = "useColorMode"
diff --git a/reflex/event.py b/reflex/event.py
index 4dc076baa..d098b0941 100644
--- a/reflex/event.py
+++ b/reflex/event.py
@@ -449,12 +449,17 @@ def get_handler_args(event_spec: EventSpec, arg: Var) -> tuple[tuple[Var, Var],
return event_spec.args if len(args) > 1 else tuple()
-def fix_events(events: list[EventHandler | EventSpec], token: str) -> list[Event]:
+def fix_events(
+ events: list[EventHandler | EventSpec],
+ token: str,
+ router_data: dict[str, Any] | None = None,
+) -> list[Event]:
"""Fix a list of events returned by an event handler.
Args:
events: The events to fix.
token: The user token.
+ router_data: The optional router data to set in the event.
Returns:
The fixed events.
@@ -485,6 +490,7 @@ def fix_events(events: list[EventHandler | EventSpec], token: str) -> list[Event
token=token,
name=name,
payload=payload,
+ router_data=router_data or {},
)
)
diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py
index 7fd465f78..77e0b3391 100644
--- a/reflex/middleware/hydrate_middleware.py
+++ b/reflex/middleware/hydrate_middleware.py
@@ -60,7 +60,7 @@ class HydrateMiddleware(Middleware):
# Add the on_load events and set is_hydrated to True.
events = [*app.get_load_events(route), type(state).set_is_hydrated(True)] # type: ignore
- events = fix_events(events, event.token)
+ events = fix_events(events, event.token, router_data=event.router_data)
# Return the state update.
return StateUpdate(delta=delta, events=events)
diff --git a/reflex/testing.py b/reflex/testing.py
index 168f33b9d..c40f7744d 100644
--- a/reflex/testing.py
+++ b/reflex/testing.py
@@ -10,11 +10,13 @@ import platform
import re
import signal
import socket
+import socketserver
import subprocess
import textwrap
import threading
import time
import types
+from http.server import SimpleHTTPRequestHandler
from typing import (
TYPE_CHECKING,
Any,
@@ -156,9 +158,9 @@ class AppHarness:
)
self.app_module_path.write_text(source_code)
with chdir(self.app_path):
- # ensure config is reloaded when testing different app
+ # ensure config and app are reloaded when testing different app
reflex.config.get_config(reload=True)
- self.app_module = reflex.utils.prerequisites.get_app()
+ self.app_module = reflex.utils.prerequisites.get_app(reload=True)
self.app_instance = self.app_module.app
def _start_backend(self):
@@ -461,3 +463,161 @@ class AppHarness:
):
raise TimeoutError("No states were observed while polling.")
return state_manager.states
+
+
+class SimpleHTTPRequestHandlerCustomErrors(SimpleHTTPRequestHandler):
+ """SimpleHTTPRequestHandler with custom error page handling."""
+
+ def __init__(self, *args, error_page_map: dict[int, pathlib.Path], **kwargs):
+ """Initialize the handler.
+
+ Args:
+ error_page_map: map of error code to error page path
+ *args: passed through to superclass
+ **kwargs: passed through to superclass
+ """
+ self.error_page_map = error_page_map
+ super().__init__(*args, **kwargs)
+
+ def send_error(
+ self, code: int, message: str | None = None, explain: str | None = None
+ ) -> None:
+ """Send the error page for the given error code.
+
+ If the code matches a custom error page, then message and explain are
+ ignored.
+
+ Args:
+ code: the error code
+ message: the error message
+ explain: the error explanation
+ """
+ error_page = self.error_page_map.get(code)
+ if error_page:
+ self.send_response(code, message)
+ self.send_header("Connection", "close")
+ body = error_page.read_bytes()
+ self.send_header("Content-Type", self.error_content_type)
+ self.send_header("Content-Length", str(len(body)))
+ self.end_headers()
+ self.wfile.write(body)
+ else:
+ super().send_error(code, message, explain)
+
+
+class Subdir404TCPServer(socketserver.TCPServer):
+ """TCPServer for SimpleHTTPRequestHandlerCustomErrors that serves from a subdir."""
+
+ def __init__(
+ self,
+ *args,
+ root: pathlib.Path,
+ error_page_map: dict[int, pathlib.Path] | None,
+ **kwargs,
+ ):
+ """Initialize the server.
+
+ Args:
+ root: the root directory to serve from
+ error_page_map: map of error code to error page path
+ *args: passed through to superclass
+ **kwargs: passed through to superclass
+ """
+ self.root = root
+ self.error_page_map = error_page_map or {}
+ super().__init__(*args, **kwargs)
+
+ def finish_request(self, request: socket.socket, client_address: tuple[str, int]):
+ """Finish one request by instantiating RequestHandlerClass.
+
+ Args:
+ request: the requesting socket
+ client_address: (host, port) referring to the client’s address.
+ """
+ print(client_address, type(client_address))
+ self.RequestHandlerClass(
+ request,
+ client_address,
+ self,
+ directory=str(self.root), # type: ignore
+ error_page_map=self.error_page_map, # type: ignore
+ )
+
+
+class AppHarnessProd(AppHarness):
+ """AppHarnessProd executes a reflex app in-process for testing.
+
+ In prod mode, instead of running `next dev` the app is exported as static
+ files and served via the builtin python http.server with custom 404 redirect
+ handling. Additionally, the backend runs in multi-worker mode.
+ """
+
+ frontend_thread: Optional[threading.Thread] = None
+ frontend_server: Optional[Subdir404TCPServer] = None
+
+ def _run_frontend(self):
+ web_root = self.app_path / reflex.constants.WEB_DIR / "_static"
+ error_page_map = {
+ 404: web_root / "404.html",
+ }
+ with Subdir404TCPServer(
+ ("", 0),
+ SimpleHTTPRequestHandlerCustomErrors,
+ root=web_root,
+ error_page_map=error_page_map,
+ ) as self.frontend_server:
+ self.frontend_url = "http://localhost:{1}".format(
+ *self.frontend_server.socket.getsockname()
+ )
+ self.frontend_server.serve_forever()
+
+ def _start_frontend(self):
+ # Set up the frontend.
+ with chdir(self.app_path):
+ config = reflex.config.get_config()
+ config.api_url = "http://{0}:{1}".format(
+ *self._poll_for_servers().getsockname(),
+ )
+ reflex.reflex.export(
+ zipping=False,
+ frontend=True,
+ backend=False,
+ loglevel=reflex.constants.LogLevel.INFO,
+ )
+
+ self.frontend_thread = threading.Thread(target=self._run_frontend)
+ self.frontend_thread.start()
+
+ def _wait_frontend(self):
+ self._poll_for(lambda: self.frontend_server is not None)
+ if self.frontend_server is None or not self.frontend_server.socket.fileno():
+ raise RuntimeError("Frontend did not start")
+
+ def _start_backend(self):
+ if self.app_instance is None:
+ raise RuntimeError("App was not initialized.")
+ os.environ[reflex.constants.SKIP_COMPILE_ENV_VAR] = "yes"
+ self.backend = uvicorn.Server(
+ uvicorn.Config(
+ app=self.app_instance,
+ host="127.0.0.1",
+ port=0,
+ workers=reflex.utils.processes.get_num_workers(),
+ ),
+ )
+ self.backend_thread = threading.Thread(target=self.backend.run)
+ self.backend_thread.start()
+
+ def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
+ try:
+ return super()._poll_for_servers(timeout)
+ finally:
+ os.environ.pop(reflex.constants.SKIP_COMPILE_ENV_VAR, None)
+
+ def stop(self):
+ """Stop the frontend python webserver."""
+ super().stop()
+ if self.frontend_server is not None:
+ self.frontend_server.shutdown()
+ if self.frontend_thread is not None:
+ self.frontend_thread.join()
diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py
index 54999d84b..28c6d268e 100644
--- a/reflex/utils/prerequisites.py
+++ b/reflex/utils/prerequisites.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import glob
+import importlib
import json
import os
import platform
@@ -97,16 +98,22 @@ def get_package_manager() -> str | None:
return path_ops.get_npm_path()
-def get_app() -> ModuleType:
+def get_app(reload: bool = False) -> ModuleType:
"""Get the app module based on the default config.
+ Args:
+ reload: Re-import the app module from disk
+
Returns:
The app based on the default config.
"""
config = get_config()
module = ".".join([config.app_name, config.app_name])
sys.path.insert(0, os.getcwd())
- return __import__(module, fromlist=(constants.APP_VAR,))
+ app = __import__(module, fromlist=(constants.APP_VAR,))
+ if reload:
+ importlib.reload(app)
+ return app
def get_redis() -> Redis | None:
diff --git a/tests/test_app.py b/tests/test_app.py
index e190f4e22..1e032500d 100644
--- a/tests/test_app.py
+++ b/tests/test_app.py
@@ -809,9 +809,17 @@ async def test_dynamic_route_var_route_change_completed_on_load(
)
for exp_index, exp_val in enumerate(exp_vals):
+ hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
+ exp_router_data = {
+ "headers": {},
+ "ip": client_ip,
+ "sid": sid,
+ "token": token,
+ **hydrate_event.router_data,
+ }
update = await process(
app,
- event=_event(name=get_hydrate_event(state), val=exp_val),
+ event=hydrate_event,
sid=sid,
headers={},
client_ip=client_ip,
@@ -830,12 +838,16 @@ async def test_dynamic_route_var_route_change_completed_on_load(
}
},
events=[
- _dynamic_state_event(name="on_load", val=exp_val, router_data={}),
+ _dynamic_state_event(
+ name="on_load",
+ val=exp_val,
+ router_data=exp_router_data,
+ ),
_dynamic_state_event(
name="set_is_hydrated",
payload={"value": True},
val=exp_val,
- router_data={},
+ router_data=exp_router_data,
),
],
)