diff --git a/integration/test_var_operations.py b/integration/test_var_operations.py
index 6eac82c4c..7da9d0bd0 100644
--- a/integration/test_var_operations.py
+++ b/integration/test_var_operations.py
@@ -28,6 +28,7 @@ def VarOperations():
str_var4: str = "a long string"
dict1: dict = {1: 2}
dict2: dict = {3: 4}
+ html_str: str = "
hello
"
app = rx.App(state=VarOperationState)
@@ -522,6 +523,19 @@ def VarOperations():
rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"),
rx.text(VarOperationState.list3.join(""), id="list_join"),
rx.text(VarOperationState.list3.join(","), id="list_join_comma"),
+ # Index from an op var
+ rx.text(
+ VarOperationState.list3[VarOperationState.int_var1 % 3],
+ id="list_index_mod",
+ ),
+ rx.html(
+ VarOperationState.html_str,
+ id="html_str",
+ ),
+ rx.highlight(
+ "second",
+ query=[VarOperationState.str_var2],
+ ),
rx.text(rx.Var.range(2, 5).join(","), id="list_join_range1"),
rx.text(rx.Var.range(2, 10, 2).join(","), id="list_join_range2"),
rx.text(rx.Var.range(5, 0, -1).join(","), id="list_join_range3"),
@@ -713,7 +727,14 @@ def test_var_operations(driver, var_operations: AppHarness):
("dict_eq_dict", "false"),
("dict_neq_dict", "true"),
("dict_contains", "true"),
+ # index from an op var
+ ("list_index_mod", "second"),
+ # html component with var
+ ("html_str", "hello"),
]
for tag, expected in tests:
assert driver.find_element(By.ID, tag).text == expected
+
+ # Highlight component with var query (does not plumb ID)
+ assert driver.find_element(By.TAG_NAME, "mark").text == "second"
diff --git a/reflex/.templates/jinja/web/pages/_app.js.jinja2 b/reflex/.templates/jinja/web/pages/_app.js.jinja2
index 4d3dff89a..deaf1a02b 100644
--- a/reflex/.templates/jinja/web/pages/_app.js.jinja2
+++ b/reflex/.templates/jinja/web/pages/_app.js.jinja2
@@ -1,7 +1,7 @@
{% extends "web/pages/base_page.js.jinja2" %}
{% block declaration %}
-import { EventLoopProvider } from "/utils/context.js";
+import { EventLoopProvider, StateProvider } from "/utils/context.js";
import { ThemeProvider } from 'next-themes'
{% for custom_code in custom_codes %}
@@ -25,12 +25,14 @@ export default function MyApp({ Component, pageProps }) {
return (
-
-
-
+
+
+
+
+
);
}
-{% endblock %}
\ No newline at end of file
+{% endblock %}
diff --git a/reflex/.templates/jinja/web/pages/index.js.jinja2 b/reflex/.templates/jinja/web/pages/index.js.jinja2
index 6f73c70c4..efb086ef5 100644
--- a/reflex/.templates/jinja/web/pages/index.js.jinja2
+++ b/reflex/.templates/jinja/web/pages/index.js.jinja2
@@ -8,32 +8,6 @@
{% block export %}
export default function Component() {
-{% if state_name %}
- const {{state_name}} = useContext(StateContext)
-{% endif %}
- const {{const.router}} = useRouter()
- const [ {{const.color_mode}}, {{const.toggle_color_mode}} ] = useContext(ColorModeContext)
- const focusRef = useRef();
-
- // Main event loop.
- const [addEvents, connectError] = useContext(EventLoopContext)
-
- // Set focus to the specified element.
- useEffect(() => {
- if (focusRef.current) {
- focusRef.current.focus();
- }
- })
-
- // Route after the initial page hydration.
- useEffect(() => {
- const change_complete = () => addEvents(initialEvents())
- {{const.router}}.events.on('routeChangeComplete', change_complete)
- return () => {
- {{const.router}}.events.off('routeChangeComplete', change_complete)
- }
- }, [{{const.router}}])
-
{% for hook in hooks %}
{{ hook }}
{% endfor %}
diff --git a/reflex/.templates/jinja/web/utils/context.js.jinja2 b/reflex/.templates/jinja/web/utils/context.js.jinja2
index c931b7515..53d7d4e58 100644
--- a/reflex/.templates/jinja/web/utils/context.js.jinja2
+++ b/reflex/.templates/jinja/web/utils/context.js.jinja2
@@ -1,5 +1,5 @@
-import { createContext, useState } from "react"
-import { Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
+import { createContext, useContext, useMemo, useReducer, useState } from "react"
+import { applyDelta, Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
{% if initial_state %}
export const initialState = {{ initial_state|json_dumps }}
@@ -8,7 +8,12 @@ export const initialState = {}
{% endif %}
export const ColorModeContext = createContext(null);
-export const StateContext = createContext(null);
+export const DispatchContext = createContext(null);
+export const StateContexts = {
+ {% for state_name in initial_state %}
+ {{state_name|var_name}}: createContext(null),
+ {% endfor %}
+}
export const EventLoopContext = createContext(null);
{% if client_storage %}
export const clientStorage = {{ client_storage|json_dumps }}
@@ -27,16 +32,40 @@ export const initialEvents = () => []
export const isDevMode = {{ is_dev_mode|json_dumps }}
export function EventLoopProvider({ children }) {
- const [state, addEvents, connectError] = useEventLoop(
- initialState,
+ const dispatch = useContext(DispatchContext)
+ const [addEvents, connectError] = useEventLoop(
+ dispatch,
initialEvents,
clientStorage,
)
return (
-
- {children}
-
+ {children}
)
-}
\ No newline at end of file
+}
+
+export function StateProvider({ children }) {
+ {% for state_name in initial_state %}
+ const [{{state_name|var_name}}, dispatch_{{state_name|var_name}}] = useReducer(applyDelta, initialState["{{state_name}}"])
+ {% endfor %}
+ const dispatchers = useMemo(() => {
+ return {
+ {% for state_name in initial_state %}
+ "{{state_name}}": dispatch_{{state_name|var_name}},
+ {% endfor %}
+ }
+ }, [])
+
+ return (
+ {% for state_name in initial_state %}
+
+ {% endfor %}
+
+ {children}
+
+ {% for state_name in initial_state|reverse %}
+
+ {% endfor %}
+ )
+}
diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js
index 22c853dea..9ae46d3c9 100644
--- a/reflex/.templates/web/utils/state.js
+++ b/reflex/.templates/web/utils/state.js
@@ -6,7 +6,7 @@ import env from "env.json";
import Cookies from "universal-cookie";
import { useEffect, useReducer, useRef, useState } from "react";
import Router, { useRouter } from "next/router";
-import { initialEvents } from "utils/context.js"
+import { initialEvents, initialState } from "utils/context.js"
// Endpoint URLs.
const EVENTURL = env.EVENT
@@ -100,37 +100,10 @@ export const getEventURL = () => {
* @param delta The delta to apply.
*/
export const applyDelta = (state, delta) => {
- const new_state = { ...state }
- for (const substate in delta) {
- let s = new_state;
- const path = substate.split(".").slice(1);
- while (path.length > 0) {
- s = s[path.shift()];
- }
- for (const key in delta[substate]) {
- s[key] = delta[substate][key];
- }
- }
- return new_state
+ return { ...state, ...delta }
};
-/**
- * Get all local storage items in a key-value object.
- * @returns object of items in local storage.
- */
-export const getAllLocalStorageItems = () => {
- var localStorageItems = {};
-
- for (var i = 0, len = localStorage.length; i < len; i++) {
- var key = localStorage.key(i);
- localStorageItems[key] = localStorage.getItem(key);
- }
-
- return localStorageItems;
-}
-
-
/**
* Handle frontend event or send the event to the backend via Websocket.
* @param event The event to send.
@@ -346,7 +319,9 @@ export const connect = async (
// On each received message, queue the updates and events.
socket.current.on("event", message => {
const update = JSON5.parse(message)
- dispatch(update.delta)
+ for (const substate in update.delta) {
+ dispatch[substate](update.delta[substate])
+ }
applyClientStorageDelta(client_storage, update.delta)
event_processing = !update.final
if (update.events) {
@@ -524,23 +499,21 @@ const applyClientStorageDelta = (client_storage, delta) => {
/**
* Establish websocket event loop for a NextJS page.
- * @param initial_state The initial app state.
- * @param initial_events Function that returns the initial app events.
+ * @param dispatch The reducer dispatch function to update state.
+ * @param initial_events The initial app events.
* @param client_storage The client storage object from context.js
*
- * @returns [state, addEvents, connectError] -
- * state is a reactive dict,
+ * @returns [addEvents, connectError] -
* addEvents is used to queue an event, and
* connectError is a reactive js error from the websocket connection (or null if connected).
*/
export const useEventLoop = (
- initial_state = {},
+ dispatch,
initial_events = () => [],
client_storage = {},
) => {
const socket = useRef(null)
const router = useRouter()
- const [state, dispatch] = useReducer(applyDelta, initial_state)
const [connectError, setConnectError] = useState(null)
// Function to add new events to the event queue.
@@ -570,7 +543,7 @@ export const useEventLoop = (
return;
}
// only use websockets if state is present
- if (Object.keys(state).length > 0) {
+ if (Object.keys(initialState).length > 0) {
// Initialize the websocket connection.
if (!socket.current) {
connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage)
@@ -583,7 +556,17 @@ export const useEventLoop = (
})()
}
})
- return [state, addEvents, connectError]
+
+ // Route after the initial page hydration.
+ useEffect(() => {
+ const change_complete = () => addEvents(initial_events())
+ router.events.on('routeChangeComplete', change_complete)
+ return () => {
+ router.events.off('routeChangeComplete', change_complete)
+ }
+ }, [router])
+
+ return [addEvents, connectError]
}
/***
diff --git a/reflex/app.py b/reflex/app.py
index 6248bcec0..01851d48d 100644
--- a/reflex/app.py
+++ b/reflex/app.py
@@ -63,7 +63,7 @@ from reflex.state import (
StateUpdate,
)
from reflex.utils import console, format, prerequisites, types
-from reflex.vars import ImportVar
+from reflex.utils.imports import ImportVar
# Define custom types.
ComponentCallable = Callable[[], Component]
diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py
index 4ffd4a876..3b9d7db3e 100644
--- a/reflex/compiler/compiler.py
+++ b/reflex/compiler/compiler.py
@@ -10,40 +10,10 @@ from reflex.compiler import templates, utils
from reflex.components.component import Component, ComponentStyle, CustomComponent
from reflex.config import get_config
from reflex.state import State
-from reflex.utils import imports
-from reflex.vars import ImportVar
+from reflex.utils.imports import ImportDict, ImportVar
# Imports to be included in every Reflex app.
-DEFAULT_IMPORTS: imports.ImportDict = {
- "react": [
- ImportVar(tag="Fragment"),
- ImportVar(tag="useEffect"),
- ImportVar(tag="useRef"),
- ImportVar(tag="useState"),
- ImportVar(tag="useContext"),
- ],
- "next/router": [ImportVar(tag="useRouter")],
- f"/{constants.Dirs.STATE_PATH}": [
- ImportVar(tag="uploadFiles"),
- ImportVar(tag="Event"),
- ImportVar(tag="isTrue"),
- ImportVar(tag="spreadArraysOrObjects"),
- ImportVar(tag="preventDefault"),
- ImportVar(tag="refs"),
- ImportVar(tag="getRefValue"),
- ImportVar(tag="getRefValues"),
- ImportVar(tag="getAllLocalStorageItems"),
- ImportVar(tag="useEventLoop"),
- ],
- "/utils/context.js": [
- ImportVar(tag="EventLoopContext"),
- ImportVar(tag="initialEvents"),
- ImportVar(tag="StateContext"),
- ImportVar(tag="ColorModeContext"),
- ],
- "/utils/helpers/range.js": [
- ImportVar(tag="range", is_default=True),
- ],
+DEFAULT_IMPORTS: ImportDict = {
"": [ImportVar(tag="focus-visible/dist/focus-visible", install=False)],
}
diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py
index f2d1272aa..57bbb44b8 100644
--- a/reflex/compiler/templates.py
+++ b/reflex/compiler/templates.py
@@ -3,7 +3,7 @@
from jinja2 import Environment, FileSystemLoader, Template
from reflex import constants
-from reflex.utils.format import json_dumps
+from reflex.utils.format import format_state_name, json_dumps
class ReflexJinjaEnvironment(Environment):
@@ -19,6 +19,7 @@ class ReflexJinjaEnvironment(Environment):
)
self.filters["json_dumps"] = json_dumps
self.filters["react_setter"] = lambda state: f"set{state.capitalize()}"
+ self.filters["var_name"] = format_state_name
self.loader = FileSystemLoader(constants.Templates.Dirs.JINJA_TEMPLATE)
self.globals["const"] = {
"socket": constants.CompileVars.SOCKET,
diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py
index b6967eaef..10a4691a4 100644
--- a/reflex/compiler/utils.py
+++ b/reflex/compiler/utils.py
@@ -24,13 +24,12 @@ from reflex.components.component import Component, ComponentStyle, CustomCompone
from reflex.state import Cookie, LocalStorage, State
from reflex.style import Style
from reflex.utils import console, format, imports, path_ops
-from reflex.vars import ImportVar
# To re-export this function.
merge_imports = imports.merge_imports
-def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
+def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]:
"""Compile an import statement.
Args:
@@ -343,7 +342,9 @@ def get_context_path() -> str:
Returns:
The path of the context module.
"""
- return os.path.join(constants.Dirs.WEB_UTILS, "context" + constants.Ext.JS)
+ return os.path.join(
+ constants.Dirs.WEB, constants.Dirs.CONTEXTS_PATH + constants.Ext.JS
+ )
def get_components_path() -> str:
diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py
index 190e95e6a..ee66ecd4d 100644
--- a/reflex/components/base/bare.py
+++ b/reflex/components/base/bare.py
@@ -1,7 +1,7 @@
"""A bare component."""
from __future__ import annotations
-from typing import Any
+from typing import Any, Iterator
from reflex.components.component import Component
from reflex.components.tags import Tag
@@ -24,7 +24,21 @@ class Bare(Component):
Returns:
The component.
"""
- return cls(contents=str(contents)) # type: ignore
+ if isinstance(contents, Var) and contents._var_data:
+ contents = contents.to(str)
+ else:
+ contents = str(contents)
+ return cls(contents=contents) # type: ignore
def _render(self) -> Tag:
return Tagless(contents=str(self.contents))
+
+ def _get_vars(self) -> Iterator[Var]:
+ """Walk all Vars used in this component.
+
+ Yields:
+ The contents if it is a Var, otherwise nothing.
+ """
+ if isinstance(self.contents, Var):
+ # Fast path for Bare text components.
+ yield self.contents
diff --git a/reflex/components/component.py b/reflex/components/component.py
index d274fbc37..70097cf60 100644
--- a/reflex/components/component.py
+++ b/reflex/components/component.py
@@ -5,11 +5,11 @@ from __future__ import annotations
import typing
from abc import ABC
from functools import lru_cache, wraps
-from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
+from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Type, Union
from reflex.base import Base
from reflex.components.tags import Tag
-from reflex.constants import Dirs, EventTriggers
+from reflex.constants import Dirs, EventTriggers, Hooks, Imports
from reflex.event import (
EventChain,
EventHandler,
@@ -20,8 +20,9 @@ from reflex.event import (
)
from reflex.style import Style
from reflex.utils import console, format, imports, types
+from reflex.utils.imports import ImportVar
from reflex.utils.serializers import serializer
-from reflex.vars import BaseVar, ImportVar, Var
+from reflex.vars import BaseVar, Var
class Component(Base, ABC):
@@ -388,7 +389,11 @@ class Component(Base, ABC):
props = props.copy()
props.update(
- self.event_triggers,
+ **{
+ trigger: handler
+ for trigger, handler in self.event_triggers.items()
+ if trigger not in {EventTriggers.ON_MOUNT, EventTriggers.ON_UNMOUNT}
+ },
key=self.key,
id=self.id,
class_name=self.class_name,
@@ -488,7 +493,7 @@ class Component(Base, ABC):
"""
if type(self) in style:
# Extract the style for this component.
- component_style = Style(style[type(self)])
+ component_style = style[type(self)]
# Only add style props that are not overridden.
component_style = {
@@ -564,6 +569,78 @@ class Component(Base, ABC):
if self._valid_children:
validate_valid_child(name)
+ @staticmethod
+ def _get_vars_from_event_triggers(
+ event_triggers: dict[str, EventChain | Var],
+ ) -> Iterator[tuple[str, list[Var]]]:
+ """Get the Vars associated with each event trigger.
+
+ Args:
+ event_triggers: The event triggers from the component instance.
+
+ Yields:
+ tuple of (event_name, event_vars)
+ """
+ for event_trigger, event in event_triggers.items():
+ if isinstance(event, Var):
+ yield event_trigger, [event]
+ elif isinstance(event, EventChain):
+ event_args = []
+ for spec in event.events:
+ for args in spec.args:
+ event_args.extend(args)
+ yield event_trigger, event_args
+
+ def _get_vars(self) -> list[Var]:
+ """Walk all Vars used in this component.
+
+ Returns:
+ Each var referenced by the component (props, styles, event handlers).
+ """
+ vars = getattr(self, "__vars", None)
+ if vars is not None:
+ return vars
+ vars = self.__vars = []
+ # Get Vars associated with event trigger arguments.
+ for _, event_vars in self._get_vars_from_event_triggers(self.event_triggers):
+ vars.extend(event_vars)
+
+ # Get Vars associated with component props.
+ for prop in self.get_props():
+ prop_var = getattr(self, prop)
+ if isinstance(prop_var, Var):
+ vars.append(prop_var)
+
+ # Style keeps track of its own VarData instance, so embed in a temp Var that is yielded.
+ if self.style:
+ vars.append(
+ BaseVar(
+ _var_name="style",
+ _var_type=str,
+ _var_data=self.style._var_data,
+ )
+ )
+
+ # Special props are always Var instances.
+ vars.extend(self.special_props)
+
+ # Get Vars associated with common Component props.
+ for comp_prop in (
+ self.class_name,
+ self.id,
+ self.key,
+ self.autofocus,
+ *self.custom_attrs.values(),
+ ):
+ if isinstance(comp_prop, Var):
+ vars.append(comp_prop)
+ elif isinstance(comp_prop, str):
+ # Collapse VarData encoded in f-strings.
+ var = Var.create_safe(comp_prop)
+ if var._var_data is not None:
+ vars.append(var)
+ return vars
+
def _get_custom_code(self) -> str | None:
"""Get custom code for the component.
@@ -644,6 +721,33 @@ class Component(Base, ABC):
dep: [ImportVar(tag=None, render=False)] for dep in self.lib_dependencies
}
+ def _get_hooks_imports(self) -> imports.ImportDict:
+ """Get the imports required by certain hooks.
+
+ Returns:
+ The imports required for all selected hooks.
+ """
+ _imports = {}
+
+ if self._get_ref_hook():
+ # Handle hooks needed for attaching react refs to DOM nodes.
+ _imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
+ _imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs"))
+
+ if self._get_mount_lifecycle_hook():
+ # Handle hooks for `on_mount` / `on_unmount`.
+ _imports.setdefault("react", set()).add(ImportVar(tag="useEffect"))
+
+ if self._get_special_hooks():
+ # Handle additional internal hooks (autofocus, etc).
+ _imports.setdefault("react", set()).update(
+ {
+ ImportVar(tag="useRef"),
+ ImportVar(tag="useEffect"),
+ },
+ )
+ return _imports
+
def _get_imports(self) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component.
@@ -651,13 +755,26 @@ class Component(Base, ABC):
The imports needed by the component.
"""
_imports = {}
+
+ # Import this component's tag from the main library.
if self.library is not None and self.tag is not None:
_imports[self.library] = {self.import_var}
+ # Get static imports required for event processing.
+ event_imports = Imports.EVENTS if self.event_triggers else {}
+
+ # Collect imports from Vars used directly by this component.
+ var_imports = [
+ var._var_data.imports for var in self._get_vars() if var._var_data
+ ]
+
return imports.merge_imports(
*self._get_props_imports(),
self._get_dependencies_imports(),
+ self._get_hooks_imports(),
_imports,
+ event_imports,
+ *var_imports,
)
def get_imports(self) -> imports.ImportDict:
@@ -678,13 +795,13 @@ class Component(Base, ABC):
"""
# pop on_mount and on_unmount from event_triggers since these are handled by
# hooks, not as actually props in the component
- on_mount = self.event_triggers.pop(EventTriggers.ON_MOUNT, None)
- on_unmount = self.event_triggers.pop(EventTriggers.ON_UNMOUNT, None)
- if on_mount:
+ on_mount = self.event_triggers.get(EventTriggers.ON_MOUNT, None)
+ on_unmount = self.event_triggers.get(EventTriggers.ON_UNMOUNT, None)
+ if on_mount is not None:
on_mount = format.format_event_chain(on_mount)
- if on_unmount:
+ if on_unmount is not None:
on_unmount = format.format_event_chain(on_unmount)
- if on_mount or on_unmount:
+ if on_mount is not None or on_unmount is not None:
return f"""
useEffect(() => {{
{on_mount or ""}
@@ -703,6 +820,47 @@ class Component(Base, ABC):
if ref is not None:
return f"const {ref} = useRef(null); refs['{ref}'] = {ref};"
+ def _get_vars_hooks(self) -> set[str]:
+ """Get the hooks required by vars referenced in this component.
+
+ Returns:
+ The hooks for the vars.
+ """
+ vars_hooks = set()
+ for var in self._get_vars():
+ if var._var_data:
+ vars_hooks.update(var._var_data.hooks)
+ return vars_hooks
+
+ def _get_events_hooks(self) -> set[str]:
+ """Get the hooks required by events referenced in this component.
+
+ Returns:
+ The hooks for the events.
+ """
+ if self.event_triggers:
+ return {Hooks.EVENTS}
+ return set()
+
+ def _get_special_hooks(self) -> set[str]:
+ """Get the hooks required by special actions referenced in this component.
+
+ Returns:
+ The hooks for special actions.
+ """
+ if self.autofocus:
+ return {
+ """
+ // Set focus to the specified element.
+ const focusRef = useRef(null)
+ useEffect(() => {
+ if (focusRef.current) {
+ focusRef.current.focus();
+ }
+ })""",
+ }
+ return set()
+
def _get_hooks_internal(self) -> Set[str]:
"""Get the React hooks for this component managed by the framework.
@@ -712,10 +870,15 @@ class Component(Base, ABC):
Returns:
Set of internally managed hooks.
"""
- return set(
- hook
- for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
- if hook
+ return (
+ set(
+ hook
+ for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
+ if hook
+ )
+ | self._get_vars_hooks()
+ | self._get_events_hooks()
+ | self._get_special_hooks()
)
def _get_hooks(self) -> str | None:
@@ -1018,11 +1181,24 @@ class NoSSRComponent(Component):
"""A dynamic component that is not rendered on the server."""
def _get_imports(self) -> imports.ImportDict:
- dynamic_import = {"next/dynamic": {ImportVar(tag="dynamic", is_default=True)}}
+ """Get the imports for the component.
+
+ Returns:
+ The imports for dynamically importing the component at module load time.
+ """
+ # Next.js dynamic import mechanism.
+ dynamic_import = {"next/dynamic": [ImportVar(tag="dynamic", is_default=True)]}
+
+ # The normal imports for this component.
+ _imports = super()._get_imports()
+
+ # Do NOT import the main library/tag statically.
+ if self.library is not None:
+ _imports[self.library] = [imports.ImportVar(tag=None, render=False)]
return imports.merge_imports(
dynamic_import,
- {self.library: {ImportVar(tag=None, render=False)}},
+ _imports,
self._get_dependencies_imports(),
)
diff --git a/reflex/components/datadisplay/code.py b/reflex/components/datadisplay/code.py
index c56eff8d0..0c796cd08 100644
--- a/reflex/components/datadisplay/code.py
+++ b/reflex/components/datadisplay/code.py
@@ -12,7 +12,8 @@ from reflex.components.media import Icon
from reflex.event import set_clipboard
from reflex.style import Style
from reflex.utils import format, imports
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
LiteralCodeBlockTheme = Literal[
"a11y-dark",
diff --git a/reflex/components/datadisplay/code.pyi b/reflex/components/datadisplay/code.pyi
index 79e756ed0..5a88faeb3 100644
--- a/reflex/components/datadisplay/code.pyi
+++ b/reflex/components/datadisplay/code.pyi
@@ -16,7 +16,8 @@ from reflex.components.media import Icon
from reflex.event import set_clipboard
from reflex.style import Style
from reflex.utils import format, imports
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
LiteralCodeBlockTheme = Literal[
"a11y-dark",
diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py
index e666d284c..f4b327aa6 100644
--- a/reflex/components/datadisplay/dataeditor.py
+++ b/reflex/components/datadisplay/dataeditor.py
@@ -8,8 +8,9 @@ from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.components.literals import LiteralRowMarker
from reflex.utils import console, format, imports, types
+from reflex.utils.imports import ImportVar
from reflex.utils.serializers import serializer
-from reflex.vars import ImportVar, Var, get_unique_variable_name
+from reflex.vars import Var, get_unique_variable_name
# TODO: Fix the serialization issue for custom types.
diff --git a/reflex/components/datadisplay/dataeditor.pyi b/reflex/components/datadisplay/dataeditor.pyi
index c951e5c77..af209f422 100644
--- a/reflex/components/datadisplay/dataeditor.pyi
+++ b/reflex/components/datadisplay/dataeditor.pyi
@@ -13,8 +13,9 @@ from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.components.literals import LiteralRowMarker
from reflex.utils import console, format, imports, types
+from reflex.utils.imports import ImportVar
from reflex.utils.serializers import serializer
-from reflex.vars import ImportVar, Var, get_unique_variable_name
+from reflex.vars import Var, get_unique_variable_name
class GridColumnIcons(Enum):
Array = "array"
diff --git a/reflex/components/datadisplay/datatable.py b/reflex/components/datadisplay/datatable.py
index 52bd45282..5b66fa54b 100644
--- a/reflex/components/datadisplay/datatable.py
+++ b/reflex/components/datadisplay/datatable.py
@@ -8,7 +8,7 @@ from reflex.components.component import Component
from reflex.components.tags import Tag
from reflex.utils import imports, types
from reflex.utils.serializers import serialize, serializer
-from reflex.vars import BaseVar, ComputedVar, ImportVar, Var
+from reflex.vars import BaseVar, ComputedVar, Var
class Gridjs(Component):
@@ -105,7 +105,7 @@ class DataTable(Gridjs):
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
- {"": {ImportVar(tag="gridjs/dist/theme/mermaid.css")}},
+ {"": {imports.ImportVar(tag="gridjs/dist/theme/mermaid.css")}},
)
def _render(self) -> Tag:
@@ -113,13 +113,13 @@ class DataTable(Gridjs):
self.columns = BaseVar(
_var_name=f"{self.data._var_name}.columns",
_var_type=List[Any],
- _var_state=self.data._var_state,
- )
+ _var_full_name_needs_state_prefix=True,
+ )._replace(merge_var_data=self.data._var_data)
self.data = BaseVar(
_var_name=f"{self.data._var_name}.data",
_var_type=List[List[Any]],
- _var_state=self.data._var_state,
- )
+ _var_full_name_needs_state_prefix=True,
+ )._replace(merge_var_data=self.data._var_data)
if types.is_dataframe(type(self.data)):
# If given a pandas df break up the data and columns
data = serialize(self.data)
diff --git a/reflex/components/datadisplay/datatable.pyi b/reflex/components/datadisplay/datatable.pyi
index 1a119dcc9..49e3ad752 100644
--- a/reflex/components/datadisplay/datatable.pyi
+++ b/reflex/components/datadisplay/datatable.pyi
@@ -12,7 +12,7 @@ from reflex.components.component import Component
from reflex.components.tags import Tag
from reflex.utils import imports, types
from reflex.utils.serializers import serialize, serializer
-from reflex.vars import BaseVar, ComputedVar, ImportVar, Var
+from reflex.vars import BaseVar, ComputedVar, Var
class Gridjs(Component):
@overload
diff --git a/reflex/components/datadisplay/moment.py b/reflex/components/datadisplay/moment.py
index 9ae8381bc..31ffb5ffa 100644
--- a/reflex/components/datadisplay/moment.py
+++ b/reflex/components/datadisplay/moment.py
@@ -4,7 +4,7 @@ from typing import Any, Dict, List
from reflex.components.component import Component, NoSSRComponent
from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
class Moment(NoSSRComponent):
@@ -78,7 +78,7 @@ class Moment(NoSSRComponent):
if self.tz is not None:
merged_imports = imports.merge_imports(
merged_imports,
- {"moment-timezone": {ImportVar(tag="")}},
+ {"moment-timezone": {imports.ImportVar(tag="")}},
)
return merged_imports
diff --git a/reflex/components/datadisplay/moment.pyi b/reflex/components/datadisplay/moment.pyi
index 0e9fcc4c4..2c7696037 100644
--- a/reflex/components/datadisplay/moment.pyi
+++ b/reflex/components/datadisplay/moment.pyi
@@ -10,7 +10,7 @@ from reflex.style import Style
from typing import Any, Dict, List
from reflex.components.component import Component, NoSSRComponent
from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
class Moment(NoSSRComponent):
def get_event_triggers(self) -> Dict[str, Any]: ...
diff --git a/reflex/components/forms/colormodeswitch.py b/reflex/components/forms/colormodeswitch.py
index 0aeefcd33..ca071a2a7 100644
--- a/reflex/components/forms/colormodeswitch.py
+++ b/reflex/components/forms/colormodeswitch.py
@@ -22,7 +22,7 @@ from reflex.components.component import Component
from reflex.components.layout.cond import Cond, cond
from reflex.components.media.icon import Icon
from reflex.style import color_mode, toggle_color_mode
-from reflex.vars import BaseVar
+from reflex.vars import Var
from .button import Button
from .switch import Switch
@@ -32,7 +32,7 @@ DEFAULT_LIGHT_ICON: Icon = Icon.create(tag="sun")
DEFAULT_DARK_ICON: Icon = Icon.create(tag="moon")
-def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component:
+def color_mode_cond(light: Any, dark: Any = None) -> Var | Component:
"""Create a component or Prop based on color_mode.
Args:
diff --git a/reflex/components/forms/colormodeswitch.pyi b/reflex/components/forms/colormodeswitch.pyi
index 478bf5c37..83af8f20c 100644
--- a/reflex/components/forms/colormodeswitch.pyi
+++ b/reflex/components/forms/colormodeswitch.pyi
@@ -12,7 +12,7 @@ from reflex.components.component import Component
from reflex.components.layout.cond import Cond, cond
from reflex.components.media.icon import Icon
from reflex.style import color_mode, toggle_color_mode
-from reflex.vars import BaseVar
+from reflex.vars import Var
from .button import Button
from .switch import Switch
@@ -20,7 +20,7 @@ DEFAULT_COLOR_MODE: str
DEFAULT_LIGHT_ICON: Icon
DEFAULT_DARK_ICON: Icon
-def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component: ...
+def color_mode_cond(light: Any, dark: Any = None) -> Var | Component: ...
class ColorModeIcon(Cond):
@overload
diff --git a/reflex/components/forms/debounce.py b/reflex/components/forms/debounce.py
index da318a1cb..1618d29ba 100644
--- a/reflex/components/forms/debounce.py
+++ b/reflex/components/forms/debounce.py
@@ -1,10 +1,11 @@
"""Wrapper around react-debounce-input."""
from __future__ import annotations
-from typing import Any
+from typing import Any, Set
from reflex.components import Component
from reflex.components.tags import Tag
+from reflex.utils import imports
from reflex.vars import Var
@@ -77,6 +78,17 @@ class DebounceInput(Component):
object.__setattr__(child, "render", lambda: "")
return tag
+ def _get_imports(self) -> imports.ImportDict:
+ return imports.merge_imports(
+ super()._get_imports(), *[c._get_imports() for c in self.children]
+ )
+
+ def _get_hooks_internal(self) -> Set[str]:
+ hooks = super()._get_hooks_internal()
+ for child in self.children:
+ hooks.update(child._get_hooks_internal())
+ return hooks
+
def props_not_none(c: Component) -> dict[str, Any]:
"""Get all properties of the component that are not None.
diff --git a/reflex/components/forms/debounce.pyi b/reflex/components/forms/debounce.pyi
index 8c2688f94..975aa8be2 100644
--- a/reflex/components/forms/debounce.pyi
+++ b/reflex/components/forms/debounce.pyi
@@ -7,9 +7,10 @@ from typing import Any, Dict, Literal, Optional, Union, overload
from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style
-from typing import Any
+from typing import Any, Set
from reflex.components import Component
from reflex.components.tags import Tag
+from reflex.utils import imports
from reflex.vars import Var
class DebounceInput(Component):
diff --git a/reflex/components/forms/editor.py b/reflex/components/forms/editor.py
index b7bfb08a5..92a1e80c3 100644
--- a/reflex/components/forms/editor.py
+++ b/reflex/components/forms/editor.py
@@ -8,7 +8,8 @@ from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.constants import EventTriggers
from reflex.utils.format import to_camel_case
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
class EditorButtonList(list, enum.Enum):
diff --git a/reflex/components/forms/editor.pyi b/reflex/components/forms/editor.pyi
index 2b806fdef..1eaeea038 100644
--- a/reflex/components/forms/editor.pyi
+++ b/reflex/components/forms/editor.pyi
@@ -13,7 +13,8 @@ from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.constants import EventTriggers
from reflex.utils.format import to_camel_case
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
class EditorButtonList(list, enum.Enum):
BASIC = [["font", "fontSize"], ["fontColor"], ["horizontalRule"], ["link", "image"]]
diff --git a/reflex/components/forms/form.py b/reflex/components/forms/form.py
index ee793da2f..301fec8e0 100644
--- a/reflex/components/forms/form.py
+++ b/reflex/components/forms/form.py
@@ -8,7 +8,7 @@ from jinja2 import Environment
from reflex.components.component import Component
from reflex.components.libs.chakra import ChakraComponent
from reflex.components.tags import Tag
-from reflex.constants import EventTriggers
+from reflex.constants import Dirs, EventTriggers
from reflex.event import EventChain
from reflex.utils import imports
from reflex.utils.format import format_event_chain, to_camel_case
@@ -65,7 +65,13 @@ class Form(ChakraComponent):
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
- {"react": {imports.ImportVar(tag="useCallback")}},
+ {
+ "react": {imports.ImportVar(tag="useCallback")},
+ f"/{Dirs.STATE_PATH}": {
+ imports.ImportVar(tag="getRefValue"),
+ imports.ImportVar(tag="getRefValues"),
+ },
+ },
)
def _get_hooks(self) -> str | None:
diff --git a/reflex/components/forms/form.pyi b/reflex/components/forms/form.pyi
index 827b2273d..e1a9c325a 100644
--- a/reflex/components/forms/form.pyi
+++ b/reflex/components/forms/form.pyi
@@ -12,7 +12,7 @@ from jinja2 import Environment
from reflex.components.component import Component
from reflex.components.libs.chakra import ChakraComponent
from reflex.components.tags import Tag
-from reflex.constants import EventTriggers
+from reflex.constants import Dirs, EventTriggers
from reflex.event import EventChain
from reflex.utils import imports
from reflex.utils.format import format_event_chain, to_camel_case
diff --git a/reflex/components/forms/input.py b/reflex/components/forms/input.py
index b4af66b02..7a2f5e4c8 100644
--- a/reflex/components/forms/input.py
+++ b/reflex/components/forms/input.py
@@ -11,7 +11,7 @@ from reflex.components.libs.chakra import (
)
from reflex.constants import EventTriggers
from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
class Input(ChakraComponent):
@@ -61,7 +61,7 @@ class Input(ChakraComponent):
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
- {"/utils/state": {ImportVar(tag="set_val")}},
+ {"/utils/state": {imports.ImportVar(tag="set_val")}},
)
def get_event_triggers(self) -> Dict[str, Any]:
diff --git a/reflex/components/forms/input.pyi b/reflex/components/forms/input.pyi
index 6ec79fe1f..d49a6a6df 100644
--- a/reflex/components/forms/input.pyi
+++ b/reflex/components/forms/input.pyi
@@ -17,7 +17,7 @@ from reflex.components.libs.chakra import (
)
from reflex.constants import EventTriggers
from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
class Input(ChakraComponent):
def get_event_triggers(self) -> Dict[str, Any]: ...
diff --git a/reflex/components/forms/pininput.py b/reflex/components/forms/pininput.py
index c090e1571..83323c25c 100644
--- a/reflex/components/forms/pininput.py
+++ b/reflex/components/forms/pininput.py
@@ -68,9 +68,11 @@ class PinInput(ChakraComponent):
Returns:
The merged import dict.
"""
+ range_var = Var.range(0)
return merge_imports(
super()._get_imports(),
PinInputField().get_imports(), # type: ignore
+ range_var._var_data.imports if range_var._var_data is not None else {},
)
def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
@@ -117,7 +119,7 @@ class PinInput(ChakraComponent):
)
refs_declaration._var_is_local = True
if ref:
- return f"const {ref} = {refs_declaration}"
+ return f"const {ref} = {str(refs_declaration)}"
return super()._get_ref_hook()
def _render(self) -> Tag:
diff --git a/reflex/components/forms/upload.py b/reflex/components/forms/upload.py
index bd01d01c6..b79e0b106 100644
--- a/reflex/components/forms/upload.py
+++ b/reflex/components/forms/upload.py
@@ -7,9 +7,10 @@ from reflex import constants
from reflex.components.component import Component
from reflex.components.forms.input import Input
from reflex.components.layout.box import Box
+from reflex.constants import Dirs
from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
from reflex.utils import imports
-from reflex.vars import BaseVar, CallableVar, ImportVar, Var
+from reflex.vars import BaseVar, CallableVar, Var, VarData
DEFAULT_UPLOAD_ID: str = "default"
@@ -30,6 +31,13 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
return BaseVar(
_var_name=f"e => upload_files.{id_}[1]((files) => e)",
_var_type=EventChain,
+ _var_data=VarData( # type: ignore
+ imports={
+ f"/{Dirs.STATE_PATH}": {
+ imports.ImportVar(tag="upload_files"),
+ },
+ },
+ ),
)
@@ -46,6 +54,13 @@ def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
return BaseVar(
_var_name=f"(upload_files.{id_} ? upload_files.{id_}[0]?.map((f) => (f.path || f.name)) : [])",
_var_type=List[str],
+ _var_data=VarData( # type: ignore
+ imports={
+ f"/{Dirs.STATE_PATH}": {
+ imports.ImportVar(tag="upload_files"),
+ },
+ },
+ ),
)
@@ -166,14 +181,16 @@ class Upload(Component):
def _get_hooks(self) -> str | None:
return (
- (super()._get_hooks() or "")
- + f"""
- upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);
- """
- )
+ super()._get_hooks() or ""
+ ) + f"upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);"
def _get_imports(self) -> imports.ImportDict:
- return {
- **super()._get_imports(),
- f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="upload_files")],
- }
+ return imports.merge_imports(
+ super()._get_imports(),
+ {
+ "react": {imports.ImportVar(tag="useState")},
+ f"/{constants.Dirs.STATE_PATH}": [
+ imports.ImportVar(tag="upload_files")
+ ],
+ },
+ )
diff --git a/reflex/components/forms/upload.pyi b/reflex/components/forms/upload.pyi
index 87e5d58c7..5486ee90d 100644
--- a/reflex/components/forms/upload.pyi
+++ b/reflex/components/forms/upload.pyi
@@ -12,9 +12,10 @@ from reflex import constants
from reflex.components.component import Component
from reflex.components.forms.input import Input
from reflex.components.layout.box import Box
+from reflex.constants import Dirs
from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
from reflex.utils import imports
-from reflex.vars import BaseVar, CallableVar, ImportVar, Var
+from reflex.vars import BaseVar, CallableVar, Var, VarData
DEFAULT_UPLOAD_ID: str
diff --git a/reflex/components/layout/cond.py b/reflex/components/layout/cond.py
index 2455433d7..37b86fcc2 100644
--- a/reflex/components/layout/cond.py
+++ b/reflex/components/layout/cond.py
@@ -1,13 +1,18 @@
"""Create a list of components from an iterable."""
from __future__ import annotations
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, overload
from reflex.components.component import Component
from reflex.components.layout.fragment import Fragment
from reflex.components.tags import CondTag, Tag
-from reflex.utils import format
-from reflex.vars import Var
+from reflex.constants import Dirs
+from reflex.utils import format, imports
+from reflex.vars import BaseVar, Var, VarData
+
+_IS_TRUE_IMPORT = {
+ f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
+}
class Cond(Component):
@@ -88,6 +93,28 @@ class Cond(Component):
cond_state=f"isTrue({self.cond._var_full_name})",
)
+ def _get_imports(self) -> imports.ImportDict:
+ return imports.merge_imports(
+ super()._get_imports(),
+ getattr(self.cond._var_data, "imports", {}),
+ _IS_TRUE_IMPORT,
+ )
+
+
+@overload
+def cond(condition: Any, c1: Component, c2: Any) -> Component:
+ ...
+
+
+@overload
+def cond(condition: Any, c1: Component) -> Component:
+ ...
+
+
+@overload
+def cond(condition: Any, c1: Any, c2: Any) -> Var:
+ ...
+
def cond(condition: Any, c1: Any, c2: Any = None):
"""Create a conditional component or Prop.
@@ -103,8 +130,11 @@ def cond(condition: Any, c1: Any, c2: Any = None):
Raises:
ValueError: If the arguments are invalid.
"""
- # Import here to avoid circular imports.
- from reflex.vars import BaseVar, Var
+ var_datas: list[VarData | None] = [
+ VarData( # type: ignore
+ imports=_IS_TRUE_IMPORT,
+ ),
+ ]
# Convert the condition to a Var.
cond_var = Var.create(condition)
@@ -116,16 +146,20 @@ def cond(condition: Any, c1: Any, c2: Any = None):
c2, Component
), "Both arguments must be components."
return Cond.create(cond_var, c1, c2)
+ if isinstance(c1, Var):
+ var_datas.append(c1._var_data)
- # Otherwise, create a conditionl Var.
+ # Otherwise, create a conditional Var.
# Check that the second argument is valid.
if isinstance(c2, Component):
raise ValueError("Both arguments must be props.")
if c2 is None:
raise ValueError("For conditional vars, the second argument must be set.")
+ if isinstance(c2, Var):
+ var_datas.append(c2._var_data)
# Create the conditional var.
- return BaseVar(
+ return cond_var._replace(
_var_name=format.format_cond(
cond=cond_var._var_full_name,
true_value=c1,
@@ -133,4 +167,7 @@ def cond(condition: Any, c1: Any, c2: Any = None):
is_prop=True,
),
_var_type=c1._var_type if isinstance(c1, BaseVar) else type(c1),
+ _var_is_local=False,
+ _var_full_name_needs_state_prefix=False,
+ merge_var_data=VarData.merge(*var_datas),
)
diff --git a/reflex/components/layout/html.py b/reflex/components/layout/html.py
index 893e2ecca..3a4ba76ad 100644
--- a/reflex/components/layout/html.py
+++ b/reflex/components/layout/html.py
@@ -1,8 +1,8 @@
"""A html component."""
-
-from typing import Any
+from typing import Dict
from reflex.components.layout.box import Box
+from reflex.vars import Var
class Html(Box):
@@ -13,7 +13,7 @@ class Html(Box):
"""
# The HTML to render.
- dangerouslySetInnerHTML: Any
+ dangerouslySetInnerHTML: Var[Dict[str, str]]
@classmethod
def create(cls, *children, **props):
diff --git a/reflex/components/layout/html.pyi b/reflex/components/layout/html.pyi
index ec513ca29..aec46e2f9 100644
--- a/reflex/components/layout/html.pyi
+++ b/reflex/components/layout/html.pyi
@@ -7,8 +7,9 @@ from typing import Any, Dict, Literal, Optional, Union, overload
from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style
-from typing import Any
+from typing import Dict
from reflex.components.layout.box import Box
+from reflex.vars import Var
class Html(Box):
@overload
@@ -16,7 +17,9 @@ class Html(Box):
def create( # type: ignore
cls,
*children,
- dangerouslySetInnerHTML: Optional[Any] = None,
+ dangerouslySetInnerHTML: Optional[
+ Union[Var[Dict[str, str]], Dict[str, str]]
+ ] = None,
element: Optional[Union[Var[str], str]] = None,
src: Optional[Union[Var[str], str]] = None,
alt: Optional[Union[Var[str], str]] = None,
diff --git a/reflex/components/libs/chakra.py b/reflex/components/libs/chakra.py
index fdab62fd4..34bfff1d5 100644
--- a/reflex/components/libs/chakra.py
+++ b/reflex/components/libs/chakra.py
@@ -6,7 +6,7 @@ from typing import List, Literal
from reflex.components.component import Component
from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
class ChakraComponent(Component):
@@ -34,7 +34,7 @@ class ChakraComponent(Component):
The dependencies imports of the component.
"""
return {
- dep: [ImportVar(tag=None, render=False)]
+ dep: [imports.ImportVar(tag=None, render=False)]
for dep in [
"@chakra-ui/system@2.5.7",
"framer-motion@10.16.4",
@@ -75,17 +75,17 @@ class ChakraProvider(ChakraComponent):
)
def _get_imports(self) -> imports.ImportDict:
- imports = super()._get_imports()
- imports.setdefault(self.__fields__["library"].default, []).append(
- ImportVar(tag="extendTheme", is_default=False),
+ _imports = super()._get_imports()
+ _imports.setdefault(self.__fields__["library"].default, []).append(
+ imports.ImportVar(tag="extendTheme", is_default=False),
)
- imports.setdefault("/utils/theme.js", []).append(
- ImportVar(tag="theme", is_default=True),
+ _imports.setdefault("/utils/theme.js", []).append(
+ imports.ImportVar(tag="theme", is_default=True),
)
- imports.setdefault(Global.__fields__["library"].default, []).append(
- ImportVar(tag="css", is_default=False),
+ _imports.setdefault(Global.__fields__["library"].default, []).append(
+ imports.ImportVar(tag="css", is_default=False),
)
- return imports
+ return _imports
def _get_custom_code(self) -> str | None:
return """
diff --git a/reflex/components/libs/chakra.pyi b/reflex/components/libs/chakra.pyi
index 967dcda35..d410a1c9f 100644
--- a/reflex/components/libs/chakra.pyi
+++ b/reflex/components/libs/chakra.pyi
@@ -11,7 +11,7 @@ from functools import lru_cache
from typing import List, Literal
from reflex.components.component import Component
from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
class ChakraComponent(Component):
@overload
diff --git a/reflex/components/navigation/client_side_routing.py b/reflex/components/navigation/client_side_routing.py
index 22620e54b..99f0f6cfd 100644
--- a/reflex/components/navigation/client_side_routing.py
+++ b/reflex/components/navigation/client_side_routing.py
@@ -10,10 +10,9 @@ routeNotFound becomes true.
from __future__ import annotations
from reflex import constants
-
-from ...vars import Var
-from ..component import Component
-from ..layout.cond import Cond
+from reflex.components.component import Component
+from reflex.components.layout.cond import cond
+from reflex.vars import Var
route_not_found: Var = Var.create_safe(constants.ROUTE_NOT_FOUND)
@@ -52,10 +51,10 @@ def wait_for_client_redirect(component) -> Component:
Returns:
The conditionally rendered component.
"""
- return Cond.create(
- cond=route_not_found,
- comp1=component,
- comp2=ClientSideRouting.create(),
+ return cond(
+ condition=route_not_found,
+ c1=component,
+ c2=ClientSideRouting.create(),
)
diff --git a/reflex/components/navigation/client_side_routing.pyi b/reflex/components/navigation/client_side_routing.pyi
index 100d0adb2..b7801246e 100644
--- a/reflex/components/navigation/client_side_routing.pyi
+++ b/reflex/components/navigation/client_side_routing.pyi
@@ -8,9 +8,9 @@ from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style
from reflex import constants
-from ...vars import Var
-from ..component import Component
-from ..layout.cond import Cond
+from reflex.components.component import Component
+from reflex.components.layout.cond import cond
+from reflex.vars import Var
route_not_found: Var
diff --git a/reflex/components/overlay/banner.py b/reflex/components/overlay/banner.py
index cdb01c063..d690f3a9c 100644
--- a/reflex/components/overlay/banner.py
+++ b/reflex/components/overlay/banner.py
@@ -5,22 +5,27 @@ from typing import Optional
from reflex.components.base.bare import Bare
from reflex.components.component import Component
-from reflex.components.layout import Box, Cond
+from reflex.components.layout import Box, cond
from reflex.components.overlay.modal import Modal
from reflex.components.typography import Text
+from reflex.constants import Hooks, Imports
from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var, VarData
+
+connect_error_var_data: VarData = VarData( # type: ignore
+ imports=Imports.EVENTS,
+ hooks={Hooks.EVENTS},
+)
connection_error: Var = Var.create_safe(
value="(connectError !== null) ? connectError.message : ''",
_var_is_local=False,
_var_is_string=False,
-)
+)._replace(merge_var_data=connect_error_var_data)
has_connection_error: Var = Var.create_safe(
value="connectError !== null",
_var_is_string=False,
-)
-has_connection_error._var_type = bool
+)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
class WebsocketTargetURL(Bare):
@@ -28,7 +33,7 @@ class WebsocketTargetURL(Bare):
def _get_imports(self) -> imports.ImportDict:
return {
- "/utils/state.js": [ImportVar(tag="getEventURL")],
+ "/utils/state.js": [imports.ImportVar(tag="getEventURL")],
}
@classmethod
@@ -78,7 +83,7 @@ class ConnectionBanner(Component):
textAlign="center",
)
- return Cond.create(has_connection_error, comp)
+ return cond(has_connection_error, comp)
class ConnectionModal(Component):
@@ -96,7 +101,7 @@ class ConnectionModal(Component):
"""
if not comp:
comp = Text.create(*default_connection_error())
- return Cond.create(
+ return cond(
has_connection_error,
Modal.create(
header="Connection Error",
diff --git a/reflex/components/overlay/banner.pyi b/reflex/components/overlay/banner.pyi
index 4e855ae1d..db3cafd04 100644
--- a/reflex/components/overlay/banner.pyi
+++ b/reflex/components/overlay/banner.pyi
@@ -10,15 +10,16 @@ from reflex.style import Style
from typing import Optional
from reflex.components.base.bare import Bare
from reflex.components.component import Component
-from reflex.components.layout import Box, Cond
+from reflex.components.layout import Box, cond
from reflex.components.overlay.modal import Modal
from reflex.components.typography import Text
+from reflex.constants import Hooks, Imports
from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var, VarData
+connect_error_var_data: VarData
connection_error: Var
has_connection_error: Var
-has_connection_error._var_type = bool
class WebsocketTargetURL(Bare):
@overload
diff --git a/reflex/components/radix/themes/base.py b/reflex/components/radix/themes/base.py
index 3e589faca..c75a426a4 100644
--- a/reflex/components/radix/themes/base.py
+++ b/reflex/components/radix/themes/base.py
@@ -6,7 +6,7 @@ from typing import Literal
from reflex.components import Component
from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
LiteralJustify = Literal["start", "center", "end", "between"]
@@ -147,7 +147,7 @@ class Theme(RadixThemesComponent):
def _get_imports(self) -> imports.ImportDict:
return {
**super()._get_imports(),
- "": [ImportVar(tag="@radix-ui/themes/styles.css", install=False)],
+ "": [imports.ImportVar(tag="@radix-ui/themes/styles.css", install=False)],
}
diff --git a/reflex/components/radix/themes/base.pyi b/reflex/components/radix/themes/base.pyi
index eb8b8cb30..9f840f2a8 100644
--- a/reflex/components/radix/themes/base.pyi
+++ b/reflex/components/radix/themes/base.pyi
@@ -10,7 +10,7 @@ from reflex.style import Style
from typing import Literal
from reflex.components import Component
from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.vars import Var
LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
LiteralJustify = Literal["start", "center", "end", "between"]
diff --git a/reflex/components/typography/markdown.py b/reflex/components/typography/markdown.py
index f414f1781..f8841a75d 100644
--- a/reflex/components/typography/markdown.py
+++ b/reflex/components/typography/markdown.py
@@ -14,7 +14,8 @@ from reflex.components.typography.heading import Heading
from reflex.components.typography.text import Text
from reflex.style import Style
from reflex.utils import console, imports, types
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
# Special vars used in the component map.
_CHILDREN = Var.create_safe("children", _var_is_local=False)
diff --git a/reflex/components/typography/markdown.pyi b/reflex/components/typography/markdown.pyi
index ee11bf255..cb9140435 100644
--- a/reflex/components/typography/markdown.pyi
+++ b/reflex/components/typography/markdown.pyi
@@ -18,7 +18,8 @@ from reflex.components.typography.heading import Heading
from reflex.components.typography.text import Text
from reflex.style import Style
from reflex.utils import console, imports, types
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var
_CHILDREN = Var.create_safe("children", _var_is_local=False)
_PROPS = Var.create_safe("...props", _var_is_local=False)
diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py
index 628f7511b..ca408c6a6 100644
--- a/reflex/constants/__init__.py
+++ b/reflex/constants/__init__.py
@@ -22,6 +22,8 @@ from .compiler import (
CompileVars,
ComponentName,
Ext,
+ Hooks,
+ Imports,
PageNames,
)
from .config import (
@@ -68,7 +70,9 @@ __ALL__ = [
Ext,
Fnm,
GitIgnore,
+ Hooks,
RequirementsTxt,
+ Imports,
IS_WINDOWS,
LOCAL_STORAGE,
LogLevel,
diff --git a/reflex/constants/base.py b/reflex/constants/base.py
index 8957edfb7..0b28e18cb 100644
--- a/reflex/constants/base.py
+++ b/reflex/constants/base.py
@@ -29,6 +29,8 @@ class Dirs(SimpleNamespace):
STATE_PATH = "/".join([UTILS, "state"])
# The name of the components file.
COMPONENTS_PATH = "/".join([UTILS, "components"])
+ # The name of the contexts file.
+ CONTEXTS_PATH = "/".join([UTILS, "context"])
# The directory where the app pages are compiled to.
WEB_PAGES = os.path.join(WEB, "pages")
# The directory where the static build is located.
diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py
index 4a9d09d4c..e309c5d4a 100644
--- a/reflex/constants/compiler.py
+++ b/reflex/constants/compiler.py
@@ -2,6 +2,9 @@
from enum import Enum
from types import SimpleNamespace
+from reflex.constants import Dirs
+from reflex.utils.imports import ImportVar
+
# The prefix used to create setters for state vars.
SETTER_PREFIX = "set_"
@@ -47,6 +50,12 @@ class CompileVars(SimpleNamespace):
HYDRATE = "hydrate"
# The name of the is_hydrated variable.
IS_HYDRATED = "is_hydrated"
+ # The name of the function to add events to the queue.
+ ADD_EVENTS = "addEvents"
+ # The name of the var storing any connection error.
+ CONNECT_ERROR = "connectError"
+ # The name of the function for converting a dict to an event.
+ TO_EVENT = "Event"
class PageNames(SimpleNamespace):
@@ -77,3 +86,19 @@ class ComponentName(Enum):
The lower-case filename with zip extension.
"""
return self.value.lower() + Ext.ZIP
+
+
+class Imports(SimpleNamespace):
+ """Common sets of import vars."""
+
+ EVENTS = {
+ "react": {ImportVar(tag="useContext")},
+ f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
+ f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
+ }
+
+
+class Hooks(SimpleNamespace):
+ """Common sets of hook declarations."""
+
+ EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);"
diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py
index 3992919a2..38d5fb14f 100644
--- a/reflex/middleware/hydrate_middleware.py
+++ b/reflex/middleware/hydrate_middleware.py
@@ -48,7 +48,7 @@ class HydrateMiddleware(Middleware):
setattr(var_state, var_name, value)
# Get the initial state.
- delta = format.format_state({state.get_name(): state.dict()})
+ delta = format.format_state(state.dict())
# since a full dict was captured, clean any dirtiness
state._clean()
diff --git a/reflex/state.py b/reflex/state.py
index 11bb5e9a4..424ead87c 100644
--- a/reflex/state.py
+++ b/reflex/state.py
@@ -1211,12 +1211,16 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
if include_computed
else {}
)
- substate_vars = {
- k: v.dict(include_computed=include_computed, **kwargs)
- for k, v in self.substates.items()
+ variables = {**base_vars, **computed_vars}
+ d = {
+ self.get_full_name(): {k: variables[k] for k in sorted(variables)},
}
- variables = {**base_vars, **computed_vars, **substate_vars}
- return {k: variables[k] for k in sorted(variables)}
+ for substate_d in [
+ v.dict(include_computed=include_computed, **kwargs)
+ for v in self.substates.values()
+ ]:
+ d.update(substate_d)
+ return d
async def __aenter__(self) -> State:
"""Enter the async context manager protocol.
diff --git a/reflex/style.py b/reflex/style.py
index bb8320163..d67029992 100644
--- a/reflex/style.py
+++ b/reflex/style.py
@@ -2,13 +2,38 @@
from __future__ import annotations
+from typing import Any
+
from reflex import constants
from reflex.event import EventChain
from reflex.utils import format
-from reflex.vars import BaseVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import BaseVar, Var, VarData
-color_mode = BaseVar(_var_name=constants.ColorMode.NAME, _var_type="str")
-toggle_color_mode = BaseVar(_var_name=constants.ColorMode.TOGGLE, _var_type=EventChain)
+VarData.update_forward_refs() # Ensure all type definitions are resolved
+
+# Reference the global ColorModeContext
+color_mode_var_data = VarData( # type: ignore
+ imports={
+ f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")},
+ "react": {ImportVar(tag="useContext")},
+ },
+ hooks={
+ f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)",
+ },
+)
+# Var resolves to the current color mode for the app ("light" or "dark")
+color_mode = BaseVar(
+ _var_name=constants.ColorMode.NAME,
+ _var_type="str",
+ _var_data=color_mode_var_data,
+)
+# Var resolves to a function invocation that toggles the color mode
+toggle_color_mode = BaseVar(
+ _var_name=constants.ColorMode.TOGGLE,
+ _var_type=EventChain,
+ _var_data=color_mode_var_data,
+)
def convert(style_dict):
@@ -20,16 +45,27 @@ def convert(style_dict):
Returns:
The formatted style dictionary.
"""
+ var_data = None # Track import/hook data from any Vars in the style dict.
out = {}
for key, value in style_dict.items():
key = format.to_camel_case(key)
+ new_var_data = None
if isinstance(value, dict):
- out[key] = convert(value)
+ # Recursively format nested style dictionaries.
+ out[key], new_var_data = convert(value)
elif isinstance(value, Var):
+ # If the value is a Var, extract the var_data and cast as str.
+ new_var_data = value._var_data
out[key] = str(value)
else:
+ # Otherwise, convert to Var to collapse VarData encoded in f-string.
+ new_var = Var.create(value)
+ if new_var is not None:
+ new_var_data = new_var._var_data
out[key] = value
- return out
+ # Combine all the collected VarData instances.
+ var_data = VarData.merge(var_data, new_var_data)
+ return out, var_data
class Style(dict):
@@ -41,5 +77,33 @@ class Style(dict):
Args:
style_dict: The style dictionary.
"""
- style_dict = style_dict or {}
- super().__init__(convert(style_dict))
+ style_dict, self._var_data = convert(style_dict or {})
+ super().__init__(style_dict)
+
+ def update(self, style_dict: dict | None, **kwargs):
+ """Update the style.
+
+ Args:
+ style_dict: The style dictionary.
+ kwargs: Other key value pairs to apply to the dict update.
+ """
+ if kwargs:
+ style_dict = {**(style_dict or {}), **kwargs}
+ converted_dict = type(self)(style_dict)
+ # Combine our VarData with that of any Vars in the style_dict that was passed.
+ self._var_data = VarData.merge(self._var_data, converted_dict._var_data)
+ super().update(converted_dict)
+
+ def __setitem__(self, key: str, value: Any):
+ """Set an item in the style.
+
+ Args:
+ key: The key to set.
+ value: The value to set.
+ """
+ # Create a Var to collapse VarData encoded in f-string.
+ _var = Var.create(value)
+ if _var is not None:
+ # Carry the imports/hooks when setting a Var as a value.
+ self._var_data = VarData.merge(self._var_data, _var._var_data)
+ super().__setitem__(key, value)
diff --git a/reflex/utils/format.py b/reflex/utils/format.py
index 03ee5c9c7..4050706ea 100644
--- a/reflex/utils/format.py
+++ b/reflex/utils/format.py
@@ -232,9 +232,9 @@ def format_route(route: str, format_case=True) -> str:
def format_cond(
- cond: str,
- true_value: str,
- false_value: str = '""',
+ cond: str | Var,
+ true_value: str | Var,
+ false_value: str | Var = '""',
is_prop=False,
) -> str:
"""Format a conditional expression.
@@ -248,9 +248,6 @@ def format_cond(
Returns:
The formatted conditional expression.
"""
- # Import here to avoid circular imports.
- from reflex.vars import Var
-
# Use Python truthiness.
cond = f"isTrue({cond})"
@@ -266,6 +263,7 @@ def format_cond(
_var_is_string=type(false_value) is str,
)
prop2._var_is_local = True
+ prop1, prop2 = str(prop1), str(prop2) # avoid f-string semantics for Var
return f"{cond} ? {prop1} : {prop2}".replace("{", "").replace("}", "")
# Format component conds.
@@ -517,6 +515,21 @@ def format_state(value: Any) -> Any:
raise TypeError(f"No JSON serializer found for var {value} of type {type(value)}.")
+def format_state_name(state_name: str) -> str:
+ """Format a state name, replacing dots with double underscore.
+
+ This allows individual substates to be accessed independently as javascript vars
+ without using dot notation.
+
+ Args:
+ state_name: The state name to format.
+
+ Returns:
+ The formatted state name.
+ """
+ return state_name.replace(".", "__")
+
+
def format_ref(ref: str) -> str:
"""Format a ref.
diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py
index f20ec141d..26faa4820 100644
--- a/reflex/utils/imports.py
+++ b/reflex/utils/imports.py
@@ -3,11 +3,9 @@
from __future__ import annotations
from collections import defaultdict
-from typing import Dict, List
+from typing import Dict, List, Optional
-from reflex.vars import ImportVar
-
-ImportDict = Dict[str, List[ImportVar]]
+from reflex.base import Base
def merge_imports(*imports) -> ImportDict:
@@ -24,3 +22,42 @@ def merge_imports(*imports) -> ImportDict:
for lib, fields in import_dict.items():
all_imports[lib].extend(fields)
return all_imports
+
+
+class ImportVar(Base):
+ """An import var."""
+
+ # The name of the import tag.
+ tag: Optional[str]
+
+ # whether the import is default or named.
+ is_default: Optional[bool] = False
+
+ # The tag alias.
+ alias: Optional[str] = None
+
+ # Whether this import need to install the associated lib
+ install: Optional[bool] = True
+
+ # whether this import should be rendered or not
+ render: Optional[bool] = True
+
+ @property
+ def name(self) -> str:
+ """The name of the import.
+
+ Returns:
+ The name(tag name with alias) of tag.
+ """
+ return self.tag if not self.alias else " as ".join([self.tag, self.alias]) # type: ignore
+
+ def __hash__(self) -> int:
+ """Define a hash function for the import var.
+
+ Returns:
+ The hash of the var.
+ """
+ return hash((self.tag, self.is_default, self.alias, self.install, self.render))
+
+
+ImportDict = Dict[str, List[ImportVar]]
diff --git a/reflex/utils/types.py b/reflex/utils/types.py
index c114e94ac..806684954 100644
--- a/reflex/utils/types.py
+++ b/reflex/utils/types.py
@@ -27,6 +27,7 @@ from reflex.utils import serializers
GenericType = Union[Type, _GenericAlias]
# Valid state var types.
+JSONType = {str, int, float, bool}
PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
StateVar = Union[PrimitiveType, Base, None]
StateIterVar = Union[list, set, tuple]
diff --git a/reflex/vars.py b/reflex/vars.py
index 468dc2a24..d4d785b7a 100644
--- a/reflex/vars.py
+++ b/reflex/vars.py
@@ -7,6 +7,7 @@ import dis
import inspect
import json
import random
+import re
import string
import sys
from types import CodeType, FunctionType
@@ -15,9 +16,11 @@ from typing import (
Any,
Callable,
Dict,
+ Iterable,
List,
Literal,
Optional,
+ Set,
Tuple,
Type,
Union,
@@ -30,7 +33,10 @@ from typing import (
from reflex import constants
from reflex.base import Base
-from reflex.utils import console, format, serializers, types
+from reflex.utils import console, format, imports, serializers, types
+
+# This module used to export ImportVar itself, so we still import it for export here
+from reflex.utils.imports import ImportDict, ImportVar
if TYPE_CHECKING:
from reflex.state import State
@@ -71,7 +77,7 @@ OPERATION_MAPPING = {
REPLACED_NAMES = {
"full_name": "_var_full_name",
"name": "_var_name",
- "state": "_var_state",
+ "state": "_var_data.state",
"type_": "_var_type",
"is_local": "_var_is_local",
"is_string": "_var_is_string",
@@ -93,6 +99,131 @@ def get_unique_variable_name() -> str:
return get_unique_variable_name()
+class VarData(Base):
+ """Metadata associated with a Var."""
+
+ # The name of the enclosing state.
+ state: str = ""
+
+ # Imports needed to render this var
+ imports: ImportDict = {}
+
+ # Hooks that need to be present in the component to render this var
+ hooks: Set[str] = set()
+
+ @classmethod
+ def merge(cls, *others: VarData | None) -> VarData | None:
+ """Merge multiple var data objects.
+
+ Args:
+ *others: The var data objects to merge.
+
+ Returns:
+ The merged var data object.
+ """
+ state = ""
+ _imports = {}
+ hooks = set()
+ for var_data in others:
+ if var_data is None:
+ continue
+ state = state or var_data.state
+ _imports = imports.merge_imports(_imports, var_data.imports)
+ hooks.update(var_data.hooks)
+ return (
+ cls(
+ state=state,
+ imports=_imports,
+ hooks=hooks,
+ )
+ or None
+ )
+
+ def __bool__(self) -> bool:
+ """Check if the var data is non-empty.
+
+ Returns:
+ True if any field is set to a non-default value.
+ """
+ return bool(self.state or self.imports or self.hooks)
+
+ def dict(self) -> dict:
+ """Convert the var data to a dictionary.
+
+ Returns:
+ The var data dictionary.
+ """
+ return {
+ "state": self.state,
+ "imports": {
+ lib: [import_var.dict() for import_var in import_vars]
+ for lib, import_vars in self.imports.items()
+ },
+ "hooks": list(self.hooks),
+ }
+
+
+def _encode_var(value: Var) -> str:
+ """Encode the state name into a formatted var.
+
+ Args:
+ value: The value to encode the state name into.
+
+ Returns:
+ The encoded var.
+ """
+ if value._var_data:
+ return f"{value._var_data.json()}" + str(value)
+ return str(value)
+
+
+def _decode_var(value: str) -> tuple[VarData | None, str]:
+ """Decode the state name from a formatted var.
+
+ Args:
+ value: The value to extract the state name from.
+
+ Returns:
+ The extracted state name and the value without the state name.
+ """
+ var_datas = []
+ if isinstance(value, str):
+ # Extract the state name from a formatted var
+ while m := re.match(r"(.*)(.*)(.*)", value):
+ value = m.group(1) + m.group(3)
+ var_datas.append(VarData.parse_raw(m.group(2)))
+ if var_datas:
+ return VarData.merge(*var_datas), value
+ return None, value
+
+
+def _extract_var_data(value: Iterable) -> list[VarData | None]:
+ """Extract the var imports and hooks from an iterable containing a Var.
+
+ Args:
+ value: The iterable to extract the VarData from
+
+ Returns:
+ The extracted VarDatas.
+ """
+ var_datas = []
+ with contextlib.suppress(TypeError):
+ for sub in value:
+ if isinstance(sub, Var):
+ var_datas.append(sub._var_data)
+ elif not isinstance(sub, str):
+ # Recurse into dict values.
+ if hasattr(sub, "values") and callable(sub.values):
+ var_datas.extend(_extract_var_data(sub.values()))
+ # Recurse into iterable values (or dict keys).
+ var_datas.extend(_extract_var_data(sub))
+ # Recurse when value is a dict itself.
+ values = getattr(value, "values", None)
+ if callable(values):
+ var_datas.extend(_extract_var_data(values()))
+ return var_datas
+
+
class Var:
"""An abstract var."""
@@ -102,15 +233,18 @@ class Var:
# The type of the var.
_var_type: Type
- # The name of the enclosing state.
- _var_state: str
-
# Whether this is a local javascript variable.
_var_is_local: bool
# Whether the var is a string literal.
_var_is_string: bool
+ # _var_full_name should be prefixed with _var_state
+ _var_full_name_needs_state_prefix: bool
+
+ # Extra metadata associated with the Var
+ _var_data: Optional[VarData]
+
@classmethod
def create(
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
@@ -136,9 +270,14 @@ class Var:
if isinstance(value, Var):
return value
+ # Try to pull the imports and hooks from contained values.
+ _var_data = None
+ if not isinstance(value, str):
+ _var_data = VarData.merge(*_extract_var_data(value))
+
# Try to serialize the value.
type_ = type(value)
- name = serializers.serialize(value)
+ name = value if type_ in types.JSONType else serializers.serialize(value)
if name is None:
raise TypeError(
f"No JSON serializer found for var {value} of type {type_}."
@@ -150,6 +289,7 @@ class Var:
_var_type=type_,
_var_is_local=_var_is_local,
_var_is_string=_var_is_string,
+ _var_data=_var_data,
)
@classmethod
@@ -186,6 +326,39 @@ class Var:
"""
return _GenericAlias(cls, type_)
+ def __post_init__(self) -> None:
+ """Post-initialize the var."""
+ # Decode any inline Var markup and apply it to the instance
+ _var_data, _var_name = _decode_var(self._var_name)
+ if _var_data:
+ self._var_name = _var_name
+ self._var_data = VarData.merge(self._var_data, _var_data)
+
+ def _replace(self, merge_var_data=None, **kwargs: Any) -> Var:
+ """Make a copy of this Var with updated fields.
+
+ Args:
+ merge_var_data: VarData to merge into the existing VarData.
+ **kwargs: Var fields to update.
+
+ Returns:
+ A new BaseVar with the updated fields overwriting the corresponding fields in this Var.
+ """
+ field_values = dict(
+ _var_name=kwargs.pop("_var_name", self._var_name),
+ _var_type=kwargs.pop("_var_type", self._var_type),
+ _var_is_local=kwargs.pop("_var_is_local", self._var_is_local),
+ _var_is_string=kwargs.pop("_var_is_string", self._var_is_string),
+ _var_full_name_needs_state_prefix=kwargs.pop(
+ "_var_full_name_needs_state_prefix",
+ self._var_full_name_needs_state_prefix,
+ ),
+ _var_data=VarData.merge(
+ kwargs.get("_var_data", self._var_data), merge_var_data
+ ),
+ )
+ return BaseVar(**field_values)
+
def _decode(self) -> Any:
"""Decode Var as a python value.
@@ -195,8 +368,6 @@ class Var:
Returns:
The decoded value or the Var name.
"""
- if self._var_state:
- return self._var_full_name
if self._var_is_string:
return self._var_name
try:
@@ -216,8 +387,10 @@ class Var:
return (
self._var_name == other._var_name
and self._var_type == other._var_type
- and self._var_state == other._var_state
and self._var_is_local == other._var_is_local
+ and self._var_full_name_needs_state_prefix
+ == other._var_full_name_needs_state_prefix
+ and self._var_data == other._var_data
)
def to_string(self, json: bool = True) -> Var:
@@ -285,9 +458,11 @@ class Var:
Returns:
The formatted var.
"""
+ # Encode the _var_data into the formatted output for tracking purposes.
+ str_self = _encode_var(self)
if self._var_is_local:
- return str(self)
- return f"${str(self)}"
+ return str_self
+ return f"${str_self}"
def __getitem__(self, i: Any) -> Var:
"""Index into a var.
@@ -320,12 +495,7 @@ class Var:
# Convert any vars to local vars.
if isinstance(i, Var):
- i = BaseVar(
- _var_name=i._var_name,
- _var_type=i._var_type,
- _var_state=i._var_state,
- _var_is_local=True,
- )
+ i = i._replace(_var_is_local=True)
# Handle list/tuple/str indexing.
if types._issubclass(self._var_type, Union[List, Tuple, str]):
@@ -344,11 +514,9 @@ class Var:
stop = i.stop or "undefined"
# Use the slice function.
- return BaseVar(
+ return self._replace(
_var_name=f"{self._var_name}.slice({start}, {stop})",
- _var_type=self._var_type,
- _var_state=self._var_state,
- _var_is_local=self._var_is_local,
+ _var_is_string=False,
)
# Get the type of the indexed var.
@@ -359,11 +527,10 @@ class Var:
)
# Use `at` to support negative indices.
- return BaseVar(
+ return self._replace(
_var_name=f"{self._var_name}.at({i})",
_var_type=type_,
- _var_state=self._var_state,
- _var_is_local=self._var_is_local,
+ _var_is_string=False,
)
# Dictionary / dataframe indexing.
@@ -393,11 +560,10 @@ class Var:
)
# Use normal indexing here.
- return BaseVar(
+ return self._replace(
_var_name=f"{self._var_name}[{i}]",
_var_type=type_,
- _var_state=self._var_state,
- _var_is_local=self._var_is_local,
+ _var_is_string=False,
)
def __getattr__(self, name: str) -> Var:
@@ -423,11 +589,10 @@ class Var:
type_ = types.get_attribute_access_type(self._var_type, name)
if type_ is not None:
- return BaseVar(
+ return self._replace(
_var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
_var_type=type_,
- _var_state=self._var_state,
- _var_is_local=self._var_is_local,
+ _var_is_string=False,
)
if name in REPLACED_NAMES:
@@ -519,10 +684,12 @@ class Var:
else f"{self._var_full_name}.{fn}()"
)
- return BaseVar(
+ return self._replace(
_var_name=operation_name,
_var_type=type_,
- _var_is_local=self._var_is_local,
+ _var_is_string=False,
+ _var_full_name_needs_state_prefix=False,
+ merge_var_data=other._var_data if other is not None else None,
)
@staticmethod
@@ -602,10 +769,10 @@ class Var:
"""
if not types._issubclass(self._var_type, List):
raise TypeError(f"Cannot get length of non-list var {self}.")
- return BaseVar(
- _var_name=f"{self._var_full_name}.length",
+ return self._replace(
+ _var_name=f"{self._var_name}.length",
_var_type=int,
- _var_is_local=self._var_is_local,
+ _var_is_string=False,
)
def __eq__(self, other: Var) -> Var:
@@ -692,7 +859,17 @@ class Var:
types.get_base_class(self._var_type) == list
and types.get_base_class(other_type) == list
):
- return self.operation(",", other, fn="spreadArraysOrObjects", flip=flip)
+ return self.operation(
+ ",", other, fn="spreadArraysOrObjects", flip=flip
+ )._replace(
+ merge_var_data=VarData(
+ imports={
+ f"/{constants.Dirs.STATE_PATH}": [
+ ImportVar(tag="spreadArraysOrObjects")
+ ]
+ },
+ ),
+ )
return self.operation("+", other, flip=flip)
def __radd__(self, other: Var) -> Var:
@@ -755,10 +932,11 @@ class Var:
]:
other_name = other._var_full_name if isinstance(other, Var) else other
name = f"Array({other_name}).fill().map(() => {self._var_full_name}).flat()"
- return BaseVar(
+ return self._replace(
_var_name=name,
_var_type=str,
- _var_is_local=self._var_is_local,
+ _var_is_string=False,
+ _var_full_name_needs_state_prefix=False,
)
return self.operation("*", other)
@@ -1003,10 +1181,11 @@ class Var:
elif not isinstance(other, Var):
other = Var.create(other)
if types._issubclass(self._var_type, Dict):
- return BaseVar(
- _var_name=f"{self._var_full_name}.{method}({other._var_full_name})",
+ return self._replace(
+ _var_name=f"{self._var_name}.{method}({other._var_full_name})",
_var_type=bool,
- _var_is_local=self._var_is_local,
+ _var_is_string=False,
+ merge_var_data=other._var_data,
)
else: # str, list, tuple
# For strings, the left operand must be a string.
@@ -1016,10 +1195,11 @@ class Var:
raise TypeError(
f"'in ' requires string as left operand, not {other._var_type}"
)
- return BaseVar(
- _var_name=f"{self._var_full_name}.includes({other._var_full_name})",
+ return self._replace(
+ _var_name=f"{self._var_name}.includes({other._var_full_name})",
_var_type=bool,
- _var_is_local=self._var_is_local,
+ _var_is_string=False,
+ merge_var_data=other._var_data,
)
def reverse(self) -> Var:
@@ -1034,10 +1214,10 @@ class Var:
if not types._issubclass(self._var_type, list):
raise TypeError(f"Cannot reverse non-list var {self._var_full_name}.")
- return BaseVar(
+ return self._replace(
_var_name=f"[...{self._var_full_name}].reverse()",
- _var_type=self._var_type,
- _var_is_local=self._var_is_local,
+ _var_is_string=False,
+ _var_full_name_needs_state_prefix=False,
)
def lower(self) -> Var:
@@ -1054,10 +1234,10 @@ class Var:
f"Cannot convert non-string var {self._var_full_name} to lowercase."
)
- return BaseVar(
- _var_name=f"{self._var_full_name}.toLowerCase()",
+ return self._replace(
+ _var_name=f"{self._var_name}.toLowerCase()",
+ _var_is_string=False,
_var_type=str,
- _var_is_local=self._var_is_local,
)
def upper(self) -> Var:
@@ -1074,10 +1254,10 @@ class Var:
f"Cannot convert non-string var {self._var_full_name} to uppercase."
)
- return BaseVar(
- _var_name=f"{self._var_full_name}.toUpperCase()",
+ return self._replace(
+ _var_name=f"{self._var_name}.toUpperCase()",
+ _var_is_string=False,
_var_type=str,
- _var_is_local=self._var_is_local,
)
def split(self, other: str | Var[str] = " ") -> Var:
@@ -1097,10 +1277,11 @@ class Var:
other = Var.create_safe(json.dumps(other)) if isinstance(other, str) else other
- return BaseVar(
- _var_name=f"{self._var_full_name}.split({other._var_full_name})",
+ return self._replace(
+ _var_name=f"{self._var_name}.split({other._var_full_name})",
+ _var_is_string=False,
_var_type=list[str],
- _var_is_local=self._var_is_local,
+ merge_var_data=other._var_data,
)
def join(self, other: str | Var[str] | None = None) -> Var:
@@ -1125,10 +1306,11 @@ class Var:
else:
other = Var.create_safe(other)
- return BaseVar(
- _var_name=f"{self._var_full_name}.join({other._var_full_name})",
+ return self._replace(
+ _var_name=f"{self._var_name}.join({other._var_full_name})",
+ _var_is_string=False,
_var_type=str,
- _var_is_local=self._var_is_local,
+ merge_var_data=other._var_data,
)
def foreach(self, fn: Callable) -> Var:
@@ -1159,10 +1341,9 @@ class Var:
fn_signature = inspect.signature(fn)
fn_args = (arg, index)
fn_ret = fn(*fn_args[: len(fn_signature.parameters)])
- return BaseVar(
+ return self._replace(
_var_name=f"{self._var_full_name}.map(({arg._var_name}, {index._var_name}) => {fn_ret})",
- _var_type=self._var_type,
- _var_is_local=self._var_is_local,
+ _var_is_string=False,
)
@classmethod
@@ -1207,6 +1388,18 @@ class Var:
_var_name=f"Array.from(range({v1._var_full_name}, {v2._var_full_name}, {step._var_name}))",
_var_type=list[int],
_var_is_local=False,
+ _var_data=VarData.merge(
+ v1._var_data,
+ v2._var_data,
+ step._var_data,
+ VarData(
+ imports={
+ "/utils/helpers/range.js": [
+ ImportVar(tag="range", is_default=True),
+ ],
+ },
+ ),
+ ),
)
def to(self, type_: Type) -> Var:
@@ -1218,12 +1411,7 @@ class Var:
Returns:
The converted var.
"""
- return BaseVar(
- _var_name=self._var_name,
- _var_type=type_,
- _var_state=self._var_state,
- _var_is_local=self._var_is_local,
- )
+ return self._replace(_var_type=type_)
@property
def _var_full_name(self) -> str:
@@ -1232,24 +1420,51 @@ class Var:
Returns:
The full name of the var.
"""
+ if not self._var_full_name_needs_state_prefix:
+ return self._var_name
return (
self._var_name
- if self._var_state == ""
- else ".".join([self._var_state, self._var_name])
+ if self._var_data is None or self._var_data.state == ""
+ else ".".join(
+ [format.format_state_name(self._var_data.state), self._var_name]
+ )
)
- def _var_set_state(self, state: Type[State]) -> Any:
+ def _var_set_state(self, state: Type[State] | str) -> Any:
"""Set the state of the var.
Args:
- state: The state to set.
+ state: The state to set or the full name of the state.
Returns:
The var with the set state.
"""
- self._var_state = state.get_full_name()
+ state_name = state if isinstance(state, str) else state.get_full_name()
+ new_var_data = VarData(
+ state=state_name,
+ hooks={
+ "const {0} = useContext(StateContexts.{0})".format(
+ format.format_state_name(state_name)
+ )
+ },
+ imports={
+ f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
+ "react": [ImportVar(tag="useContext")],
+ },
+ )
+ self._var_data = VarData.merge(self._var_data, new_var_data)
+ self._var_full_name_needs_state_prefix = True
return self
+ @property
+ def _var_state(self) -> str:
+ """Compat method for getting the state.
+
+ Returns:
+ The state name associated with the var.
+ """
+ return self._var_data.state if self._var_data else ""
+
@dataclasses.dataclass(
eq=False,
@@ -1264,15 +1479,18 @@ class BaseVar(Var):
# The type of the var.
_var_type: Type = dataclasses.field(default=Any)
- # The name of the enclosing state.
- _var_state: str = dataclasses.field(default="")
-
# Whether this is a local javascript variable.
_var_is_local: bool = dataclasses.field(default=False)
# Whether the var is a string literal.
_var_is_string: bool = dataclasses.field(default=False)
+ # _var_full_name should be prefixed with _var_state
+ _var_full_name_needs_state_prefix: bool = dataclasses.field(default=False)
+
+ # Extra metadata associated with the Var
+ _var_data: Optional[VarData] = dataclasses.field(default=None)
+
def __hash__(self) -> int:
"""Define a hash function for a var.
@@ -1334,9 +1552,11 @@ class BaseVar(Var):
The name of the setter function.
"""
setter = constants.SETTER_PREFIX + self._var_name
- if not include_state or self._var_state == "":
+ if self._var_data is None:
return setter
- return ".".join((self._var_state, setter))
+ if not include_state or self._var_data.state == "":
+ return setter
+ return ".".join((self._var_data.state, setter))
def get_setter(self) -> Callable[[State, Any], None]:
"""Get the var's setter function.
@@ -1550,48 +1770,6 @@ def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
return cvar
-class ImportVar(Base):
- """An import var."""
-
- # The name of the import tag.
- tag: Optional[str]
-
- # whether the import is default or named.
- is_default: Optional[bool] = False
-
- # The tag alias.
- alias: Optional[str] = None
-
- # Whether this import need to install the associated lib
- install: Optional[bool] = True
-
- # whether this import should be rendered or not
- render: Optional[bool] = True
-
- @property
- def name(self) -> str:
- """The name of the import.
-
- Returns:
- The name(tag name with alias) of tag.
- """
- return self.tag if not self.alias else " as ".join([self.tag, self.alias]) # type: ignore
-
- def __hash__(self) -> int:
- """Define a hash function for the import var.
-
- Returns:
- The hash of the var.
- """
- return hash((self.tag, self.is_default, self.alias, self.install, self.render))
-
-
-class NoRenderImportVar(ImportVar):
- """A import that doesn't need to be rendered."""
-
- render: Optional[bool] = False
-
-
class CallableVar(BaseVar):
"""Decorate a Var-returning function to act as both a Var and a function.
diff --git a/reflex/vars.pyi b/reflex/vars.pyi
index 001ba9aa3..9208013db 100644
--- a/reflex/vars.pyi
+++ b/reflex/vars.pyi
@@ -6,11 +6,13 @@ from reflex import constants as constants
from reflex.base import Base as Base
from reflex.state import State as State
from reflex.utils import console as console, format as format, types as types
+from reflex.utils.imports import ImportVar
from types import FunctionType
from typing import (
Any,
Callable,
Dict,
+ Iterable,
List,
Optional,
Set,
@@ -22,13 +24,24 @@ from typing import (
USED_VARIABLES: Incomplete
def get_unique_variable_name() -> str: ...
+def _encode_var(value: Var) -> str: ...
+def _decode_var(value: str) -> tuple[VarData, str]: ...
+def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
+
+class VarData(Base):
+ state: str
+ imports: dict[str, set[ImportVar]]
+ hooks: set[str]
+ @classmethod
+ def merge(cls, *others: VarData | None) -> VarData | None: ...
class Var:
_var_name: str
_var_type: Type
- _var_state: str = ""
_var_is_local: bool = False
_var_is_string: bool = False
+ _var_full_name_needs_state_prefix: bool = False
+ _var_data: VarData | None = None
@classmethod
def create(
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
@@ -38,7 +51,8 @@ class Var:
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
) -> Var: ...
@classmethod
- def __class_getitem__(cls, type_: str) -> _GenericAlias: ...
+ def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...
+ def _replace(self, merge_var_data=None, **kwargs: Any) -> Var: ...
def equals(self, other: Var) -> bool: ...
def to_string(self) -> Var: ...
def __hash__(self) -> int: ...
@@ -95,15 +109,16 @@ class Var:
def to(self, type_: Type) -> Var: ...
@property
def _var_full_name(self) -> str: ...
- def _var_set_state(self, state: Type[State]) -> Any: ...
+ def _var_set_state(self, state: Type[State] | str) -> Any: ...
@dataclass(eq=False)
class BaseVar(Var):
_var_name: str
_var_type: Any
- _var_state: str = ""
_var_is_local: bool = False
_var_is_string: bool = False
+ _var_full_name_needs_state_prefix: bool = False
+ _var_data: VarData | None = None
def __hash__(self) -> int: ...
def get_default_value(self) -> Any: ...
def get_setter_name(self, include_state: bool = ...) -> str: ...
@@ -123,21 +138,6 @@ class ComputedVar(Var):
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
-class ImportVar(Base):
- tag: Optional[str]
- is_default: Optional[bool] = False
- alias: Optional[str] = None
- install: Optional[bool] = True
- render: Optional[bool] = True
- @property
- def name(self) -> str: ...
- def __hash__(self) -> int: ...
-
-class NoRenderImportVar(ImportVar):
- """A import that doesn't need to be rendered."""
-
-def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ...
-
class CallableVar(BaseVar):
def __init__(self, fn: Callable[..., BaseVar]): ...
def __call__(self, *args, **kwargs) -> BaseVar: ...
diff --git a/tests/compiler/test_compiler.py b/tests/compiler/test_compiler.py
index e8423d97f..1def5ca4f 100644
--- a/tests/compiler/test_compiler.py
+++ b/tests/compiler/test_compiler.py
@@ -5,7 +5,7 @@ import pytest
from reflex.compiler import compiler, utils
from reflex.utils import imports
-from reflex.vars import ImportVar
+from reflex.utils.imports import ImportVar
@pytest.mark.parametrize(
diff --git a/tests/components/layout/test_cond.py b/tests/components/layout/test_cond.py
index 3bf373bb4..00cf4de7d 100644
--- a/tests/components/layout/test_cond.py
+++ b/tests/components/layout/test_cond.py
@@ -110,7 +110,7 @@ def test_cond_no_else():
# Props do not support the use of cond without else
with pytest.raises(ValueError):
- cond(True, "hello")
+ cond(True, "hello") # type: ignore
def test_mobile_only():
diff --git a/tests/components/test_component.py b/tests/components/test_component.py
index abf6c5a06..6d6b0a0cc 100644
--- a/tests/components/test_component.py
+++ b/tests/components/test_component.py
@@ -4,14 +4,16 @@ import pytest
import reflex as rx
from reflex.base import Base
+from reflex.components.base.bare import Bare
from reflex.components.component import Component, CustomComponent, custom_component
from reflex.components.layout.box import Box
from reflex.constants import EventTriggers
-from reflex.event import EventHandler
+from reflex.event import EventChain, EventHandler
from reflex.state import State
from reflex.style import Style
from reflex.utils import imports
-from reflex.vars import ImportVar, Var
+from reflex.utils.imports import ImportVar
+from reflex.vars import Var, VarData
@pytest.fixture
@@ -600,3 +602,161 @@ def test_format_component(component, rendered):
rendered: The expected rendered component.
"""
assert str(component) == rendered
+
+
+TEST_VAR = Var.create_safe("test")._replace(
+ merge_var_data=VarData(
+ hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test"
+ )
+)
+FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar")
+STYLE_VAR = TEST_VAR._replace(_var_name="style", _var_is_local=False)
+EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain)
+ARG_VAR = Var.create("arg")
+
+
+class EventState(rx.State):
+ """State for testing event handlers with _get_vars."""
+
+ v: int = 42
+
+ def handler(self):
+ """A handler that does nothing."""
+
+ def handler2(self, arg):
+ """A handler that takes an arg.
+
+ Args:
+ arg: An arg.
+ """
+
+
+@pytest.mark.parametrize(
+ ("component", "exp_vars"),
+ (
+ pytest.param(
+ Bare.create(TEST_VAR),
+ [TEST_VAR],
+ id="direct-bare",
+ ),
+ pytest.param(
+ Bare.create(f"foo{TEST_VAR}bar"),
+ [FORMATTED_TEST_VAR],
+ id="fstring-bare",
+ ),
+ pytest.param(
+ rx.text(as_=TEST_VAR),
+ [TEST_VAR],
+ id="direct-prop",
+ ),
+ pytest.param(
+ rx.text(as_=f"foo{TEST_VAR}bar"),
+ [FORMATTED_TEST_VAR],
+ id="fstring-prop",
+ ),
+ pytest.param(
+ rx.fragment(id=TEST_VAR),
+ [TEST_VAR],
+ id="direct-id",
+ ),
+ pytest.param(
+ rx.fragment(id=f"foo{TEST_VAR}bar"),
+ [FORMATTED_TEST_VAR],
+ id="fstring-id",
+ ),
+ pytest.param(
+ rx.fragment(key=TEST_VAR),
+ [TEST_VAR],
+ id="direct-key",
+ ),
+ pytest.param(
+ rx.fragment(key=f"foo{TEST_VAR}bar"),
+ [FORMATTED_TEST_VAR],
+ id="fstring-key",
+ ),
+ pytest.param(
+ rx.fragment(class_name=TEST_VAR),
+ [TEST_VAR],
+ id="direct-class_name",
+ ),
+ pytest.param(
+ rx.fragment(class_name=f"foo{TEST_VAR}bar"),
+ [FORMATTED_TEST_VAR],
+ id="fstring-class_name",
+ ),
+ pytest.param(
+ rx.fragment(special_props={TEST_VAR}),
+ [TEST_VAR],
+ id="direct-special_props",
+ ),
+ pytest.param(
+ rx.fragment(special_props={Var.create(f"foo{TEST_VAR}bar")}),
+ [FORMATTED_TEST_VAR],
+ id="fstring-special_props",
+ ),
+ pytest.param(
+ # custom_attrs cannot accept a Var directly as a value
+ rx.fragment(custom_attrs={"href": f"{TEST_VAR}"}),
+ [TEST_VAR],
+ id="fstring-custom_attrs-nofmt",
+ ),
+ pytest.param(
+ rx.fragment(custom_attrs={"href": f"foo{TEST_VAR}bar"}),
+ [FORMATTED_TEST_VAR],
+ id="fstring-custom_attrs",
+ ),
+ pytest.param(
+ rx.fragment(background_color=TEST_VAR),
+ [STYLE_VAR],
+ id="direct-background_color",
+ ),
+ pytest.param(
+ rx.fragment(background_color=f"foo{TEST_VAR}bar"),
+ [STYLE_VAR],
+ id="fstring-background_color",
+ ),
+ pytest.param(
+ rx.fragment(style={"background_color": TEST_VAR}), # type: ignore
+ [STYLE_VAR],
+ id="direct-style-background_color",
+ ),
+ pytest.param(
+ rx.fragment(style={"background_color": f"foo{TEST_VAR}bar"}), # type: ignore
+ [STYLE_VAR],
+ id="fstring-style-background_color",
+ ),
+ pytest.param(
+ rx.fragment(on_click=EVENT_CHAIN_VAR), # type: ignore
+ [EVENT_CHAIN_VAR],
+ id="direct-event-chain",
+ ),
+ pytest.param(
+ rx.fragment(on_click=EventState.handler),
+ [],
+ id="direct-event-handler",
+ ),
+ pytest.param(
+ rx.fragment(on_click=EventState.handler2(TEST_VAR)), # type: ignore
+ [ARG_VAR, TEST_VAR],
+ id="direct-event-handler-arg",
+ ),
+ pytest.param(
+ rx.fragment(on_click=EventState.handler2(EventState.v)), # type: ignore
+ [ARG_VAR, EventState.v],
+ id="direct-event-handler-arg2",
+ ),
+ pytest.param(
+ rx.fragment(on_click=lambda: EventState.handler2(TEST_VAR)), # type: ignore
+ [ARG_VAR, TEST_VAR],
+ id="direct-event-handler-lambda",
+ ),
+ ),
+)
+def test_get_vars(component, exp_vars):
+ comp_vars = sorted(component._get_vars(), key=lambda v: v._var_name)
+ assert len(comp_vars) == len(exp_vars)
+ for comp_var, exp_var in zip(
+ comp_vars,
+ sorted(exp_vars, key=lambda v: v._var_name),
+ ):
+ assert comp_var.equals(exp_var)
diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py
index 150083bd5..7767dcf8b 100644
--- a/tests/middleware/test_hydrate_middleware.py
+++ b/tests/middleware/test_hydrate_middleware.py
@@ -104,7 +104,7 @@ async def test_preprocess(
app=app, event=request.getfixturevalue(event_fixture), state=state
)
assert isinstance(update, StateUpdate)
- assert update.delta == {state.get_name(): state.dict()}
+ assert update.delta == state.dict()
events = update.events
assert len(events) == 2
@@ -133,7 +133,7 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
update = await hydrate_middleware.preprocess(app=app, event=event1, state=state)
assert isinstance(update, StateUpdate)
- assert update.delta == {"test_state": state.dict()}
+ assert update.delta == state.dict()
assert len(update.events) == 3
# Apply the events.
@@ -163,7 +163,7 @@ async def test_preprocess_no_events(hydrate_middleware, event1):
state=state,
)
assert isinstance(update, StateUpdate)
- assert update.delta == {"test_state": state.dict()}
+ assert update.delta == state.dict()
assert len(update.events) == 1
assert isinstance(update, StateUpdate)
diff --git a/tests/test_app.py b/tests/test_app.py
index 434eeb60f..e7ac03661 100644
--- a/tests/test_app.py
+++ b/tests/test_app.py
@@ -769,9 +769,7 @@ async def test_upload_file(tmp_path, state, delta, token: str):
)
current_state = await app.state_manager.get_state(token)
- state_dict = current_state.dict()
- for substate in state.get_full_name().split(".")[1:]:
- state_dict = state_dict[substate]
+ state_dict = current_state.dict()[state.get_full_name()]
assert state_dict["img_list"] == [
"image1.jpg",
"image2.jpg",
diff --git a/tests/test_state.py b/tests/test_state.py
index 29065ed11..d85f0e59f 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -324,11 +324,17 @@ def test_dict(test_state):
Args:
test_state: A state.
"""
- substates = {"child_state", "child_state2"}
- assert set(test_state.dict().keys()) == set(test_state.vars.keys()) | substates
- assert (
- set(test_state.dict(include_computed=False).keys())
- == set(test_state.base_vars) | substates
+ substates = {
+ "test_state",
+ "test_state.child_state",
+ "test_state.child_state.grandchild_state",
+ "test_state.child_state2",
+ }
+ test_state_dict = test_state.dict()
+ assert set(test_state_dict) == substates
+ assert set(test_state_dict[test_state.get_name()]) == set(test_state.vars)
+ assert set(test_state.dict(include_computed=False)[test_state.get_name()]) == set(
+ test_state.base_vars
)
@@ -1081,9 +1087,9 @@ def test_computed_var_cached():
return self.v
cs = ComputedState()
- assert cs.dict()["v"] == 0
+ assert cs.dict()[cs.get_full_name()]["v"] == 0
assert comp_v_calls == 1
- assert cs.dict()["comp_v"] == 0
+ assert cs.dict()[cs.get_full_name()]["comp_v"] == 0
assert comp_v_calls == 1
assert cs.comp_v == 0
assert comp_v_calls == 1
@@ -1156,24 +1162,27 @@ def test_computed_var_depends_on_parent_non_cached():
assert ps.dirty_vars == set()
assert cs.dirty_vars == set()
- assert ps.dict() == {
- cs.get_name(): {"dep_v": 2},
+ dict1 = ps.dict()
+ assert dict1[ps.get_full_name()] == {
"no_cache_v": 1,
CompileVars.IS_HYDRATED: False,
"router": formatted_router,
}
- assert ps.dict() == {
- cs.get_name(): {"dep_v": 4},
+ assert dict1[cs.get_full_name()] == {"dep_v": 2}
+ dict2 = ps.dict()
+ assert dict2[ps.get_full_name()] == {
"no_cache_v": 3,
CompileVars.IS_HYDRATED: False,
"router": formatted_router,
}
- assert ps.dict() == {
- cs.get_name(): {"dep_v": 6},
+ assert dict2[cs.get_full_name()] == {"dep_v": 4}
+ dict3 = ps.dict()
+ assert dict3[ps.get_full_name()] == {
"no_cache_v": 5,
CompileVars.IS_HYDRATED: False,
"router": formatted_router,
}
+ assert dict3[cs.get_full_name()] == {"dep_v": 6}
assert counter == 6
@@ -2201,13 +2210,13 @@ def test_json_dumps_with_mutables():
items: List[Foo] = [Foo()]
dict_val = MutableContainsBase().dict()
- assert isinstance(dict_val["items"][0], dict)
+ assert isinstance(dict_val[MutableContainsBase.get_full_name()]["items"][0], dict)
val = json_dumps(dict_val)
f_items = '[{"tags": ["123", "456"]}]'
f_formatted_router = str(formatted_router).replace("'", '"')
assert (
val
- == f'{{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}'
+ == f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}'
)
diff --git a/tests/test_style.py b/tests/test_style.py
index 8b09f9ac0..a8fcf6839 100644
--- a/tests/test_style.py
+++ b/tests/test_style.py
@@ -22,7 +22,8 @@ def test_convert(style_dict, expected):
style_dict: The style to check.
expected: The expected formatted style.
"""
- assert style.convert(style_dict) == expected
+ converted_dict, _var_data = style.convert(style_dict)
+ assert converted_dict == expected
@pytest.mark.parametrize(
diff --git a/tests/test_var.py b/tests/test_var.py
index 9efb5fb78..e3cb28a1c 100644
--- a/tests/test_var.py
+++ b/tests/test_var.py
@@ -7,18 +7,20 @@ from pandas import DataFrame
from reflex.base import Base
from reflex.state import State
+from reflex.utils.imports import ImportVar
from reflex.vars import (
BaseVar,
ComputedVar,
- ImportVar,
Var,
)
test_vars = [
BaseVar(_var_name="prop1", _var_type=int),
BaseVar(_var_name="key", _var_type=str),
- BaseVar(_var_name="value", _var_type=str, _var_state="state"),
- BaseVar(_var_name="local", _var_type=str, _var_state="state", _var_is_local=True),
+ BaseVar(_var_name="value", _var_type=str)._var_set_state("state"),
+ BaseVar(_var_name="local", _var_type=str, _var_is_local=True)._var_set_state(
+ "state"
+ ),
BaseVar(_var_name="local2", _var_type=str, _var_is_local=True),
]
@@ -263,7 +265,7 @@ def test_basic_operations(TestObj):
assert str(v([1, 2, 3])[v(0)]) == "{[1, 2, 3].at(0)}"
assert str(v({"a": 1, "b": 2})["a"]) == '{{"a": 1, "b": 2}["a"]}'
assert (
- str(BaseVar(_var_name="foo", _var_state="state", _var_type=TestObj).bar)
+ str(BaseVar(_var_name="foo", _var_type=TestObj)._var_set_state("state").bar)
== "{state.foo.bar}"
)
assert str(abs(v(1))) == "{Math.abs(1)}"
@@ -274,7 +276,7 @@ def test_basic_operations(TestObj):
assert str(v([1, 2, 3]).reverse()) == "{[...[1, 2, 3]].reverse()}"
assert str(v(["1", "2", "3"]).reverse()) == '{[...["1", "2", "3"]].reverse()}'
assert (
- str(BaseVar(_var_name="foo", _var_state="state", _var_type=list).reverse())
+ str(BaseVar(_var_name="foo", _var_type=list)._var_set_state("state").reverse())
== "{[...state.foo].reverse()}"
)
assert (
@@ -288,11 +290,14 @@ def test_basic_operations(TestObj):
[
(v([1, 2, 3]), "[1, 2, 3]"),
(v(["1", "2", "3"]), '["1", "2", "3"]'),
- (BaseVar(_var_name="foo", _var_state="state", _var_type=list), "state.foo"),
+ (BaseVar(_var_name="foo", _var_type=list)._var_set_state("state"), "state.foo"),
(BaseVar(_var_name="foo", _var_type=list), "foo"),
(v((1, 2, 3)), "[1, 2, 3]"),
(v(("1", "2", "3")), '["1", "2", "3"]'),
- (BaseVar(_var_name="foo", _var_state="state", _var_type=tuple), "state.foo"),
+ (
+ BaseVar(_var_name="foo", _var_type=tuple)._var_set_state("state"),
+ "state.foo",
+ ),
(BaseVar(_var_name="foo", _var_type=tuple), "foo"),
],
)
@@ -301,7 +306,7 @@ def test_list_tuple_contains(var, expected):
assert str(var.contains("1")) == f'{{{expected}.includes("1")}}'
assert str(var.contains(v(1))) == f"{{{expected}.includes(1)}}"
assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}'
- other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
+ other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
other_var = BaseVar(_var_name="other", _var_type=str)
assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}"
assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}"
@@ -311,14 +316,14 @@ def test_list_tuple_contains(var, expected):
"var, expected",
[
(v("123"), json.dumps("123")),
- (BaseVar(_var_name="foo", _var_state="state", _var_type=str), "state.foo"),
+ (BaseVar(_var_name="foo", _var_type=str)._var_set_state("state"), "state.foo"),
(BaseVar(_var_name="foo", _var_type=str), "foo"),
],
)
def test_str_contains(var, expected):
assert str(var.contains("1")) == f'{{{expected}.includes("1")}}'
assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}'
- other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
+ other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
other_var = BaseVar(_var_name="other", _var_type=str)
assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}"
assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}"
@@ -328,7 +333,7 @@ def test_str_contains(var, expected):
"var, expected",
[
(v({"a": 1, "b": 2}), '{"a": 1, "b": 2}'),
- (BaseVar(_var_name="foo", _var_state="state", _var_type=dict), "state.foo"),
+ (BaseVar(_var_name="foo", _var_type=dict)._var_set_state("state"), "state.foo"),
(BaseVar(_var_name="foo", _var_type=dict), "foo"),
],
)
@@ -337,7 +342,7 @@ def test_dict_contains(var, expected):
assert str(var.contains("1")) == f'{{{expected}.hasOwnProperty("1")}}'
assert str(var.contains(v(1))) == f"{{{expected}.hasOwnProperty(1)}}"
assert str(var.contains(v("1"))) == f'{{{expected}.hasOwnProperty("1")}}'
- other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
+ other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
other_var = BaseVar(_var_name="other", _var_type=str)
assert (
str(var.contains(other_state_var))
@@ -548,10 +553,10 @@ def test_var_unsupported_indexing_dicts(var, index):
"fixture,full_name",
[
("ParentState", "parent_state.var_without_annotation"),
- ("ChildState", "parent_state.child_state.var_without_annotation"),
+ ("ChildState", "parent_state__child_state.var_without_annotation"),
(
"GrandChildState",
- "parent_state.child_state.grand_child_state.var_without_annotation",
+ "parent_state__child_state__grand_child_state.var_without_annotation",
),
("StateWithAnyVar", "state_with_any_var.var_without_annotation"),
],
@@ -630,8 +635,8 @@ def test_import_var(import_var, expected):
[
(f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
(
- f"testing f-string with {BaseVar(_var_name='myvar', _var_state='state', _var_type=int)}",
- "testing f-string with ${state.myvar}",
+ f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
+ 'testing f-string with ${"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}{state.myvar}',
),
(
f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",
@@ -643,6 +648,35 @@ def test_fstrings(out, expected):
assert out == expected
+@pytest.mark.parametrize(
+ ("value", "expect_state"),
+ [
+ ([1], ""),
+ ({"a": 1}, ""),
+ ([Var.create_safe(1)._var_set_state("foo")], "foo"),
+ ({"a": Var.create_safe(1)._var_set_state("foo")}, "foo"),
+ ],
+)
+def test_extract_state_from_container(value, expect_state):
+ """Test that _var_state is extracted from containers containing BaseVar.
+
+ Args:
+ value: The value to create a var from.
+ expect_state: The expected state.
+ """
+ assert Var.create_safe(value)._var_state == expect_state
+
+
+def test_fstring_roundtrip():
+ """Test that f-string roundtrip carries state."""
+ var = BaseVar.create_safe("var")._var_set_state("state")
+ rt_var = Var.create_safe(f"{var}")
+ assert var._var_state == rt_var._var_state
+ assert var._var_full_name_needs_state_prefix
+ assert not rt_var._var_full_name_needs_state_prefix
+ assert rt_var._var_name == var._var_full_name
+
+
@pytest.mark.parametrize(
"var",
[
diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py
index c83fd8c79..257f37514 100644
--- a/tests/utils/test_format.py
+++ b/tests/utils/test_format.py
@@ -8,7 +8,13 @@ from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent
from reflex.style import Style
from reflex.utils import format
from reflex.vars import BaseVar, Var
-from tests.test_state import ChildState, DateTimeState, GrandchildState, TestState
+from tests.test_state import (
+ ChildState,
+ ChildState2,
+ DateTimeState,
+ GrandchildState,
+ TestState,
+)
def mock_event(arg):
@@ -349,7 +355,6 @@ def test_format_cond(condition: str, true_value: str, false_value: str, expected
BaseVar(
_var_name="_",
_var_type=Any,
- _var_state="",
_var_is_local=True,
_var_is_string=False,
),
@@ -515,40 +520,44 @@ formatted_router = {
(
TestState().dict(), # type: ignore
{
- "array": [1, 2, 3.14],
- "child_state": {
+ TestState.get_full_name(): {
+ "array": [1, 2, 3.14],
+ "complex": {
+ 1: {"prop1": 42, "prop2": "hello"},
+ 2: {"prop1": 42, "prop2": "hello"},
+ },
+ "dt": "1989-11-09 18:53:00+01:00",
+ "fig": [],
+ "is_hydrated": False,
+ "key": "",
+ "map_key": "a",
+ "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
+ "num1": 0,
+ "num2": 3.14,
+ "obj": {"prop1": 42, "prop2": "hello"},
+ "sum": 3.14,
+ "upper": "",
+ "router": formatted_router,
+ },
+ ChildState.get_full_name(): {
"count": 23,
- "grandchild_state": {"value2": ""},
"value": "",
},
- "child_state2": {"value": ""},
- "complex": {
- 1: {"prop1": 42, "prop2": "hello"},
- 2: {"prop1": 42, "prop2": "hello"},
- },
- "dt": "1989-11-09 18:53:00+01:00",
- "fig": [],
- "is_hydrated": False,
- "key": "",
- "map_key": "a",
- "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
- "num1": 0,
- "num2": 3.14,
- "obj": {"prop1": 42, "prop2": "hello"},
- "sum": 3.14,
- "upper": "",
- "router": formatted_router,
+ ChildState2.get_full_name(): {"value": ""},
+ GrandchildState.get_full_name(): {"value2": ""},
},
),
(
DateTimeState().dict(),
{
- "d": "1989-11-09",
- "dt": "1989-11-09 18:53:00+01:00",
- "is_hydrated": False,
- "t": "18:53:00+01:00",
- "td": "11 days, 0:11:00",
- "router": formatted_router,
+ DateTimeState.get_full_name(): {
+ "d": "1989-11-09",
+ "dt": "1989-11-09 18:53:00+01:00",
+ "is_hydrated": False,
+ "t": "18:53:00+01:00",
+ "td": "11 days, 0:11:00",
+ "router": formatted_router,
+ },
},
),
],