Merge branch 'main' into lendemor/builtins_states

This commit is contained in:
Lendemor 2024-10-25 12:49:36 +02:00
commit 76dce8bba3
41 changed files with 471 additions and 194 deletions

View File

@ -51,6 +51,7 @@ jobs:
SCREENSHOT_DIR: /tmp/screenshots SCREENSHOT_DIR: /tmp/screenshots
REDIS_URL: ${{ matrix.state_manager == 'redis' && 'redis://localhost:6379' || '' }} REDIS_URL: ${{ matrix.state_manager == 'redis' && 'redis://localhost:6379' || '' }}
run: | run: |
poetry run playwright install --with-deps
poetry run pytest tests/integration poetry run pytest tests/integration
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
name: Upload failed test screenshots name: Upload failed test screenshots

35
poetry.lock generated
View File

@ -521,6 +521,21 @@ files = [
{file = "darglint-1.8.1.tar.gz", hash = "sha256:080d5106df149b199822e7ee7deb9c012b49891538f14a11be681044f0bb20da"}, {file = "darglint-1.8.1.tar.gz", hash = "sha256:080d5106df149b199822e7ee7deb9c012b49891538f14a11be681044f0bb20da"},
] ]
[[package]]
name = "dill"
version = "0.3.9"
description = "serialize all of Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"},
{file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"},
]
[package.extras]
graph = ["objgraph (>=1.7.2)"]
profile = ["gprof2dot (>=2022.7.29)"]
[[package]] [[package]]
name = "distlib" name = "distlib"
version = "0.3.9" version = "0.3.9"
@ -1333,8 +1348,8 @@ files = [
[package.dependencies] [package.dependencies]
numpy = [ numpy = [
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""},
] ]
python-dateutil = ">=2.8.2" python-dateutil = ">=2.8.2"
@ -1652,8 +1667,8 @@ files = [
annotated-types = ">=0.6.0" annotated-types = ">=0.6.0"
pydantic-core = "2.23.4" pydantic-core = "2.23.4"
typing-extensions = [ typing-extensions = [
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
{version = ">=4.6.1", markers = "python_version < \"3.13\""}, {version = ">=4.6.1", markers = "python_version < \"3.13\""},
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
] ]
[package.extras] [package.extras]
@ -1977,20 +1992,6 @@ files = [
[package.dependencies] [package.dependencies]
six = ">=1.5" six = ">=1.5"
[[package]]
name = "python-dotenv"
version = "1.0.1"
description = "Read key-value pairs from a .env file and set them as environment variables"
optional = false
python-versions = ">=3.8"
files = [
{file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"},
{file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"},
]
[package.extras]
cli = ["click (>=5.0)"]
[[package]] [[package]]
name = "python-engineio" name = "python-engineio"
version = "4.10.1" version = "4.10.1"
@ -3047,4 +3048,4 @@ type = ["pytest-mypy"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "c5da15520cef58124f6699007c81158036840469d4f9972592d72bd456c45e7e" content-hash = "e03374b85bf10f0a7bb857969b2d6714f25affa63e14a48a88be9fa154b24326"

View File

@ -33,7 +33,6 @@ jinja2 = ">=3.1.2,<4.0"
psutil = ">=5.9.4,<7.0" psutil = ">=5.9.4,<7.0"
pydantic = ">=1.10.2,<3.0" pydantic = ">=1.10.2,<3.0"
python-multipart = ">=0.0.5,<0.1" python-multipart = ">=0.0.5,<0.1"
python-dotenv = ">=1.0.1"
python-socketio = ">=5.7.0,<6.0" python-socketio = ">=5.7.0,<6.0"
redis = ">=4.3.5,<6.0" redis = ">=4.3.5,<6.0"
rich = ">=13.0.0,<14.0" rich = ">=13.0.0,<14.0"
@ -66,6 +65,7 @@ pytest = ">=7.1.2,<9.0"
pytest-mock = ">=3.10.0,<4.0" pytest-mock = ">=3.10.0,<4.0"
pyright = ">=1.1.229,<1.1.335" pyright = ">=1.1.229,<1.1.335"
darglint = ">=1.8.1,<2.0" darglint = ">=1.8.1,<2.0"
dill = ">=0.3.8"
toml = ">=0.10.2,<1.0" toml = ">=0.10.2,<1.0"
pytest-asyncio = ">=0.24.0" pytest-asyncio = ">=0.24.0"
pytest-cov = ">=4.0.0,<6.0" pytest-cov = ">=4.0.0,<6.0"

View File

@ -1,11 +1,11 @@
{% extends "web/pages/base_page.js.jinja2" %} {% extends "web/pages/base_page.js.jinja2" %}
{% block early_imports %} {% block early_imports %}
import '/styles/styles.css' import '$/styles/styles.css'
{% endblock %} {% endblock %}
{% block declaration %} {% block declaration %}
import { EventLoopProvider, StateProvider, defaultColorMode } from "/utils/context.js"; import { EventLoopProvider, StateProvider, defaultColorMode } from "$/utils/context.js";
import { ThemeProvider } from 'next-themes' import { ThemeProvider } from 'next-themes'
{% for library_alias, library_path in window_libraries %} {% for library_alias, library_path in window_libraries %}
import * as {{library_alias}} from "{{library_path}}"; import * as {{library_alias}} from "{{library_path}}";

View File

@ -1,5 +1,5 @@
import { createContext, useContext, useMemo, useReducer, useState } from "react" import { createContext, useContext, useMemo, useReducer, useState } from "react"
import { applyDelta, Event, hydrateClientStorage, useEventLoop, refs } from "/utils/state.js" import { applyDelta, Event, hydrateClientStorage, useEventLoop, refs } from "$/utils/state.js"
{% if initial_state %} {% if initial_state %}
export const initialState = {{ initial_state|json_dumps }} export const initialState = {{ initial_state|json_dumps }}
@ -59,6 +59,8 @@ export const initialEvents = () => [
{% else %} {% else %}
export const state_name = undefined export const state_name = undefined
export const exception_state_name = undefined
export const onLoadInternalEvent = () => [] export const onLoadInternalEvent = () => []
export const initialEvents = () => [] export const initialEvents = () => []

View File

@ -4,8 +4,8 @@ import {
ColorModeContext, ColorModeContext,
defaultColorMode, defaultColorMode,
isDevMode, isDevMode,
lastCompiledTimeStamp lastCompiledTimeStamp,
} from "/utils/context.js"; } from "$/utils/context.js";
export default function RadixThemesColorModeProvider({ children }) { export default function RadixThemesColorModeProvider({ children }) {
const { theme, resolvedTheme, setTheme } = useTheme(); const { theme, resolvedTheme, setTheme } = useTheme();
@ -37,7 +37,7 @@ export default function RadixThemesColorModeProvider({ children }) {
const allowedModes = ["light", "dark", "system"]; const allowedModes = ["light", "dark", "system"];
if (!allowedModes.includes(mode)) { if (!allowedModes.includes(mode)) {
console.error( console.error(
`Invalid color mode "${mode}". Defaulting to "${defaultColorMode}".`, `Invalid color mode "${mode}". Defaulting to "${defaultColorMode}".`
); );
mode = defaultColorMode; mode = defaultColorMode;
} }

View File

@ -2,7 +2,8 @@
"compilerOptions": { "compilerOptions": {
"baseUrl": ".", "baseUrl": ".",
"paths": { "paths": {
"$/*": ["*"],
"@/*": ["public/*"] "@/*": ["public/*"]
} }
} }
} }

View File

@ -2,7 +2,7 @@
import axios from "axios"; import axios from "axios";
import io from "socket.io-client"; import io from "socket.io-client";
import JSON5 from "json5"; import JSON5 from "json5";
import env from "/env.json"; import env from "$/env.json";
import Cookies from "universal-cookie"; import Cookies from "universal-cookie";
import { useEffect, useRef, useState } from "react"; import { useEffect, useRef, useState } from "react";
import Router, { useRouter } from "next/router"; import Router, { useRouter } from "next/router";
@ -12,9 +12,9 @@ import {
onLoadInternalEvent, onLoadInternalEvent,
state_name, state_name,
exception_state_name, exception_state_name,
} from "utils/context.js"; } from "$/utils/context.js";
import debounce from "/utils/helpers/debounce"; import debounce from "$/utils/helpers/debounce";
import throttle from "/utils/helpers/throttle"; import throttle from "$/utils/helpers/throttle";
import * as Babel from "@babel/standalone"; import * as Babel from "@babel/standalone";
// Endpoint URLs. // Endpoint URLs.

View File

@ -679,7 +679,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
for i, tags in imports.items() for i, tags in imports.items()
if i not in constants.PackageJson.DEPENDENCIES if i not in constants.PackageJson.DEPENDENCIES
and i not in constants.PackageJson.DEV_DEPENDENCIES and i not in constants.PackageJson.DEV_DEPENDENCIES
and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"]) and not any(i.startswith(prefix) for prefix in ["/", "$/", ".", "next/"])
and i != "" and i != ""
and any(tag.install for tag in tags) and any(tag.install for tag in tags)
} }

View File

@ -67,8 +67,8 @@ def _compile_app(app_root: Component) -> str:
window_libraries = [ window_libraries = [
(_normalize_library_name(name), name) for name in bundled_libraries (_normalize_library_name(name), name) for name in bundled_libraries
] + [ ] + [
("utils_context", f"/{constants.Dirs.UTILS}/context"), ("utils_context", f"$/{constants.Dirs.UTILS}/context"),
("utils_state", f"/{constants.Dirs.UTILS}/state"), ("utils_state", f"$/{constants.Dirs.UTILS}/state"),
] ]
return templates.APP_ROOT.render( return templates.APP_ROOT.render(
@ -228,7 +228,7 @@ def _compile_components(
""" """
imports = { imports = {
"react": [ImportVar(tag="memo")], "react": [ImportVar(tag="memo")],
f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="E"), ImportVar(tag="isTrue")], f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="E"), ImportVar(tag="isTrue")],
} }
component_renders = [] component_renders = []
@ -315,7 +315,7 @@ def _compile_stateful_components(
# Don't import from the file that we're about to create. # Don't import from the file that we're about to create.
all_imports = utils.merge_imports(*all_import_dicts) all_imports = utils.merge_imports(*all_import_dicts)
all_imports.pop( all_imports.pop(
f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None f"$/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None
) )
return templates.STATEFUL_COMPONENTS.render( return templates.STATEFUL_COMPONENTS.render(

View File

@ -83,6 +83,12 @@ def validate_imports(import_dict: ParsedImportDict):
f"{_import.tag}/{_import.alias}" if _import.alias else _import.tag f"{_import.tag}/{_import.alias}" if _import.alias else _import.tag
) )
if import_name in used_tags: if import_name in used_tags:
already_imported = used_tags[import_name]
if (already_imported[0] == "$" and already_imported[1:] == lib) or (
lib[0] == "$" and lib[1:] == already_imported
):
used_tags[import_name] = lib if lib[0] == "$" else already_imported
continue
raise ValueError( raise ValueError(
f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}" f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}"
) )

View File

@ -38,6 +38,7 @@ from reflex.constants import (
) )
from reflex.constants.compiler import SpecialAttributes from reflex.constants.compiler import SpecialAttributes
from reflex.event import ( from reflex.event import (
EventCallback,
EventChain, EventChain,
EventChainVar, EventChainVar,
EventHandler, EventHandler,
@ -1126,6 +1127,8 @@ class Component(BaseComponent, ABC):
for trigger in self.event_triggers.values(): for trigger in self.event_triggers.values():
if isinstance(trigger, EventChain): if isinstance(trigger, EventChain):
for event in trigger.events: for event in trigger.events:
if isinstance(event, EventCallback):
continue
if isinstance(event, EventSpec): if isinstance(event, EventSpec):
if event.handler.state_full_name: if event.handler.state_full_name:
return True return True
@ -1305,7 +1308,9 @@ class Component(BaseComponent, ABC):
if self._get_ref_hook(): if self._get_ref_hook():
# Handle hooks needed for attaching react refs to DOM nodes. # Handle hooks needed for attaching react refs to DOM nodes.
_imports.setdefault("react", set()).add(ImportVar(tag="useRef")) _imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
_imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs")) _imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add(
ImportVar(tag="refs")
)
if self._get_mount_lifecycle_hook(): if self._get_mount_lifecycle_hook():
# Handle hooks for `on_mount` / `on_unmount`. # Handle hooks for `on_mount` / `on_unmount`.
@ -1662,7 +1667,7 @@ class CustomComponent(Component):
"""A custom user-defined component.""" """A custom user-defined component."""
# Use the components library. # Use the components library.
library = f"/{Dirs.COMPONENTS_PATH}" library = f"$/{Dirs.COMPONENTS_PATH}"
# The function that creates the component. # The function that creates the component.
component_fn: Callable[..., Component] = Component.create component_fn: Callable[..., Component] = Component.create
@ -2230,7 +2235,7 @@ class StatefulComponent(BaseComponent):
""" """
if self.rendered_as_shared: if self.rendered_as_shared:
return { return {
f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [ f"$/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [
ImportVar(tag=self.tag) ImportVar(tag=self.tag)
] ]
} }

View File

@ -66,8 +66,8 @@ class WebsocketTargetURL(Var):
_js_expr="getBackendURL(env.EVENT).href", _js_expr="getBackendURL(env.EVENT).href",
_var_data=VarData( _var_data=VarData(
imports={ imports={
"/env.json": [ImportVar(tag="env", is_default=True)], "$/env.json": [ImportVar(tag="env", is_default=True)],
f"/{Dirs.STATE_PATH}": [ImportVar(tag="getBackendURL")], f"$/{Dirs.STATE_PATH}": [ImportVar(tag="getBackendURL")],
}, },
), ),
_var_type=WebsocketTargetURL, _var_type=WebsocketTargetURL,

View File

@ -21,7 +21,7 @@ route_not_found: Var = Var(_js_expr=constants.ROUTE_NOT_FOUND)
class ClientSideRouting(Component): class ClientSideRouting(Component):
"""The client-side routing component.""" """The client-side routing component."""
library = "/utils/client_side_routing" library = "$/utils/client_side_routing"
tag = "useClientSideRouting" tag = "useClientSideRouting"
def add_hooks(self) -> list[str]: def add_hooks(self) -> list[str]:

View File

@ -67,7 +67,7 @@ class Clipboard(Fragment):
The import dict for the component. The import dict for the component.
""" """
return { return {
"/utils/helpers/paste.js": ImportVar( "$/utils/helpers/paste.js": ImportVar(
tag="usePasteHandler", is_default=True tag="usePasteHandler", is_default=True
), ),
} }

View File

@ -15,7 +15,7 @@ from reflex.vars.base import LiteralVar, Var
from reflex.vars.number import ternary_operation from reflex.vars.number import ternary_operation
_IS_TRUE_IMPORT: ImportDict = { _IS_TRUE_IMPORT: ImportDict = {
f"/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")], f"$/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
} }

View File

@ -118,7 +118,7 @@ class DebounceInput(Component):
_var_type=Type[Component], _var_type=Type[Component],
_var_data=VarData( _var_data=VarData(
imports=child._get_imports(), imports=child._get_imports(),
hooks=child._get_hooks_internal(), hooks=child._get_all_hooks(),
), ),
), ),
) )
@ -128,6 +128,10 @@ class DebounceInput(Component):
component.event_triggers.update(child.event_triggers) component.event_triggers.update(child.event_triggers)
component.children = child.children component.children = child.children
component._rename_props = child._rename_props component._rename_props = child._rename_props
outer_get_all_custom_code = component._get_all_custom_code
component._get_all_custom_code = lambda: outer_get_all_custom_code().union(
child._get_all_custom_code()
)
return component return component
def _render(self): def _render(self):

View File

@ -29,7 +29,7 @@ DEFAULT_UPLOAD_ID: str = "default"
upload_files_context_var_data: VarData = VarData( upload_files_context_var_data: VarData = VarData(
imports={ imports={
"react": "useContext", "react": "useContext",
f"/{Dirs.CONTEXTS_PATH}": "UploadFilesContext", f"$/{Dirs.CONTEXTS_PATH}": "UploadFilesContext",
}, },
hooks={ hooks={
"const [filesById, setFilesById] = useContext(UploadFilesContext);": None, "const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
@ -134,8 +134,8 @@ uploaded_files_url_prefix = Var(
_js_expr="getBackendURL(env.UPLOAD)", _js_expr="getBackendURL(env.UPLOAD)",
_var_data=VarData( _var_data=VarData(
imports={ imports={
f"/{Dirs.STATE_PATH}": "getBackendURL", f"$/{Dirs.STATE_PATH}": "getBackendURL",
"/env.json": ImportVar(tag="env", is_default=True), "$/env.json": ImportVar(tag="env", is_default=True),
} }
), ),
).to(str) ).to(str)
@ -170,7 +170,7 @@ def _on_drop_spec(files: Var) -> Tuple[Var[Any]]:
class UploadFilesProvider(Component): class UploadFilesProvider(Component):
"""AppWrap component that provides a dict of selected files by ID via useContext.""" """AppWrap component that provides a dict of selected files by ID via useContext."""
library = f"/{Dirs.CONTEXTS_PATH}" library = f"$/{Dirs.CONTEXTS_PATH}"
tag = "UploadFilesProvider" tag = "UploadFilesProvider"

View File

@ -34,8 +34,8 @@ uploaded_files_url_prefix = Var(
_js_expr="getBackendURL(env.UPLOAD)", _js_expr="getBackendURL(env.UPLOAD)",
_var_data=VarData( _var_data=VarData(
imports={ imports={
f"/{Dirs.STATE_PATH}": "getBackendURL", f"$/{Dirs.STATE_PATH}": "getBackendURL",
"/env.json": ImportVar(tag="env", is_default=True), "$/env.json": ImportVar(tag="env", is_default=True),
} }
), ),
).to(str) ).to(str)

View File

@ -344,7 +344,7 @@ class DataEditor(NoSSRComponent):
return { return {
"": f"{format.format_library_name(self.library)}/dist/index.css", "": f"{format.format_library_name(self.library)}/dist/index.css",
self.library: "GridCellKind", self.library: "GridCellKind",
"/utils/helpers/dataeditor.js": ImportVar( "$/utils/helpers/dataeditor.js": ImportVar(
tag="formatDataEditorCells", is_default=False, install=False tag="formatDataEditorCells", is_default=False, install=False
), ),
} }

View File

@ -90,7 +90,7 @@ def load_dynamic_serializer():
for lib, names in component._get_all_imports().items(): for lib, names in component._get_all_imports().items():
formatted_lib_name = format_library_name(lib) formatted_lib_name = format_library_name(lib)
if ( if (
not lib.startswith((".", "/")) not lib.startswith((".", "/", "$/"))
and not lib.startswith("http") and not lib.startswith("http")
and formatted_lib_name not in libs_in_window and formatted_lib_name not in libs_in_window
): ):
@ -106,7 +106,7 @@ def load_dynamic_serializer():
# Rewrite imports from `/` to destructure from window # Rewrite imports from `/` to destructure from window
for ix, line in enumerate(module_code_lines[:]): for ix, line in enumerate(module_code_lines[:]):
if line.startswith("import "): if line.startswith("import "):
if 'from "/' in line: if 'from "$/' in line or 'from "/' in line:
module_code_lines[ix] = ( module_code_lines[ix] = (
line.replace("import ", "const ", 1).replace( line.replace("import ", "const ", 1).replace(
" from ", " = window['__reflex'][", 1 " from ", " = window['__reflex'][", 1
@ -157,7 +157,7 @@ def load_dynamic_serializer():
merge_var_data=VarData.merge( merge_var_data=VarData.merge(
VarData( VarData(
imports={ imports={
f"/{constants.Dirs.STATE_PATH}": [ f"$/{constants.Dirs.STATE_PATH}": [
imports.ImportVar(tag="evalReactComponent"), imports.ImportVar(tag="evalReactComponent"),
], ],
"react": [ "react": [

View File

@ -187,7 +187,7 @@ class Form(BaseHTML):
""" """
return { return {
"react": "useCallback", "react": "useCallback",
f"/{Dirs.STATE_PATH}": ["getRefValue", "getRefValues"], f"$/{Dirs.STATE_PATH}": ["getRefValue", "getRefValues"],
} }
def add_hooks(self) -> list[str]: def add_hooks(self) -> list[str]:
@ -615,6 +615,42 @@ class Textarea(BaseHTML):
# Fired when a key is released # Fired when a key is released
on_key_up: EventHandler[key_event] on_key_up: EventHandler[key_event]
@classmethod
def create(cls, *children, **props):
"""Create a textarea component.
Args:
*children: The children of the textarea.
**props: The properties of the textarea.
Returns:
The textarea component.
Raises:
ValueError: when `enter_key_submit` is combined with `on_key_down`.
"""
enter_key_submit = props.get("enter_key_submit")
auto_height = props.get("auto_height")
custom_attrs = props.setdefault("custom_attrs", {})
if enter_key_submit is not None:
enter_key_submit = Var.create(enter_key_submit)
if "on_key_down" in props:
raise ValueError(
"Cannot combine `enter_key_submit` with `on_key_down`.",
)
custom_attrs["on_key_down"] = Var(
_js_expr=f"(e) => enterKeySubmitOnKeyDown(e, {str(enter_key_submit)})",
_var_data=VarData.merge(enter_key_submit._get_all_var_data()),
)
if auto_height is not None:
auto_height = Var.create(auto_height)
custom_attrs["on_input"] = Var(
_js_expr=f"(e) => autoHeightOnInput(e, {str(auto_height)})",
_var_data=VarData.merge(auto_height._get_all_var_data()),
)
return super().create(*children, **props)
def _exclude_props(self) -> list[str]: def _exclude_props(self) -> list[str]:
return super()._exclude_props() + [ return super()._exclude_props() + [
"auto_height", "auto_height",
@ -634,28 +670,6 @@ class Textarea(BaseHTML):
custom_code.add(ENTER_KEY_SUBMIT_JS) custom_code.add(ENTER_KEY_SUBMIT_JS)
return custom_code return custom_code
def _render(self) -> Tag:
tag = super()._render()
if self.enter_key_submit is not None:
if "on_key_down" in self.event_triggers:
raise ValueError(
"Cannot combine `enter_key_submit` with `on_key_down`.",
)
tag.add_props(
on_key_down=Var(
_js_expr=f"(e) => enterKeySubmitOnKeyDown(e, {str(self.enter_key_submit)})",
_var_data=VarData.merge(self.enter_key_submit._get_all_var_data()),
)
)
if self.auto_height is not None:
tag.add_props(
on_input=Var(
_js_expr=f"(e) => autoHeightOnInput(e, {str(self.auto_height)})",
_var_data=VarData.merge(self.auto_height._get_all_var_data()),
)
)
return tag
button = Button.create button = Button.create
fieldset = Fieldset.create fieldset = Fieldset.create

View File

@ -1376,10 +1376,10 @@ class Textarea(BaseHTML):
on_unmount: Optional[EventType[[]]] = None, on_unmount: Optional[EventType[[]]] = None,
**props, **props,
) -> "Textarea": ) -> "Textarea":
"""Create the component. """Create a textarea component.
Args: Args:
*children: The children of the component. *children: The children of the textarea.
auto_complete: Whether the form control should have autocomplete enabled auto_complete: Whether the form control should have autocomplete enabled
auto_focus: Automatically focuses the textarea when the page loads auto_focus: Automatically focuses the textarea when the page loads
auto_height: Automatically fit the content height to the text (use min-height with this prop) auto_height: Automatically fit the content height to the text (use min-height with this prop)
@ -1419,10 +1419,13 @@ class Textarea(BaseHTML):
class_name: The class name for the component. class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute custom_attrs: custom attribute
**props: The props of the component. **props: The properties of the textarea.
Returns: Returns:
The component. The textarea component.
Raises:
ValueError: when `enter_key_submit` is combined with `on_key_down`.
""" """
... ...

View File

@ -221,7 +221,7 @@ class Theme(RadixThemesComponent):
The import dict. The import dict.
""" """
_imports: ImportDict = { _imports: ImportDict = {
"/utils/theme.js": [ImportVar(tag="theme", is_default=True)], "$/utils/theme.js": [ImportVar(tag="theme", is_default=True)],
} }
if get_config().tailwind is None: if get_config().tailwind is None:
# When tailwind is disabled, import the radix-ui styles directly because they will # When tailwind is disabled, import the radix-ui styles directly because they will
@ -265,7 +265,7 @@ class ThemePanel(RadixThemesComponent):
class RadixThemesColorModeProvider(Component): class RadixThemesColorModeProvider(Component):
"""Next-themes integration for radix themes components.""" """Next-themes integration for radix themes components."""
library = "/components/reflex/radix_themes_color_mode_provider.js" library = "$/components/reflex/radix_themes_color_mode_provider.js"
tag = "RadixThemesColorModeProvider" tag = "RadixThemesColorModeProvider"
is_default = True is_default = True

View File

@ -251,7 +251,7 @@ class Toaster(Component):
_js_expr=f"{toast_ref} = toast", _js_expr=f"{toast_ref} = toast",
_var_data=VarData( _var_data=VarData(
imports={ imports={
"/utils/state": [ImportVar(tag="refs")], "$/utils/state": [ImportVar(tag="refs")],
self.library: [ImportVar(tag="toast", install=False)], self.library: [ImportVar(tag="toast", install=False)],
} }
), ),

View File

@ -8,12 +8,12 @@ import os
import sys import sys
import urllib.parse import urllib.parse
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set
from typing_extensions import get_type_hints from typing_extensions import get_type_hints
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
from reflex.utils.types import value_inside_optional from reflex.utils.types import GenericType, is_union, value_inside_optional
try: try:
import pydantic.v1 as pydantic import pydantic.v1 as pydantic
@ -157,11 +157,13 @@ def get_default_value_for_field(field: dataclasses.Field) -> Any:
) )
def interpret_boolean_env(value: str) -> bool: # TODO: Change all interpret_.* signatures to value: str, field: dataclasses.Field once we migrate rx.Config to dataclasses
def interpret_boolean_env(value: str, field_name: str) -> bool:
"""Interpret a boolean environment variable value. """Interpret a boolean environment variable value.
Args: Args:
value: The environment variable value. value: The environment variable value.
field_name: The field name.
Returns: Returns:
The interpreted value. The interpreted value.
@ -176,14 +178,15 @@ def interpret_boolean_env(value: str) -> bool:
return True return True
elif value.lower() in false_values: elif value.lower() in false_values:
return False return False
raise EnvironmentVarValueError(f"Invalid boolean value: {value}") raise EnvironmentVarValueError(f"Invalid boolean value: {value} for {field_name}")
def interpret_int_env(value: str) -> int: def interpret_int_env(value: str, field_name: str) -> int:
"""Interpret an integer environment variable value. """Interpret an integer environment variable value.
Args: Args:
value: The environment variable value. value: The environment variable value.
field_name: The field name.
Returns: Returns:
The interpreted value. The interpreted value.
@ -194,14 +197,17 @@ def interpret_int_env(value: str) -> int:
try: try:
return int(value) return int(value)
except ValueError as ve: except ValueError as ve:
raise EnvironmentVarValueError(f"Invalid integer value: {value}") from ve raise EnvironmentVarValueError(
f"Invalid integer value: {value} for {field_name}"
) from ve
def interpret_path_env(value: str) -> Path: def interpret_path_env(value: str, field_name: str) -> Path:
"""Interpret a path environment variable value. """Interpret a path environment variable value.
Args: Args:
value: The environment variable value. value: The environment variable value.
field_name: The field name.
Returns: Returns:
The interpreted value. The interpreted value.
@ -211,16 +217,19 @@ def interpret_path_env(value: str) -> Path:
""" """
path = Path(value) path = Path(value)
if not path.exists(): if not path.exists():
raise EnvironmentVarValueError(f"Path does not exist: {path}") raise EnvironmentVarValueError(f"Path does not exist: {path} for {field_name}")
return path return path
def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any: def interpret_env_var_value(
value: str, field_type: GenericType, field_name: str
) -> Any:
"""Interpret an environment variable value based on the field type. """Interpret an environment variable value based on the field type.
Args: Args:
value: The environment variable value. value: The environment variable value.
field: The field. field_type: The field type.
field_name: The field name.
Returns: Returns:
The interpreted value. The interpreted value.
@ -228,20 +237,25 @@ def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any:
Raises: Raises:
ValueError: If the value is invalid. ValueError: If the value is invalid.
""" """
field_type = value_inside_optional(field.type) field_type = value_inside_optional(field_type)
if is_union(field_type):
raise ValueError(
f"Union types are not supported for environment variables: {field_name}."
)
if field_type is bool: if field_type is bool:
return interpret_boolean_env(value) return interpret_boolean_env(value, field_name)
elif field_type is str: elif field_type is str:
return value return value
elif field_type is int: elif field_type is int:
return interpret_int_env(value) return interpret_int_env(value, field_name)
elif field_type is Path: elif field_type is Path:
return interpret_path_env(value) return interpret_path_env(value, field_name)
else: else:
raise ValueError( raise ValueError(
f"Invalid type for environment variable {field.name}: {field_type}. This is probably an issue in Reflex." f"Invalid type for environment variable {field_name}: {field_type}. This is probably an issue in Reflex."
) )
@ -316,7 +330,7 @@ class EnvironmentVariables:
field.type = type_hints.get(field.name) or field.type field.type = type_hints.get(field.name) or field.type
value = ( value = (
interpret_env_var_value(raw_value, field) interpret_env_var_value(raw_value, field.type, field.name)
if raw_value is not None if raw_value is not None
else get_default_value_for_field(field) else get_default_value_for_field(field)
) )
@ -387,7 +401,7 @@ class Config(Base):
telemetry_enabled: bool = True telemetry_enabled: bool = True
# The bun path # The bun path
bun_path: Union[str, Path] = constants.Bun.DEFAULT_PATH bun_path: Path = constants.Bun.DEFAULT_PATH
# List of origins that are allowed to connect to the backend API. # List of origins that are allowed to connect to the backend API.
cors_allowed_origins: List[str] = ["*"] cors_allowed_origins: List[str] = ["*"]
@ -484,17 +498,17 @@ class Config(Base):
Returns: Returns:
The updated config values. The updated config values.
Raises:
EnvVarValueError: If an environment variable is set to an invalid type.
""" """
from reflex.utils.exceptions import EnvVarValueError
if self.env_file: if self.env_file:
from dotenv import load_dotenv try:
from dotenv import load_dotenv # type: ignore
# load env file if exists # load env file if exists
load_dotenv(self.env_file, override=True) load_dotenv(self.env_file, override=True)
except ImportError:
console.error(
"""The `python-dotenv` package is required to load environment variables from a file. Run `pip install "python-dotenv>=1.0.1"`."""
)
updated_values = {} updated_values = {}
# Iterate over the fields. # Iterate over the fields.
@ -510,21 +524,11 @@ class Config(Base):
dedupe=True, dedupe=True,
) )
# Convert the env var to the expected type. # Interpret the value.
try: value = interpret_env_var_value(env_var, field.type_, field.name)
if issubclass(field.type_, bool):
# special handling for bool values
env_var = env_var.lower() in ["true", "1", "yes"]
else:
env_var = field.type_(env_var)
except ValueError as ve:
console.error(
f"Could not convert {key.upper()}={env_var} to type {field.type_}"
)
raise EnvVarValueError from ve
# Set the value. # Set the value.
updated_values[key] = env_var updated_values[key] = value
return updated_values return updated_values

View File

@ -118,8 +118,8 @@ class Imports(SimpleNamespace):
EVENTS = { EVENTS = {
"react": [ImportVar(tag="useContext")], "react": [ImportVar(tag="useContext")],
f"/{Dirs.CONTEXTS_PATH}": [ImportVar(tag="EventLoopContext")], f"$/{Dirs.CONTEXTS_PATH}": [ImportVar(tag="EventLoopContext")],
f"/{Dirs.STATE_PATH}": [ImportVar(tag=CompileVars.TO_EVENT)], f"$/{Dirs.STATE_PATH}": [ImportVar(tag=CompileVars.TO_EVENT)],
} }

View File

@ -16,6 +16,7 @@ from typing import (
Generic, Generic,
List, List,
Optional, Optional,
Sequence,
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
@ -389,7 +390,9 @@ class CallableEventSpec(EventSpec):
class EventChain(EventActionsMixin): class EventChain(EventActionsMixin):
"""Container for a chain of events that will be executed in order.""" """Container for a chain of events that will be executed in order."""
events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list) events: Sequence[Union[EventSpec, EventVar, EventCallback]] = dataclasses.field(
default_factory=list
)
args_spec: Optional[Callable] = dataclasses.field(default=None) args_spec: Optional[Callable] = dataclasses.field(default=None)
@ -1445,13 +1448,8 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
) )
G = ParamSpec("G")
IndividualEventType = Union[EventSpec, EventHandler, Callable[G, Any], Var[Any]]
EventType = Union[IndividualEventType[G], List[IndividualEventType[G]]]
P = ParamSpec("P") P = ParamSpec("P")
Q = ParamSpec("Q")
T = TypeVar("T") T = TypeVar("T")
V = TypeVar("V") V = TypeVar("V")
V2 = TypeVar("V2") V2 = TypeVar("V2")
@ -1473,55 +1471,73 @@ if sys.version_info >= (3, 10):
""" """
self.func = func self.func = func
@property
def prevent_default(self):
"""Prevent default behavior.
Returns:
The event callback with prevent default behavior.
"""
return self
@property
def stop_propagation(self):
"""Stop event propagation.
Returns:
The event callback with stop propagation behavior.
"""
return self
@overload @overload
def __get__( def __call__(
self: EventCallback[[V], T], instance: None, owner self: EventCallback[Concatenate[V, Q], T], value: V | Var[V]
) -> Callable[[Union[Var[V], V]], EventSpec]: ... ) -> EventCallback[Q, T]: ...
@overload
def __call__(
self: EventCallback[Concatenate[V, V2, Q], T],
value: V | Var[V],
value2: V2 | Var[V2],
) -> EventCallback[Q, T]: ...
@overload
def __call__(
self: EventCallback[Concatenate[V, V2, V3, Q], T],
value: V | Var[V],
value2: V2 | Var[V2],
value3: V3 | Var[V3],
) -> EventCallback[Q, T]: ...
@overload
def __call__(
self: EventCallback[Concatenate[V, V2, V3, V4, Q], T],
value: V | Var[V],
value2: V2 | Var[V2],
value3: V3 | Var[V3],
value4: V4 | Var[V4],
) -> EventCallback[Q, T]: ...
def __call__(self, *values) -> EventCallback: # type: ignore
"""Call the function with the values.
Args:
*values: The values to call the function with.
Returns:
The function with the values.
"""
return self.func(*values) # type: ignore
@overload @overload
def __get__( def __get__(
self: EventCallback[[V, V2], T], instance: None, owner self: EventCallback[P, T], instance: None, owner
) -> Callable[[Union[Var[V], V], Union[Var[V2], V2]], EventSpec]: ... ) -> EventCallback[P, T]: ...
@overload
def __get__(
self: EventCallback[[V, V2, V3], T], instance: None, owner
) -> Callable[
[Union[Var[V], V], Union[Var[V2], V2], Union[Var[V3], V3]],
EventSpec,
]: ...
@overload
def __get__(
self: EventCallback[[V, V2, V3, V4], T], instance: None, owner
) -> Callable[
[
Union[Var[V], V],
Union[Var[V2], V2],
Union[Var[V3], V3],
Union[Var[V4], V4],
],
EventSpec,
]: ...
@overload
def __get__(
self: EventCallback[[V, V2, V3, V4, V5], T], instance: None, owner
) -> Callable[
[
Union[Var[V], V],
Union[Var[V2], V2],
Union[Var[V3], V3],
Union[Var[V4], V4],
Union[Var[V5], V5],
],
EventSpec,
]: ...
@overload @overload
def __get__(self, instance, owner) -> Callable[P, T]: ... def __get__(self, instance, owner) -> Callable[P, T]: ...
def __get__(self, instance, owner) -> Callable: def __get__(self, instance, owner) -> Callable: # type: ignore
"""Get the function with the instance bound to it. """Get the function with the instance bound to it.
Args: Args:
@ -1548,6 +1564,9 @@ if sys.version_info >= (3, 10):
return func # type: ignore return func # type: ignore
else: else:
class EventCallback(Generic[P, T]):
"""A descriptor that wraps a function to be used as an event."""
def event_handler(func: Callable[P, T]) -> Callable[P, T]: def event_handler(func: Callable[P, T]) -> Callable[P, T]:
"""Wrap a function to be used as an event. """Wrap a function to be used as an event.
@ -1560,6 +1579,17 @@ else:
return func return func
G = ParamSpec("G")
IndividualEventType = Union[
EventSpec, EventHandler, Callable[G, Any], EventCallback[G, Any], Var[Any]
]
ItemOrList = Union[V, List[V]]
EventType = ItemOrList[IndividualEventType[G]]
class EventNamespace(types.SimpleNamespace): class EventNamespace(types.SimpleNamespace):
"""A namespace for event related classes.""" """A namespace for event related classes."""

View File

@ -21,7 +21,7 @@ NoValue = object()
_refs_import = { _refs_import = {
f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")], f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
} }

View File

@ -38,7 +38,7 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
url = url or conf.db_url url = url or conf.db_url
if url is None: if url is None:
raise ValueError("No database url configured") raise ValueError("No database url configured")
if environment.ALEMBIC_CONFIG.exists(): if not environment.ALEMBIC_CONFIG.exists():
console.warn( console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first." "Database is not initialized, run [bold]reflex db init[/bold] first."
) )

View File

@ -218,6 +218,7 @@ class EventHandlerSetVar(EventHandler):
Raises: Raises:
AttributeError: If the given Var name does not exist on the state. AttributeError: If the given Var name does not exist on the state.
EventHandlerValueError: If the given Var name is not a str EventHandlerValueError: If the given Var name is not a str
NotImplementedError: If the setter for the given Var is async
""" """
from reflex.utils.exceptions import EventHandlerValueError from reflex.utils.exceptions import EventHandlerValueError
@ -226,11 +227,20 @@ class EventHandlerSetVar(EventHandler):
raise EventHandlerValueError( raise EventHandlerValueError(
f"Var name must be passed as a string, got {args[0]!r}" f"Var name must be passed as a string, got {args[0]!r}"
) )
handler = getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None)
# Check that the requested Var setter exists on the State at compile time. # Check that the requested Var setter exists on the State at compile time.
if getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) is None: if handler is None:
raise AttributeError( raise AttributeError(
f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`" f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`"
) )
if asyncio.iscoroutinefunction(handler.fn):
raise NotImplementedError(
f"Setter for {args[0]} is async, which is not supported."
)
return super().__call__(*args) return super().__call__(*args)
@ -2053,12 +2063,24 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
""" """
try: try:
return pickle.dumps((self._to_schema(), self)) return pickle.dumps((self._to_schema(), self))
except pickle.PicklingError: except (pickle.PicklingError, AttributeError) as og_pickle_error:
console.warn( error = (
f"Failed to serialize state {self.get_full_name()} due to unpicklable object. " f"Failed to serialize state {self.get_full_name()} due to unpicklable object. "
"This state will not be persisted." "This state will not be persisted. "
) )
return b"" try:
import dill
return dill.dumps((self._to_schema(), self))
except ImportError:
error += (
f"Pickle error: {og_pickle_error}. "
"Consider `pip install 'dill>=0.3.8'` for more exotic serialization support."
)
except (pickle.PicklingError, TypeError, ValueError) as ex:
error += f"Dill was also unable to pickle the state: {ex}"
console.warn(error)
return b""
@classmethod @classmethod
def _deserialize( def _deserialize(
@ -2725,9 +2747,13 @@ class StateManagerDisk(StateManager):
for substate in state.get_substates(): for substate in state.get_substates():
substate_token = _substate_key(client_token, substate) substate_token = _substate_key(client_token, substate)
fresh_instance = await root_state.get_state(substate)
instance = await self.load_state(substate_token) instance = await self.load_state(substate_token)
if instance is None: if instance is not None:
instance = await root_state.get_state(substate) # Ensure all substates exist, even if they weren't serialized previously.
instance.substates = fresh_instance.substates
else:
instance = fresh_instance
state.substates[substate.get_name()] = instance state.substates[substate.get_name()] = instance
instance.parent_state = state instance.parent_state = state

View File

@ -23,7 +23,7 @@ LiteralColorMode = Literal["system", "light", "dark"]
# Reference the global ColorModeContext # Reference the global ColorModeContext
color_mode_imports = { color_mode_imports = {
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")], f"$/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")],
"react": [ImportVar(tag="useContext")], "react": [ImportVar(tag="useContext")],
} }

View File

@ -249,7 +249,8 @@ class AppHarness:
return textwrap.dedent(source) return textwrap.dedent(source)
def _initialize_app(self): def _initialize_app(self):
os.environ["TELEMETRY_ENABLED"] = "" # disable telemetry reporting for tests # disable telemetry reporting for tests
os.environ["TELEMETRY_ENABLED"] = "false"
self.app_path.mkdir(parents=True, exist_ok=True) self.app_path.mkdir(parents=True, exist_ok=True)
if self.app_source is not None: if self.app_source is not None:
app_globals = self._get_globals_from_signature(self.app_source) app_globals = self._get_globals_from_signature(self.app_source)

View File

@ -23,6 +23,12 @@ def merge_imports(
for lib, fields in ( for lib, fields in (
import_dict if isinstance(import_dict, tuple) else import_dict.items() import_dict if isinstance(import_dict, tuple) else import_dict.items()
): ):
# If the lib is an absolute path, we need to prefix it with a $
lib = (
"$" + lib
if lib.startswith(("/utils/", "/components/", "/styles/", "/public/"))
else lib
)
if isinstance(fields, (list, tuple, set)): if isinstance(fields, (list, tuple, set)):
all_imports[lib].extend( all_imports[lib].extend(
( (

View File

@ -217,7 +217,7 @@ class VarData:
): None ): None
}, },
imports={ imports={
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")], f"$/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
"react": [ImportVar(tag="useContext")], "react": [ImportVar(tag="useContext")],
}, },
) )
@ -956,7 +956,7 @@ class Var(Generic[VAR_TYPE]):
_js_expr="refs", _js_expr="refs",
_var_data=VarData( _var_data=VarData(
imports={ imports={
f"/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")] f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")]
} }
), ),
).to(ObjectVar, Dict[str, str]) ).to(ObjectVar, Dict[str, str])
@ -2530,7 +2530,7 @@ def get_uuid_string_var() -> Var:
unique_uuid_var = get_unique_variable_name() unique_uuid_var = get_unique_variable_name()
unique_uuid_var_data = VarData( unique_uuid_var_data = VarData(
imports={ imports={
f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="generateUUID")}, # type: ignore f"$/{constants.Dirs.STATE_PATH}": {ImportVar(tag="generateUUID")}, # type: ignore
"react": "useMemo", "react": "useMemo",
}, },
hooks={f"const {unique_uuid_var} = useMemo(generateUUID, [])": None}, hooks={f"const {unique_uuid_var} = useMemo(generateUUID, [])": None},

View File

@ -1090,7 +1090,7 @@ boolean_types = Union[BooleanVar, bool]
_IS_TRUE_IMPORT: ImportDict = { _IS_TRUE_IMPORT: ImportDict = {
f"/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")], f"$/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
} }

View File

@ -0,0 +1,59 @@
"""Integration tests for a stateless app."""
from typing import Generator
import httpx
import pytest
from playwright.sync_api import Page, expect
import reflex as rx
from reflex.testing import AppHarness
def StatelessApp():
"""A stateless app that renders a heading."""
import reflex as rx
def index():
return rx.heading("This is a stateless app")
app = rx.App()
app.add_page(index)
@pytest.fixture(scope="module")
def stateless_app(tmp_path_factory) -> Generator[AppHarness, None, None]:
"""Create a stateless app AppHarness.
Args:
tmp_path_factory: pytest fixture for creating temporary directories.
Yields:
AppHarness: A harness for testing the stateless app.
"""
with AppHarness.create(
root=tmp_path_factory.mktemp("stateless_app"),
app_source=StatelessApp, # type: ignore
) as harness:
yield harness
def test_statelessness(stateless_app: AppHarness, page: Page):
"""Test that the stateless app renders a heading but backend/_event is not mounted.
Args:
stateless_app: A harness for testing the stateless app.
page: A Playwright page.
"""
assert stateless_app.frontend_url is not None
assert stateless_app.backend is not None
assert stateless_app.backend.started
res = httpx.get(rx.config.get_config().api_url + "/_event")
assert res.status_code == 404
res2 = httpx.get(rx.config.get_config().api_url + "/ping")
assert res2.status_code == 200
page.goto(stateless_app.frontend_url)
expect(page.get_by_role("heading")).to_have_text("This is a stateless app")

View File

@ -12,7 +12,7 @@ def test_websocket_target_url():
var_data = url._get_all_var_data() var_data = url._get_all_var_data()
assert var_data is not None assert var_data is not None
assert sorted(tuple((key for key, _ in var_data.imports))) == sorted( assert sorted(tuple((key for key, _ in var_data.imports))) == sorted(
("/utils/state", "/env.json") ("$/utils/state", "$/env.json")
) )
@ -22,10 +22,10 @@ def test_connection_banner():
assert sorted(tuple(_imports)) == sorted( assert sorted(tuple(_imports)) == sorted(
( (
"react", "react",
"/utils/context", "$/utils/context",
"/utils/state", "$/utils/state",
"@radix-ui/themes@^3.0.0", "@radix-ui/themes@^3.0.0",
"/env.json", "$/env.json",
) )
) )
@ -40,10 +40,10 @@ def test_connection_modal():
assert sorted(tuple(_imports)) == sorted( assert sorted(tuple(_imports)) == sorted(
( (
"react", "react",
"/utils/context", "$/utils/context",
"/utils/state", "$/utils/state",
"@radix-ui/themes@^3.0.0", "@radix-ui/themes@^3.0.0",
"/env.json", "$/env.json",
) )
) )

View File

@ -1,5 +1,7 @@
import multiprocessing import multiprocessing
import os import os
from pathlib import Path
from typing import Any, Dict
import pytest import pytest
@ -42,7 +44,12 @@ def test_set_app_name(base_config_values):
("TELEMETRY_ENABLED", True), ("TELEMETRY_ENABLED", True),
], ],
) )
def test_update_from_env(base_config_values, monkeypatch, env_var, value): def test_update_from_env(
base_config_values: Dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
env_var: str,
value: Any,
):
"""Test that environment variables override config values. """Test that environment variables override config values.
Args: Args:
@ -57,6 +64,29 @@ def test_update_from_env(base_config_values, monkeypatch, env_var, value):
assert getattr(config, env_var.lower()) == value assert getattr(config, env_var.lower()) == value
def test_update_from_env_path(
base_config_values: Dict[str, Any],
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
):
"""Test that environment variables override config values.
Args:
base_config_values: Config values.
monkeypatch: The pytest monkeypatch object.
tmp_path: The pytest tmp_path fixture object.
"""
monkeypatch.setenv("BUN_PATH", "/test")
assert os.environ.get("BUN_PATH") == "/test"
with pytest.raises(ValueError):
rx.Config(**base_config_values)
monkeypatch.setenv("BUN_PATH", str(tmp_path))
assert os.environ.get("BUN_PATH") == str(tmp_path)
config = rx.Config(**base_config_values)
assert config.bun_path == tmp_path
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kwargs, expected", "kwargs, expected",
[ [

View File

@ -105,6 +105,7 @@ class TestState(BaseState):
fig: Figure = Figure() fig: Figure = Figure()
dt: datetime.datetime = datetime.datetime.fromisoformat("1989-11-09T18:53:00+01:00") dt: datetime.datetime = datetime.datetime.fromisoformat("1989-11-09T18:53:00+01:00")
_backend: int = 0 _backend: int = 0
asynctest: int = 0
@ComputedVar @ComputedVar
def sum(self) -> float: def sum(self) -> float:
@ -128,6 +129,14 @@ class TestState(BaseState):
"""Do something.""" """Do something."""
pass pass
async def set_asynctest(self, value: int):
"""Set the asynctest value. Intentionally overwrite the default setter with an async one.
Args:
value: The new value.
"""
self.asynctest = value
class ChildState(TestState): class ChildState(TestState):
"""A child state fixture.""" """A child state fixture."""
@ -312,6 +321,7 @@ def test_class_vars(test_state):
"upper", "upper",
"fig", "fig",
"dt", "dt",
"asynctest",
} }
@ -732,6 +742,7 @@ def test_reset(test_state, child_state):
"mapping", "mapping",
"dt", "dt",
"_backend", "_backend",
"asynctest",
} }
# The dirty vars should be reset. # The dirty vars should be reset.
@ -3180,6 +3191,13 @@ async def test_setvar(mock_app: rx.App, token: str):
TestState.setvar(42, 42) TestState.setvar(42, 42)
@pytest.mark.asyncio
async def test_setvar_async_setter():
"""Test that overridden async setters raise Exception when used with setvar."""
with pytest.raises(NotImplementedError):
TestState.setvar("asynctest", 42)
@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis") @pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"expiration_kwargs, expected_values", "expiration_kwargs, expected_values",
@ -3315,3 +3333,68 @@ def test_assignment_to_undeclared_vars():
state.handle_supported_regular_vars() state.handle_supported_regular_vars()
state.handle_non_var() state.handle_non_var()
@pytest.mark.asyncio
async def test_deserialize_gc_state_disk(token):
"""Test that a state can be deserialized from disk with a grandchild state.
Args:
token: A token.
"""
class Root(BaseState):
pass
class State(Root):
num: int = 42
class Child(State):
foo: str = "bar"
dsm = StateManagerDisk(state=Root)
async with dsm.modify_state(token) as root:
s = await root.get_state(State)
s.num += 1
c = await root.get_state(Child)
assert s._get_was_touched()
assert not c._get_was_touched()
dsm2 = StateManagerDisk(state=Root)
root = await dsm2.get_state(token)
s = await root.get_state(State)
assert s.num == 43
c = await root.get_state(Child)
assert c.foo == "bar"
class Obj(Base):
"""A object containing a callable for testing fallback pickle."""
_f: Callable
def test_fallback_pickle():
"""Test that state serialization will fall back to dill."""
class DillState(BaseState):
_o: Optional[Obj] = None
_f: Optional[Callable] = None
_g: Any = None
state = DillState(_reflex_internal_init=True) # type: ignore
state._o = Obj(_f=lambda: 42)
state._f = lambda: 420
pk = state._serialize()
unpickled_state = BaseState._deserialize(pk)
assert unpickled_state._f() == 420
assert unpickled_state._o._f() == 42
# Some object, like generator, are still unpicklable with dill.
state._g = (i for i in range(10))
pk = state._serialize()
assert len(pk) == 0
with pytest.raises(EOFError):
BaseState._deserialize(pk)

View File

@ -601,6 +601,7 @@ formatted_router = {
"sum": 3.14, "sum": 3.14,
"upper": "", "upper": "",
"router": formatted_router, "router": formatted_router,
"asynctest": 0,
}, },
ChildState.get_full_name(): { ChildState.get_full_name(): {
"count": 23, "count": 23,