Track state usage (#2441)
* rebase * pass include_children kwarg in radix FormRoot * respect include_children * ruff fixes * readd statemanager init, run pyi gen * minor performance imporovements, fix for state changes * fix pyi and pyright * pass include_children for chakra * remove old state detection * add test for unused states in stateless app --------- Co-authored-by: Masen Furer <m_github@0x26.net>
This commit is contained in:
parent
5d647a498f
commit
19a5cdd408
@ -28,6 +28,9 @@ def TailwindApp(
|
|||||||
import reflex as rx
|
import reflex as rx
|
||||||
import reflex.components.radix.themes as rdxt
|
import reflex.components.radix.themes as rdxt
|
||||||
|
|
||||||
|
class UnusedState(rx.State):
|
||||||
|
pass
|
||||||
|
|
||||||
def index():
|
def index():
|
||||||
return rx.el.div(
|
return rx.el.div(
|
||||||
rx.chakra.text(paragraph_text, class_name=paragraph_class_name),
|
rx.chakra.text(paragraph_text, class_name=paragraph_class_name),
|
||||||
|
152
reflex/app.py
152
reflex/app.py
@ -91,6 +91,12 @@ def default_overlay_component() -> Component:
|
|||||||
return Fragment.create(connection_pulser(), connection_modal())
|
return Fragment.create(connection_pulser(), connection_modal())
|
||||||
|
|
||||||
|
|
||||||
|
class OverlayFragment(Fragment):
|
||||||
|
"""Alias for Fragment, used to wrap the overlay_component."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class App(Base):
|
class App(Base):
|
||||||
"""A Reflex application."""
|
"""A Reflex application."""
|
||||||
|
|
||||||
@ -159,7 +165,7 @@ class App(Base):
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the event namespace is not provided in the config.
|
ValueError: If the event namespace is not provided in the config.
|
||||||
Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist
|
Also, if there are multiple client subclasses of rx.BaseState(Subclasses of rx.BaseState should consist
|
||||||
of the DefaultState and the client app state).
|
of the DefaultState and the client app state).
|
||||||
"""
|
"""
|
||||||
if "connect_error_component" in kwargs:
|
if "connect_error_component" in kwargs:
|
||||||
@ -167,12 +173,12 @@ class App(Base):
|
|||||||
"`connect_error_component` is deprecated, use `overlay_component` instead"
|
"`connect_error_component` is deprecated, use `overlay_component` instead"
|
||||||
)
|
)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
state_subclasses = BaseState.__subclasses__()
|
base_state_subclasses = BaseState.__subclasses__()
|
||||||
|
|
||||||
# Special case to allow test cases have multiple subclasses of rx.BaseState.
|
# Special case to allow test cases have multiple subclasses of rx.BaseState.
|
||||||
if not is_testing_env():
|
if not is_testing_env():
|
||||||
# Only one Base State class is allowed.
|
# Only one Base State class is allowed.
|
||||||
if len(state_subclasses) > 1:
|
if len(base_state_subclasses) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"rx.BaseState cannot be subclassed multiple times. use rx.State instead"
|
"rx.BaseState cannot be subclassed multiple times. use rx.State instead"
|
||||||
)
|
)
|
||||||
@ -184,12 +190,6 @@ class App(Base):
|
|||||||
deprecation_version="0.3.5",
|
deprecation_version="0.3.5",
|
||||||
removal_version="0.5.0",
|
removal_version="0.5.0",
|
||||||
)
|
)
|
||||||
# 2 substates are built-in and not considered when determining if app is stateless.
|
|
||||||
if len(State.class_subclasses) > 2:
|
|
||||||
self.state = State
|
|
||||||
# Get the config
|
|
||||||
config = get_config()
|
|
||||||
|
|
||||||
# Add middleware.
|
# Add middleware.
|
||||||
self.middleware.append(HydrateMiddleware())
|
self.middleware.append(HydrateMiddleware())
|
||||||
|
|
||||||
@ -198,45 +198,60 @@ class App(Base):
|
|||||||
self.add_cors()
|
self.add_cors()
|
||||||
self.add_default_endpoints()
|
self.add_default_endpoints()
|
||||||
|
|
||||||
if self.state:
|
self.setup_state()
|
||||||
# Set up the state manager.
|
|
||||||
self._state_manager = StateManager.create(state=self.state)
|
|
||||||
|
|
||||||
# Set up the Socket.IO AsyncServer.
|
|
||||||
self.sio = AsyncServer(
|
|
||||||
async_mode="asgi",
|
|
||||||
cors_allowed_origins=(
|
|
||||||
"*"
|
|
||||||
if config.cors_allowed_origins == ["*"]
|
|
||||||
else config.cors_allowed_origins
|
|
||||||
),
|
|
||||||
cors_credentials=True,
|
|
||||||
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
|
|
||||||
ping_interval=constants.Ping.INTERVAL,
|
|
||||||
ping_timeout=constants.Ping.TIMEOUT,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
|
|
||||||
self.socket_app = ASGIApp(self.sio, socketio_path="")
|
|
||||||
namespace = config.get_event_namespace()
|
|
||||||
|
|
||||||
if not namespace:
|
|
||||||
raise ValueError("event namespace must be provided in the config.")
|
|
||||||
|
|
||||||
# Create the event namespace and attach the main app. Not related to any paths.
|
|
||||||
self.event_namespace = EventNamespace(namespace, self)
|
|
||||||
|
|
||||||
# Register the event namespace with the socket.
|
|
||||||
self.sio.register_namespace(self.event_namespace)
|
|
||||||
# Mount the socket app with the API.
|
|
||||||
self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
|
|
||||||
|
|
||||||
# Set up the admin dash.
|
# Set up the admin dash.
|
||||||
self.setup_admin_dash()
|
self.setup_admin_dash()
|
||||||
|
|
||||||
# If a State is not used and no overlay_component is specified, do not render the connection modal
|
def enable_state(self) -> None:
|
||||||
if self.state is None and self.overlay_component is default_overlay_component:
|
"""Enable state for the app."""
|
||||||
self.overlay_component = None
|
if not self.state:
|
||||||
|
self.state = State
|
||||||
|
self.setup_state()
|
||||||
|
|
||||||
|
def setup_state(self) -> None:
|
||||||
|
"""Set up the state for the app.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the event namespace is not provided in the config.
|
||||||
|
If the state has not been enabled.
|
||||||
|
"""
|
||||||
|
if not self.state:
|
||||||
|
return
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
|
||||||
|
# Set up the state manager.
|
||||||
|
self._state_manager = StateManager.create(state=self.state)
|
||||||
|
|
||||||
|
# Set up the Socket.IO AsyncServer.
|
||||||
|
self.sio = AsyncServer(
|
||||||
|
async_mode="asgi",
|
||||||
|
cors_allowed_origins=(
|
||||||
|
"*"
|
||||||
|
if config.cors_allowed_origins == ["*"]
|
||||||
|
else config.cors_allowed_origins
|
||||||
|
),
|
||||||
|
cors_credentials=True,
|
||||||
|
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
|
||||||
|
ping_interval=constants.Ping.INTERVAL,
|
||||||
|
ping_timeout=constants.Ping.TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the socket app. Note event endpoint constant replaces the default 'socket.io' path.
|
||||||
|
self.socket_app = ASGIApp(self.sio, socketio_path="")
|
||||||
|
namespace = config.get_event_namespace()
|
||||||
|
|
||||||
|
if not namespace:
|
||||||
|
raise ValueError("event namespace must be provided in the config.")
|
||||||
|
|
||||||
|
# Create the event namespace and attach the main app. Not related to any paths.
|
||||||
|
self.event_namespace = EventNamespace(namespace, self)
|
||||||
|
|
||||||
|
# Register the event namespace with the socket.
|
||||||
|
self.sio.register_namespace(self.event_namespace)
|
||||||
|
# Mount the socket app with the API.
|
||||||
|
self.api.mount(str(constants.Endpoint.EVENT), self.socket_app)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Get the string representation of the app.
|
"""Get the string representation of the app.
|
||||||
@ -430,21 +445,24 @@ class App(Base):
|
|||||||
# Check if the route given is valid
|
# Check if the route given is valid
|
||||||
verify_route_validity(route)
|
verify_route_validity(route)
|
||||||
|
|
||||||
# Apply dynamic args to the route.
|
# Setup dynamic args for the route.
|
||||||
if self.state:
|
# this state assignment is only required for tests using the deprecated state kwarg for App
|
||||||
self.state.setup_dynamic_args(get_route_args(route))
|
state = self.state if self.state else State
|
||||||
|
state.setup_dynamic_args(get_route_args(route))
|
||||||
|
|
||||||
# Generate the component if it is a callable.
|
# Generate the component if it is a callable.
|
||||||
component = self._generate_component(component)
|
component = self._generate_component(component)
|
||||||
|
|
||||||
# Wrap the component in a fragment with optional overlay.
|
if self.state is None:
|
||||||
if self.overlay_component is not None:
|
for var in component._get_vars(include_children=True):
|
||||||
component = Fragment.create(
|
if not var._var_data:
|
||||||
self._generate_component(self.overlay_component),
|
continue
|
||||||
component,
|
if not var._var_data.state:
|
||||||
)
|
continue
|
||||||
else:
|
self.enable_state()
|
||||||
component = Fragment.create(component)
|
break
|
||||||
|
|
||||||
|
component = OverlayFragment.create(component)
|
||||||
|
|
||||||
# Add meta information to the component.
|
# Add meta information to the component.
|
||||||
compiler_utils.add_meta(
|
compiler_utils.add_meta(
|
||||||
@ -649,6 +667,28 @@ class App(Base):
|
|||||||
# By default, compile the app.
|
# By default, compile the app.
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _add_overlay_to_component(self, component: Component) -> Component:
|
||||||
|
if self.overlay_component is None:
|
||||||
|
return component
|
||||||
|
|
||||||
|
children = component.children
|
||||||
|
overlay_component = self._generate_component(self.overlay_component)
|
||||||
|
|
||||||
|
if children[0] == overlay_component:
|
||||||
|
return component
|
||||||
|
|
||||||
|
# recreate OverlayFragment with overlay_component as first child
|
||||||
|
component = OverlayFragment.create(overlay_component, *children)
|
||||||
|
|
||||||
|
return component
|
||||||
|
|
||||||
|
def _setup_overlay_component(self):
|
||||||
|
"""If a State is not used and no overlay_component is specified, do not render the connection modal."""
|
||||||
|
if self.state is None and self.overlay_component is default_overlay_component:
|
||||||
|
self.overlay_component = None
|
||||||
|
for k, component in self.pages.items():
|
||||||
|
self.pages[k] = self._add_overlay_to_component(component)
|
||||||
|
|
||||||
def compile(self):
|
def compile(self):
|
||||||
"""compile_() is the new function for performing compilation.
|
"""compile_() is the new function for performing compilation.
|
||||||
Reflex framework will call it automatically as needed.
|
Reflex framework will call it automatically as needed.
|
||||||
@ -682,6 +722,8 @@ class App(Base):
|
|||||||
if not self._should_compile():
|
if not self._should_compile():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self._setup_overlay_component()
|
||||||
|
|
||||||
# Create a progress bar.
|
# Create a progress bar.
|
||||||
progress = Progress(
|
progress = Progress(
|
||||||
*Progress.get_default_columns()[:-1],
|
*Progress.get_default_columns()[:-1],
|
||||||
|
@ -64,6 +64,11 @@ Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
|
|||||||
|
|
||||||
def default_overlay_component() -> Component: ...
|
def default_overlay_component() -> Component: ...
|
||||||
|
|
||||||
|
class OverlayFragment(Fragment):
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
def create(cls, *children, **props) -> "OverlayFragment": ... # type: ignore
|
||||||
|
|
||||||
class App(Base):
|
class App(Base):
|
||||||
pages: Dict[str, Component]
|
pages: Dict[str, Component]
|
||||||
stylesheets: List[str]
|
stylesheets: List[str]
|
||||||
@ -122,6 +127,7 @@ class App(Base):
|
|||||||
def compile(self) -> None: ...
|
def compile(self) -> None: ...
|
||||||
def compile_(self) -> None: ...
|
def compile_(self) -> None: ...
|
||||||
def modify_state(self, token: str) -> AsyncContextManager[State]: ...
|
def modify_state(self, token: str) -> AsyncContextManager[State]: ...
|
||||||
|
def _setup_overlay_component(self) -> None: ...
|
||||||
def _process_background(
|
def _process_background(
|
||||||
self, state: State, event: Event
|
self, state: State, event: Event
|
||||||
) -> asyncio.Task | None: ...
|
) -> asyncio.Task | None: ...
|
||||||
|
@ -33,9 +33,12 @@ class Bare(Component):
|
|||||||
def _render(self) -> Tag:
|
def _render(self) -> Tag:
|
||||||
return Tagless(contents=str(self.contents))
|
return Tagless(contents=str(self.contents))
|
||||||
|
|
||||||
def _get_vars(self) -> Iterator[Var]:
|
def _get_vars(self, include_children: bool = False) -> Iterator[Var]:
|
||||||
"""Walk all Vars used in this component.
|
"""Walk all Vars used in this component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_children: Whether to include Vars from children.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
The contents if it is a Var, otherwise nothing.
|
The contents if it is a Var, otherwise nothing.
|
||||||
"""
|
"""
|
||||||
|
@ -812,9 +812,12 @@ class Component(BaseComponent, ABC):
|
|||||||
event_args.extend(args)
|
event_args.extend(args)
|
||||||
yield event_trigger, event_args
|
yield event_trigger, event_args
|
||||||
|
|
||||||
def _get_vars(self) -> list[Var]:
|
def _get_vars(self, include_children: bool = False) -> list[Var]:
|
||||||
"""Walk all Vars used in this component.
|
"""Walk all Vars used in this component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_children: Whether to include Vars from children.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Each var referenced by the component (props, styles, event handlers).
|
Each var referenced by the component (props, styles, event handlers).
|
||||||
"""
|
"""
|
||||||
@ -860,6 +863,14 @@ class Component(BaseComponent, ABC):
|
|||||||
var = Var.create_safe(comp_prop)
|
var = Var.create_safe(comp_prop)
|
||||||
if var._var_data is not None:
|
if var._var_data is not None:
|
||||||
vars.append(var)
|
vars.append(var)
|
||||||
|
|
||||||
|
# Get Vars associated with children.
|
||||||
|
if include_children:
|
||||||
|
for child in self.children:
|
||||||
|
if not isinstance(child, Component):
|
||||||
|
continue
|
||||||
|
vars.extend(child._get_vars(include_children=include_children))
|
||||||
|
|
||||||
return vars
|
return vars
|
||||||
|
|
||||||
def _get_custom_code(self) -> str | None:
|
def _get_custom_code(self) -> str | None:
|
||||||
|
@ -222,8 +222,8 @@ class Form(BaseHTML):
|
|||||||
)._replace(merge_var_data=ref_var._var_data)
|
)._replace(merge_var_data=ref_var._var_data)
|
||||||
return form_refs
|
return form_refs
|
||||||
|
|
||||||
def _get_vars(self) -> Iterator[Var]:
|
def _get_vars(self, include_children: bool = True) -> Iterator[Var]:
|
||||||
yield from super()._get_vars()
|
yield from super()._get_vars(include_children=include_children)
|
||||||
yield from self._get_form_refs().values()
|
yield from self._get_form_refs().values()
|
||||||
|
|
||||||
def _exclude_props(self) -> list[str]:
|
def _exclude_props(self) -> list[str]:
|
||||||
|
@ -219,6 +219,15 @@ class Config(Base):
|
|||||||
self._non_default_attributes.update(kwargs)
|
self._non_default_attributes.update(kwargs)
|
||||||
self._replace_defaults(**kwargs)
|
self._replace_defaults(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def module(self) -> str:
|
||||||
|
"""Get the module name of the app.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The module name.
|
||||||
|
"""
|
||||||
|
return ".".join([self.app_name, self.app_name])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_deprecated_values(**kwargs):
|
def check_deprecated_values(**kwargs):
|
||||||
"""Check for deprecated config values.
|
"""Check for deprecated config values.
|
||||||
|
@ -99,6 +99,8 @@ class Config(Base):
|
|||||||
gunicorn_worker_class: Optional[str] = None,
|
gunicorn_worker_class: Optional[str] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
@property
|
||||||
|
def module(self) -> str: ...
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_deprecated_values(**kwargs) -> None: ...
|
def check_deprecated_values(**kwargs) -> None: ...
|
||||||
def update_from_env(self) -> None: ...
|
def update_from_env(self) -> None: ...
|
||||||
|
@ -2918,3 +2918,19 @@ def code_uses_state_contexts(javascript_code: str) -> bool:
|
|||||||
True if the code attempts to access a member of StateContexts.
|
True if the code attempts to access a member of StateContexts.
|
||||||
"""
|
"""
|
||||||
return bool("useContext(StateContexts" in javascript_code)
|
return bool("useContext(StateContexts" in javascript_code)
|
||||||
|
|
||||||
|
|
||||||
|
def reload_state_module(
|
||||||
|
module: str,
|
||||||
|
state: Type[BaseState] = State,
|
||||||
|
) -> None:
|
||||||
|
"""Reset rx.State subclasses to avoid conflict when reloading.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The module to reload.
|
||||||
|
state: Recursive argument for the state class to reload.
|
||||||
|
"""
|
||||||
|
for subclass in tuple(state.class_subclasses):
|
||||||
|
reload_state_module(module=module, state=subclass)
|
||||||
|
if subclass.__module__ == module and module is not None:
|
||||||
|
state.class_subclasses.remove(subclass)
|
||||||
|
@ -232,7 +232,6 @@ class AppHarness:
|
|||||||
State.get_class_substate.cache_clear()
|
State.get_class_substate.cache_clear()
|
||||||
# Ensure the AppHarness test does not skip State assignment due to running via pytest
|
# Ensure the AppHarness test does not skip State assignment due to running via pytest
|
||||||
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
|
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
|
||||||
# self.app_module.app.
|
|
||||||
self.app_module = reflex.utils.prerequisites.get_compiled_app(reload=True)
|
self.app_module = reflex.utils.prerequisites.get_compiled_app(reload=True)
|
||||||
self.app_instance = self.app_module.app
|
self.app_instance = self.app_module.app
|
||||||
if isinstance(self.app_instance._state_manager, StateManagerRedis):
|
if isinstance(self.app_instance._state_manager, StateManagerRedis):
|
||||||
|
@ -185,16 +185,16 @@ def get_app(reload: bool = False) -> ModuleType:
|
|||||||
"Cannot get the app module because `app_name` is not set in rxconfig! "
|
"Cannot get the app module because `app_name` is not set in rxconfig! "
|
||||||
"If this error occurs in a reflex test case, ensure that `get_app` is mocked."
|
"If this error occurs in a reflex test case, ensure that `get_app` is mocked."
|
||||||
)
|
)
|
||||||
module = ".".join([config.app_name, config.app_name])
|
module = config.module
|
||||||
sys.path.insert(0, os.getcwd())
|
sys.path.insert(0, os.getcwd())
|
||||||
app = __import__(module, fromlist=(constants.CompileVars.APP,))
|
app = __import__(module, fromlist=(constants.CompileVars.APP,))
|
||||||
|
|
||||||
if reload:
|
if reload:
|
||||||
from reflex.state import State
|
from reflex.state import reload_state_module
|
||||||
|
|
||||||
# Reset rx.State subclasses to avoid conflict when reloading.
|
# Reset rx.State subclasses to avoid conflict when reloading.
|
||||||
for subclass in tuple(State.class_subclasses):
|
reload_state_module(module=module)
|
||||||
if subclass.__module__ == module:
|
|
||||||
State.class_subclasses.remove(subclass)
|
|
||||||
# Reload the app module.
|
# Reload the app module.
|
||||||
importlib.reload(app)
|
importlib.reload(app)
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from reflex import AdminDash, constants
|
|||||||
from reflex.app import (
|
from reflex.app import (
|
||||||
App,
|
App,
|
||||||
ComponentCallable,
|
ComponentCallable,
|
||||||
|
OverlayFragment,
|
||||||
default_overlay_component,
|
default_overlay_component,
|
||||||
process,
|
process,
|
||||||
upload,
|
upload,
|
||||||
@ -1182,12 +1183,13 @@ def test_overlay_component(
|
|||||||
exp_page_child: The type of the expected child in the page fragment.
|
exp_page_child: The type of the expected child in the page fragment.
|
||||||
"""
|
"""
|
||||||
app = App(state=state, overlay_component=overlay_component)
|
app = App(state=state, overlay_component=overlay_component)
|
||||||
|
app._setup_overlay_component()
|
||||||
if exp_page_child is None:
|
if exp_page_child is None:
|
||||||
assert app.overlay_component is None
|
assert app.overlay_component is None
|
||||||
elif isinstance(exp_page_child, Fragment):
|
elif isinstance(exp_page_child, OverlayFragment):
|
||||||
assert app.overlay_component is not None
|
assert app.overlay_component is not None
|
||||||
generated_component = app._generate_component(app.overlay_component) # type: ignore
|
generated_component = app._generate_component(app.overlay_component) # type: ignore
|
||||||
assert isinstance(generated_component, Fragment)
|
assert isinstance(generated_component, OverlayFragment)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
generated_component.children[0],
|
generated_component.children[0],
|
||||||
Cond, # ConnectionModal is a Cond under the hood
|
Cond, # ConnectionModal is a Cond under the hood
|
||||||
@ -1200,7 +1202,10 @@ def test_overlay_component(
|
|||||||
)
|
)
|
||||||
|
|
||||||
app.add_page(rx.box("Index"), route="/test")
|
app.add_page(rx.box("Index"), route="/test")
|
||||||
|
# overlay components are wrapped during compile only
|
||||||
|
app._setup_overlay_component()
|
||||||
page = app.pages["test"]
|
page = app.pages["test"]
|
||||||
|
|
||||||
if exp_page_child is not None:
|
if exp_page_child is not None:
|
||||||
assert len(page.children) == 3
|
assert len(page.children) == 3
|
||||||
children_types = (type(child) for child in page.children)
|
children_types = (type(child) for child in page.children)
|
||||||
|
Loading…
Reference in New Issue
Block a user