diff --git a/docker-example/Caddyfile b/docker-example/Caddyfile index 3ab6b0815..275140b56 100644 --- a/docker-example/Caddyfile +++ b/docker-example/Caddyfile @@ -7,12 +7,8 @@ handle @backend_routes { reverse_proxy app:8000 } +root * /srv route { - try_files {path} {path}.html - file_server { - root /srv - pass_thru - } - # proxy dynamic routes to nextjs server - reverse_proxy app:3000 + try_files {path} {path}/ /404.html + file_server } diff --git a/docker-example/Dockerfile b/docker-example/Dockerfile index 11fde7b99..a0c68944c 100644 --- a/docker-example/Dockerfile +++ b/docker-example/Dockerfile @@ -8,9 +8,6 @@ ARG API_URL WORKDIR /app COPY . . -# Reflex will install bun, nvm, and node to `$HOME/.reflex` (/app/.reflex) -ENV HOME=/app - # Create virtualenv which will be copied into final container ENV VIRTUAL_ENV=/app/.venv ENV PATH="$VIRTUAL_ENV/bin:$PATH" @@ -22,9 +19,13 @@ RUN pip install -r requirements.txt # Deploy templates and prepare app RUN reflex init -# Export static copy of frontend to /app/.web/_static (and pre-install frontend packages) +# Export static copy of frontend to /app/.web/_static RUN reflex export --frontend-only --no-zip +# Copy static files out of /app to save space in backend image +RUN mv .web/_static /tmp/_static +RUN rm -rf .web && mkdir .web +RUN mv /tmp/_static .web/_static # Stage 2: copy artifacts into slim image FROM python:3.11-slim @@ -35,4 +36,4 @@ COPY --chown=reflex --from=init /app /app USER reflex ENV PATH="/app/.venv/bin:$PATH" API_URL=$API_URL -CMD reflex db migrate && reflex run --env prod +CMD reflex db migrate && reflex run --env prod --backend-only diff --git a/integration/conftest.py b/integration/conftest.py index 92e0613fb..2a13161a3 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -5,6 +5,8 @@ from pathlib import Path import pytest +from reflex.testing import AppHarness, AppHarnessProd + DISPLAY = None XVFB_DIMENSIONS = (800, 600) @@ -57,3 +59,18 @@ def pytest_exception_interact(node, call, report): ) except Exception as e: print(f"Failed to take screenshot for {node}: {e}") + + +@pytest.fixture( + scope="session", params=[AppHarness, AppHarnessProd], ids=["dev", "prod"] +) +def app_harness_env(request): + """Parametrize the AppHarness class to use for the test, either dev or prod. + + Args: + request: The pytest fixture request object. + + Returns: + The AppHarness class to use for the test. + """ + return request.param diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py index e654151dc..c76993a24 100644 --- a/integration/test_dynamic_routes.py +++ b/integration/test_dynamic_routes.py @@ -1,12 +1,12 @@ """Integration tests for dynamic route page behavior.""" -import time -from typing import Generator +from typing import Callable, Generator, Type from urllib.parse import urlsplit import pytest from selenium.webdriver.common.by import By -from reflex.testing import AppHarness +from reflex import State +from reflex.testing import AppHarness, AppHarnessProd, WebDriver from .utils import poll_for_navigation @@ -20,7 +20,14 @@ def DynamicRoute(): page_id: str = "" def on_load(self): - self.order.append(self.page_id or "no page id") + self.order.append( + f"{self.get_current_page()}-{self.page_id or 'no page id'}" + ) + + def on_load_redir(self): + query_params = self.get_query_params() + self.order.append(f"on_load_redir-{query_params}") + return rx.redirect(f"/page/{query_params['page_id']}") @rx.var def next_page(self) -> str: @@ -42,37 +49,46 @@ def DynamicRoute(): rx.link( "next", href="/page/" + DynamicState.next_page, id="link_page_next" # type: ignore ), + rx.link("missing", href="/missing", id="link_missing"), rx.list( rx.foreach(DynamicState.order, lambda i: rx.list_item(rx.text(i))), # type: ignore ), ) + @rx.page(route="/redirect-page/[page_id]", on_load=DynamicState.on_load_redir) # type: ignore + def redirect_page(): + return rx.fragment(rx.text("redirecting...")) + app = rx.App(state=DynamicState) app.add_page(index) app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load) # type: ignore app.add_page(index, route="/static/x", on_load=DynamicState.on_load) # type: ignore + app.add_custom_404_page(on_load=DynamicState.on_load) # type: ignore app.compile() @pytest.fixture(scope="session") -def dynamic_route(tmp_path_factory) -> Generator[AppHarness, None, None]: +def dynamic_route( + app_harness_env: Type[AppHarness], tmp_path_factory +) -> Generator[AppHarness, None, None]: """Start DynamicRoute app at tmp_path via AppHarness. Args: + app_harness_env: either AppHarness (dev) or AppHarnessProd (prod) tmp_path_factory: pytest tmp_path_factory fixture Yields: running AppHarness instance """ - with AppHarness.create( - root=tmp_path_factory.mktemp("dynamic_route"), + with app_harness_env.create( + root=tmp_path_factory.mktemp(f"dynamic_route"), app_source=DynamicRoute, # type: ignore ) as harness: yield harness @pytest.fixture -def driver(dynamic_route: AppHarness): +def driver(dynamic_route: AppHarness) -> Generator[WebDriver, None, None]: """Get an instance of the browser open to the dynamic_route app. Args: @@ -90,22 +106,70 @@ def driver(dynamic_route: AppHarness): driver.quit() -def test_on_load_navigate(dynamic_route: AppHarness, driver): +@pytest.fixture() +def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State: + """Get the backend state. + + Args: + dynamic_route: harness for DynamicRoute app. + driver: WebDriver instance. + + Returns: + The backend state associated with the token visible in the driver browser. + """ + assert dynamic_route.app_instance is not None + token_input = driver.find_element(By.ID, "token") + assert token_input + + # wait for the backend connection to send the token + token = dynamic_route.poll_for_value(token_input) + assert token is not None + + # look up the backend state from the state manager + return dynamic_route.app_instance.state_manager.states[token] + + +@pytest.fixture() +def poll_for_order( + dynamic_route: AppHarness, backend_state: State +) -> Callable[[list[str]], None]: + """Poll for the order list to match the expected order. + + Args: + dynamic_route: harness for DynamicRoute app. + backend_state: The backend state associated with the token visible in the driver browser. + + Returns: + A function that polls for the order list to match the expected order. + """ + + def _poll_for_order(exp_order: list[str]): + dynamic_route._poll_for(lambda: backend_state.order == exp_order) + assert backend_state.order == exp_order + + return _poll_for_order + + +def test_on_load_navigate( + dynamic_route: AppHarness, + driver: WebDriver, + backend_state: State, + poll_for_order: Callable[[list[str]], None], +): """Click links to navigate between dynamic pages with on_load event. Args: dynamic_route: harness for DynamicRoute app. driver: WebDriver instance. + backend_state: The backend state associated with the token visible in the driver browser. + poll_for_order: function that polls for the order list to match the expected order. """ assert dynamic_route.app_instance is not None - token_input = driver.find_element(By.ID, "token") + is_prod = isinstance(dynamic_route, AppHarnessProd) link = driver.find_element(By.ID, "link_page_next") - assert token_input assert link - # wait for the backend connection to send the token - token = dynamic_route.poll_for_value(token_input) - assert token is not None + exp_order = [f"/page/[page-id]-{ix}" for ix in range(10)] # click the link a few times for ix in range(10): @@ -121,40 +185,84 @@ def test_on_load_navigate(dynamic_route: AppHarness, driver): assert page_id_input assert dynamic_route.poll_for_value(page_id_input) == str(ix) + poll_for_order(exp_order) - # look up the backend state and assert that `on_load` was called for all - # navigation events - backend_state = dynamic_route.app_instance.state_manager.states[token] - time.sleep(0.2) - assert backend_state.order == [str(ix) for ix in range(10)] + # manually load the next page to trigger client side routing in prod mode + if is_prod: + exp_order += ["/404-no page id"] + exp_order += ["/page/[page-id]-10"] + with poll_for_navigation(driver): + driver.get(f"{dynamic_route.frontend_url}/page/10/") + poll_for_order(exp_order) + + # make sure internal nav still hydrates after redirect + exp_order += ["/page/[page-id]-11"] + link = driver.find_element(By.ID, "link_page_next") + with poll_for_navigation(driver): + link.click() + poll_for_order(exp_order) + + # load same page with a query param and make sure it passes through + if is_prod: + exp_order += ["/404-no page id"] + exp_order += ["/page/[page-id]-11"] + with poll_for_navigation(driver): + driver.get(f"{driver.current_url}?foo=bar") + poll_for_order(exp_order) + assert backend_state.get_query_params()["foo"] == "bar" + + # hit a 404 and ensure we still hydrate + exp_order += ["/404-no page id"] + with poll_for_navigation(driver): + driver.get(f"{dynamic_route.frontend_url}/missing") + poll_for_order(exp_order) + + # browser nav should still trigger hydration + if is_prod: + exp_order += ["/404-no page id"] + exp_order += ["/page/[page-id]-11"] + with poll_for_navigation(driver): + driver.back() + poll_for_order(exp_order) + + # next/link to a 404 and ensure we still hydrate + exp_order += ["/404-no page id"] + link = driver.find_element(By.ID, "link_missing") + with poll_for_navigation(driver): + link.click() + poll_for_order(exp_order) + + # hit a page that redirects back to dynamic page + if is_prod: + exp_order += ["/404-no page id"] + exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page-id]-0"] + with poll_for_navigation(driver): + driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar") + poll_for_order(exp_order) + # should have redirected back to page 0 + assert urlsplit(driver.current_url).path == "/page/0/" -def test_on_load_navigate_non_dynamic(dynamic_route: AppHarness, driver): +def test_on_load_navigate_non_dynamic( + dynamic_route: AppHarness, + driver: WebDriver, + poll_for_order: Callable[[list[str]], None], +): """Click links to navigate between static pages with on_load event. - Args: dynamic_route: harness for DynamicRoute app. driver: WebDriver instance. + poll_for_order: function that polls for the order list to match the expected order. """ assert dynamic_route.app_instance is not None - token_input = driver.find_element(By.ID, "token") link = driver.find_element(By.ID, "link_page_x") - assert token_input assert link - # wait for the backend connection to send the token - token = dynamic_route.poll_for_value(token_input) - assert token is not None - with poll_for_navigation(driver): link.click() assert urlsplit(driver.current_url).path == "/static/x/" - - # look up the backend state and assert that `on_load` was called once - backend_state = dynamic_route.app_instance.state_manager.states[token] - time.sleep(0.2) - assert backend_state.order == ["no page id"] + poll_for_order(["/static/x-no page id"]) # go back to the index and navigate back to the static route link = driver.find_element(By.ID, "link_index") @@ -166,5 +274,4 @@ def test_on_load_navigate_non_dynamic(dynamic_route: AppHarness, driver): with poll_for_navigation(driver): link.click() assert urlsplit(driver.current_url).path == "/static/x/" - time.sleep(0.2) - assert backend_state.order == ["no page id", "no page id"] + poll_for_order(["/static/x-no page id", "/static/x-no page id"]) diff --git a/reflex/.templates/web/pages/404.js b/reflex/.templates/web/pages/404.js deleted file mode 100644 index dc03e1e89..000000000 --- a/reflex/.templates/web/pages/404.js +++ /dev/null @@ -1,19 +0,0 @@ -import Router from "next/router"; -import { useEffect, useState } from "react"; - -export default function Custom404() { - const [isNotFound, setIsNotFound] = useState(false); - - useEffect(() => { - const pathNameArray = window.location.pathname.split("/"); - if (pathNameArray.length == 2 && pathNameArray[1] == "404") { - setIsNotFound(true); - } else { - Router.replace(window.location.pathname); - } - }, []); - - if (isNotFound) return

404 - Page Not Found

; - - return null; -} diff --git a/reflex/.templates/web/utils/client_side_routing.js b/reflex/.templates/web/utils/client_side_routing.js new file mode 100644 index 000000000..75fb581c8 --- /dev/null +++ b/reflex/.templates/web/utils/client_side_routing.js @@ -0,0 +1,36 @@ +import { useEffect, useRef, useState } from "react"; +import { useRouter } from "next/router"; + +/** + * React hook for use in /404 page to enable client-side routing. + * + * Uses the next/router to redirect to the provided URL when loading + * the 404 page (for example as a fallback in static hosting situations). + * + * @returns {boolean} routeNotFound - true if the current route is an actual 404 + */ +export const useClientSideRouting = () => { + const [routeNotFound, setRouteNotFound] = useState(false) + const didRedirect = useRef(false) + const router = useRouter() + useEffect(() => { + if ( + router.isReady && + !didRedirect.current // have not tried redirecting yet + ) { + didRedirect.current = true // never redirect twice to avoid "Hard Navigate" error + // attempt to redirect to the route in the browser address bar once + router.replace({ + pathname: window.location.pathname, + query: window.location.search.slice(1), + }) + .catch((e) => { + setRouteNotFound(true) // navigation failed, so this is a real 404 + }) + } + }, [router.isReady]); + + // Return the reactive bool, to avoid flashing 404 page until we know for sure + // the route is not found. + return routeNotFound +} \ No newline at end of file diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index beb8236bd..9b90145a0 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -173,10 +173,13 @@ export const applyEvent = async (event, socket) => { return false; } - // Send the event to the server. - event.token = getToken(); - event.router_data = (({ pathname, query, asPath }) => ({ pathname, query, asPath }))(Router); + // Update token and router data (if missing). + event.token = getToken() + if (event.router_data === undefined || Object.keys(event.router_data).length === 0) { + event.router_data = (({ pathname, query, asPath }) => ({ pathname, query, asPath }))(Router) + } + // Send the event to the server. if (socket) { socket.emit("event", JSON.stringify(event)); return true; @@ -255,7 +258,6 @@ export const processEvent = async ( * @param dispatch The function to queue state update * @param transports The transports to use. * @param setConnectError The function to update connection error value. - * @param initial_events Array of events to seed the queue after connecting. * @param client_storage The client storage object from context.js */ export const connect = async ( @@ -263,7 +265,6 @@ export const connect = async ( dispatch, transports, setConnectError, - initial_events = [], client_storage = {}, ) => { // Get backend URL object from the endpoint. @@ -277,7 +278,6 @@ export const connect = async ( // Once the socket is open, hydrate the page. socket.current.on("connect", () => { - queueEvents(initial_events, socket) setConnectError(null) }); @@ -427,8 +427,8 @@ const applyClientStorageDelta = (client_storage, delta) => { /** * Establish websocket event loop for a NextJS page. - * @param initial_state The initial page state. - * @param initial_events Array of events to seed the queue after connecting. + * @param initial_state The initial app state. + * @param initial_events The initial app events. * @param client_storage The client storage object from context.js * * @returns [state, Event, connectError] - @@ -452,6 +452,15 @@ export const useEventLoop = ( queueEvents(events, socket) } + const sentHydrate = useRef(false); // Avoid double-hydrate due to React strict-mode + // initial state hydrate + useEffect(() => { + if (router.isReady && !sentHydrate.current) { + Event(initial_events.map((e) => ({...e}))) + sentHydrate.current = true + } + }, [router.isReady]) + // Main event loop. useEffect(() => { // Skip if the router is not ready. @@ -461,7 +470,7 @@ export const useEventLoop = ( // Initialize the websocket connection. if (!socket.current) { - connect(socket, dispatch, ['websocket', 'polling'], setConnectError, initial_events, client_storage) + connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage) } (async () => { // Process all outstanding events. diff --git a/reflex/app.py b/reflex/app.py index 84b0227be..a1ba828ec 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -32,6 +32,10 @@ from reflex.compiler import utils as compiler_utils from reflex.components import connection_modal from reflex.components.component import Component, ComponentStyle from reflex.components.layout.fragment import Fragment +from reflex.components.navigation.client_side_routing import ( + Default404Page, + wait_for_client_redirect, +) from reflex.config import get_config from reflex.event import Event, EventHandler, EventSpec from reflex.middleware import HydrateMiddleware, Middleware @@ -451,8 +455,10 @@ class App(Base): on_load: The event handler(s) that will be called each time the page load. meta: The metadata of the page. """ + if component is None: + component = Default404Page.create() self.add_page( - component=component if component else Fragment.create(), + component=wait_for_client_redirect(self._generate_component(component)), route=constants.SLUG_404, title=title or constants.TITLE_404, image=image or constants.FAVICON_404, @@ -533,6 +539,10 @@ class App(Base): for render, kwargs in DECORATED_PAGES: self.add_page(render, **kwargs) + # Render a default 404 page if the user didn't supply one + if constants.SLUG_404 not in self.pages: + self.add_custom_404_page() + task = progress.add_task("Compiling: ", total=len(self.pages)) # TODO: include all work done in progress indicator, not just self.pages diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index de5fbad8b..a177bdd71 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -276,5 +276,5 @@ def compile_tailwind( def purge_web_pages_dir(): """Empty out .web directory.""" - template_files = ["_app.js", "404.js"] + template_files = ["_app.js"] utils.empty_dir(constants.WEB_PAGES_DIR, keep_files=template_files) diff --git a/reflex/components/navigation/client_side_routing.py b/reflex/components/navigation/client_side_routing.py new file mode 100644 index 000000000..b37f82e0f --- /dev/null +++ b/reflex/components/navigation/client_side_routing.py @@ -0,0 +1,69 @@ +"""Handle dynamic routes in static exports via client-side routing. + +Works with /utils/client_side_routing.js to handle the redirect and state. + +When the user hits a 404 accessing a route, redirect them to the same page, +setting a reactive state var "routeNotFound" to true if the redirect fails. The +`wait_for_client_redirect` function will render the component only after +routeNotFound becomes true. +""" +from __future__ import annotations + +from reflex import constants + +from ...vars import Var +from ..component import Component +from ..layout.cond import Cond + +route_not_found = Var.create_safe(constants.ROUTE_NOT_FOUND) + + +class ClientSideRouting(Component): + """The client-side routing component.""" + + library = "/utils/client_side_routing" + tag = "useClientSideRouting" + + def _get_hooks(self) -> str: + """Get the hooks to render. + + Returns: + The useClientSideRouting hook. + """ + return f"const {constants.ROUTE_NOT_FOUND} = {self.tag}()" + + def render(self) -> str: + """Render the component. + + Returns: + Empty string, because this component is only used for its hooks. + """ + return "" + + +def wait_for_client_redirect(component) -> Component: + """Wait for a redirect to occur before rendering a component. + + This prevents the 404 page from flashing while the redirect is happening. + + Args: + component: The component to render after the redirect. + + Returns: + The conditionally rendered component. + """ + return Cond.create( + cond=route_not_found, + comp1=component, + comp2=ClientSideRouting.create(), + ) + + +class Default404Page(Component): + """The NextJS default 404 page.""" + + library = "next/error" + tag = "Error" + is_default = True + + status_code: Var[int] = 404 # type: ignore diff --git a/reflex/constants.py b/reflex/constants.py index 9d3d3ec05..ca4358f1b 100644 --- a/reflex/constants.py +++ b/reflex/constants.py @@ -344,6 +344,7 @@ SLUG_404 = "404" TITLE_404 = "404 - Not Found" FAVICON_404 = "favicon.ico" DESCRIPTION_404 = "The page was not found" +ROUTE_NOT_FOUND = "routeNotFound" # Color mode variables USE_COLOR_MODE = "useColorMode" diff --git a/reflex/event.py b/reflex/event.py index 4dc076baa..d098b0941 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -449,12 +449,17 @@ def get_handler_args(event_spec: EventSpec, arg: Var) -> tuple[tuple[Var, Var], return event_spec.args if len(args) > 1 else tuple() -def fix_events(events: list[EventHandler | EventSpec], token: str) -> list[Event]: +def fix_events( + events: list[EventHandler | EventSpec], + token: str, + router_data: dict[str, Any] | None = None, +) -> list[Event]: """Fix a list of events returned by an event handler. Args: events: The events to fix. token: The user token. + router_data: The optional router data to set in the event. Returns: The fixed events. @@ -485,6 +490,7 @@ def fix_events(events: list[EventHandler | EventSpec], token: str) -> list[Event token=token, name=name, payload=payload, + router_data=router_data or {}, ) ) diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 7fd465f78..77e0b3391 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -60,7 +60,7 @@ class HydrateMiddleware(Middleware): # Add the on_load events and set is_hydrated to True. events = [*app.get_load_events(route), type(state).set_is_hydrated(True)] # type: ignore - events = fix_events(events, event.token) + events = fix_events(events, event.token, router_data=event.router_data) # Return the state update. return StateUpdate(delta=delta, events=events) diff --git a/reflex/testing.py b/reflex/testing.py index 168f33b9d..c40f7744d 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -10,11 +10,13 @@ import platform import re import signal import socket +import socketserver import subprocess import textwrap import threading import time import types +from http.server import SimpleHTTPRequestHandler from typing import ( TYPE_CHECKING, Any, @@ -156,9 +158,9 @@ class AppHarness: ) self.app_module_path.write_text(source_code) with chdir(self.app_path): - # ensure config is reloaded when testing different app + # ensure config and app are reloaded when testing different app reflex.config.get_config(reload=True) - self.app_module = reflex.utils.prerequisites.get_app() + self.app_module = reflex.utils.prerequisites.get_app(reload=True) self.app_instance = self.app_module.app def _start_backend(self): @@ -461,3 +463,161 @@ class AppHarness: ): raise TimeoutError("No states were observed while polling.") return state_manager.states + + +class SimpleHTTPRequestHandlerCustomErrors(SimpleHTTPRequestHandler): + """SimpleHTTPRequestHandler with custom error page handling.""" + + def __init__(self, *args, error_page_map: dict[int, pathlib.Path], **kwargs): + """Initialize the handler. + + Args: + error_page_map: map of error code to error page path + *args: passed through to superclass + **kwargs: passed through to superclass + """ + self.error_page_map = error_page_map + super().__init__(*args, **kwargs) + + def send_error( + self, code: int, message: str | None = None, explain: str | None = None + ) -> None: + """Send the error page for the given error code. + + If the code matches a custom error page, then message and explain are + ignored. + + Args: + code: the error code + message: the error message + explain: the error explanation + """ + error_page = self.error_page_map.get(code) + if error_page: + self.send_response(code, message) + self.send_header("Connection", "close") + body = error_page.read_bytes() + self.send_header("Content-Type", self.error_content_type) + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + else: + super().send_error(code, message, explain) + + +class Subdir404TCPServer(socketserver.TCPServer): + """TCPServer for SimpleHTTPRequestHandlerCustomErrors that serves from a subdir.""" + + def __init__( + self, + *args, + root: pathlib.Path, + error_page_map: dict[int, pathlib.Path] | None, + **kwargs, + ): + """Initialize the server. + + Args: + root: the root directory to serve from + error_page_map: map of error code to error page path + *args: passed through to superclass + **kwargs: passed through to superclass + """ + self.root = root + self.error_page_map = error_page_map or {} + super().__init__(*args, **kwargs) + + def finish_request(self, request: socket.socket, client_address: tuple[str, int]): + """Finish one request by instantiating RequestHandlerClass. + + Args: + request: the requesting socket + client_address: (host, port) referring to the client’s address. + """ + print(client_address, type(client_address)) + self.RequestHandlerClass( + request, + client_address, + self, + directory=str(self.root), # type: ignore + error_page_map=self.error_page_map, # type: ignore + ) + + +class AppHarnessProd(AppHarness): + """AppHarnessProd executes a reflex app in-process for testing. + + In prod mode, instead of running `next dev` the app is exported as static + files and served via the builtin python http.server with custom 404 redirect + handling. Additionally, the backend runs in multi-worker mode. + """ + + frontend_thread: Optional[threading.Thread] = None + frontend_server: Optional[Subdir404TCPServer] = None + + def _run_frontend(self): + web_root = self.app_path / reflex.constants.WEB_DIR / "_static" + error_page_map = { + 404: web_root / "404.html", + } + with Subdir404TCPServer( + ("", 0), + SimpleHTTPRequestHandlerCustomErrors, + root=web_root, + error_page_map=error_page_map, + ) as self.frontend_server: + self.frontend_url = "http://localhost:{1}".format( + *self.frontend_server.socket.getsockname() + ) + self.frontend_server.serve_forever() + + def _start_frontend(self): + # Set up the frontend. + with chdir(self.app_path): + config = reflex.config.get_config() + config.api_url = "http://{0}:{1}".format( + *self._poll_for_servers().getsockname(), + ) + reflex.reflex.export( + zipping=False, + frontend=True, + backend=False, + loglevel=reflex.constants.LogLevel.INFO, + ) + + self.frontend_thread = threading.Thread(target=self._run_frontend) + self.frontend_thread.start() + + def _wait_frontend(self): + self._poll_for(lambda: self.frontend_server is not None) + if self.frontend_server is None or not self.frontend_server.socket.fileno(): + raise RuntimeError("Frontend did not start") + + def _start_backend(self): + if self.app_instance is None: + raise RuntimeError("App was not initialized.") + os.environ[reflex.constants.SKIP_COMPILE_ENV_VAR] = "yes" + self.backend = uvicorn.Server( + uvicorn.Config( + app=self.app_instance, + host="127.0.0.1", + port=0, + workers=reflex.utils.processes.get_num_workers(), + ), + ) + self.backend_thread = threading.Thread(target=self.backend.run) + self.backend_thread.start() + + def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket: + try: + return super()._poll_for_servers(timeout) + finally: + os.environ.pop(reflex.constants.SKIP_COMPILE_ENV_VAR, None) + + def stop(self): + """Stop the frontend python webserver.""" + super().stop() + if self.frontend_server is not None: + self.frontend_server.shutdown() + if self.frontend_thread is not None: + self.frontend_thread.join() diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 54999d84b..28c6d268e 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -3,6 +3,7 @@ from __future__ import annotations import glob +import importlib import json import os import platform @@ -97,16 +98,22 @@ def get_package_manager() -> str | None: return path_ops.get_npm_path() -def get_app() -> ModuleType: +def get_app(reload: bool = False) -> ModuleType: """Get the app module based on the default config. + Args: + reload: Re-import the app module from disk + Returns: The app based on the default config. """ config = get_config() module = ".".join([config.app_name, config.app_name]) sys.path.insert(0, os.getcwd()) - return __import__(module, fromlist=(constants.APP_VAR,)) + app = __import__(module, fromlist=(constants.APP_VAR,)) + if reload: + importlib.reload(app) + return app def get_redis() -> Redis | None: diff --git a/tests/test_app.py b/tests/test_app.py index e190f4e22..1e032500d 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -809,9 +809,17 @@ async def test_dynamic_route_var_route_change_completed_on_load( ) for exp_index, exp_val in enumerate(exp_vals): + hydrate_event = _event(name=get_hydrate_event(state), val=exp_val) + exp_router_data = { + "headers": {}, + "ip": client_ip, + "sid": sid, + "token": token, + **hydrate_event.router_data, + } update = await process( app, - event=_event(name=get_hydrate_event(state), val=exp_val), + event=hydrate_event, sid=sid, headers={}, client_ip=client_ip, @@ -830,12 +838,16 @@ async def test_dynamic_route_var_route_change_completed_on_load( } }, events=[ - _dynamic_state_event(name="on_load", val=exp_val, router_data={}), + _dynamic_state_event( + name="on_load", + val=exp_val, + router_data=exp_router_data, + ), _dynamic_state_event( name="set_is_hydrated", payload={"value": True}, val=exp_val, - router_data={}, + router_data=exp_router_data, ), ], )