From 19a5cdd4087760c6a293070b3d937a68bf9ed6ab Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Thu, 7 Mar 2024 23:25:55 +0100 Subject: [PATCH] Track state usage (#2441) * rebase * pass include_children kwarg in radix FormRoot * respect include_children * ruff fixes * readd statemanager init, run pyi gen * minor performance imporovements, fix for state changes * fix pyi and pyright * pass include_children for chakra * remove old state detection * add test for unused states in stateless app --------- Co-authored-by: Masen Furer --- integration/test_tailwind.py | 3 + reflex/app.py | 152 ++++++++++++++++--------- reflex/app.pyi | 6 + reflex/components/base/bare.py | 5 +- reflex/components/component.py | 13 ++- reflex/components/el/elements/forms.py | 4 +- reflex/config.py | 9 ++ reflex/config.pyi | 2 + reflex/state.py | 16 +++ reflex/testing.py | 1 - reflex/utils/prerequisites.py | 10 +- tests/test_app.py | 9 +- 12 files changed, 163 insertions(+), 67 deletions(-) diff --git a/integration/test_tailwind.py b/integration/test_tailwind.py index d5fe764d1..e6efbcd53 100644 --- a/integration/test_tailwind.py +++ b/integration/test_tailwind.py @@ -28,6 +28,9 @@ def TailwindApp( import reflex as rx import reflex.components.radix.themes as rdxt + class UnusedState(rx.State): + pass + def index(): return rx.el.div( rx.chakra.text(paragraph_text, class_name=paragraph_class_name), diff --git a/reflex/app.py b/reflex/app.py index 638c01a4e..d2b9d868e 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -91,6 +91,12 @@ def default_overlay_component() -> Component: return Fragment.create(connection_pulser(), connection_modal()) +class OverlayFragment(Fragment): + """Alias for Fragment, used to wrap the overlay_component.""" + + pass + + class App(Base): """A Reflex application.""" @@ -159,7 +165,7 @@ class App(Base): Raises: ValueError: If the event namespace is not provided in the config. - Also, if there are multiple client subclasses of rx.State(Subclasses of rx.State should consist + Also, if there are multiple client subclasses of rx.BaseState(Subclasses of rx.BaseState should consist of the DefaultState and the client app state). """ if "connect_error_component" in kwargs: @@ -167,12 +173,12 @@ class App(Base): "`connect_error_component` is deprecated, use `overlay_component` instead" ) super().__init__(*args, **kwargs) - state_subclasses = BaseState.__subclasses__() + base_state_subclasses = BaseState.__subclasses__() # Special case to allow test cases have multiple subclasses of rx.BaseState. if not is_testing_env(): # Only one Base State class is allowed. - if len(state_subclasses) > 1: + if len(base_state_subclasses) > 1: raise ValueError( "rx.BaseState cannot be subclassed multiple times. use rx.State instead" ) @@ -184,12 +190,6 @@ class App(Base): deprecation_version="0.3.5", removal_version="0.5.0", ) - # 2 substates are built-in and not considered when determining if app is stateless. - if len(State.class_subclasses) > 2: - self.state = State - # Get the config - config = get_config() - # Add middleware. self.middleware.append(HydrateMiddleware()) @@ -198,45 +198,60 @@ class App(Base): self.add_cors() self.add_default_endpoints() - if self.state: - # Set up the state manager. - self._state_manager = StateManager.create(state=self.state) - - # Set up the Socket.IO AsyncServer. - self.sio = AsyncServer( - async_mode="asgi", - cors_allowed_origins=( - "*" - if config.cors_allowed_origins == ["*"] - else config.cors_allowed_origins - ), - cors_credentials=True, - max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE, - ping_interval=constants.Ping.INTERVAL, - ping_timeout=constants.Ping.TIMEOUT, - ) - - # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path. - self.socket_app = ASGIApp(self.sio, socketio_path="") - namespace = config.get_event_namespace() - - if not namespace: - raise ValueError("event namespace must be provided in the config.") - - # Create the event namespace and attach the main app. Not related to any paths. - self.event_namespace = EventNamespace(namespace, self) - - # Register the event namespace with the socket. - self.sio.register_namespace(self.event_namespace) - # Mount the socket app with the API. - self.api.mount(str(constants.Endpoint.EVENT), self.socket_app) + self.setup_state() # Set up the admin dash. self.setup_admin_dash() - # If a State is not used and no overlay_component is specified, do not render the connection modal - if self.state is None and self.overlay_component is default_overlay_component: - self.overlay_component = None + def enable_state(self) -> None: + """Enable state for the app.""" + if not self.state: + self.state = State + self.setup_state() + + def setup_state(self) -> None: + """Set up the state for the app. + + Raises: + ValueError: If the event namespace is not provided in the config. + If the state has not been enabled. + """ + if not self.state: + return + + config = get_config() + + # Set up the state manager. + self._state_manager = StateManager.create(state=self.state) + + # Set up the Socket.IO AsyncServer. + self.sio = AsyncServer( + async_mode="asgi", + cors_allowed_origins=( + "*" + if config.cors_allowed_origins == ["*"] + else config.cors_allowed_origins + ), + cors_credentials=True, + max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE, + ping_interval=constants.Ping.INTERVAL, + ping_timeout=constants.Ping.TIMEOUT, + ) + + # Create the socket app. Note event endpoint constant replaces the default 'socket.io' path. + self.socket_app = ASGIApp(self.sio, socketio_path="") + namespace = config.get_event_namespace() + + if not namespace: + raise ValueError("event namespace must be provided in the config.") + + # Create the event namespace and attach the main app. Not related to any paths. + self.event_namespace = EventNamespace(namespace, self) + + # Register the event namespace with the socket. + self.sio.register_namespace(self.event_namespace) + # Mount the socket app with the API. + self.api.mount(str(constants.Endpoint.EVENT), self.socket_app) def __repr__(self) -> str: """Get the string representation of the app. @@ -430,21 +445,24 @@ class App(Base): # Check if the route given is valid verify_route_validity(route) - # Apply dynamic args to the route. - if self.state: - self.state.setup_dynamic_args(get_route_args(route)) + # Setup dynamic args for the route. + # this state assignment is only required for tests using the deprecated state kwarg for App + state = self.state if self.state else State + state.setup_dynamic_args(get_route_args(route)) # Generate the component if it is a callable. component = self._generate_component(component) - # Wrap the component in a fragment with optional overlay. - if self.overlay_component is not None: - component = Fragment.create( - self._generate_component(self.overlay_component), - component, - ) - else: - component = Fragment.create(component) + if self.state is None: + for var in component._get_vars(include_children=True): + if not var._var_data: + continue + if not var._var_data.state: + continue + self.enable_state() + break + + component = OverlayFragment.create(component) # Add meta information to the component. compiler_utils.add_meta( @@ -649,6 +667,28 @@ class App(Base): # By default, compile the app. return True + def _add_overlay_to_component(self, component: Component) -> Component: + if self.overlay_component is None: + return component + + children = component.children + overlay_component = self._generate_component(self.overlay_component) + + if children[0] == overlay_component: + return component + + # recreate OverlayFragment with overlay_component as first child + component = OverlayFragment.create(overlay_component, *children) + + return component + + def _setup_overlay_component(self): + """If a State is not used and no overlay_component is specified, do not render the connection modal.""" + if self.state is None and self.overlay_component is default_overlay_component: + self.overlay_component = None + for k, component in self.pages.items(): + self.pages[k] = self._add_overlay_to_component(component) + def compile(self): """compile_() is the new function for performing compilation. Reflex framework will call it automatically as needed. @@ -682,6 +722,8 @@ class App(Base): if not self._should_compile(): return + self._setup_overlay_component() + # Create a progress bar. progress = Progress( *Progress.get_default_columns()[:-1], diff --git a/reflex/app.pyi b/reflex/app.pyi index d826a5033..f71c5b013 100644 --- a/reflex/app.pyi +++ b/reflex/app.pyi @@ -64,6 +64,11 @@ Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]] def default_overlay_component() -> Component: ... +class OverlayFragment(Fragment): + @overload + @classmethod + def create(cls, *children, **props) -> "OverlayFragment": ... # type: ignore + class App(Base): pages: Dict[str, Component] stylesheets: List[str] @@ -122,6 +127,7 @@ class App(Base): def compile(self) -> None: ... def compile_(self) -> None: ... def modify_state(self, token: str) -> AsyncContextManager[State]: ... + def _setup_overlay_component(self) -> None: ... def _process_background( self, state: State, event: Event ) -> asyncio.Task | None: ... diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index ee66ecd4d..24bf83b36 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -33,9 +33,12 @@ class Bare(Component): def _render(self) -> Tag: return Tagless(contents=str(self.contents)) - def _get_vars(self) -> Iterator[Var]: + def _get_vars(self, include_children: bool = False) -> Iterator[Var]: """Walk all Vars used in this component. + Args: + include_children: Whether to include Vars from children. + Yields: The contents if it is a Var, otherwise nothing. """ diff --git a/reflex/components/component.py b/reflex/components/component.py index cf6f67c82..250cc1894 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -812,9 +812,12 @@ class Component(BaseComponent, ABC): event_args.extend(args) yield event_trigger, event_args - def _get_vars(self) -> list[Var]: + def _get_vars(self, include_children: bool = False) -> list[Var]: """Walk all Vars used in this component. + Args: + include_children: Whether to include Vars from children. + Returns: Each var referenced by the component (props, styles, event handlers). """ @@ -860,6 +863,14 @@ class Component(BaseComponent, ABC): var = Var.create_safe(comp_prop) if var._var_data is not None: vars.append(var) + + # Get Vars associated with children. + if include_children: + for child in self.children: + if not isinstance(child, Component): + continue + vars.extend(child._get_vars(include_children=include_children)) + return vars def _get_custom_code(self) -> str | None: diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 26baeacd9..7aa73e7b5 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -222,8 +222,8 @@ class Form(BaseHTML): )._replace(merge_var_data=ref_var._var_data) return form_refs - def _get_vars(self) -> Iterator[Var]: - yield from super()._get_vars() + def _get_vars(self, include_children: bool = True) -> Iterator[Var]: + yield from super()._get_vars(include_children=include_children) yield from self._get_form_refs().values() def _exclude_props(self) -> list[str]: diff --git a/reflex/config.py b/reflex/config.py index 9739c8ec2..cef3a76bd 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -219,6 +219,15 @@ class Config(Base): self._non_default_attributes.update(kwargs) self._replace_defaults(**kwargs) + @property + def module(self) -> str: + """Get the module name of the app. + + Returns: + The module name. + """ + return ".".join([self.app_name, self.app_name]) + @staticmethod def check_deprecated_values(**kwargs): """Check for deprecated config values. diff --git a/reflex/config.pyi b/reflex/config.pyi index 0193d6533..ce3dbcae7 100644 --- a/reflex/config.pyi +++ b/reflex/config.pyi @@ -99,6 +99,8 @@ class Config(Base): gunicorn_worker_class: Optional[str] = None, **kwargs ) -> None: ... + @property + def module(self) -> str: ... @staticmethod def check_deprecated_values(**kwargs) -> None: ... def update_from_env(self) -> None: ... diff --git a/reflex/state.py b/reflex/state.py index 20552249b..980b36c15 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2918,3 +2918,19 @@ def code_uses_state_contexts(javascript_code: str) -> bool: True if the code attempts to access a member of StateContexts. """ return bool("useContext(StateContexts" in javascript_code) + + +def reload_state_module( + module: str, + state: Type[BaseState] = State, +) -> None: + """Reset rx.State subclasses to avoid conflict when reloading. + + Args: + module: The module to reload. + state: Recursive argument for the state class to reload. + """ + for subclass in tuple(state.class_subclasses): + reload_state_module(module=module, state=subclass) + if subclass.__module__ == module and module is not None: + state.class_subclasses.remove(subclass) diff --git a/reflex/testing.py b/reflex/testing.py index 22d95f401..9a8f40f54 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -232,7 +232,6 @@ class AppHarness: State.get_class_substate.cache_clear() # Ensure the AppHarness test does not skip State assignment due to running via pytest os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None) - # self.app_module.app. self.app_module = reflex.utils.prerequisites.get_compiled_app(reload=True) self.app_instance = self.app_module.app if isinstance(self.app_instance._state_manager, StateManagerRedis): diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 5808726a9..7a286ca36 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -185,16 +185,16 @@ def get_app(reload: bool = False) -> ModuleType: "Cannot get the app module because `app_name` is not set in rxconfig! " "If this error occurs in a reflex test case, ensure that `get_app` is mocked." ) - module = ".".join([config.app_name, config.app_name]) + module = config.module sys.path.insert(0, os.getcwd()) app = __import__(module, fromlist=(constants.CompileVars.APP,)) + if reload: - from reflex.state import State + from reflex.state import reload_state_module # Reset rx.State subclasses to avoid conflict when reloading. - for subclass in tuple(State.class_subclasses): - if subclass.__module__ == module: - State.class_subclasses.remove(subclass) + reload_state_module(module=module) + # Reload the app module. importlib.reload(app) diff --git a/tests/test_app.py b/tests/test_app.py index 0da6c11f9..0838d4a90 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -20,6 +20,7 @@ from reflex import AdminDash, constants from reflex.app import ( App, ComponentCallable, + OverlayFragment, default_overlay_component, process, upload, @@ -1182,12 +1183,13 @@ def test_overlay_component( exp_page_child: The type of the expected child in the page fragment. """ app = App(state=state, overlay_component=overlay_component) + app._setup_overlay_component() if exp_page_child is None: assert app.overlay_component is None - elif isinstance(exp_page_child, Fragment): + elif isinstance(exp_page_child, OverlayFragment): assert app.overlay_component is not None generated_component = app._generate_component(app.overlay_component) # type: ignore - assert isinstance(generated_component, Fragment) + assert isinstance(generated_component, OverlayFragment) assert isinstance( generated_component.children[0], Cond, # ConnectionModal is a Cond under the hood @@ -1200,7 +1202,10 @@ def test_overlay_component( ) app.add_page(rx.box("Index"), route="/test") + # overlay components are wrapped during compile only + app._setup_overlay_component() page = app.pages["test"] + if exp_page_child is not None: assert len(page.children) == 3 children_types = (type(child) for child in page.children)