diff --git a/reflex/app.py b/reflex/app.py index d399a9941..aebbbc3a0 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -3,12 +3,14 @@ from __future__ import annotations import asyncio +import concurrent.futures import contextlib import copy import dataclasses import functools import inspect import io +import multiprocessing import os import platform import sys @@ -34,7 +36,6 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile from fastapi.middleware import cors from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles -from pathos import pools from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp, AsyncNamespace, AsyncServer from starlette_admin.contrib.sqla.admin import Admin @@ -48,6 +49,7 @@ from reflex.compiler import compiler from reflex.compiler import utils as compiler_utils from reflex.compiler.compiler import ( ExecutorSafeFunctions, + compile_theme, ) from reflex.components.base.app_wrap import AppWrap from reflex.components.base.error_boundary import ErrorBoundary @@ -934,9 +936,27 @@ class App(MiddlewareMixin, LifespanMixin, Base): for route in self.uncompiled_pages: self._compile_page(route) - executor = pools.ProcessPool() + executor = concurrent.futures.ProcessPoolExecutor( + max_workers=int(os.environ.get("REFLEX_COMPILE_PROCESSES", 0)) or None, + mp_context=multiprocessing.get_context("fork"), + ) else: - executor = pools.ThreadPool() + executor = concurrent.futures.ThreadPoolExecutor( + max_workers=int(os.environ.get("REFLEX_COMPILE_THREADS", 0)) or None, + ) + + for route, component in self.pages.items(): + component._add_style_recursive(self.style, self.theme) + + ExecutorSafeFunctions.COMPONENTS[route] = component + + for route, page in self.uncompiled_pages.items(): + if route in self.pages: + continue + + ExecutorSafeFunctions.UNCOMPILED_PAGES[route] = page + + ExecutorSafeFunctions.STATE = self.state pages_results = [] @@ -945,39 +965,36 @@ class App(MiddlewareMixin, LifespanMixin, Base): pages_futures = [] def _submit_work(fn, *args, **kwargs): - f = executor.apipe(fn, *args, **kwargs) + f = executor.submit(fn, *args, **kwargs) + # f = executor.apipe(fn, *args, **kwargs) result_futures.append(f) # Compile all page components. - for route, page in self.uncompiled_pages.items(): + for route in self.uncompiled_pages: if route in self.pages: continue - f = executor.apipe( + f = executor.submit( ExecutorSafeFunctions.compile_uncompiled_page, route, - page, self.state, self.style, self.theme, ) - pages_futures.append((route, f)) + pages_futures.append(f) # Compile the pre-compiled pages. - for route, component in self.pages.items(): - component._add_style_recursive(self.style, self.theme) + for route in self.pages: _submit_work( ExecutorSafeFunctions.compile_page, route, - component, - self.state, ) # Compile the root stylesheet with base styles. _submit_work(compiler.compile_root_stylesheet, self.stylesheets) # Compile the theme. - _submit_work(ExecutorSafeFunctions.compile_theme, self.style) + _submit_work(compile_theme, self.style) # Compile the Tailwind config. if config.tailwind is not None: @@ -989,12 +1006,12 @@ class App(MiddlewareMixin, LifespanMixin, Base): _submit_work(compiler.remove_tailwind_from_postcss) # Wait for all compilation tasks to complete. - for future in result_futures: - compile_results.append(future.get()) + for future in concurrent.futures.as_completed(result_futures): + compile_results.append(future.result()) progress.advance(task) - for _, future in pages_futures: - pages_results.append(future.get()) + for future in concurrent.futures.as_completed(pages_futures): + pages_results.append(future.result()) progress.advance(task) for route, component, compiled_page in pages_results: diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 425132f68..b99a862d8 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -587,10 +587,12 @@ class ExecutorSafeFunctions: """ + COMPONENTS: Dict[str, Component] = {} + UNCOMPILED_PAGES: Dict[str, UncompiledPage] = {} + STATE: Optional[Type[BaseState]] = None + @classmethod - def compile_page( - cls, route: str, component: Component, state: Type[BaseState] - ) -> tuple[str, str]: + def compile_page(cls, route: str) -> tuple[str, str]: """Compile a page. Args: @@ -601,13 +603,12 @@ class ExecutorSafeFunctions: Returns: The path and code of the compiled page. """ - return compile_page(route, component, state) + return compile_page(route, cls.COMPONENTS[route], cls.STATE) @classmethod def compile_uncompiled_page( cls, route: str, - page: UncompiledPage, state: Type[BaseState], style: ComponentStyle, theme: Component, @@ -624,7 +625,7 @@ class ExecutorSafeFunctions: Returns: The route, compiled component, and compiled page. """ - component = compile_uncompiled_page_helper(route, page) + component = compile_uncompiled_page_helper(route, cls.UNCOMPILED_PAGES[route]) component = component if isinstance(component, Component) else component() component._add_style_recursive(style, theme) return route, component, compile_page(route, component, state) diff --git a/reflex/style.py b/reflex/style.py index a2083f634..0a8bb8db4 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -267,7 +267,10 @@ class Style(dict): _var = LiteralVar.create(value) if _var is not None: # Carry the imports/hooks when setting a Var as a value. - self._var_data = VarData.merge(self._var_data, _var._get_all_var_data()) + self._var_data = VarData.merge( + self._var_data, + _var._get_all_var_data(), + ) super().__setitem__(key, value)