merging
This commit is contained in:
commit
dc9ce91717
16
poetry.lock
generated
16
poetry.lock
generated
@ -1977,20 +1977,6 @@ files = [
|
||||
[package.dependencies]
|
||||
six = ">=1.5"
|
||||
|
||||
[[package]]
|
||||
name = "python-dotenv"
|
||||
version = "1.0.1"
|
||||
description = "Read key-value pairs from a .env file and set them as environment variables"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"},
|
||||
{file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
cli = ["click (>=5.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "python-engineio"
|
||||
version = "4.10.1"
|
||||
@ -3047,4 +3033,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "c5da15520cef58124f6699007c81158036840469d4f9972592d72bd456c45e7e"
|
||||
content-hash = "8090ccaeca173bd8612e17a0b8d157d7492618e49450abd1c8373e2976349db0"
|
||||
|
@ -33,7 +33,6 @@ jinja2 = ">=3.1.2,<4.0"
|
||||
psutil = ">=5.9.4,<7.0"
|
||||
pydantic = ">=1.10.2,<3.0"
|
||||
python-multipart = ">=0.0.5,<0.1"
|
||||
python-dotenv = ">=1.0.1"
|
||||
python-socketio = ">=5.7.0,<6.0"
|
||||
redis = ">=4.3.5,<6.0"
|
||||
rich = ">=13.0.0,<14.0"
|
||||
|
@ -4,7 +4,7 @@ import {
|
||||
ColorModeContext,
|
||||
defaultColorMode,
|
||||
isDevMode,
|
||||
lastCompiledTimeStamp
|
||||
lastCompiledTimeStamp,
|
||||
} from "$/utils/context.js";
|
||||
|
||||
export default function RadixThemesColorModeProvider({ children }) {
|
||||
@ -37,7 +37,7 @@ export default function RadixThemesColorModeProvider({ children }) {
|
||||
const allowedModes = ["light", "dark", "system"];
|
||||
if (!allowedModes.includes(mode)) {
|
||||
console.error(
|
||||
`Invalid color mode "${mode}". Defaulting to "${defaultColorMode}".`,
|
||||
`Invalid color mode "${mode}". Defaulting to "${defaultColorMode}".`
|
||||
);
|
||||
mode = defaultColorMode;
|
||||
}
|
||||
|
@ -315,7 +315,7 @@ def _compile_stateful_components(
|
||||
# Don't import from the file that we're about to create.
|
||||
all_imports = utils.merge_imports(*all_import_dicts)
|
||||
all_imports.pop(
|
||||
f"/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None
|
||||
f"$/{constants.Dirs.UTILS}/{constants.PageNames.STATEFUL_COMPONENTS}", None
|
||||
)
|
||||
|
||||
return templates.STATEFUL_COMPONENTS.render(
|
||||
|
@ -83,6 +83,12 @@ def validate_imports(import_dict: ParsedImportDict):
|
||||
f"{_import.tag}/{_import.alias}" if _import.alias else _import.tag
|
||||
)
|
||||
if import_name in used_tags:
|
||||
already_imported = used_tags[import_name]
|
||||
if (already_imported[0] == "$" and already_imported[1:] == lib) or (
|
||||
lib[0] == "$" and lib[1:] == already_imported
|
||||
):
|
||||
used_tags[import_name] = lib if lib[0] == "$" else already_imported
|
||||
continue
|
||||
raise ValueError(
|
||||
f"Can not compile, the tag {import_name} is used multiple time from {lib} and {used_tags[import_name]}"
|
||||
)
|
||||
|
@ -39,6 +39,7 @@ from reflex.constants import (
|
||||
)
|
||||
from reflex.constants.compiler import SpecialAttributes
|
||||
from reflex.event import (
|
||||
EventCallback,
|
||||
EventChain,
|
||||
EventChainVar,
|
||||
EventHandler,
|
||||
@ -1135,6 +1136,8 @@ class Component(BaseComponent, ABC):
|
||||
for trigger in self.event_triggers.values():
|
||||
if isinstance(trigger, EventChain):
|
||||
for event in trigger.events:
|
||||
if isinstance(event, EventCallback):
|
||||
continue
|
||||
if isinstance(event, EventSpec):
|
||||
if event.handler.state_full_name:
|
||||
return True
|
||||
@ -2241,7 +2244,7 @@ class StatefulComponent(BaseComponent):
|
||||
"""
|
||||
if self.rendered_as_shared:
|
||||
return {
|
||||
f"/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [
|
||||
f"$/{Dirs.UTILS}/{PageNames.STATEFUL_COMPONENTS}": [
|
||||
ImportVar(tag=self.tag)
|
||||
]
|
||||
}
|
||||
|
@ -118,7 +118,7 @@ class DebounceInput(Component):
|
||||
_var_type=Type[Component],
|
||||
_var_data=VarData(
|
||||
imports=child._get_imports(),
|
||||
hooks=child._get_hooks_internal(),
|
||||
hooks=child._get_all_hooks(),
|
||||
),
|
||||
),
|
||||
)
|
||||
@ -128,6 +128,10 @@ class DebounceInput(Component):
|
||||
component.event_triggers.update(child.event_triggers)
|
||||
component.children = child.children
|
||||
component._rename_props = child._rename_props
|
||||
outer_get_all_custom_code = component._get_all_custom_code
|
||||
component._get_all_custom_code = lambda: outer_get_all_custom_code().union(
|
||||
child._get_all_custom_code()
|
||||
)
|
||||
return component
|
||||
|
||||
def _render(self):
|
||||
|
@ -93,7 +93,7 @@ def load_dynamic_serializer():
|
||||
for lib, names in component._get_all_imports().items():
|
||||
formatted_lib_name = format_library_name(lib)
|
||||
if (
|
||||
not lib.startswith((".", "$/"))
|
||||
not lib.startswith((".", "/", "$/"))
|
||||
and not lib.startswith("http")
|
||||
and formatted_lib_name not in libs_in_window
|
||||
):
|
||||
@ -109,7 +109,7 @@ def load_dynamic_serializer():
|
||||
# Rewrite imports from `/` to destructure from window
|
||||
for ix, line in enumerate(module_code_lines[:]):
|
||||
if line.startswith("import "):
|
||||
if 'from "$/' in line:
|
||||
if 'from "$/' in line or 'from "/' in line:
|
||||
module_code_lines[ix] = (
|
||||
line.replace("import ", "const ", 1).replace(
|
||||
" from ", " = window['__reflex'][", 1
|
||||
|
@ -615,6 +615,42 @@ class Textarea(BaseHTML):
|
||||
# Fired when a key is released
|
||||
on_key_up: EventHandler[key_event]
|
||||
|
||||
@classmethod
|
||||
def create(cls, *children, **props):
|
||||
"""Create a textarea component.
|
||||
|
||||
Args:
|
||||
*children: The children of the textarea.
|
||||
**props: The properties of the textarea.
|
||||
|
||||
Returns:
|
||||
The textarea component.
|
||||
|
||||
Raises:
|
||||
ValueError: when `enter_key_submit` is combined with `on_key_down`.
|
||||
"""
|
||||
enter_key_submit = props.get("enter_key_submit")
|
||||
auto_height = props.get("auto_height")
|
||||
custom_attrs = props.setdefault("custom_attrs", {})
|
||||
|
||||
if enter_key_submit is not None:
|
||||
enter_key_submit = Var.create(enter_key_submit)
|
||||
if "on_key_down" in props:
|
||||
raise ValueError(
|
||||
"Cannot combine `enter_key_submit` with `on_key_down`.",
|
||||
)
|
||||
custom_attrs["on_key_down"] = Var(
|
||||
_js_expr=f"(e) => enterKeySubmitOnKeyDown(e, {str(enter_key_submit)})",
|
||||
_var_data=VarData.merge(enter_key_submit._get_all_var_data()),
|
||||
)
|
||||
if auto_height is not None:
|
||||
auto_height = Var.create(auto_height)
|
||||
custom_attrs["on_input"] = Var(
|
||||
_js_expr=f"(e) => autoHeightOnInput(e, {str(auto_height)})",
|
||||
_var_data=VarData.merge(auto_height._get_all_var_data()),
|
||||
)
|
||||
return super().create(*children, **props)
|
||||
|
||||
def _exclude_props(self) -> list[str]:
|
||||
return super()._exclude_props() + [
|
||||
"auto_height",
|
||||
@ -634,28 +670,6 @@ class Textarea(BaseHTML):
|
||||
custom_code.add(ENTER_KEY_SUBMIT_JS)
|
||||
return custom_code
|
||||
|
||||
def _render(self) -> Tag:
|
||||
tag = super()._render()
|
||||
if self.enter_key_submit is not None:
|
||||
if "on_key_down" in self.event_triggers:
|
||||
raise ValueError(
|
||||
"Cannot combine `enter_key_submit` with `on_key_down`.",
|
||||
)
|
||||
tag.add_props(
|
||||
on_key_down=Var(
|
||||
_js_expr=f"(e) => enterKeySubmitOnKeyDown(e, {str(self.enter_key_submit)})",
|
||||
_var_data=VarData.merge(self.enter_key_submit._get_all_var_data()),
|
||||
)
|
||||
)
|
||||
if self.auto_height is not None:
|
||||
tag.add_props(
|
||||
on_input=Var(
|
||||
_js_expr=f"(e) => autoHeightOnInput(e, {str(self.auto_height)})",
|
||||
_var_data=VarData.merge(self.auto_height._get_all_var_data()),
|
||||
)
|
||||
)
|
||||
return tag
|
||||
|
||||
|
||||
button = Button.create
|
||||
fieldset = Fieldset.create
|
||||
|
@ -1376,10 +1376,10 @@ class Textarea(BaseHTML):
|
||||
on_unmount: Optional[EventType[[]]] = None,
|
||||
**props,
|
||||
) -> "Textarea":
|
||||
"""Create the component.
|
||||
"""Create a textarea component.
|
||||
|
||||
Args:
|
||||
*children: The children of the component.
|
||||
*children: The children of the textarea.
|
||||
auto_complete: Whether the form control should have autocomplete enabled
|
||||
auto_focus: Automatically focuses the textarea when the page loads
|
||||
auto_height: Automatically fit the content height to the text (use min-height with this prop)
|
||||
@ -1419,10 +1419,13 @@ class Textarea(BaseHTML):
|
||||
class_name: The class name for the component.
|
||||
autofocus: Whether the component should take the focus once the page is loaded
|
||||
custom_attrs: custom attribute
|
||||
**props: The props of the component.
|
||||
**props: The properties of the textarea.
|
||||
|
||||
Returns:
|
||||
The component.
|
||||
The textarea component.
|
||||
|
||||
Raises:
|
||||
ValueError: when `enter_key_submit` is combined with `on_key_down`.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -8,12 +8,12 @@ import os
|
||||
import sys
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from typing_extensions import get_type_hints
|
||||
|
||||
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
|
||||
from reflex.utils.types import value_inside_optional
|
||||
from reflex.utils.types import GenericType, is_union, value_inside_optional
|
||||
|
||||
try:
|
||||
import pydantic.v1 as pydantic
|
||||
@ -157,11 +157,13 @@ def get_default_value_for_field(field: dataclasses.Field) -> Any:
|
||||
)
|
||||
|
||||
|
||||
def interpret_boolean_env(value: str) -> bool:
|
||||
# TODO: Change all interpret_.* signatures to value: str, field: dataclasses.Field once we migrate rx.Config to dataclasses
|
||||
def interpret_boolean_env(value: str, field_name: str) -> bool:
|
||||
"""Interpret a boolean environment variable value.
|
||||
|
||||
Args:
|
||||
value: The environment variable value.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
The interpreted value.
|
||||
@ -176,14 +178,15 @@ def interpret_boolean_env(value: str) -> bool:
|
||||
return True
|
||||
elif value.lower() in false_values:
|
||||
return False
|
||||
raise EnvironmentVarValueError(f"Invalid boolean value: {value}")
|
||||
raise EnvironmentVarValueError(f"Invalid boolean value: {value} for {field_name}")
|
||||
|
||||
|
||||
def interpret_int_env(value: str) -> int:
|
||||
def interpret_int_env(value: str, field_name: str) -> int:
|
||||
"""Interpret an integer environment variable value.
|
||||
|
||||
Args:
|
||||
value: The environment variable value.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
The interpreted value.
|
||||
@ -194,14 +197,17 @@ def interpret_int_env(value: str) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError as ve:
|
||||
raise EnvironmentVarValueError(f"Invalid integer value: {value}") from ve
|
||||
raise EnvironmentVarValueError(
|
||||
f"Invalid integer value: {value} for {field_name}"
|
||||
) from ve
|
||||
|
||||
|
||||
def interpret_path_env(value: str) -> Path:
|
||||
def interpret_path_env(value: str, field_name: str) -> Path:
|
||||
"""Interpret a path environment variable value.
|
||||
|
||||
Args:
|
||||
value: The environment variable value.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
The interpreted value.
|
||||
@ -211,16 +217,19 @@ def interpret_path_env(value: str) -> Path:
|
||||
"""
|
||||
path = Path(value)
|
||||
if not path.exists():
|
||||
raise EnvironmentVarValueError(f"Path does not exist: {path}")
|
||||
raise EnvironmentVarValueError(f"Path does not exist: {path} for {field_name}")
|
||||
return path
|
||||
|
||||
|
||||
def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any:
|
||||
def interpret_env_var_value(
|
||||
value: str, field_type: GenericType, field_name: str
|
||||
) -> Any:
|
||||
"""Interpret an environment variable value based on the field type.
|
||||
|
||||
Args:
|
||||
value: The environment variable value.
|
||||
field: The field.
|
||||
field_type: The field type.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
The interpreted value.
|
||||
@ -228,20 +237,25 @@ def interpret_env_var_value(value: str, field: dataclasses.Field) -> Any:
|
||||
Raises:
|
||||
ValueError: If the value is invalid.
|
||||
"""
|
||||
field_type = value_inside_optional(field.type)
|
||||
field_type = value_inside_optional(field_type)
|
||||
|
||||
if is_union(field_type):
|
||||
raise ValueError(
|
||||
f"Union types are not supported for environment variables: {field_name}."
|
||||
)
|
||||
|
||||
if field_type is bool:
|
||||
return interpret_boolean_env(value)
|
||||
return interpret_boolean_env(value, field_name)
|
||||
elif field_type is str:
|
||||
return value
|
||||
elif field_type is int:
|
||||
return interpret_int_env(value)
|
||||
return interpret_int_env(value, field_name)
|
||||
elif field_type is Path:
|
||||
return interpret_path_env(value)
|
||||
return interpret_path_env(value, field_name)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid type for environment variable {field.name}: {field_type}. This is probably an issue in Reflex."
|
||||
f"Invalid type for environment variable {field_name}: {field_type}. This is probably an issue in Reflex."
|
||||
)
|
||||
|
||||
|
||||
@ -316,7 +330,7 @@ class EnvironmentVariables:
|
||||
field.type = type_hints.get(field.name) or field.type
|
||||
|
||||
value = (
|
||||
interpret_env_var_value(raw_value, field)
|
||||
interpret_env_var_value(raw_value, field.type, field.name)
|
||||
if raw_value is not None
|
||||
else get_default_value_for_field(field)
|
||||
)
|
||||
@ -387,7 +401,7 @@ class Config(Base):
|
||||
telemetry_enabled: bool = True
|
||||
|
||||
# The bun path
|
||||
bun_path: Union[str, Path] = constants.Bun.DEFAULT_PATH
|
||||
bun_path: Path = constants.Bun.DEFAULT_PATH
|
||||
|
||||
# List of origins that are allowed to connect to the backend API.
|
||||
cors_allowed_origins: List[str] = ["*"]
|
||||
@ -484,17 +498,17 @@ class Config(Base):
|
||||
|
||||
Returns:
|
||||
The updated config values.
|
||||
|
||||
Raises:
|
||||
EnvVarValueError: If an environment variable is set to an invalid type.
|
||||
"""
|
||||
from reflex.utils.exceptions import EnvVarValueError
|
||||
|
||||
if self.env_file:
|
||||
from dotenv import load_dotenv
|
||||
try:
|
||||
from dotenv import load_dotenv # type: ignore
|
||||
|
||||
# load env file if exists
|
||||
load_dotenv(self.env_file, override=True)
|
||||
# load env file if exists
|
||||
load_dotenv(self.env_file, override=True)
|
||||
except ImportError:
|
||||
console.error(
|
||||
"""The `python-dotenv` package is required to load environment variables from a file. Run `pip install "python-dotenv>=1.0.1"`."""
|
||||
)
|
||||
|
||||
updated_values = {}
|
||||
# Iterate over the fields.
|
||||
@ -510,21 +524,11 @@ class Config(Base):
|
||||
dedupe=True,
|
||||
)
|
||||
|
||||
# Convert the env var to the expected type.
|
||||
try:
|
||||
if issubclass(field.type_, bool):
|
||||
# special handling for bool values
|
||||
env_var = env_var.lower() in ["true", "1", "yes"]
|
||||
else:
|
||||
env_var = field.type_(env_var)
|
||||
except ValueError as ve:
|
||||
console.error(
|
||||
f"Could not convert {key.upper()}={env_var} to type {field.type_}"
|
||||
)
|
||||
raise EnvVarValueError from ve
|
||||
# Interpret the value.
|
||||
value = interpret_env_var_value(env_var, field.type_, field.name)
|
||||
|
||||
# Set the value.
|
||||
updated_values[key] = env_var
|
||||
updated_values[key] = value
|
||||
|
||||
return updated_values
|
||||
|
||||
|
126
reflex/event.py
126
reflex/event.py
@ -16,6 +16,7 @@ from typing import (
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
@ -389,7 +390,9 @@ class CallableEventSpec(EventSpec):
|
||||
class EventChain(EventActionsMixin):
|
||||
"""Container for a chain of events that will be executed in order."""
|
||||
|
||||
events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list)
|
||||
events: Sequence[Union[EventSpec, EventVar, EventCallback]] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
args_spec: Optional[Callable] = dataclasses.field(default=None)
|
||||
|
||||
@ -1445,13 +1448,8 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
|
||||
)
|
||||
|
||||
|
||||
G = ParamSpec("G")
|
||||
|
||||
IndividualEventType = Union[EventSpec, EventHandler, Callable[G, Any], Var[Any]]
|
||||
|
||||
EventType = Union[IndividualEventType[G], List[IndividualEventType[G]]]
|
||||
|
||||
P = ParamSpec("P")
|
||||
Q = ParamSpec("Q")
|
||||
T = TypeVar("T")
|
||||
V = TypeVar("V")
|
||||
V2 = TypeVar("V2")
|
||||
@ -1473,55 +1471,73 @@ if sys.version_info >= (3, 10):
|
||||
"""
|
||||
self.func = func
|
||||
|
||||
@property
|
||||
def prevent_default(self):
|
||||
"""Prevent default behavior.
|
||||
|
||||
Returns:
|
||||
The event callback with prevent default behavior.
|
||||
"""
|
||||
return self
|
||||
|
||||
@property
|
||||
def stop_propagation(self):
|
||||
"""Stop event propagation.
|
||||
|
||||
Returns:
|
||||
The event callback with stop propagation behavior.
|
||||
"""
|
||||
return self
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: EventCallback[[V], T], instance: None, owner
|
||||
) -> Callable[[Union[Var[V], V]], EventSpec]: ...
|
||||
def __call__(
|
||||
self: EventCallback[Concatenate[V, Q], T], value: V | Var[V]
|
||||
) -> EventCallback[Q, T]: ...
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self: EventCallback[Concatenate[V, V2, Q], T],
|
||||
value: V | Var[V],
|
||||
value2: V2 | Var[V2],
|
||||
) -> EventCallback[Q, T]: ...
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self: EventCallback[Concatenate[V, V2, V3, Q], T],
|
||||
value: V | Var[V],
|
||||
value2: V2 | Var[V2],
|
||||
value3: V3 | Var[V3],
|
||||
) -> EventCallback[Q, T]: ...
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self: EventCallback[Concatenate[V, V2, V3, V4, Q], T],
|
||||
value: V | Var[V],
|
||||
value2: V2 | Var[V2],
|
||||
value3: V3 | Var[V3],
|
||||
value4: V4 | Var[V4],
|
||||
) -> EventCallback[Q, T]: ...
|
||||
|
||||
def __call__(self, *values) -> EventCallback: # type: ignore
|
||||
"""Call the function with the values.
|
||||
|
||||
Args:
|
||||
*values: The values to call the function with.
|
||||
|
||||
Returns:
|
||||
The function with the values.
|
||||
"""
|
||||
return self.func(*values) # type: ignore
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: EventCallback[[V, V2], T], instance: None, owner
|
||||
) -> Callable[[Union[Var[V], V], Union[Var[V2], V2]], EventSpec]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: EventCallback[[V, V2, V3], T], instance: None, owner
|
||||
) -> Callable[
|
||||
[Union[Var[V], V], Union[Var[V2], V2], Union[Var[V3], V3]],
|
||||
EventSpec,
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: EventCallback[[V, V2, V3, V4], T], instance: None, owner
|
||||
) -> Callable[
|
||||
[
|
||||
Union[Var[V], V],
|
||||
Union[Var[V2], V2],
|
||||
Union[Var[V3], V3],
|
||||
Union[Var[V4], V4],
|
||||
],
|
||||
EventSpec,
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: EventCallback[[V, V2, V3, V4, V5], T], instance: None, owner
|
||||
) -> Callable[
|
||||
[
|
||||
Union[Var[V], V],
|
||||
Union[Var[V2], V2],
|
||||
Union[Var[V3], V3],
|
||||
Union[Var[V4], V4],
|
||||
Union[Var[V5], V5],
|
||||
],
|
||||
EventSpec,
|
||||
]: ...
|
||||
self: EventCallback[P, T], instance: None, owner
|
||||
) -> EventCallback[P, T]: ...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance, owner) -> Callable[P, T]: ...
|
||||
|
||||
def __get__(self, instance, owner) -> Callable:
|
||||
def __get__(self, instance, owner) -> Callable: # type: ignore
|
||||
"""Get the function with the instance bound to it.
|
||||
|
||||
Args:
|
||||
@ -1548,6 +1564,9 @@ if sys.version_info >= (3, 10):
|
||||
return func # type: ignore
|
||||
else:
|
||||
|
||||
class EventCallback(Generic[P, T]):
|
||||
"""A descriptor that wraps a function to be used as an event."""
|
||||
|
||||
def event_handler(func: Callable[P, T]) -> Callable[P, T]:
|
||||
"""Wrap a function to be used as an event.
|
||||
|
||||
@ -1560,6 +1579,17 @@ else:
|
||||
return func
|
||||
|
||||
|
||||
G = ParamSpec("G")
|
||||
|
||||
IndividualEventType = Union[
|
||||
EventSpec, EventHandler, Callable[G, Any], EventCallback[G, Any], Var[Any]
|
||||
]
|
||||
|
||||
ItemOrList = Union[V, List[V]]
|
||||
|
||||
EventType = ItemOrList[IndividualEventType[G]]
|
||||
|
||||
|
||||
class EventNamespace(types.SimpleNamespace):
|
||||
"""A namespace for event related classes."""
|
||||
|
||||
|
@ -38,7 +38,7 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
|
||||
url = url or conf.db_url
|
||||
if url is None:
|
||||
raise ValueError("No database url configured")
|
||||
if environment.ALEMBIC_CONFIG.exists():
|
||||
if not environment.ALEMBIC_CONFIG.exists():
|
||||
console.warn(
|
||||
"Database is not initialized, run [bold]reflex db init[/bold] first."
|
||||
)
|
||||
|
@ -220,6 +220,7 @@ class EventHandlerSetVar(EventHandler):
|
||||
Raises:
|
||||
AttributeError: If the given Var name does not exist on the state.
|
||||
EventHandlerValueError: If the given Var name is not a str
|
||||
NotImplementedError: If the setter for the given Var is async
|
||||
"""
|
||||
from reflex.utils.exceptions import EventHandlerValueError
|
||||
|
||||
@ -228,11 +229,20 @@ class EventHandlerSetVar(EventHandler):
|
||||
raise EventHandlerValueError(
|
||||
f"Var name must be passed as a string, got {args[0]!r}"
|
||||
)
|
||||
|
||||
handler = getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None)
|
||||
|
||||
# Check that the requested Var setter exists on the State at compile time.
|
||||
if getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) is None:
|
||||
if handler is None:
|
||||
raise AttributeError(
|
||||
f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`"
|
||||
)
|
||||
|
||||
if asyncio.iscoroutinefunction(handler.fn):
|
||||
raise NotImplementedError(
|
||||
f"Setter for {args[0]} is async, which is not supported."
|
||||
)
|
||||
|
||||
return super().__call__(*args)
|
||||
|
||||
|
||||
@ -2895,9 +2905,13 @@ class StateManagerDisk(StateManager):
|
||||
for substate in state.get_substates():
|
||||
substate_token = _substate_key(client_token, substate)
|
||||
|
||||
fresh_instance = await root_state.get_state(substate)
|
||||
instance = await self.load_state(substate_token)
|
||||
if instance is None:
|
||||
instance = await root_state.get_state(substate)
|
||||
if instance is not None:
|
||||
# Ensure all substates exist, even if they weren't serialized previously.
|
||||
instance.substates = fresh_instance.substates
|
||||
else:
|
||||
instance = fresh_instance
|
||||
state.substates[substate.get_name()] = instance
|
||||
instance.parent_state = state
|
||||
|
||||
|
@ -249,7 +249,8 @@ class AppHarness:
|
||||
return textwrap.dedent(source)
|
||||
|
||||
def _initialize_app(self):
|
||||
os.environ["TELEMETRY_ENABLED"] = "" # disable telemetry reporting for tests
|
||||
# disable telemetry reporting for tests
|
||||
os.environ["TELEMETRY_ENABLED"] = "false"
|
||||
self.app_path.mkdir(parents=True, exist_ok=True)
|
||||
if self.app_source is not None:
|
||||
app_globals = self._get_globals_from_signature(self.app_source)
|
||||
|
@ -23,6 +23,12 @@ def merge_imports(
|
||||
for lib, fields in (
|
||||
import_dict if isinstance(import_dict, tuple) else import_dict.items()
|
||||
):
|
||||
# If the lib is an absolute path, we need to prefix it with a $
|
||||
lib = (
|
||||
"$" + lib
|
||||
if lib.startswith(("/utils/", "/components/", "/styles/", "/public/"))
|
||||
else lib
|
||||
)
|
||||
if isinstance(fields, (list, tuple, set)):
|
||||
all_imports[lib].extend(
|
||||
(
|
||||
|
@ -1,5 +1,7 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
@ -42,7 +44,12 @@ def test_set_app_name(base_config_values):
|
||||
("TELEMETRY_ENABLED", True),
|
||||
],
|
||||
)
|
||||
def test_update_from_env(base_config_values, monkeypatch, env_var, value):
|
||||
def test_update_from_env(
|
||||
base_config_values: Dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
env_var: str,
|
||||
value: Any,
|
||||
):
|
||||
"""Test that environment variables override config values.
|
||||
|
||||
Args:
|
||||
@ -57,6 +64,29 @@ def test_update_from_env(base_config_values, monkeypatch, env_var, value):
|
||||
assert getattr(config, env_var.lower()) == value
|
||||
|
||||
|
||||
def test_update_from_env_path(
|
||||
base_config_values: Dict[str, Any],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""Test that environment variables override config values.
|
||||
|
||||
Args:
|
||||
base_config_values: Config values.
|
||||
monkeypatch: The pytest monkeypatch object.
|
||||
tmp_path: The pytest tmp_path fixture object.
|
||||
"""
|
||||
monkeypatch.setenv("BUN_PATH", "/test")
|
||||
assert os.environ.get("BUN_PATH") == "/test"
|
||||
with pytest.raises(ValueError):
|
||||
rx.Config(**base_config_values)
|
||||
|
||||
monkeypatch.setenv("BUN_PATH", str(tmp_path))
|
||||
assert os.environ.get("BUN_PATH") == str(tmp_path)
|
||||
config = rx.Config(**base_config_values)
|
||||
assert config.bun_path == tmp_path
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kwargs, expected",
|
||||
[
|
||||
|
@ -106,6 +106,7 @@ class TestState(BaseState):
|
||||
fig: Figure = Figure()
|
||||
dt: datetime.datetime = datetime.datetime.fromisoformat("1989-11-09T18:53:00+01:00")
|
||||
_backend: int = 0
|
||||
asynctest: int = 0
|
||||
|
||||
@ComputedVar
|
||||
def sum(self) -> float:
|
||||
@ -129,6 +130,14 @@ class TestState(BaseState):
|
||||
"""Do something."""
|
||||
pass
|
||||
|
||||
async def set_asynctest(self, value: int):
|
||||
"""Set the asynctest value. Intentionally overwrite the default setter with an async one.
|
||||
|
||||
Args:
|
||||
value: The new value.
|
||||
"""
|
||||
self.asynctest = value
|
||||
|
||||
|
||||
class ChildState(TestState):
|
||||
"""A child state fixture."""
|
||||
@ -313,6 +322,7 @@ def test_class_vars(test_state):
|
||||
"upper",
|
||||
"fig",
|
||||
"dt",
|
||||
"asynctest",
|
||||
}
|
||||
|
||||
|
||||
@ -733,6 +743,7 @@ def test_reset(test_state, child_state):
|
||||
"mapping",
|
||||
"dt",
|
||||
"_backend",
|
||||
"asynctest",
|
||||
}
|
||||
|
||||
# The dirty vars should be reset.
|
||||
@ -3179,6 +3190,13 @@ async def test_setvar(mock_app: rx.App, token: str):
|
||||
TestState.setvar(42, 42)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setvar_async_setter():
|
||||
"""Test that overridden async setters raise Exception when used with setvar."""
|
||||
with pytest.raises(NotImplementedError):
|
||||
TestState.setvar("asynctest", 42)
|
||||
|
||||
|
||||
@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
|
||||
@pytest.mark.parametrize(
|
||||
"expiration_kwargs, expected_values",
|
||||
@ -3313,3 +3331,36 @@ def test_assignment_to_undeclared_vars():
|
||||
|
||||
state.handle_supported_regular_vars()
|
||||
state.handle_non_var()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deserialize_gc_state_disk(token):
|
||||
"""Test that a state can be deserialized from disk with a grandchild state.
|
||||
|
||||
Args:
|
||||
token: A token.
|
||||
"""
|
||||
|
||||
class Root(BaseState):
|
||||
pass
|
||||
|
||||
class State(Root):
|
||||
num: int = 42
|
||||
|
||||
class Child(State):
|
||||
foo: str = "bar"
|
||||
|
||||
dsm = StateManagerDisk(state=Root)
|
||||
async with dsm.modify_state(token) as root:
|
||||
s = await root.get_state(State)
|
||||
s.num += 1
|
||||
c = await root.get_state(Child)
|
||||
assert s._get_was_touched()
|
||||
assert not c._get_was_touched()
|
||||
|
||||
dsm2 = StateManagerDisk(state=Root)
|
||||
root = await dsm2.get_state(token)
|
||||
s = await root.get_state(State)
|
||||
assert s.num == 43
|
||||
c = await root.get_state(Child)
|
||||
assert c.foo == "bar"
|
||||
|
@ -601,6 +601,7 @@ formatted_router = {
|
||||
"sum": 3.14,
|
||||
"upper": "",
|
||||
"router": formatted_router,
|
||||
"asynctest": 0,
|
||||
},
|
||||
ChildState.get_full_name(): {
|
||||
"count": 23,
|
||||
|
Loading…
Reference in New Issue
Block a user