[REF-889] useContext per substate (#2149)

This commit is contained in:
Masen Furer 2023-11-21 11:52:06 -08:00 committed by GitHub
parent e9437ad941
commit 1603144c7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
65 changed files with 1257 additions and 455 deletions

View File

@ -28,6 +28,7 @@ def VarOperations():
str_var4: str = "a long string"
dict1: dict = {1: 2}
dict2: dict = {3: 4}
html_str: str = "<div>hello</div>"
app = rx.App(state=VarOperationState)
@ -522,6 +523,19 @@ def VarOperations():
rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"),
rx.text(VarOperationState.list3.join(""), id="list_join"),
rx.text(VarOperationState.list3.join(","), id="list_join_comma"),
# Index from an op var
rx.text(
VarOperationState.list3[VarOperationState.int_var1 % 3],
id="list_index_mod",
),
rx.html(
VarOperationState.html_str,
id="html_str",
),
rx.highlight(
"second",
query=[VarOperationState.str_var2],
),
rx.text(rx.Var.range(2, 5).join(","), id="list_join_range1"),
rx.text(rx.Var.range(2, 10, 2).join(","), id="list_join_range2"),
rx.text(rx.Var.range(5, 0, -1).join(","), id="list_join_range3"),
@ -713,7 +727,14 @@ def test_var_operations(driver, var_operations: AppHarness):
("dict_eq_dict", "false"),
("dict_neq_dict", "true"),
("dict_contains", "true"),
# index from an op var
("list_index_mod", "second"),
# html component with var
("html_str", "hello"),
]
for tag, expected in tests:
assert driver.find_element(By.ID, tag).text == expected
# Highlight component with var query (does not plumb ID)
assert driver.find_element(By.TAG_NAME, "mark").text == "second"

View File

@ -1,7 +1,7 @@
{% extends "web/pages/base_page.js.jinja2" %}
{% block declaration %}
import { EventLoopProvider } from "/utils/context.js";
import { EventLoopProvider, StateProvider } from "/utils/context.js";
import { ThemeProvider } from 'next-themes'
{% for custom_code in custom_codes %}
@ -25,12 +25,14 @@ export default function MyApp({ Component, pageProps }) {
return (
<ThemeProvider defaultTheme="light" storageKey="chakra-ui-color-mode" attribute="class">
<AppWrap>
<EventLoopProvider>
<Component {...pageProps} />
</EventLoopProvider>
<StateProvider>
<EventLoopProvider>
<Component {...pageProps} />
</EventLoopProvider>
</StateProvider>
</AppWrap>
</ThemeProvider>
);
}
{% endblock %}
{% endblock %}

View File

@ -8,32 +8,6 @@
{% block export %}
export default function Component() {
{% if state_name %}
const {{state_name}} = useContext(StateContext)
{% endif %}
const {{const.router}} = useRouter()
const [ {{const.color_mode}}, {{const.toggle_color_mode}} ] = useContext(ColorModeContext)
const focusRef = useRef();
// Main event loop.
const [addEvents, connectError] = useContext(EventLoopContext)
// Set focus to the specified element.
useEffect(() => {
if (focusRef.current) {
focusRef.current.focus();
}
})
// Route after the initial page hydration.
useEffect(() => {
const change_complete = () => addEvents(initialEvents())
{{const.router}}.events.on('routeChangeComplete', change_complete)
return () => {
{{const.router}}.events.off('routeChangeComplete', change_complete)
}
}, [{{const.router}}])
{% for hook in hooks %}
{{ hook }}
{% endfor %}

View File

@ -1,5 +1,5 @@
import { createContext, useState } from "react"
import { Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
import { createContext, useContext, useMemo, useReducer, useState } from "react"
import { applyDelta, Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
{% if initial_state %}
export const initialState = {{ initial_state|json_dumps }}
@ -8,7 +8,12 @@ export const initialState = {}
{% endif %}
export const ColorModeContext = createContext(null);
export const StateContext = createContext(null);
export const DispatchContext = createContext(null);
export const StateContexts = {
{% for state_name in initial_state %}
{{state_name|var_name}}: createContext(null),
{% endfor %}
}
export const EventLoopContext = createContext(null);
{% if client_storage %}
export const clientStorage = {{ client_storage|json_dumps }}
@ -27,16 +32,40 @@ export const initialEvents = () => []
export const isDevMode = {{ is_dev_mode|json_dumps }}
export function EventLoopProvider({ children }) {
const [state, addEvents, connectError] = useEventLoop(
initialState,
const dispatch = useContext(DispatchContext)
const [addEvents, connectError] = useEventLoop(
dispatch,
initialEvents,
clientStorage,
)
return (
<EventLoopContext.Provider value={[addEvents, connectError]}>
<StateContext.Provider value={state}>
{children}
</StateContext.Provider>
{children}
</EventLoopContext.Provider>
)
}
}
export function StateProvider({ children }) {
{% for state_name in initial_state %}
const [{{state_name|var_name}}, dispatch_{{state_name|var_name}}] = useReducer(applyDelta, initialState["{{state_name}}"])
{% endfor %}
const dispatchers = useMemo(() => {
return {
{% for state_name in initial_state %}
"{{state_name}}": dispatch_{{state_name|var_name}},
{% endfor %}
}
}, [])
return (
{% for state_name in initial_state %}
<StateContexts.{{state_name|var_name}}.Provider value={ {{state_name|var_name}} }>
{% endfor %}
<DispatchContext.Provider value={dispatchers}>
{children}
</DispatchContext.Provider>
{% for state_name in initial_state|reverse %}
</StateContexts.{{state_name|var_name}}.Provider>
{% endfor %}
)
}

View File

@ -6,7 +6,7 @@ import env from "env.json";
import Cookies from "universal-cookie";
import { useEffect, useReducer, useRef, useState } from "react";
import Router, { useRouter } from "next/router";
import { initialEvents } from "utils/context.js"
import { initialEvents, initialState } from "utils/context.js"
// Endpoint URLs.
const EVENTURL = env.EVENT
@ -100,37 +100,10 @@ export const getEventURL = () => {
* @param delta The delta to apply.
*/
export const applyDelta = (state, delta) => {
const new_state = { ...state }
for (const substate in delta) {
let s = new_state;
const path = substate.split(".").slice(1);
while (path.length > 0) {
s = s[path.shift()];
}
for (const key in delta[substate]) {
s[key] = delta[substate][key];
}
}
return new_state
return { ...state, ...delta }
};
/**
* Get all local storage items in a key-value object.
* @returns object of items in local storage.
*/
export const getAllLocalStorageItems = () => {
var localStorageItems = {};
for (var i = 0, len = localStorage.length; i < len; i++) {
var key = localStorage.key(i);
localStorageItems[key] = localStorage.getItem(key);
}
return localStorageItems;
}
/**
* Handle frontend event or send the event to the backend via Websocket.
* @param event The event to send.
@ -346,7 +319,9 @@ export const connect = async (
// On each received message, queue the updates and events.
socket.current.on("event", message => {
const update = JSON5.parse(message)
dispatch(update.delta)
for (const substate in update.delta) {
dispatch[substate](update.delta[substate])
}
applyClientStorageDelta(client_storage, update.delta)
event_processing = !update.final
if (update.events) {
@ -524,23 +499,21 @@ const applyClientStorageDelta = (client_storage, delta) => {
/**
* Establish websocket event loop for a NextJS page.
* @param initial_state The initial app state.
* @param initial_events Function that returns the initial app events.
* @param dispatch The reducer dispatch function to update state.
* @param initial_events The initial app events.
* @param client_storage The client storage object from context.js
*
* @returns [state, addEvents, connectError] -
* state is a reactive dict,
* @returns [addEvents, connectError] -
* addEvents is used to queue an event, and
* connectError is a reactive js error from the websocket connection (or null if connected).
*/
export const useEventLoop = (
initial_state = {},
dispatch,
initial_events = () => [],
client_storage = {},
) => {
const socket = useRef(null)
const router = useRouter()
const [state, dispatch] = useReducer(applyDelta, initial_state)
const [connectError, setConnectError] = useState(null)
// Function to add new events to the event queue.
@ -570,7 +543,7 @@ export const useEventLoop = (
return;
}
// only use websockets if state is present
if (Object.keys(state).length > 0) {
if (Object.keys(initialState).length > 0) {
// Initialize the websocket connection.
if (!socket.current) {
connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage)
@ -583,7 +556,17 @@ export const useEventLoop = (
})()
}
})
return [state, addEvents, connectError]
// Route after the initial page hydration.
useEffect(() => {
const change_complete = () => addEvents(initial_events())
router.events.on('routeChangeComplete', change_complete)
return () => {
router.events.off('routeChangeComplete', change_complete)
}
}, [router])
return [addEvents, connectError]
}
/***

View File

@ -63,7 +63,7 @@ from reflex.state import (
StateUpdate,
)
from reflex.utils import console, format, prerequisites, types
from reflex.vars import ImportVar
from reflex.utils.imports import ImportVar
# Define custom types.
ComponentCallable = Callable[[], Component]

View File

@ -10,40 +10,10 @@ from reflex.compiler import templates, utils
from reflex.components.component import Component, ComponentStyle, CustomComponent
from reflex.config import get_config
from reflex.state import State
from reflex.utils import imports
from reflex.vars import ImportVar
from reflex.utils.imports import ImportDict, ImportVar
# Imports to be included in every Reflex app.
DEFAULT_IMPORTS: imports.ImportDict = {
"react": [
ImportVar(tag="Fragment"),
ImportVar(tag="useEffect"),
ImportVar(tag="useRef"),
ImportVar(tag="useState"),
ImportVar(tag="useContext"),
],
"next/router": [ImportVar(tag="useRouter")],
f"/{constants.Dirs.STATE_PATH}": [
ImportVar(tag="uploadFiles"),
ImportVar(tag="Event"),
ImportVar(tag="isTrue"),
ImportVar(tag="spreadArraysOrObjects"),
ImportVar(tag="preventDefault"),
ImportVar(tag="refs"),
ImportVar(tag="getRefValue"),
ImportVar(tag="getRefValues"),
ImportVar(tag="getAllLocalStorageItems"),
ImportVar(tag="useEventLoop"),
],
"/utils/context.js": [
ImportVar(tag="EventLoopContext"),
ImportVar(tag="initialEvents"),
ImportVar(tag="StateContext"),
ImportVar(tag="ColorModeContext"),
],
"/utils/helpers/range.js": [
ImportVar(tag="range", is_default=True),
],
DEFAULT_IMPORTS: ImportDict = {
"": [ImportVar(tag="focus-visible/dist/focus-visible", install=False)],
}

View File

@ -3,7 +3,7 @@
from jinja2 import Environment, FileSystemLoader, Template
from reflex import constants
from reflex.utils.format import json_dumps
from reflex.utils.format import format_state_name, json_dumps
class ReflexJinjaEnvironment(Environment):
@ -19,6 +19,7 @@ class ReflexJinjaEnvironment(Environment):
)
self.filters["json_dumps"] = json_dumps
self.filters["react_setter"] = lambda state: f"set{state.capitalize()}"
self.filters["var_name"] = format_state_name
self.loader = FileSystemLoader(constants.Templates.Dirs.JINJA_TEMPLATE)
self.globals["const"] = {
"socket": constants.CompileVars.SOCKET,

View File

@ -24,13 +24,12 @@ from reflex.components.component import Component, ComponentStyle, CustomCompone
from reflex.state import Cookie, LocalStorage, State
from reflex.style import Style
from reflex.utils import console, format, imports, path_ops
from reflex.vars import ImportVar
# To re-export this function.
merge_imports = imports.merge_imports
def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]:
"""Compile an import statement.
Args:
@ -343,7 +342,9 @@ def get_context_path() -> str:
Returns:
The path of the context module.
"""
return os.path.join(constants.Dirs.WEB_UTILS, "context" + constants.Ext.JS)
return os.path.join(
constants.Dirs.WEB, constants.Dirs.CONTEXTS_PATH + constants.Ext.JS
)
def get_components_path() -> str:

View File

@ -1,7 +1,7 @@
"""A bare component."""
from __future__ import annotations
from typing import Any
from typing import Any, Iterator
from reflex.components.component import Component
from reflex.components.tags import Tag
@ -24,7 +24,21 @@ class Bare(Component):
Returns:
The component.
"""
return cls(contents=str(contents)) # type: ignore
if isinstance(contents, Var) and contents._var_data:
contents = contents.to(str)
else:
contents = str(contents)
return cls(contents=contents) # type: ignore
def _render(self) -> Tag:
return Tagless(contents=str(self.contents))
def _get_vars(self) -> Iterator[Var]:
"""Walk all Vars used in this component.
Yields:
The contents if it is a Var, otherwise nothing.
"""
if isinstance(self.contents, Var):
# Fast path for Bare text components.
yield self.contents

View File

@ -5,11 +5,11 @@ from __future__ import annotations
import typing
from abc import ABC
from functools import lru_cache, wraps
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Type, Union
from reflex.base import Base
from reflex.components.tags import Tag
from reflex.constants import Dirs, EventTriggers
from reflex.constants import Dirs, EventTriggers, Hooks, Imports
from reflex.event import (
EventChain,
EventHandler,
@ -20,8 +20,9 @@ from reflex.event import (
)
from reflex.style import Style
from reflex.utils import console, format, imports, types
from reflex.utils.imports import ImportVar
from reflex.utils.serializers import serializer
from reflex.vars import BaseVar, ImportVar, Var
from reflex.vars import BaseVar, Var
class Component(Base, ABC):
@ -388,7 +389,11 @@ class Component(Base, ABC):
props = props.copy()
props.update(
self.event_triggers,
**{
trigger: handler
for trigger, handler in self.event_triggers.items()
if trigger not in {EventTriggers.ON_MOUNT, EventTriggers.ON_UNMOUNT}
},
key=self.key,
id=self.id,
class_name=self.class_name,
@ -488,7 +493,7 @@ class Component(Base, ABC):
"""
if type(self) in style:
# Extract the style for this component.
component_style = Style(style[type(self)])
component_style = style[type(self)]
# Only add style props that are not overridden.
component_style = {
@ -564,6 +569,78 @@ class Component(Base, ABC):
if self._valid_children:
validate_valid_child(name)
@staticmethod
def _get_vars_from_event_triggers(
event_triggers: dict[str, EventChain | Var],
) -> Iterator[tuple[str, list[Var]]]:
"""Get the Vars associated with each event trigger.
Args:
event_triggers: The event triggers from the component instance.
Yields:
tuple of (event_name, event_vars)
"""
for event_trigger, event in event_triggers.items():
if isinstance(event, Var):
yield event_trigger, [event]
elif isinstance(event, EventChain):
event_args = []
for spec in event.events:
for args in spec.args:
event_args.extend(args)
yield event_trigger, event_args
def _get_vars(self) -> list[Var]:
"""Walk all Vars used in this component.
Returns:
Each var referenced by the component (props, styles, event handlers).
"""
vars = getattr(self, "__vars", None)
if vars is not None:
return vars
vars = self.__vars = []
# Get Vars associated with event trigger arguments.
for _, event_vars in self._get_vars_from_event_triggers(self.event_triggers):
vars.extend(event_vars)
# Get Vars associated with component props.
for prop in self.get_props():
prop_var = getattr(self, prop)
if isinstance(prop_var, Var):
vars.append(prop_var)
# Style keeps track of its own VarData instance, so embed in a temp Var that is yielded.
if self.style:
vars.append(
BaseVar(
_var_name="style",
_var_type=str,
_var_data=self.style._var_data,
)
)
# Special props are always Var instances.
vars.extend(self.special_props)
# Get Vars associated with common Component props.
for comp_prop in (
self.class_name,
self.id,
self.key,
self.autofocus,
*self.custom_attrs.values(),
):
if isinstance(comp_prop, Var):
vars.append(comp_prop)
elif isinstance(comp_prop, str):
# Collapse VarData encoded in f-strings.
var = Var.create_safe(comp_prop)
if var._var_data is not None:
vars.append(var)
return vars
def _get_custom_code(self) -> str | None:
"""Get custom code for the component.
@ -644,6 +721,33 @@ class Component(Base, ABC):
dep: [ImportVar(tag=None, render=False)] for dep in self.lib_dependencies
}
def _get_hooks_imports(self) -> imports.ImportDict:
"""Get the imports required by certain hooks.
Returns:
The imports required for all selected hooks.
"""
_imports = {}
if self._get_ref_hook():
# Handle hooks needed for attaching react refs to DOM nodes.
_imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
_imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs"))
if self._get_mount_lifecycle_hook():
# Handle hooks for `on_mount` / `on_unmount`.
_imports.setdefault("react", set()).add(ImportVar(tag="useEffect"))
if self._get_special_hooks():
# Handle additional internal hooks (autofocus, etc).
_imports.setdefault("react", set()).update(
{
ImportVar(tag="useRef"),
ImportVar(tag="useEffect"),
},
)
return _imports
def _get_imports(self) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component.
@ -651,13 +755,26 @@ class Component(Base, ABC):
The imports needed by the component.
"""
_imports = {}
# Import this component's tag from the main library.
if self.library is not None and self.tag is not None:
_imports[self.library] = {self.import_var}
# Get static imports required for event processing.
event_imports = Imports.EVENTS if self.event_triggers else {}
# Collect imports from Vars used directly by this component.
var_imports = [
var._var_data.imports for var in self._get_vars() if var._var_data
]
return imports.merge_imports(
*self._get_props_imports(),
self._get_dependencies_imports(),
self._get_hooks_imports(),
_imports,
event_imports,
*var_imports,
)
def get_imports(self) -> imports.ImportDict:
@ -678,13 +795,13 @@ class Component(Base, ABC):
"""
# pop on_mount and on_unmount from event_triggers since these are handled by
# hooks, not as actually props in the component
on_mount = self.event_triggers.pop(EventTriggers.ON_MOUNT, None)
on_unmount = self.event_triggers.pop(EventTriggers.ON_UNMOUNT, None)
if on_mount:
on_mount = self.event_triggers.get(EventTriggers.ON_MOUNT, None)
on_unmount = self.event_triggers.get(EventTriggers.ON_UNMOUNT, None)
if on_mount is not None:
on_mount = format.format_event_chain(on_mount)
if on_unmount:
if on_unmount is not None:
on_unmount = format.format_event_chain(on_unmount)
if on_mount or on_unmount:
if on_mount is not None or on_unmount is not None:
return f"""
useEffect(() => {{
{on_mount or ""}
@ -703,6 +820,47 @@ class Component(Base, ABC):
if ref is not None:
return f"const {ref} = useRef(null); refs['{ref}'] = {ref};"
def _get_vars_hooks(self) -> set[str]:
"""Get the hooks required by vars referenced in this component.
Returns:
The hooks for the vars.
"""
vars_hooks = set()
for var in self._get_vars():
if var._var_data:
vars_hooks.update(var._var_data.hooks)
return vars_hooks
def _get_events_hooks(self) -> set[str]:
"""Get the hooks required by events referenced in this component.
Returns:
The hooks for the events.
"""
if self.event_triggers:
return {Hooks.EVENTS}
return set()
def _get_special_hooks(self) -> set[str]:
"""Get the hooks required by special actions referenced in this component.
Returns:
The hooks for special actions.
"""
if self.autofocus:
return {
"""
// Set focus to the specified element.
const focusRef = useRef(null)
useEffect(() => {
if (focusRef.current) {
focusRef.current.focus();
}
})""",
}
return set()
def _get_hooks_internal(self) -> Set[str]:
"""Get the React hooks for this component managed by the framework.
@ -712,10 +870,15 @@ class Component(Base, ABC):
Returns:
Set of internally managed hooks.
"""
return set(
hook
for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
if hook
return (
set(
hook
for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
if hook
)
| self._get_vars_hooks()
| self._get_events_hooks()
| self._get_special_hooks()
)
def _get_hooks(self) -> str | None:
@ -1018,11 +1181,24 @@ class NoSSRComponent(Component):
"""A dynamic component that is not rendered on the server."""
def _get_imports(self) -> imports.ImportDict:
dynamic_import = {"next/dynamic": {ImportVar(tag="dynamic", is_default=True)}}
"""Get the imports for the component.
Returns:
The imports for dynamically importing the component at module load time.
"""
# Next.js dynamic import mechanism.
dynamic_import = {"next/dynamic": [ImportVar(tag="dynamic", is_default=True)]}
# The normal imports for this component.
_imports = super()._get_imports()
# Do NOT import the main library/tag statically.
if self.library is not None:
_imports[self.library] = [imports.ImportVar(tag=None, render=False)]
return imports.merge_imports(
dynamic_import,
{self.library: {ImportVar(tag=None, render=False)}},
_imports,
self._get_dependencies_imports(),
)

View File

@ -12,7 +12,8 @@ from reflex.components.media import Icon
from reflex.event import set_clipboard
from reflex.style import Style
from reflex.utils import format, imports
from reflex.vars import ImportVar, Var
from reflex.utils.imports import ImportVar
from reflex.vars import Var
LiteralCodeBlockTheme = Literal[
"a11y-dark",

View File

@ -16,7 +16,8 @@ from reflex.components.media import Icon
from reflex.event import set_clipboard
from reflex.style import Style
from reflex.utils import format, imports
from reflex.vars import ImportVar, Var
from reflex.utils.imports import ImportVar
from reflex.vars import Var
LiteralCodeBlockTheme = Literal[
"a11y-dark",

View File

@ -8,8 +8,9 @@ from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.components.literals import LiteralRowMarker
from reflex.utils import console, format, imports, types
from reflex.utils.imports import ImportVar
from reflex.utils.serializers import serializer
from reflex.vars import ImportVar, Var, get_unique_variable_name
from reflex.vars import Var, get_unique_variable_name
# TODO: Fix the serialization issue for custom types.

View File

@ -13,8 +13,9 @@ from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.components.literals import LiteralRowMarker
from reflex.utils import console, format, imports, types
from reflex.utils.imports import ImportVar
from reflex.utils.serializers import serializer
from reflex.vars import ImportVar, Var, get_unique_variable_name
from reflex.vars import Var, get_unique_variable_name
class GridColumnIcons(Enum):
Array = "array"

View File

@ -8,7 +8,7 @@ from reflex.components.component import Component
from reflex.components.tags import Tag
from reflex.utils import imports, types
from reflex.utils.serializers import serialize, serializer
from reflex.vars import BaseVar, ComputedVar, ImportVar, Var
from reflex.vars import BaseVar, ComputedVar, Var
class Gridjs(Component):
@ -105,7 +105,7 @@ class DataTable(Gridjs):
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
{"": {ImportVar(tag="gridjs/dist/theme/mermaid.css")}},
{"": {imports.ImportVar(tag="gridjs/dist/theme/mermaid.css")}},
)
def _render(self) -> Tag:
@ -113,13 +113,13 @@ class DataTable(Gridjs):
self.columns = BaseVar(
_var_name=f"{self.data._var_name}.columns",
_var_type=List[Any],
_var_state=self.data._var_state,
)
_var_full_name_needs_state_prefix=True,
)._replace(merge_var_data=self.data._var_data)
self.data = BaseVar(
_var_name=f"{self.data._var_name}.data",
_var_type=List[List[Any]],
_var_state=self.data._var_state,
)
_var_full_name_needs_state_prefix=True,
)._replace(merge_var_data=self.data._var_data)
if types.is_dataframe(type(self.data)):
# If given a pandas df break up the data and columns
data = serialize(self.data)

View File

@ -12,7 +12,7 @@ from reflex.components.component import Component
from reflex.components.tags import Tag
from reflex.utils import imports, types
from reflex.utils.serializers import serialize, serializer
from reflex.vars import BaseVar, ComputedVar, ImportVar, Var
from reflex.vars import BaseVar, ComputedVar, Var
class Gridjs(Component):
@overload

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List
from reflex.components.component import Component, NoSSRComponent
from reflex.utils import imports
from reflex.vars import ImportVar, Var
from reflex.vars import Var
class Moment(NoSSRComponent):
@ -78,7 +78,7 @@ class Moment(NoSSRComponent):
if self.tz is not None:
merged_imports = imports.merge_imports(
merged_imports,
{"moment-timezone": {ImportVar(tag="")}},
{"moment-timezone": {imports.ImportVar(tag="")}},
)
return merged_imports

View File

@ -10,7 +10,7 @@ from reflex.style import Style
from typing import Any, Dict, List
from reflex.components.component import Component, NoSSRComponent
from reflex.utils import imports
from reflex.vars import ImportVar, Var
from reflex.vars import Var
class Moment(NoSSRComponent):
def get_event_triggers(self) -> Dict[str, Any]: ...

View File

@ -22,7 +22,7 @@ from reflex.components.component import Component
from reflex.components.layout.cond import Cond, cond
from reflex.components.media.icon import Icon
from reflex.style import color_mode, toggle_color_mode
from reflex.vars import BaseVar
from reflex.vars import Var
from .button import Button
from .switch import Switch
@ -32,7 +32,7 @@ DEFAULT_LIGHT_ICON: Icon = Icon.create(tag="sun")
DEFAULT_DARK_ICON: Icon = Icon.create(tag="moon")
def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component:
def color_mode_cond(light: Any, dark: Any = None) -> Var | Component:
"""Create a component or Prop based on color_mode.
Args:

View File

@ -12,7 +12,7 @@ from reflex.components.component import Component
from reflex.components.layout.cond import Cond, cond
from reflex.components.media.icon import Icon
from reflex.style import color_mode, toggle_color_mode
from reflex.vars import BaseVar
from reflex.vars import Var
from .button import Button
from .switch import Switch
@ -20,7 +20,7 @@ DEFAULT_COLOR_MODE: str
DEFAULT_LIGHT_ICON: Icon
DEFAULT_DARK_ICON: Icon
def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component: ...
def color_mode_cond(light: Any, dark: Any = None) -> Var | Component: ...
class ColorModeIcon(Cond):
@overload

View File

@ -1,10 +1,11 @@
"""Wrapper around react-debounce-input."""
from __future__ import annotations
from typing import Any
from typing import Any, Set
from reflex.components import Component
from reflex.components.tags import Tag
from reflex.utils import imports
from reflex.vars import Var
@ -77,6 +78,17 @@ class DebounceInput(Component):
object.__setattr__(child, "render", lambda: "")
return tag
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(), *[c._get_imports() for c in self.children]
)
def _get_hooks_internal(self) -> Set[str]:
hooks = super()._get_hooks_internal()
for child in self.children:
hooks.update(child._get_hooks_internal())
return hooks
def props_not_none(c: Component) -> dict[str, Any]:
"""Get all properties of the component that are not None.

View File

@ -7,9 +7,10 @@ from typing import Any, Dict, Literal, Optional, Union, overload
from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style
from typing import Any
from typing import Any, Set
from reflex.components import Component
from reflex.components.tags import Tag
from reflex.utils import imports
from reflex.vars import Var
class DebounceInput(Component):

View File

@ -8,7 +8,8 @@ from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.constants import EventTriggers
from reflex.utils.format import to_camel_case
from reflex.vars import ImportVar, Var
from reflex.utils.imports import ImportVar
from reflex.vars import Var
class EditorButtonList(list, enum.Enum):

View File

@ -13,7 +13,8 @@ from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.constants import EventTriggers
from reflex.utils.format import to_camel_case
from reflex.vars import ImportVar, Var
from reflex.utils.imports import ImportVar
from reflex.vars import Var
class EditorButtonList(list, enum.Enum):
BASIC = [["font", "fontSize"], ["fontColor"], ["horizontalRule"], ["link", "image"]]

View File

@ -8,7 +8,7 @@ from jinja2 import Environment
from reflex.components.component import Component
from reflex.components.libs.chakra import ChakraComponent
from reflex.components.tags import Tag
from reflex.constants import EventTriggers
from reflex.constants import Dirs, EventTriggers
from reflex.event import EventChain
from reflex.utils import imports
from reflex.utils.format import format_event_chain, to_camel_case
@ -65,7 +65,13 @@ class Form(ChakraComponent):
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
{"react": {imports.ImportVar(tag="useCallback")}},
{
"react": {imports.ImportVar(tag="useCallback")},
f"/{Dirs.STATE_PATH}": {
imports.ImportVar(tag="getRefValue"),
imports.ImportVar(tag="getRefValues"),
},
},
)
def _get_hooks(self) -> str | None:

View File

@ -12,7 +12,7 @@ from jinja2 import Environment
from reflex.components.component import Component
from reflex.components.libs.chakra import ChakraComponent
from reflex.components.tags import Tag
from reflex.constants import EventTriggers
from reflex.constants import Dirs, EventTriggers
from reflex.event import EventChain
from reflex.utils import imports
from reflex.utils.format import format_event_chain, to_camel_case

View File

@ -11,7 +11,7 @@ from reflex.components.libs.chakra import (
)
from reflex.constants import EventTriggers
from reflex.utils import imports
from reflex.vars import ImportVar, Var
from reflex.vars import Var
class Input(ChakraComponent):
@ -61,7 +61,7 @@ class Input(ChakraComponent):
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
{"/utils/state": {ImportVar(tag="set_val")}},
{"/utils/state": {imports.ImportVar(tag="set_val")}},
)
def get_event_triggers(self) -> Dict[str, Any]:

View File

@ -17,7 +17,7 @@ from reflex.components.libs.chakra import (
)
from reflex.constants import EventTriggers
from reflex.utils import imports
from reflex.vars import ImportVar, Var
from reflex.vars import Var
class Input(ChakraComponent):
def get_event_triggers(self) -> Dict[str, Any]: ...

View File

@ -68,9 +68,11 @@ class PinInput(ChakraComponent):
Returns:
The merged import dict.
"""
range_var = Var.range(0)
return merge_imports(
super()._get_imports(),
PinInputField().get_imports(), # type: ignore
range_var._var_data.imports if range_var._var_data is not None else {},
)
def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
@ -117,7 +119,7 @@ class PinInput(ChakraComponent):
)
refs_declaration._var_is_local = True
if ref:
return f"const {ref} = {refs_declaration}"
return f"const {ref} = {str(refs_declaration)}"
return super()._get_ref_hook()
def _render(self) -> Tag:

View File

@ -7,9 +7,10 @@ from reflex import constants
from reflex.components.component import Component
from reflex.components.forms.input import Input
from reflex.components.layout.box import Box
from reflex.constants import Dirs
from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
from reflex.utils import imports
from reflex.vars import BaseVar, CallableVar, ImportVar, Var
from reflex.vars import BaseVar, CallableVar, Var, VarData
DEFAULT_UPLOAD_ID: str = "default"
@ -30,6 +31,13 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
return BaseVar(
_var_name=f"e => upload_files.{id_}[1]((files) => e)",
_var_type=EventChain,
_var_data=VarData( # type: ignore
imports={
f"/{Dirs.STATE_PATH}": {
imports.ImportVar(tag="upload_files"),
},
},
),
)
@ -46,6 +54,13 @@ def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
return BaseVar(
_var_name=f"(upload_files.{id_} ? upload_files.{id_}[0]?.map((f) => (f.path || f.name)) : [])",
_var_type=List[str],
_var_data=VarData( # type: ignore
imports={
f"/{Dirs.STATE_PATH}": {
imports.ImportVar(tag="upload_files"),
},
},
),
)
@ -166,14 +181,16 @@ class Upload(Component):
def _get_hooks(self) -> str | None:
return (
(super()._get_hooks() or "")
+ f"""
upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);
"""
)
super()._get_hooks() or ""
) + f"upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);"
def _get_imports(self) -> imports.ImportDict:
return {
**super()._get_imports(),
f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="upload_files")],
}
return imports.merge_imports(
super()._get_imports(),
{
"react": {imports.ImportVar(tag="useState")},
f"/{constants.Dirs.STATE_PATH}": [
imports.ImportVar(tag="upload_files")
],
},
)

View File

@ -12,9 +12,10 @@ from reflex import constants
from reflex.components.component import Component
from reflex.components.forms.input import Input
from reflex.components.layout.box import Box
from reflex.constants import Dirs
from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
from reflex.utils import imports
from reflex.vars import BaseVar, CallableVar, ImportVar, Var
from reflex.vars import BaseVar, CallableVar, Var, VarData
DEFAULT_UPLOAD_ID: str

View File

@ -1,13 +1,18 @@
"""Create a list of components from an iterable."""
from __future__ import annotations
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, overload
from reflex.components.component import Component
from reflex.components.layout.fragment import Fragment
from reflex.components.tags import CondTag, Tag
from reflex.utils import format
from reflex.vars import Var
from reflex.constants import Dirs
from reflex.utils import format, imports
from reflex.vars import BaseVar, Var, VarData
_IS_TRUE_IMPORT = {
f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
}
class Cond(Component):
@ -88,6 +93,28 @@ class Cond(Component):
cond_state=f"isTrue({self.cond._var_full_name})",
)
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
getattr(self.cond._var_data, "imports", {}),
_IS_TRUE_IMPORT,
)
@overload
def cond(condition: Any, c1: Component, c2: Any) -> Component:
...
@overload
def cond(condition: Any, c1: Component) -> Component:
...
@overload
def cond(condition: Any, c1: Any, c2: Any) -> Var:
...
def cond(condition: Any, c1: Any, c2: Any = None):
"""Create a conditional component or Prop.
@ -103,8 +130,11 @@ def cond(condition: Any, c1: Any, c2: Any = None):
Raises:
ValueError: If the arguments are invalid.
"""
# Import here to avoid circular imports.
from reflex.vars import BaseVar, Var
var_datas: list[VarData | None] = [
VarData( # type: ignore
imports=_IS_TRUE_IMPORT,
),
]
# Convert the condition to a Var.
cond_var = Var.create(condition)
@ -116,16 +146,20 @@ def cond(condition: Any, c1: Any, c2: Any = None):
c2, Component
), "Both arguments must be components."
return Cond.create(cond_var, c1, c2)
if isinstance(c1, Var):
var_datas.append(c1._var_data)
# Otherwise, create a conditionl Var.
# Otherwise, create a conditional Var.
# Check that the second argument is valid.
if isinstance(c2, Component):
raise ValueError("Both arguments must be props.")
if c2 is None:
raise ValueError("For conditional vars, the second argument must be set.")
if isinstance(c2, Var):
var_datas.append(c2._var_data)
# Create the conditional var.
return BaseVar(
return cond_var._replace(
_var_name=format.format_cond(
cond=cond_var._var_full_name,
true_value=c1,
@ -133,4 +167,7 @@ def cond(condition: Any, c1: Any, c2: Any = None):
is_prop=True,
),
_var_type=c1._var_type if isinstance(c1, BaseVar) else type(c1),
_var_is_local=False,
_var_full_name_needs_state_prefix=False,
merge_var_data=VarData.merge(*var_datas),
)

View File

@ -1,8 +1,8 @@
"""A html component."""
from typing import Any
from typing import Dict
from reflex.components.layout.box import Box
from reflex.vars import Var
class Html(Box):
@ -13,7 +13,7 @@ class Html(Box):
"""
# The HTML to render.
dangerouslySetInnerHTML: Any
dangerouslySetInnerHTML: Var[Dict[str, str]]
@classmethod
def create(cls, *children, **props):

View File

@ -7,8 +7,9 @@ from typing import Any, Dict, Literal, Optional, Union, overload
from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style
from typing import Any
from typing import Dict
from reflex.components.layout.box import Box
from reflex.vars import Var
class Html(Box):
@overload
@ -16,7 +17,9 @@ class Html(Box):
def create( # type: ignore
cls,
*children,
dangerouslySetInnerHTML: Optional[Any] = None,
dangerouslySetInnerHTML: Optional[
Union[Var[Dict[str, str]], Dict[str, str]]
] = None,
element: Optional[Union[Var[str], str]] = None,
src: Optional[Union[Var[str], str]] = None,
alt: Optional[Union[Var[str], str]] = None,

View File

@ -6,7 +6,7 @@ from typing import List, Literal
from reflex.components.component import Component
from reflex.utils import imports
from reflex.vars import ImportVar, Var
from reflex.vars import Var
class ChakraComponent(Component):
@ -34,7 +34,7 @@ class ChakraComponent(Component):
The dependencies imports of the component.
"""
return {
dep: [ImportVar(tag=None, render=False)]
dep: [imports.ImportVar(tag=None, render=False)]
for dep in [
"@chakra-ui/system@2.5.7",
"framer-motion@10.16.4",
@ -75,17 +75,17 @@ class ChakraProvider(ChakraComponent):
)
def _get_imports(self) -> imports.ImportDict:
imports = super()._get_imports()
imports.setdefault(self.__fields__["library"].default, []).append(
ImportVar(tag="extendTheme", is_default=False),
_imports = super()._get_imports()
_imports.setdefault(self.__fields__["library"].default, []).append(
imports.ImportVar(tag="extendTheme", is_default=False),
)
imports.setdefault("/utils/theme.js", []).append(
ImportVar(tag="theme", is_default=True),
_imports.setdefault("/utils/theme.js", []).append(
imports.ImportVar(tag="theme", is_default=True),
)
imports.setdefault(Global.__fields__["library"].default, []).append(
ImportVar(tag="css", is_default=False),
_imports.setdefault(Global.__fields__["library"].default, []).append(
imports.ImportVar(tag="css", is_default=False),
)
return imports
return _imports
def _get_custom_code(self) -> str | None:
return """

View File

@ -11,7 +11,7 @@ from functools import lru_cache
from typing import List, Literal
from reflex.components.component import Component
from reflex.utils import imports
from reflex.vars import ImportVar, Var
from reflex.vars import Var
class ChakraComponent(Component):
@overload

View File

@ -10,10 +10,9 @@ routeNotFound becomes true.
from __future__ import annotations
from reflex import constants
from ...vars import Var
from ..component import Component
from ..layout.cond import Cond
from reflex.components.component import Component
from reflex.components.layout.cond import cond
from reflex.vars import Var
route_not_found: Var = Var.create_safe(constants.ROUTE_NOT_FOUND)
@ -52,10 +51,10 @@ def wait_for_client_redirect(component) -> Component:
Returns:
The conditionally rendered component.
"""
return Cond.create(
cond=route_not_found,
comp1=component,
comp2=ClientSideRouting.create(),
return cond(
condition=route_not_found,
c1=component,
c2=ClientSideRouting.create(),
)

View File

@ -8,9 +8,9 @@ from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style
from reflex import constants
from ...vars import Var
from ..component import Component
from ..layout.cond import Cond
from reflex.components.component import Component
from reflex.components.layout.cond import cond
from reflex.vars import Var
route_not_found: Var

View File

@ -5,22 +5,27 @@ from typing import Optional
from reflex.components.base.bare import Bare
from reflex.components.component import Component
from reflex.components.layout import Box, Cond
from reflex.components.layout import Box, cond
from reflex.components.overlay.modal import Modal
from reflex.components.typography import Text
from reflex.constants import Hooks, Imports
from reflex.utils import imports
from reflex.vars import ImportVar, Var
from reflex.vars import Var, VarData
connect_error_var_data: VarData = VarData( # type: ignore
imports=Imports.EVENTS,
hooks={Hooks.EVENTS},
)
connection_error: Var = Var.create_safe(
value="(connectError !== null) ? connectError.message : ''",
_var_is_local=False,
_var_is_string=False,
)
)._replace(merge_var_data=connect_error_var_data)
has_connection_error: Var = Var.create_safe(
value="connectError !== null",
_var_is_string=False,
)
has_connection_error._var_type = bool
)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
class WebsocketTargetURL(Bare):
@ -28,7 +33,7 @@ class WebsocketTargetURL(Bare):
def _get_imports(self) -> imports.ImportDict:
return {
"/utils/state.js": [ImportVar(tag="getEventURL")],
"/utils/state.js": [imports.ImportVar(tag="getEventURL")],
}
@classmethod
@ -78,7 +83,7 @@ class ConnectionBanner(Component):
textAlign="center",
)
return Cond.create(has_connection_error, comp)
return cond(has_connection_error, comp)
class ConnectionModal(Component):
@ -96,7 +101,7 @@ class ConnectionModal(Component):
"""
if not comp:
comp = Text.create(*default_connection_error())
return Cond.create(
return cond(
has_connection_error,
Modal.create(
header="Connection Error",

View File

@ -10,15 +10,16 @@ from reflex.style import Style
from typing import Optional
from reflex.components.base.bare import Bare
from reflex.components.component import Component
from reflex.components.layout import Box, Cond
from reflex.components.layout import Box, cond
from reflex.components.overlay.modal import Modal
from reflex.components.typography import Text
from reflex.constants import Hooks, Imports
from reflex.utils import imports
from reflex.vars import ImportVar, Var
from reflex.vars import Var, VarData
connect_error_var_data: VarData
connection_error: Var
has_connection_error: Var
has_connection_error._var_type = bool
class WebsocketTargetURL(Bare):
@overload

View File

@ -6,7 +6,7 @@ from typing import Literal
from reflex.components import Component
from reflex.utils import imports
from reflex.vars import ImportVar, Var
from reflex.vars import Var
LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
LiteralJustify = Literal["start", "center", "end", "between"]
@ -147,7 +147,7 @@ class Theme(RadixThemesComponent):
def _get_imports(self) -> imports.ImportDict:
return {
**super()._get_imports(),
"": [ImportVar(tag="@radix-ui/themes/styles.css", install=False)],
"": [imports.ImportVar(tag="@radix-ui/themes/styles.css", install=False)],
}

View File

@ -10,7 +10,7 @@ from reflex.style import Style
from typing import Literal
from reflex.components import Component
from reflex.utils import imports
from reflex.vars import ImportVar, Var
from reflex.vars import Var
LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
LiteralJustify = Literal["start", "center", "end", "between"]

View File

@ -14,7 +14,8 @@ from reflex.components.typography.heading import Heading
from reflex.components.typography.text import Text
from reflex.style import Style
from reflex.utils import console, imports, types
from reflex.vars import ImportVar, Var
from reflex.utils.imports import ImportVar
from reflex.vars import Var
# Special vars used in the component map.
_CHILDREN = Var.create_safe("children", _var_is_local=False)

View File

@ -18,7 +18,8 @@ from reflex.components.typography.heading import Heading
from reflex.components.typography.text import Text
from reflex.style import Style
from reflex.utils import console, imports, types
from reflex.vars import ImportVar, Var
from reflex.utils.imports import ImportVar
from reflex.vars import Var
_CHILDREN = Var.create_safe("children", _var_is_local=False)
_PROPS = Var.create_safe("...props", _var_is_local=False)

View File

@ -22,6 +22,8 @@ from .compiler import (
CompileVars,
ComponentName,
Ext,
Hooks,
Imports,
PageNames,
)
from .config import (
@ -68,7 +70,9 @@ __ALL__ = [
Ext,
Fnm,
GitIgnore,
Hooks,
RequirementsTxt,
Imports,
IS_WINDOWS,
LOCAL_STORAGE,
LogLevel,

View File

@ -29,6 +29,8 @@ class Dirs(SimpleNamespace):
STATE_PATH = "/".join([UTILS, "state"])
# The name of the components file.
COMPONENTS_PATH = "/".join([UTILS, "components"])
# The name of the contexts file.
CONTEXTS_PATH = "/".join([UTILS, "context"])
# The directory where the app pages are compiled to.
WEB_PAGES = os.path.join(WEB, "pages")
# The directory where the static build is located.

View File

@ -2,6 +2,9 @@
from enum import Enum
from types import SimpleNamespace
from reflex.constants import Dirs
from reflex.utils.imports import ImportVar
# The prefix used to create setters for state vars.
SETTER_PREFIX = "set_"
@ -47,6 +50,12 @@ class CompileVars(SimpleNamespace):
HYDRATE = "hydrate"
# The name of the is_hydrated variable.
IS_HYDRATED = "is_hydrated"
# The name of the function to add events to the queue.
ADD_EVENTS = "addEvents"
# The name of the var storing any connection error.
CONNECT_ERROR = "connectError"
# The name of the function for converting a dict to an event.
TO_EVENT = "Event"
class PageNames(SimpleNamespace):
@ -77,3 +86,19 @@ class ComponentName(Enum):
The lower-case filename with zip extension.
"""
return self.value.lower() + Ext.ZIP
class Imports(SimpleNamespace):
"""Common sets of import vars."""
EVENTS = {
"react": {ImportVar(tag="useContext")},
f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
}
class Hooks(SimpleNamespace):
"""Common sets of hook declarations."""
EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);"

View File

@ -48,7 +48,7 @@ class HydrateMiddleware(Middleware):
setattr(var_state, var_name, value)
# Get the initial state.
delta = format.format_state({state.get_name(): state.dict()})
delta = format.format_state(state.dict())
# since a full dict was captured, clean any dirtiness
state._clean()

View File

@ -1211,12 +1211,16 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
if include_computed
else {}
)
substate_vars = {
k: v.dict(include_computed=include_computed, **kwargs)
for k, v in self.substates.items()
variables = {**base_vars, **computed_vars}
d = {
self.get_full_name(): {k: variables[k] for k in sorted(variables)},
}
variables = {**base_vars, **computed_vars, **substate_vars}
return {k: variables[k] for k in sorted(variables)}
for substate_d in [
v.dict(include_computed=include_computed, **kwargs)
for v in self.substates.values()
]:
d.update(substate_d)
return d
async def __aenter__(self) -> State:
"""Enter the async context manager protocol.

View File

@ -2,13 +2,38 @@
from __future__ import annotations
from typing import Any
from reflex import constants
from reflex.event import EventChain
from reflex.utils import format
from reflex.vars import BaseVar, Var
from reflex.utils.imports import ImportVar
from reflex.vars import BaseVar, Var, VarData
color_mode = BaseVar(_var_name=constants.ColorMode.NAME, _var_type="str")
toggle_color_mode = BaseVar(_var_name=constants.ColorMode.TOGGLE, _var_type=EventChain)
VarData.update_forward_refs() # Ensure all type definitions are resolved
# Reference the global ColorModeContext
color_mode_var_data = VarData( # type: ignore
imports={
f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")},
"react": {ImportVar(tag="useContext")},
},
hooks={
f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)",
},
)
# Var resolves to the current color mode for the app ("light" or "dark")
color_mode = BaseVar(
_var_name=constants.ColorMode.NAME,
_var_type="str",
_var_data=color_mode_var_data,
)
# Var resolves to a function invocation that toggles the color mode
toggle_color_mode = BaseVar(
_var_name=constants.ColorMode.TOGGLE,
_var_type=EventChain,
_var_data=color_mode_var_data,
)
def convert(style_dict):
@ -20,16 +45,27 @@ def convert(style_dict):
Returns:
The formatted style dictionary.
"""
var_data = None # Track import/hook data from any Vars in the style dict.
out = {}
for key, value in style_dict.items():
key = format.to_camel_case(key)
new_var_data = None
if isinstance(value, dict):
out[key] = convert(value)
# Recursively format nested style dictionaries.
out[key], new_var_data = convert(value)
elif isinstance(value, Var):
# If the value is a Var, extract the var_data and cast as str.
new_var_data = value._var_data
out[key] = str(value)
else:
# Otherwise, convert to Var to collapse VarData encoded in f-string.
new_var = Var.create(value)
if new_var is not None:
new_var_data = new_var._var_data
out[key] = value
return out
# Combine all the collected VarData instances.
var_data = VarData.merge(var_data, new_var_data)
return out, var_data
class Style(dict):
@ -41,5 +77,33 @@ class Style(dict):
Args:
style_dict: The style dictionary.
"""
style_dict = style_dict or {}
super().__init__(convert(style_dict))
style_dict, self._var_data = convert(style_dict or {})
super().__init__(style_dict)
def update(self, style_dict: dict | None, **kwargs):
"""Update the style.
Args:
style_dict: The style dictionary.
kwargs: Other key value pairs to apply to the dict update.
"""
if kwargs:
style_dict = {**(style_dict or {}), **kwargs}
converted_dict = type(self)(style_dict)
# Combine our VarData with that of any Vars in the style_dict that was passed.
self._var_data = VarData.merge(self._var_data, converted_dict._var_data)
super().update(converted_dict)
def __setitem__(self, key: str, value: Any):
"""Set an item in the style.
Args:
key: The key to set.
value: The value to set.
"""
# Create a Var to collapse VarData encoded in f-string.
_var = Var.create(value)
if _var is not None:
# Carry the imports/hooks when setting a Var as a value.
self._var_data = VarData.merge(self._var_data, _var._var_data)
super().__setitem__(key, value)

View File

@ -232,9 +232,9 @@ def format_route(route: str, format_case=True) -> str:
def format_cond(
cond: str,
true_value: str,
false_value: str = '""',
cond: str | Var,
true_value: str | Var,
false_value: str | Var = '""',
is_prop=False,
) -> str:
"""Format a conditional expression.
@ -248,9 +248,6 @@ def format_cond(
Returns:
The formatted conditional expression.
"""
# Import here to avoid circular imports.
from reflex.vars import Var
# Use Python truthiness.
cond = f"isTrue({cond})"
@ -266,6 +263,7 @@ def format_cond(
_var_is_string=type(false_value) is str,
)
prop2._var_is_local = True
prop1, prop2 = str(prop1), str(prop2) # avoid f-string semantics for Var
return f"{cond} ? {prop1} : {prop2}".replace("{", "").replace("}", "")
# Format component conds.
@ -517,6 +515,21 @@ def format_state(value: Any) -> Any:
raise TypeError(f"No JSON serializer found for var {value} of type {type(value)}.")
def format_state_name(state_name: str) -> str:
"""Format a state name, replacing dots with double underscore.
This allows individual substates to be accessed independently as javascript vars
without using dot notation.
Args:
state_name: The state name to format.
Returns:
The formatted state name.
"""
return state_name.replace(".", "__")
def format_ref(ref: str) -> str:
"""Format a ref.

View File

@ -3,11 +3,9 @@
from __future__ import annotations
from collections import defaultdict
from typing import Dict, List
from typing import Dict, List, Optional
from reflex.vars import ImportVar
ImportDict = Dict[str, List[ImportVar]]
from reflex.base import Base
def merge_imports(*imports) -> ImportDict:
@ -24,3 +22,42 @@ def merge_imports(*imports) -> ImportDict:
for lib, fields in import_dict.items():
all_imports[lib].extend(fields)
return all_imports
class ImportVar(Base):
"""An import var."""
# The name of the import tag.
tag: Optional[str]
# whether the import is default or named.
is_default: Optional[bool] = False
# The tag alias.
alias: Optional[str] = None
# Whether this import need to install the associated lib
install: Optional[bool] = True
# whether this import should be rendered or not
render: Optional[bool] = True
@property
def name(self) -> str:
"""The name of the import.
Returns:
The name(tag name with alias) of tag.
"""
return self.tag if not self.alias else " as ".join([self.tag, self.alias]) # type: ignore
def __hash__(self) -> int:
"""Define a hash function for the import var.
Returns:
The hash of the var.
"""
return hash((self.tag, self.is_default, self.alias, self.install, self.render))
ImportDict = Dict[str, List[ImportVar]]

View File

@ -27,6 +27,7 @@ from reflex.utils import serializers
GenericType = Union[Type, _GenericAlias]
# Valid state var types.
JSONType = {str, int, float, bool}
PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
StateVar = Union[PrimitiveType, Base, None]
StateIterVar = Union[list, set, tuple]

View File

@ -7,6 +7,7 @@ import dis
import inspect
import json
import random
import re
import string
import sys
from types import CodeType, FunctionType
@ -15,9 +16,11 @@ from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Set,
Tuple,
Type,
Union,
@ -30,7 +33,10 @@ from typing import (
from reflex import constants
from reflex.base import Base
from reflex.utils import console, format, serializers, types
from reflex.utils import console, format, imports, serializers, types
# This module used to export ImportVar itself, so we still import it for export here
from reflex.utils.imports import ImportDict, ImportVar
if TYPE_CHECKING:
from reflex.state import State
@ -71,7 +77,7 @@ OPERATION_MAPPING = {
REPLACED_NAMES = {
"full_name": "_var_full_name",
"name": "_var_name",
"state": "_var_state",
"state": "_var_data.state",
"type_": "_var_type",
"is_local": "_var_is_local",
"is_string": "_var_is_string",
@ -93,6 +99,131 @@ def get_unique_variable_name() -> str:
return get_unique_variable_name()
class VarData(Base):
"""Metadata associated with a Var."""
# The name of the enclosing state.
state: str = ""
# Imports needed to render this var
imports: ImportDict = {}
# Hooks that need to be present in the component to render this var
hooks: Set[str] = set()
@classmethod
def merge(cls, *others: VarData | None) -> VarData | None:
"""Merge multiple var data objects.
Args:
*others: The var data objects to merge.
Returns:
The merged var data object.
"""
state = ""
_imports = {}
hooks = set()
for var_data in others:
if var_data is None:
continue
state = state or var_data.state
_imports = imports.merge_imports(_imports, var_data.imports)
hooks.update(var_data.hooks)
return (
cls(
state=state,
imports=_imports,
hooks=hooks,
)
or None
)
def __bool__(self) -> bool:
"""Check if the var data is non-empty.
Returns:
True if any field is set to a non-default value.
"""
return bool(self.state or self.imports or self.hooks)
def dict(self) -> dict:
"""Convert the var data to a dictionary.
Returns:
The var data dictionary.
"""
return {
"state": self.state,
"imports": {
lib: [import_var.dict() for import_var in import_vars]
for lib, import_vars in self.imports.items()
},
"hooks": list(self.hooks),
}
def _encode_var(value: Var) -> str:
"""Encode the state name into a formatted var.
Args:
value: The value to encode the state name into.
Returns:
The encoded var.
"""
if value._var_data:
return f"<reflex.Var>{value._var_data.json()}</reflex.Var>" + str(value)
return str(value)
def _decode_var(value: str) -> tuple[VarData | None, str]:
"""Decode the state name from a formatted var.
Args:
value: The value to extract the state name from.
Returns:
The extracted state name and the value without the state name.
"""
var_datas = []
if isinstance(value, str):
# Extract the state name from a formatted var
while m := re.match(r"(.*)<reflex.Var>(.*)</reflex.Var>(.*)", value):
value = m.group(1) + m.group(3)
var_datas.append(VarData.parse_raw(m.group(2)))
if var_datas:
return VarData.merge(*var_datas), value
return None, value
def _extract_var_data(value: Iterable) -> list[VarData | None]:
"""Extract the var imports and hooks from an iterable containing a Var.
Args:
value: The iterable to extract the VarData from
Returns:
The extracted VarDatas.
"""
var_datas = []
with contextlib.suppress(TypeError):
for sub in value:
if isinstance(sub, Var):
var_datas.append(sub._var_data)
elif not isinstance(sub, str):
# Recurse into dict values.
if hasattr(sub, "values") and callable(sub.values):
var_datas.extend(_extract_var_data(sub.values()))
# Recurse into iterable values (or dict keys).
var_datas.extend(_extract_var_data(sub))
# Recurse when value is a dict itself.
values = getattr(value, "values", None)
if callable(values):
var_datas.extend(_extract_var_data(values()))
return var_datas
class Var:
"""An abstract var."""
@ -102,15 +233,18 @@ class Var:
# The type of the var.
_var_type: Type
# The name of the enclosing state.
_var_state: str
# Whether this is a local javascript variable.
_var_is_local: bool
# Whether the var is a string literal.
_var_is_string: bool
# _var_full_name should be prefixed with _var_state
_var_full_name_needs_state_prefix: bool
# Extra metadata associated with the Var
_var_data: Optional[VarData]
@classmethod
def create(
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
@ -136,9 +270,14 @@ class Var:
if isinstance(value, Var):
return value
# Try to pull the imports and hooks from contained values.
_var_data = None
if not isinstance(value, str):
_var_data = VarData.merge(*_extract_var_data(value))
# Try to serialize the value.
type_ = type(value)
name = serializers.serialize(value)
name = value if type_ in types.JSONType else serializers.serialize(value)
if name is None:
raise TypeError(
f"No JSON serializer found for var {value} of type {type_}."
@ -150,6 +289,7 @@ class Var:
_var_type=type_,
_var_is_local=_var_is_local,
_var_is_string=_var_is_string,
_var_data=_var_data,
)
@classmethod
@ -186,6 +326,39 @@ class Var:
"""
return _GenericAlias(cls, type_)
def __post_init__(self) -> None:
"""Post-initialize the var."""
# Decode any inline Var markup and apply it to the instance
_var_data, _var_name = _decode_var(self._var_name)
if _var_data:
self._var_name = _var_name
self._var_data = VarData.merge(self._var_data, _var_data)
def _replace(self, merge_var_data=None, **kwargs: Any) -> Var:
"""Make a copy of this Var with updated fields.
Args:
merge_var_data: VarData to merge into the existing VarData.
**kwargs: Var fields to update.
Returns:
A new BaseVar with the updated fields overwriting the corresponding fields in this Var.
"""
field_values = dict(
_var_name=kwargs.pop("_var_name", self._var_name),
_var_type=kwargs.pop("_var_type", self._var_type),
_var_is_local=kwargs.pop("_var_is_local", self._var_is_local),
_var_is_string=kwargs.pop("_var_is_string", self._var_is_string),
_var_full_name_needs_state_prefix=kwargs.pop(
"_var_full_name_needs_state_prefix",
self._var_full_name_needs_state_prefix,
),
_var_data=VarData.merge(
kwargs.get("_var_data", self._var_data), merge_var_data
),
)
return BaseVar(**field_values)
def _decode(self) -> Any:
"""Decode Var as a python value.
@ -195,8 +368,6 @@ class Var:
Returns:
The decoded value or the Var name.
"""
if self._var_state:
return self._var_full_name
if self._var_is_string:
return self._var_name
try:
@ -216,8 +387,10 @@ class Var:
return (
self._var_name == other._var_name
and self._var_type == other._var_type
and self._var_state == other._var_state
and self._var_is_local == other._var_is_local
and self._var_full_name_needs_state_prefix
== other._var_full_name_needs_state_prefix
and self._var_data == other._var_data
)
def to_string(self, json: bool = True) -> Var:
@ -285,9 +458,11 @@ class Var:
Returns:
The formatted var.
"""
# Encode the _var_data into the formatted output for tracking purposes.
str_self = _encode_var(self)
if self._var_is_local:
return str(self)
return f"${str(self)}"
return str_self
return f"${str_self}"
def __getitem__(self, i: Any) -> Var:
"""Index into a var.
@ -320,12 +495,7 @@ class Var:
# Convert any vars to local vars.
if isinstance(i, Var):
i = BaseVar(
_var_name=i._var_name,
_var_type=i._var_type,
_var_state=i._var_state,
_var_is_local=True,
)
i = i._replace(_var_is_local=True)
# Handle list/tuple/str indexing.
if types._issubclass(self._var_type, Union[List, Tuple, str]):
@ -344,11 +514,9 @@ class Var:
stop = i.stop or "undefined"
# Use the slice function.
return BaseVar(
return self._replace(
_var_name=f"{self._var_name}.slice({start}, {stop})",
_var_type=self._var_type,
_var_state=self._var_state,
_var_is_local=self._var_is_local,
_var_is_string=False,
)
# Get the type of the indexed var.
@ -359,11 +527,10 @@ class Var:
)
# Use `at` to support negative indices.
return BaseVar(
return self._replace(
_var_name=f"{self._var_name}.at({i})",
_var_type=type_,
_var_state=self._var_state,
_var_is_local=self._var_is_local,
_var_is_string=False,
)
# Dictionary / dataframe indexing.
@ -393,11 +560,10 @@ class Var:
)
# Use normal indexing here.
return BaseVar(
return self._replace(
_var_name=f"{self._var_name}[{i}]",
_var_type=type_,
_var_state=self._var_state,
_var_is_local=self._var_is_local,
_var_is_string=False,
)
def __getattr__(self, name: str) -> Var:
@ -423,11 +589,10 @@ class Var:
type_ = types.get_attribute_access_type(self._var_type, name)
if type_ is not None:
return BaseVar(
return self._replace(
_var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
_var_type=type_,
_var_state=self._var_state,
_var_is_local=self._var_is_local,
_var_is_string=False,
)
if name in REPLACED_NAMES:
@ -519,10 +684,12 @@ class Var:
else f"{self._var_full_name}.{fn}()"
)
return BaseVar(
return self._replace(
_var_name=operation_name,
_var_type=type_,
_var_is_local=self._var_is_local,
_var_is_string=False,
_var_full_name_needs_state_prefix=False,
merge_var_data=other._var_data if other is not None else None,
)
@staticmethod
@ -602,10 +769,10 @@ class Var:
"""
if not types._issubclass(self._var_type, List):
raise TypeError(f"Cannot get length of non-list var {self}.")
return BaseVar(
_var_name=f"{self._var_full_name}.length",
return self._replace(
_var_name=f"{self._var_name}.length",
_var_type=int,
_var_is_local=self._var_is_local,
_var_is_string=False,
)
def __eq__(self, other: Var) -> Var:
@ -692,7 +859,17 @@ class Var:
types.get_base_class(self._var_type) == list
and types.get_base_class(other_type) == list
):
return self.operation(",", other, fn="spreadArraysOrObjects", flip=flip)
return self.operation(
",", other, fn="spreadArraysOrObjects", flip=flip
)._replace(
merge_var_data=VarData(
imports={
f"/{constants.Dirs.STATE_PATH}": [
ImportVar(tag="spreadArraysOrObjects")
]
},
),
)
return self.operation("+", other, flip=flip)
def __radd__(self, other: Var) -> Var:
@ -755,10 +932,11 @@ class Var:
]:
other_name = other._var_full_name if isinstance(other, Var) else other
name = f"Array({other_name}).fill().map(() => {self._var_full_name}).flat()"
return BaseVar(
return self._replace(
_var_name=name,
_var_type=str,
_var_is_local=self._var_is_local,
_var_is_string=False,
_var_full_name_needs_state_prefix=False,
)
return self.operation("*", other)
@ -1003,10 +1181,11 @@ class Var:
elif not isinstance(other, Var):
other = Var.create(other)
if types._issubclass(self._var_type, Dict):
return BaseVar(
_var_name=f"{self._var_full_name}.{method}({other._var_full_name})",
return self._replace(
_var_name=f"{self._var_name}.{method}({other._var_full_name})",
_var_type=bool,
_var_is_local=self._var_is_local,
_var_is_string=False,
merge_var_data=other._var_data,
)
else: # str, list, tuple
# For strings, the left operand must be a string.
@ -1016,10 +1195,11 @@ class Var:
raise TypeError(
f"'in <string>' requires string as left operand, not {other._var_type}"
)
return BaseVar(
_var_name=f"{self._var_full_name}.includes({other._var_full_name})",
return self._replace(
_var_name=f"{self._var_name}.includes({other._var_full_name})",
_var_type=bool,
_var_is_local=self._var_is_local,
_var_is_string=False,
merge_var_data=other._var_data,
)
def reverse(self) -> Var:
@ -1034,10 +1214,10 @@ class Var:
if not types._issubclass(self._var_type, list):
raise TypeError(f"Cannot reverse non-list var {self._var_full_name}.")
return BaseVar(
return self._replace(
_var_name=f"[...{self._var_full_name}].reverse()",
_var_type=self._var_type,
_var_is_local=self._var_is_local,
_var_is_string=False,
_var_full_name_needs_state_prefix=False,
)
def lower(self) -> Var:
@ -1054,10 +1234,10 @@ class Var:
f"Cannot convert non-string var {self._var_full_name} to lowercase."
)
return BaseVar(
_var_name=f"{self._var_full_name}.toLowerCase()",
return self._replace(
_var_name=f"{self._var_name}.toLowerCase()",
_var_is_string=False,
_var_type=str,
_var_is_local=self._var_is_local,
)
def upper(self) -> Var:
@ -1074,10 +1254,10 @@ class Var:
f"Cannot convert non-string var {self._var_full_name} to uppercase."
)
return BaseVar(
_var_name=f"{self._var_full_name}.toUpperCase()",
return self._replace(
_var_name=f"{self._var_name}.toUpperCase()",
_var_is_string=False,
_var_type=str,
_var_is_local=self._var_is_local,
)
def split(self, other: str | Var[str] = " ") -> Var:
@ -1097,10 +1277,11 @@ class Var:
other = Var.create_safe(json.dumps(other)) if isinstance(other, str) else other
return BaseVar(
_var_name=f"{self._var_full_name}.split({other._var_full_name})",
return self._replace(
_var_name=f"{self._var_name}.split({other._var_full_name})",
_var_is_string=False,
_var_type=list[str],
_var_is_local=self._var_is_local,
merge_var_data=other._var_data,
)
def join(self, other: str | Var[str] | None = None) -> Var:
@ -1125,10 +1306,11 @@ class Var:
else:
other = Var.create_safe(other)
return BaseVar(
_var_name=f"{self._var_full_name}.join({other._var_full_name})",
return self._replace(
_var_name=f"{self._var_name}.join({other._var_full_name})",
_var_is_string=False,
_var_type=str,
_var_is_local=self._var_is_local,
merge_var_data=other._var_data,
)
def foreach(self, fn: Callable) -> Var:
@ -1159,10 +1341,9 @@ class Var:
fn_signature = inspect.signature(fn)
fn_args = (arg, index)
fn_ret = fn(*fn_args[: len(fn_signature.parameters)])
return BaseVar(
return self._replace(
_var_name=f"{self._var_full_name}.map(({arg._var_name}, {index._var_name}) => {fn_ret})",
_var_type=self._var_type,
_var_is_local=self._var_is_local,
_var_is_string=False,
)
@classmethod
@ -1207,6 +1388,18 @@ class Var:
_var_name=f"Array.from(range({v1._var_full_name}, {v2._var_full_name}, {step._var_name}))",
_var_type=list[int],
_var_is_local=False,
_var_data=VarData.merge(
v1._var_data,
v2._var_data,
step._var_data,
VarData(
imports={
"/utils/helpers/range.js": [
ImportVar(tag="range", is_default=True),
],
},
),
),
)
def to(self, type_: Type) -> Var:
@ -1218,12 +1411,7 @@ class Var:
Returns:
The converted var.
"""
return BaseVar(
_var_name=self._var_name,
_var_type=type_,
_var_state=self._var_state,
_var_is_local=self._var_is_local,
)
return self._replace(_var_type=type_)
@property
def _var_full_name(self) -> str:
@ -1232,24 +1420,51 @@ class Var:
Returns:
The full name of the var.
"""
if not self._var_full_name_needs_state_prefix:
return self._var_name
return (
self._var_name
if self._var_state == ""
else ".".join([self._var_state, self._var_name])
if self._var_data is None or self._var_data.state == ""
else ".".join(
[format.format_state_name(self._var_data.state), self._var_name]
)
)
def _var_set_state(self, state: Type[State]) -> Any:
def _var_set_state(self, state: Type[State] | str) -> Any:
"""Set the state of the var.
Args:
state: The state to set.
state: The state to set or the full name of the state.
Returns:
The var with the set state.
"""
self._var_state = state.get_full_name()
state_name = state if isinstance(state, str) else state.get_full_name()
new_var_data = VarData(
state=state_name,
hooks={
"const {0} = useContext(StateContexts.{0})".format(
format.format_state_name(state_name)
)
},
imports={
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
"react": [ImportVar(tag="useContext")],
},
)
self._var_data = VarData.merge(self._var_data, new_var_data)
self._var_full_name_needs_state_prefix = True
return self
@property
def _var_state(self) -> str:
"""Compat method for getting the state.
Returns:
The state name associated with the var.
"""
return self._var_data.state if self._var_data else ""
@dataclasses.dataclass(
eq=False,
@ -1264,15 +1479,18 @@ class BaseVar(Var):
# The type of the var.
_var_type: Type = dataclasses.field(default=Any)
# The name of the enclosing state.
_var_state: str = dataclasses.field(default="")
# Whether this is a local javascript variable.
_var_is_local: bool = dataclasses.field(default=False)
# Whether the var is a string literal.
_var_is_string: bool = dataclasses.field(default=False)
# _var_full_name should be prefixed with _var_state
_var_full_name_needs_state_prefix: bool = dataclasses.field(default=False)
# Extra metadata associated with the Var
_var_data: Optional[VarData] = dataclasses.field(default=None)
def __hash__(self) -> int:
"""Define a hash function for a var.
@ -1334,9 +1552,11 @@ class BaseVar(Var):
The name of the setter function.
"""
setter = constants.SETTER_PREFIX + self._var_name
if not include_state or self._var_state == "":
if self._var_data is None:
return setter
return ".".join((self._var_state, setter))
if not include_state or self._var_data.state == "":
return setter
return ".".join((self._var_data.state, setter))
def get_setter(self) -> Callable[[State, Any], None]:
"""Get the var's setter function.
@ -1550,48 +1770,6 @@ def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
return cvar
class ImportVar(Base):
"""An import var."""
# The name of the import tag.
tag: Optional[str]
# whether the import is default or named.
is_default: Optional[bool] = False
# The tag alias.
alias: Optional[str] = None
# Whether this import need to install the associated lib
install: Optional[bool] = True
# whether this import should be rendered or not
render: Optional[bool] = True
@property
def name(self) -> str:
"""The name of the import.
Returns:
The name(tag name with alias) of tag.
"""
return self.tag if not self.alias else " as ".join([self.tag, self.alias]) # type: ignore
def __hash__(self) -> int:
"""Define a hash function for the import var.
Returns:
The hash of the var.
"""
return hash((self.tag, self.is_default, self.alias, self.install, self.render))
class NoRenderImportVar(ImportVar):
"""A import that doesn't need to be rendered."""
render: Optional[bool] = False
class CallableVar(BaseVar):
"""Decorate a Var-returning function to act as both a Var and a function.

View File

@ -6,11 +6,13 @@ from reflex import constants as constants
from reflex.base import Base as Base
from reflex.state import State as State
from reflex.utils import console as console, format as format, types as types
from reflex.utils.imports import ImportVar
from types import FunctionType
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
@ -22,13 +24,24 @@ from typing import (
USED_VARIABLES: Incomplete
def get_unique_variable_name() -> str: ...
def _encode_var(value: Var) -> str: ...
def _decode_var(value: str) -> tuple[VarData, str]: ...
def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
class VarData(Base):
state: str
imports: dict[str, set[ImportVar]]
hooks: set[str]
@classmethod
def merge(cls, *others: VarData | None) -> VarData | None: ...
class Var:
_var_name: str
_var_type: Type
_var_state: str = ""
_var_is_local: bool = False
_var_is_string: bool = False
_var_full_name_needs_state_prefix: bool = False
_var_data: VarData | None = None
@classmethod
def create(
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
@ -38,7 +51,8 @@ class Var:
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
) -> Var: ...
@classmethod
def __class_getitem__(cls, type_: str) -> _GenericAlias: ...
def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...
def _replace(self, merge_var_data=None, **kwargs: Any) -> Var: ...
def equals(self, other: Var) -> bool: ...
def to_string(self) -> Var: ...
def __hash__(self) -> int: ...
@ -95,15 +109,16 @@ class Var:
def to(self, type_: Type) -> Var: ...
@property
def _var_full_name(self) -> str: ...
def _var_set_state(self, state: Type[State]) -> Any: ...
def _var_set_state(self, state: Type[State] | str) -> Any: ...
@dataclass(eq=False)
class BaseVar(Var):
_var_name: str
_var_type: Any
_var_state: str = ""
_var_is_local: bool = False
_var_is_string: bool = False
_var_full_name_needs_state_prefix: bool = False
_var_data: VarData | None = None
def __hash__(self) -> int: ...
def get_default_value(self) -> Any: ...
def get_setter_name(self, include_state: bool = ...) -> str: ...
@ -123,21 +138,6 @@ class ComputedVar(Var):
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
class ImportVar(Base):
tag: Optional[str]
is_default: Optional[bool] = False
alias: Optional[str] = None
install: Optional[bool] = True
render: Optional[bool] = True
@property
def name(self) -> str: ...
def __hash__(self) -> int: ...
class NoRenderImportVar(ImportVar):
"""A import that doesn't need to be rendered."""
def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ...
class CallableVar(BaseVar):
def __init__(self, fn: Callable[..., BaseVar]): ...
def __call__(self, *args, **kwargs) -> BaseVar: ...

View File

@ -5,7 +5,7 @@ import pytest
from reflex.compiler import compiler, utils
from reflex.utils import imports
from reflex.vars import ImportVar
from reflex.utils.imports import ImportVar
@pytest.mark.parametrize(

View File

@ -110,7 +110,7 @@ def test_cond_no_else():
# Props do not support the use of cond without else
with pytest.raises(ValueError):
cond(True, "hello")
cond(True, "hello") # type: ignore
def test_mobile_only():

View File

@ -4,14 +4,16 @@ import pytest
import reflex as rx
from reflex.base import Base
from reflex.components.base.bare import Bare
from reflex.components.component import Component, CustomComponent, custom_component
from reflex.components.layout.box import Box
from reflex.constants import EventTriggers
from reflex.event import EventHandler
from reflex.event import EventChain, EventHandler
from reflex.state import State
from reflex.style import Style
from reflex.utils import imports
from reflex.vars import ImportVar, Var
from reflex.utils.imports import ImportVar
from reflex.vars import Var, VarData
@pytest.fixture
@ -600,3 +602,161 @@ def test_format_component(component, rendered):
rendered: The expected rendered component.
"""
assert str(component) == rendered
TEST_VAR = Var.create_safe("test")._replace(
merge_var_data=VarData(
hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test"
)
)
FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar")
STYLE_VAR = TEST_VAR._replace(_var_name="style", _var_is_local=False)
EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain)
ARG_VAR = Var.create("arg")
class EventState(rx.State):
"""State for testing event handlers with _get_vars."""
v: int = 42
def handler(self):
"""A handler that does nothing."""
def handler2(self, arg):
"""A handler that takes an arg.
Args:
arg: An arg.
"""
@pytest.mark.parametrize(
("component", "exp_vars"),
(
pytest.param(
Bare.create(TEST_VAR),
[TEST_VAR],
id="direct-bare",
),
pytest.param(
Bare.create(f"foo{TEST_VAR}bar"),
[FORMATTED_TEST_VAR],
id="fstring-bare",
),
pytest.param(
rx.text(as_=TEST_VAR),
[TEST_VAR],
id="direct-prop",
),
pytest.param(
rx.text(as_=f"foo{TEST_VAR}bar"),
[FORMATTED_TEST_VAR],
id="fstring-prop",
),
pytest.param(
rx.fragment(id=TEST_VAR),
[TEST_VAR],
id="direct-id",
),
pytest.param(
rx.fragment(id=f"foo{TEST_VAR}bar"),
[FORMATTED_TEST_VAR],
id="fstring-id",
),
pytest.param(
rx.fragment(key=TEST_VAR),
[TEST_VAR],
id="direct-key",
),
pytest.param(
rx.fragment(key=f"foo{TEST_VAR}bar"),
[FORMATTED_TEST_VAR],
id="fstring-key",
),
pytest.param(
rx.fragment(class_name=TEST_VAR),
[TEST_VAR],
id="direct-class_name",
),
pytest.param(
rx.fragment(class_name=f"foo{TEST_VAR}bar"),
[FORMATTED_TEST_VAR],
id="fstring-class_name",
),
pytest.param(
rx.fragment(special_props={TEST_VAR}),
[TEST_VAR],
id="direct-special_props",
),
pytest.param(
rx.fragment(special_props={Var.create(f"foo{TEST_VAR}bar")}),
[FORMATTED_TEST_VAR],
id="fstring-special_props",
),
pytest.param(
# custom_attrs cannot accept a Var directly as a value
rx.fragment(custom_attrs={"href": f"{TEST_VAR}"}),
[TEST_VAR],
id="fstring-custom_attrs-nofmt",
),
pytest.param(
rx.fragment(custom_attrs={"href": f"foo{TEST_VAR}bar"}),
[FORMATTED_TEST_VAR],
id="fstring-custom_attrs",
),
pytest.param(
rx.fragment(background_color=TEST_VAR),
[STYLE_VAR],
id="direct-background_color",
),
pytest.param(
rx.fragment(background_color=f"foo{TEST_VAR}bar"),
[STYLE_VAR],
id="fstring-background_color",
),
pytest.param(
rx.fragment(style={"background_color": TEST_VAR}), # type: ignore
[STYLE_VAR],
id="direct-style-background_color",
),
pytest.param(
rx.fragment(style={"background_color": f"foo{TEST_VAR}bar"}), # type: ignore
[STYLE_VAR],
id="fstring-style-background_color",
),
pytest.param(
rx.fragment(on_click=EVENT_CHAIN_VAR), # type: ignore
[EVENT_CHAIN_VAR],
id="direct-event-chain",
),
pytest.param(
rx.fragment(on_click=EventState.handler),
[],
id="direct-event-handler",
),
pytest.param(
rx.fragment(on_click=EventState.handler2(TEST_VAR)), # type: ignore
[ARG_VAR, TEST_VAR],
id="direct-event-handler-arg",
),
pytest.param(
rx.fragment(on_click=EventState.handler2(EventState.v)), # type: ignore
[ARG_VAR, EventState.v],
id="direct-event-handler-arg2",
),
pytest.param(
rx.fragment(on_click=lambda: EventState.handler2(TEST_VAR)), # type: ignore
[ARG_VAR, TEST_VAR],
id="direct-event-handler-lambda",
),
),
)
def test_get_vars(component, exp_vars):
comp_vars = sorted(component._get_vars(), key=lambda v: v._var_name)
assert len(comp_vars) == len(exp_vars)
for comp_var, exp_var in zip(
comp_vars,
sorted(exp_vars, key=lambda v: v._var_name),
):
assert comp_var.equals(exp_var)

View File

@ -104,7 +104,7 @@ async def test_preprocess(
app=app, event=request.getfixturevalue(event_fixture), state=state
)
assert isinstance(update, StateUpdate)
assert update.delta == {state.get_name(): state.dict()}
assert update.delta == state.dict()
events = update.events
assert len(events) == 2
@ -133,7 +133,7 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
update = await hydrate_middleware.preprocess(app=app, event=event1, state=state)
assert isinstance(update, StateUpdate)
assert update.delta == {"test_state": state.dict()}
assert update.delta == state.dict()
assert len(update.events) == 3
# Apply the events.
@ -163,7 +163,7 @@ async def test_preprocess_no_events(hydrate_middleware, event1):
state=state,
)
assert isinstance(update, StateUpdate)
assert update.delta == {"test_state": state.dict()}
assert update.delta == state.dict()
assert len(update.events) == 1
assert isinstance(update, StateUpdate)

View File

@ -769,9 +769,7 @@ async def test_upload_file(tmp_path, state, delta, token: str):
)
current_state = await app.state_manager.get_state(token)
state_dict = current_state.dict()
for substate in state.get_full_name().split(".")[1:]:
state_dict = state_dict[substate]
state_dict = current_state.dict()[state.get_full_name()]
assert state_dict["img_list"] == [
"image1.jpg",
"image2.jpg",

View File

@ -324,11 +324,17 @@ def test_dict(test_state):
Args:
test_state: A state.
"""
substates = {"child_state", "child_state2"}
assert set(test_state.dict().keys()) == set(test_state.vars.keys()) | substates
assert (
set(test_state.dict(include_computed=False).keys())
== set(test_state.base_vars) | substates
substates = {
"test_state",
"test_state.child_state",
"test_state.child_state.grandchild_state",
"test_state.child_state2",
}
test_state_dict = test_state.dict()
assert set(test_state_dict) == substates
assert set(test_state_dict[test_state.get_name()]) == set(test_state.vars)
assert set(test_state.dict(include_computed=False)[test_state.get_name()]) == set(
test_state.base_vars
)
@ -1081,9 +1087,9 @@ def test_computed_var_cached():
return self.v
cs = ComputedState()
assert cs.dict()["v"] == 0
assert cs.dict()[cs.get_full_name()]["v"] == 0
assert comp_v_calls == 1
assert cs.dict()["comp_v"] == 0
assert cs.dict()[cs.get_full_name()]["comp_v"] == 0
assert comp_v_calls == 1
assert cs.comp_v == 0
assert comp_v_calls == 1
@ -1156,24 +1162,27 @@ def test_computed_var_depends_on_parent_non_cached():
assert ps.dirty_vars == set()
assert cs.dirty_vars == set()
assert ps.dict() == {
cs.get_name(): {"dep_v": 2},
dict1 = ps.dict()
assert dict1[ps.get_full_name()] == {
"no_cache_v": 1,
CompileVars.IS_HYDRATED: False,
"router": formatted_router,
}
assert ps.dict() == {
cs.get_name(): {"dep_v": 4},
assert dict1[cs.get_full_name()] == {"dep_v": 2}
dict2 = ps.dict()
assert dict2[ps.get_full_name()] == {
"no_cache_v": 3,
CompileVars.IS_HYDRATED: False,
"router": formatted_router,
}
assert ps.dict() == {
cs.get_name(): {"dep_v": 6},
assert dict2[cs.get_full_name()] == {"dep_v": 4}
dict3 = ps.dict()
assert dict3[ps.get_full_name()] == {
"no_cache_v": 5,
CompileVars.IS_HYDRATED: False,
"router": formatted_router,
}
assert dict3[cs.get_full_name()] == {"dep_v": 6}
assert counter == 6
@ -2201,13 +2210,13 @@ def test_json_dumps_with_mutables():
items: List[Foo] = [Foo()]
dict_val = MutableContainsBase().dict()
assert isinstance(dict_val["items"][0], dict)
assert isinstance(dict_val[MutableContainsBase.get_full_name()]["items"][0], dict)
val = json_dumps(dict_val)
f_items = '[{"tags": ["123", "456"]}]'
f_formatted_router = str(formatted_router).replace("'", '"')
assert (
val
== f'{{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}'
== f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}'
)

View File

@ -22,7 +22,8 @@ def test_convert(style_dict, expected):
style_dict: The style to check.
expected: The expected formatted style.
"""
assert style.convert(style_dict) == expected
converted_dict, _var_data = style.convert(style_dict)
assert converted_dict == expected
@pytest.mark.parametrize(

View File

@ -7,18 +7,20 @@ from pandas import DataFrame
from reflex.base import Base
from reflex.state import State
from reflex.utils.imports import ImportVar
from reflex.vars import (
BaseVar,
ComputedVar,
ImportVar,
Var,
)
test_vars = [
BaseVar(_var_name="prop1", _var_type=int),
BaseVar(_var_name="key", _var_type=str),
BaseVar(_var_name="value", _var_type=str, _var_state="state"),
BaseVar(_var_name="local", _var_type=str, _var_state="state", _var_is_local=True),
BaseVar(_var_name="value", _var_type=str)._var_set_state("state"),
BaseVar(_var_name="local", _var_type=str, _var_is_local=True)._var_set_state(
"state"
),
BaseVar(_var_name="local2", _var_type=str, _var_is_local=True),
]
@ -263,7 +265,7 @@ def test_basic_operations(TestObj):
assert str(v([1, 2, 3])[v(0)]) == "{[1, 2, 3].at(0)}"
assert str(v({"a": 1, "b": 2})["a"]) == '{{"a": 1, "b": 2}["a"]}'
assert (
str(BaseVar(_var_name="foo", _var_state="state", _var_type=TestObj).bar)
str(BaseVar(_var_name="foo", _var_type=TestObj)._var_set_state("state").bar)
== "{state.foo.bar}"
)
assert str(abs(v(1))) == "{Math.abs(1)}"
@ -274,7 +276,7 @@ def test_basic_operations(TestObj):
assert str(v([1, 2, 3]).reverse()) == "{[...[1, 2, 3]].reverse()}"
assert str(v(["1", "2", "3"]).reverse()) == '{[...["1", "2", "3"]].reverse()}'
assert (
str(BaseVar(_var_name="foo", _var_state="state", _var_type=list).reverse())
str(BaseVar(_var_name="foo", _var_type=list)._var_set_state("state").reverse())
== "{[...state.foo].reverse()}"
)
assert (
@ -288,11 +290,14 @@ def test_basic_operations(TestObj):
[
(v([1, 2, 3]), "[1, 2, 3]"),
(v(["1", "2", "3"]), '["1", "2", "3"]'),
(BaseVar(_var_name="foo", _var_state="state", _var_type=list), "state.foo"),
(BaseVar(_var_name="foo", _var_type=list)._var_set_state("state"), "state.foo"),
(BaseVar(_var_name="foo", _var_type=list), "foo"),
(v((1, 2, 3)), "[1, 2, 3]"),
(v(("1", "2", "3")), '["1", "2", "3"]'),
(BaseVar(_var_name="foo", _var_state="state", _var_type=tuple), "state.foo"),
(
BaseVar(_var_name="foo", _var_type=tuple)._var_set_state("state"),
"state.foo",
),
(BaseVar(_var_name="foo", _var_type=tuple), "foo"),
],
)
@ -301,7 +306,7 @@ def test_list_tuple_contains(var, expected):
assert str(var.contains("1")) == f'{{{expected}.includes("1")}}'
assert str(var.contains(v(1))) == f"{{{expected}.includes(1)}}"
assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}'
other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
other_var = BaseVar(_var_name="other", _var_type=str)
assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}"
assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}"
@ -311,14 +316,14 @@ def test_list_tuple_contains(var, expected):
"var, expected",
[
(v("123"), json.dumps("123")),
(BaseVar(_var_name="foo", _var_state="state", _var_type=str), "state.foo"),
(BaseVar(_var_name="foo", _var_type=str)._var_set_state("state"), "state.foo"),
(BaseVar(_var_name="foo", _var_type=str), "foo"),
],
)
def test_str_contains(var, expected):
assert str(var.contains("1")) == f'{{{expected}.includes("1")}}'
assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}'
other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
other_var = BaseVar(_var_name="other", _var_type=str)
assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}"
assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}"
@ -328,7 +333,7 @@ def test_str_contains(var, expected):
"var, expected",
[
(v({"a": 1, "b": 2}), '{"a": 1, "b": 2}'),
(BaseVar(_var_name="foo", _var_state="state", _var_type=dict), "state.foo"),
(BaseVar(_var_name="foo", _var_type=dict)._var_set_state("state"), "state.foo"),
(BaseVar(_var_name="foo", _var_type=dict), "foo"),
],
)
@ -337,7 +342,7 @@ def test_dict_contains(var, expected):
assert str(var.contains("1")) == f'{{{expected}.hasOwnProperty("1")}}'
assert str(var.contains(v(1))) == f"{{{expected}.hasOwnProperty(1)}}"
assert str(var.contains(v("1"))) == f'{{{expected}.hasOwnProperty("1")}}'
other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
other_var = BaseVar(_var_name="other", _var_type=str)
assert (
str(var.contains(other_state_var))
@ -548,10 +553,10 @@ def test_var_unsupported_indexing_dicts(var, index):
"fixture,full_name",
[
("ParentState", "parent_state.var_without_annotation"),
("ChildState", "parent_state.child_state.var_without_annotation"),
("ChildState", "parent_state__child_state.var_without_annotation"),
(
"GrandChildState",
"parent_state.child_state.grand_child_state.var_without_annotation",
"parent_state__child_state__grand_child_state.var_without_annotation",
),
("StateWithAnyVar", "state_with_any_var.var_without_annotation"),
],
@ -630,8 +635,8 @@ def test_import_var(import_var, expected):
[
(f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
(
f"testing f-string with {BaseVar(_var_name='myvar', _var_state='state', _var_type=int)}",
"testing f-string with ${state.myvar}",
f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
'testing f-string with $<reflex.Var>{"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}</reflex.Var>{state.myvar}',
),
(
f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",
@ -643,6 +648,35 @@ def test_fstrings(out, expected):
assert out == expected
@pytest.mark.parametrize(
("value", "expect_state"),
[
([1], ""),
({"a": 1}, ""),
([Var.create_safe(1)._var_set_state("foo")], "foo"),
({"a": Var.create_safe(1)._var_set_state("foo")}, "foo"),
],
)
def test_extract_state_from_container(value, expect_state):
"""Test that _var_state is extracted from containers containing BaseVar.
Args:
value: The value to create a var from.
expect_state: The expected state.
"""
assert Var.create_safe(value)._var_state == expect_state
def test_fstring_roundtrip():
"""Test that f-string roundtrip carries state."""
var = BaseVar.create_safe("var")._var_set_state("state")
rt_var = Var.create_safe(f"{var}")
assert var._var_state == rt_var._var_state
assert var._var_full_name_needs_state_prefix
assert not rt_var._var_full_name_needs_state_prefix
assert rt_var._var_name == var._var_full_name
@pytest.mark.parametrize(
"var",
[

View File

@ -8,7 +8,13 @@ from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent
from reflex.style import Style
from reflex.utils import format
from reflex.vars import BaseVar, Var
from tests.test_state import ChildState, DateTimeState, GrandchildState, TestState
from tests.test_state import (
ChildState,
ChildState2,
DateTimeState,
GrandchildState,
TestState,
)
def mock_event(arg):
@ -349,7 +355,6 @@ def test_format_cond(condition: str, true_value: str, false_value: str, expected
BaseVar(
_var_name="_",
_var_type=Any,
_var_state="",
_var_is_local=True,
_var_is_string=False,
),
@ -515,40 +520,44 @@ formatted_router = {
(
TestState().dict(), # type: ignore
{
"array": [1, 2, 3.14],
"child_state": {
TestState.get_full_name(): {
"array": [1, 2, 3.14],
"complex": {
1: {"prop1": 42, "prop2": "hello"},
2: {"prop1": 42, "prop2": "hello"},
},
"dt": "1989-11-09 18:53:00+01:00",
"fig": [],
"is_hydrated": False,
"key": "",
"map_key": "a",
"mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
"num1": 0,
"num2": 3.14,
"obj": {"prop1": 42, "prop2": "hello"},
"sum": 3.14,
"upper": "",
"router": formatted_router,
},
ChildState.get_full_name(): {
"count": 23,
"grandchild_state": {"value2": ""},
"value": "",
},
"child_state2": {"value": ""},
"complex": {
1: {"prop1": 42, "prop2": "hello"},
2: {"prop1": 42, "prop2": "hello"},
},
"dt": "1989-11-09 18:53:00+01:00",
"fig": [],
"is_hydrated": False,
"key": "",
"map_key": "a",
"mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
"num1": 0,
"num2": 3.14,
"obj": {"prop1": 42, "prop2": "hello"},
"sum": 3.14,
"upper": "",
"router": formatted_router,
ChildState2.get_full_name(): {"value": ""},
GrandchildState.get_full_name(): {"value2": ""},
},
),
(
DateTimeState().dict(),
{
"d": "1989-11-09",
"dt": "1989-11-09 18:53:00+01:00",
"is_hydrated": False,
"t": "18:53:00+01:00",
"td": "11 days, 0:11:00",
"router": formatted_router,
DateTimeState.get_full_name(): {
"d": "1989-11-09",
"dt": "1989-11-09 18:53:00+01:00",
"is_hydrated": False,
"t": "18:53:00+01:00",
"td": "11 days, 0:11:00",
"router": formatted_router,
},
},
),
],