Track state usage (#2441)
* rebase * pass include_children kwarg in radix FormRoot * respect include_children * ruff fixes * readd statemanager init, run pyi gen * minor performance imporovements, fix for state changes * fix pyi and pyright * pass include_children for chakra * remove old state detection * add test for unused states in stateless app --------- Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
parent
5d647a498f
commit
19a5cdd408
@ -28,6 +28,9 @@ def TailwindApp(
|
||||
import reflex as rx
|
||||
import reflex.components.radix.themes as rdxt
|
||||
|
||||
class UnusedState(rx.State):
|
||||
pass
|
||||
|
||||
def index():
|
||||
return rx.el.div(
|
||||
rx.chakra.text(paragraph_text, class_name=paragraph_class_name),
|
||||
|
152
reflex/app.py
152
reflex/app.py
@ -91,6 +91,12 @@ def default_overlay_component() -> Component:
|
||||
return Fragment.create(connection_pulser(), connection_modal())
|
||||
|
||||
|
||||
class OverlayFragment(Fragment):
|
||||
"""Alias for Fragment, used to wrap the overlay_component."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class App(Base):
|
||||
"""A Reflex application."""
|
||||
|
||||
@ -159,7 +165,7 @@ class App(Base):
|
||||
|
||||
Raises:
|
||||
ValueError: If the event namespace is not provided in the config.
|
||||
Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist
|
||||
Also, if there are multiple client subclasses of rx.BaseState(Subclasses of rx.BaseState should consist
|
||||
of the DefaultState and the client app state).
|
||||
"""
|
||||
if "connect_error_component" in kwargs:
|
||||
@ -167,12 +173,12 @@ class App(Base):
|
||||
"`connect_error_component` is deprecated, use `overlay_component` instead"
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
state_subclasses = BaseState.__subclasses__()
|
||||
base_state_subclasses = BaseState.__subclasses__()
|
||||
|
||||
# Special case to allow test cases have multiple subclasses of rx.BaseState.
|
||||
if not is_testing_env():
|
||||
# Only one Base State class is allowed.
|
||||
if len(state_subclasses) > 1:
|
||||
if len(base_state_subclasses) > 1:
|
||||
raise ValueError(
|
||||
"rx.BaseState cannot be subclassed multiple times. use rx.State instead"
|
||||
)
|
||||
@ -184,12 +190,6 @@ class App(Base):
|
||||
deprecation_version="0.3.5",
|
||||
removal_version="0.5.0",
|
||||
)
|
||||
# 2 substates are built-in and not considered when determining if app is stateless.
|
||||
if len(State.class_subclasses) > 2:
|
||||
self.state = State
|
||||
# Get the config
|
||||
config = get_config()
|
||||
|
||||
# Add middleware.
|
||||
self.middleware.append(HydrateMiddleware())
|
||||
|
||||
@ -198,45 +198,60 @@ class App(Base):
|
||||
self.add_cors()
|
||||
self.add_default_endpoints()
|
||||
|
||||
if self.state:
|
||||
# Set up the state manager.
|
||||
self._state_manager = StateManager.create(state=self.state)
|
||||
|
||||
# Set up the Socket.IO AsyncServer.
|
||||
self.sio = AsyncServer(
|
||||
async_mode="asgi",
|
||||
cors_allowed_origins=(
|
||||
"*"
|
||||
if config.cors_allowed_origins == ["*"]
|
||||
else config.cors_allowed_origins
|
||||
),
|
||||
cors_credentials=True,
|
||||
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
|
||||
ping_interval=constants.Ping.INTERVAL,
|
||||
ping_timeout=constants.Ping.TIMEOUT,
|
||||
)
|
||||
|
||||
# Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
|
||||
self.socket_app = ASGIApp(self.sio, socketio_path="")
|
||||
namespace = config.get_event_namespace()
|
||||
|
||||
if not namespace:
|
||||
raise ValueError("event namespace must be provided in the config.")
|
||||
|
||||
# Create the event namespace and attach the main app. Not related to any paths.
|
||||
self.event_namespace = EventNamespace(namespace, self)
|
||||
|
||||
# Register the event namespace with the socket.
|
||||
self.sio.register_namespace(self.event_namespace)
|
||||
# Mount the socket app with the API.
|
||||
self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
|
||||
self.setup_state()
|
||||
|
||||
# Set up the admin dash.
|
||||
self.setup_admin_dash()
|
||||
|
||||
# If a State is not used and no overlay_component is specified, do not render the connection modal
|
||||
if self.state is None and self.overlay_component is default_overlay_component:
|
||||
self.overlay_component = None
|
||||
def enable_state(self) -> None:
|
||||
"""Enable state for the app."""
|
||||
if not self.state:
|
||||
self.state = State
|
||||
self.setup_state()
|
||||
|
||||
def setup_state(self) -> None:
|
||||
"""Set up the state for the app.
|
||||
|
||||
Raises:
|
||||
ValueError: If the event namespace is not provided in the config.
|
||||
If the state has not been enabled.
|
||||
"""
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
config = get_config()
|
||||
|
||||
# Set up the state manager.
|
||||
self._state_manager = StateManager.create(state=self.state)
|
||||
|
||||
# Set up the Socket.IO AsyncServer.
|
||||
self.sio = AsyncServer(
|
||||
async_mode="asgi",
|
||||
cors_allowed_origins=(
|
||||
"*"
|
||||
if config.cors_allowed_origins == ["*"]
|
||||
else config.cors_allowed_origins
|
||||
),
|
||||
cors_credentials=True,
|
||||
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
|
||||
ping_interval=constants.Ping.INTERVAL,
|
||||
ping_timeout=constants.Ping.TIMEOUT,
|
||||
)
|
||||
|
||||
# Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
|
||||
self.socket_app = ASGIApp(self.sio, socketio_path="")
|
||||
namespace = config.get_event_namespace()
|
||||
|
||||
if not namespace:
|
||||
raise ValueError("event namespace must be provided in the config.")
|
||||
|
||||
# Create the event namespace and attach the main app. Not related to any paths.
|
||||
self.event_namespace = EventNamespace(namespace, self)
|
||||
|
||||
# Register the event namespace with the socket.
|
||||
self.sio.register_namespace(self.event_namespace)
|
||||
# Mount the socket app with the API.
|
||||
self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Get the string representation of the app.
|
||||
@ -430,21 +445,24 @@ class App(Base):
|
||||
# Check if the route given is valid
|
||||
verify_route_validity(route)
|
||||
|
||||
# Apply dynamic args to the route.
|
||||
if self.state:
|
||||
self.state.setup_dynamic_args(get_route_args(route))
|
||||
# Setup dynamic args for the route.
|
||||
# this state assignment is only required for tests using the deprecated state kwarg for App
|
||||
state = self.state if self.state else State
|
||||
state.setup_dynamic_args(get_route_args(route))
|
||||
|
||||
# Generate the component if it is a callable.
|
||||
component = self._generate_component(component)
|
||||
|
||||
# Wrap the component in a fragment with optional overlay.
|
||||
if self.overlay_component is not None:
|
||||
component = Fragment.create(
|
||||
self._generate_component(self.overlay_component),
|
||||
component,
|
||||
)
|
||||
else:
|
||||
component = Fragment.create(component)
|
||||
if self.state is None:
|
||||
for var in component._get_vars(include_children=True):
|
||||
if not var._var_data:
|
||||
continue
|
||||
if not var._var_data.state:
|
||||
continue
|
||||
self.enable_state()
|
||||
break
|
||||
|
||||
component = OverlayFragment.create(component)
|
||||
|
||||
# Add meta information to the component.
|
||||
compiler_utils.add_meta(
|
||||
@ -649,6 +667,28 @@ class App(Base):
|
||||
# By default, compile the app.
|
||||
return True
|
||||
|
||||
def _add_overlay_to_component(self, component: Component) -> Component:
|
||||
if self.overlay_component is None:
|
||||
return component
|
||||
|
||||
children = component.children
|
||||
overlay_component = self._generate_component(self.overlay_component)
|
||||
|
||||
if children[0] == overlay_component:
|
||||
return component
|
||||
|
||||
# recreate OverlayFragment with overlay_component as first child
|
||||
component = OverlayFragment.create(overlay_component, *children)
|
||||
|
||||
return component
|
||||
|
||||
def _setup_overlay_component(self):
|
||||
"""If a State is not used and no overlay_component is specified, do not render the connection modal."""
|
||||
if self.state is None and self.overlay_component is default_overlay_component:
|
||||
self.overlay_component = None
|
||||
for k, component in self.pages.items():
|
||||
self.pages[k] = self._add_overlay_to_component(component)
|
||||
|
||||
def compile(self):
|
||||
"""compile_() is the new function for performing compilation.
|
||||
Reflex framework will call it automatically as needed.
|
||||
@ -682,6 +722,8 @@ class App(Base):
|
||||
if not self._should_compile():
|
||||
return
|
||||
|
||||
self._setup_overlay_component()
|
||||
|
||||
# Create a progress bar.
|
||||
progress = Progress(
|
||||
*Progress.get_default_columns()[:-1],
|
||||
|
@ -64,6 +64,11 @@ Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
|
||||
|
||||
def default_overlay_component() -> Component: ...
|
||||
|
||||
class OverlayFragment(Fragment):
|
||||
@overload
|
||||
@classmethod
|
||||
def create(cls, *children, **props) -> "OverlayFragment": ... # type: ignore
|
||||
|
||||
class App(Base):
|
||||
pages: Dict[str, Component]
|
||||
stylesheets: List[str]
|
||||
@ -122,6 +127,7 @@ class App(Base):
|
||||
def compile(self) -> None: ...
|
||||
def compile_(self) -> None: ...
|
||||
def modify_state(self, token: str) -> AsyncContextManager[State]: ...
|
||||
def _setup_overlay_component(self) -> None: ...
|
||||
def _process_background(
|
||||
self, state: State, event: Event
|
||||
) -> asyncio.Task | None: ...
|
||||
|
@ -33,9 +33,12 @@ class Bare(Component):
|
||||
def _render(self) -> Tag:
|
||||
return Tagless(contents=str(self.contents))
|
||||
|
||||
def _get_vars(self) -> Iterator[Var]:
|
||||
def _get_vars(self, include_children: bool = False) -> Iterator[Var]:
|
||||
"""Walk all Vars used in this component.
|
||||
|
||||
Args:
|
||||
include_children: Whether to include Vars from children.
|
||||
|
||||
Yields:
|
||||
The contents if it is a Var, otherwise nothing.
|
||||
"""
|
||||
|
@ -812,9 +812,12 @@ class Component(BaseComponent, ABC):
|
||||
event_args.extend(args)
|
||||
yield event_trigger, event_args
|
||||
|
||||
def _get_vars(self) -> list[Var]:
|
||||
def _get_vars(self, include_children: bool = False) -> list[Var]:
|
||||
"""Walk all Vars used in this component.
|
||||
|
||||
Args:
|
||||
include_children: Whether to include Vars from children.
|
||||
|
||||
Returns:
|
||||
Each var referenced by the component (props, styles, event handlers).
|
||||
"""
|
||||
@ -860,6 +863,14 @@ class Component(BaseComponent, ABC):
|
||||
var = Var.create_safe(comp_prop)
|
||||
if var._var_data is not None:
|
||||
vars.append(var)
|
||||
|
||||
# Get Vars associated with children.
|
||||
if include_children:
|
||||
for child in self.children:
|
||||
if not isinstance(child, Component):
|
||||
continue
|
||||
vars.extend(child._get_vars(include_children=include_children))
|
||||
|
||||
return vars
|
||||
|
||||
def _get_custom_code(self) -> str | None:
|
||||
|
@ -222,8 +222,8 @@ class Form(BaseHTML):
|
||||
)._replace(merge_var_data=ref_var._var_data)
|
||||
return form_refs
|
||||
|
||||
def _get_vars(self) -> Iterator[Var]:
|
||||
yield from super()._get_vars()
|
||||
def _get_vars(self, include_children: bool = True) -> Iterator[Var]:
|
||||
yield from super()._get_vars(include_children=include_children)
|
||||
yield from self._get_form_refs().values()
|
||||
|
||||
def _exclude_props(self) -> list[str]:
|
||||
|
@ -219,6 +219,15 @@ class Config(Base):
|
||||
self._non_default_attributes.update(kwargs)
|
||||
self._replace_defaults(**kwargs)
|
||||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
"""Get the module name of the app.
|
||||
|
||||
Returns:
|
||||
The module name.
|
||||
"""
|
||||
return ".".join([self.app_name, self.app_name])
|
||||
|
||||
@staticmethod
|
||||
def check_deprecated_values(**kwargs):
|
||||
"""Check for deprecated config values.
|
||||
|
@ -99,6 +99,8 @@ class Config(Base):
|
||||
gunicorn_worker_class: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> None: ...
|
||||
@property
|
||||
def module(self) -> str: ...
|
||||
@staticmethod
|
||||
def check_deprecated_values(**kwargs) -> None: ...
|
||||
def update_from_env(self) -> None: ...
|
||||
|
@ -2918,3 +2918,19 @@ def code_uses_state_contexts(javascript_code: str) -> bool:
|
||||
True if the code attempts to access a member of StateContexts.
|
||||
"""
|
||||
return bool("useContext(StateContexts" in javascript_code)
|
||||
|
||||
|
||||
def reload_state_module(
|
||||
module: str,
|
||||
state: Type[BaseState] = State,
|
||||
) -> None:
|
||||
"""Reset rx.State subclasses to avoid conflict when reloading.
|
||||
|
||||
Args:
|
||||
module: The module to reload.
|
||||
state: Recursive argument for the state class to reload.
|
||||
"""
|
||||
for subclass in tuple(state.class_subclasses):
|
||||
reload_state_module(module=module, state=subclass)
|
||||
if subclass.__module__ == module and module is not None:
|
||||
state.class_subclasses.remove(subclass)
|
||||
|
@ -232,7 +232,6 @@ class AppHarness:
|
||||
State.get_class_substate.cache_clear()
|
||||
# Ensure the AppHarness test does not skip State assignment due to running via pytest
|
||||
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
|
||||
# self.app_module.app.
|
||||
self.app_module = reflex.utils.prerequisites.get_compiled_app(reload=True)
|
||||
self.app_instance = self.app_module.app
|
||||
if isinstance(self.app_instance._state_manager, StateManagerRedis):
|
||||
|
@ -185,16 +185,16 @@ def get_app(reload: bool = False) -> ModuleType:
|
||||
"Cannot get the app module because `app_name` is not set in rxconfig! "
|
||||
"If this error occurs in a reflex test case, ensure that `get_app` is mocked."
|
||||
)
|
||||
module = ".".join([config.app_name, config.app_name])
|
||||
module = config.module
|
||||
sys.path.insert(0, os.getcwd())
|
||||
app = __import__(module, fromlist=(constants.CompileVars.APP,))
|
||||
|
||||
if reload:
|
||||
from reflex.state import State
|
||||
from reflex.state import reload_state_module
|
||||
|
||||
# Reset rx.State subclasses to avoid conflict when reloading.
|
||||
for subclass in tuple(State.class_subclasses):
|
||||
if subclass.__module__ == module:
|
||||
State.class_subclasses.remove(subclass)
|
||||
reload_state_module(module=module)
|
||||
|
||||
# Reload the app module.
|
||||
importlib.reload(app)
|
||||
|
||||
|
@ -20,6 +20,7 @@ from reflex import AdminDash, constants
|
||||
from reflex.app import (
|
||||
App,
|
||||
ComponentCallable,
|
||||
OverlayFragment,
|
||||
default_overlay_component,
|
||||
process,
|
||||
upload,
|
||||
@ -1182,12 +1183,13 @@ def test_overlay_component(
|
||||
exp_page_child: The type of the expected child in the page fragment.
|
||||
"""
|
||||
app = App(state=state, overlay_component=overlay_component)
|
||||
app._setup_overlay_component()
|
||||
if exp_page_child is None:
|
||||
assert app.overlay_component is None
|
||||
elif isinstance(exp_page_child, Fragment):
|
||||
elif isinstance(exp_page_child, OverlayFragment):
|
||||
assert app.overlay_component is not None
|
||||
generated_component = app._generate_component(app.overlay_component) # type: ignore
|
||||
assert isinstance(generated_component, Fragment)
|
||||
assert isinstance(generated_component, OverlayFragment)
|
||||
assert isinstance(
|
||||
generated_component.children[0],
|
||||
Cond, # ConnectionModal is a Cond under the hood
|
||||
@ -1200,7 +1202,10 @@ def test_overlay_component(
|
||||
)
|
||||
|
||||
app.add_page(rx.box("Index"), route="/test")
|
||||
# overlay components are wrapped during compile only
|
||||
app._setup_overlay_component()
|
||||
page = app.pages["test"]
|
||||
|
||||
if exp_page_child is not None:
|
||||
assert len(page.children) == 3
|
||||
children_types = (type(child) for child in page.children)
|
||||
|
Loading…
Reference in New Issue
Block a user