Expose app.add_all_routes_endpoint for flexgen
This commit is contained in:
parent
5950336745
commit
e9cf4ce1a5
@ -614,6 +614,8 @@ class App(MiddlewareMixin, LifespanMixin):
|
||||
self.api.get(str(constants.Endpoint.AUTH_CODESPACE))(
|
||||
codespaces.auth_codespace
|
||||
)
|
||||
if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get():
|
||||
self.add_all_routes_endpoint()
|
||||
|
||||
def _add_cors(self):
|
||||
"""Add CORS middleware to the app."""
|
||||
@ -1360,7 +1362,7 @@ class App(MiddlewareMixin, LifespanMixin):
|
||||
for output_path, code in compile_results:
|
||||
compiler_utils.write_page(output_path, code)
|
||||
|
||||
# Pickle dynamic states
|
||||
# Write list of routes that create dynamic states for backend to use.
|
||||
if self._state is not None:
|
||||
stateful_pages_marker = (
|
||||
prerequisites.get_backend_dir() / constants.Dirs.STATEFUL_PAGES
|
||||
@ -1369,6 +1371,15 @@ class App(MiddlewareMixin, LifespanMixin):
|
||||
with stateful_pages_marker.open("w") as f:
|
||||
json.dump(list(self._stateful_pages), f)
|
||||
|
||||
def add_all_routes_endpoint(self):
|
||||
"""Add an endpoint to the app that returns all the routes."""
|
||||
if not self.api:
|
||||
return
|
||||
|
||||
@self.api.get(str(constants.Endpoint.ALL_ROUTES))
|
||||
async def all_routes():
|
||||
return list(self._unevaluated_pages.keys())
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
||||
"""Modify the state out of band.
|
||||
|
@ -712,6 +712,9 @@ class EnvironmentVariables:
|
||||
# Paths to exclude from the hot reload. Takes precedence over include paths. Separated by a colon.
|
||||
REFLEX_HOT_RELOAD_EXCLUDE_PATHS: EnvVar[List[Path]] = env_var([])
|
||||
|
||||
# Used by flexgen to enumerate the pages.
|
||||
REFLEX_ADD_ALL_ROUTES_ENDPOINT: EnvVar[bool] = env_var(False)
|
||||
|
||||
|
||||
environment = EnvironmentVariables()
|
||||
|
||||
|
@ -12,6 +12,7 @@ class Endpoint(Enum):
|
||||
UPLOAD = "_upload"
|
||||
AUTH_CODESPACE = "auth-codespace"
|
||||
HEALTH = "_health"
|
||||
ALL_ROUTES = "_all_routes"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Get the string representation of the endpoint.
|
||||
|
Loading…
Reference in New Issue
Block a user