wip: more dynamic jinja contexts, tests for minification

This commit is contained in:
Benedikt Bartscher 2024-07-30 21:17:25 +02:00 committed by Masen Furer
parent 215a8343f4
commit dadfb5663a
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
19 changed files with 371 additions and 175 deletions

View File

@ -117,7 +117,7 @@ export const isStateful = () => {
if (event_queue.length === 0) { if (event_queue.length === 0) {
return false; return false;
} }
return event_queue.some((event) => event.name.startsWith("reflex___state")); return event_queue.some(event => event.name.includes("___"));
}; };
/** /**
@ -810,7 +810,7 @@ export const useEventLoop = (
const vars = {}; const vars = {};
vars[storage_to_state_map[e.key]] = e.newValue; vars[storage_to_state_map[e.key]] = e.newValue;
const event = Event( const event = Event(
`${state_name}.reflex___state____update_vars_internal_state.update_vars_internal`, `${state_name}.{{ update_vars_internal }}`,
{ vars: vars } { vars: vars }
); );
addEvents([event], e); addEvents([event], e);

View File

@ -34,7 +34,7 @@ def _compile_document_root(root: Component) -> str:
Returns: Returns:
The compiled document root. The compiled document root.
""" """
return templates.DOCUMENT_ROOT.render( return templates.document_root().render(
imports=utils.compile_imports(root._get_all_imports()), imports=utils.compile_imports(root._get_all_imports()),
document=root.render(), document=root.render(),
) )
@ -72,7 +72,7 @@ def _compile_app(app_root: Component) -> str:
("utils_state", f"$/{constants.Dirs.UTILS}/state"), ("utils_state", f"$/{constants.Dirs.UTILS}/state"),
] ]
return templates.APP_ROOT.render( return templates.app_root().render(
imports=utils.compile_imports(app_root._get_all_imports()), imports=utils.compile_imports(app_root._get_all_imports()),
custom_codes=app_root._get_all_custom_code(), custom_codes=app_root._get_all_custom_code(),
hooks={**app_root._get_all_hooks_internal(), **app_root._get_all_hooks()}, hooks={**app_root._get_all_hooks_internal(), **app_root._get_all_hooks()},
@ -90,7 +90,7 @@ def _compile_theme(theme: str) -> str:
Returns: Returns:
The compiled theme. The compiled theme.
""" """
return templates.THEME.render(theme=theme) return templates.theme().render(theme=theme)
def _compile_contexts(state: Optional[Type[BaseState]], theme: Component | None) -> str: def _compile_contexts(state: Optional[Type[BaseState]], theme: Component | None) -> str:
@ -109,7 +109,7 @@ def _compile_contexts(state: Optional[Type[BaseState]], theme: Component | None)
last_compiled_time = str(datetime.now()) last_compiled_time = str(datetime.now())
return ( return (
templates.CONTEXT.render( templates.context().render(
initial_state=utils.compile_state(state), initial_state=utils.compile_state(state),
state_name=state.get_name(), state_name=state.get_name(),
client_storage=utils.compile_client_storage(state), client_storage=utils.compile_client_storage(state),
@ -118,7 +118,7 @@ def _compile_contexts(state: Optional[Type[BaseState]], theme: Component | None)
default_color_mode=appearance, default_color_mode=appearance,
) )
if state if state
else templates.CONTEXT.render( else templates.context().render(
is_dev_mode=not is_prod_mode(), is_dev_mode=not is_prod_mode(),
default_color_mode=appearance, default_color_mode=appearance,
last_compiled_time=last_compiled_time, last_compiled_time=last_compiled_time,
@ -145,7 +145,7 @@ def _compile_page(
# Compile the code to render the component. # Compile the code to render the component.
kwargs = {"state_name": state.get_name()} if state is not None else {} kwargs = {"state_name": state.get_name()} if state is not None else {}
return templates.PAGE.render( return templates.page().render(
imports=imports, imports=imports,
dynamic_imports=component._get_all_dynamic_imports(), dynamic_imports=component._get_all_dynamic_imports(),
custom_codes=component._get_all_custom_code(), custom_codes=component._get_all_custom_code(),
@ -201,7 +201,7 @@ def _compile_root_stylesheet(stylesheets: list[str]) -> str:
) )
stylesheet = f"../{constants.Dirs.PUBLIC}/{stylesheet.strip('/')}" stylesheet = f"../{constants.Dirs.PUBLIC}/{stylesheet.strip('/')}"
sheets.append(stylesheet) if stylesheet not in sheets else None sheets.append(stylesheet) if stylesheet not in sheets else None
return templates.STYLE.render(stylesheets=sheets) return templates.style().render(stylesheets=sheets)
def _compile_component(component: Component | StatefulComponent) -> str: def _compile_component(component: Component | StatefulComponent) -> str:
@ -213,7 +213,7 @@ def _compile_component(component: Component | StatefulComponent) -> str:
Returns: Returns:
The compiled component. The compiled component.
""" """
return templates.COMPONENT.render(component=component) return templates.component().render(component=component)
def _compile_components( def _compile_components(
@ -241,7 +241,7 @@ def _compile_components(
# Compile the components page. # Compile the components page.
return ( return (
templates.COMPONENTS.render( templates.components().render(
imports=utils.compile_imports(imports), imports=utils.compile_imports(imports),
components=component_renders, components=component_renders,
), ),
@ -319,7 +319,7 @@ def _compile_stateful_components(
f"$/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None f"$/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None
) )
return templates.STATEFUL_COMPONENTS.render( return templates.stateful_components().render(
imports=utils.compile_imports(all_imports), imports=utils.compile_imports(all_imports),
memoized_code="\n".join(rendered_components), memoized_code="\n".join(rendered_components),
) )
@ -336,7 +336,7 @@ def _compile_tailwind(
Returns: Returns:
The compiled Tailwind config. The compiled Tailwind config.
""" """
return templates.TAILWIND_CONFIG.render( return templates.tailwind_config().render(
**config, **config,
) )

View File

@ -11,6 +11,12 @@ class ReflexJinjaEnvironment(Environment):
def __init__(self) -> None: def __init__(self) -> None:
"""Set default environment.""" """Set default environment."""
from reflex.state import (
FrontendEventExceptionState,
OnLoadInternalState,
UpdateVarsInternalState,
)
extensions = ["jinja2.ext.debug"] extensions = ["jinja2.ext.debug"]
super().__init__( super().__init__(
extensions=extensions, extensions=extensions,
@ -42,9 +48,9 @@ class ReflexJinjaEnvironment(Environment):
"set_color_mode": constants.ColorMode.SET, "set_color_mode": constants.ColorMode.SET,
"use_color_mode": constants.ColorMode.USE, "use_color_mode": constants.ColorMode.USE,
"hydrate": constants.CompileVars.HYDRATE, "hydrate": constants.CompileVars.HYDRATE,
"on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL, "on_load_internal": f"{OnLoadInternalState.get_name()}.on_load_internal",
"update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL, "update_vars_internal": f"{UpdateVarsInternalState.get_name()}.update_vars_internal",
"frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL, "frontend_exception_state": FrontendEventExceptionState.get_full_name(),
} }
@ -60,61 +66,172 @@ def get_template(name: str) -> Template:
return ReflexJinjaEnvironment().get_template(name=name) return ReflexJinjaEnvironment().get_template(name=name)
# Template for the Reflex config file. def rxconfig():
RXCONFIG = get_template("app/rxconfig.py.jinja2") """Template for the Reflex config file.
# Code to render a NextJS Document root. Returns:
DOCUMENT_ROOT = get_template("web/pages/_document.js.jinja2") Template: The template for the Reflex config file.
"""
return get_template("app/rxconfig.py.jinja2")
# Code to render NextJS App root.
APP_ROOT = get_template("web/pages/_app.js.jinja2")
# Template for the theme file. def document_root():
THEME = get_template("web/utils/theme.js.jinja2") """Code to render a NextJS Document root.
# Template for the context file. Returns:
CONTEXT = get_template("web/utils/context.js.jinja2") Template: The template for the NextJS Document root.
"""
return get_template("web/pages/_document.js.jinja2")
# Template for Tailwind config.
TAILWIND_CONFIG = get_template("web/tailwind.config.js.jinja2")
# Template to render a component tag. def app_root():
COMPONENT = get_template("web/pages/component.js.jinja2") """Code to render NextJS App root.
# Code to render a single NextJS page. Returns:
PAGE = get_template("web/pages/index.js.jinja2") Template: The template for the NextJS App root.
"""
return get_template("web/pages/_app.js.jinja2")
# Code to render the custom components page.
COMPONENTS = get_template("web/pages/custom_component.js.jinja2")
# Code to render Component instances as part of StatefulComponent def theme():
STATEFUL_COMPONENT = get_template("web/pages/stateful_component.js.jinja2") """Template for the theme file.
# Code to render StatefulComponent to an external file to be shared Returns:
STATEFUL_COMPONENTS = get_template("web/pages/stateful_components.js.jinja2") Template: The template for the theme file.
"""
return get_template("web/utils/theme.js.jinja2")
# Sitemap config file.
SITEMAP_CONFIG = "module.exports = {config}".format
# Code to render the root stylesheet. def context():
STYLE = get_template("web/styles/styles.css.jinja2") """Template for the context file.
# Code that generate the package json file Returns:
PACKAGE_JSON = get_template("web/package.json.jinja2") Template: The template for the context file.
"""
return get_template("web/utils/context.js.jinja2")
# Code that generate the pyproject.toml file for custom components.
CUSTOM_COMPONENTS_PYPROJECT_TOML = get_template(
"custom_components/pyproject.toml.jinja2"
)
# Code that generates the README file for custom components. def tailwind_config():
CUSTOM_COMPONENTS_README = get_template("custom_components/README.md.jinja2") """Template for Tailwind config.
# Code that generates the source file for custom components. Returns:
CUSTOM_COMPONENTS_SOURCE = get_template("custom_components/src.py.jinja2") Template: The template for the Tailwind config
"""
return get_template("web/tailwind.config.js.jinja2")
# Code that generates the init file for custom components.
CUSTOM_COMPONENTS_INIT_FILE = get_template("custom_components/__init__.py.jinja2")
# Code that generates the demo app main py file for testing custom components. def component():
CUSTOM_COMPONENTS_DEMO_APP = get_template("custom_components/demo_app.py.jinja2") """Template to render a component tag.
Returns:
Template: The template for the component tag.
"""
return get_template("web/pages/component.js.jinja2")
def page():
"""Code to render a single NextJS page.
Returns:
Template: The template for the NextJS page.
"""
return get_template("web/pages/index.js.jinja2")
def components():
"""Code to render the custom components page.
Returns:
Template: The template for the custom components page.
"""
return get_template("web/pages/custom_component.js.jinja2")
def stateful_component():
"""Code to render Component instances as part of StatefulComponent.
Returns:
Template: The template for the StatefulComponent.
"""
return get_template("web/pages/stateful_component.js.jinja2")
def stateful_components():
"""Code to render StatefulComponent to an external file to be shared.
Returns:
Template: The template for the StatefulComponent.
"""
return get_template("web/pages/stateful_components.js.jinja2")
def sitemap_config():
"""Sitemap config file.
Returns:
Template: The template for the sitemap config file.
"""
return "module.exports = {config}".format
def style():
"""Code to render the root stylesheet.
Returns:
Template: The template for the root stylesheet
"""
return get_template("web/styles/styles.css.jinja2")
def package_json():
"""Code that generate the package json file.
Returns:
Template: The template for the package json file
"""
return get_template("web/package.json.jinja2")
def custom_components_pyproject_toml():
"""Code that generate the pyproject.toml file for custom components.
Returns:
Template: The template for the pyproject.toml file
"""
return get_template("custom_components/pyproject.toml.jinja2")
def custom_components_readme():
"""Code that generates the README file for custom components.
Returns:
Template: The template for the README file
"""
return get_template("custom_components/README.md.jinja2")
def custom_components_source():
"""Code that generates the source file for custom components.
Returns:
Template: The template for the source file
"""
return get_template("custom_components/src.py.jinja2")
def custom_components_init():
"""Code that generates the init file for custom components.
Returns:
Template: The template for the init file
"""
return get_template("custom_components/__init__.py.jinja2")
def custom_components_demo_app():
"""Code that generates the demo app main py file for testing custom components.
Returns:
Template: The template for the demo app main py file
"""
return get_template("custom_components/demo_app.py.jinja2")

View File

@ -24,7 +24,7 @@ from typing import (
import reflex.state import reflex.state
from reflex.base import Base from reflex.base import Base
from reflex.compiler.templates import STATEFUL_COMPONENT from reflex.compiler.templates import stateful_component
from reflex.components.core.breakpoints import Breakpoints from reflex.components.core.breakpoints import Breakpoints
from reflex.components.dynamic import load_dynamic_serializer from reflex.components.dynamic import load_dynamic_serializer
from reflex.components.tags import Tag from reflex.components.tags import Tag
@ -2134,7 +2134,7 @@ class StatefulComponent(BaseComponent):
component.event_triggers[event_trigger] = memo_trigger component.event_triggers[event_trigger] = memo_trigger
# Render the code for this component and hooks. # Render the code for this component and hooks.
return STATEFUL_COMPONENT.render( return stateful_component().render(
tag_name=tag_name, tag_name=tag_name,
memo_trigger_hooks=memo_trigger_hooks, memo_trigger_hooks=memo_trigger_hooks,
component=component, component=component,

View File

@ -80,7 +80,7 @@ def load_dynamic_serializer():
) )
rendered_components[ rendered_components[
templates.STATEFUL_COMPONENT.render( templates.stateful_component().render(
tag_name="MySSRComponent", tag_name="MySSRComponent",
memo_trigger_hooks=[], memo_trigger_hooks=[],
component=component, component=component,
@ -101,10 +101,14 @@ def load_dynamic_serializer():
else: else:
imports[lib] = names imports[lib] = names
module_code_lines = templates.STATEFUL_COMPONENTS.render( module_code_lines = (
templates.stateful_components()
.render(
imports=utils.compile_imports(imports), imports=utils.compile_imports(imports),
memoized_code="\n".join(rendered_components), memoized_code="\n".join(rendered_components),
).splitlines()[1:] )
.splitlines()[1:]
)
# Rewrite imports from `/` to destructure from window # Rewrite imports from `/` to destructure from window
for ix, line in enumerate(module_code_lines[:]): for ix, line in enumerate(module_code_lines[:]):

View File

@ -545,6 +545,9 @@ class EnvironmentVariables:
# Where to save screenshots when tests fail. # Where to save screenshots when tests fail.
SCREENSHOT_DIR: EnvVar[Optional[Path]] = env_var(None) SCREENSHOT_DIR: EnvVar[Optional[Path]] = env_var(None)
# Whether to minify state names.
REFLEX_MINIFY_STATES: EnvVar[Optional[bool]] = env_var(False)
environment = EnvironmentVariables() environment = EnvironmentVariables()

View File

@ -6,7 +6,7 @@ from enum import Enum
from types import SimpleNamespace from types import SimpleNamespace
from reflex.base import Base from reflex.base import Base
from reflex.constants import ENV_MODE_ENV_VAR, Dirs, Env from reflex.constants import Dirs, Env
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
# The prefix used to create setters for state vars. # The prefix used to create setters for state vars.
@ -40,12 +40,14 @@ def minify_states() -> bool:
Returns: Returns:
True if states should be minified. True if states should be minified.
""" """
env = os.environ.get(ENV_MINIFY_STATES, None) from reflex.config import environment
env = environment.REFLEX_MINIFY_STATES.get()
if env is not None: if env is not None:
return env.lower() == "true" return env
# minify states in prod by default # minify states in prod by default
return os.environ.get(ENV_MODE_ENV_VAR, "") == Env.PROD.value return environment.REFLEX_ENV_MODE.get() == Env.PROD
class CompileVars(SimpleNamespace): class CompileVars(SimpleNamespace):
@ -80,34 +82,14 @@ class CompileVars(SimpleNamespace):
# The name of the function for converting a dict to an event. # The name of the function for converting a dict to an event.
TO_EVENT = "Event" TO_EVENT = "Event"
# Whether to minify states. @classmethod
MINIFY_STATES = minify_states() def MINIFY_STATES(cls) -> bool:
"""Whether to minify states.
# The name of the OnLoadInternal state. Returns:
ON_LOAD_INTERNAL_STATE = ( True if states should be minified.
"l" if MINIFY_STATES else "reflex___state____on_load_internal_state" """
) return minify_states()
# The name of the internal on_load event.
ON_LOAD_INTERNAL = f"{ON_LOAD_INTERNAL_STATE}.on_load_internal"
# The name of the UpdateVarsInternal state.
UPDATE_VARS_INTERNAL_STATE = (
"u" if MINIFY_STATES else "reflex___state____update_vars_internal_state"
)
# The name of the internal event to update generic state vars.
UPDATE_VARS_INTERNAL = f"{UPDATE_VARS_INTERNAL_STATE}.update_vars_internal"
# The name of the frontend event exception state
FRONTEND_EXCEPTION_STATE = (
"e" if MINIFY_STATES else "reflex___state____frontend_event_exception_state"
)
# The full name of the frontend exception state
FRONTEND_EXCEPTION_STATE_FULL = (
f"reflex___state____state.{FRONTEND_EXCEPTION_STATE}"
)
INTERNAL_STATE_NAMES = {
ON_LOAD_INTERNAL_STATE,
UPDATE_VARS_INTERNAL_STATE,
FRONTEND_EXCEPTION_STATE,
}
class PageNames(SimpleNamespace): class PageNames(SimpleNamespace):

View File

@ -65,7 +65,7 @@ def _create_package_config(module_name: str, package_name: str):
pyproject = Path(CustomComponents.PYPROJECT_TOML) pyproject = Path(CustomComponents.PYPROJECT_TOML)
pyproject.write_text( pyproject.write_text(
templates.CUSTOM_COMPONENTS_PYPROJECT_TOML.render( templates.custom_components_pyproject_toml().render(
module_name=module_name, module_name=module_name,
package_name=package_name, package_name=package_name,
reflex_version=constants.Reflex.VERSION, reflex_version=constants.Reflex.VERSION,
@ -106,7 +106,7 @@ def _create_readme(module_name: str, package_name: str):
readme = Path(CustomComponents.PACKAGE_README) readme = Path(CustomComponents.PACKAGE_README)
readme.write_text( readme.write_text(
templates.CUSTOM_COMPONENTS_README.render( templates.custom_components_readme().render(
module_name=module_name, module_name=module_name,
package_name=package_name, package_name=package_name,
) )
@ -129,14 +129,14 @@ def _write_source_and_init_py(
module_path = custom_component_src_dir / f"{module_name}.py" module_path = custom_component_src_dir / f"{module_name}.py"
module_path.write_text( module_path.write_text(
templates.CUSTOM_COMPONENTS_SOURCE.render( templates.custom_components_source().render(
component_class_name=component_class_name, module_name=module_name component_class_name=component_class_name, module_name=module_name
) )
) )
init_path = custom_component_src_dir / CustomComponents.INIT_FILE init_path = custom_component_src_dir / CustomComponents.INIT_FILE
init_path.write_text( init_path.write_text(
templates.CUSTOM_COMPONENTS_INIT_FILE.render(module_name=module_name) templates.custom_components_init.render(module_name=module_name)
) )
@ -164,7 +164,7 @@ def _populate_demo_app(name_variants: NameVariants):
# This source file is rendered using jinja template file. # This source file is rendered using jinja template file.
with open(f"{demo_app_name}/{demo_app_name}.py", "w") as f: with open(f"{demo_app_name}/{demo_app_name}.py", "w") as f:
f.write( f.write(
templates.CUSTOM_COMPONENTS_DEMO_APP.render( templates.custom_components_demo_app().render(
custom_component_module_dir=name_variants.custom_component_module_dir, custom_component_module_dir=name_variants.custom_component_module_dir,
module_name=name_variants.module_name, module_name=name_variants.module_name,
) )

View File

@ -288,7 +288,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
# Keep track of all state instances to calculate minified state names # Keep track of all state instances to calculate minified state names
state_count: int = 0 state_count: int = 0
all_state_names: Set[str] = set() minified_state_names: Dict[str, str] = {}
def next_minified_state_name() -> str: def next_minified_state_name() -> str:
@ -296,12 +296,8 @@ def next_minified_state_name() -> str:
Returns: Returns:
The next minified state name. The next minified state name.
Raises:
RuntimeError: If the minified state name already exists.
""" """
global state_count global state_count
global all_state_names
num = state_count num = state_count
# All possible chars for minified state name # All possible chars for minified state name
@ -318,25 +314,28 @@ def next_minified_state_name() -> str:
state_count += 1 state_count += 1
if state_name in all_state_names:
raise RuntimeError(f"Minified state name {state_name} already exists")
all_state_names.add(state_name)
return state_name return state_name
def generate_state_name() -> str: def get_minified_state_name(state_name: str) -> str:
"""Generate a minified state name. """Generate a minified state name.
Args:
state_name: The state name to minify.
Returns: Returns:
The minified state name. The minified state name.
Raises: Raises:
ValueError: If no more minified state names are available ValueError: If no more minified state names are available
""" """
if state_name in minified_state_names:
return minified_state_names[state_name]
while name := next_minified_state_name(): while name := next_minified_state_name():
if name in constants.CompileVars.INTERNAL_STATE_NAMES: if name in minified_state_names.values():
continue continue
minified_state_names[state_name] = name
return name return name
raise ValueError("No more minified state names available") raise ValueError("No more minified state names available")
@ -410,9 +409,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# A special event handler for setting base vars. # A special event handler for setting base vars.
setvar: ClassVar[EventHandler] setvar: ClassVar[EventHandler]
# Minified state name
_state_name: ClassVar[Optional[str]] = None
def __init__( def __init__(
self, self,
parent_state: BaseState | None = None, parent_state: BaseState | None = None,
@ -518,10 +514,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if "<locals>" in cls.__qualname__: if "<locals>" in cls.__qualname__:
cls._handle_local_def() cls._handle_local_def()
# Generate a minified state name by converting state count to string
if not cls._state_name or cls._state_name in all_state_names:
cls._state_name = generate_state_name()
# Validate the module name. # Validate the module name.
cls._validate_module_name() cls._validate_module_name()
@ -937,18 +929,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Returns: Returns:
The name of the state. The name of the state.
Raises:
RuntimeError: If the state name is not set.
""" """
if constants.CompileVars.MINIFY_STATES:
if not cls._state_name:
raise RuntimeError(
"State name minification is enabled, but state name is not set."
)
return cls._state_name
module = cls.__module__.replace(".", "___") module = cls.__module__.replace(".", "___")
return format.to_snake_case(f"{module}___{cls.__name__}") state_name = format.to_snake_case(f"{module}___{cls.__name__}")
if constants.compiler.CompileVars.MINIFY_STATES():
return get_minified_state_name(state_name)
return state_name
@classmethod @classmethod
@functools.lru_cache() @functools.lru_cache()
@ -2290,10 +2276,6 @@ def dynamic(func: Callable[[T], Component]):
class FrontendEventExceptionState(State): class FrontendEventExceptionState(State):
"""Substate for handling frontend exceptions.""" """Substate for handling frontend exceptions."""
_state_name: ClassVar[Optional[str]] = (
constants.CompileVars.FRONTEND_EXCEPTION_STATE
)
@event @event
def handle_frontend_exception(self, stack: str, component_stack: str) -> None: def handle_frontend_exception(self, stack: str, component_stack: str) -> None:
"""Handle frontend exceptions. """Handle frontend exceptions.
@ -2313,10 +2295,6 @@ class FrontendEventExceptionState(State):
class UpdateVarsInternalState(State): class UpdateVarsInternalState(State):
"""Substate for handling internal state var updates.""" """Substate for handling internal state var updates."""
_state_name: ClassVar[Optional[str]] = (
constants.CompileVars.UPDATE_VARS_INTERNAL_STATE
)
async def update_vars_internal(self, vars: dict[str, Any]) -> None: async def update_vars_internal(self, vars: dict[str, Any]) -> None:
"""Apply updates to fully qualified state vars. """Apply updates to fully qualified state vars.
@ -2342,8 +2320,6 @@ class OnLoadInternalState(State):
This is a separate substate to avoid deserializing the entire state tree for every page navigation. This is a separate substate to avoid deserializing the entire state tree for every page navigation.
""" """
_state_name: ClassVar[Optional[str]] = constants.CompileVars.ON_LOAD_INTERNAL_STATE
def on_load_internal(self) -> list[Event | EventSpec] | None: def on_load_internal(self) -> list[Event | EventSpec] | None:
"""Queue on_load handlers for the current page. """Queue on_load handlers for the current page.

View File

@ -46,12 +46,14 @@ import reflex.utils.processes
from reflex.config import environment from reflex.config import environment
from reflex.state import ( from reflex.state import (
BaseState, BaseState,
State,
StateManager, StateManager,
StateManagerDisk, StateManagerDisk,
StateManagerMemory, StateManagerMemory,
StateManagerRedis, StateManagerRedis,
reload_state_module, reload_state_module,
) )
from reflex.utils.types import override
try: try:
from selenium import webdriver # pyright: ignore [reportMissingImports] from selenium import webdriver # pyright: ignore [reportMissingImports]
@ -141,7 +143,7 @@ class AppHarness:
types.FunctionType | types.ModuleType | str | functools.partial[Any] types.FunctionType | types.ModuleType | str | functools.partial[Any]
] = None, ] = None,
app_name: Optional[str] = None, app_name: Optional[str] = None,
) -> "AppHarness": ) -> AppHarness:
"""Create an AppHarness instance at root. """Create an AppHarness instance at root.
Args: Args:
@ -191,7 +193,14 @@ class AppHarness:
Returns: Returns:
The state name The state name
Raises:
NotImplementedError: when minified state names are enabled
""" """
if reflex.constants.CompileVars.MINIFY_STATES():
raise NotImplementedError(
"This API is not available with minified state names."
)
return reflex.utils.format.to_snake_case( return reflex.utils.format.to_snake_case(
f"{self.app_name}___{self.app_name}___" + state_cls_name f"{self.app_name}___{self.app_name}___" + state_cls_name
) )
@ -207,7 +216,7 @@ class AppHarness:
""" """
# NOTE: using State.get_name() somehow causes trouble here # NOTE: using State.get_name() somehow causes trouble here
# path = [State.get_name()] + [self.get_state_name(p) for p in path] # path = [State.get_name()] + [self.get_state_name(p) for p in path]
path = ["reflex___state____state"] + [self.get_state_name(p) for p in path] path = [State.get_name()] + [self.get_state_name(p) for p in path]
return ".".join(path) return ".".join(path)
def _get_globals_from_signature(self, func: Any) -> dict[str, Any]: def _get_globals_from_signature(self, func: Any) -> dict[str, Any]:
@ -412,7 +421,7 @@ class AppHarness:
self.frontend_output_thread = threading.Thread(target=consume_frontend_output) self.frontend_output_thread = threading.Thread(target=consume_frontend_output)
self.frontend_output_thread.start() self.frontend_output_thread.start()
def start(self) -> "AppHarness": def start(self) -> AppHarness:
"""Start the backend in a new thread and dev frontend as a separate process. """Start the backend in a new thread and dev frontend as a separate process.
Returns: Returns:
@ -442,7 +451,7 @@ class AppHarness:
return f"{key} = {value!r}" return f"{key} = {value!r}"
return inspect.getsource(value) return inspect.getsource(value)
def __enter__(self) -> "AppHarness": def __enter__(self) -> AppHarness:
"""Contextmanager protocol for `start()`. """Contextmanager protocol for `start()`.
Returns: Returns:
@ -921,6 +930,7 @@ class AppHarnessProd(AppHarness):
) )
self.frontend_server.serve_forever() self.frontend_server.serve_forever()
@override
def _start_frontend(self): def _start_frontend(self):
# Set up the frontend. # Set up the frontend.
with chdir(self.app_path): with chdir(self.app_path):
@ -932,17 +942,19 @@ class AppHarnessProd(AppHarness):
zipping=False, zipping=False,
frontend=True, frontend=True,
backend=False, backend=False,
loglevel=reflex.constants.LogLevel.INFO, loglevel=reflex.constants.base.LogLevel.INFO,
) )
self.frontend_thread = threading.Thread(target=self._run_frontend) self.frontend_thread = threading.Thread(target=self._run_frontend)
self.frontend_thread.start() self.frontend_thread.start()
@override
def _wait_frontend(self): def _wait_frontend(self):
self._poll_for(lambda: self.frontend_server is not None) _ = self._poll_for(lambda: self.frontend_server is not None)
if self.frontend_server is None or not self.frontend_server.socket.fileno(): if self.frontend_server is None or not self.frontend_server.socket.fileno():
raise RuntimeError("Frontend did not start") raise RuntimeError("Frontend did not start")
@override
def _start_backend(self): def _start_backend(self):
if self.app_instance is None: if self.app_instance is None:
raise RuntimeError("App was not initialized.") raise RuntimeError("App was not initialized.")
@ -959,12 +971,25 @@ class AppHarnessProd(AppHarness):
self.backend_thread = threading.Thread(target=self.backend.run) self.backend_thread = threading.Thread(target=self.backend.run)
self.backend_thread.start() self.backend_thread.start()
@override
def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket: def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
try: try:
return super()._poll_for_servers(timeout) return super()._poll_for_servers(timeout)
finally: finally:
environment.REFLEX_SKIP_COMPILE.set(None) environment.REFLEX_SKIP_COMPILE.set(None)
@override
def start(self) -> AppHarnessProd:
"""Start AppHarnessProd instance.
Returns:
self
"""
environment.REFLEX_ENV_MODE.set(reflex.constants.base.Env.PROD)
_ = super().start()
return self
@override
def stop(self): def stop(self):
"""Stop the frontend python webserver.""" """Stop the frontend python webserver."""
super().stop() super().stop()
@ -972,3 +997,4 @@ class AppHarnessProd(AppHarness):
self.frontend_server.shutdown() self.frontend_server.shutdown()
if self.frontend_thread is not None: if self.frontend_thread is not None:
self.frontend_thread.join() self.frontend_thread.join()
environment.REFLEX_ENV_MODE.set(None)

View File

@ -44,7 +44,7 @@ def generate_sitemap_config(deploy_url: str, export=False):
config = json.dumps(config) config = json.dumps(config)
sitemap = prerequisites.get_web_dir() / constants.Next.SITEMAP_CONFIG_FILE sitemap = prerequisites.get_web_dir() / constants.Next.SITEMAP_CONFIG_FILE
sitemap.write_text(templates.SITEMAP_CONFIG(config=config)) sitemap.write_text(templates.sitemap_config()(config=config))
def _zip( def _zip(

View File

@ -436,7 +436,7 @@ def create_config(app_name: str):
config_name = f"{re.sub(r'[^a-zA-Z]', '', app_name).capitalize()}Config" config_name = f"{re.sub(r'[^a-zA-Z]', '', app_name).capitalize()}Config"
with open(constants.Config.FILE, "w") as f: with open(constants.Config.FILE, "w") as f:
console.debug(f"Creating {constants.Config.FILE}") console.debug(f"Creating {constants.Config.FILE}")
f.write(templates.RXCONFIG.render(app_name=app_name, config_name=config_name)) f.write(templates.rxconfig().render(app_name=app_name, config_name=config_name))
def initialize_gitignore( def initialize_gitignore(
@ -604,7 +604,7 @@ def initialize_web_directory():
def _compile_package_json(): def _compile_package_json():
return templates.PACKAGE_JSON.render( return templates.package_json().render(
scripts={ scripts={
"dev": constants.PackageJson.Commands.DEV, "dev": constants.PackageJson.Commands.DEV,
"export": constants.PackageJson.Commands.EXPORT, "export": constants.PackageJson.Commands.EXPORT,

View File

@ -3,9 +3,11 @@
import os import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Generator, Type
import pytest import pytest
import reflex.constants
from reflex.config import environment from reflex.config import environment
from reflex.testing import AppHarness, AppHarnessProd from reflex.testing import AppHarness, AppHarnessProd
@ -64,15 +66,30 @@ def pytest_exception_interact(node, call, report):
@pytest.fixture( @pytest.fixture(
scope="session", params=[AppHarness, AppHarnessProd], ids=["dev", "prod"] scope="session",
params=[
AppHarness,
AppHarnessProd,
],
ids=[
reflex.constants.Env.DEV.value,
reflex.constants.Env.PROD.value,
],
) )
def app_harness_env(request): def app_harness_env(
request: pytest.FixtureRequest,
) -> Generator[Type[AppHarness], None, None]:
"""Parametrize the AppHarness class to use for the test, either dev or prod. """Parametrize the AppHarness class to use for the test, either dev or prod.
Args: Args:
request: The pytest fixture request object. request: The pytest fixture request object.
Returns: Yields:
The AppHarness class to use for the test. The AppHarness class to use for the test.
""" """
return request.param harness: Type[AppHarness] = request.param
if issubclass(harness, AppHarnessProd):
environment.REFLEX_ENV_MODE.set(reflex.constants.Env.PROD)
yield harness
if issubclass(harness, AppHarnessProd):
environment.REFLEX_ENV_MODE.set(None)

View File

@ -106,7 +106,6 @@ def ComputedVars():
), ),
) )
# raise Exception(State.count3._deps(objclass=State))
app = rx.App() app = rx.App()
app.add_page(index) app.add_page(index)

View File

@ -2,20 +2,24 @@
from __future__ import annotations from __future__ import annotations
import time import os
from typing import Generator, Type from functools import partial
from typing import Generator, Optional, Type
import pytest import pytest
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.remote.webdriver import WebDriver from selenium.webdriver.remote.webdriver import WebDriver
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait
from reflex.testing import AppHarness from reflex.constants.compiler import ENV_MINIFY_STATES
from reflex.testing import AppHarness, AppHarnessProd
def TestApp(): def TestApp(minify: bool | None) -> None:
"""A test app for minified state names.""" """A test app for minified state names.
Args:
minify: whether to minify state names
"""
import reflex as rx import reflex as rx
class TestAppState(rx.State): class TestAppState(rx.State):
@ -25,7 +29,6 @@ def TestApp():
app = rx.App() app = rx.App()
@app.add_page
def index(): def index():
return rx.vstack( return rx.vstack(
rx.input( rx.input(
@ -33,27 +36,64 @@ def TestApp():
is_read_only=True, is_read_only=True,
id="token", id="token",
), ),
rx.text(f"minify: {minify}", id="minify"),
rx.text(TestAppState.get_name(), id="state_name"),
rx.text(TestAppState.get_full_name(), id="state_full_name"),
) )
app.add_page(index)
@pytest.fixture(scope="module")
@pytest.fixture(
params=[
pytest.param(False),
pytest.param(True),
pytest.param(None),
],
)
def minify_state_env(
request: pytest.FixtureRequest,
) -> Generator[Optional[bool], None, None]:
"""Set the environment variable to minify state names.
Args:
request: pytest fixture request
Yields:
minify_states: whether to minify state names
"""
minify_states: Optional[bool] = request.param
if minify_states is None:
_ = os.environ.pop(ENV_MINIFY_STATES, None)
else:
os.environ[ENV_MINIFY_STATES] = str(minify_states).lower()
yield minify_states
if minify_states is not None:
os.environ.pop(ENV_MINIFY_STATES, None)
@pytest.fixture
def test_app( def test_app(
app_harness_env: Type[AppHarness], tmp_path_factory: pytest.TempPathFactory app_harness_env: Type[AppHarness],
tmp_path_factory: pytest.TempPathFactory,
minify_state_env: Optional[bool],
) -> Generator[AppHarness, None, None]: ) -> Generator[AppHarness, None, None]:
"""Start TestApp app at tmp_path via AppHarness. """Start TestApp app at tmp_path via AppHarness.
Args: Args:
app_harness_env: either AppHarness (dev) or AppHarnessProd (prod) app_harness_env: either AppHarness (dev) or AppHarnessProd (prod)
tmp_path_factory: pytest tmp_path_factory fixture tmp_path_factory: pytest tmp_path_factory fixture
minify_state_env: need to request this fixture to set env before the app starts
Yields: Yields:
running AppHarness instance running AppHarness instance
""" """
name = f"testapp_{app_harness_env.__name__.lower()}"
with app_harness_env.create( with app_harness_env.create(
root=tmp_path_factory.mktemp("test_app"), root=tmp_path_factory.mktemp(name),
app_name=f"testapp_{app_harness_env.__name__.lower()}", app_name=name,
app_source=TestApp, # type: ignore app_source=partial(TestApp, minify=minify_state_env), # pyright: ignore[reportArgumentType]
) as harness: ) as harness:
yield harness yield harness
@ -80,16 +120,33 @@ def driver(test_app: AppHarness) -> Generator[WebDriver, None, None]:
def test_minified_states( def test_minified_states(
test_app: AppHarness, test_app: AppHarness,
driver: WebDriver, driver: WebDriver,
minify_state_env: Optional[bool],
) -> None: ) -> None:
"""Test minified state names. """Test minified state names.
Args: Args:
test_app: harness for TestApp test_app: harness for TestApp
driver: WebDriver instance. driver: WebDriver instance.
minify_state_env: whether state minification is enabled by env var.
""" """
assert test_app.app_instance is not None, "app is not running" assert test_app.app_instance is not None, "app is not running"
is_prod = isinstance(test_app, AppHarnessProd)
# default to minifying in production
should_minify: bool = is_prod
# env overrides default
if minify_state_env is not None:
should_minify = minify_state_env
# TODO: reload internal states, or refactor VarData to reference state object instead of name
if should_minify:
pytest.skip(
"minify tests are currently not working, because _var_set_states writes the state names during import time"
)
# get a reference to the connected client # get a reference to the connected client
token_input = driver.find_element(By.ID, "token") token_input = driver.find_element(By.ID, "token")
assert token_input assert token_input
@ -97,3 +154,20 @@ def test_minified_states(
# wait for the backend connection to send the token # wait for the backend connection to send the token
token = test_app.poll_for_value(token_input) token = test_app.poll_for_value(token_input)
assert token assert token
state_name_text = driver.find_element(By.ID, "state_name")
assert state_name_text
state_name = state_name_text.text
state_full_name_text = driver.find_element(By.ID, "state_full_name")
assert state_full_name_text
_ = state_full_name_text.text
assert test_app.app_module
module_state_prefix = test_app.app_module.__name__.replace(".", "___")
# prod_module_suffix = "prod" if is_prod else ""
if should_minify:
assert len(state_name) == 1
else:
assert state_name == f"{module_state_prefix}____test_app_state"

View File

@ -1,17 +1,14 @@
from typing import Set from typing import Set
from reflex.state import all_state_names, next_minified_state_name from reflex.state import next_minified_state_name
def test_next_minified_state_name(): def test_next_minified_state_name():
"""Test that the next_minified_state_name function returns unique state names.""" """Test that the next_minified_state_name function returns unique state names."""
current_state_count = len(all_state_names)
state_names: Set[str] = set() state_names: Set[str] = set()
gen: int = 10000 gen = 10000
for _ in range(gen): for _ in range(gen):
state_name = next_minified_state_name() state_name = next_minified_state_name()
assert state_name not in state_names assert state_name not in state_names
state_names.add(state_name) state_names.add(state_name)
assert len(state_names) == gen assert len(state_names) == gen
assert len(all_state_names) == current_state_count + gen

View File

@ -1032,7 +1032,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
prev_exp_val = "" prev_exp_val = ""
for exp_index, exp_val in enumerate(exp_vals): for exp_index, exp_val in enumerate(exp_vals):
on_load_internal = _event( on_load_internal = _event(
name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL.rpartition('.')[2]}", name=f"{state.get_full_name()}.on_load_internal",
val=exp_val, val=exp_val,
) )
exp_router_data = { exp_router_data = {

View File

@ -54,6 +54,7 @@ CI = bool(os.environ.get("CI", False))
LOCK_EXPIRATION = 2000 if CI else 300 LOCK_EXPIRATION = 2000 if CI else 300
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
ON_LOAD_INTERNAL = f"{OnLoadInternalState.get_name()}.on_load_internal"
formatted_router = { formatted_router = {
"session": {"client_token": "", "client_ip": "", "session_id": ""}, "session": {"client_token": "", "client_ip": "", "session_id": ""},
@ -2793,7 +2794,7 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
app=app, app=app,
event=Event( event=Event(
token=token, token=token,
name=f"{state.get_name()}.{CompileVars.ON_LOAD_INTERNAL}", name=f"{state.get_name()}.{ON_LOAD_INTERNAL}",
router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}}, router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}},
), ),
sid="sid", sid="sid",
@ -2840,7 +2841,7 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
app=app, app=app,
event=Event( event=Event(
token=token, token=token,
name=f"{state.get_full_name()}.{CompileVars.ON_LOAD_INTERNAL}", name=f"{state.get_full_name()}.{ON_LOAD_INTERNAL}",
router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}}, router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}},
), ),
sid="sid", sid="sid",

View File

@ -275,7 +275,7 @@ def test_unsupported_literals(cls: type):
], ],
) )
def test_create_config(app_name, expected_config_name, mocker): def test_create_config(app_name, expected_config_name, mocker):
"""Test templates.RXCONFIG is formatted with correct app name and config class name. """Test templates.rxconfig is formatted with correct app name and config class name.
Args: Args:
app_name: App name. app_name: App name.
@ -283,9 +283,9 @@ def test_create_config(app_name, expected_config_name, mocker):
mocker: Mocker object. mocker: Mocker object.
""" """
mocker.patch("builtins.open") mocker.patch("builtins.open")
tmpl_mock = mocker.patch("reflex.compiler.templates.RXCONFIG") tmpl_mock = mocker.patch("reflex.compiler.templates.rxconfig")
prerequisites.create_config(app_name) prerequisites.create_config(app_name)
tmpl_mock.render.assert_called_with( tmpl_mock().render.assert_called_with(
app_name=app_name, config_name=expected_config_name app_name=app_name, config_name=expected_config_name
) )