use original mp

This commit is contained in:
Khaleel Al-Adhami 2024-08-27 15:59:25 -07:00
parent 5af03683e9
commit 5adc10691c
3 changed files with 45 additions and 24 deletions

View File

@ -3,12 +3,14 @@
from __future__ import annotations
import asyncio
import concurrent.futures
import contextlib
import copy
import dataclasses
import functools
import inspect
import io
import multiprocessing
import os
import platform
import sys
@ -34,7 +36,6 @@ from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.middleware import cors
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from pathos import pools
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from socketio import ASGIApp, AsyncNamespace, AsyncServer
from starlette_admin.contrib.sqla.admin import Admin
@ -48,6 +49,7 @@ from reflex.compiler import compiler
from reflex.compiler import utils as compiler_utils
from reflex.compiler.compiler import (
ExecutorSafeFunctions,
compile_theme,
)
from reflex.components.base.app_wrap import AppWrap
from reflex.components.base.error_boundary import ErrorBoundary
@ -934,9 +936,27 @@ class App(MiddlewareMixin, LifespanMixin, Base):
for route in self.uncompiled_pages:
self._compile_page(route)
executor = pools.ProcessPool()
executor = concurrent.futures.ProcessPoolExecutor(
max_workers=int(os.environ.get("REFLEX_COMPILE_PROCESSES", 0)) or None,
mp_context=multiprocessing.get_context("fork"),
)
else:
executor = pools.ThreadPool()
executor = concurrent.futures.ThreadPoolExecutor(
max_workers=int(os.environ.get("REFLEX_COMPILE_THREADS", 0)) or None,
)
for route, component in self.pages.items():
component._add_style_recursive(self.style, self.theme)
ExecutorSafeFunctions.COMPONENTS[route] = component
for route, page in self.uncompiled_pages.items():
if route in self.pages:
continue
ExecutorSafeFunctions.UNCOMPILED_PAGES[route] = page
ExecutorSafeFunctions.STATE = self.state
pages_results = []
@ -945,39 +965,36 @@ class App(MiddlewareMixin, LifespanMixin, Base):
pages_futures = []
def _submit_work(fn, *args, **kwargs):
f = executor.apipe(fn, *args, **kwargs)
f = executor.submit(fn, *args, **kwargs)
# f = executor.apipe(fn, *args, **kwargs)
result_futures.append(f)
# Compile all page components.
for route, page in self.uncompiled_pages.items():
for route in self.uncompiled_pages:
if route in self.pages:
continue
f = executor.apipe(
f = executor.submit(
ExecutorSafeFunctions.compile_uncompiled_page,
route,
page,
self.state,
self.style,
self.theme,
)
pages_futures.append((route, f))
pages_futures.append(f)
# Compile the pre-compiled pages.
for route, component in self.pages.items():
component._add_style_recursive(self.style, self.theme)
for route in self.pages:
_submit_work(
ExecutorSafeFunctions.compile_page,
route,
component,
self.state,
)
# Compile the root stylesheet with base styles.
_submit_work(compiler.compile_root_stylesheet, self.stylesheets)
# Compile the theme.
_submit_work(ExecutorSafeFunctions.compile_theme, self.style)
_submit_work(compile_theme, self.style)
# Compile the Tailwind config.
if config.tailwind is not None:
@ -989,12 +1006,12 @@ class App(MiddlewareMixin, LifespanMixin, Base):
_submit_work(compiler.remove_tailwind_from_postcss)
# Wait for all compilation tasks to complete.
for future in result_futures:
compile_results.append(future.get())
for future in concurrent.futures.as_completed(result_futures):
compile_results.append(future.result())
progress.advance(task)
for _, future in pages_futures:
pages_results.append(future.get())
for future in concurrent.futures.as_completed(pages_futures):
pages_results.append(future.result())
progress.advance(task)
for route, component, compiled_page in pages_results:

View File

@ -587,10 +587,12 @@ class ExecutorSafeFunctions:
"""
COMPONENTS: Dict[str, Component] = {}
UNCOMPILED_PAGES: Dict[str, UncompiledPage] = {}
STATE: Optional[Type[BaseState]] = None
@classmethod
def compile_page(
cls, route: str, component: Component, state: Type[BaseState]
) -> tuple[str, str]:
def compile_page(cls, route: str) -> tuple[str, str]:
"""Compile a page.
Args:
@ -601,13 +603,12 @@ class ExecutorSafeFunctions:
Returns:
The path and code of the compiled page.
"""
return compile_page(route, component, state)
return compile_page(route, cls.COMPONENTS[route], cls.STATE)
@classmethod
def compile_uncompiled_page(
cls,
route: str,
page: UncompiledPage,
state: Type[BaseState],
style: ComponentStyle,
theme: Component,
@ -624,7 +625,7 @@ class ExecutorSafeFunctions:
Returns:
The route, compiled component, and compiled page.
"""
component = compile_uncompiled_page_helper(route, page)
component = compile_uncompiled_page_helper(route, cls.UNCOMPILED_PAGES[route])
component = component if isinstance(component, Component) else component()
component._add_style_recursive(style, theme)
return route, component, compile_page(route, component, state)

View File

@ -267,7 +267,10 @@ class Style(dict):
_var = LiteralVar.create(value)
if _var is not None:
# Carry the imports/hooks when setting a Var as a value.
self._var_data = VarData.merge(self._var_data, _var._get_all_var_data())
self._var_data = VarData.merge(
self._var_data,
_var._get_all_var_data(),
)
super().__setitem__(key, value)