From a8756cb0f83f00bfe024b0da3b4c740688ac6aaa Mon Sep 17 00:00:00 2001 From: jackie-pc Date: Fri, 12 Jan 2024 14:22:18 -0800 Subject: [PATCH] use process pool to compile faster (#2377) --- reflex/app.py | 184 ++++++++++++++++++++---------------- reflex/compiler/compiler.py | 110 +++++++++++++++++++++ 2 files changed, 213 insertions(+), 81 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index fabafd31e..4c99b3982 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -6,7 +6,9 @@ import concurrent.futures import contextlib import copy import functools +import multiprocessing import os +import platform from typing import ( Any, AsyncIterator, @@ -35,6 +37,7 @@ from reflex.admin import AdminDash from reflex.base import Base from reflex.compiler import compiler from reflex.compiler import utils as compiler_utils +from reflex.compiler.compiler import ExecutorSafeFunctions from reflex.components import connection_modal from reflex.components.base.app_wrap import AppWrap from reflex.components.base.fragment import Fragment @@ -661,15 +664,24 @@ class App(Base): TimeElapsedColumn(), ) + # try to be somewhat accurate - but still not 100% + adhoc_steps_without_executor = 6 + fixed_pages_within_executor = 7 + progress.start() + task = progress.add_task( + "Compiling:", + total=len(self.pages) + + fixed_pages_within_executor + + adhoc_steps_without_executor, + ) + # Get the env mode. config = get_config() # Store the compile results. compile_results = [] - # Compile the pages in parallel. custom_components = set() - # TODO Anecdotally, processes=2 works 10% faster (cpu_count=12) all_imports = {} app_wrappers: Dict[tuple[int, str], Component] = { # Default app wrap component renders {children} @@ -679,127 +691,137 @@ class App(Base): # If a theme component was provided, wrap the app with it app_wrappers[(20, "Theme")] = self.theme - with progress, concurrent.futures.ThreadPoolExecutor() as thread_pool: - fixed_pages = 7 - task = progress.add_task("Compiling:", total=len(self.pages) + fixed_pages) + progress.advance(task) - def mark_complete(_=None): - progress.advance(task) + for _route, component in self.pages.items(): + # Merge the component style with the app style. + component.add_style(self.style) - for _route, component in self.pages.items(): - # Merge the component style with the app style. - component.add_style(self.style) + component.apply_theme(self.theme) - component.apply_theme(self.theme) + # Add component.get_imports() to all_imports. + all_imports.update(component.get_imports()) - # Add component.get_imports() to all_imports. - all_imports.update(component.get_imports()) + # Add the app wrappers from this component. + app_wrappers.update(component.get_app_wrap_components()) - # Add the app wrappers from this component. - app_wrappers.update(component.get_app_wrap_components()) + # Add the custom components from the page to the set. + custom_components |= component.get_custom_components() - # Add the custom components from the page to the set. - custom_components |= component.get_custom_components() + progress.advance(task) - # Perform auto-memoization of stateful components. - ( - stateful_components_path, - stateful_components_code, - page_components, - ) = compiler.compile_stateful_components(self.pages.values()) + # Perform auto-memoization of stateful components. + ( + stateful_components_path, + stateful_components_code, + page_components, + ) = compiler.compile_stateful_components(self.pages.values()) - # Catch "static" apps (that do not define a rx.State subclass) which are trying to access rx.State. - if ( - code_uses_state_contexts(stateful_components_code) - and self.state is None - ): - raise RuntimeError( - "To access rx.State in frontend components, at least one " - "subclass of rx.State must be defined in the app." - ) - compile_results.append((stateful_components_path, stateful_components_code)) + progress.advance(task) + # Catch "static" apps (that do not define a rx.State subclass) which are trying to access rx.State. + if code_uses_state_contexts(stateful_components_code) and self.state is None: + raise RuntimeError( + "To access rx.State in frontend components, at least one " + "subclass of rx.State must be defined in the app." + ) + compile_results.append((stateful_components_path, stateful_components_code)) + + app_root = self._app_root(app_wrappers=app_wrappers) + + progress.advance(task) + + # Prepopulate the global ExecutorSafeFunctions class with input data required by the compile functions. + # This is required for multiprocessing to work, in presence of non-picklable inputs. + for route, component in zip(self.pages, page_components): + ExecutorSafeFunctions.COMPILE_PAGE_ARGS_BY_ROUTE[route] = ( + route, + component, + self.state, + ) + + ExecutorSafeFunctions.COMPILE_APP_APP_ROOT = app_root + ExecutorSafeFunctions.CUSTOM_COMPONENTS = custom_components + ExecutorSafeFunctions.HEAD_COMPONENTS = self.head_components + ExecutorSafeFunctions.STYLE = self.style + ExecutorSafeFunctions.STATE = self.state + + # Use a forking process pool, if possible. Much faster, especially for large sites. + # Fallback to ThreadPoolExecutor as something that will always work. + executor = None + if platform.system() in ("Linux", "Darwin"): + executor = concurrent.futures.ProcessPoolExecutor( + mp_context=multiprocessing.get_context("fork") + ) + else: + executor = concurrent.futures.ThreadPoolExecutor() + + with executor: result_futures = [] - def submit_work(fn, *args, **kwargs): - """Submit work to the thread pool and add a callback to mark the task as complete. + def _mark_complete(_=None): + progress.advance(task) - The Future will be added to the `result_futures` list. - - Args: - fn: The function to submit. - *args: The args to submit. - **kwargs: The kwargs to submit. - """ - f = thread_pool.submit(fn, *args, **kwargs) - f.add_done_callback(mark_complete) + def _submit_work(fn, *args, **kwargs): + f = executor.submit(fn, *args, **kwargs) + f.add_done_callback(_mark_complete) result_futures.append(f) # Compile all page components. - for route, component in zip(self.pages, page_components): - submit_work( - compiler.compile_page, - route, - component, - self.state, - ) + for route in self.pages: + _submit_work(ExecutorSafeFunctions.compile_page, route) # Compile the app wrapper. - app_root = self._app_root(app_wrappers=app_wrappers) - submit_work(compiler.compile_app, app_root) + _submit_work(ExecutorSafeFunctions.compile_app) # Compile the custom components. - submit_work(compiler.compile_components, custom_components) + _submit_work(ExecutorSafeFunctions.compile_custom_components) # Compile the root stylesheet with base styles. - submit_work(compiler.compile_root_stylesheet, self.stylesheets) + _submit_work(compiler.compile_root_stylesheet, self.stylesheets) # Compile the root document. - submit_work(compiler.compile_document_root, self.head_components) + _submit_work(ExecutorSafeFunctions.compile_document_root) # Compile the theme. - submit_work(compiler.compile_theme, style=self.style) + _submit_work(ExecutorSafeFunctions.compile_theme) # Compile the contexts. - submit_work(compiler.compile_contexts, self.state) + _submit_work(ExecutorSafeFunctions.compile_contexts) # Compile the Tailwind config. if config.tailwind is not None: config.tailwind["content"] = config.tailwind.get( "content", constants.Tailwind.CONTENT ) - submit_work(compiler.compile_tailwind, config.tailwind) + _submit_work(compiler.compile_tailwind, config.tailwind) else: - submit_work(compiler.remove_tailwind_from_postcss) - - # Get imports from AppWrap components. - all_imports.update(app_root.get_imports()) - - # Iterate through all the custom components and add their imports to the all_imports. - for component in custom_components: - all_imports.update(component.get_imports()) + _submit_work(compiler.remove_tailwind_from_postcss) # Wait for all compilation tasks to complete. for future in concurrent.futures.as_completed(result_futures): compile_results.append(future.result()) - # Empty the .web pages directory. - compiler.purge_web_pages_dir() + # Get imports from AppWrap components. + all_imports.update(app_root.get_imports()) - # Avoid flickering when installing frontend packages - progress.stop() + # Iterate through all the custom components and add their imports to the all_imports. + for component in custom_components: + all_imports.update(component.get_imports()) - # Install frontend packages. - self.get_frontend_packages(all_imports) + progress.advance(task) - # Write the pages at the end to trigger the NextJS hot reload only once. - write_page_futures = [] - for output_path, code in compile_results: - write_page_futures.append( - thread_pool.submit(compiler_utils.write_page, output_path, code) - ) - for future in concurrent.futures.as_completed(write_page_futures): - future.result() + # Empty the .web pages directory. + compiler.purge_web_pages_dir() + + progress.advance(task) + progress.stop() + + # Install frontend packages. + self.get_frontend_packages(all_imports) + + for output_path, code in compile_results: + compiler_utils.write_page(output_path, code) @contextlib.asynccontextmanager async def modify_state(self, token: str) -> AsyncIterator[BaseState]: diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 51101ac6a..2b73a3689 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -454,3 +454,113 @@ def remove_tailwind_from_postcss() -> tuple[str, str]: def purge_web_pages_dir(): """Empty out .web directory.""" utils.empty_dir(constants.Dirs.WEB_PAGES, keep_files=["_app.js"]) + + +class ExecutorSafeFunctions: + """Helper class to allow parallelisation of parts of the compilation process. + + This class (and its class attributes) are available at global scope. + + In a multiprocessing context (like when using a ProcessPoolExecutor), the content of this + global class is logically replicated to any FORKED process. + + How it works: + * Before the child process is forked, ensure that we stash any input data required by any future + function call in the child process. + * After the child process is forked, the child process will have a copy of the global class, which + includes the previously stashed input data. + * Any task submitted to the child process simply needs a way to communicate which input data the + requested function call requires. + + Why do we need this? Passing input data directly to child process often not possible because the input data is not picklable. + The mechanic described here removes the need to pickle the input data at all. + + Limitations: + * This can never support returning unpicklable OUTPUT data. + * Any object mutations done by the child process will not propagate back to the parent process (fork goes one way!). + + """ + + COMPILE_PAGE_ARGS_BY_ROUTE = {} + COMPILE_APP_APP_ROOT: Component | None = None + CUSTOM_COMPONENTS: set[CustomComponent] | None = None + HEAD_COMPONENTS: list[Component] | None = None + STYLE: ComponentStyle | None = None + STATE: type[BaseState] | None = None + + @classmethod + def compile_page(cls, route: str): + """Compile a page. + + Args: + route: The route of the page to compile. + + Returns: + The path and code of the compiled page. + """ + return compile_page(*cls.COMPILE_PAGE_ARGS_BY_ROUTE[route]) + + @classmethod + def compile_app(cls): + """Compile the app. + + Returns: + The path and code of the compiled app. + + Raises: + ValueError: If the app root is not set. + """ + if cls.COMPILE_APP_APP_ROOT is None: + raise ValueError("COMPILE_APP_APP_ROOT should be set") + return compile_app(cls.COMPILE_APP_APP_ROOT) + + @classmethod + def compile_custom_components(cls): + """Compile the custom components. + + Returns: + The path and code of the compiled custom components. + + Raises: + ValueError: If the custom components are not set. + """ + if cls.CUSTOM_COMPONENTS is None: + raise ValueError("CUSTOM_COMPONENTS should be set") + return compile_components(cls.CUSTOM_COMPONENTS) + + @classmethod + def compile_document_root(cls): + """Compile the document root. + + Returns: + The path and code of the compiled document root. + + Raises: + ValueError: If the head components are not set. + """ + if cls.HEAD_COMPONENTS is None: + raise ValueError("HEAD_COMPONENTS should be set") + return compile_document_root(cls.HEAD_COMPONENTS) + + @classmethod + def compile_theme(cls): + """Compile the theme. + + Returns: + The path and code of the compiled theme. + + Raises: + ValueError: If the style is not set. + """ + if cls.STYLE is None: + raise ValueError("STYLE should be set") + return compile_theme(cls.STYLE) + + @classmethod + def compile_contexts(cls): + """Compile the contexts. + + Returns: + The path and code of the compiled contexts. + """ + return compile_contexts(cls.STATE)