[REF-889] useContext per substate (#2149)
This commit is contained in:
parent
e9437ad941
commit
1603144c7d
@ -28,6 +28,7 @@ def VarOperations():
|
||||
str_var4: str = "a long string"
|
||||
dict1: dict = {1: 2}
|
||||
dict2: dict = {3: 4}
|
||||
html_str: str = "<div>hello</div>"
|
||||
|
||||
app = rx.App(state=VarOperationState)
|
||||
|
||||
@ -522,6 +523,19 @@ def VarOperations():
|
||||
rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"),
|
||||
rx.text(VarOperationState.list3.join(""), id="list_join"),
|
||||
rx.text(VarOperationState.list3.join(","), id="list_join_comma"),
|
||||
# Index from an op var
|
||||
rx.text(
|
||||
VarOperationState.list3[VarOperationState.int_var1 % 3],
|
||||
id="list_index_mod",
|
||||
),
|
||||
rx.html(
|
||||
VarOperationState.html_str,
|
||||
id="html_str",
|
||||
),
|
||||
rx.highlight(
|
||||
"second",
|
||||
query=[VarOperationState.str_var2],
|
||||
),
|
||||
rx.text(rx.Var.range(2, 5).join(","), id="list_join_range1"),
|
||||
rx.text(rx.Var.range(2, 10, 2).join(","), id="list_join_range2"),
|
||||
rx.text(rx.Var.range(5, 0, -1).join(","), id="list_join_range3"),
|
||||
@ -713,7 +727,14 @@ def test_var_operations(driver, var_operations: AppHarness):
|
||||
("dict_eq_dict", "false"),
|
||||
("dict_neq_dict", "true"),
|
||||
("dict_contains", "true"),
|
||||
# index from an op var
|
||||
("list_index_mod", "second"),
|
||||
# html component with var
|
||||
("html_str", "hello"),
|
||||
]
|
||||
|
||||
for tag, expected in tests:
|
||||
assert driver.find_element(By.ID, tag).text == expected
|
||||
|
||||
# Highlight component with var query (does not plumb ID)
|
||||
assert driver.find_element(By.TAG_NAME, "mark").text == "second"
|
||||
|
@ -1,7 +1,7 @@
|
||||
{% extends "web/pages/base_page.js.jinja2" %}
|
||||
|
||||
{% block declaration %}
|
||||
import { EventLoopProvider } from "/utils/context.js";
|
||||
import { EventLoopProvider, StateProvider } from "/utils/context.js";
|
||||
import { ThemeProvider } from 'next-themes'
|
||||
|
||||
{% for custom_code in custom_codes %}
|
||||
@ -25,12 +25,14 @@ export default function MyApp({ Component, pageProps }) {
|
||||
return (
|
||||
<ThemeProvider defaultTheme="light" storageKey="chakra-ui-color-mode" attribute="class">
|
||||
<AppWrap>
|
||||
<EventLoopProvider>
|
||||
<Component {...pageProps} />
|
||||
</EventLoopProvider>
|
||||
<StateProvider>
|
||||
<EventLoopProvider>
|
||||
<Component {...pageProps} />
|
||||
</EventLoopProvider>
|
||||
</StateProvider>
|
||||
</AppWrap>
|
||||
</ThemeProvider>
|
||||
);
|
||||
}
|
||||
|
||||
{% endblock %}
|
||||
{% endblock %}
|
||||
|
@ -8,32 +8,6 @@
|
||||
|
||||
{% block export %}
|
||||
export default function Component() {
|
||||
{% if state_name %}
|
||||
const {{state_name}} = useContext(StateContext)
|
||||
{% endif %}
|
||||
const {{const.router}} = useRouter()
|
||||
const [ {{const.color_mode}}, {{const.toggle_color_mode}} ] = useContext(ColorModeContext)
|
||||
const focusRef = useRef();
|
||||
|
||||
// Main event loop.
|
||||
const [addEvents, connectError] = useContext(EventLoopContext)
|
||||
|
||||
// Set focus to the specified element.
|
||||
useEffect(() => {
|
||||
if (focusRef.current) {
|
||||
focusRef.current.focus();
|
||||
}
|
||||
})
|
||||
|
||||
// Route after the initial page hydration.
|
||||
useEffect(() => {
|
||||
const change_complete = () => addEvents(initialEvents())
|
||||
{{const.router}}.events.on('routeChangeComplete', change_complete)
|
||||
return () => {
|
||||
{{const.router}}.events.off('routeChangeComplete', change_complete)
|
||||
}
|
||||
}, [{{const.router}}])
|
||||
|
||||
{% for hook in hooks %}
|
||||
{{ hook }}
|
||||
{% endfor %}
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { createContext, useState } from "react"
|
||||
import { Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
|
||||
import { createContext, useContext, useMemo, useReducer, useState } from "react"
|
||||
import { applyDelta, Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
|
||||
|
||||
{% if initial_state %}
|
||||
export const initialState = {{ initial_state|json_dumps }}
|
||||
@ -8,7 +8,12 @@ export const initialState = {}
|
||||
{% endif %}
|
||||
|
||||
export const ColorModeContext = createContext(null);
|
||||
export const StateContext = createContext(null);
|
||||
export const DispatchContext = createContext(null);
|
||||
export const StateContexts = {
|
||||
{% for state_name in initial_state %}
|
||||
{{state_name|var_name}}: createContext(null),
|
||||
{% endfor %}
|
||||
}
|
||||
export const EventLoopContext = createContext(null);
|
||||
{% if client_storage %}
|
||||
export const clientStorage = {{ client_storage|json_dumps }}
|
||||
@ -27,16 +32,40 @@ export const initialEvents = () => []
|
||||
export const isDevMode = {{ is_dev_mode|json_dumps }}
|
||||
|
||||
export function EventLoopProvider({ children }) {
|
||||
const [state, addEvents, connectError] = useEventLoop(
|
||||
initialState,
|
||||
const dispatch = useContext(DispatchContext)
|
||||
const [addEvents, connectError] = useEventLoop(
|
||||
dispatch,
|
||||
initialEvents,
|
||||
clientStorage,
|
||||
)
|
||||
return (
|
||||
<EventLoopContext.Provider value={[addEvents, connectError]}>
|
||||
<StateContext.Provider value={state}>
|
||||
{children}
|
||||
</StateContext.Provider>
|
||||
{children}
|
||||
</EventLoopContext.Provider>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
export function StateProvider({ children }) {
|
||||
{% for state_name in initial_state %}
|
||||
const [{{state_name|var_name}}, dispatch_{{state_name|var_name}}] = useReducer(applyDelta, initialState["{{state_name}}"])
|
||||
{% endfor %}
|
||||
const dispatchers = useMemo(() => {
|
||||
return {
|
||||
{% for state_name in initial_state %}
|
||||
"{{state_name}}": dispatch_{{state_name|var_name}},
|
||||
{% endfor %}
|
||||
}
|
||||
}, [])
|
||||
|
||||
return (
|
||||
{% for state_name in initial_state %}
|
||||
<StateContexts.{{state_name|var_name}}.Provider value={ {{state_name|var_name}} }>
|
||||
{% endfor %}
|
||||
<DispatchContext.Provider value={dispatchers}>
|
||||
{children}
|
||||
</DispatchContext.Provider>
|
||||
{% for state_name in initial_state|reverse %}
|
||||
</StateContexts.{{state_name|var_name}}.Provider>
|
||||
{% endfor %}
|
||||
)
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import env from "env.json";
|
||||
import Cookies from "universal-cookie";
|
||||
import { useEffect, useReducer, useRef, useState } from "react";
|
||||
import Router, { useRouter } from "next/router";
|
||||
import { initialEvents } from "utils/context.js"
|
||||
import { initialEvents, initialState } from "utils/context.js"
|
||||
|
||||
// Endpoint URLs.
|
||||
const EVENTURL = env.EVENT
|
||||
@ -100,37 +100,10 @@ export const getEventURL = () => {
|
||||
* @param delta The delta to apply.
|
||||
*/
|
||||
export const applyDelta = (state, delta) => {
|
||||
const new_state = { ...state }
|
||||
for (const substate in delta) {
|
||||
let s = new_state;
|
||||
const path = substate.split(".").slice(1);
|
||||
while (path.length > 0) {
|
||||
s = s[path.shift()];
|
||||
}
|
||||
for (const key in delta[substate]) {
|
||||
s[key] = delta[substate][key];
|
||||
}
|
||||
}
|
||||
return new_state
|
||||
return { ...state, ...delta }
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* Get all local storage items in a key-value object.
|
||||
* @returns object of items in local storage.
|
||||
*/
|
||||
export const getAllLocalStorageItems = () => {
|
||||
var localStorageItems = {};
|
||||
|
||||
for (var i = 0, len = localStorage.length; i < len; i++) {
|
||||
var key = localStorage.key(i);
|
||||
localStorageItems[key] = localStorage.getItem(key);
|
||||
}
|
||||
|
||||
return localStorageItems;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Handle frontend event or send the event to the backend via Websocket.
|
||||
* @param event The event to send.
|
||||
@ -346,7 +319,9 @@ export const connect = async (
|
||||
// On each received message, queue the updates and events.
|
||||
socket.current.on("event", message => {
|
||||
const update = JSON5.parse(message)
|
||||
dispatch(update.delta)
|
||||
for (const substate in update.delta) {
|
||||
dispatch[substate](update.delta[substate])
|
||||
}
|
||||
applyClientStorageDelta(client_storage, update.delta)
|
||||
event_processing = !update.final
|
||||
if (update.events) {
|
||||
@ -524,23 +499,21 @@ const applyClientStorageDelta = (client_storage, delta) => {
|
||||
|
||||
/**
|
||||
* Establish websocket event loop for a NextJS page.
|
||||
* @param initial_state The initial app state.
|
||||
* @param initial_events Function that returns the initial app events.
|
||||
* @param dispatch The reducer dispatch function to update state.
|
||||
* @param initial_events The initial app events.
|
||||
* @param client_storage The client storage object from context.js
|
||||
*
|
||||
* @returns [state, addEvents, connectError] -
|
||||
* state is a reactive dict,
|
||||
* @returns [addEvents, connectError] -
|
||||
* addEvents is used to queue an event, and
|
||||
* connectError is a reactive js error from the websocket connection (or null if connected).
|
||||
*/
|
||||
export const useEventLoop = (
|
||||
initial_state = {},
|
||||
dispatch,
|
||||
initial_events = () => [],
|
||||
client_storage = {},
|
||||
) => {
|
||||
const socket = useRef(null)
|
||||
const router = useRouter()
|
||||
const [state, dispatch] = useReducer(applyDelta, initial_state)
|
||||
const [connectError, setConnectError] = useState(null)
|
||||
|
||||
// Function to add new events to the event queue.
|
||||
@ -570,7 +543,7 @@ export const useEventLoop = (
|
||||
return;
|
||||
}
|
||||
// only use websockets if state is present
|
||||
if (Object.keys(state).length > 0) {
|
||||
if (Object.keys(initialState).length > 0) {
|
||||
// Initialize the websocket connection.
|
||||
if (!socket.current) {
|
||||
connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage)
|
||||
@ -583,7 +556,17 @@ export const useEventLoop = (
|
||||
})()
|
||||
}
|
||||
})
|
||||
return [state, addEvents, connectError]
|
||||
|
||||
// Route after the initial page hydration.
|
||||
useEffect(() => {
|
||||
const change_complete = () => addEvents(initial_events())
|
||||
router.events.on('routeChangeComplete', change_complete)
|
||||
return () => {
|
||||
router.events.off('routeChangeComplete', change_complete)
|
||||
}
|
||||
}, [router])
|
||||
|
||||
return [addEvents, connectError]
|
||||
}
|
||||
|
||||
/***
|
||||
|
@ -63,7 +63,7 @@ from reflex.state import (
|
||||
StateUpdate,
|
||||
)
|
||||
from reflex.utils import console, format, prerequisites, types
|
||||
from reflex.vars import ImportVar
|
||||
from reflex.utils.imports import ImportVar
|
||||
|
||||
# Define custom types.
|
||||
ComponentCallable = Callable[[], Component]
|
||||
|
@ -10,40 +10,10 @@ from reflex.compiler import templates, utils
|
||||
from reflex.components.component import Component, ComponentStyle, CustomComponent
|
||||
from reflex.config import get_config
|
||||
from reflex.state import State
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar
|
||||
from reflex.utils.imports import ImportDict, ImportVar
|
||||
|
||||
# Imports to be included in every Reflex app.
|
||||
DEFAULT_IMPORTS: imports.ImportDict = {
|
||||
"react": [
|
||||
ImportVar(tag="Fragment"),
|
||||
ImportVar(tag="useEffect"),
|
||||
ImportVar(tag="useRef"),
|
||||
ImportVar(tag="useState"),
|
||||
ImportVar(tag="useContext"),
|
||||
],
|
||||
"next/router": [ImportVar(tag="useRouter")],
|
||||
f"/{constants.Dirs.STATE_PATH}": [
|
||||
ImportVar(tag="uploadFiles"),
|
||||
ImportVar(tag="Event"),
|
||||
ImportVar(tag="isTrue"),
|
||||
ImportVar(tag="spreadArraysOrObjects"),
|
||||
ImportVar(tag="preventDefault"),
|
||||
ImportVar(tag="refs"),
|
||||
ImportVar(tag="getRefValue"),
|
||||
ImportVar(tag="getRefValues"),
|
||||
ImportVar(tag="getAllLocalStorageItems"),
|
||||
ImportVar(tag="useEventLoop"),
|
||||
],
|
||||
"/utils/context.js": [
|
||||
ImportVar(tag="EventLoopContext"),
|
||||
ImportVar(tag="initialEvents"),
|
||||
ImportVar(tag="StateContext"),
|
||||
ImportVar(tag="ColorModeContext"),
|
||||
],
|
||||
"/utils/helpers/range.js": [
|
||||
ImportVar(tag="range", is_default=True),
|
||||
],
|
||||
DEFAULT_IMPORTS: ImportDict = {
|
||||
"": [ImportVar(tag="focus-visible/dist/focus-visible", install=False)],
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
from jinja2 import Environment, FileSystemLoader, Template
|
||||
|
||||
from reflex import constants
|
||||
from reflex.utils.format import json_dumps
|
||||
from reflex.utils.format import format_state_name, json_dumps
|
||||
|
||||
|
||||
class ReflexJinjaEnvironment(Environment):
|
||||
@ -19,6 +19,7 @@ class ReflexJinjaEnvironment(Environment):
|
||||
)
|
||||
self.filters["json_dumps"] = json_dumps
|
||||
self.filters["react_setter"] = lambda state: f"set{state.capitalize()}"
|
||||
self.filters["var_name"] = format_state_name
|
||||
self.loader = FileSystemLoader(constants.Templates.Dirs.JINJA_TEMPLATE)
|
||||
self.globals["const"] = {
|
||||
"socket": constants.CompileVars.SOCKET,
|
||||
|
@ -24,13 +24,12 @@ from reflex.components.component import Component, ComponentStyle, CustomCompone
|
||||
from reflex.state import Cookie, LocalStorage, State
|
||||
from reflex.style import Style
|
||||
from reflex.utils import console, format, imports, path_ops
|
||||
from reflex.vars import ImportVar
|
||||
|
||||
# To re-export this function.
|
||||
merge_imports = imports.merge_imports
|
||||
|
||||
|
||||
def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
|
||||
def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]:
|
||||
"""Compile an import statement.
|
||||
|
||||
Args:
|
||||
@ -343,7 +342,9 @@ def get_context_path() -> str:
|
||||
Returns:
|
||||
The path of the context module.
|
||||
"""
|
||||
return os.path.join(constants.Dirs.WEB_UTILS, "context" + constants.Ext.JS)
|
||||
return os.path.join(
|
||||
constants.Dirs.WEB, constants.Dirs.CONTEXTS_PATH + constants.Ext.JS
|
||||
)
|
||||
|
||||
|
||||
def get_components_path() -> str:
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""A bare component."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Iterator
|
||||
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.tags import Tag
|
||||
@ -24,7 +24,21 @@ class Bare(Component):
|
||||
Returns:
|
||||
The component.
|
||||
"""
|
||||
return cls(contents=str(contents)) # type: ignore
|
||||
if isinstance(contents, Var) and contents._var_data:
|
||||
contents = contents.to(str)
|
||||
else:
|
||||
contents = str(contents)
|
||||
return cls(contents=contents) # type: ignore
|
||||
|
||||
def _render(self) -> Tag:
|
||||
return Tagless(contents=str(self.contents))
|
||||
|
||||
def _get_vars(self) -> Iterator[Var]:
|
||||
"""Walk all Vars used in this component.
|
||||
|
||||
Yields:
|
||||
The contents if it is a Var, otherwise nothing.
|
||||
"""
|
||||
if isinstance(self.contents, Var):
|
||||
# Fast path for Bare text components.
|
||||
yield self.contents
|
||||
|
@ -5,11 +5,11 @@ from __future__ import annotations
|
||||
import typing
|
||||
from abc import ABC
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Type, Union
|
||||
|
||||
from reflex.base import Base
|
||||
from reflex.components.tags import Tag
|
||||
from reflex.constants import Dirs, EventTriggers
|
||||
from reflex.constants import Dirs, EventTriggers, Hooks, Imports
|
||||
from reflex.event import (
|
||||
EventChain,
|
||||
EventHandler,
|
||||
@ -20,8 +20,9 @@ from reflex.event import (
|
||||
)
|
||||
from reflex.style import Style
|
||||
from reflex.utils import console, format, imports, types
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.utils.serializers import serializer
|
||||
from reflex.vars import BaseVar, ImportVar, Var
|
||||
from reflex.vars import BaseVar, Var
|
||||
|
||||
|
||||
class Component(Base, ABC):
|
||||
@ -388,7 +389,11 @@ class Component(Base, ABC):
|
||||
props = props.copy()
|
||||
|
||||
props.update(
|
||||
self.event_triggers,
|
||||
**{
|
||||
trigger: handler
|
||||
for trigger, handler in self.event_triggers.items()
|
||||
if trigger not in {EventTriggers.ON_MOUNT, EventTriggers.ON_UNMOUNT}
|
||||
},
|
||||
key=self.key,
|
||||
id=self.id,
|
||||
class_name=self.class_name,
|
||||
@ -488,7 +493,7 @@ class Component(Base, ABC):
|
||||
"""
|
||||
if type(self) in style:
|
||||
# Extract the style for this component.
|
||||
component_style = Style(style[type(self)])
|
||||
component_style = style[type(self)]
|
||||
|
||||
# Only add style props that are not overridden.
|
||||
component_style = {
|
||||
@ -564,6 +569,78 @@ class Component(Base, ABC):
|
||||
if self._valid_children:
|
||||
validate_valid_child(name)
|
||||
|
||||
@staticmethod
|
||||
def _get_vars_from_event_triggers(
|
||||
event_triggers: dict[str, EventChain | Var],
|
||||
) -> Iterator[tuple[str, list[Var]]]:
|
||||
"""Get the Vars associated with each event trigger.
|
||||
|
||||
Args:
|
||||
event_triggers: The event triggers from the component instance.
|
||||
|
||||
Yields:
|
||||
tuple of (event_name, event_vars)
|
||||
"""
|
||||
for event_trigger, event in event_triggers.items():
|
||||
if isinstance(event, Var):
|
||||
yield event_trigger, [event]
|
||||
elif isinstance(event, EventChain):
|
||||
event_args = []
|
||||
for spec in event.events:
|
||||
for args in spec.args:
|
||||
event_args.extend(args)
|
||||
yield event_trigger, event_args
|
||||
|
||||
def _get_vars(self) -> list[Var]:
|
||||
"""Walk all Vars used in this component.
|
||||
|
||||
Returns:
|
||||
Each var referenced by the component (props, styles, event handlers).
|
||||
"""
|
||||
vars = getattr(self, "__vars", None)
|
||||
if vars is not None:
|
||||
return vars
|
||||
vars = self.__vars = []
|
||||
# Get Vars associated with event trigger arguments.
|
||||
for _, event_vars in self._get_vars_from_event_triggers(self.event_triggers):
|
||||
vars.extend(event_vars)
|
||||
|
||||
# Get Vars associated with component props.
|
||||
for prop in self.get_props():
|
||||
prop_var = getattr(self, prop)
|
||||
if isinstance(prop_var, Var):
|
||||
vars.append(prop_var)
|
||||
|
||||
# Style keeps track of its own VarData instance, so embed in a temp Var that is yielded.
|
||||
if self.style:
|
||||
vars.append(
|
||||
BaseVar(
|
||||
_var_name="style",
|
||||
_var_type=str,
|
||||
_var_data=self.style._var_data,
|
||||
)
|
||||
)
|
||||
|
||||
# Special props are always Var instances.
|
||||
vars.extend(self.special_props)
|
||||
|
||||
# Get Vars associated with common Component props.
|
||||
for comp_prop in (
|
||||
self.class_name,
|
||||
self.id,
|
||||
self.key,
|
||||
self.autofocus,
|
||||
*self.custom_attrs.values(),
|
||||
):
|
||||
if isinstance(comp_prop, Var):
|
||||
vars.append(comp_prop)
|
||||
elif isinstance(comp_prop, str):
|
||||
# Collapse VarData encoded in f-strings.
|
||||
var = Var.create_safe(comp_prop)
|
||||
if var._var_data is not None:
|
||||
vars.append(var)
|
||||
return vars
|
||||
|
||||
def _get_custom_code(self) -> str | None:
|
||||
"""Get custom code for the component.
|
||||
|
||||
@ -644,6 +721,33 @@ class Component(Base, ABC):
|
||||
dep: [ImportVar(tag=None, render=False)] for dep in self.lib_dependencies
|
||||
}
|
||||
|
||||
def _get_hooks_imports(self) -> imports.ImportDict:
|
||||
"""Get the imports required by certain hooks.
|
||||
|
||||
Returns:
|
||||
The imports required for all selected hooks.
|
||||
"""
|
||||
_imports = {}
|
||||
|
||||
if self._get_ref_hook():
|
||||
# Handle hooks needed for attaching react refs to DOM nodes.
|
||||
_imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
|
||||
_imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs"))
|
||||
|
||||
if self._get_mount_lifecycle_hook():
|
||||
# Handle hooks for `on_mount` / `on_unmount`.
|
||||
_imports.setdefault("react", set()).add(ImportVar(tag="useEffect"))
|
||||
|
||||
if self._get_special_hooks():
|
||||
# Handle additional internal hooks (autofocus, etc).
|
||||
_imports.setdefault("react", set()).update(
|
||||
{
|
||||
ImportVar(tag="useRef"),
|
||||
ImportVar(tag="useEffect"),
|
||||
},
|
||||
)
|
||||
return _imports
|
||||
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
"""Get all the libraries and fields that are used by the component.
|
||||
|
||||
@ -651,13 +755,26 @@ class Component(Base, ABC):
|
||||
The imports needed by the component.
|
||||
"""
|
||||
_imports = {}
|
||||
|
||||
# Import this component's tag from the main library.
|
||||
if self.library is not None and self.tag is not None:
|
||||
_imports[self.library] = {self.import_var}
|
||||
|
||||
# Get static imports required for event processing.
|
||||
event_imports = Imports.EVENTS if self.event_triggers else {}
|
||||
|
||||
# Collect imports from Vars used directly by this component.
|
||||
var_imports = [
|
||||
var._var_data.imports for var in self._get_vars() if var._var_data
|
||||
]
|
||||
|
||||
return imports.merge_imports(
|
||||
*self._get_props_imports(),
|
||||
self._get_dependencies_imports(),
|
||||
self._get_hooks_imports(),
|
||||
_imports,
|
||||
event_imports,
|
||||
*var_imports,
|
||||
)
|
||||
|
||||
def get_imports(self) -> imports.ImportDict:
|
||||
@ -678,13 +795,13 @@ class Component(Base, ABC):
|
||||
"""
|
||||
# pop on_mount and on_unmount from event_triggers since these are handled by
|
||||
# hooks, not as actually props in the component
|
||||
on_mount = self.event_triggers.pop(EventTriggers.ON_MOUNT, None)
|
||||
on_unmount = self.event_triggers.pop(EventTriggers.ON_UNMOUNT, None)
|
||||
if on_mount:
|
||||
on_mount = self.event_triggers.get(EventTriggers.ON_MOUNT, None)
|
||||
on_unmount = self.event_triggers.get(EventTriggers.ON_UNMOUNT, None)
|
||||
if on_mount is not None:
|
||||
on_mount = format.format_event_chain(on_mount)
|
||||
if on_unmount:
|
||||
if on_unmount is not None:
|
||||
on_unmount = format.format_event_chain(on_unmount)
|
||||
if on_mount or on_unmount:
|
||||
if on_mount is not None or on_unmount is not None:
|
||||
return f"""
|
||||
useEffect(() => {{
|
||||
{on_mount or ""}
|
||||
@ -703,6 +820,47 @@ class Component(Base, ABC):
|
||||
if ref is not None:
|
||||
return f"const {ref} = useRef(null); refs['{ref}'] = {ref};"
|
||||
|
||||
def _get_vars_hooks(self) -> set[str]:
|
||||
"""Get the hooks required by vars referenced in this component.
|
||||
|
||||
Returns:
|
||||
The hooks for the vars.
|
||||
"""
|
||||
vars_hooks = set()
|
||||
for var in self._get_vars():
|
||||
if var._var_data:
|
||||
vars_hooks.update(var._var_data.hooks)
|
||||
return vars_hooks
|
||||
|
||||
def _get_events_hooks(self) -> set[str]:
|
||||
"""Get the hooks required by events referenced in this component.
|
||||
|
||||
Returns:
|
||||
The hooks for the events.
|
||||
"""
|
||||
if self.event_triggers:
|
||||
return {Hooks.EVENTS}
|
||||
return set()
|
||||
|
||||
def _get_special_hooks(self) -> set[str]:
|
||||
"""Get the hooks required by special actions referenced in this component.
|
||||
|
||||
Returns:
|
||||
The hooks for special actions.
|
||||
"""
|
||||
if self.autofocus:
|
||||
return {
|
||||
"""
|
||||
// Set focus to the specified element.
|
||||
const focusRef = useRef(null)
|
||||
useEffect(() => {
|
||||
if (focusRef.current) {
|
||||
focusRef.current.focus();
|
||||
}
|
||||
})""",
|
||||
}
|
||||
return set()
|
||||
|
||||
def _get_hooks_internal(self) -> Set[str]:
|
||||
"""Get the React hooks for this component managed by the framework.
|
||||
|
||||
@ -712,10 +870,15 @@ class Component(Base, ABC):
|
||||
Returns:
|
||||
Set of internally managed hooks.
|
||||
"""
|
||||
return set(
|
||||
hook
|
||||
for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
|
||||
if hook
|
||||
return (
|
||||
set(
|
||||
hook
|
||||
for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()]
|
||||
if hook
|
||||
)
|
||||
| self._get_vars_hooks()
|
||||
| self._get_events_hooks()
|
||||
| self._get_special_hooks()
|
||||
)
|
||||
|
||||
def _get_hooks(self) -> str | None:
|
||||
@ -1018,11 +1181,24 @@ class NoSSRComponent(Component):
|
||||
"""A dynamic component that is not rendered on the server."""
|
||||
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
dynamic_import = {"next/dynamic": {ImportVar(tag="dynamic", is_default=True)}}
|
||||
"""Get the imports for the component.
|
||||
|
||||
Returns:
|
||||
The imports for dynamically importing the component at module load time.
|
||||
"""
|
||||
# Next.js dynamic import mechanism.
|
||||
dynamic_import = {"next/dynamic": [ImportVar(tag="dynamic", is_default=True)]}
|
||||
|
||||
# The normal imports for this component.
|
||||
_imports = super()._get_imports()
|
||||
|
||||
# Do NOT import the main library/tag statically.
|
||||
if self.library is not None:
|
||||
_imports[self.library] = [imports.ImportVar(tag=None, render=False)]
|
||||
|
||||
return imports.merge_imports(
|
||||
dynamic_import,
|
||||
{self.library: {ImportVar(tag=None, render=False)}},
|
||||
_imports,
|
||||
self._get_dependencies_imports(),
|
||||
)
|
||||
|
||||
|
@ -12,7 +12,8 @@ from reflex.components.media import Icon
|
||||
from reflex.event import set_clipboard
|
||||
from reflex.style import Style
|
||||
from reflex.utils import format, imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import Var
|
||||
|
||||
LiteralCodeBlockTheme = Literal[
|
||||
"a11y-dark",
|
||||
|
@ -16,7 +16,8 @@ from reflex.components.media import Icon
|
||||
from reflex.event import set_clipboard
|
||||
from reflex.style import Style
|
||||
from reflex.utils import format, imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import Var
|
||||
|
||||
LiteralCodeBlockTheme = Literal[
|
||||
"a11y-dark",
|
||||
|
@ -8,8 +8,9 @@ from reflex.base import Base
|
||||
from reflex.components.component import Component, NoSSRComponent
|
||||
from reflex.components.literals import LiteralRowMarker
|
||||
from reflex.utils import console, format, imports, types
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.utils.serializers import serializer
|
||||
from reflex.vars import ImportVar, Var, get_unique_variable_name
|
||||
from reflex.vars import Var, get_unique_variable_name
|
||||
|
||||
|
||||
# TODO: Fix the serialization issue for custom types.
|
||||
|
@ -13,8 +13,9 @@ from reflex.base import Base
|
||||
from reflex.components.component import Component, NoSSRComponent
|
||||
from reflex.components.literals import LiteralRowMarker
|
||||
from reflex.utils import console, format, imports, types
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.utils.serializers import serializer
|
||||
from reflex.vars import ImportVar, Var, get_unique_variable_name
|
||||
from reflex.vars import Var, get_unique_variable_name
|
||||
|
||||
class GridColumnIcons(Enum):
|
||||
Array = "array"
|
||||
|
@ -8,7 +8,7 @@ from reflex.components.component import Component
|
||||
from reflex.components.tags import Tag
|
||||
from reflex.utils import imports, types
|
||||
from reflex.utils.serializers import serialize, serializer
|
||||
from reflex.vars import BaseVar, ComputedVar, ImportVar, Var
|
||||
from reflex.vars import BaseVar, ComputedVar, Var
|
||||
|
||||
|
||||
class Gridjs(Component):
|
||||
@ -105,7 +105,7 @@ class DataTable(Gridjs):
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
return imports.merge_imports(
|
||||
super()._get_imports(),
|
||||
{"": {ImportVar(tag="gridjs/dist/theme/mermaid.css")}},
|
||||
{"": {imports.ImportVar(tag="gridjs/dist/theme/mermaid.css")}},
|
||||
)
|
||||
|
||||
def _render(self) -> Tag:
|
||||
@ -113,13 +113,13 @@ class DataTable(Gridjs):
|
||||
self.columns = BaseVar(
|
||||
_var_name=f"{self.data._var_name}.columns",
|
||||
_var_type=List[Any],
|
||||
_var_state=self.data._var_state,
|
||||
)
|
||||
_var_full_name_needs_state_prefix=True,
|
||||
)._replace(merge_var_data=self.data._var_data)
|
||||
self.data = BaseVar(
|
||||
_var_name=f"{self.data._var_name}.data",
|
||||
_var_type=List[List[Any]],
|
||||
_var_state=self.data._var_state,
|
||||
)
|
||||
_var_full_name_needs_state_prefix=True,
|
||||
)._replace(merge_var_data=self.data._var_data)
|
||||
if types.is_dataframe(type(self.data)):
|
||||
# If given a pandas df break up the data and columns
|
||||
data = serialize(self.data)
|
||||
|
@ -12,7 +12,7 @@ from reflex.components.component import Component
|
||||
from reflex.components.tags import Tag
|
||||
from reflex.utils import imports, types
|
||||
from reflex.utils.serializers import serialize, serializer
|
||||
from reflex.vars import BaseVar, ComputedVar, ImportVar, Var
|
||||
from reflex.vars import BaseVar, ComputedVar, Var
|
||||
|
||||
class Gridjs(Component):
|
||||
@overload
|
||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List
|
||||
|
||||
from reflex.components.component import Component, NoSSRComponent
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.vars import Var
|
||||
|
||||
|
||||
class Moment(NoSSRComponent):
|
||||
@ -78,7 +78,7 @@ class Moment(NoSSRComponent):
|
||||
if self.tz is not None:
|
||||
merged_imports = imports.merge_imports(
|
||||
merged_imports,
|
||||
{"moment-timezone": {ImportVar(tag="")}},
|
||||
{"moment-timezone": {imports.ImportVar(tag="")}},
|
||||
)
|
||||
return merged_imports
|
||||
|
||||
|
@ -10,7 +10,7 @@ from reflex.style import Style
|
||||
from typing import Any, Dict, List
|
||||
from reflex.components.component import Component, NoSSRComponent
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.vars import Var
|
||||
|
||||
class Moment(NoSSRComponent):
|
||||
def get_event_triggers(self) -> Dict[str, Any]: ...
|
||||
|
@ -22,7 +22,7 @@ from reflex.components.component import Component
|
||||
from reflex.components.layout.cond import Cond, cond
|
||||
from reflex.components.media.icon import Icon
|
||||
from reflex.style import color_mode, toggle_color_mode
|
||||
from reflex.vars import BaseVar
|
||||
from reflex.vars import Var
|
||||
|
||||
from .button import Button
|
||||
from .switch import Switch
|
||||
@ -32,7 +32,7 @@ DEFAULT_LIGHT_ICON: Icon = Icon.create(tag="sun")
|
||||
DEFAULT_DARK_ICON: Icon = Icon.create(tag="moon")
|
||||
|
||||
|
||||
def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component:
|
||||
def color_mode_cond(light: Any, dark: Any = None) -> Var | Component:
|
||||
"""Create a component or Prop based on color_mode.
|
||||
|
||||
Args:
|
||||
|
@ -12,7 +12,7 @@ from reflex.components.component import Component
|
||||
from reflex.components.layout.cond import Cond, cond
|
||||
from reflex.components.media.icon import Icon
|
||||
from reflex.style import color_mode, toggle_color_mode
|
||||
from reflex.vars import BaseVar
|
||||
from reflex.vars import Var
|
||||
from .button import Button
|
||||
from .switch import Switch
|
||||
|
||||
@ -20,7 +20,7 @@ DEFAULT_COLOR_MODE: str
|
||||
DEFAULT_LIGHT_ICON: Icon
|
||||
DEFAULT_DARK_ICON: Icon
|
||||
|
||||
def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component: ...
|
||||
def color_mode_cond(light: Any, dark: Any = None) -> Var | Component: ...
|
||||
|
||||
class ColorModeIcon(Cond):
|
||||
@overload
|
||||
|
@ -1,10 +1,11 @@
|
||||
"""Wrapper around react-debounce-input."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Set
|
||||
|
||||
from reflex.components import Component
|
||||
from reflex.components.tags import Tag
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import Var
|
||||
|
||||
|
||||
@ -77,6 +78,17 @@ class DebounceInput(Component):
|
||||
object.__setattr__(child, "render", lambda: "")
|
||||
return tag
|
||||
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
return imports.merge_imports(
|
||||
super()._get_imports(), *[c._get_imports() for c in self.children]
|
||||
)
|
||||
|
||||
def _get_hooks_internal(self) -> Set[str]:
|
||||
hooks = super()._get_hooks_internal()
|
||||
for child in self.children:
|
||||
hooks.update(child._get_hooks_internal())
|
||||
return hooks
|
||||
|
||||
|
||||
def props_not_none(c: Component) -> dict[str, Any]:
|
||||
"""Get all properties of the component that are not None.
|
||||
|
@ -7,9 +7,10 @@ from typing import Any, Dict, Literal, Optional, Union, overload
|
||||
from reflex.vars import Var, BaseVar, ComputedVar
|
||||
from reflex.event import EventChain, EventHandler, EventSpec
|
||||
from reflex.style import Style
|
||||
from typing import Any
|
||||
from typing import Any, Set
|
||||
from reflex.components import Component
|
||||
from reflex.components.tags import Tag
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import Var
|
||||
|
||||
class DebounceInput(Component):
|
||||
|
@ -8,7 +8,8 @@ from reflex.base import Base
|
||||
from reflex.components.component import Component, NoSSRComponent
|
||||
from reflex.constants import EventTriggers
|
||||
from reflex.utils.format import to_camel_case
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import Var
|
||||
|
||||
|
||||
class EditorButtonList(list, enum.Enum):
|
||||
|
@ -13,7 +13,8 @@ from reflex.base import Base
|
||||
from reflex.components.component import Component, NoSSRComponent
|
||||
from reflex.constants import EventTriggers
|
||||
from reflex.utils.format import to_camel_case
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import Var
|
||||
|
||||
class EditorButtonList(list, enum.Enum):
|
||||
BASIC = [["font", "fontSize"], ["fontColor"], ["horizontalRule"], ["link", "image"]]
|
||||
|
@ -8,7 +8,7 @@ from jinja2 import Environment
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.libs.chakra import ChakraComponent
|
||||
from reflex.components.tags import Tag
|
||||
from reflex.constants import EventTriggers
|
||||
from reflex.constants import Dirs, EventTriggers
|
||||
from reflex.event import EventChain
|
||||
from reflex.utils import imports
|
||||
from reflex.utils.format import format_event_chain, to_camel_case
|
||||
@ -65,7 +65,13 @@ class Form(ChakraComponent):
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
return imports.merge_imports(
|
||||
super()._get_imports(),
|
||||
{"react": {imports.ImportVar(tag="useCallback")}},
|
||||
{
|
||||
"react": {imports.ImportVar(tag="useCallback")},
|
||||
f"/{Dirs.STATE_PATH}": {
|
||||
imports.ImportVar(tag="getRefValue"),
|
||||
imports.ImportVar(tag="getRefValues"),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def _get_hooks(self) -> str | None:
|
||||
|
@ -12,7 +12,7 @@ from jinja2 import Environment
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.libs.chakra import ChakraComponent
|
||||
from reflex.components.tags import Tag
|
||||
from reflex.constants import EventTriggers
|
||||
from reflex.constants import Dirs, EventTriggers
|
||||
from reflex.event import EventChain
|
||||
from reflex.utils import imports
|
||||
from reflex.utils.format import format_event_chain, to_camel_case
|
||||
|
@ -11,7 +11,7 @@ from reflex.components.libs.chakra import (
|
||||
)
|
||||
from reflex.constants import EventTriggers
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.vars import Var
|
||||
|
||||
|
||||
class Input(ChakraComponent):
|
||||
@ -61,7 +61,7 @@ class Input(ChakraComponent):
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
return imports.merge_imports(
|
||||
super()._get_imports(),
|
||||
{"/utils/state": {ImportVar(tag="set_val")}},
|
||||
{"/utils/state": {imports.ImportVar(tag="set_val")}},
|
||||
)
|
||||
|
||||
def get_event_triggers(self) -> Dict[str, Any]:
|
||||
|
@ -17,7 +17,7 @@ from reflex.components.libs.chakra import (
|
||||
)
|
||||
from reflex.constants import EventTriggers
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.vars import Var
|
||||
|
||||
class Input(ChakraComponent):
|
||||
def get_event_triggers(self) -> Dict[str, Any]: ...
|
||||
|
@ -68,9 +68,11 @@ class PinInput(ChakraComponent):
|
||||
Returns:
|
||||
The merged import dict.
|
||||
"""
|
||||
range_var = Var.range(0)
|
||||
return merge_imports(
|
||||
super()._get_imports(),
|
||||
PinInputField().get_imports(), # type: ignore
|
||||
range_var._var_data.imports if range_var._var_data is not None else {},
|
||||
)
|
||||
|
||||
def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
|
||||
@ -117,7 +119,7 @@ class PinInput(ChakraComponent):
|
||||
)
|
||||
refs_declaration._var_is_local = True
|
||||
if ref:
|
||||
return f"const {ref} = {refs_declaration}"
|
||||
return f"const {ref} = {str(refs_declaration)}"
|
||||
return super()._get_ref_hook()
|
||||
|
||||
def _render(self) -> Tag:
|
||||
|
@ -7,9 +7,10 @@ from reflex import constants
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.forms.input import Input
|
||||
from reflex.components.layout.box import Box
|
||||
from reflex.constants import Dirs
|
||||
from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import BaseVar, CallableVar, ImportVar, Var
|
||||
from reflex.vars import BaseVar, CallableVar, Var, VarData
|
||||
|
||||
DEFAULT_UPLOAD_ID: str = "default"
|
||||
|
||||
@ -30,6 +31,13 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
|
||||
return BaseVar(
|
||||
_var_name=f"e => upload_files.{id_}[1]((files) => e)",
|
||||
_var_type=EventChain,
|
||||
_var_data=VarData( # type: ignore
|
||||
imports={
|
||||
f"/{Dirs.STATE_PATH}": {
|
||||
imports.ImportVar(tag="upload_files"),
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -46,6 +54,13 @@ def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
|
||||
return BaseVar(
|
||||
_var_name=f"(upload_files.{id_} ? upload_files.{id_}[0]?.map((f) => (f.path || f.name)) : [])",
|
||||
_var_type=List[str],
|
||||
_var_data=VarData( # type: ignore
|
||||
imports={
|
||||
f"/{Dirs.STATE_PATH}": {
|
||||
imports.ImportVar(tag="upload_files"),
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -166,14 +181,16 @@ class Upload(Component):
|
||||
|
||||
def _get_hooks(self) -> str | None:
|
||||
return (
|
||||
(super()._get_hooks() or "")
|
||||
+ f"""
|
||||
upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);
|
||||
"""
|
||||
)
|
||||
super()._get_hooks() or ""
|
||||
) + f"upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);"
|
||||
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
return {
|
||||
**super()._get_imports(),
|
||||
f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="upload_files")],
|
||||
}
|
||||
return imports.merge_imports(
|
||||
super()._get_imports(),
|
||||
{
|
||||
"react": {imports.ImportVar(tag="useState")},
|
||||
f"/{constants.Dirs.STATE_PATH}": [
|
||||
imports.ImportVar(tag="upload_files")
|
||||
],
|
||||
},
|
||||
)
|
||||
|
@ -12,9 +12,10 @@ from reflex import constants
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.forms.input import Input
|
||||
from reflex.components.layout.box import Box
|
||||
from reflex.constants import Dirs
|
||||
from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import BaseVar, CallableVar, ImportVar, Var
|
||||
from reflex.vars import BaseVar, CallableVar, Var, VarData
|
||||
|
||||
DEFAULT_UPLOAD_ID: str
|
||||
|
||||
|
@ -1,13 +1,18 @@
|
||||
"""Create a list of components from an iterable."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, overload
|
||||
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.layout.fragment import Fragment
|
||||
from reflex.components.tags import CondTag, Tag
|
||||
from reflex.utils import format
|
||||
from reflex.vars import Var
|
||||
from reflex.constants import Dirs
|
||||
from reflex.utils import format, imports
|
||||
from reflex.vars import BaseVar, Var, VarData
|
||||
|
||||
_IS_TRUE_IMPORT = {
|
||||
f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
|
||||
}
|
||||
|
||||
|
||||
class Cond(Component):
|
||||
@ -88,6 +93,28 @@ class Cond(Component):
|
||||
cond_state=f"isTrue({self.cond._var_full_name})",
|
||||
)
|
||||
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
return imports.merge_imports(
|
||||
super()._get_imports(),
|
||||
getattr(self.cond._var_data, "imports", {}),
|
||||
_IS_TRUE_IMPORT,
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def cond(condition: Any, c1: Component, c2: Any) -> Component:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def cond(condition: Any, c1: Component) -> Component:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def cond(condition: Any, c1: Any, c2: Any) -> Var:
|
||||
...
|
||||
|
||||
|
||||
def cond(condition: Any, c1: Any, c2: Any = None):
|
||||
"""Create a conditional component or Prop.
|
||||
@ -103,8 +130,11 @@ def cond(condition: Any, c1: Any, c2: Any = None):
|
||||
Raises:
|
||||
ValueError: If the arguments are invalid.
|
||||
"""
|
||||
# Import here to avoid circular imports.
|
||||
from reflex.vars import BaseVar, Var
|
||||
var_datas: list[VarData | None] = [
|
||||
VarData( # type: ignore
|
||||
imports=_IS_TRUE_IMPORT,
|
||||
),
|
||||
]
|
||||
|
||||
# Convert the condition to a Var.
|
||||
cond_var = Var.create(condition)
|
||||
@ -116,16 +146,20 @@ def cond(condition: Any, c1: Any, c2: Any = None):
|
||||
c2, Component
|
||||
), "Both arguments must be components."
|
||||
return Cond.create(cond_var, c1, c2)
|
||||
if isinstance(c1, Var):
|
||||
var_datas.append(c1._var_data)
|
||||
|
||||
# Otherwise, create a conditionl Var.
|
||||
# Otherwise, create a conditional Var.
|
||||
# Check that the second argument is valid.
|
||||
if isinstance(c2, Component):
|
||||
raise ValueError("Both arguments must be props.")
|
||||
if c2 is None:
|
||||
raise ValueError("For conditional vars, the second argument must be set.")
|
||||
if isinstance(c2, Var):
|
||||
var_datas.append(c2._var_data)
|
||||
|
||||
# Create the conditional var.
|
||||
return BaseVar(
|
||||
return cond_var._replace(
|
||||
_var_name=format.format_cond(
|
||||
cond=cond_var._var_full_name,
|
||||
true_value=c1,
|
||||
@ -133,4 +167,7 @@ def cond(condition: Any, c1: Any, c2: Any = None):
|
||||
is_prop=True,
|
||||
),
|
||||
_var_type=c1._var_type if isinstance(c1, BaseVar) else type(c1),
|
||||
_var_is_local=False,
|
||||
_var_full_name_needs_state_prefix=False,
|
||||
merge_var_data=VarData.merge(*var_datas),
|
||||
)
|
||||
|
@ -1,8 +1,8 @@
|
||||
"""A html component."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
from reflex.components.layout.box import Box
|
||||
from reflex.vars import Var
|
||||
|
||||
|
||||
class Html(Box):
|
||||
@ -13,7 +13,7 @@ class Html(Box):
|
||||
"""
|
||||
|
||||
# The HTML to render.
|
||||
dangerouslySetInnerHTML: Any
|
||||
dangerouslySetInnerHTML: Var[Dict[str, str]]
|
||||
|
||||
@classmethod
|
||||
def create(cls, *children, **props):
|
||||
|
@ -7,8 +7,9 @@ from typing import Any, Dict, Literal, Optional, Union, overload
|
||||
from reflex.vars import Var, BaseVar, ComputedVar
|
||||
from reflex.event import EventChain, EventHandler, EventSpec
|
||||
from reflex.style import Style
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from reflex.components.layout.box import Box
|
||||
from reflex.vars import Var
|
||||
|
||||
class Html(Box):
|
||||
@overload
|
||||
@ -16,7 +17,9 @@ class Html(Box):
|
||||
def create( # type: ignore
|
||||
cls,
|
||||
*children,
|
||||
dangerouslySetInnerHTML: Optional[Any] = None,
|
||||
dangerouslySetInnerHTML: Optional[
|
||||
Union[Var[Dict[str, str]], Dict[str, str]]
|
||||
] = None,
|
||||
element: Optional[Union[Var[str], str]] = None,
|
||||
src: Optional[Union[Var[str], str]] = None,
|
||||
alt: Optional[Union[Var[str], str]] = None,
|
||||
|
@ -6,7 +6,7 @@ from typing import List, Literal
|
||||
|
||||
from reflex.components.component import Component
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.vars import Var
|
||||
|
||||
|
||||
class ChakraComponent(Component):
|
||||
@ -34,7 +34,7 @@ class ChakraComponent(Component):
|
||||
The dependencies imports of the component.
|
||||
"""
|
||||
return {
|
||||
dep: [ImportVar(tag=None, render=False)]
|
||||
dep: [imports.ImportVar(tag=None, render=False)]
|
||||
for dep in [
|
||||
"@chakra-ui/system@2.5.7",
|
||||
"framer-motion@10.16.4",
|
||||
@ -75,17 +75,17 @@ class ChakraProvider(ChakraComponent):
|
||||
)
|
||||
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
imports = super()._get_imports()
|
||||
imports.setdefault(self.__fields__["library"].default, []).append(
|
||||
ImportVar(tag="extendTheme", is_default=False),
|
||||
_imports = super()._get_imports()
|
||||
_imports.setdefault(self.__fields__["library"].default, []).append(
|
||||
imports.ImportVar(tag="extendTheme", is_default=False),
|
||||
)
|
||||
imports.setdefault("/utils/theme.js", []).append(
|
||||
ImportVar(tag="theme", is_default=True),
|
||||
_imports.setdefault("/utils/theme.js", []).append(
|
||||
imports.ImportVar(tag="theme", is_default=True),
|
||||
)
|
||||
imports.setdefault(Global.__fields__["library"].default, []).append(
|
||||
ImportVar(tag="css", is_default=False),
|
||||
_imports.setdefault(Global.__fields__["library"].default, []).append(
|
||||
imports.ImportVar(tag="css", is_default=False),
|
||||
)
|
||||
return imports
|
||||
return _imports
|
||||
|
||||
def _get_custom_code(self) -> str | None:
|
||||
return """
|
||||
|
@ -11,7 +11,7 @@ from functools import lru_cache
|
||||
from typing import List, Literal
|
||||
from reflex.components.component import Component
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.vars import Var
|
||||
|
||||
class ChakraComponent(Component):
|
||||
@overload
|
||||
|
@ -10,10 +10,9 @@ routeNotFound becomes true.
|
||||
from __future__ import annotations
|
||||
|
||||
from reflex import constants
|
||||
|
||||
from ...vars import Var
|
||||
from ..component import Component
|
||||
from ..layout.cond import Cond
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.layout.cond import cond
|
||||
from reflex.vars import Var
|
||||
|
||||
route_not_found: Var = Var.create_safe(constants.ROUTE_NOT_FOUND)
|
||||
|
||||
@ -52,10 +51,10 @@ def wait_for_client_redirect(component) -> Component:
|
||||
Returns:
|
||||
The conditionally rendered component.
|
||||
"""
|
||||
return Cond.create(
|
||||
cond=route_not_found,
|
||||
comp1=component,
|
||||
comp2=ClientSideRouting.create(),
|
||||
return cond(
|
||||
condition=route_not_found,
|
||||
c1=component,
|
||||
c2=ClientSideRouting.create(),
|
||||
)
|
||||
|
||||
|
||||
|
@ -8,9 +8,9 @@ from reflex.vars import Var, BaseVar, ComputedVar
|
||||
from reflex.event import EventChain, EventHandler, EventSpec
|
||||
from reflex.style import Style
|
||||
from reflex import constants
|
||||
from ...vars import Var
|
||||
from ..component import Component
|
||||
from ..layout.cond import Cond
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.layout.cond import cond
|
||||
from reflex.vars import Var
|
||||
|
||||
route_not_found: Var
|
||||
|
||||
|
@ -5,22 +5,27 @@ from typing import Optional
|
||||
|
||||
from reflex.components.base.bare import Bare
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.layout import Box, Cond
|
||||
from reflex.components.layout import Box, cond
|
||||
from reflex.components.overlay.modal import Modal
|
||||
from reflex.components.typography import Text
|
||||
from reflex.constants import Hooks, Imports
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.vars import Var, VarData
|
||||
|
||||
connect_error_var_data: VarData = VarData( # type: ignore
|
||||
imports=Imports.EVENTS,
|
||||
hooks={Hooks.EVENTS},
|
||||
)
|
||||
|
||||
connection_error: Var = Var.create_safe(
|
||||
value="(connectError !== null) ? connectError.message : ''",
|
||||
_var_is_local=False,
|
||||
_var_is_string=False,
|
||||
)
|
||||
)._replace(merge_var_data=connect_error_var_data)
|
||||
has_connection_error: Var = Var.create_safe(
|
||||
value="connectError !== null",
|
||||
_var_is_string=False,
|
||||
)
|
||||
has_connection_error._var_type = bool
|
||||
)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
|
||||
|
||||
|
||||
class WebsocketTargetURL(Bare):
|
||||
@ -28,7 +33,7 @@ class WebsocketTargetURL(Bare):
|
||||
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
return {
|
||||
"/utils/state.js": [ImportVar(tag="getEventURL")],
|
||||
"/utils/state.js": [imports.ImportVar(tag="getEventURL")],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -78,7 +83,7 @@ class ConnectionBanner(Component):
|
||||
textAlign="center",
|
||||
)
|
||||
|
||||
return Cond.create(has_connection_error, comp)
|
||||
return cond(has_connection_error, comp)
|
||||
|
||||
|
||||
class ConnectionModal(Component):
|
||||
@ -96,7 +101,7 @@ class ConnectionModal(Component):
|
||||
"""
|
||||
if not comp:
|
||||
comp = Text.create(*default_connection_error())
|
||||
return Cond.create(
|
||||
return cond(
|
||||
has_connection_error,
|
||||
Modal.create(
|
||||
header="Connection Error",
|
||||
|
@ -10,15 +10,16 @@ from reflex.style import Style
|
||||
from typing import Optional
|
||||
from reflex.components.base.bare import Bare
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.layout import Box, Cond
|
||||
from reflex.components.layout import Box, cond
|
||||
from reflex.components.overlay.modal import Modal
|
||||
from reflex.components.typography import Text
|
||||
from reflex.constants import Hooks, Imports
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.vars import Var, VarData
|
||||
|
||||
connect_error_var_data: VarData
|
||||
connection_error: Var
|
||||
has_connection_error: Var
|
||||
has_connection_error._var_type = bool
|
||||
|
||||
class WebsocketTargetURL(Bare):
|
||||
@overload
|
||||
|
@ -6,7 +6,7 @@ from typing import Literal
|
||||
|
||||
from reflex.components import Component
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.vars import Var
|
||||
|
||||
LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
|
||||
LiteralJustify = Literal["start", "center", "end", "between"]
|
||||
@ -147,7 +147,7 @@ class Theme(RadixThemesComponent):
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
return {
|
||||
**super()._get_imports(),
|
||||
"": [ImportVar(tag="@radix-ui/themes/styles.css", install=False)],
|
||||
"": [imports.ImportVar(tag="@radix-ui/themes/styles.css", install=False)],
|
||||
}
|
||||
|
||||
|
||||
|
@ -10,7 +10,7 @@ from reflex.style import Style
|
||||
from typing import Literal
|
||||
from reflex.components import Component
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.vars import Var
|
||||
|
||||
LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"]
|
||||
LiteralJustify = Literal["start", "center", "end", "between"]
|
||||
|
@ -14,7 +14,8 @@ from reflex.components.typography.heading import Heading
|
||||
from reflex.components.typography.text import Text
|
||||
from reflex.style import Style
|
||||
from reflex.utils import console, imports, types
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import Var
|
||||
|
||||
# Special vars used in the component map.
|
||||
_CHILDREN = Var.create_safe("children", _var_is_local=False)
|
||||
|
@ -18,7 +18,8 @@ from reflex.components.typography.heading import Heading
|
||||
from reflex.components.typography.text import Text
|
||||
from reflex.style import Style
|
||||
from reflex.utils import console, imports, types
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import Var
|
||||
|
||||
_CHILDREN = Var.create_safe("children", _var_is_local=False)
|
||||
_PROPS = Var.create_safe("...props", _var_is_local=False)
|
||||
|
@ -22,6 +22,8 @@ from .compiler import (
|
||||
CompileVars,
|
||||
ComponentName,
|
||||
Ext,
|
||||
Hooks,
|
||||
Imports,
|
||||
PageNames,
|
||||
)
|
||||
from .config import (
|
||||
@ -68,7 +70,9 @@ __ALL__ = [
|
||||
Ext,
|
||||
Fnm,
|
||||
GitIgnore,
|
||||
Hooks,
|
||||
RequirementsTxt,
|
||||
Imports,
|
||||
IS_WINDOWS,
|
||||
LOCAL_STORAGE,
|
||||
LogLevel,
|
||||
|
@ -29,6 +29,8 @@ class Dirs(SimpleNamespace):
|
||||
STATE_PATH = "/".join([UTILS, "state"])
|
||||
# The name of the components file.
|
||||
COMPONENTS_PATH = "/".join([UTILS, "components"])
|
||||
# The name of the contexts file.
|
||||
CONTEXTS_PATH = "/".join([UTILS, "context"])
|
||||
# The directory where the app pages are compiled to.
|
||||
WEB_PAGES = os.path.join(WEB, "pages")
|
||||
# The directory where the static build is located.
|
||||
|
@ -2,6 +2,9 @@
|
||||
from enum import Enum
|
||||
from types import SimpleNamespace
|
||||
|
||||
from reflex.constants import Dirs
|
||||
from reflex.utils.imports import ImportVar
|
||||
|
||||
# The prefix used to create setters for state vars.
|
||||
SETTER_PREFIX = "set_"
|
||||
|
||||
@ -47,6 +50,12 @@ class CompileVars(SimpleNamespace):
|
||||
HYDRATE = "hydrate"
|
||||
# The name of the is_hydrated variable.
|
||||
IS_HYDRATED = "is_hydrated"
|
||||
# The name of the function to add events to the queue.
|
||||
ADD_EVENTS = "addEvents"
|
||||
# The name of the var storing any connection error.
|
||||
CONNECT_ERROR = "connectError"
|
||||
# The name of the function for converting a dict to an event.
|
||||
TO_EVENT = "Event"
|
||||
|
||||
|
||||
class PageNames(SimpleNamespace):
|
||||
@ -77,3 +86,19 @@ class ComponentName(Enum):
|
||||
The lower-case filename with zip extension.
|
||||
"""
|
||||
return self.value.lower() + Ext.ZIP
|
||||
|
||||
|
||||
class Imports(SimpleNamespace):
|
||||
"""Common sets of import vars."""
|
||||
|
||||
EVENTS = {
|
||||
"react": {ImportVar(tag="useContext")},
|
||||
f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
|
||||
f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
|
||||
}
|
||||
|
||||
|
||||
class Hooks(SimpleNamespace):
|
||||
"""Common sets of hook declarations."""
|
||||
|
||||
EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);"
|
||||
|
@ -48,7 +48,7 @@ class HydrateMiddleware(Middleware):
|
||||
setattr(var_state, var_name, value)
|
||||
|
||||
# Get the initial state.
|
||||
delta = format.format_state({state.get_name(): state.dict()})
|
||||
delta = format.format_state(state.dict())
|
||||
# since a full dict was captured, clean any dirtiness
|
||||
state._clean()
|
||||
|
||||
|
@ -1211,12 +1211,16 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if include_computed
|
||||
else {}
|
||||
)
|
||||
substate_vars = {
|
||||
k: v.dict(include_computed=include_computed, **kwargs)
|
||||
for k, v in self.substates.items()
|
||||
variables = {**base_vars, **computed_vars}
|
||||
d = {
|
||||
self.get_full_name(): {k: variables[k] for k in sorted(variables)},
|
||||
}
|
||||
variables = {**base_vars, **computed_vars, **substate_vars}
|
||||
return {k: variables[k] for k in sorted(variables)}
|
||||
for substate_d in [
|
||||
v.dict(include_computed=include_computed, **kwargs)
|
||||
for v in self.substates.values()
|
||||
]:
|
||||
d.update(substate_d)
|
||||
return d
|
||||
|
||||
async def __aenter__(self) -> State:
|
||||
"""Enter the async context manager protocol.
|
||||
|
@ -2,13 +2,38 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from reflex import constants
|
||||
from reflex.event import EventChain
|
||||
from reflex.utils import format
|
||||
from reflex.vars import BaseVar, Var
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import BaseVar, Var, VarData
|
||||
|
||||
color_mode = BaseVar(_var_name=constants.ColorMode.NAME, _var_type="str")
|
||||
toggle_color_mode = BaseVar(_var_name=constants.ColorMode.TOGGLE, _var_type=EventChain)
|
||||
VarData.update_forward_refs() # Ensure all type definitions are resolved
|
||||
|
||||
# Reference the global ColorModeContext
|
||||
color_mode_var_data = VarData( # type: ignore
|
||||
imports={
|
||||
f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")},
|
||||
"react": {ImportVar(tag="useContext")},
|
||||
},
|
||||
hooks={
|
||||
f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)",
|
||||
},
|
||||
)
|
||||
# Var resolves to the current color mode for the app ("light" or "dark")
|
||||
color_mode = BaseVar(
|
||||
_var_name=constants.ColorMode.NAME,
|
||||
_var_type="str",
|
||||
_var_data=color_mode_var_data,
|
||||
)
|
||||
# Var resolves to a function invocation that toggles the color mode
|
||||
toggle_color_mode = BaseVar(
|
||||
_var_name=constants.ColorMode.TOGGLE,
|
||||
_var_type=EventChain,
|
||||
_var_data=color_mode_var_data,
|
||||
)
|
||||
|
||||
|
||||
def convert(style_dict):
|
||||
@ -20,16 +45,27 @@ def convert(style_dict):
|
||||
Returns:
|
||||
The formatted style dictionary.
|
||||
"""
|
||||
var_data = None # Track import/hook data from any Vars in the style dict.
|
||||
out = {}
|
||||
for key, value in style_dict.items():
|
||||
key = format.to_camel_case(key)
|
||||
new_var_data = None
|
||||
if isinstance(value, dict):
|
||||
out[key] = convert(value)
|
||||
# Recursively format nested style dictionaries.
|
||||
out[key], new_var_data = convert(value)
|
||||
elif isinstance(value, Var):
|
||||
# If the value is a Var, extract the var_data and cast as str.
|
||||
new_var_data = value._var_data
|
||||
out[key] = str(value)
|
||||
else:
|
||||
# Otherwise, convert to Var to collapse VarData encoded in f-string.
|
||||
new_var = Var.create(value)
|
||||
if new_var is not None:
|
||||
new_var_data = new_var._var_data
|
||||
out[key] = value
|
||||
return out
|
||||
# Combine all the collected VarData instances.
|
||||
var_data = VarData.merge(var_data, new_var_data)
|
||||
return out, var_data
|
||||
|
||||
|
||||
class Style(dict):
|
||||
@ -41,5 +77,33 @@ class Style(dict):
|
||||
Args:
|
||||
style_dict: The style dictionary.
|
||||
"""
|
||||
style_dict = style_dict or {}
|
||||
super().__init__(convert(style_dict))
|
||||
style_dict, self._var_data = convert(style_dict or {})
|
||||
super().__init__(style_dict)
|
||||
|
||||
def update(self, style_dict: dict | None, **kwargs):
|
||||
"""Update the style.
|
||||
|
||||
Args:
|
||||
style_dict: The style dictionary.
|
||||
kwargs: Other key value pairs to apply to the dict update.
|
||||
"""
|
||||
if kwargs:
|
||||
style_dict = {**(style_dict or {}), **kwargs}
|
||||
converted_dict = type(self)(style_dict)
|
||||
# Combine our VarData with that of any Vars in the style_dict that was passed.
|
||||
self._var_data = VarData.merge(self._var_data, converted_dict._var_data)
|
||||
super().update(converted_dict)
|
||||
|
||||
def __setitem__(self, key: str, value: Any):
|
||||
"""Set an item in the style.
|
||||
|
||||
Args:
|
||||
key: The key to set.
|
||||
value: The value to set.
|
||||
"""
|
||||
# Create a Var to collapse VarData encoded in f-string.
|
||||
_var = Var.create(value)
|
||||
if _var is not None:
|
||||
# Carry the imports/hooks when setting a Var as a value.
|
||||
self._var_data = VarData.merge(self._var_data, _var._var_data)
|
||||
super().__setitem__(key, value)
|
||||
|
@ -232,9 +232,9 @@ def format_route(route: str, format_case=True) -> str:
|
||||
|
||||
|
||||
def format_cond(
|
||||
cond: str,
|
||||
true_value: str,
|
||||
false_value: str = '""',
|
||||
cond: str | Var,
|
||||
true_value: str | Var,
|
||||
false_value: str | Var = '""',
|
||||
is_prop=False,
|
||||
) -> str:
|
||||
"""Format a conditional expression.
|
||||
@ -248,9 +248,6 @@ def format_cond(
|
||||
Returns:
|
||||
The formatted conditional expression.
|
||||
"""
|
||||
# Import here to avoid circular imports.
|
||||
from reflex.vars import Var
|
||||
|
||||
# Use Python truthiness.
|
||||
cond = f"isTrue({cond})"
|
||||
|
||||
@ -266,6 +263,7 @@ def format_cond(
|
||||
_var_is_string=type(false_value) is str,
|
||||
)
|
||||
prop2._var_is_local = True
|
||||
prop1, prop2 = str(prop1), str(prop2) # avoid f-string semantics for Var
|
||||
return f"{cond} ? {prop1} : {prop2}".replace("{", "").replace("}", "")
|
||||
|
||||
# Format component conds.
|
||||
@ -517,6 +515,21 @@ def format_state(value: Any) -> Any:
|
||||
raise TypeError(f"No JSON serializer found for var {value} of type {type(value)}.")
|
||||
|
||||
|
||||
def format_state_name(state_name: str) -> str:
|
||||
"""Format a state name, replacing dots with double underscore.
|
||||
|
||||
This allows individual substates to be accessed independently as javascript vars
|
||||
without using dot notation.
|
||||
|
||||
Args:
|
||||
state_name: The state name to format.
|
||||
|
||||
Returns:
|
||||
The formatted state name.
|
||||
"""
|
||||
return state_name.replace(".", "__")
|
||||
|
||||
|
||||
def format_ref(ref: str) -> str:
|
||||
"""Format a ref.
|
||||
|
||||
|
@ -3,11 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from reflex.vars import ImportVar
|
||||
|
||||
ImportDict = Dict[str, List[ImportVar]]
|
||||
from reflex.base import Base
|
||||
|
||||
|
||||
def merge_imports(*imports) -> ImportDict:
|
||||
@ -24,3 +22,42 @@ def merge_imports(*imports) -> ImportDict:
|
||||
for lib, fields in import_dict.items():
|
||||
all_imports[lib].extend(fields)
|
||||
return all_imports
|
||||
|
||||
|
||||
class ImportVar(Base):
|
||||
"""An import var."""
|
||||
|
||||
# The name of the import tag.
|
||||
tag: Optional[str]
|
||||
|
||||
# whether the import is default or named.
|
||||
is_default: Optional[bool] = False
|
||||
|
||||
# The tag alias.
|
||||
alias: Optional[str] = None
|
||||
|
||||
# Whether this import need to install the associated lib
|
||||
install: Optional[bool] = True
|
||||
|
||||
# whether this import should be rendered or not
|
||||
render: Optional[bool] = True
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the import.
|
||||
|
||||
Returns:
|
||||
The name(tag name with alias) of tag.
|
||||
"""
|
||||
return self.tag if not self.alias else " as ".join([self.tag, self.alias]) # type: ignore
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Define a hash function for the import var.
|
||||
|
||||
Returns:
|
||||
The hash of the var.
|
||||
"""
|
||||
return hash((self.tag, self.is_default, self.alias, self.install, self.render))
|
||||
|
||||
|
||||
ImportDict = Dict[str, List[ImportVar]]
|
||||
|
@ -27,6 +27,7 @@ from reflex.utils import serializers
|
||||
GenericType = Union[Type, _GenericAlias]
|
||||
|
||||
# Valid state var types.
|
||||
JSONType = {str, int, float, bool}
|
||||
PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
|
||||
StateVar = Union[PrimitiveType, Base, None]
|
||||
StateIterVar = Union[list, set, tuple]
|
||||
|
418
reflex/vars.py
418
reflex/vars.py
@ -7,6 +7,7 @@ import dis
|
||||
import inspect
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
from types import CodeType, FunctionType
|
||||
@ -15,9 +16,11 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
@ -30,7 +33,10 @@ from typing import (
|
||||
|
||||
from reflex import constants
|
||||
from reflex.base import Base
|
||||
from reflex.utils import console, format, serializers, types
|
||||
from reflex.utils import console, format, imports, serializers, types
|
||||
|
||||
# This module used to export ImportVar itself, so we still import it for export here
|
||||
from reflex.utils.imports import ImportDict, ImportVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflex.state import State
|
||||
@ -71,7 +77,7 @@ OPERATION_MAPPING = {
|
||||
REPLACED_NAMES = {
|
||||
"full_name": "_var_full_name",
|
||||
"name": "_var_name",
|
||||
"state": "_var_state",
|
||||
"state": "_var_data.state",
|
||||
"type_": "_var_type",
|
||||
"is_local": "_var_is_local",
|
||||
"is_string": "_var_is_string",
|
||||
@ -93,6 +99,131 @@ def get_unique_variable_name() -> str:
|
||||
return get_unique_variable_name()
|
||||
|
||||
|
||||
class VarData(Base):
|
||||
"""Metadata associated with a Var."""
|
||||
|
||||
# The name of the enclosing state.
|
||||
state: str = ""
|
||||
|
||||
# Imports needed to render this var
|
||||
imports: ImportDict = {}
|
||||
|
||||
# Hooks that need to be present in the component to render this var
|
||||
hooks: Set[str] = set()
|
||||
|
||||
@classmethod
|
||||
def merge(cls, *others: VarData | None) -> VarData | None:
|
||||
"""Merge multiple var data objects.
|
||||
|
||||
Args:
|
||||
*others: The var data objects to merge.
|
||||
|
||||
Returns:
|
||||
The merged var data object.
|
||||
"""
|
||||
state = ""
|
||||
_imports = {}
|
||||
hooks = set()
|
||||
for var_data in others:
|
||||
if var_data is None:
|
||||
continue
|
||||
state = state or var_data.state
|
||||
_imports = imports.merge_imports(_imports, var_data.imports)
|
||||
hooks.update(var_data.hooks)
|
||||
return (
|
||||
cls(
|
||||
state=state,
|
||||
imports=_imports,
|
||||
hooks=hooks,
|
||||
)
|
||||
or None
|
||||
)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Check if the var data is non-empty.
|
||||
|
||||
Returns:
|
||||
True if any field is set to a non-default value.
|
||||
"""
|
||||
return bool(self.state or self.imports or self.hooks)
|
||||
|
||||
def dict(self) -> dict:
|
||||
"""Convert the var data to a dictionary.
|
||||
|
||||
Returns:
|
||||
The var data dictionary.
|
||||
"""
|
||||
return {
|
||||
"state": self.state,
|
||||
"imports": {
|
||||
lib: [import_var.dict() for import_var in import_vars]
|
||||
for lib, import_vars in self.imports.items()
|
||||
},
|
||||
"hooks": list(self.hooks),
|
||||
}
|
||||
|
||||
|
||||
def _encode_var(value: Var) -> str:
|
||||
"""Encode the state name into a formatted var.
|
||||
|
||||
Args:
|
||||
value: The value to encode the state name into.
|
||||
|
||||
Returns:
|
||||
The encoded var.
|
||||
"""
|
||||
if value._var_data:
|
||||
return f"<reflex.Var>{value._var_data.json()}</reflex.Var>" + str(value)
|
||||
return str(value)
|
||||
|
||||
|
||||
def _decode_var(value: str) -> tuple[VarData | None, str]:
|
||||
"""Decode the state name from a formatted var.
|
||||
|
||||
Args:
|
||||
value: The value to extract the state name from.
|
||||
|
||||
Returns:
|
||||
The extracted state name and the value without the state name.
|
||||
"""
|
||||
var_datas = []
|
||||
if isinstance(value, str):
|
||||
# Extract the state name from a formatted var
|
||||
while m := re.match(r"(.*)<reflex.Var>(.*)</reflex.Var>(.*)", value):
|
||||
value = m.group(1) + m.group(3)
|
||||
var_datas.append(VarData.parse_raw(m.group(2)))
|
||||
if var_datas:
|
||||
return VarData.merge(*var_datas), value
|
||||
return None, value
|
||||
|
||||
|
||||
def _extract_var_data(value: Iterable) -> list[VarData | None]:
|
||||
"""Extract the var imports and hooks from an iterable containing a Var.
|
||||
|
||||
Args:
|
||||
value: The iterable to extract the VarData from
|
||||
|
||||
Returns:
|
||||
The extracted VarDatas.
|
||||
"""
|
||||
var_datas = []
|
||||
with contextlib.suppress(TypeError):
|
||||
for sub in value:
|
||||
if isinstance(sub, Var):
|
||||
var_datas.append(sub._var_data)
|
||||
elif not isinstance(sub, str):
|
||||
# Recurse into dict values.
|
||||
if hasattr(sub, "values") and callable(sub.values):
|
||||
var_datas.extend(_extract_var_data(sub.values()))
|
||||
# Recurse into iterable values (or dict keys).
|
||||
var_datas.extend(_extract_var_data(sub))
|
||||
# Recurse when value is a dict itself.
|
||||
values = getattr(value, "values", None)
|
||||
if callable(values):
|
||||
var_datas.extend(_extract_var_data(values()))
|
||||
return var_datas
|
||||
|
||||
|
||||
class Var:
|
||||
"""An abstract var."""
|
||||
|
||||
@ -102,15 +233,18 @@ class Var:
|
||||
# The type of the var.
|
||||
_var_type: Type
|
||||
|
||||
# The name of the enclosing state.
|
||||
_var_state: str
|
||||
|
||||
# Whether this is a local javascript variable.
|
||||
_var_is_local: bool
|
||||
|
||||
# Whether the var is a string literal.
|
||||
_var_is_string: bool
|
||||
|
||||
# _var_full_name should be prefixed with _var_state
|
||||
_var_full_name_needs_state_prefix: bool
|
||||
|
||||
# Extra metadata associated with the Var
|
||||
_var_data: Optional[VarData]
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
|
||||
@ -136,9 +270,14 @@ class Var:
|
||||
if isinstance(value, Var):
|
||||
return value
|
||||
|
||||
# Try to pull the imports and hooks from contained values.
|
||||
_var_data = None
|
||||
if not isinstance(value, str):
|
||||
_var_data = VarData.merge(*_extract_var_data(value))
|
||||
|
||||
# Try to serialize the value.
|
||||
type_ = type(value)
|
||||
name = serializers.serialize(value)
|
||||
name = value if type_ in types.JSONType else serializers.serialize(value)
|
||||
if name is None:
|
||||
raise TypeError(
|
||||
f"No JSON serializer found for var {value} of type {type_}."
|
||||
@ -150,6 +289,7 @@ class Var:
|
||||
_var_type=type_,
|
||||
_var_is_local=_var_is_local,
|
||||
_var_is_string=_var_is_string,
|
||||
_var_data=_var_data,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -186,6 +326,39 @@ class Var:
|
||||
"""
|
||||
return _GenericAlias(cls, type_)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Post-initialize the var."""
|
||||
# Decode any inline Var markup and apply it to the instance
|
||||
_var_data, _var_name = _decode_var(self._var_name)
|
||||
if _var_data:
|
||||
self._var_name = _var_name
|
||||
self._var_data = VarData.merge(self._var_data, _var_data)
|
||||
|
||||
def _replace(self, merge_var_data=None, **kwargs: Any) -> Var:
|
||||
"""Make a copy of this Var with updated fields.
|
||||
|
||||
Args:
|
||||
merge_var_data: VarData to merge into the existing VarData.
|
||||
**kwargs: Var fields to update.
|
||||
|
||||
Returns:
|
||||
A new BaseVar with the updated fields overwriting the corresponding fields in this Var.
|
||||
"""
|
||||
field_values = dict(
|
||||
_var_name=kwargs.pop("_var_name", self._var_name),
|
||||
_var_type=kwargs.pop("_var_type", self._var_type),
|
||||
_var_is_local=kwargs.pop("_var_is_local", self._var_is_local),
|
||||
_var_is_string=kwargs.pop("_var_is_string", self._var_is_string),
|
||||
_var_full_name_needs_state_prefix=kwargs.pop(
|
||||
"_var_full_name_needs_state_prefix",
|
||||
self._var_full_name_needs_state_prefix,
|
||||
),
|
||||
_var_data=VarData.merge(
|
||||
kwargs.get("_var_data", self._var_data), merge_var_data
|
||||
),
|
||||
)
|
||||
return BaseVar(**field_values)
|
||||
|
||||
def _decode(self) -> Any:
|
||||
"""Decode Var as a python value.
|
||||
|
||||
@ -195,8 +368,6 @@ class Var:
|
||||
Returns:
|
||||
The decoded value or the Var name.
|
||||
"""
|
||||
if self._var_state:
|
||||
return self._var_full_name
|
||||
if self._var_is_string:
|
||||
return self._var_name
|
||||
try:
|
||||
@ -216,8 +387,10 @@ class Var:
|
||||
return (
|
||||
self._var_name == other._var_name
|
||||
and self._var_type == other._var_type
|
||||
and self._var_state == other._var_state
|
||||
and self._var_is_local == other._var_is_local
|
||||
and self._var_full_name_needs_state_prefix
|
||||
== other._var_full_name_needs_state_prefix
|
||||
and self._var_data == other._var_data
|
||||
)
|
||||
|
||||
def to_string(self, json: bool = True) -> Var:
|
||||
@ -285,9 +458,11 @@ class Var:
|
||||
Returns:
|
||||
The formatted var.
|
||||
"""
|
||||
# Encode the _var_data into the formatted output for tracking purposes.
|
||||
str_self = _encode_var(self)
|
||||
if self._var_is_local:
|
||||
return str(self)
|
||||
return f"${str(self)}"
|
||||
return str_self
|
||||
return f"${str_self}"
|
||||
|
||||
def __getitem__(self, i: Any) -> Var:
|
||||
"""Index into a var.
|
||||
@ -320,12 +495,7 @@ class Var:
|
||||
|
||||
# Convert any vars to local vars.
|
||||
if isinstance(i, Var):
|
||||
i = BaseVar(
|
||||
_var_name=i._var_name,
|
||||
_var_type=i._var_type,
|
||||
_var_state=i._var_state,
|
||||
_var_is_local=True,
|
||||
)
|
||||
i = i._replace(_var_is_local=True)
|
||||
|
||||
# Handle list/tuple/str indexing.
|
||||
if types._issubclass(self._var_type, Union[List, Tuple, str]):
|
||||
@ -344,11 +514,9 @@ class Var:
|
||||
stop = i.stop or "undefined"
|
||||
|
||||
# Use the slice function.
|
||||
return BaseVar(
|
||||
return self._replace(
|
||||
_var_name=f"{self._var_name}.slice({start}, {stop})",
|
||||
_var_type=self._var_type,
|
||||
_var_state=self._var_state,
|
||||
_var_is_local=self._var_is_local,
|
||||
_var_is_string=False,
|
||||
)
|
||||
|
||||
# Get the type of the indexed var.
|
||||
@ -359,11 +527,10 @@ class Var:
|
||||
)
|
||||
|
||||
# Use `at` to support negative indices.
|
||||
return BaseVar(
|
||||
return self._replace(
|
||||
_var_name=f"{self._var_name}.at({i})",
|
||||
_var_type=type_,
|
||||
_var_state=self._var_state,
|
||||
_var_is_local=self._var_is_local,
|
||||
_var_is_string=False,
|
||||
)
|
||||
|
||||
# Dictionary / dataframe indexing.
|
||||
@ -393,11 +560,10 @@ class Var:
|
||||
)
|
||||
|
||||
# Use normal indexing here.
|
||||
return BaseVar(
|
||||
return self._replace(
|
||||
_var_name=f"{self._var_name}[{i}]",
|
||||
_var_type=type_,
|
||||
_var_state=self._var_state,
|
||||
_var_is_local=self._var_is_local,
|
||||
_var_is_string=False,
|
||||
)
|
||||
|
||||
def __getattr__(self, name: str) -> Var:
|
||||
@ -423,11 +589,10 @@ class Var:
|
||||
type_ = types.get_attribute_access_type(self._var_type, name)
|
||||
|
||||
if type_ is not None:
|
||||
return BaseVar(
|
||||
return self._replace(
|
||||
_var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
|
||||
_var_type=type_,
|
||||
_var_state=self._var_state,
|
||||
_var_is_local=self._var_is_local,
|
||||
_var_is_string=False,
|
||||
)
|
||||
|
||||
if name in REPLACED_NAMES:
|
||||
@ -519,10 +684,12 @@ class Var:
|
||||
else f"{self._var_full_name}.{fn}()"
|
||||
)
|
||||
|
||||
return BaseVar(
|
||||
return self._replace(
|
||||
_var_name=operation_name,
|
||||
_var_type=type_,
|
||||
_var_is_local=self._var_is_local,
|
||||
_var_is_string=False,
|
||||
_var_full_name_needs_state_prefix=False,
|
||||
merge_var_data=other._var_data if other is not None else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -602,10 +769,10 @@ class Var:
|
||||
"""
|
||||
if not types._issubclass(self._var_type, List):
|
||||
raise TypeError(f"Cannot get length of non-list var {self}.")
|
||||
return BaseVar(
|
||||
_var_name=f"{self._var_full_name}.length",
|
||||
return self._replace(
|
||||
_var_name=f"{self._var_name}.length",
|
||||
_var_type=int,
|
||||
_var_is_local=self._var_is_local,
|
||||
_var_is_string=False,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Var) -> Var:
|
||||
@ -692,7 +859,17 @@ class Var:
|
||||
types.get_base_class(self._var_type) == list
|
||||
and types.get_base_class(other_type) == list
|
||||
):
|
||||
return self.operation(",", other, fn="spreadArraysOrObjects", flip=flip)
|
||||
return self.operation(
|
||||
",", other, fn="spreadArraysOrObjects", flip=flip
|
||||
)._replace(
|
||||
merge_var_data=VarData(
|
||||
imports={
|
||||
f"/{constants.Dirs.STATE_PATH}": [
|
||||
ImportVar(tag="spreadArraysOrObjects")
|
||||
]
|
||||
},
|
||||
),
|
||||
)
|
||||
return self.operation("+", other, flip=flip)
|
||||
|
||||
def __radd__(self, other: Var) -> Var:
|
||||
@ -755,10 +932,11 @@ class Var:
|
||||
]:
|
||||
other_name = other._var_full_name if isinstance(other, Var) else other
|
||||
name = f"Array({other_name}).fill().map(() => {self._var_full_name}).flat()"
|
||||
return BaseVar(
|
||||
return self._replace(
|
||||
_var_name=name,
|
||||
_var_type=str,
|
||||
_var_is_local=self._var_is_local,
|
||||
_var_is_string=False,
|
||||
_var_full_name_needs_state_prefix=False,
|
||||
)
|
||||
|
||||
return self.operation("*", other)
|
||||
@ -1003,10 +1181,11 @@ class Var:
|
||||
elif not isinstance(other, Var):
|
||||
other = Var.create(other)
|
||||
if types._issubclass(self._var_type, Dict):
|
||||
return BaseVar(
|
||||
_var_name=f"{self._var_full_name}.{method}({other._var_full_name})",
|
||||
return self._replace(
|
||||
_var_name=f"{self._var_name}.{method}({other._var_full_name})",
|
||||
_var_type=bool,
|
||||
_var_is_local=self._var_is_local,
|
||||
_var_is_string=False,
|
||||
merge_var_data=other._var_data,
|
||||
)
|
||||
else: # str, list, tuple
|
||||
# For strings, the left operand must be a string.
|
||||
@ -1016,10 +1195,11 @@ class Var:
|
||||
raise TypeError(
|
||||
f"'in <string>' requires string as left operand, not {other._var_type}"
|
||||
)
|
||||
return BaseVar(
|
||||
_var_name=f"{self._var_full_name}.includes({other._var_full_name})",
|
||||
return self._replace(
|
||||
_var_name=f"{self._var_name}.includes({other._var_full_name})",
|
||||
_var_type=bool,
|
||||
_var_is_local=self._var_is_local,
|
||||
_var_is_string=False,
|
||||
merge_var_data=other._var_data,
|
||||
)
|
||||
|
||||
def reverse(self) -> Var:
|
||||
@ -1034,10 +1214,10 @@ class Var:
|
||||
if not types._issubclass(self._var_type, list):
|
||||
raise TypeError(f"Cannot reverse non-list var {self._var_full_name}.")
|
||||
|
||||
return BaseVar(
|
||||
return self._replace(
|
||||
_var_name=f"[...{self._var_full_name}].reverse()",
|
||||
_var_type=self._var_type,
|
||||
_var_is_local=self._var_is_local,
|
||||
_var_is_string=False,
|
||||
_var_full_name_needs_state_prefix=False,
|
||||
)
|
||||
|
||||
def lower(self) -> Var:
|
||||
@ -1054,10 +1234,10 @@ class Var:
|
||||
f"Cannot convert non-string var {self._var_full_name} to lowercase."
|
||||
)
|
||||
|
||||
return BaseVar(
|
||||
_var_name=f"{self._var_full_name}.toLowerCase()",
|
||||
return self._replace(
|
||||
_var_name=f"{self._var_name}.toLowerCase()",
|
||||
_var_is_string=False,
|
||||
_var_type=str,
|
||||
_var_is_local=self._var_is_local,
|
||||
)
|
||||
|
||||
def upper(self) -> Var:
|
||||
@ -1074,10 +1254,10 @@ class Var:
|
||||
f"Cannot convert non-string var {self._var_full_name} to uppercase."
|
||||
)
|
||||
|
||||
return BaseVar(
|
||||
_var_name=f"{self._var_full_name}.toUpperCase()",
|
||||
return self._replace(
|
||||
_var_name=f"{self._var_name}.toUpperCase()",
|
||||
_var_is_string=False,
|
||||
_var_type=str,
|
||||
_var_is_local=self._var_is_local,
|
||||
)
|
||||
|
||||
def split(self, other: str | Var[str] = " ") -> Var:
|
||||
@ -1097,10 +1277,11 @@ class Var:
|
||||
|
||||
other = Var.create_safe(json.dumps(other)) if isinstance(other, str) else other
|
||||
|
||||
return BaseVar(
|
||||
_var_name=f"{self._var_full_name}.split({other._var_full_name})",
|
||||
return self._replace(
|
||||
_var_name=f"{self._var_name}.split({other._var_full_name})",
|
||||
_var_is_string=False,
|
||||
_var_type=list[str],
|
||||
_var_is_local=self._var_is_local,
|
||||
merge_var_data=other._var_data,
|
||||
)
|
||||
|
||||
def join(self, other: str | Var[str] | None = None) -> Var:
|
||||
@ -1125,10 +1306,11 @@ class Var:
|
||||
else:
|
||||
other = Var.create_safe(other)
|
||||
|
||||
return BaseVar(
|
||||
_var_name=f"{self._var_full_name}.join({other._var_full_name})",
|
||||
return self._replace(
|
||||
_var_name=f"{self._var_name}.join({other._var_full_name})",
|
||||
_var_is_string=False,
|
||||
_var_type=str,
|
||||
_var_is_local=self._var_is_local,
|
||||
merge_var_data=other._var_data,
|
||||
)
|
||||
|
||||
def foreach(self, fn: Callable) -> Var:
|
||||
@ -1159,10 +1341,9 @@ class Var:
|
||||
fn_signature = inspect.signature(fn)
|
||||
fn_args = (arg, index)
|
||||
fn_ret = fn(*fn_args[: len(fn_signature.parameters)])
|
||||
return BaseVar(
|
||||
return self._replace(
|
||||
_var_name=f"{self._var_full_name}.map(({arg._var_name}, {index._var_name}) => {fn_ret})",
|
||||
_var_type=self._var_type,
|
||||
_var_is_local=self._var_is_local,
|
||||
_var_is_string=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1207,6 +1388,18 @@ class Var:
|
||||
_var_name=f"Array.from(range({v1._var_full_name}, {v2._var_full_name}, {step._var_name}))",
|
||||
_var_type=list[int],
|
||||
_var_is_local=False,
|
||||
_var_data=VarData.merge(
|
||||
v1._var_data,
|
||||
v2._var_data,
|
||||
step._var_data,
|
||||
VarData(
|
||||
imports={
|
||||
"/utils/helpers/range.js": [
|
||||
ImportVar(tag="range", is_default=True),
|
||||
],
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def to(self, type_: Type) -> Var:
|
||||
@ -1218,12 +1411,7 @@ class Var:
|
||||
Returns:
|
||||
The converted var.
|
||||
"""
|
||||
return BaseVar(
|
||||
_var_name=self._var_name,
|
||||
_var_type=type_,
|
||||
_var_state=self._var_state,
|
||||
_var_is_local=self._var_is_local,
|
||||
)
|
||||
return self._replace(_var_type=type_)
|
||||
|
||||
@property
|
||||
def _var_full_name(self) -> str:
|
||||
@ -1232,24 +1420,51 @@ class Var:
|
||||
Returns:
|
||||
The full name of the var.
|
||||
"""
|
||||
if not self._var_full_name_needs_state_prefix:
|
||||
return self._var_name
|
||||
return (
|
||||
self._var_name
|
||||
if self._var_state == ""
|
||||
else ".".join([self._var_state, self._var_name])
|
||||
if self._var_data is None or self._var_data.state == ""
|
||||
else ".".join(
|
||||
[format.format_state_name(self._var_data.state), self._var_name]
|
||||
)
|
||||
)
|
||||
|
||||
def _var_set_state(self, state: Type[State]) -> Any:
|
||||
def _var_set_state(self, state: Type[State] | str) -> Any:
|
||||
"""Set the state of the var.
|
||||
|
||||
Args:
|
||||
state: The state to set.
|
||||
state: The state to set or the full name of the state.
|
||||
|
||||
Returns:
|
||||
The var with the set state.
|
||||
"""
|
||||
self._var_state = state.get_full_name()
|
||||
state_name = state if isinstance(state, str) else state.get_full_name()
|
||||
new_var_data = VarData(
|
||||
state=state_name,
|
||||
hooks={
|
||||
"const {0} = useContext(StateContexts.{0})".format(
|
||||
format.format_state_name(state_name)
|
||||
)
|
||||
},
|
||||
imports={
|
||||
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
|
||||
"react": [ImportVar(tag="useContext")],
|
||||
},
|
||||
)
|
||||
self._var_data = VarData.merge(self._var_data, new_var_data)
|
||||
self._var_full_name_needs_state_prefix = True
|
||||
return self
|
||||
|
||||
@property
|
||||
def _var_state(self) -> str:
|
||||
"""Compat method for getting the state.
|
||||
|
||||
Returns:
|
||||
The state name associated with the var.
|
||||
"""
|
||||
return self._var_data.state if self._var_data else ""
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
@ -1264,15 +1479,18 @@ class BaseVar(Var):
|
||||
# The type of the var.
|
||||
_var_type: Type = dataclasses.field(default=Any)
|
||||
|
||||
# The name of the enclosing state.
|
||||
_var_state: str = dataclasses.field(default="")
|
||||
|
||||
# Whether this is a local javascript variable.
|
||||
_var_is_local: bool = dataclasses.field(default=False)
|
||||
|
||||
# Whether the var is a string literal.
|
||||
_var_is_string: bool = dataclasses.field(default=False)
|
||||
|
||||
# _var_full_name should be prefixed with _var_state
|
||||
_var_full_name_needs_state_prefix: bool = dataclasses.field(default=False)
|
||||
|
||||
# Extra metadata associated with the Var
|
||||
_var_data: Optional[VarData] = dataclasses.field(default=None)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Define a hash function for a var.
|
||||
|
||||
@ -1334,9 +1552,11 @@ class BaseVar(Var):
|
||||
The name of the setter function.
|
||||
"""
|
||||
setter = constants.SETTER_PREFIX + self._var_name
|
||||
if not include_state or self._var_state == "":
|
||||
if self._var_data is None:
|
||||
return setter
|
||||
return ".".join((self._var_state, setter))
|
||||
if not include_state or self._var_data.state == "":
|
||||
return setter
|
||||
return ".".join((self._var_data.state, setter))
|
||||
|
||||
def get_setter(self) -> Callable[[State, Any], None]:
|
||||
"""Get the var's setter function.
|
||||
@ -1550,48 +1770,6 @@ def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
|
||||
return cvar
|
||||
|
||||
|
||||
class ImportVar(Base):
|
||||
"""An import var."""
|
||||
|
||||
# The name of the import tag.
|
||||
tag: Optional[str]
|
||||
|
||||
# whether the import is default or named.
|
||||
is_default: Optional[bool] = False
|
||||
|
||||
# The tag alias.
|
||||
alias: Optional[str] = None
|
||||
|
||||
# Whether this import need to install the associated lib
|
||||
install: Optional[bool] = True
|
||||
|
||||
# whether this import should be rendered or not
|
||||
render: Optional[bool] = True
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the import.
|
||||
|
||||
Returns:
|
||||
The name(tag name with alias) of tag.
|
||||
"""
|
||||
return self.tag if not self.alias else " as ".join([self.tag, self.alias]) # type: ignore
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Define a hash function for the import var.
|
||||
|
||||
Returns:
|
||||
The hash of the var.
|
||||
"""
|
||||
return hash((self.tag, self.is_default, self.alias, self.install, self.render))
|
||||
|
||||
|
||||
class NoRenderImportVar(ImportVar):
|
||||
"""A import that doesn't need to be rendered."""
|
||||
|
||||
render: Optional[bool] = False
|
||||
|
||||
|
||||
class CallableVar(BaseVar):
|
||||
"""Decorate a Var-returning function to act as both a Var and a function.
|
||||
|
||||
|
@ -6,11 +6,13 @@ from reflex import constants as constants
|
||||
from reflex.base import Base as Base
|
||||
from reflex.state import State as State
|
||||
from reflex.utils import console as console, format as format, types as types
|
||||
from reflex.utils.imports import ImportVar
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
@ -22,13 +24,24 @@ from typing import (
|
||||
USED_VARIABLES: Incomplete
|
||||
|
||||
def get_unique_variable_name() -> str: ...
|
||||
def _encode_var(value: Var) -> str: ...
|
||||
def _decode_var(value: str) -> tuple[VarData, str]: ...
|
||||
def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
|
||||
|
||||
class VarData(Base):
|
||||
state: str
|
||||
imports: dict[str, set[ImportVar]]
|
||||
hooks: set[str]
|
||||
@classmethod
|
||||
def merge(cls, *others: VarData | None) -> VarData | None: ...
|
||||
|
||||
class Var:
|
||||
_var_name: str
|
||||
_var_type: Type
|
||||
_var_state: str = ""
|
||||
_var_is_local: bool = False
|
||||
_var_is_string: bool = False
|
||||
_var_full_name_needs_state_prefix: bool = False
|
||||
_var_data: VarData | None = None
|
||||
@classmethod
|
||||
def create(
|
||||
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
|
||||
@ -38,7 +51,8 @@ class Var:
|
||||
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
|
||||
) -> Var: ...
|
||||
@classmethod
|
||||
def __class_getitem__(cls, type_: str) -> _GenericAlias: ...
|
||||
def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...
|
||||
def _replace(self, merge_var_data=None, **kwargs: Any) -> Var: ...
|
||||
def equals(self, other: Var) -> bool: ...
|
||||
def to_string(self) -> Var: ...
|
||||
def __hash__(self) -> int: ...
|
||||
@ -95,15 +109,16 @@ class Var:
|
||||
def to(self, type_: Type) -> Var: ...
|
||||
@property
|
||||
def _var_full_name(self) -> str: ...
|
||||
def _var_set_state(self, state: Type[State]) -> Any: ...
|
||||
def _var_set_state(self, state: Type[State] | str) -> Any: ...
|
||||
|
||||
@dataclass(eq=False)
|
||||
class BaseVar(Var):
|
||||
_var_name: str
|
||||
_var_type: Any
|
||||
_var_state: str = ""
|
||||
_var_is_local: bool = False
|
||||
_var_is_string: bool = False
|
||||
_var_full_name_needs_state_prefix: bool = False
|
||||
_var_data: VarData | None = None
|
||||
def __hash__(self) -> int: ...
|
||||
def get_default_value(self) -> Any: ...
|
||||
def get_setter_name(self, include_state: bool = ...) -> str: ...
|
||||
@ -123,21 +138,6 @@ class ComputedVar(Var):
|
||||
|
||||
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
|
||||
|
||||
class ImportVar(Base):
|
||||
tag: Optional[str]
|
||||
is_default: Optional[bool] = False
|
||||
alias: Optional[str] = None
|
||||
install: Optional[bool] = True
|
||||
render: Optional[bool] = True
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
def __hash__(self) -> int: ...
|
||||
|
||||
class NoRenderImportVar(ImportVar):
|
||||
"""A import that doesn't need to be rendered."""
|
||||
|
||||
def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ...
|
||||
|
||||
class CallableVar(BaseVar):
|
||||
def __init__(self, fn: Callable[..., BaseVar]): ...
|
||||
def __call__(self, *args, **kwargs) -> BaseVar: ...
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from reflex.compiler import compiler, utils
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar
|
||||
from reflex.utils.imports import ImportVar
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -110,7 +110,7 @@ def test_cond_no_else():
|
||||
|
||||
# Props do not support the use of cond without else
|
||||
with pytest.raises(ValueError):
|
||||
cond(True, "hello")
|
||||
cond(True, "hello") # type: ignore
|
||||
|
||||
|
||||
def test_mobile_only():
|
||||
|
@ -4,14 +4,16 @@ import pytest
|
||||
|
||||
import reflex as rx
|
||||
from reflex.base import Base
|
||||
from reflex.components.base.bare import Bare
|
||||
from reflex.components.component import Component, CustomComponent, custom_component
|
||||
from reflex.components.layout.box import Box
|
||||
from reflex.constants import EventTriggers
|
||||
from reflex.event import EventHandler
|
||||
from reflex.event import EventChain, EventHandler
|
||||
from reflex.state import State
|
||||
from reflex.style import Style
|
||||
from reflex.utils import imports
|
||||
from reflex.vars import ImportVar, Var
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import Var, VarData
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -600,3 +602,161 @@ def test_format_component(component, rendered):
|
||||
rendered: The expected rendered component.
|
||||
"""
|
||||
assert str(component) == rendered
|
||||
|
||||
|
||||
TEST_VAR = Var.create_safe("test")._replace(
|
||||
merge_var_data=VarData(
|
||||
hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test"
|
||||
)
|
||||
)
|
||||
FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar")
|
||||
STYLE_VAR = TEST_VAR._replace(_var_name="style", _var_is_local=False)
|
||||
EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain)
|
||||
ARG_VAR = Var.create("arg")
|
||||
|
||||
|
||||
class EventState(rx.State):
|
||||
"""State for testing event handlers with _get_vars."""
|
||||
|
||||
v: int = 42
|
||||
|
||||
def handler(self):
|
||||
"""A handler that does nothing."""
|
||||
|
||||
def handler2(self, arg):
|
||||
"""A handler that takes an arg.
|
||||
|
||||
Args:
|
||||
arg: An arg.
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("component", "exp_vars"),
|
||||
(
|
||||
pytest.param(
|
||||
Bare.create(TEST_VAR),
|
||||
[TEST_VAR],
|
||||
id="direct-bare",
|
||||
),
|
||||
pytest.param(
|
||||
Bare.create(f"foo{TEST_VAR}bar"),
|
||||
[FORMATTED_TEST_VAR],
|
||||
id="fstring-bare",
|
||||
),
|
||||
pytest.param(
|
||||
rx.text(as_=TEST_VAR),
|
||||
[TEST_VAR],
|
||||
id="direct-prop",
|
||||
),
|
||||
pytest.param(
|
||||
rx.text(as_=f"foo{TEST_VAR}bar"),
|
||||
[FORMATTED_TEST_VAR],
|
||||
id="fstring-prop",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(id=TEST_VAR),
|
||||
[TEST_VAR],
|
||||
id="direct-id",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(id=f"foo{TEST_VAR}bar"),
|
||||
[FORMATTED_TEST_VAR],
|
||||
id="fstring-id",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(key=TEST_VAR),
|
||||
[TEST_VAR],
|
||||
id="direct-key",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(key=f"foo{TEST_VAR}bar"),
|
||||
[FORMATTED_TEST_VAR],
|
||||
id="fstring-key",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(class_name=TEST_VAR),
|
||||
[TEST_VAR],
|
||||
id="direct-class_name",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(class_name=f"foo{TEST_VAR}bar"),
|
||||
[FORMATTED_TEST_VAR],
|
||||
id="fstring-class_name",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(special_props={TEST_VAR}),
|
||||
[TEST_VAR],
|
||||
id="direct-special_props",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(special_props={Var.create(f"foo{TEST_VAR}bar")}),
|
||||
[FORMATTED_TEST_VAR],
|
||||
id="fstring-special_props",
|
||||
),
|
||||
pytest.param(
|
||||
# custom_attrs cannot accept a Var directly as a value
|
||||
rx.fragment(custom_attrs={"href": f"{TEST_VAR}"}),
|
||||
[TEST_VAR],
|
||||
id="fstring-custom_attrs-nofmt",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(custom_attrs={"href": f"foo{TEST_VAR}bar"}),
|
||||
[FORMATTED_TEST_VAR],
|
||||
id="fstring-custom_attrs",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(background_color=TEST_VAR),
|
||||
[STYLE_VAR],
|
||||
id="direct-background_color",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(background_color=f"foo{TEST_VAR}bar"),
|
||||
[STYLE_VAR],
|
||||
id="fstring-background_color",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(style={"background_color": TEST_VAR}), # type: ignore
|
||||
[STYLE_VAR],
|
||||
id="direct-style-background_color",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(style={"background_color": f"foo{TEST_VAR}bar"}), # type: ignore
|
||||
[STYLE_VAR],
|
||||
id="fstring-style-background_color",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(on_click=EVENT_CHAIN_VAR), # type: ignore
|
||||
[EVENT_CHAIN_VAR],
|
||||
id="direct-event-chain",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(on_click=EventState.handler),
|
||||
[],
|
||||
id="direct-event-handler",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(on_click=EventState.handler2(TEST_VAR)), # type: ignore
|
||||
[ARG_VAR, TEST_VAR],
|
||||
id="direct-event-handler-arg",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(on_click=EventState.handler2(EventState.v)), # type: ignore
|
||||
[ARG_VAR, EventState.v],
|
||||
id="direct-event-handler-arg2",
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(on_click=lambda: EventState.handler2(TEST_VAR)), # type: ignore
|
||||
[ARG_VAR, TEST_VAR],
|
||||
id="direct-event-handler-lambda",
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_get_vars(component, exp_vars):
|
||||
comp_vars = sorted(component._get_vars(), key=lambda v: v._var_name)
|
||||
assert len(comp_vars) == len(exp_vars)
|
||||
for comp_var, exp_var in zip(
|
||||
comp_vars,
|
||||
sorted(exp_vars, key=lambda v: v._var_name),
|
||||
):
|
||||
assert comp_var.equals(exp_var)
|
||||
|
@ -104,7 +104,7 @@ async def test_preprocess(
|
||||
app=app, event=request.getfixturevalue(event_fixture), state=state
|
||||
)
|
||||
assert isinstance(update, StateUpdate)
|
||||
assert update.delta == {state.get_name(): state.dict()}
|
||||
assert update.delta == state.dict()
|
||||
events = update.events
|
||||
assert len(events) == 2
|
||||
|
||||
@ -133,7 +133,7 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
|
||||
|
||||
update = await hydrate_middleware.preprocess(app=app, event=event1, state=state)
|
||||
assert isinstance(update, StateUpdate)
|
||||
assert update.delta == {"test_state": state.dict()}
|
||||
assert update.delta == state.dict()
|
||||
assert len(update.events) == 3
|
||||
|
||||
# Apply the events.
|
||||
@ -163,7 +163,7 @@ async def test_preprocess_no_events(hydrate_middleware, event1):
|
||||
state=state,
|
||||
)
|
||||
assert isinstance(update, StateUpdate)
|
||||
assert update.delta == {"test_state": state.dict()}
|
||||
assert update.delta == state.dict()
|
||||
assert len(update.events) == 1
|
||||
assert isinstance(update, StateUpdate)
|
||||
|
||||
|
@ -769,9 +769,7 @@ async def test_upload_file(tmp_path, state, delta, token: str):
|
||||
)
|
||||
|
||||
current_state = await app.state_manager.get_state(token)
|
||||
state_dict = current_state.dict()
|
||||
for substate in state.get_full_name().split(".")[1:]:
|
||||
state_dict = state_dict[substate]
|
||||
state_dict = current_state.dict()[state.get_full_name()]
|
||||
assert state_dict["img_list"] == [
|
||||
"image1.jpg",
|
||||
"image2.jpg",
|
||||
|
@ -324,11 +324,17 @@ def test_dict(test_state):
|
||||
Args:
|
||||
test_state: A state.
|
||||
"""
|
||||
substates = {"child_state", "child_state2"}
|
||||
assert set(test_state.dict().keys()) == set(test_state.vars.keys()) | substates
|
||||
assert (
|
||||
set(test_state.dict(include_computed=False).keys())
|
||||
== set(test_state.base_vars) | substates
|
||||
substates = {
|
||||
"test_state",
|
||||
"test_state.child_state",
|
||||
"test_state.child_state.grandchild_state",
|
||||
"test_state.child_state2",
|
||||
}
|
||||
test_state_dict = test_state.dict()
|
||||
assert set(test_state_dict) == substates
|
||||
assert set(test_state_dict[test_state.get_name()]) == set(test_state.vars)
|
||||
assert set(test_state.dict(include_computed=False)[test_state.get_name()]) == set(
|
||||
test_state.base_vars
|
||||
)
|
||||
|
||||
|
||||
@ -1081,9 +1087,9 @@ def test_computed_var_cached():
|
||||
return self.v
|
||||
|
||||
cs = ComputedState()
|
||||
assert cs.dict()["v"] == 0
|
||||
assert cs.dict()[cs.get_full_name()]["v"] == 0
|
||||
assert comp_v_calls == 1
|
||||
assert cs.dict()["comp_v"] == 0
|
||||
assert cs.dict()[cs.get_full_name()]["comp_v"] == 0
|
||||
assert comp_v_calls == 1
|
||||
assert cs.comp_v == 0
|
||||
assert comp_v_calls == 1
|
||||
@ -1156,24 +1162,27 @@ def test_computed_var_depends_on_parent_non_cached():
|
||||
assert ps.dirty_vars == set()
|
||||
assert cs.dirty_vars == set()
|
||||
|
||||
assert ps.dict() == {
|
||||
cs.get_name(): {"dep_v": 2},
|
||||
dict1 = ps.dict()
|
||||
assert dict1[ps.get_full_name()] == {
|
||||
"no_cache_v": 1,
|
||||
CompileVars.IS_HYDRATED: False,
|
||||
"router": formatted_router,
|
||||
}
|
||||
assert ps.dict() == {
|
||||
cs.get_name(): {"dep_v": 4},
|
||||
assert dict1[cs.get_full_name()] == {"dep_v": 2}
|
||||
dict2 = ps.dict()
|
||||
assert dict2[ps.get_full_name()] == {
|
||||
"no_cache_v": 3,
|
||||
CompileVars.IS_HYDRATED: False,
|
||||
"router": formatted_router,
|
||||
}
|
||||
assert ps.dict() == {
|
||||
cs.get_name(): {"dep_v": 6},
|
||||
assert dict2[cs.get_full_name()] == {"dep_v": 4}
|
||||
dict3 = ps.dict()
|
||||
assert dict3[ps.get_full_name()] == {
|
||||
"no_cache_v": 5,
|
||||
CompileVars.IS_HYDRATED: False,
|
||||
"router": formatted_router,
|
||||
}
|
||||
assert dict3[cs.get_full_name()] == {"dep_v": 6}
|
||||
assert counter == 6
|
||||
|
||||
|
||||
@ -2201,13 +2210,13 @@ def test_json_dumps_with_mutables():
|
||||
items: List[Foo] = [Foo()]
|
||||
|
||||
dict_val = MutableContainsBase().dict()
|
||||
assert isinstance(dict_val["items"][0], dict)
|
||||
assert isinstance(dict_val[MutableContainsBase.get_full_name()]["items"][0], dict)
|
||||
val = json_dumps(dict_val)
|
||||
f_items = '[{"tags": ["123", "456"]}]'
|
||||
f_formatted_router = str(formatted_router).replace("'", '"')
|
||||
assert (
|
||||
val
|
||||
== f'{{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}'
|
||||
== f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}'
|
||||
)
|
||||
|
||||
|
||||
|
@ -22,7 +22,8 @@ def test_convert(style_dict, expected):
|
||||
style_dict: The style to check.
|
||||
expected: The expected formatted style.
|
||||
"""
|
||||
assert style.convert(style_dict) == expected
|
||||
converted_dict, _var_data = style.convert(style_dict)
|
||||
assert converted_dict == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -7,18 +7,20 @@ from pandas import DataFrame
|
||||
|
||||
from reflex.base import Base
|
||||
from reflex.state import State
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import (
|
||||
BaseVar,
|
||||
ComputedVar,
|
||||
ImportVar,
|
||||
Var,
|
||||
)
|
||||
|
||||
test_vars = [
|
||||
BaseVar(_var_name="prop1", _var_type=int),
|
||||
BaseVar(_var_name="key", _var_type=str),
|
||||
BaseVar(_var_name="value", _var_type=str, _var_state="state"),
|
||||
BaseVar(_var_name="local", _var_type=str, _var_state="state", _var_is_local=True),
|
||||
BaseVar(_var_name="value", _var_type=str)._var_set_state("state"),
|
||||
BaseVar(_var_name="local", _var_type=str, _var_is_local=True)._var_set_state(
|
||||
"state"
|
||||
),
|
||||
BaseVar(_var_name="local2", _var_type=str, _var_is_local=True),
|
||||
]
|
||||
|
||||
@ -263,7 +265,7 @@ def test_basic_operations(TestObj):
|
||||
assert str(v([1, 2, 3])[v(0)]) == "{[1, 2, 3].at(0)}"
|
||||
assert str(v({"a": 1, "b": 2})["a"]) == '{{"a": 1, "b": 2}["a"]}'
|
||||
assert (
|
||||
str(BaseVar(_var_name="foo", _var_state="state", _var_type=TestObj).bar)
|
||||
str(BaseVar(_var_name="foo", _var_type=TestObj)._var_set_state("state").bar)
|
||||
== "{state.foo.bar}"
|
||||
)
|
||||
assert str(abs(v(1))) == "{Math.abs(1)}"
|
||||
@ -274,7 +276,7 @@ def test_basic_operations(TestObj):
|
||||
assert str(v([1, 2, 3]).reverse()) == "{[...[1, 2, 3]].reverse()}"
|
||||
assert str(v(["1", "2", "3"]).reverse()) == '{[...["1", "2", "3"]].reverse()}'
|
||||
assert (
|
||||
str(BaseVar(_var_name="foo", _var_state="state", _var_type=list).reverse())
|
||||
str(BaseVar(_var_name="foo", _var_type=list)._var_set_state("state").reverse())
|
||||
== "{[...state.foo].reverse()}"
|
||||
)
|
||||
assert (
|
||||
@ -288,11 +290,14 @@ def test_basic_operations(TestObj):
|
||||
[
|
||||
(v([1, 2, 3]), "[1, 2, 3]"),
|
||||
(v(["1", "2", "3"]), '["1", "2", "3"]'),
|
||||
(BaseVar(_var_name="foo", _var_state="state", _var_type=list), "state.foo"),
|
||||
(BaseVar(_var_name="foo", _var_type=list)._var_set_state("state"), "state.foo"),
|
||||
(BaseVar(_var_name="foo", _var_type=list), "foo"),
|
||||
(v((1, 2, 3)), "[1, 2, 3]"),
|
||||
(v(("1", "2", "3")), '["1", "2", "3"]'),
|
||||
(BaseVar(_var_name="foo", _var_state="state", _var_type=tuple), "state.foo"),
|
||||
(
|
||||
BaseVar(_var_name="foo", _var_type=tuple)._var_set_state("state"),
|
||||
"state.foo",
|
||||
),
|
||||
(BaseVar(_var_name="foo", _var_type=tuple), "foo"),
|
||||
],
|
||||
)
|
||||
@ -301,7 +306,7 @@ def test_list_tuple_contains(var, expected):
|
||||
assert str(var.contains("1")) == f'{{{expected}.includes("1")}}'
|
||||
assert str(var.contains(v(1))) == f"{{{expected}.includes(1)}}"
|
||||
assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}'
|
||||
other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
|
||||
other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
|
||||
other_var = BaseVar(_var_name="other", _var_type=str)
|
||||
assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}"
|
||||
assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}"
|
||||
@ -311,14 +316,14 @@ def test_list_tuple_contains(var, expected):
|
||||
"var, expected",
|
||||
[
|
||||
(v("123"), json.dumps("123")),
|
||||
(BaseVar(_var_name="foo", _var_state="state", _var_type=str), "state.foo"),
|
||||
(BaseVar(_var_name="foo", _var_type=str)._var_set_state("state"), "state.foo"),
|
||||
(BaseVar(_var_name="foo", _var_type=str), "foo"),
|
||||
],
|
||||
)
|
||||
def test_str_contains(var, expected):
|
||||
assert str(var.contains("1")) == f'{{{expected}.includes("1")}}'
|
||||
assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}'
|
||||
other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
|
||||
other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
|
||||
other_var = BaseVar(_var_name="other", _var_type=str)
|
||||
assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}"
|
||||
assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}"
|
||||
@ -328,7 +333,7 @@ def test_str_contains(var, expected):
|
||||
"var, expected",
|
||||
[
|
||||
(v({"a": 1, "b": 2}), '{"a": 1, "b": 2}'),
|
||||
(BaseVar(_var_name="foo", _var_state="state", _var_type=dict), "state.foo"),
|
||||
(BaseVar(_var_name="foo", _var_type=dict)._var_set_state("state"), "state.foo"),
|
||||
(BaseVar(_var_name="foo", _var_type=dict), "foo"),
|
||||
],
|
||||
)
|
||||
@ -337,7 +342,7 @@ def test_dict_contains(var, expected):
|
||||
assert str(var.contains("1")) == f'{{{expected}.hasOwnProperty("1")}}'
|
||||
assert str(var.contains(v(1))) == f"{{{expected}.hasOwnProperty(1)}}"
|
||||
assert str(var.contains(v("1"))) == f'{{{expected}.hasOwnProperty("1")}}'
|
||||
other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str)
|
||||
other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state")
|
||||
other_var = BaseVar(_var_name="other", _var_type=str)
|
||||
assert (
|
||||
str(var.contains(other_state_var))
|
||||
@ -548,10 +553,10 @@ def test_var_unsupported_indexing_dicts(var, index):
|
||||
"fixture,full_name",
|
||||
[
|
||||
("ParentState", "parent_state.var_without_annotation"),
|
||||
("ChildState", "parent_state.child_state.var_without_annotation"),
|
||||
("ChildState", "parent_state__child_state.var_without_annotation"),
|
||||
(
|
||||
"GrandChildState",
|
||||
"parent_state.child_state.grand_child_state.var_without_annotation",
|
||||
"parent_state__child_state__grand_child_state.var_without_annotation",
|
||||
),
|
||||
("StateWithAnyVar", "state_with_any_var.var_without_annotation"),
|
||||
],
|
||||
@ -630,8 +635,8 @@ def test_import_var(import_var, expected):
|
||||
[
|
||||
(f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
|
||||
(
|
||||
f"testing f-string with {BaseVar(_var_name='myvar', _var_state='state', _var_type=int)}",
|
||||
"testing f-string with ${state.myvar}",
|
||||
f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
|
||||
'testing f-string with $<reflex.Var>{"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}</reflex.Var>{state.myvar}',
|
||||
),
|
||||
(
|
||||
f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",
|
||||
@ -643,6 +648,35 @@ def test_fstrings(out, expected):
|
||||
assert out == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "expect_state"),
|
||||
[
|
||||
([1], ""),
|
||||
({"a": 1}, ""),
|
||||
([Var.create_safe(1)._var_set_state("foo")], "foo"),
|
||||
({"a": Var.create_safe(1)._var_set_state("foo")}, "foo"),
|
||||
],
|
||||
)
|
||||
def test_extract_state_from_container(value, expect_state):
|
||||
"""Test that _var_state is extracted from containers containing BaseVar.
|
||||
|
||||
Args:
|
||||
value: The value to create a var from.
|
||||
expect_state: The expected state.
|
||||
"""
|
||||
assert Var.create_safe(value)._var_state == expect_state
|
||||
|
||||
|
||||
def test_fstring_roundtrip():
|
||||
"""Test that f-string roundtrip carries state."""
|
||||
var = BaseVar.create_safe("var")._var_set_state("state")
|
||||
rt_var = Var.create_safe(f"{var}")
|
||||
assert var._var_state == rt_var._var_state
|
||||
assert var._var_full_name_needs_state_prefix
|
||||
assert not rt_var._var_full_name_needs_state_prefix
|
||||
assert rt_var._var_name == var._var_full_name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"var",
|
||||
[
|
||||
|
@ -8,7 +8,13 @@ from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent
|
||||
from reflex.style import Style
|
||||
from reflex.utils import format
|
||||
from reflex.vars import BaseVar, Var
|
||||
from tests.test_state import ChildState, DateTimeState, GrandchildState, TestState
|
||||
from tests.test_state import (
|
||||
ChildState,
|
||||
ChildState2,
|
||||
DateTimeState,
|
||||
GrandchildState,
|
||||
TestState,
|
||||
)
|
||||
|
||||
|
||||
def mock_event(arg):
|
||||
@ -349,7 +355,6 @@ def test_format_cond(condition: str, true_value: str, false_value: str, expected
|
||||
BaseVar(
|
||||
_var_name="_",
|
||||
_var_type=Any,
|
||||
_var_state="",
|
||||
_var_is_local=True,
|
||||
_var_is_string=False,
|
||||
),
|
||||
@ -515,40 +520,44 @@ formatted_router = {
|
||||
(
|
||||
TestState().dict(), # type: ignore
|
||||
{
|
||||
"array": [1, 2, 3.14],
|
||||
"child_state": {
|
||||
TestState.get_full_name(): {
|
||||
"array": [1, 2, 3.14],
|
||||
"complex": {
|
||||
1: {"prop1": 42, "prop2": "hello"},
|
||||
2: {"prop1": 42, "prop2": "hello"},
|
||||
},
|
||||
"dt": "1989-11-09 18:53:00+01:00",
|
||||
"fig": [],
|
||||
"is_hydrated": False,
|
||||
"key": "",
|
||||
"map_key": "a",
|
||||
"mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
|
||||
"num1": 0,
|
||||
"num2": 3.14,
|
||||
"obj": {"prop1": 42, "prop2": "hello"},
|
||||
"sum": 3.14,
|
||||
"upper": "",
|
||||
"router": formatted_router,
|
||||
},
|
||||
ChildState.get_full_name(): {
|
||||
"count": 23,
|
||||
"grandchild_state": {"value2": ""},
|
||||
"value": "",
|
||||
},
|
||||
"child_state2": {"value": ""},
|
||||
"complex": {
|
||||
1: {"prop1": 42, "prop2": "hello"},
|
||||
2: {"prop1": 42, "prop2": "hello"},
|
||||
},
|
||||
"dt": "1989-11-09 18:53:00+01:00",
|
||||
"fig": [],
|
||||
"is_hydrated": False,
|
||||
"key": "",
|
||||
"map_key": "a",
|
||||
"mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
|
||||
"num1": 0,
|
||||
"num2": 3.14,
|
||||
"obj": {"prop1": 42, "prop2": "hello"},
|
||||
"sum": 3.14,
|
||||
"upper": "",
|
||||
"router": formatted_router,
|
||||
ChildState2.get_full_name(): {"value": ""},
|
||||
GrandchildState.get_full_name(): {"value2": ""},
|
||||
},
|
||||
),
|
||||
(
|
||||
DateTimeState().dict(),
|
||||
{
|
||||
"d": "1989-11-09",
|
||||
"dt": "1989-11-09 18:53:00+01:00",
|
||||
"is_hydrated": False,
|
||||
"t": "18:53:00+01:00",
|
||||
"td": "11 days, 0:11:00",
|
||||
"router": formatted_router,
|
||||
DateTimeState.get_full_name(): {
|
||||
"d": "1989-11-09",
|
||||
"dt": "1989-11-09 18:53:00+01:00",
|
||||
"is_hydrated": False,
|
||||
"t": "18:53:00+01:00",
|
||||
"td": "11 days, 0:11:00",
|
||||
"router": formatted_router,
|
||||
},
|
||||
},
|
||||
),
|
||||
],
|
||||
|
Loading…
Reference in New Issue
Block a user