Track state usage (#2441)

* rebase

* pass include_children kwarg in radix FormRoot

* respect include_children

* ruff fixes

* readd statemanager init, run pyi gen

* minor performance imporovements, fix for state changes

* fix pyi and pyright

* pass include_children for chakra

* remove old state detection

* add test for unused states in stateless app

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
benedikt-bartscher 2024-03-07 23:25:55 +01:00 committed by GitHub
parent 5d647a498f
commit 19a5cdd408
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 163 additions and 67 deletions

View File

@ -28,6 +28,9 @@ def TailwindApp(
import reflex as rx
import reflex.components.radix.themes as rdxt
class UnusedState(rx.State):
pass
def index():
return rx.el.div(
rx.chakra.text(paragraph_text, class_name=paragraph_class_name),

View File

@ -91,6 +91,12 @@ def default_overlay_component() -> Component:
return Fragment.create(connection_pulser(), connection_modal())
class OverlayFragment(Fragment):
"""Alias for Fragment, used to wrap the overlay_component."""
pass
class App(Base):
"""A Reflex application."""
@ -159,7 +165,7 @@ class App(Base):
Raises:
ValueError: If the event namespace is not provided in the config.
Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist
Also, if there are multiple client subclasses of rx.BaseState(Subclasses of rx.BaseState should consist
of the DefaultState and the client app state).
"""
if "connect_error_component" in kwargs:
@ -167,12 +173,12 @@ class App(Base):
"`connect_error_component` is deprecated, use `overlay_component` instead"
)
super().__init__(*args, **kwargs)
state_subclasses = BaseState.__subclasses__()
base_state_subclasses = BaseState.__subclasses__()
# Special case to allow test cases have multiple subclasses of rx.BaseState.
if not is_testing_env():
# Only one Base State class is allowed.
if len(state_subclasses) > 1:
if len(base_state_subclasses) > 1:
raise ValueError(
"rx.BaseState cannot be subclassed multiple times. use rx.State instead"
)
@ -184,12 +190,6 @@ class App(Base):
deprecation_version="0.3.5",
removal_version="0.5.0",
)
# 2 substates are built-in and not considered when determining if app is stateless.
if len(State.class_subclasses) > 2:
self.state = State
# Get the config
config = get_config()
# Add middleware.
self.middleware.append(HydrateMiddleware())
@ -198,45 +198,60 @@ class App(Base):
self.add_cors()
self.add_default_endpoints()
if self.state:
# Set up the state manager.
self._state_manager = StateManager.create(state=self.state)
# Set up the Socket.IO AsyncServer.
self.sio = AsyncServer(
async_mode="asgi",
cors_allowed_origins=(
"*"
if config.cors_allowed_origins == ["*"]
else config.cors_allowed_origins
),
cors_credentials=True,
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
ping_interval=constants.Ping.INTERVAL,
ping_timeout=constants.Ping.TIMEOUT,
)
# Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
self.socket_app = ASGIApp(self.sio, socketio_path="")
namespace = config.get_event_namespace()
if not namespace:
raise ValueError("event namespace must be provided in the config.")
# Create the event namespace and attach the main app. Not related to any paths.
self.event_namespace = EventNamespace(namespace, self)
# Register the event namespace with the socket.
self.sio.register_namespace(self.event_namespace)
# Mount the socket app with the API.
self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
self.setup_state()
# Set up the admin dash.
self.setup_admin_dash()
# If a State is not used and no overlay_component is specified, do not render the connection modal
if self.state is None and self.overlay_component is default_overlay_component:
self.overlay_component = None
def enable_state(self) -> None:
"""Enable state for the app."""
if not self.state:
self.state = State
self.setup_state()
def setup_state(self) -> None:
"""Set up the state for the app.
Raises:
ValueError: If the event namespace is not provided in the config.
If the state has not been enabled.
"""
if not self.state:
return
config = get_config()
# Set up the state manager.
self._state_manager = StateManager.create(state=self.state)
# Set up the Socket.IO AsyncServer.
self.sio = AsyncServer(
async_mode="asgi",
cors_allowed_origins=(
"*"
if config.cors_allowed_origins == ["*"]
else config.cors_allowed_origins
),
cors_credentials=True,
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
ping_interval=constants.Ping.INTERVAL,
ping_timeout=constants.Ping.TIMEOUT,
)
# Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
self.socket_app = ASGIApp(self.sio, socketio_path="")
namespace = config.get_event_namespace()
if not namespace:
raise ValueError("event namespace must be provided in the config.")
# Create the event namespace and attach the main app. Not related to any paths.
self.event_namespace = EventNamespace(namespace, self)
# Register the event namespace with the socket.
self.sio.register_namespace(self.event_namespace)
# Mount the socket app with the API.
self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
def __repr__(self) -> str:
"""Get the string representation of the app.
@ -430,21 +445,24 @@ class App(Base):
# Check if the route given is valid
verify_route_validity(route)
# Apply dynamic args to the route.
if self.state:
self.state.setup_dynamic_args(get_route_args(route))
# Setup dynamic args for the route.
# this state assignment is only required for tests using the deprecated state kwarg for App
state = self.state if self.state else State
state.setup_dynamic_args(get_route_args(route))
# Generate the component if it is a callable.
component = self._generate_component(component)
# Wrap the component in a fragment with optional overlay.
if self.overlay_component is not None:
component = Fragment.create(
self._generate_component(self.overlay_component),
component,
)
else:
component = Fragment.create(component)
if self.state is None:
for var in component._get_vars(include_children=True):
if not var._var_data:
continue
if not var._var_data.state:
continue
self.enable_state()
break
component = OverlayFragment.create(component)
# Add meta information to the component.
compiler_utils.add_meta(
@ -649,6 +667,28 @@ class App(Base):
# By default, compile the app.
return True
def _add_overlay_to_component(self, component: Component) -> Component:
if self.overlay_component is None:
return component
children = component.children
overlay_component = self._generate_component(self.overlay_component)
if children[0] == overlay_component:
return component
# recreate OverlayFragment with overlay_component as first child
component = OverlayFragment.create(overlay_component, *children)
return component
def _setup_overlay_component(self):
"""If a State is not used and no overlay_component is specified, do not render the connection modal."""
if self.state is None and self.overlay_component is default_overlay_component:
self.overlay_component = None
for k, component in self.pages.items():
self.pages[k] = self._add_overlay_to_component(component)
def compile(self):
"""compile_() is the new function for performing compilation.
Reflex framework will call it automatically as needed.
@ -682,6 +722,8 @@ class App(Base):
if not self._should_compile():
return
self._setup_overlay_component()
# Create a progress bar.
progress = Progress(
*Progress.get_default_columns()[:-1],

View File

@ -64,6 +64,11 @@ Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
def default_overlay_component() -> Component: ...
class OverlayFragment(Fragment):
@overload
@classmethod
def create(cls, *children, **props) -> "OverlayFragment": ... # type: ignore
class App(Base):
pages: Dict[str, Component]
stylesheets: List[str]
@ -122,6 +127,7 @@ class App(Base):
def compile(self) -> None: ...
def compile_(self) -> None: ...
def modify_state(self, token: str) -> AsyncContextManager[State]: ...
def _setup_overlay_component(self) -> None: ...
def _process_background(
self, state: State, event: Event
) -> asyncio.Task | None: ...

View File

@ -33,9 +33,12 @@ class Bare(Component):
def _render(self) -> Tag:
return Tagless(contents=str(self.contents))
def _get_vars(self) -> Iterator[Var]:
def _get_vars(self, include_children: bool = False) -> Iterator[Var]:
"""Walk all Vars used in this component.
Args:
include_children: Whether to include Vars from children.
Yields:
The contents if it is a Var, otherwise nothing.
"""

View File

@ -812,9 +812,12 @@ class Component(BaseComponent, ABC):
event_args.extend(args)
yield event_trigger, event_args
def _get_vars(self) -> list[Var]:
def _get_vars(self, include_children: bool = False) -> list[Var]:
"""Walk all Vars used in this component.
Args:
include_children: Whether to include Vars from children.
Returns:
Each var referenced by the component (props, styles, event handlers).
"""
@ -860,6 +863,14 @@ class Component(BaseComponent, ABC):
var = Var.create_safe(comp_prop)
if var._var_data is not None:
vars.append(var)
# Get Vars associated with children.
if include_children:
for child in self.children:
if not isinstance(child, Component):
continue
vars.extend(child._get_vars(include_children=include_children))
return vars
def _get_custom_code(self) -> str | None:

View File

@ -222,8 +222,8 @@ class Form(BaseHTML):
)._replace(merge_var_data=ref_var._var_data)
return form_refs
def _get_vars(self) -> Iterator[Var]:
yield from super()._get_vars()
def _get_vars(self, include_children: bool = True) -> Iterator[Var]:
yield from super()._get_vars(include_children=include_children)
yield from self._get_form_refs().values()
def _exclude_props(self) -> list[str]:

View File

@ -219,6 +219,15 @@ class Config(Base):
self._non_default_attributes.update(kwargs)
self._replace_defaults(**kwargs)
@property
def module(self) -> str:
"""Get the module name of the app.
Returns:
The module name.
"""
return ".".join([self.app_name, self.app_name])
@staticmethod
def check_deprecated_values(**kwargs):
"""Check for deprecated config values.

View File

@ -99,6 +99,8 @@ class Config(Base):
gunicorn_worker_class: Optional[str] = None,
**kwargs
) -> None: ...
@property
def module(self) -> str: ...
@staticmethod
def check_deprecated_values(**kwargs) -> None: ...
def update_from_env(self) -> None: ...

View File

@ -2918,3 +2918,19 @@ def code_uses_state_contexts(javascript_code: str) -> bool:
True if the code attempts to access a member of StateContexts.
"""
return bool("useContext(StateContexts" in javascript_code)
def reload_state_module(
module: str,
state: Type[BaseState] = State,
) -> None:
"""Reset rx.State subclasses to avoid conflict when reloading.
Args:
module: The module to reload.
state: Recursive argument for the state class to reload.
"""
for subclass in tuple(state.class_subclasses):
reload_state_module(module=module, state=subclass)
if subclass.__module__ == module and module is not None:
state.class_subclasses.remove(subclass)

View File

@ -232,7 +232,6 @@ class AppHarness:
State.get_class_substate.cache_clear()
# Ensure the AppHarness test does not skip State assignment due to running via pytest
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
# self.app_module.app.
self.app_module = reflex.utils.prerequisites.get_compiled_app(reload=True)
self.app_instance = self.app_module.app
if isinstance(self.app_instance._state_manager, StateManagerRedis):

View File

@ -185,16 +185,16 @@ def get_app(reload: bool = False) -> ModuleType:
"Cannot get the app module because `app_name` is not set in rxconfig! "
"If this error occurs in a reflex test case, ensure that `get_app` is mocked."
)
module = ".".join([config.app_name, config.app_name])
module = config.module
sys.path.insert(0, os.getcwd())
app = __import__(module, fromlist=(constants.CompileVars.APP,))
if reload:
from reflex.state import State
from reflex.state import reload_state_module
# Reset rx.State subclasses to avoid conflict when reloading.
for subclass in tuple(State.class_subclasses):
if subclass.__module__ == module:
State.class_subclasses.remove(subclass)
reload_state_module(module=module)
# Reload the app module.
importlib.reload(app)

View File

@ -20,6 +20,7 @@ from reflex import AdminDash, constants
from reflex.app import (
App,
ComponentCallable,
OverlayFragment,
default_overlay_component,
process,
upload,
@ -1182,12 +1183,13 @@ def test_overlay_component(
exp_page_child: The type of the expected child in the page fragment.
"""
app = App(state=state, overlay_component=overlay_component)
app._setup_overlay_component()
if exp_page_child is None:
assert app.overlay_component is None
elif isinstance(exp_page_child, Fragment):
elif isinstance(exp_page_child, OverlayFragment):
assert app.overlay_component is not None
generated_component = app._generate_component(app.overlay_component) # type: ignore
assert isinstance(generated_component, Fragment)
assert isinstance(generated_component, OverlayFragment)
assert isinstance(
generated_component.children[0],
Cond, # ConnectionModal is a Cond under the hood
@ -1200,7 +1202,10 @@ def test_overlay_component(
)
app.add_page(rx.box("Index"), route="/test")
# overlay components are wrapped during compile only
app._setup_overlay_component()
page = app.pages["test"]
if exp_page_child is not None:
assert len(page.children) == 3
children_types = (type(child) for child in page.children)