From 1603144c7dd0fa328bedc74db96202796e077552 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 21 Nov 2023 11:52:06 -0800 Subject: [PATCH] [REF-889] useContext per substate (#2149) --- integration/test_var_operations.py | 21 + .../.templates/jinja/web/pages/_app.js.jinja2 | 12 +- .../jinja/web/pages/index.js.jinja2 | 26 -- .../jinja/web/utils/context.js.jinja2 | 47 +- reflex/.templates/web/utils/state.js | 59 +-- reflex/app.py | 2 +- reflex/compiler/compiler.py | 34 +- reflex/compiler/templates.py | 3 +- reflex/compiler/utils.py | 7 +- reflex/components/base/bare.py | 18 +- reflex/components/component.py | 208 ++++++++- reflex/components/datadisplay/code.py | 3 +- reflex/components/datadisplay/code.pyi | 3 +- reflex/components/datadisplay/dataeditor.py | 3 +- reflex/components/datadisplay/dataeditor.pyi | 3 +- reflex/components/datadisplay/datatable.py | 12 +- reflex/components/datadisplay/datatable.pyi | 2 +- reflex/components/datadisplay/moment.py | 4 +- reflex/components/datadisplay/moment.pyi | 2 +- reflex/components/forms/colormodeswitch.py | 4 +- reflex/components/forms/colormodeswitch.pyi | 4 +- reflex/components/forms/debounce.py | 14 +- reflex/components/forms/debounce.pyi | 3 +- reflex/components/forms/editor.py | 3 +- reflex/components/forms/editor.pyi | 3 +- reflex/components/forms/form.py | 10 +- reflex/components/forms/form.pyi | 2 +- reflex/components/forms/input.py | 4 +- reflex/components/forms/input.pyi | 2 +- reflex/components/forms/pininput.py | 4 +- reflex/components/forms/upload.py | 37 +- reflex/components/forms/upload.pyi | 3 +- reflex/components/layout/cond.py | 51 ++- reflex/components/layout/html.py | 6 +- reflex/components/layout/html.pyi | 7 +- reflex/components/libs/chakra.py | 20 +- reflex/components/libs/chakra.pyi | 2 +- .../navigation/client_side_routing.py | 15 +- .../navigation/client_side_routing.pyi | 6 +- reflex/components/overlay/banner.py | 21 +- reflex/components/overlay/banner.pyi | 7 +- reflex/components/radix/themes/base.py | 4 +- reflex/components/radix/themes/base.pyi | 2 +- reflex/components/typography/markdown.py | 3 +- reflex/components/typography/markdown.pyi | 3 +- reflex/constants/__init__.py | 4 + reflex/constants/base.py | 2 + reflex/constants/compiler.py | 25 ++ reflex/middleware/hydrate_middleware.py | 2 +- reflex/state.py | 14 +- reflex/style.py | 78 +++- reflex/utils/format.py | 25 +- reflex/utils/imports.py | 45 +- reflex/utils/types.py | 1 + reflex/vars.py | 418 +++++++++++++----- reflex/vars.pyi | 38 +- tests/compiler/test_compiler.py | 2 +- tests/components/layout/test_cond.py | 2 +- tests/components/test_component.py | 164 ++++++- tests/middleware/test_hydrate_middleware.py | 6 +- tests/test_app.py | 4 +- tests/test_state.py | 39 +- tests/test_style.py | 3 +- tests/test_var.py | 66 ++- tests/utils/test_format.py | 65 +-- 65 files changed, 1257 insertions(+), 455 deletions(-) diff --git a/integration/test_var_operations.py b/integration/test_var_operations.py index 6eac82c4c..7da9d0bd0 100644 --- a/integration/test_var_operations.py +++ b/integration/test_var_operations.py @@ -28,6 +28,7 @@ def VarOperations(): str_var4: str = "a long string" dict1: dict = {1: 2} dict2: dict = {3: 4} + html_str: str = "
hello
" app = rx.App(state=VarOperationState) @@ -522,6 +523,19 @@ def VarOperations(): rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"), rx.text(VarOperationState.list3.join(""), id="list_join"), rx.text(VarOperationState.list3.join(","), id="list_join_comma"), + # Index from an op var + rx.text( + VarOperationState.list3[VarOperationState.int_var1 % 3], + id="list_index_mod", + ), + rx.html( + VarOperationState.html_str, + id="html_str", + ), + rx.highlight( + "second", + query=[VarOperationState.str_var2], + ), rx.text(rx.Var.range(2, 5).join(","), id="list_join_range1"), rx.text(rx.Var.range(2, 10, 2).join(","), id="list_join_range2"), rx.text(rx.Var.range(5, 0, -1).join(","), id="list_join_range3"), @@ -713,7 +727,14 @@ def test_var_operations(driver, var_operations: AppHarness): ("dict_eq_dict", "false"), ("dict_neq_dict", "true"), ("dict_contains", "true"), + # index from an op var + ("list_index_mod", "second"), + # html component with var + ("html_str", "hello"), ] for tag, expected in tests: assert driver.find_element(By.ID, tag).text == expected + + # Highlight component with var query (does not plumb ID) + assert driver.find_element(By.TAG_NAME, "mark").text == "second" diff --git a/reflex/.templates/jinja/web/pages/_app.js.jinja2 b/reflex/.templates/jinja/web/pages/_app.js.jinja2 index 4d3dff89a..deaf1a02b 100644 --- a/reflex/.templates/jinja/web/pages/_app.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/_app.js.jinja2 @@ -1,7 +1,7 @@ {% extends "web/pages/base_page.js.jinja2" %} {% block declaration %} -import { EventLoopProvider } from "/utils/context.js"; +import { EventLoopProvider, StateProvider } from "/utils/context.js"; import { ThemeProvider } from 'next-themes' {% for custom_code in custom_codes %} @@ -25,12 +25,14 @@ export default function MyApp({ Component, pageProps }) { return ( - - - + + + + + ); } -{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/reflex/.templates/jinja/web/pages/index.js.jinja2 b/reflex/.templates/jinja/web/pages/index.js.jinja2 index 6f73c70c4..efb086ef5 100644 --- a/reflex/.templates/jinja/web/pages/index.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/index.js.jinja2 @@ -8,32 +8,6 @@ {% block export %} export default function Component() { -{% if state_name %} - const {{state_name}} = useContext(StateContext) -{% endif %} - const {{const.router}} = useRouter() - const [ {{const.color_mode}}, {{const.toggle_color_mode}} ] = useContext(ColorModeContext) - const focusRef = useRef(); - - // Main event loop. - const [addEvents, connectError] = useContext(EventLoopContext) - - // Set focus to the specified element. - useEffect(() => { - if (focusRef.current) { - focusRef.current.focus(); - } - }) - - // Route after the initial page hydration. - useEffect(() => { - const change_complete = () => addEvents(initialEvents()) - {{const.router}}.events.on('routeChangeComplete', change_complete) - return () => { - {{const.router}}.events.off('routeChangeComplete', change_complete) - } - }, [{{const.router}}]) - {% for hook in hooks %} {{ hook }} {% endfor %} diff --git a/reflex/.templates/jinja/web/utils/context.js.jinja2 b/reflex/.templates/jinja/web/utils/context.js.jinja2 index c931b7515..53d7d4e58 100644 --- a/reflex/.templates/jinja/web/utils/context.js.jinja2 +++ b/reflex/.templates/jinja/web/utils/context.js.jinja2 @@ -1,5 +1,5 @@ -import { createContext, useState } from "react" -import { Event, hydrateClientStorage, useEventLoop } from "/utils/state.js" +import { createContext, useContext, useMemo, useReducer, useState } from "react" +import { applyDelta, Event, hydrateClientStorage, useEventLoop } from "/utils/state.js" {% if initial_state %} export const initialState = {{ initial_state|json_dumps }} @@ -8,7 +8,12 @@ export const initialState = {} {% endif %} export const ColorModeContext = createContext(null); -export const StateContext = createContext(null); +export const DispatchContext = createContext(null); +export const StateContexts = { + {% for state_name in initial_state %} + {{state_name|var_name}}: createContext(null), + {% endfor %} +} export const EventLoopContext = createContext(null); {% if client_storage %} export const clientStorage = {{ client_storage|json_dumps }} @@ -27,16 +32,40 @@ export const initialEvents = () => [] export const isDevMode = {{ is_dev_mode|json_dumps }} export function EventLoopProvider({ children }) { - const [state, addEvents, connectError] = useEventLoop( - initialState, + const dispatch = useContext(DispatchContext) + const [addEvents, connectError] = useEventLoop( + dispatch, initialEvents, clientStorage, ) return ( - - {children} - + {children} ) -} \ No newline at end of file +} + +export function StateProvider({ children }) { + {% for state_name in initial_state %} + const [{{state_name|var_name}}, dispatch_{{state_name|var_name}}] = useReducer(applyDelta, initialState["{{state_name}}"]) + {% endfor %} + const dispatchers = useMemo(() => { + return { + {% for state_name in initial_state %} + "{{state_name}}": dispatch_{{state_name|var_name}}, + {% endfor %} + } + }, []) + + return ( + {% for state_name in initial_state %} + + {% endfor %} + + {children} + + {% for state_name in initial_state|reverse %} + + {% endfor %} + ) +} diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 22c853dea..9ae46d3c9 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -6,7 +6,7 @@ import env from "env.json"; import Cookies from "universal-cookie"; import { useEffect, useReducer, useRef, useState } from "react"; import Router, { useRouter } from "next/router"; -import { initialEvents } from "utils/context.js" +import { initialEvents, initialState } from "utils/context.js" // Endpoint URLs. const EVENTURL = env.EVENT @@ -100,37 +100,10 @@ export const getEventURL = () => { * @param delta The delta to apply. */ export const applyDelta = (state, delta) => { - const new_state = { ...state } - for (const substate in delta) { - let s = new_state; - const path = substate.split(".").slice(1); - while (path.length > 0) { - s = s[path.shift()]; - } - for (const key in delta[substate]) { - s[key] = delta[substate][key]; - } - } - return new_state + return { ...state, ...delta } }; -/** - * Get all local storage items in a key-value object. - * @returns object of items in local storage. - */ -export const getAllLocalStorageItems = () => { - var localStorageItems = {}; - - for (var i = 0, len = localStorage.length; i < len; i++) { - var key = localStorage.key(i); - localStorageItems[key] = localStorage.getItem(key); - } - - return localStorageItems; -} - - /** * Handle frontend event or send the event to the backend via Websocket. * @param event The event to send. @@ -346,7 +319,9 @@ export const connect = async ( // On each received message, queue the updates and events. socket.current.on("event", message => { const update = JSON5.parse(message) - dispatch(update.delta) + for (const substate in update.delta) { + dispatch[substate](update.delta[substate]) + } applyClientStorageDelta(client_storage, update.delta) event_processing = !update.final if (update.events) { @@ -524,23 +499,21 @@ const applyClientStorageDelta = (client_storage, delta) => { /** * Establish websocket event loop for a NextJS page. - * @param initial_state The initial app state. - * @param initial_events Function that returns the initial app events. + * @param dispatch The reducer dispatch function to update state. + * @param initial_events The initial app events. * @param client_storage The client storage object from context.js * - * @returns [state, addEvents, connectError] - - * state is a reactive dict, + * @returns [addEvents, connectError] - * addEvents is used to queue an event, and * connectError is a reactive js error from the websocket connection (or null if connected). */ export const useEventLoop = ( - initial_state = {}, + dispatch, initial_events = () => [], client_storage = {}, ) => { const socket = useRef(null) const router = useRouter() - const [state, dispatch] = useReducer(applyDelta, initial_state) const [connectError, setConnectError] = useState(null) // Function to add new events to the event queue. @@ -570,7 +543,7 @@ export const useEventLoop = ( return; } // only use websockets if state is present - if (Object.keys(state).length > 0) { + if (Object.keys(initialState).length > 0) { // Initialize the websocket connection. if (!socket.current) { connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage) @@ -583,7 +556,17 @@ export const useEventLoop = ( })() } }) - return [state, addEvents, connectError] + + // Route after the initial page hydration. + useEffect(() => { + const change_complete = () => addEvents(initial_events()) + router.events.on('routeChangeComplete', change_complete) + return () => { + router.events.off('routeChangeComplete', change_complete) + } + }, [router]) + + return [addEvents, connectError] } /*** diff --git a/reflex/app.py b/reflex/app.py index 6248bcec0..01851d48d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -63,7 +63,7 @@ from reflex.state import ( StateUpdate, ) from reflex.utils import console, format, prerequisites, types -from reflex.vars import ImportVar +from reflex.utils.imports import ImportVar # Define custom types. ComponentCallable = Callable[[], Component] diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 4ffd4a876..3b9d7db3e 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -10,40 +10,10 @@ from reflex.compiler import templates, utils from reflex.components.component import Component, ComponentStyle, CustomComponent from reflex.config import get_config from reflex.state import State -from reflex.utils import imports -from reflex.vars import ImportVar +from reflex.utils.imports import ImportDict, ImportVar # Imports to be included in every Reflex app. -DEFAULT_IMPORTS: imports.ImportDict = { - "react": [ - ImportVar(tag="Fragment"), - ImportVar(tag="useEffect"), - ImportVar(tag="useRef"), - ImportVar(tag="useState"), - ImportVar(tag="useContext"), - ], - "next/router": [ImportVar(tag="useRouter")], - f"/{constants.Dirs.STATE_PATH}": [ - ImportVar(tag="uploadFiles"), - ImportVar(tag="Event"), - ImportVar(tag="isTrue"), - ImportVar(tag="spreadArraysOrObjects"), - ImportVar(tag="preventDefault"), - ImportVar(tag="refs"), - ImportVar(tag="getRefValue"), - ImportVar(tag="getRefValues"), - ImportVar(tag="getAllLocalStorageItems"), - ImportVar(tag="useEventLoop"), - ], - "/utils/context.js": [ - ImportVar(tag="EventLoopContext"), - ImportVar(tag="initialEvents"), - ImportVar(tag="StateContext"), - ImportVar(tag="ColorModeContext"), - ], - "/utils/helpers/range.js": [ - ImportVar(tag="range", is_default=True), - ], +DEFAULT_IMPORTS: ImportDict = { "": [ImportVar(tag="focus-visible/dist/focus-visible", install=False)], } diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index f2d1272aa..57bbb44b8 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -3,7 +3,7 @@ from jinja2 import Environment, FileSystemLoader, Template from reflex import constants -from reflex.utils.format import json_dumps +from reflex.utils.format import format_state_name, json_dumps class ReflexJinjaEnvironment(Environment): @@ -19,6 +19,7 @@ class ReflexJinjaEnvironment(Environment): ) self.filters["json_dumps"] = json_dumps self.filters["react_setter"] = lambda state: f"set{state.capitalize()}" + self.filters["var_name"] = format_state_name self.loader = FileSystemLoader(constants.Templates.Dirs.JINJA_TEMPLATE) self.globals["const"] = { "socket": constants.CompileVars.SOCKET, diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index b6967eaef..10a4691a4 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -24,13 +24,12 @@ from reflex.components.component import Component, ComponentStyle, CustomCompone from reflex.state import Cookie, LocalStorage, State from reflex.style import Style from reflex.utils import console, format, imports, path_ops -from reflex.vars import ImportVar # To re-export this function. merge_imports = imports.merge_imports -def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]: +def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]: """Compile an import statement. Args: @@ -343,7 +342,9 @@ def get_context_path() -> str: Returns: The path of the context module. """ - return os.path.join(constants.Dirs.WEB_UTILS, "context" + constants.Ext.JS) + return os.path.join( + constants.Dirs.WEB, constants.Dirs.CONTEXTS_PATH + constants.Ext.JS + ) def get_components_path() -> str: diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index 190e95e6a..ee66ecd4d 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -1,7 +1,7 @@ """A bare component.""" from __future__ import annotations -from typing import Any +from typing import Any, Iterator from reflex.components.component import Component from reflex.components.tags import Tag @@ -24,7 +24,21 @@ class Bare(Component): Returns: The component. """ - return cls(contents=str(contents)) # type: ignore + if isinstance(contents, Var) and contents._var_data: + contents = contents.to(str) + else: + contents = str(contents) + return cls(contents=contents) # type: ignore def _render(self) -> Tag: return Tagless(contents=str(self.contents)) + + def _get_vars(self) -> Iterator[Var]: + """Walk all Vars used in this component. + + Yields: + The contents if it is a Var, otherwise nothing. + """ + if isinstance(self.contents, Var): + # Fast path for Bare text components. + yield self.contents diff --git a/reflex/components/component.py b/reflex/components/component.py index d274fbc37..70097cf60 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -5,11 +5,11 @@ from __future__ import annotations import typing from abc import ABC from functools import lru_cache, wraps -from typing import Any, Callable, Dict, List, Optional, Set, Type, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Type, Union from reflex.base import Base from reflex.components.tags import Tag -from reflex.constants import Dirs, EventTriggers +from reflex.constants import Dirs, EventTriggers, Hooks, Imports from reflex.event import ( EventChain, EventHandler, @@ -20,8 +20,9 @@ from reflex.event import ( ) from reflex.style import Style from reflex.utils import console, format, imports, types +from reflex.utils.imports import ImportVar from reflex.utils.serializers import serializer -from reflex.vars import BaseVar, ImportVar, Var +from reflex.vars import BaseVar, Var class Component(Base, ABC): @@ -388,7 +389,11 @@ class Component(Base, ABC): props = props.copy() props.update( - self.event_triggers, + **{ + trigger: handler + for trigger, handler in self.event_triggers.items() + if trigger not in {EventTriggers.ON_MOUNT, EventTriggers.ON_UNMOUNT} + }, key=self.key, id=self.id, class_name=self.class_name, @@ -488,7 +493,7 @@ class Component(Base, ABC): """ if type(self) in style: # Extract the style for this component. - component_style = Style(style[type(self)]) + component_style = style[type(self)] # Only add style props that are not overridden. component_style = { @@ -564,6 +569,78 @@ class Component(Base, ABC): if self._valid_children: validate_valid_child(name) + @staticmethod + def _get_vars_from_event_triggers( + event_triggers: dict[str, EventChain | Var], + ) -> Iterator[tuple[str, list[Var]]]: + """Get the Vars associated with each event trigger. + + Args: + event_triggers: The event triggers from the component instance. + + Yields: + tuple of (event_name, event_vars) + """ + for event_trigger, event in event_triggers.items(): + if isinstance(event, Var): + yield event_trigger, [event] + elif isinstance(event, EventChain): + event_args = [] + for spec in event.events: + for args in spec.args: + event_args.extend(args) + yield event_trigger, event_args + + def _get_vars(self) -> list[Var]: + """Walk all Vars used in this component. + + Returns: + Each var referenced by the component (props, styles, event handlers). + """ + vars = getattr(self, "__vars", None) + if vars is not None: + return vars + vars = self.__vars = [] + # Get Vars associated with event trigger arguments. + for _, event_vars in self._get_vars_from_event_triggers(self.event_triggers): + vars.extend(event_vars) + + # Get Vars associated with component props. + for prop in self.get_props(): + prop_var = getattr(self, prop) + if isinstance(prop_var, Var): + vars.append(prop_var) + + # Style keeps track of its own VarData instance, so embed in a temp Var that is yielded. + if self.style: + vars.append( + BaseVar( + _var_name="style", + _var_type=str, + _var_data=self.style._var_data, + ) + ) + + # Special props are always Var instances. + vars.extend(self.special_props) + + # Get Vars associated with common Component props. + for comp_prop in ( + self.class_name, + self.id, + self.key, + self.autofocus, + *self.custom_attrs.values(), + ): + if isinstance(comp_prop, Var): + vars.append(comp_prop) + elif isinstance(comp_prop, str): + # Collapse VarData encoded in f-strings. + var = Var.create_safe(comp_prop) + if var._var_data is not None: + vars.append(var) + return vars + def _get_custom_code(self) -> str | None: """Get custom code for the component. @@ -644,6 +721,33 @@ class Component(Base, ABC): dep: [ImportVar(tag=None, render=False)] for dep in self.lib_dependencies } + def _get_hooks_imports(self) -> imports.ImportDict: + """Get the imports required by certain hooks. + + Returns: + The imports required for all selected hooks. + """ + _imports = {} + + if self._get_ref_hook(): + # Handle hooks needed for attaching react refs to DOM nodes. + _imports.setdefault("react", set()).add(ImportVar(tag="useRef")) + _imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs")) + + if self._get_mount_lifecycle_hook(): + # Handle hooks for `on_mount` / `on_unmount`. + _imports.setdefault("react", set()).add(ImportVar(tag="useEffect")) + + if self._get_special_hooks(): + # Handle additional internal hooks (autofocus, etc). + _imports.setdefault("react", set()).update( + { + ImportVar(tag="useRef"), + ImportVar(tag="useEffect"), + }, + ) + return _imports + def _get_imports(self) -> imports.ImportDict: """Get all the libraries and fields that are used by the component. @@ -651,13 +755,26 @@ class Component(Base, ABC): The imports needed by the component. """ _imports = {} + + # Import this component's tag from the main library. if self.library is not None and self.tag is not None: _imports[self.library] = {self.import_var} + # Get static imports required for event processing. + event_imports = Imports.EVENTS if self.event_triggers else {} + + # Collect imports from Vars used directly by this component. + var_imports = [ + var._var_data.imports for var in self._get_vars() if var._var_data + ] + return imports.merge_imports( *self._get_props_imports(), self._get_dependencies_imports(), + self._get_hooks_imports(), _imports, + event_imports, + *var_imports, ) def get_imports(self) -> imports.ImportDict: @@ -678,13 +795,13 @@ class Component(Base, ABC): """ # pop on_mount and on_unmount from event_triggers since these are handled by # hooks, not as actually props in the component - on_mount = self.event_triggers.pop(EventTriggers.ON_MOUNT, None) - on_unmount = self.event_triggers.pop(EventTriggers.ON_UNMOUNT, None) - if on_mount: + on_mount = self.event_triggers.get(EventTriggers.ON_MOUNT, None) + on_unmount = self.event_triggers.get(EventTriggers.ON_UNMOUNT, None) + if on_mount is not None: on_mount = format.format_event_chain(on_mount) - if on_unmount: + if on_unmount is not None: on_unmount = format.format_event_chain(on_unmount) - if on_mount or on_unmount: + if on_mount is not None or on_unmount is not None: return f""" useEffect(() => {{ {on_mount or ""} @@ -703,6 +820,47 @@ class Component(Base, ABC): if ref is not None: return f"const {ref} = useRef(null); refs['{ref}'] = {ref};" + def _get_vars_hooks(self) -> set[str]: + """Get the hooks required by vars referenced in this component. + + Returns: + The hooks for the vars. + """ + vars_hooks = set() + for var in self._get_vars(): + if var._var_data: + vars_hooks.update(var._var_data.hooks) + return vars_hooks + + def _get_events_hooks(self) -> set[str]: + """Get the hooks required by events referenced in this component. + + Returns: + The hooks for the events. + """ + if self.event_triggers: + return {Hooks.EVENTS} + return set() + + def _get_special_hooks(self) -> set[str]: + """Get the hooks required by special actions referenced in this component. + + Returns: + The hooks for special actions. + """ + if self.autofocus: + return { + """ + // Set focus to the specified element. + const focusRef = useRef(null) + useEffect(() => { + if (focusRef.current) { + focusRef.current.focus(); + } + })""", + } + return set() + def _get_hooks_internal(self) -> Set[str]: """Get the React hooks for this component managed by the framework. @@ -712,10 +870,15 @@ class Component(Base, ABC): Returns: Set of internally managed hooks. """ - return set( - hook - for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()] - if hook + return ( + set( + hook + for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()] + if hook + ) + | self._get_vars_hooks() + | self._get_events_hooks() + | self._get_special_hooks() ) def _get_hooks(self) -> str | None: @@ -1018,11 +1181,24 @@ class NoSSRComponent(Component): """A dynamic component that is not rendered on the server.""" def _get_imports(self) -> imports.ImportDict: - dynamic_import = {"next/dynamic": {ImportVar(tag="dynamic", is_default=True)}} + """Get the imports for the component. + + Returns: + The imports for dynamically importing the component at module load time. + """ + # Next.js dynamic import mechanism. + dynamic_import = {"next/dynamic": [ImportVar(tag="dynamic", is_default=True)]} + + # The normal imports for this component. + _imports = super()._get_imports() + + # Do NOT import the main library/tag statically. + if self.library is not None: + _imports[self.library] = [imports.ImportVar(tag=None, render=False)] return imports.merge_imports( dynamic_import, - {self.library: {ImportVar(tag=None, render=False)}}, + _imports, self._get_dependencies_imports(), ) diff --git a/reflex/components/datadisplay/code.py b/reflex/components/datadisplay/code.py index c56eff8d0..0c796cd08 100644 --- a/reflex/components/datadisplay/code.py +++ b/reflex/components/datadisplay/code.py @@ -12,7 +12,8 @@ from reflex.components.media import Icon from reflex.event import set_clipboard from reflex.style import Style from reflex.utils import format, imports -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var LiteralCodeBlockTheme = Literal[ "a11y-dark", diff --git a/reflex/components/datadisplay/code.pyi b/reflex/components/datadisplay/code.pyi index 79e756ed0..5a88faeb3 100644 --- a/reflex/components/datadisplay/code.pyi +++ b/reflex/components/datadisplay/code.pyi @@ -16,7 +16,8 @@ from reflex.components.media import Icon from reflex.event import set_clipboard from reflex.style import Style from reflex.utils import format, imports -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var LiteralCodeBlockTheme = Literal[ "a11y-dark", diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py index e666d284c..f4b327aa6 100644 --- a/reflex/components/datadisplay/dataeditor.py +++ b/reflex/components/datadisplay/dataeditor.py @@ -8,8 +8,9 @@ from reflex.base import Base from reflex.components.component import Component, NoSSRComponent from reflex.components.literals import LiteralRowMarker from reflex.utils import console, format, imports, types +from reflex.utils.imports import ImportVar from reflex.utils.serializers import serializer -from reflex.vars import ImportVar, Var, get_unique_variable_name +from reflex.vars import Var, get_unique_variable_name # TODO: Fix the serialization issue for custom types. diff --git a/reflex/components/datadisplay/dataeditor.pyi b/reflex/components/datadisplay/dataeditor.pyi index c951e5c77..af209f422 100644 --- a/reflex/components/datadisplay/dataeditor.pyi +++ b/reflex/components/datadisplay/dataeditor.pyi @@ -13,8 +13,9 @@ from reflex.base import Base from reflex.components.component import Component, NoSSRComponent from reflex.components.literals import LiteralRowMarker from reflex.utils import console, format, imports, types +from reflex.utils.imports import ImportVar from reflex.utils.serializers import serializer -from reflex.vars import ImportVar, Var, get_unique_variable_name +from reflex.vars import Var, get_unique_variable_name class GridColumnIcons(Enum): Array = "array" diff --git a/reflex/components/datadisplay/datatable.py b/reflex/components/datadisplay/datatable.py index 52bd45282..5b66fa54b 100644 --- a/reflex/components/datadisplay/datatable.py +++ b/reflex/components/datadisplay/datatable.py @@ -8,7 +8,7 @@ from reflex.components.component import Component from reflex.components.tags import Tag from reflex.utils import imports, types from reflex.utils.serializers import serialize, serializer -from reflex.vars import BaseVar, ComputedVar, ImportVar, Var +from reflex.vars import BaseVar, ComputedVar, Var class Gridjs(Component): @@ -105,7 +105,7 @@ class DataTable(Gridjs): def _get_imports(self) -> imports.ImportDict: return imports.merge_imports( super()._get_imports(), - {"": {ImportVar(tag="gridjs/dist/theme/mermaid.css")}}, + {"": {imports.ImportVar(tag="gridjs/dist/theme/mermaid.css")}}, ) def _render(self) -> Tag: @@ -113,13 +113,13 @@ class DataTable(Gridjs): self.columns = BaseVar( _var_name=f"{self.data._var_name}.columns", _var_type=List[Any], - _var_state=self.data._var_state, - ) + _var_full_name_needs_state_prefix=True, + )._replace(merge_var_data=self.data._var_data) self.data = BaseVar( _var_name=f"{self.data._var_name}.data", _var_type=List[List[Any]], - _var_state=self.data._var_state, - ) + _var_full_name_needs_state_prefix=True, + )._replace(merge_var_data=self.data._var_data) if types.is_dataframe(type(self.data)): # If given a pandas df break up the data and columns data = serialize(self.data) diff --git a/reflex/components/datadisplay/datatable.pyi b/reflex/components/datadisplay/datatable.pyi index 1a119dcc9..49e3ad752 100644 --- a/reflex/components/datadisplay/datatable.pyi +++ b/reflex/components/datadisplay/datatable.pyi @@ -12,7 +12,7 @@ from reflex.components.component import Component from reflex.components.tags import Tag from reflex.utils import imports, types from reflex.utils.serializers import serialize, serializer -from reflex.vars import BaseVar, ComputedVar, ImportVar, Var +from reflex.vars import BaseVar, ComputedVar, Var class Gridjs(Component): @overload diff --git a/reflex/components/datadisplay/moment.py b/reflex/components/datadisplay/moment.py index 9ae8381bc..31ffb5ffa 100644 --- a/reflex/components/datadisplay/moment.py +++ b/reflex/components/datadisplay/moment.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List from reflex.components.component import Component, NoSSRComponent from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var class Moment(NoSSRComponent): @@ -78,7 +78,7 @@ class Moment(NoSSRComponent): if self.tz is not None: merged_imports = imports.merge_imports( merged_imports, - {"moment-timezone": {ImportVar(tag="")}}, + {"moment-timezone": {imports.ImportVar(tag="")}}, ) return merged_imports diff --git a/reflex/components/datadisplay/moment.pyi b/reflex/components/datadisplay/moment.pyi index 0e9fcc4c4..2c7696037 100644 --- a/reflex/components/datadisplay/moment.pyi +++ b/reflex/components/datadisplay/moment.pyi @@ -10,7 +10,7 @@ from reflex.style import Style from typing import Any, Dict, List from reflex.components.component import Component, NoSSRComponent from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var class Moment(NoSSRComponent): def get_event_triggers(self) -> Dict[str, Any]: ... diff --git a/reflex/components/forms/colormodeswitch.py b/reflex/components/forms/colormodeswitch.py index 0aeefcd33..ca071a2a7 100644 --- a/reflex/components/forms/colormodeswitch.py +++ b/reflex/components/forms/colormodeswitch.py @@ -22,7 +22,7 @@ from reflex.components.component import Component from reflex.components.layout.cond import Cond, cond from reflex.components.media.icon import Icon from reflex.style import color_mode, toggle_color_mode -from reflex.vars import BaseVar +from reflex.vars import Var from .button import Button from .switch import Switch @@ -32,7 +32,7 @@ DEFAULT_LIGHT_ICON: Icon = Icon.create(tag="sun") DEFAULT_DARK_ICON: Icon = Icon.create(tag="moon") -def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component: +def color_mode_cond(light: Any, dark: Any = None) -> Var | Component: """Create a component or Prop based on color_mode. Args: diff --git a/reflex/components/forms/colormodeswitch.pyi b/reflex/components/forms/colormodeswitch.pyi index 478bf5c37..83af8f20c 100644 --- a/reflex/components/forms/colormodeswitch.pyi +++ b/reflex/components/forms/colormodeswitch.pyi @@ -12,7 +12,7 @@ from reflex.components.component import Component from reflex.components.layout.cond import Cond, cond from reflex.components.media.icon import Icon from reflex.style import color_mode, toggle_color_mode -from reflex.vars import BaseVar +from reflex.vars import Var from .button import Button from .switch import Switch @@ -20,7 +20,7 @@ DEFAULT_COLOR_MODE: str DEFAULT_LIGHT_ICON: Icon DEFAULT_DARK_ICON: Icon -def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component: ... +def color_mode_cond(light: Any, dark: Any = None) -> Var | Component: ... class ColorModeIcon(Cond): @overload diff --git a/reflex/components/forms/debounce.py b/reflex/components/forms/debounce.py index da318a1cb..1618d29ba 100644 --- a/reflex/components/forms/debounce.py +++ b/reflex/components/forms/debounce.py @@ -1,10 +1,11 @@ """Wrapper around react-debounce-input.""" from __future__ import annotations -from typing import Any +from typing import Any, Set from reflex.components import Component from reflex.components.tags import Tag +from reflex.utils import imports from reflex.vars import Var @@ -77,6 +78,17 @@ class DebounceInput(Component): object.__setattr__(child, "render", lambda: "") return tag + def _get_imports(self) -> imports.ImportDict: + return imports.merge_imports( + super()._get_imports(), *[c._get_imports() for c in self.children] + ) + + def _get_hooks_internal(self) -> Set[str]: + hooks = super()._get_hooks_internal() + for child in self.children: + hooks.update(child._get_hooks_internal()) + return hooks + def props_not_none(c: Component) -> dict[str, Any]: """Get all properties of the component that are not None. diff --git a/reflex/components/forms/debounce.pyi b/reflex/components/forms/debounce.pyi index 8c2688f94..975aa8be2 100644 --- a/reflex/components/forms/debounce.pyi +++ b/reflex/components/forms/debounce.pyi @@ -7,9 +7,10 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.vars import Var, BaseVar, ComputedVar from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style -from typing import Any +from typing import Any, Set from reflex.components import Component from reflex.components.tags import Tag +from reflex.utils import imports from reflex.vars import Var class DebounceInput(Component): diff --git a/reflex/components/forms/editor.py b/reflex/components/forms/editor.py index b7bfb08a5..92a1e80c3 100644 --- a/reflex/components/forms/editor.py +++ b/reflex/components/forms/editor.py @@ -8,7 +8,8 @@ from reflex.base import Base from reflex.components.component import Component, NoSSRComponent from reflex.constants import EventTriggers from reflex.utils.format import to_camel_case -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var class EditorButtonList(list, enum.Enum): diff --git a/reflex/components/forms/editor.pyi b/reflex/components/forms/editor.pyi index 2b806fdef..1eaeea038 100644 --- a/reflex/components/forms/editor.pyi +++ b/reflex/components/forms/editor.pyi @@ -13,7 +13,8 @@ from reflex.base import Base from reflex.components.component import Component, NoSSRComponent from reflex.constants import EventTriggers from reflex.utils.format import to_camel_case -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var class EditorButtonList(list, enum.Enum): BASIC = [["font", "fontSize"], ["fontColor"], ["horizontalRule"], ["link", "image"]] diff --git a/reflex/components/forms/form.py b/reflex/components/forms/form.py index ee793da2f..301fec8e0 100644 --- a/reflex/components/forms/form.py +++ b/reflex/components/forms/form.py @@ -8,7 +8,7 @@ from jinja2 import Environment from reflex.components.component import Component from reflex.components.libs.chakra import ChakraComponent from reflex.components.tags import Tag -from reflex.constants import EventTriggers +from reflex.constants import Dirs, EventTriggers from reflex.event import EventChain from reflex.utils import imports from reflex.utils.format import format_event_chain, to_camel_case @@ -65,7 +65,13 @@ class Form(ChakraComponent): def _get_imports(self) -> imports.ImportDict: return imports.merge_imports( super()._get_imports(), - {"react": {imports.ImportVar(tag="useCallback")}}, + { + "react": {imports.ImportVar(tag="useCallback")}, + f"/{Dirs.STATE_PATH}": { + imports.ImportVar(tag="getRefValue"), + imports.ImportVar(tag="getRefValues"), + }, + }, ) def _get_hooks(self) -> str | None: diff --git a/reflex/components/forms/form.pyi b/reflex/components/forms/form.pyi index 827b2273d..e1a9c325a 100644 --- a/reflex/components/forms/form.pyi +++ b/reflex/components/forms/form.pyi @@ -12,7 +12,7 @@ from jinja2 import Environment from reflex.components.component import Component from reflex.components.libs.chakra import ChakraComponent from reflex.components.tags import Tag -from reflex.constants import EventTriggers +from reflex.constants import Dirs, EventTriggers from reflex.event import EventChain from reflex.utils import imports from reflex.utils.format import format_event_chain, to_camel_case diff --git a/reflex/components/forms/input.py b/reflex/components/forms/input.py index b4af66b02..7a2f5e4c8 100644 --- a/reflex/components/forms/input.py +++ b/reflex/components/forms/input.py @@ -11,7 +11,7 @@ from reflex.components.libs.chakra import ( ) from reflex.constants import EventTriggers from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var class Input(ChakraComponent): @@ -61,7 +61,7 @@ class Input(ChakraComponent): def _get_imports(self) -> imports.ImportDict: return imports.merge_imports( super()._get_imports(), - {"/utils/state": {ImportVar(tag="set_val")}}, + {"/utils/state": {imports.ImportVar(tag="set_val")}}, ) def get_event_triggers(self) -> Dict[str, Any]: diff --git a/reflex/components/forms/input.pyi b/reflex/components/forms/input.pyi index 6ec79fe1f..d49a6a6df 100644 --- a/reflex/components/forms/input.pyi +++ b/reflex/components/forms/input.pyi @@ -17,7 +17,7 @@ from reflex.components.libs.chakra import ( ) from reflex.constants import EventTriggers from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var class Input(ChakraComponent): def get_event_triggers(self) -> Dict[str, Any]: ... diff --git a/reflex/components/forms/pininput.py b/reflex/components/forms/pininput.py index c090e1571..83323c25c 100644 --- a/reflex/components/forms/pininput.py +++ b/reflex/components/forms/pininput.py @@ -68,9 +68,11 @@ class PinInput(ChakraComponent): Returns: The merged import dict. """ + range_var = Var.range(0) return merge_imports( super()._get_imports(), PinInputField().get_imports(), # type: ignore + range_var._var_data.imports if range_var._var_data is not None else {}, ) def get_event_triggers(self) -> dict[str, Union[Var, Any]]: @@ -117,7 +119,7 @@ class PinInput(ChakraComponent): ) refs_declaration._var_is_local = True if ref: - return f"const {ref} = {refs_declaration}" + return f"const {ref} = {str(refs_declaration)}" return super()._get_ref_hook() def _render(self) -> Tag: diff --git a/reflex/components/forms/upload.py b/reflex/components/forms/upload.py index bd01d01c6..b79e0b106 100644 --- a/reflex/components/forms/upload.py +++ b/reflex/components/forms/upload.py @@ -7,9 +7,10 @@ from reflex import constants from reflex.components.component import Component from reflex.components.forms.input import Input from reflex.components.layout.box import Box +from reflex.constants import Dirs from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script from reflex.utils import imports -from reflex.vars import BaseVar, CallableVar, ImportVar, Var +from reflex.vars import BaseVar, CallableVar, Var, VarData DEFAULT_UPLOAD_ID: str = "default" @@ -30,6 +31,13 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: return BaseVar( _var_name=f"e => upload_files.{id_}[1]((files) => e)", _var_type=EventChain, + _var_data=VarData( # type: ignore + imports={ + f"/{Dirs.STATE_PATH}": { + imports.ImportVar(tag="upload_files"), + }, + }, + ), ) @@ -46,6 +54,13 @@ def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: return BaseVar( _var_name=f"(upload_files.{id_} ? upload_files.{id_}[0]?.map((f) => (f.path || f.name)) : [])", _var_type=List[str], + _var_data=VarData( # type: ignore + imports={ + f"/{Dirs.STATE_PATH}": { + imports.ImportVar(tag="upload_files"), + }, + }, + ), ) @@ -166,14 +181,16 @@ class Upload(Component): def _get_hooks(self) -> str | None: return ( - (super()._get_hooks() or "") - + f""" - upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]); - """ - ) + super()._get_hooks() or "" + ) + f"upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);" def _get_imports(self) -> imports.ImportDict: - return { - **super()._get_imports(), - f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="upload_files")], - } + return imports.merge_imports( + super()._get_imports(), + { + "react": {imports.ImportVar(tag="useState")}, + f"/{constants.Dirs.STATE_PATH}": [ + imports.ImportVar(tag="upload_files") + ], + }, + ) diff --git a/reflex/components/forms/upload.pyi b/reflex/components/forms/upload.pyi index 87e5d58c7..5486ee90d 100644 --- a/reflex/components/forms/upload.pyi +++ b/reflex/components/forms/upload.pyi @@ -12,9 +12,10 @@ from reflex import constants from reflex.components.component import Component from reflex.components.forms.input import Input from reflex.components.layout.box import Box +from reflex.constants import Dirs from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script from reflex.utils import imports -from reflex.vars import BaseVar, CallableVar, ImportVar, Var +from reflex.vars import BaseVar, CallableVar, Var, VarData DEFAULT_UPLOAD_ID: str diff --git a/reflex/components/layout/cond.py b/reflex/components/layout/cond.py index 2455433d7..37b86fcc2 100644 --- a/reflex/components/layout/cond.py +++ b/reflex/components/layout/cond.py @@ -1,13 +1,18 @@ """Create a list of components from an iterable.""" from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, overload from reflex.components.component import Component from reflex.components.layout.fragment import Fragment from reflex.components.tags import CondTag, Tag -from reflex.utils import format -from reflex.vars import Var +from reflex.constants import Dirs +from reflex.utils import format, imports +from reflex.vars import BaseVar, Var, VarData + +_IS_TRUE_IMPORT = { + f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")}, +} class Cond(Component): @@ -88,6 +93,28 @@ class Cond(Component): cond_state=f"isTrue({self.cond._var_full_name})", ) + def _get_imports(self) -> imports.ImportDict: + return imports.merge_imports( + super()._get_imports(), + getattr(self.cond._var_data, "imports", {}), + _IS_TRUE_IMPORT, + ) + + +@overload +def cond(condition: Any, c1: Component, c2: Any) -> Component: + ... + + +@overload +def cond(condition: Any, c1: Component) -> Component: + ... + + +@overload +def cond(condition: Any, c1: Any, c2: Any) -> Var: + ... + def cond(condition: Any, c1: Any, c2: Any = None): """Create a conditional component or Prop. @@ -103,8 +130,11 @@ def cond(condition: Any, c1: Any, c2: Any = None): Raises: ValueError: If the arguments are invalid. """ - # Import here to avoid circular imports. - from reflex.vars import BaseVar, Var + var_datas: list[VarData | None] = [ + VarData( # type: ignore + imports=_IS_TRUE_IMPORT, + ), + ] # Convert the condition to a Var. cond_var = Var.create(condition) @@ -116,16 +146,20 @@ def cond(condition: Any, c1: Any, c2: Any = None): c2, Component ), "Both arguments must be components." return Cond.create(cond_var, c1, c2) + if isinstance(c1, Var): + var_datas.append(c1._var_data) - # Otherwise, create a conditionl Var. + # Otherwise, create a conditional Var. # Check that the second argument is valid. if isinstance(c2, Component): raise ValueError("Both arguments must be props.") if c2 is None: raise ValueError("For conditional vars, the second argument must be set.") + if isinstance(c2, Var): + var_datas.append(c2._var_data) # Create the conditional var. - return BaseVar( + return cond_var._replace( _var_name=format.format_cond( cond=cond_var._var_full_name, true_value=c1, @@ -133,4 +167,7 @@ def cond(condition: Any, c1: Any, c2: Any = None): is_prop=True, ), _var_type=c1._var_type if isinstance(c1, BaseVar) else type(c1), + _var_is_local=False, + _var_full_name_needs_state_prefix=False, + merge_var_data=VarData.merge(*var_datas), ) diff --git a/reflex/components/layout/html.py b/reflex/components/layout/html.py index 893e2ecca..3a4ba76ad 100644 --- a/reflex/components/layout/html.py +++ b/reflex/components/layout/html.py @@ -1,8 +1,8 @@ """A html component.""" - -from typing import Any +from typing import Dict from reflex.components.layout.box import Box +from reflex.vars import Var class Html(Box): @@ -13,7 +13,7 @@ class Html(Box): """ # The HTML to render. - dangerouslySetInnerHTML: Any + dangerouslySetInnerHTML: Var[Dict[str, str]] @classmethod def create(cls, *children, **props): diff --git a/reflex/components/layout/html.pyi b/reflex/components/layout/html.pyi index ec513ca29..aec46e2f9 100644 --- a/reflex/components/layout/html.pyi +++ b/reflex/components/layout/html.pyi @@ -7,8 +7,9 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.vars import Var, BaseVar, ComputedVar from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style -from typing import Any +from typing import Dict from reflex.components.layout.box import Box +from reflex.vars import Var class Html(Box): @overload @@ -16,7 +17,9 @@ class Html(Box): def create( # type: ignore cls, *children, - dangerouslySetInnerHTML: Optional[Any] = None, + dangerouslySetInnerHTML: Optional[ + Union[Var[Dict[str, str]], Dict[str, str]] + ] = None, element: Optional[Union[Var[str], str]] = None, src: Optional[Union[Var[str], str]] = None, alt: Optional[Union[Var[str], str]] = None, diff --git a/reflex/components/libs/chakra.py b/reflex/components/libs/chakra.py index fdab62fd4..34bfff1d5 100644 --- a/reflex/components/libs/chakra.py +++ b/reflex/components/libs/chakra.py @@ -6,7 +6,7 @@ from typing import List, Literal from reflex.components.component import Component from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var class ChakraComponent(Component): @@ -34,7 +34,7 @@ class ChakraComponent(Component): The dependencies imports of the component. """ return { - dep: [ImportVar(tag=None, render=False)] + dep: [imports.ImportVar(tag=None, render=False)] for dep in [ "@chakra-ui/system@2.5.7", "framer-motion@10.16.4", @@ -75,17 +75,17 @@ class ChakraProvider(ChakraComponent): ) def _get_imports(self) -> imports.ImportDict: - imports = super()._get_imports() - imports.setdefault(self.__fields__["library"].default, []).append( - ImportVar(tag="extendTheme", is_default=False), + _imports = super()._get_imports() + _imports.setdefault(self.__fields__["library"].default, []).append( + imports.ImportVar(tag="extendTheme", is_default=False), ) - imports.setdefault("/utils/theme.js", []).append( - ImportVar(tag="theme", is_default=True), + _imports.setdefault("/utils/theme.js", []).append( + imports.ImportVar(tag="theme", is_default=True), ) - imports.setdefault(Global.__fields__["library"].default, []).append( - ImportVar(tag="css", is_default=False), + _imports.setdefault(Global.__fields__["library"].default, []).append( + imports.ImportVar(tag="css", is_default=False), ) - return imports + return _imports def _get_custom_code(self) -> str | None: return """ diff --git a/reflex/components/libs/chakra.pyi b/reflex/components/libs/chakra.pyi index 967dcda35..d410a1c9f 100644 --- a/reflex/components/libs/chakra.pyi +++ b/reflex/components/libs/chakra.pyi @@ -11,7 +11,7 @@ from functools import lru_cache from typing import List, Literal from reflex.components.component import Component from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var class ChakraComponent(Component): @overload diff --git a/reflex/components/navigation/client_side_routing.py b/reflex/components/navigation/client_side_routing.py index 22620e54b..99f0f6cfd 100644 --- a/reflex/components/navigation/client_side_routing.py +++ b/reflex/components/navigation/client_side_routing.py @@ -10,10 +10,9 @@ routeNotFound becomes true. from __future__ import annotations from reflex import constants - -from ...vars import Var -from ..component import Component -from ..layout.cond import Cond +from reflex.components.component import Component +from reflex.components.layout.cond import cond +from reflex.vars import Var route_not_found: Var = Var.create_safe(constants.ROUTE_NOT_FOUND) @@ -52,10 +51,10 @@ def wait_for_client_redirect(component) -> Component: Returns: The conditionally rendered component. """ - return Cond.create( - cond=route_not_found, - comp1=component, - comp2=ClientSideRouting.create(), + return cond( + condition=route_not_found, + c1=component, + c2=ClientSideRouting.create(), ) diff --git a/reflex/components/navigation/client_side_routing.pyi b/reflex/components/navigation/client_side_routing.pyi index 100d0adb2..b7801246e 100644 --- a/reflex/components/navigation/client_side_routing.pyi +++ b/reflex/components/navigation/client_side_routing.pyi @@ -8,9 +8,9 @@ from reflex.vars import Var, BaseVar, ComputedVar from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style from reflex import constants -from ...vars import Var -from ..component import Component -from ..layout.cond import Cond +from reflex.components.component import Component +from reflex.components.layout.cond import cond +from reflex.vars import Var route_not_found: Var diff --git a/reflex/components/overlay/banner.py b/reflex/components/overlay/banner.py index cdb01c063..d690f3a9c 100644 --- a/reflex/components/overlay/banner.py +++ b/reflex/components/overlay/banner.py @@ -5,22 +5,27 @@ from typing import Optional from reflex.components.base.bare import Bare from reflex.components.component import Component -from reflex.components.layout import Box, Cond +from reflex.components.layout import Box, cond from reflex.components.overlay.modal import Modal from reflex.components.typography import Text +from reflex.constants import Hooks, Imports from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var, VarData + +connect_error_var_data: VarData = VarData( # type: ignore + imports=Imports.EVENTS, + hooks={Hooks.EVENTS}, +) connection_error: Var = Var.create_safe( value="(connectError !== null) ? connectError.message : ''", _var_is_local=False, _var_is_string=False, -) +)._replace(merge_var_data=connect_error_var_data) has_connection_error: Var = Var.create_safe( value="connectError !== null", _var_is_string=False, -) -has_connection_error._var_type = bool +)._replace(_var_type=bool, merge_var_data=connect_error_var_data) class WebsocketTargetURL(Bare): @@ -28,7 +33,7 @@ class WebsocketTargetURL(Bare): def _get_imports(self) -> imports.ImportDict: return { - "/utils/state.js": [ImportVar(tag="getEventURL")], + "/utils/state.js": [imports.ImportVar(tag="getEventURL")], } @classmethod @@ -78,7 +83,7 @@ class ConnectionBanner(Component): textAlign="center", ) - return Cond.create(has_connection_error, comp) + return cond(has_connection_error, comp) class ConnectionModal(Component): @@ -96,7 +101,7 @@ class ConnectionModal(Component): """ if not comp: comp = Text.create(*default_connection_error()) - return Cond.create( + return cond( has_connection_error, Modal.create( header="Connection Error", diff --git a/reflex/components/overlay/banner.pyi b/reflex/components/overlay/banner.pyi index 4e855ae1d..db3cafd04 100644 --- a/reflex/components/overlay/banner.pyi +++ b/reflex/components/overlay/banner.pyi @@ -10,15 +10,16 @@ from reflex.style import Style from typing import Optional from reflex.components.base.bare import Bare from reflex.components.component import Component -from reflex.components.layout import Box, Cond +from reflex.components.layout import Box, cond from reflex.components.overlay.modal import Modal from reflex.components.typography import Text +from reflex.constants import Hooks, Imports from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var, VarData +connect_error_var_data: VarData connection_error: Var has_connection_error: Var -has_connection_error._var_type = bool class WebsocketTargetURL(Bare): @overload diff --git a/reflex/components/radix/themes/base.py b/reflex/components/radix/themes/base.py index 3e589faca..c75a426a4 100644 --- a/reflex/components/radix/themes/base.py +++ b/reflex/components/radix/themes/base.py @@ -6,7 +6,7 @@ from typing import Literal from reflex.components import Component from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"] LiteralJustify = Literal["start", "center", "end", "between"] @@ -147,7 +147,7 @@ class Theme(RadixThemesComponent): def _get_imports(self) -> imports.ImportDict: return { **super()._get_imports(), - "": [ImportVar(tag="@radix-ui/themes/styles.css", install=False)], + "": [imports.ImportVar(tag="@radix-ui/themes/styles.css", install=False)], } diff --git a/reflex/components/radix/themes/base.pyi b/reflex/components/radix/themes/base.pyi index eb8b8cb30..9f840f2a8 100644 --- a/reflex/components/radix/themes/base.pyi +++ b/reflex/components/radix/themes/base.pyi @@ -10,7 +10,7 @@ from reflex.style import Style from typing import Literal from reflex.components import Component from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"] LiteralJustify = Literal["start", "center", "end", "between"] diff --git a/reflex/components/typography/markdown.py b/reflex/components/typography/markdown.py index f414f1781..f8841a75d 100644 --- a/reflex/components/typography/markdown.py +++ b/reflex/components/typography/markdown.py @@ -14,7 +14,8 @@ from reflex.components.typography.heading import Heading from reflex.components.typography.text import Text from reflex.style import Style from reflex.utils import console, imports, types -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var # Special vars used in the component map. _CHILDREN = Var.create_safe("children", _var_is_local=False) diff --git a/reflex/components/typography/markdown.pyi b/reflex/components/typography/markdown.pyi index ee11bf255..cb9140435 100644 --- a/reflex/components/typography/markdown.pyi +++ b/reflex/components/typography/markdown.pyi @@ -18,7 +18,8 @@ from reflex.components.typography.heading import Heading from reflex.components.typography.text import Text from reflex.style import Style from reflex.utils import console, imports, types -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var _CHILDREN = Var.create_safe("children", _var_is_local=False) _PROPS = Var.create_safe("...props", _var_is_local=False) diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index 628f7511b..ca408c6a6 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -22,6 +22,8 @@ from .compiler import ( CompileVars, ComponentName, Ext, + Hooks, + Imports, PageNames, ) from .config import ( @@ -68,7 +70,9 @@ __ALL__ = [ Ext, Fnm, GitIgnore, + Hooks, RequirementsTxt, + Imports, IS_WINDOWS, LOCAL_STORAGE, LogLevel, diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 8957edfb7..0b28e18cb 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -29,6 +29,8 @@ class Dirs(SimpleNamespace): STATE_PATH = "/".join([UTILS, "state"]) # The name of the components file. COMPONENTS_PATH = "/".join([UTILS, "components"]) + # The name of the contexts file. + CONTEXTS_PATH = "/".join([UTILS, "context"]) # The directory where the app pages are compiled to. WEB_PAGES = os.path.join(WEB, "pages") # The directory where the static build is located. diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 4a9d09d4c..e309c5d4a 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -2,6 +2,9 @@ from enum import Enum from types import SimpleNamespace +from reflex.constants import Dirs +from reflex.utils.imports import ImportVar + # The prefix used to create setters for state vars. SETTER_PREFIX = "set_" @@ -47,6 +50,12 @@ class CompileVars(SimpleNamespace): HYDRATE = "hydrate" # The name of the is_hydrated variable. IS_HYDRATED = "is_hydrated" + # The name of the function to add events to the queue. + ADD_EVENTS = "addEvents" + # The name of the var storing any connection error. + CONNECT_ERROR = "connectError" + # The name of the function for converting a dict to an event. + TO_EVENT = "Event" class PageNames(SimpleNamespace): @@ -77,3 +86,19 @@ class ComponentName(Enum): The lower-case filename with zip extension. """ return self.value.lower() + Ext.ZIP + + +class Imports(SimpleNamespace): + """Common sets of import vars.""" + + EVENTS = { + "react": {ImportVar(tag="useContext")}, + f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")}, + f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)}, + } + + +class Hooks(SimpleNamespace): + """Common sets of hook declarations.""" + + EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);" diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 3992919a2..38d5fb14f 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -48,7 +48,7 @@ class HydrateMiddleware(Middleware): setattr(var_state, var_name, value) # Get the initial state. - delta = format.format_state({state.get_name(): state.dict()}) + delta = format.format_state(state.dict()) # since a full dict was captured, clean any dirtiness state._clean() diff --git a/reflex/state.py b/reflex/state.py index 11bb5e9a4..424ead87c 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1211,12 +1211,16 @@ class State(Base, ABC, extra=pydantic.Extra.allow): if include_computed else {} ) - substate_vars = { - k: v.dict(include_computed=include_computed, **kwargs) - for k, v in self.substates.items() + variables = {**base_vars, **computed_vars} + d = { + self.get_full_name(): {k: variables[k] for k in sorted(variables)}, } - variables = {**base_vars, **computed_vars, **substate_vars} - return {k: variables[k] for k in sorted(variables)} + for substate_d in [ + v.dict(include_computed=include_computed, **kwargs) + for v in self.substates.values() + ]: + d.update(substate_d) + return d async def __aenter__(self) -> State: """Enter the async context manager protocol. diff --git a/reflex/style.py b/reflex/style.py index bb8320163..d67029992 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -2,13 +2,38 @@ from __future__ import annotations +from typing import Any + from reflex import constants from reflex.event import EventChain from reflex.utils import format -from reflex.vars import BaseVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import BaseVar, Var, VarData -color_mode = BaseVar(_var_name=constants.ColorMode.NAME, _var_type="str") -toggle_color_mode = BaseVar(_var_name=constants.ColorMode.TOGGLE, _var_type=EventChain) +VarData.update_forward_refs() # Ensure all type definitions are resolved + +# Reference the global ColorModeContext +color_mode_var_data = VarData( # type: ignore + imports={ + f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")}, + "react": {ImportVar(tag="useContext")}, + }, + hooks={ + f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)", + }, +) +# Var resolves to the current color mode for the app ("light" or "dark") +color_mode = BaseVar( + _var_name=constants.ColorMode.NAME, + _var_type="str", + _var_data=color_mode_var_data, +) +# Var resolves to a function invocation that toggles the color mode +toggle_color_mode = BaseVar( + _var_name=constants.ColorMode.TOGGLE, + _var_type=EventChain, + _var_data=color_mode_var_data, +) def convert(style_dict): @@ -20,16 +45,27 @@ def convert(style_dict): Returns: The formatted style dictionary. """ + var_data = None # Track import/hook data from any Vars in the style dict. out = {} for key, value in style_dict.items(): key = format.to_camel_case(key) + new_var_data = None if isinstance(value, dict): - out[key] = convert(value) + # Recursively format nested style dictionaries. + out[key], new_var_data = convert(value) elif isinstance(value, Var): + # If the value is a Var, extract the var_data and cast as str. + new_var_data = value._var_data out[key] = str(value) else: + # Otherwise, convert to Var to collapse VarData encoded in f-string. + new_var = Var.create(value) + if new_var is not None: + new_var_data = new_var._var_data out[key] = value - return out + # Combine all the collected VarData instances. + var_data = VarData.merge(var_data, new_var_data) + return out, var_data class Style(dict): @@ -41,5 +77,33 @@ class Style(dict): Args: style_dict: The style dictionary. """ - style_dict = style_dict or {} - super().__init__(convert(style_dict)) + style_dict, self._var_data = convert(style_dict or {}) + super().__init__(style_dict) + + def update(self, style_dict: dict | None, **kwargs): + """Update the style. + + Args: + style_dict: The style dictionary. + kwargs: Other key value pairs to apply to the dict update. + """ + if kwargs: + style_dict = {**(style_dict or {}), **kwargs} + converted_dict = type(self)(style_dict) + # Combine our VarData with that of any Vars in the style_dict that was passed. + self._var_data = VarData.merge(self._var_data, converted_dict._var_data) + super().update(converted_dict) + + def __setitem__(self, key: str, value: Any): + """Set an item in the style. + + Args: + key: The key to set. + value: The value to set. + """ + # Create a Var to collapse VarData encoded in f-string. + _var = Var.create(value) + if _var is not None: + # Carry the imports/hooks when setting a Var as a value. + self._var_data = VarData.merge(self._var_data, _var._var_data) + super().__setitem__(key, value) diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 03ee5c9c7..4050706ea 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -232,9 +232,9 @@ def format_route(route: str, format_case=True) -> str: def format_cond( - cond: str, - true_value: str, - false_value: str = '""', + cond: str | Var, + true_value: str | Var, + false_value: str | Var = '""', is_prop=False, ) -> str: """Format a conditional expression. @@ -248,9 +248,6 @@ def format_cond( Returns: The formatted conditional expression. """ - # Import here to avoid circular imports. - from reflex.vars import Var - # Use Python truthiness. cond = f"isTrue({cond})" @@ -266,6 +263,7 @@ def format_cond( _var_is_string=type(false_value) is str, ) prop2._var_is_local = True + prop1, prop2 = str(prop1), str(prop2) # avoid f-string semantics for Var return f"{cond} ? {prop1} : {prop2}".replace("{", "").replace("}", "") # Format component conds. @@ -517,6 +515,21 @@ def format_state(value: Any) -> Any: raise TypeError(f"No JSON serializer found for var {value} of type {type(value)}.") +def format_state_name(state_name: str) -> str: + """Format a state name, replacing dots with double underscore. + + This allows individual substates to be accessed independently as javascript vars + without using dot notation. + + Args: + state_name: The state name to format. + + Returns: + The formatted state name. + """ + return state_name.replace(".", "__") + + def format_ref(ref: str) -> str: """Format a ref. diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py index f20ec141d..26faa4820 100644 --- a/reflex/utils/imports.py +++ b/reflex/utils/imports.py @@ -3,11 +3,9 @@ from __future__ import annotations from collections import defaultdict -from typing import Dict, List +from typing import Dict, List, Optional -from reflex.vars import ImportVar - -ImportDict = Dict[str, List[ImportVar]] +from reflex.base import Base def merge_imports(*imports) -> ImportDict: @@ -24,3 +22,42 @@ def merge_imports(*imports) -> ImportDict: for lib, fields in import_dict.items(): all_imports[lib].extend(fields) return all_imports + + +class ImportVar(Base): + """An import var.""" + + # The name of the import tag. + tag: Optional[str] + + # whether the import is default or named. + is_default: Optional[bool] = False + + # The tag alias. + alias: Optional[str] = None + + # Whether this import need to install the associated lib + install: Optional[bool] = True + + # whether this import should be rendered or not + render: Optional[bool] = True + + @property + def name(self) -> str: + """The name of the import. + + Returns: + The name(tag name with alias) of tag. + """ + return self.tag if not self.alias else " as ".join([self.tag, self.alias]) # type: ignore + + def __hash__(self) -> int: + """Define a hash function for the import var. + + Returns: + The hash of the var. + """ + return hash((self.tag, self.is_default, self.alias, self.install, self.render)) + + +ImportDict = Dict[str, List[ImportVar]] diff --git a/reflex/utils/types.py b/reflex/utils/types.py index c114e94ac..806684954 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -27,6 +27,7 @@ from reflex.utils import serializers GenericType = Union[Type, _GenericAlias] # Valid state var types. +JSONType = {str, int, float, bool} PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple] StateVar = Union[PrimitiveType, Base, None] StateIterVar = Union[list, set, tuple] diff --git a/reflex/vars.py b/reflex/vars.py index 468dc2a24..d4d785b7a 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -7,6 +7,7 @@ import dis import inspect import json import random +import re import string import sys from types import CodeType, FunctionType @@ -15,9 +16,11 @@ from typing import ( Any, Callable, Dict, + Iterable, List, Literal, Optional, + Set, Tuple, Type, Union, @@ -30,7 +33,10 @@ from typing import ( from reflex import constants from reflex.base import Base -from reflex.utils import console, format, serializers, types +from reflex.utils import console, format, imports, serializers, types + +# This module used to export ImportVar itself, so we still import it for export here +from reflex.utils.imports import ImportDict, ImportVar if TYPE_CHECKING: from reflex.state import State @@ -71,7 +77,7 @@ OPERATION_MAPPING = { REPLACED_NAMES = { "full_name": "_var_full_name", "name": "_var_name", - "state": "_var_state", + "state": "_var_data.state", "type_": "_var_type", "is_local": "_var_is_local", "is_string": "_var_is_string", @@ -93,6 +99,131 @@ def get_unique_variable_name() -> str: return get_unique_variable_name() +class VarData(Base): + """Metadata associated with a Var.""" + + # The name of the enclosing state. + state: str = "" + + # Imports needed to render this var + imports: ImportDict = {} + + # Hooks that need to be present in the component to render this var + hooks: Set[str] = set() + + @classmethod + def merge(cls, *others: VarData | None) -> VarData | None: + """Merge multiple var data objects. + + Args: + *others: The var data objects to merge. + + Returns: + The merged var data object. + """ + state = "" + _imports = {} + hooks = set() + for var_data in others: + if var_data is None: + continue + state = state or var_data.state + _imports = imports.merge_imports(_imports, var_data.imports) + hooks.update(var_data.hooks) + return ( + cls( + state=state, + imports=_imports, + hooks=hooks, + ) + or None + ) + + def __bool__(self) -> bool: + """Check if the var data is non-empty. + + Returns: + True if any field is set to a non-default value. + """ + return bool(self.state or self.imports or self.hooks) + + def dict(self) -> dict: + """Convert the var data to a dictionary. + + Returns: + The var data dictionary. + """ + return { + "state": self.state, + "imports": { + lib: [import_var.dict() for import_var in import_vars] + for lib, import_vars in self.imports.items() + }, + "hooks": list(self.hooks), + } + + +def _encode_var(value: Var) -> str: + """Encode the state name into a formatted var. + + Args: + value: The value to encode the state name into. + + Returns: + The encoded var. + """ + if value._var_data: + return f"{value._var_data.json()}" + str(value) + return str(value) + + +def _decode_var(value: str) -> tuple[VarData | None, str]: + """Decode the state name from a formatted var. + + Args: + value: The value to extract the state name from. + + Returns: + The extracted state name and the value without the state name. + """ + var_datas = [] + if isinstance(value, str): + # Extract the state name from a formatted var + while m := re.match(r"(.*)(.*)(.*)", value): + value = m.group(1) + m.group(3) + var_datas.append(VarData.parse_raw(m.group(2))) + if var_datas: + return VarData.merge(*var_datas), value + return None, value + + +def _extract_var_data(value: Iterable) -> list[VarData | None]: + """Extract the var imports and hooks from an iterable containing a Var. + + Args: + value: The iterable to extract the VarData from + + Returns: + The extracted VarDatas. + """ + var_datas = [] + with contextlib.suppress(TypeError): + for sub in value: + if isinstance(sub, Var): + var_datas.append(sub._var_data) + elif not isinstance(sub, str): + # Recurse into dict values. + if hasattr(sub, "values") and callable(sub.values): + var_datas.extend(_extract_var_data(sub.values())) + # Recurse into iterable values (or dict keys). + var_datas.extend(_extract_var_data(sub)) + # Recurse when value is a dict itself. + values = getattr(value, "values", None) + if callable(values): + var_datas.extend(_extract_var_data(values())) + return var_datas + + class Var: """An abstract var.""" @@ -102,15 +233,18 @@ class Var: # The type of the var. _var_type: Type - # The name of the enclosing state. - _var_state: str - # Whether this is a local javascript variable. _var_is_local: bool # Whether the var is a string literal. _var_is_string: bool + # _var_full_name should be prefixed with _var_state + _var_full_name_needs_state_prefix: bool + + # Extra metadata associated with the Var + _var_data: Optional[VarData] + @classmethod def create( cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False @@ -136,9 +270,14 @@ class Var: if isinstance(value, Var): return value + # Try to pull the imports and hooks from contained values. + _var_data = None + if not isinstance(value, str): + _var_data = VarData.merge(*_extract_var_data(value)) + # Try to serialize the value. type_ = type(value) - name = serializers.serialize(value) + name = value if type_ in types.JSONType else serializers.serialize(value) if name is None: raise TypeError( f"No JSON serializer found for var {value} of type {type_}." @@ -150,6 +289,7 @@ class Var: _var_type=type_, _var_is_local=_var_is_local, _var_is_string=_var_is_string, + _var_data=_var_data, ) @classmethod @@ -186,6 +326,39 @@ class Var: """ return _GenericAlias(cls, type_) + def __post_init__(self) -> None: + """Post-initialize the var.""" + # Decode any inline Var markup and apply it to the instance + _var_data, _var_name = _decode_var(self._var_name) + if _var_data: + self._var_name = _var_name + self._var_data = VarData.merge(self._var_data, _var_data) + + def _replace(self, merge_var_data=None, **kwargs: Any) -> Var: + """Make a copy of this Var with updated fields. + + Args: + merge_var_data: VarData to merge into the existing VarData. + **kwargs: Var fields to update. + + Returns: + A new BaseVar with the updated fields overwriting the corresponding fields in this Var. + """ + field_values = dict( + _var_name=kwargs.pop("_var_name", self._var_name), + _var_type=kwargs.pop("_var_type", self._var_type), + _var_is_local=kwargs.pop("_var_is_local", self._var_is_local), + _var_is_string=kwargs.pop("_var_is_string", self._var_is_string), + _var_full_name_needs_state_prefix=kwargs.pop( + "_var_full_name_needs_state_prefix", + self._var_full_name_needs_state_prefix, + ), + _var_data=VarData.merge( + kwargs.get("_var_data", self._var_data), merge_var_data + ), + ) + return BaseVar(**field_values) + def _decode(self) -> Any: """Decode Var as a python value. @@ -195,8 +368,6 @@ class Var: Returns: The decoded value or the Var name. """ - if self._var_state: - return self._var_full_name if self._var_is_string: return self._var_name try: @@ -216,8 +387,10 @@ class Var: return ( self._var_name == other._var_name and self._var_type == other._var_type - and self._var_state == other._var_state and self._var_is_local == other._var_is_local + and self._var_full_name_needs_state_prefix + == other._var_full_name_needs_state_prefix + and self._var_data == other._var_data ) def to_string(self, json: bool = True) -> Var: @@ -285,9 +458,11 @@ class Var: Returns: The formatted var. """ + # Encode the _var_data into the formatted output for tracking purposes. + str_self = _encode_var(self) if self._var_is_local: - return str(self) - return f"${str(self)}" + return str_self + return f"${str_self}" def __getitem__(self, i: Any) -> Var: """Index into a var. @@ -320,12 +495,7 @@ class Var: # Convert any vars to local vars. if isinstance(i, Var): - i = BaseVar( - _var_name=i._var_name, - _var_type=i._var_type, - _var_state=i._var_state, - _var_is_local=True, - ) + i = i._replace(_var_is_local=True) # Handle list/tuple/str indexing. if types._issubclass(self._var_type, Union[List, Tuple, str]): @@ -344,11 +514,9 @@ class Var: stop = i.stop or "undefined" # Use the slice function. - return BaseVar( + return self._replace( _var_name=f"{self._var_name}.slice({start}, {stop})", - _var_type=self._var_type, - _var_state=self._var_state, - _var_is_local=self._var_is_local, + _var_is_string=False, ) # Get the type of the indexed var. @@ -359,11 +527,10 @@ class Var: ) # Use `at` to support negative indices. - return BaseVar( + return self._replace( _var_name=f"{self._var_name}.at({i})", _var_type=type_, - _var_state=self._var_state, - _var_is_local=self._var_is_local, + _var_is_string=False, ) # Dictionary / dataframe indexing. @@ -393,11 +560,10 @@ class Var: ) # Use normal indexing here. - return BaseVar( + return self._replace( _var_name=f"{self._var_name}[{i}]", _var_type=type_, - _var_state=self._var_state, - _var_is_local=self._var_is_local, + _var_is_string=False, ) def __getattr__(self, name: str) -> Var: @@ -423,11 +589,10 @@ class Var: type_ = types.get_attribute_access_type(self._var_type, name) if type_ is not None: - return BaseVar( + return self._replace( _var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}", _var_type=type_, - _var_state=self._var_state, - _var_is_local=self._var_is_local, + _var_is_string=False, ) if name in REPLACED_NAMES: @@ -519,10 +684,12 @@ class Var: else f"{self._var_full_name}.{fn}()" ) - return BaseVar( + return self._replace( _var_name=operation_name, _var_type=type_, - _var_is_local=self._var_is_local, + _var_is_string=False, + _var_full_name_needs_state_prefix=False, + merge_var_data=other._var_data if other is not None else None, ) @staticmethod @@ -602,10 +769,10 @@ class Var: """ if not types._issubclass(self._var_type, List): raise TypeError(f"Cannot get length of non-list var {self}.") - return BaseVar( - _var_name=f"{self._var_full_name}.length", + return self._replace( + _var_name=f"{self._var_name}.length", _var_type=int, - _var_is_local=self._var_is_local, + _var_is_string=False, ) def __eq__(self, other: Var) -> Var: @@ -692,7 +859,17 @@ class Var: types.get_base_class(self._var_type) == list and types.get_base_class(other_type) == list ): - return self.operation(",", other, fn="spreadArraysOrObjects", flip=flip) + return self.operation( + ",", other, fn="spreadArraysOrObjects", flip=flip + )._replace( + merge_var_data=VarData( + imports={ + f"/{constants.Dirs.STATE_PATH}": [ + ImportVar(tag="spreadArraysOrObjects") + ] + }, + ), + ) return self.operation("+", other, flip=flip) def __radd__(self, other: Var) -> Var: @@ -755,10 +932,11 @@ class Var: ]: other_name = other._var_full_name if isinstance(other, Var) else other name = f"Array({other_name}).fill().map(() => {self._var_full_name}).flat()" - return BaseVar( + return self._replace( _var_name=name, _var_type=str, - _var_is_local=self._var_is_local, + _var_is_string=False, + _var_full_name_needs_state_prefix=False, ) return self.operation("*", other) @@ -1003,10 +1181,11 @@ class Var: elif not isinstance(other, Var): other = Var.create(other) if types._issubclass(self._var_type, Dict): - return BaseVar( - _var_name=f"{self._var_full_name}.{method}({other._var_full_name})", + return self._replace( + _var_name=f"{self._var_name}.{method}({other._var_full_name})", _var_type=bool, - _var_is_local=self._var_is_local, + _var_is_string=False, + merge_var_data=other._var_data, ) else: # str, list, tuple # For strings, the left operand must be a string. @@ -1016,10 +1195,11 @@ class Var: raise TypeError( f"'in ' requires string as left operand, not {other._var_type}" ) - return BaseVar( - _var_name=f"{self._var_full_name}.includes({other._var_full_name})", + return self._replace( + _var_name=f"{self._var_name}.includes({other._var_full_name})", _var_type=bool, - _var_is_local=self._var_is_local, + _var_is_string=False, + merge_var_data=other._var_data, ) def reverse(self) -> Var: @@ -1034,10 +1214,10 @@ class Var: if not types._issubclass(self._var_type, list): raise TypeError(f"Cannot reverse non-list var {self._var_full_name}.") - return BaseVar( + return self._replace( _var_name=f"[...{self._var_full_name}].reverse()", - _var_type=self._var_type, - _var_is_local=self._var_is_local, + _var_is_string=False, + _var_full_name_needs_state_prefix=False, ) def lower(self) -> Var: @@ -1054,10 +1234,10 @@ class Var: f"Cannot convert non-string var {self._var_full_name} to lowercase." ) - return BaseVar( - _var_name=f"{self._var_full_name}.toLowerCase()", + return self._replace( + _var_name=f"{self._var_name}.toLowerCase()", + _var_is_string=False, _var_type=str, - _var_is_local=self._var_is_local, ) def upper(self) -> Var: @@ -1074,10 +1254,10 @@ class Var: f"Cannot convert non-string var {self._var_full_name} to uppercase." ) - return BaseVar( - _var_name=f"{self._var_full_name}.toUpperCase()", + return self._replace( + _var_name=f"{self._var_name}.toUpperCase()", + _var_is_string=False, _var_type=str, - _var_is_local=self._var_is_local, ) def split(self, other: str | Var[str] = " ") -> Var: @@ -1097,10 +1277,11 @@ class Var: other = Var.create_safe(json.dumps(other)) if isinstance(other, str) else other - return BaseVar( - _var_name=f"{self._var_full_name}.split({other._var_full_name})", + return self._replace( + _var_name=f"{self._var_name}.split({other._var_full_name})", + _var_is_string=False, _var_type=list[str], - _var_is_local=self._var_is_local, + merge_var_data=other._var_data, ) def join(self, other: str | Var[str] | None = None) -> Var: @@ -1125,10 +1306,11 @@ class Var: else: other = Var.create_safe(other) - return BaseVar( - _var_name=f"{self._var_full_name}.join({other._var_full_name})", + return self._replace( + _var_name=f"{self._var_name}.join({other._var_full_name})", + _var_is_string=False, _var_type=str, - _var_is_local=self._var_is_local, + merge_var_data=other._var_data, ) def foreach(self, fn: Callable) -> Var: @@ -1159,10 +1341,9 @@ class Var: fn_signature = inspect.signature(fn) fn_args = (arg, index) fn_ret = fn(*fn_args[: len(fn_signature.parameters)]) - return BaseVar( + return self._replace( _var_name=f"{self._var_full_name}.map(({arg._var_name}, {index._var_name}) => {fn_ret})", - _var_type=self._var_type, - _var_is_local=self._var_is_local, + _var_is_string=False, ) @classmethod @@ -1207,6 +1388,18 @@ class Var: _var_name=f"Array.from(range({v1._var_full_name}, {v2._var_full_name}, {step._var_name}))", _var_type=list[int], _var_is_local=False, + _var_data=VarData.merge( + v1._var_data, + v2._var_data, + step._var_data, + VarData( + imports={ + "/utils/helpers/range.js": [ + ImportVar(tag="range", is_default=True), + ], + }, + ), + ), ) def to(self, type_: Type) -> Var: @@ -1218,12 +1411,7 @@ class Var: Returns: The converted var. """ - return BaseVar( - _var_name=self._var_name, - _var_type=type_, - _var_state=self._var_state, - _var_is_local=self._var_is_local, - ) + return self._replace(_var_type=type_) @property def _var_full_name(self) -> str: @@ -1232,24 +1420,51 @@ class Var: Returns: The full name of the var. """ + if not self._var_full_name_needs_state_prefix: + return self._var_name return ( self._var_name - if self._var_state == "" - else ".".join([self._var_state, self._var_name]) + if self._var_data is None or self._var_data.state == "" + else ".".join( + [format.format_state_name(self._var_data.state), self._var_name] + ) ) - def _var_set_state(self, state: Type[State]) -> Any: + def _var_set_state(self, state: Type[State] | str) -> Any: """Set the state of the var. Args: - state: The state to set. + state: The state to set or the full name of the state. Returns: The var with the set state. """ - self._var_state = state.get_full_name() + state_name = state if isinstance(state, str) else state.get_full_name() + new_var_data = VarData( + state=state_name, + hooks={ + "const {0} = useContext(StateContexts.{0})".format( + format.format_state_name(state_name) + ) + }, + imports={ + f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")], + "react": [ImportVar(tag="useContext")], + }, + ) + self._var_data = VarData.merge(self._var_data, new_var_data) + self._var_full_name_needs_state_prefix = True return self + @property + def _var_state(self) -> str: + """Compat method for getting the state. + + Returns: + The state name associated with the var. + """ + return self._var_data.state if self._var_data else "" + @dataclasses.dataclass( eq=False, @@ -1264,15 +1479,18 @@ class BaseVar(Var): # The type of the var. _var_type: Type = dataclasses.field(default=Any) - # The name of the enclosing state. - _var_state: str = dataclasses.field(default="") - # Whether this is a local javascript variable. _var_is_local: bool = dataclasses.field(default=False) # Whether the var is a string literal. _var_is_string: bool = dataclasses.field(default=False) + # _var_full_name should be prefixed with _var_state + _var_full_name_needs_state_prefix: bool = dataclasses.field(default=False) + + # Extra metadata associated with the Var + _var_data: Optional[VarData] = dataclasses.field(default=None) + def __hash__(self) -> int: """Define a hash function for a var. @@ -1334,9 +1552,11 @@ class BaseVar(Var): The name of the setter function. """ setter = constants.SETTER_PREFIX + self._var_name - if not include_state or self._var_state == "": + if self._var_data is None: return setter - return ".".join((self._var_state, setter)) + if not include_state or self._var_data.state == "": + return setter + return ".".join((self._var_data.state, setter)) def get_setter(self) -> Callable[[State, Any], None]: """Get the var's setter function. @@ -1550,48 +1770,6 @@ def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: return cvar -class ImportVar(Base): - """An import var.""" - - # The name of the import tag. - tag: Optional[str] - - # whether the import is default or named. - is_default: Optional[bool] = False - - # The tag alias. - alias: Optional[str] = None - - # Whether this import need to install the associated lib - install: Optional[bool] = True - - # whether this import should be rendered or not - render: Optional[bool] = True - - @property - def name(self) -> str: - """The name of the import. - - Returns: - The name(tag name with alias) of tag. - """ - return self.tag if not self.alias else " as ".join([self.tag, self.alias]) # type: ignore - - def __hash__(self) -> int: - """Define a hash function for the import var. - - Returns: - The hash of the var. - """ - return hash((self.tag, self.is_default, self.alias, self.install, self.render)) - - -class NoRenderImportVar(ImportVar): - """A import that doesn't need to be rendered.""" - - render: Optional[bool] = False - - class CallableVar(BaseVar): """Decorate a Var-returning function to act as both a Var and a function. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 001ba9aa3..9208013db 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -6,11 +6,13 @@ from reflex import constants as constants from reflex.base import Base as Base from reflex.state import State as State from reflex.utils import console as console, format as format, types as types +from reflex.utils.imports import ImportVar from types import FunctionType from typing import ( Any, Callable, Dict, + Iterable, List, Optional, Set, @@ -22,13 +24,24 @@ from typing import ( USED_VARIABLES: Incomplete def get_unique_variable_name() -> str: ... +def _encode_var(value: Var) -> str: ... +def _decode_var(value: str) -> tuple[VarData, str]: ... +def _extract_var_data(value: Iterable) -> list[VarData | None]: ... + +class VarData(Base): + state: str + imports: dict[str, set[ImportVar]] + hooks: set[str] + @classmethod + def merge(cls, *others: VarData | None) -> VarData | None: ... class Var: _var_name: str _var_type: Type - _var_state: str = "" _var_is_local: bool = False _var_is_string: bool = False + _var_full_name_needs_state_prefix: bool = False + _var_data: VarData | None = None @classmethod def create( cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False @@ -38,7 +51,8 @@ class Var: cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False ) -> Var: ... @classmethod - def __class_getitem__(cls, type_: str) -> _GenericAlias: ... + def __class_getitem__(cls, type_: Type) -> _GenericAlias: ... + def _replace(self, merge_var_data=None, **kwargs: Any) -> Var: ... def equals(self, other: Var) -> bool: ... def to_string(self) -> Var: ... def __hash__(self) -> int: ... @@ -95,15 +109,16 @@ class Var: def to(self, type_: Type) -> Var: ... @property def _var_full_name(self) -> str: ... - def _var_set_state(self, state: Type[State]) -> Any: ... + def _var_set_state(self, state: Type[State] | str) -> Any: ... @dataclass(eq=False) class BaseVar(Var): _var_name: str _var_type: Any - _var_state: str = "" _var_is_local: bool = False _var_is_string: bool = False + _var_full_name_needs_state_prefix: bool = False + _var_data: VarData | None = None def __hash__(self) -> int: ... def get_default_value(self) -> Any: ... def get_setter_name(self, include_state: bool = ...) -> str: ... @@ -123,21 +138,6 @@ class ComputedVar(Var): def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ... -class ImportVar(Base): - tag: Optional[str] - is_default: Optional[bool] = False - alias: Optional[str] = None - install: Optional[bool] = True - render: Optional[bool] = True - @property - def name(self) -> str: ... - def __hash__(self) -> int: ... - -class NoRenderImportVar(ImportVar): - """A import that doesn't need to be rendered.""" - -def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ... - class CallableVar(BaseVar): def __init__(self, fn: Callable[..., BaseVar]): ... def __call__(self, *args, **kwargs) -> BaseVar: ... diff --git a/tests/compiler/test_compiler.py b/tests/compiler/test_compiler.py index e8423d97f..1def5ca4f 100644 --- a/tests/compiler/test_compiler.py +++ b/tests/compiler/test_compiler.py @@ -5,7 +5,7 @@ import pytest from reflex.compiler import compiler, utils from reflex.utils import imports -from reflex.vars import ImportVar +from reflex.utils.imports import ImportVar @pytest.mark.parametrize( diff --git a/tests/components/layout/test_cond.py b/tests/components/layout/test_cond.py index 3bf373bb4..00cf4de7d 100644 --- a/tests/components/layout/test_cond.py +++ b/tests/components/layout/test_cond.py @@ -110,7 +110,7 @@ def test_cond_no_else(): # Props do not support the use of cond without else with pytest.raises(ValueError): - cond(True, "hello") + cond(True, "hello") # type: ignore def test_mobile_only(): diff --git a/tests/components/test_component.py b/tests/components/test_component.py index abf6c5a06..6d6b0a0cc 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -4,14 +4,16 @@ import pytest import reflex as rx from reflex.base import Base +from reflex.components.base.bare import Bare from reflex.components.component import Component, CustomComponent, custom_component from reflex.components.layout.box import Box from reflex.constants import EventTriggers -from reflex.event import EventHandler +from reflex.event import EventChain, EventHandler from reflex.state import State from reflex.style import Style from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var, VarData @pytest.fixture @@ -600,3 +602,161 @@ def test_format_component(component, rendered): rendered: The expected rendered component. """ assert str(component) == rendered + + +TEST_VAR = Var.create_safe("test")._replace( + merge_var_data=VarData( + hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test" + ) +) +FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar") +STYLE_VAR = TEST_VAR._replace(_var_name="style", _var_is_local=False) +EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain) +ARG_VAR = Var.create("arg") + + +class EventState(rx.State): + """State for testing event handlers with _get_vars.""" + + v: int = 42 + + def handler(self): + """A handler that does nothing.""" + + def handler2(self, arg): + """A handler that takes an arg. + + Args: + arg: An arg. + """ + + +@pytest.mark.parametrize( + ("component", "exp_vars"), + ( + pytest.param( + Bare.create(TEST_VAR), + [TEST_VAR], + id="direct-bare", + ), + pytest.param( + Bare.create(f"foo{TEST_VAR}bar"), + [FORMATTED_TEST_VAR], + id="fstring-bare", + ), + pytest.param( + rx.text(as_=TEST_VAR), + [TEST_VAR], + id="direct-prop", + ), + pytest.param( + rx.text(as_=f"foo{TEST_VAR}bar"), + [FORMATTED_TEST_VAR], + id="fstring-prop", + ), + pytest.param( + rx.fragment(id=TEST_VAR), + [TEST_VAR], + id="direct-id", + ), + pytest.param( + rx.fragment(id=f"foo{TEST_VAR}bar"), + [FORMATTED_TEST_VAR], + id="fstring-id", + ), + pytest.param( + rx.fragment(key=TEST_VAR), + [TEST_VAR], + id="direct-key", + ), + pytest.param( + rx.fragment(key=f"foo{TEST_VAR}bar"), + [FORMATTED_TEST_VAR], + id="fstring-key", + ), + pytest.param( + rx.fragment(class_name=TEST_VAR), + [TEST_VAR], + id="direct-class_name", + ), + pytest.param( + rx.fragment(class_name=f"foo{TEST_VAR}bar"), + [FORMATTED_TEST_VAR], + id="fstring-class_name", + ), + pytest.param( + rx.fragment(special_props={TEST_VAR}), + [TEST_VAR], + id="direct-special_props", + ), + pytest.param( + rx.fragment(special_props={Var.create(f"foo{TEST_VAR}bar")}), + [FORMATTED_TEST_VAR], + id="fstring-special_props", + ), + pytest.param( + # custom_attrs cannot accept a Var directly as a value + rx.fragment(custom_attrs={"href": f"{TEST_VAR}"}), + [TEST_VAR], + id="fstring-custom_attrs-nofmt", + ), + pytest.param( + rx.fragment(custom_attrs={"href": f"foo{TEST_VAR}bar"}), + [FORMATTED_TEST_VAR], + id="fstring-custom_attrs", + ), + pytest.param( + rx.fragment(background_color=TEST_VAR), + [STYLE_VAR], + id="direct-background_color", + ), + pytest.param( + rx.fragment(background_color=f"foo{TEST_VAR}bar"), + [STYLE_VAR], + id="fstring-background_color", + ), + pytest.param( + rx.fragment(style={"background_color": TEST_VAR}), # type: ignore + [STYLE_VAR], + id="direct-style-background_color", + ), + pytest.param( + rx.fragment(style={"background_color": f"foo{TEST_VAR}bar"}), # type: ignore + [STYLE_VAR], + id="fstring-style-background_color", + ), + pytest.param( + rx.fragment(on_click=EVENT_CHAIN_VAR), # type: ignore + [EVENT_CHAIN_VAR], + id="direct-event-chain", + ), + pytest.param( + rx.fragment(on_click=EventState.handler), + [], + id="direct-event-handler", + ), + pytest.param( + rx.fragment(on_click=EventState.handler2(TEST_VAR)), # type: ignore + [ARG_VAR, TEST_VAR], + id="direct-event-handler-arg", + ), + pytest.param( + rx.fragment(on_click=EventState.handler2(EventState.v)), # type: ignore + [ARG_VAR, EventState.v], + id="direct-event-handler-arg2", + ), + pytest.param( + rx.fragment(on_click=lambda: EventState.handler2(TEST_VAR)), # type: ignore + [ARG_VAR, TEST_VAR], + id="direct-event-handler-lambda", + ), + ), +) +def test_get_vars(component, exp_vars): + comp_vars = sorted(component._get_vars(), key=lambda v: v._var_name) + assert len(comp_vars) == len(exp_vars) + for comp_var, exp_var in zip( + comp_vars, + sorted(exp_vars, key=lambda v: v._var_name), + ): + assert comp_var.equals(exp_var) diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py index 150083bd5..7767dcf8b 100644 --- a/tests/middleware/test_hydrate_middleware.py +++ b/tests/middleware/test_hydrate_middleware.py @@ -104,7 +104,7 @@ async def test_preprocess( app=app, event=request.getfixturevalue(event_fixture), state=state ) assert isinstance(update, StateUpdate) - assert update.delta == {state.get_name(): state.dict()} + assert update.delta == state.dict() events = update.events assert len(events) == 2 @@ -133,7 +133,7 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1): update = await hydrate_middleware.preprocess(app=app, event=event1, state=state) assert isinstance(update, StateUpdate) - assert update.delta == {"test_state": state.dict()} + assert update.delta == state.dict() assert len(update.events) == 3 # Apply the events. @@ -163,7 +163,7 @@ async def test_preprocess_no_events(hydrate_middleware, event1): state=state, ) assert isinstance(update, StateUpdate) - assert update.delta == {"test_state": state.dict()} + assert update.delta == state.dict() assert len(update.events) == 1 assert isinstance(update, StateUpdate) diff --git a/tests/test_app.py b/tests/test_app.py index 434eeb60f..e7ac03661 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -769,9 +769,7 @@ async def test_upload_file(tmp_path, state, delta, token: str): ) current_state = await app.state_manager.get_state(token) - state_dict = current_state.dict() - for substate in state.get_full_name().split(".")[1:]: - state_dict = state_dict[substate] + state_dict = current_state.dict()[state.get_full_name()] assert state_dict["img_list"] == [ "image1.jpg", "image2.jpg", diff --git a/tests/test_state.py b/tests/test_state.py index 29065ed11..d85f0e59f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -324,11 +324,17 @@ def test_dict(test_state): Args: test_state: A state. """ - substates = {"child_state", "child_state2"} - assert set(test_state.dict().keys()) == set(test_state.vars.keys()) | substates - assert ( - set(test_state.dict(include_computed=False).keys()) - == set(test_state.base_vars) | substates + substates = { + "test_state", + "test_state.child_state", + "test_state.child_state.grandchild_state", + "test_state.child_state2", + } + test_state_dict = test_state.dict() + assert set(test_state_dict) == substates + assert set(test_state_dict[test_state.get_name()]) == set(test_state.vars) + assert set(test_state.dict(include_computed=False)[test_state.get_name()]) == set( + test_state.base_vars ) @@ -1081,9 +1087,9 @@ def test_computed_var_cached(): return self.v cs = ComputedState() - assert cs.dict()["v"] == 0 + assert cs.dict()[cs.get_full_name()]["v"] == 0 assert comp_v_calls == 1 - assert cs.dict()["comp_v"] == 0 + assert cs.dict()[cs.get_full_name()]["comp_v"] == 0 assert comp_v_calls == 1 assert cs.comp_v == 0 assert comp_v_calls == 1 @@ -1156,24 +1162,27 @@ def test_computed_var_depends_on_parent_non_cached(): assert ps.dirty_vars == set() assert cs.dirty_vars == set() - assert ps.dict() == { - cs.get_name(): {"dep_v": 2}, + dict1 = ps.dict() + assert dict1[ps.get_full_name()] == { "no_cache_v": 1, CompileVars.IS_HYDRATED: False, "router": formatted_router, } - assert ps.dict() == { - cs.get_name(): {"dep_v": 4}, + assert dict1[cs.get_full_name()] == {"dep_v": 2} + dict2 = ps.dict() + assert dict2[ps.get_full_name()] == { "no_cache_v": 3, CompileVars.IS_HYDRATED: False, "router": formatted_router, } - assert ps.dict() == { - cs.get_name(): {"dep_v": 6}, + assert dict2[cs.get_full_name()] == {"dep_v": 4} + dict3 = ps.dict() + assert dict3[ps.get_full_name()] == { "no_cache_v": 5, CompileVars.IS_HYDRATED: False, "router": formatted_router, } + assert dict3[cs.get_full_name()] == {"dep_v": 6} assert counter == 6 @@ -2201,13 +2210,13 @@ def test_json_dumps_with_mutables(): items: List[Foo] = [Foo()] dict_val = MutableContainsBase().dict() - assert isinstance(dict_val["items"][0], dict) + assert isinstance(dict_val[MutableContainsBase.get_full_name()]["items"][0], dict) val = json_dumps(dict_val) f_items = '[{"tags": ["123", "456"]}]' f_formatted_router = str(formatted_router).replace("'", '"') assert ( val - == f'{{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}' + == f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}' ) diff --git a/tests/test_style.py b/tests/test_style.py index 8b09f9ac0..a8fcf6839 100644 --- a/tests/test_style.py +++ b/tests/test_style.py @@ -22,7 +22,8 @@ def test_convert(style_dict, expected): style_dict: The style to check. expected: The expected formatted style. """ - assert style.convert(style_dict) == expected + converted_dict, _var_data = style.convert(style_dict) + assert converted_dict == expected @pytest.mark.parametrize( diff --git a/tests/test_var.py b/tests/test_var.py index 9efb5fb78..e3cb28a1c 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -7,18 +7,20 @@ from pandas import DataFrame from reflex.base import Base from reflex.state import State +from reflex.utils.imports import ImportVar from reflex.vars import ( BaseVar, ComputedVar, - ImportVar, Var, ) test_vars = [ BaseVar(_var_name="prop1", _var_type=int), BaseVar(_var_name="key", _var_type=str), - BaseVar(_var_name="value", _var_type=str, _var_state="state"), - BaseVar(_var_name="local", _var_type=str, _var_state="state", _var_is_local=True), + BaseVar(_var_name="value", _var_type=str)._var_set_state("state"), + BaseVar(_var_name="local", _var_type=str, _var_is_local=True)._var_set_state( + "state" + ), BaseVar(_var_name="local2", _var_type=str, _var_is_local=True), ] @@ -263,7 +265,7 @@ def test_basic_operations(TestObj): assert str(v([1, 2, 3])[v(0)]) == "{[1, 2, 3].at(0)}" assert str(v({"a": 1, "b": 2})["a"]) == '{{"a": 1, "b": 2}["a"]}' assert ( - str(BaseVar(_var_name="foo", _var_state="state", _var_type=TestObj).bar) + str(BaseVar(_var_name="foo", _var_type=TestObj)._var_set_state("state").bar) == "{state.foo.bar}" ) assert str(abs(v(1))) == "{Math.abs(1)}" @@ -274,7 +276,7 @@ def test_basic_operations(TestObj): assert str(v([1, 2, 3]).reverse()) == "{[...[1, 2, 3]].reverse()}" assert str(v(["1", "2", "3"]).reverse()) == '{[...["1", "2", "3"]].reverse()}' assert ( - str(BaseVar(_var_name="foo", _var_state="state", _var_type=list).reverse()) + str(BaseVar(_var_name="foo", _var_type=list)._var_set_state("state").reverse()) == "{[...state.foo].reverse()}" ) assert ( @@ -288,11 +290,14 @@ def test_basic_operations(TestObj): [ (v([1, 2, 3]), "[1, 2, 3]"), (v(["1", "2", "3"]), '["1", "2", "3"]'), - (BaseVar(_var_name="foo", _var_state="state", _var_type=list), "state.foo"), + (BaseVar(_var_name="foo", _var_type=list)._var_set_state("state"), "state.foo"), (BaseVar(_var_name="foo", _var_type=list), "foo"), (v((1, 2, 3)), "[1, 2, 3]"), (v(("1", "2", "3")), '["1", "2", "3"]'), - (BaseVar(_var_name="foo", _var_state="state", _var_type=tuple), "state.foo"), + ( + BaseVar(_var_name="foo", _var_type=tuple)._var_set_state("state"), + "state.foo", + ), (BaseVar(_var_name="foo", _var_type=tuple), "foo"), ], ) @@ -301,7 +306,7 @@ def test_list_tuple_contains(var, expected): assert str(var.contains("1")) == f'{{{expected}.includes("1")}}' assert str(var.contains(v(1))) == f"{{{expected}.includes(1)}}" assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}' - other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str) + other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state") other_var = BaseVar(_var_name="other", _var_type=str) assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}" assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}" @@ -311,14 +316,14 @@ def test_list_tuple_contains(var, expected): "var, expected", [ (v("123"), json.dumps("123")), - (BaseVar(_var_name="foo", _var_state="state", _var_type=str), "state.foo"), + (BaseVar(_var_name="foo", _var_type=str)._var_set_state("state"), "state.foo"), (BaseVar(_var_name="foo", _var_type=str), "foo"), ], ) def test_str_contains(var, expected): assert str(var.contains("1")) == f'{{{expected}.includes("1")}}' assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}' - other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str) + other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state") other_var = BaseVar(_var_name="other", _var_type=str) assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}" assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}" @@ -328,7 +333,7 @@ def test_str_contains(var, expected): "var, expected", [ (v({"a": 1, "b": 2}), '{"a": 1, "b": 2}'), - (BaseVar(_var_name="foo", _var_state="state", _var_type=dict), "state.foo"), + (BaseVar(_var_name="foo", _var_type=dict)._var_set_state("state"), "state.foo"), (BaseVar(_var_name="foo", _var_type=dict), "foo"), ], ) @@ -337,7 +342,7 @@ def test_dict_contains(var, expected): assert str(var.contains("1")) == f'{{{expected}.hasOwnProperty("1")}}' assert str(var.contains(v(1))) == f"{{{expected}.hasOwnProperty(1)}}" assert str(var.contains(v("1"))) == f'{{{expected}.hasOwnProperty("1")}}' - other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str) + other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state") other_var = BaseVar(_var_name="other", _var_type=str) assert ( str(var.contains(other_state_var)) @@ -548,10 +553,10 @@ def test_var_unsupported_indexing_dicts(var, index): "fixture,full_name", [ ("ParentState", "parent_state.var_without_annotation"), - ("ChildState", "parent_state.child_state.var_without_annotation"), + ("ChildState", "parent_state__child_state.var_without_annotation"), ( "GrandChildState", - "parent_state.child_state.grand_child_state.var_without_annotation", + "parent_state__child_state__grand_child_state.var_without_annotation", ), ("StateWithAnyVar", "state_with_any_var.var_without_annotation"), ], @@ -630,8 +635,8 @@ def test_import_var(import_var, expected): [ (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"), ( - f"testing f-string with {BaseVar(_var_name='myvar', _var_state='state', _var_type=int)}", - "testing f-string with ${state.myvar}", + f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}", + 'testing f-string with ${"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}{state.myvar}', ), ( f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}", @@ -643,6 +648,35 @@ def test_fstrings(out, expected): assert out == expected +@pytest.mark.parametrize( + ("value", "expect_state"), + [ + ([1], ""), + ({"a": 1}, ""), + ([Var.create_safe(1)._var_set_state("foo")], "foo"), + ({"a": Var.create_safe(1)._var_set_state("foo")}, "foo"), + ], +) +def test_extract_state_from_container(value, expect_state): + """Test that _var_state is extracted from containers containing BaseVar. + + Args: + value: The value to create a var from. + expect_state: The expected state. + """ + assert Var.create_safe(value)._var_state == expect_state + + +def test_fstring_roundtrip(): + """Test that f-string roundtrip carries state.""" + var = BaseVar.create_safe("var")._var_set_state("state") + rt_var = Var.create_safe(f"{var}") + assert var._var_state == rt_var._var_state + assert var._var_full_name_needs_state_prefix + assert not rt_var._var_full_name_needs_state_prefix + assert rt_var._var_name == var._var_full_name + + @pytest.mark.parametrize( "var", [ diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index c83fd8c79..257f37514 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -8,7 +8,13 @@ from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent from reflex.style import Style from reflex.utils import format from reflex.vars import BaseVar, Var -from tests.test_state import ChildState, DateTimeState, GrandchildState, TestState +from tests.test_state import ( + ChildState, + ChildState2, + DateTimeState, + GrandchildState, + TestState, +) def mock_event(arg): @@ -349,7 +355,6 @@ def test_format_cond(condition: str, true_value: str, false_value: str, expected BaseVar( _var_name="_", _var_type=Any, - _var_state="", _var_is_local=True, _var_is_string=False, ), @@ -515,40 +520,44 @@ formatted_router = { ( TestState().dict(), # type: ignore { - "array": [1, 2, 3.14], - "child_state": { + TestState.get_full_name(): { + "array": [1, 2, 3.14], + "complex": { + 1: {"prop1": 42, "prop2": "hello"}, + 2: {"prop1": 42, "prop2": "hello"}, + }, + "dt": "1989-11-09 18:53:00+01:00", + "fig": [], + "is_hydrated": False, + "key": "", + "map_key": "a", + "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]}, + "num1": 0, + "num2": 3.14, + "obj": {"prop1": 42, "prop2": "hello"}, + "sum": 3.14, + "upper": "", + "router": formatted_router, + }, + ChildState.get_full_name(): { "count": 23, - "grandchild_state": {"value2": ""}, "value": "", }, - "child_state2": {"value": ""}, - "complex": { - 1: {"prop1": 42, "prop2": "hello"}, - 2: {"prop1": 42, "prop2": "hello"}, - }, - "dt": "1989-11-09 18:53:00+01:00", - "fig": [], - "is_hydrated": False, - "key": "", - "map_key": "a", - "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]}, - "num1": 0, - "num2": 3.14, - "obj": {"prop1": 42, "prop2": "hello"}, - "sum": 3.14, - "upper": "", - "router": formatted_router, + ChildState2.get_full_name(): {"value": ""}, + GrandchildState.get_full_name(): {"value2": ""}, }, ), ( DateTimeState().dict(), { - "d": "1989-11-09", - "dt": "1989-11-09 18:53:00+01:00", - "is_hydrated": False, - "t": "18:53:00+01:00", - "td": "11 days, 0:11:00", - "router": formatted_router, + DateTimeState.get_full_name(): { + "d": "1989-11-09", + "dt": "1989-11-09 18:53:00+01:00", + "is_hydrated": False, + "t": "18:53:00+01:00", + "td": "11 days, 0:11:00", + "router": formatted_router, + }, }, ), ],