Compare commits
101 Commits
main
...
add-valida
Author | SHA1 | Date | |
---|---|---|---|
![]() |
c74313992f | ||
![]() |
d1ff6d51a2 | ||
![]() |
3080cadd31 | ||
![]() |
a2074b9081 | ||
![]() |
cabb9b43ad | ||
![]() |
aef6aced03 | ||
![]() |
1bb8ba5585 | ||
![]() |
06b751f679 | ||
![]() |
749577f0bc | ||
![]() |
46c66b2adf | ||
![]() |
8fab30eb69 | ||
![]() |
2ffa698c6b | ||
![]() |
0d746bf762 | ||
![]() |
d3b12a84fa | ||
![]() |
6604784ca1 | ||
![]() |
24f341d125 | ||
![]() |
29fc4b020a | ||
![]() |
19dd15bd44 | ||
![]() |
6a50b3a29e | ||
![]() |
b2a27cb8c1 | ||
![]() |
540382dd3e | ||
![]() |
c6e2368c95 | ||
![]() |
e10cf07506 | ||
![]() |
84d3a2bb97 | ||
![]() |
8173e10698 | ||
![]() |
1f9fbd88de | ||
![]() |
68b8c12127 | ||
![]() |
5e45ca509b | ||
![]() |
be92063421 | ||
![]() |
5d6b51c561 | ||
![]() |
b10f6e836d | ||
![]() |
db89a712e9 | ||
![]() |
9e7eeb2a6e | ||
![]() |
b7579f4d8d | ||
![]() |
392c5b5a69 | ||
![]() |
b11fc5a8ef | ||
![]() |
00019daa27 | ||
![]() |
36af8255d3 | ||
![]() |
aadd8b56bf | ||
![]() |
fa6c12e8b3 | ||
![]() |
f4aa122950 | ||
![]() |
9a987caf76 | ||
![]() |
c6f05bb320 | ||
![]() |
0e539a208c | ||
![]() |
a7230f1f45 | ||
![]() |
0798cb8f60 | ||
![]() |
270fcb996d | ||
![]() |
57d8ea02e9 | ||
![]() |
112b2ed948 | ||
![]() |
3d73f561b7 | ||
![]() |
72f1fa7cb4 | ||
![]() |
94b4443afc | ||
![]() |
f9d45d5562 | ||
![]() |
94c9e52474 | ||
![]() |
4300f338d8 | ||
![]() |
f42d1f4b0f | ||
![]() |
a488fe0c49 | ||
![]() |
076cfea6ae | ||
![]() |
d0208e678c | ||
![]() |
19b6fe9efc | ||
![]() |
1aa728ee4c | ||
![]() |
713f907bf0 | ||
![]() |
990bf131c6 | ||
![]() |
f257122934 | ||
![]() |
f0f84d5410 | ||
![]() |
45dde0072e | ||
![]() |
2a02e96d87 | ||
![]() |
056de9e277 | ||
![]() |
d31510c655 | ||
![]() |
a5526afaeb | ||
![]() |
99a3090784 | ||
![]() |
bd2ea5b417 | ||
![]() |
8830d5ab77 | ||
![]() |
06eb04f005 | ||
![]() |
2e1bc057a4 | ||
![]() |
2b05ee98ed | ||
![]() |
ed1ae0d3a2 | ||
![]() |
7d0a4f7133 | ||
![]() |
079cc56f59 | ||
![]() |
92b1232806 | ||
![]() |
7f1dc7c841 | ||
![]() |
eac54d60d2 | ||
![]() |
702670ff26 | ||
![]() |
88cfb3b7e2 | ||
![]() |
53b98543cc | ||
![]() |
5f0546f32e | ||
![]() |
2c04153013 | ||
![]() |
a9db61b371 | ||
![]() |
3cdd2097b6 | ||
![]() |
9d7e353ed3 | ||
![]() |
f4aa1f58c3 | ||
![]() |
ebc81811c0 | ||
![]() |
1e9743dcd6 | ||
![]() |
05bd41c040 | ||
![]() |
f9b24fe5bd | ||
![]() |
6745d6cb9d | ||
![]() |
7ada0ea5b9 | ||
![]() |
48951dbabd | ||
![]() |
7aa9245514 | ||
![]() |
9b06d684cd | ||
![]() |
56f0d6375b |
@ -23,7 +23,7 @@ fastapi = ">=0.96.0,!=0.111.0,!=0.111.1"
|
||||
gunicorn = ">=20.1.0,<24.0"
|
||||
jinja2 = ">=3.1.2,<4.0"
|
||||
psutil = ">=5.9.4,<7.0"
|
||||
pydantic = ">=1.10.2,<3.0"
|
||||
pydantic = ">=1.10.17,<3.0"
|
||||
python-multipart = ">=0.0.5,<0.1"
|
||||
python-socketio = ">=5.7.0,<6.0"
|
||||
redis = ">=4.3.5,<6.0"
|
||||
@ -55,7 +55,7 @@ typing_extensions = ">=4.6.0"
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = ">=7.1.2,<9.0"
|
||||
pytest-mock = ">=3.10.0,<4.0"
|
||||
pyright = ">=1.1.392, <1.2"
|
||||
pyright = ">=1.1.392.post0,<1.2"
|
||||
darglint = ">=1.8.1,<2.0"
|
||||
dill = ">=0.3.8"
|
||||
toml = ">=0.10.2,<1.0"
|
||||
|
@ -6,12 +6,6 @@
|
||||
{% filter indent(width=indent_width) %}
|
||||
{%- if component is not mapping %}
|
||||
{{- component }}
|
||||
{%- elif "iterable" in component %}
|
||||
{{- render_iterable_tag(component) }}
|
||||
{%- elif component.name == "match"%}
|
||||
{{- render_match_tag(component) }}
|
||||
{%- elif "cond" in component %}
|
||||
{{- render_condition_tag(component) }}
|
||||
{%- elif component.children|length %}
|
||||
{{- render_tag(component) }}
|
||||
{%- else %}
|
||||
@ -44,30 +38,6 @@
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
{# Rendering condition component. #}
|
||||
{# Args: #}
|
||||
{# component: component dictionary #}
|
||||
{% macro render_condition_tag(component) %}
|
||||
{ {{- component.cond_state }} ? (
|
||||
{{ render(component.true_value) }}
|
||||
) : (
|
||||
{{ render(component.false_value) }}
|
||||
)}
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
{# Rendering iterable component. #}
|
||||
{# Args: #}
|
||||
{# component: component dictionary #}
|
||||
{% macro render_iterable_tag(component) %}
|
||||
<>{ {{ component.iterable_state }}.map(({{ component.arg_name }}, {{ component.arg_index }}) => (
|
||||
{% for child in component.children %}
|
||||
{{ render(child) }}
|
||||
{% endfor %}
|
||||
))}</>
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
{# Rendering props of a component. #}
|
||||
{# Args: #}
|
||||
{# component: component dictionary #}
|
||||
@ -75,29 +45,6 @@
|
||||
{% if props|length %} {{ props|join(" ") }}{% endif %}
|
||||
{% endmacro %}
|
||||
|
||||
{# Rendering Match component. #}
|
||||
{# Args: #}
|
||||
{# component: component dictionary #}
|
||||
{% macro render_match_tag(component) %}
|
||||
{
|
||||
(() => {
|
||||
switch (JSON.stringify({{ component.cond._js_expr }})) {
|
||||
{% for case in component.match_cases %}
|
||||
{% for condition in case[:-1] %}
|
||||
case JSON.stringify({{ condition._js_expr }}):
|
||||
{% endfor %}
|
||||
return {{ render(case[-1]) }};
|
||||
break;
|
||||
{% endfor %}
|
||||
default:
|
||||
return {{ render(component.default) }};
|
||||
break;
|
||||
}
|
||||
})()
|
||||
}
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
{# Rendering content with args. #}
|
||||
{# Args: #}
|
||||
{# component: component dictionary #}
|
||||
|
@ -1,43 +1,43 @@
|
||||
/**
|
||||
* Simulate the python range() builtin function.
|
||||
* inspired by https://dev.to/guyariely/using-python-range-in-javascript-337p
|
||||
*
|
||||
*
|
||||
* If needed outside of an iterator context, use `Array.from(range(10))` or
|
||||
* spread syntax `[...range(10)]` to get an array.
|
||||
*
|
||||
*
|
||||
* @param {number} start: the start or end of the range.
|
||||
* @param {number} stop: the end of the range.
|
||||
* @param {number} step: the step of the range.
|
||||
* @returns {object} an object with a Symbol.iterator method over the range
|
||||
*/
|
||||
export default function range(start, stop, step) {
|
||||
return {
|
||||
[Symbol.iterator]() {
|
||||
if (stop === undefined) {
|
||||
stop = start;
|
||||
start = 0;
|
||||
}
|
||||
if (step === undefined) {
|
||||
step = 1;
|
||||
}
|
||||
|
||||
let i = start - step;
|
||||
|
||||
return {
|
||||
next() {
|
||||
i += step;
|
||||
if ((step > 0 && i < stop) || (step < 0 && i > stop)) {
|
||||
return {
|
||||
value: i,
|
||||
done: false,
|
||||
};
|
||||
}
|
||||
return {
|
||||
[Symbol.iterator]() {
|
||||
if ((stop ?? undefined) === undefined) {
|
||||
stop = start;
|
||||
start = 0;
|
||||
}
|
||||
if ((step ?? undefined) === undefined) {
|
||||
step = 1;
|
||||
}
|
||||
|
||||
let i = start - step;
|
||||
|
||||
return {
|
||||
next() {
|
||||
i += step;
|
||||
if ((step > 0 && i < stop) || (step < 0 && i > stop)) {
|
||||
return {
|
||||
value: undefined,
|
||||
done: true,
|
||||
value: i,
|
||||
done: false,
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
return {
|
||||
value: undefined,
|
||||
done: true,
|
||||
};
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -926,6 +926,45 @@ export const isTrue = (val) => {
|
||||
return Boolean(val);
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns a copy of a section of an array.
|
||||
* @param {Array | string} arrayLike The array to slice.
|
||||
* @param {[number, number, number]} slice The slice to apply.
|
||||
* @returns The sliced array.
|
||||
*/
|
||||
export const atSlice = (arrayLike, slice) => {
|
||||
const array = [...arrayLike];
|
||||
const [startSlice, endSlice, stepSlice] = slice;
|
||||
if (stepSlice ?? null === null) {
|
||||
return array.slice(startSlice ?? undefined, endSlice ?? undefined);
|
||||
}
|
||||
const step = stepSlice ?? 1;
|
||||
if (step > 0) {
|
||||
return array
|
||||
.slice(startSlice ?? undefined, endSlice ?? undefined)
|
||||
.filter((_, i) => i % step === 0);
|
||||
}
|
||||
const actualStart = (endSlice ?? null) === null ? 0 : endSlice + 1;
|
||||
const actualEnd =
|
||||
(startSlice ?? null) === null ? array.length : startSlice + 1;
|
||||
return array
|
||||
.slice(actualStart, actualEnd)
|
||||
.reverse()
|
||||
.filter((_, i) => i % step === 0);
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the value at a slice or index.
|
||||
* @param {Array | string} arrayLike The array to get the value from.
|
||||
* @param {number | [number, number, number]} sliceOrIndex The slice or index to get the value at.
|
||||
* @returns The value at the slice or index.
|
||||
*/
|
||||
export const atSliceOrIndex = (arrayLike, sliceOrIndex) => {
|
||||
return Array.isArray(sliceOrIndex)
|
||||
? atSlice(arrayLike, sliceOrIndex)
|
||||
: arrayLike.at(sliceOrIndex);
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the value from a ref.
|
||||
* @param ref The ref to get the value from.
|
||||
|
@ -5,15 +5,9 @@ from __future__ import annotations
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, List, Type
|
||||
|
||||
try:
|
||||
import pydantic.v1.main as pydantic_main
|
||||
from pydantic.v1 import BaseModel
|
||||
from pydantic.v1.fields import ModelField
|
||||
except ModuleNotFoundError:
|
||||
if not TYPE_CHECKING:
|
||||
import pydantic.main as pydantic_main
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import ModelField
|
||||
import pydantic.v1.main as pydantic_main
|
||||
from pydantic.v1 import BaseModel
|
||||
from pydantic.v1.fields import ModelField
|
||||
|
||||
|
||||
def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
|
||||
@ -50,7 +44,7 @@ if TYPE_CHECKING:
|
||||
from reflex.vars import Var
|
||||
|
||||
|
||||
class Base(BaseModel): # pyright: ignore [reportPossiblyUnboundVariable]
|
||||
class Base(BaseModel):
|
||||
"""The base class subclassed by all Reflex classes.
|
||||
|
||||
This class wraps Pydantic and provides common methods such as
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterator
|
||||
|
||||
from reflex.components.component import Component, LiteralComponentVar
|
||||
from reflex.components.component import Component, ComponentStyle
|
||||
from reflex.components.tags import Tag
|
||||
from reflex.components.tags.tagless import Tagless
|
||||
from reflex.config import PerformanceMode, environment
|
||||
@ -12,7 +12,7 @@ from reflex.utils import console
|
||||
from reflex.utils.decorator import once
|
||||
from reflex.utils.imports import ParsedImportDict
|
||||
from reflex.vars import BooleanVar, ObjectVar, Var
|
||||
from reflex.vars.base import VarData
|
||||
from reflex.vars.base import GLOBAL_CACHE, VarData
|
||||
from reflex.vars.sequence import LiteralStringVar
|
||||
|
||||
|
||||
@ -80,8 +80,11 @@ class Bare(Component):
|
||||
The hooks for the component.
|
||||
"""
|
||||
hooks = super()._get_all_hooks_internal()
|
||||
if isinstance(self.contents, LiteralComponentVar):
|
||||
hooks |= self.contents._var_value._get_all_hooks_internal()
|
||||
if isinstance(self.contents, Var):
|
||||
var_data = self.contents._get_all_var_data()
|
||||
if var_data:
|
||||
for component in var_data.components:
|
||||
hooks |= component._get_all_hooks_internal()
|
||||
return hooks
|
||||
|
||||
def _get_all_hooks(self) -> dict[str, VarData | None]:
|
||||
@ -91,18 +94,24 @@ class Bare(Component):
|
||||
The hooks for the component.
|
||||
"""
|
||||
hooks = super()._get_all_hooks()
|
||||
if isinstance(self.contents, LiteralComponentVar):
|
||||
hooks |= self.contents._var_value._get_all_hooks()
|
||||
if isinstance(self.contents, Var):
|
||||
var_data = self.contents._get_all_var_data()
|
||||
if var_data:
|
||||
for component in var_data.components:
|
||||
hooks |= component._get_all_hooks()
|
||||
return hooks
|
||||
|
||||
def _get_all_imports(self) -> ParsedImportDict:
|
||||
def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict:
|
||||
"""Include the imports for the component.
|
||||
|
||||
Args:
|
||||
collapse: Whether to collapse the imports.
|
||||
|
||||
Returns:
|
||||
The imports for the component.
|
||||
"""
|
||||
imports = super()._get_all_imports()
|
||||
if isinstance(self.contents, LiteralComponentVar):
|
||||
imports = super()._get_all_imports(collapse=collapse)
|
||||
if isinstance(self.contents, Var):
|
||||
var_data = self.contents._get_all_var_data()
|
||||
if var_data:
|
||||
imports |= {k: list(v) for k, v in var_data.imports}
|
||||
@ -115,8 +124,11 @@ class Bare(Component):
|
||||
The dynamic imports.
|
||||
"""
|
||||
dynamic_imports = super()._get_all_dynamic_imports()
|
||||
if isinstance(self.contents, LiteralComponentVar):
|
||||
dynamic_imports |= self.contents._var_value._get_all_dynamic_imports()
|
||||
if isinstance(self.contents, Var):
|
||||
var_data = self.contents._get_all_var_data()
|
||||
if var_data:
|
||||
for component in var_data.components:
|
||||
dynamic_imports |= component._get_all_dynamic_imports()
|
||||
return dynamic_imports
|
||||
|
||||
def _get_all_custom_code(self) -> set[str]:
|
||||
@ -126,10 +138,28 @@ class Bare(Component):
|
||||
The custom code.
|
||||
"""
|
||||
custom_code = super()._get_all_custom_code()
|
||||
if isinstance(self.contents, LiteralComponentVar):
|
||||
custom_code |= self.contents._var_value._get_all_custom_code()
|
||||
if isinstance(self.contents, Var):
|
||||
var_data = self.contents._get_all_var_data()
|
||||
if var_data:
|
||||
for component in var_data.components:
|
||||
custom_code |= component._get_all_custom_code()
|
||||
return custom_code
|
||||
|
||||
def _get_all_app_wrap_components(self) -> dict[tuple[int, str], Component]:
|
||||
"""Get the components that should be wrapped in the app.
|
||||
|
||||
Returns:
|
||||
The components that should be wrapped in the app.
|
||||
"""
|
||||
app_wrap_components = super()._get_all_app_wrap_components()
|
||||
if isinstance(self.contents, Var):
|
||||
var_data = self.contents._get_all_var_data()
|
||||
if var_data:
|
||||
for component in var_data.components:
|
||||
if isinstance(component, Component):
|
||||
app_wrap_components |= component._get_all_app_wrap_components()
|
||||
return app_wrap_components
|
||||
|
||||
def _get_all_refs(self) -> set[str]:
|
||||
"""Get the refs for the children of the component.
|
||||
|
||||
@ -137,8 +167,11 @@ class Bare(Component):
|
||||
The refs for the children.
|
||||
"""
|
||||
refs = super()._get_all_refs()
|
||||
if isinstance(self.contents, LiteralComponentVar):
|
||||
refs |= self.contents._var_value._get_all_refs()
|
||||
if isinstance(self.contents, Var):
|
||||
var_data = self.contents._get_all_var_data()
|
||||
if var_data:
|
||||
for component in var_data.components:
|
||||
refs |= component._get_all_refs()
|
||||
return refs
|
||||
|
||||
def _render(self) -> Tag:
|
||||
@ -148,6 +181,30 @@ class Bare(Component):
|
||||
return Tagless(contents=f"{{{self.contents!s}}}")
|
||||
return Tagless(contents=str(self.contents))
|
||||
|
||||
def _add_style_recursive(
|
||||
self, style: ComponentStyle, theme: Component | None = None
|
||||
) -> Component:
|
||||
"""Add style to the component and its children.
|
||||
|
||||
Args:
|
||||
style: The style to add.
|
||||
theme: The theme to add.
|
||||
|
||||
Returns:
|
||||
The component with the style added.
|
||||
"""
|
||||
new_self = super()._add_style_recursive(style, theme)
|
||||
if isinstance(self.contents, Var):
|
||||
var_data = self.contents._get_all_var_data()
|
||||
if var_data:
|
||||
for component in var_data.components:
|
||||
if isinstance(component, Component):
|
||||
component._add_style_recursive(style, theme)
|
||||
|
||||
GLOBAL_CACHE.clear()
|
||||
|
||||
return new_self
|
||||
|
||||
def _get_vars(
|
||||
self, include_children: bool = False, ignore_ids: set[int] | None = None
|
||||
) -> Iterator[Var]:
|
||||
|
@ -65,8 +65,7 @@ from reflex.vars.base import (
|
||||
Var,
|
||||
cached_property_no_lock,
|
||||
)
|
||||
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar
|
||||
from reflex.vars.number import ternary_operation
|
||||
from reflex.vars.function import FunctionStringVar
|
||||
from reflex.vars.object import ObjectVar
|
||||
from reflex.vars.sequence import LiteralArrayVar
|
||||
|
||||
@ -889,10 +888,8 @@ class Component(BaseComponent, ABC):
|
||||
children: The children of the component.
|
||||
|
||||
"""
|
||||
from reflex.components.base.bare import Bare
|
||||
from reflex.components.base.fragment import Fragment
|
||||
from reflex.components.core.cond import Cond
|
||||
from reflex.components.core.foreach import Foreach
|
||||
from reflex.components.core.match import Match
|
||||
|
||||
no_valid_parents_defined = all(child._valid_parents == [] for child in children)
|
||||
if (
|
||||
@ -903,9 +900,7 @@ class Component(BaseComponent, ABC):
|
||||
return
|
||||
|
||||
comp_name = type(self).__name__
|
||||
allowed_components = [
|
||||
comp.__name__ for comp in (Fragment, Foreach, Cond, Match)
|
||||
]
|
||||
allowed_components = [comp.__name__ for comp in (Fragment,)]
|
||||
|
||||
def validate_child(child: Any):
|
||||
child_name = type(child).__name__
|
||||
@ -915,24 +910,39 @@ class Component(BaseComponent, ABC):
|
||||
for c in child.children:
|
||||
validate_child(c)
|
||||
|
||||
if isinstance(child, Cond):
|
||||
validate_child(child.comp1)
|
||||
validate_child(child.comp2)
|
||||
|
||||
if isinstance(child, Match):
|
||||
for cases in child.match_cases:
|
||||
validate_child(cases[-1])
|
||||
validate_child(child.default)
|
||||
if (
|
||||
isinstance(child, Bare)
|
||||
and child.contents is not None
|
||||
and isinstance(child.contents, Var)
|
||||
):
|
||||
var_data = child.contents._get_all_var_data()
|
||||
if var_data is not None:
|
||||
for c in var_data.components:
|
||||
validate_child(c)
|
||||
|
||||
if self._invalid_children and child_name in self._invalid_children:
|
||||
raise ValueError(
|
||||
f"The component `{comp_name}` cannot have `{child_name}` as a child component"
|
||||
)
|
||||
|
||||
if self._valid_children and child_name not in [
|
||||
*self._valid_children,
|
||||
*allowed_components,
|
||||
]:
|
||||
valid_children = self._valid_children + allowed_components
|
||||
|
||||
def child_is_in_valid(child_component: Any):
|
||||
if type(child_component).__name__ in valid_children:
|
||||
return True
|
||||
|
||||
if (
|
||||
not isinstance(child_component, Bare)
|
||||
or child_component.contents is None
|
||||
or not isinstance(child_component.contents, Var)
|
||||
or (var_data := child_component.contents._get_all_var_data())
|
||||
is None
|
||||
):
|
||||
return False
|
||||
|
||||
return all(child_is_in_valid(c) for c in var_data.components)
|
||||
|
||||
if self._valid_children and not child_is_in_valid(child):
|
||||
valid_child_list = ", ".join(
|
||||
[f"`{v_child}`" for v_child in self._valid_children]
|
||||
)
|
||||
@ -1918,8 +1928,6 @@ class StatefulComponent(BaseComponent):
|
||||
Returns:
|
||||
The stateful component or None if the component should not be memoized.
|
||||
"""
|
||||
from reflex.components.core.foreach import Foreach
|
||||
|
||||
if component._memoization_mode.disposition == MemoizationDisposition.NEVER:
|
||||
# Never memoize this component.
|
||||
return None
|
||||
@ -1948,10 +1956,6 @@ class StatefulComponent(BaseComponent):
|
||||
# Skip BaseComponent and StatefulComponent children.
|
||||
if not isinstance(child, Component):
|
||||
continue
|
||||
# Always consider Foreach something that must be memoized by the parent.
|
||||
if isinstance(child, Foreach):
|
||||
should_memoize = True
|
||||
break
|
||||
child = cls._child_var(child)
|
||||
if isinstance(child, Var) and child._get_all_var_data():
|
||||
should_memoize = True
|
||||
@ -2001,18 +2005,9 @@ class StatefulComponent(BaseComponent):
|
||||
The Var from the child component or the child itself (for regular cases).
|
||||
"""
|
||||
from reflex.components.base.bare import Bare
|
||||
from reflex.components.core.cond import Cond
|
||||
from reflex.components.core.foreach import Foreach
|
||||
from reflex.components.core.match import Match
|
||||
|
||||
if isinstance(child, Bare):
|
||||
return child.contents
|
||||
if isinstance(child, Cond):
|
||||
return child.cond
|
||||
if isinstance(child, Foreach):
|
||||
return child.iterable
|
||||
if isinstance(child, Match):
|
||||
return child.cond
|
||||
return child
|
||||
|
||||
@classmethod
|
||||
@ -2359,53 +2354,6 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
|
||||
return render_dict_to_var(tag.render(), imported_names)
|
||||
return Var.create(tag)
|
||||
|
||||
if "iterable" in tag:
|
||||
function_return = Var.create(
|
||||
[
|
||||
render_dict_to_var(child.render(), imported_names)
|
||||
for child in tag["children"]
|
||||
]
|
||||
)
|
||||
|
||||
func = ArgsFunctionOperation.create(
|
||||
(tag["arg_var_name"], tag["index_var_name"]),
|
||||
function_return,
|
||||
)
|
||||
|
||||
return FunctionStringVar.create("Array.prototype.map.call").call(
|
||||
tag["iterable"]
|
||||
if not isinstance(tag["iterable"], ObjectVar)
|
||||
else tag["iterable"].items(),
|
||||
func,
|
||||
)
|
||||
|
||||
if tag["name"] == "match":
|
||||
element = tag["cond"]
|
||||
|
||||
conditionals = render_dict_to_var(tag["default"], imported_names)
|
||||
|
||||
for case in tag["match_cases"][::-1]:
|
||||
condition = case[0].to_string() == element.to_string()
|
||||
for pattern in case[1:-1]:
|
||||
condition = condition | (pattern.to_string() == element.to_string())
|
||||
|
||||
conditionals = ternary_operation(
|
||||
condition,
|
||||
render_dict_to_var(case[-1], imported_names),
|
||||
conditionals,
|
||||
)
|
||||
|
||||
return conditionals
|
||||
|
||||
if "cond" in tag:
|
||||
return ternary_operation(
|
||||
tag["cond"],
|
||||
render_dict_to_var(tag["true_value"], imported_names),
|
||||
render_dict_to_var(tag["false_value"], imported_names)
|
||||
if tag["false_value"] is not None
|
||||
else Var.create(None),
|
||||
)
|
||||
|
||||
props = {}
|
||||
|
||||
special_props = []
|
||||
@ -2485,17 +2433,14 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
|
||||
"@emotion/react": [
|
||||
ImportVar(tag="jsx"),
|
||||
],
|
||||
}
|
||||
),
|
||||
VarData(
|
||||
imports=self._var_value._get_all_imports(),
|
||||
),
|
||||
VarData(
|
||||
imports={
|
||||
"react": [
|
||||
ImportVar(tag="Fragment"),
|
||||
],
|
||||
}
|
||||
},
|
||||
components=(self._var_value,),
|
||||
),
|
||||
VarData(
|
||||
imports=self._var_value._get_all_imports(),
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -21,16 +21,14 @@ _SUBMOD_ATTRS: dict[str, list[str]] = {
|
||||
"colors": [
|
||||
"color",
|
||||
],
|
||||
"cond": ["Cond", "color_mode_cond", "cond"],
|
||||
"cond": ["color_mode_cond", "cond"],
|
||||
"debounce": ["DebounceInput", "debounce_input"],
|
||||
"foreach": [
|
||||
"foreach",
|
||||
"Foreach",
|
||||
],
|
||||
"html": ["html", "Html"],
|
||||
"match": [
|
||||
"match",
|
||||
"Match",
|
||||
],
|
||||
"breakpoints": ["breakpoints", "set_breakpoints"],
|
||||
"responsive": [
|
||||
|
@ -17,16 +17,13 @@ from .breakpoints import set_breakpoints as set_breakpoints
|
||||
from .clipboard import Clipboard as Clipboard
|
||||
from .clipboard import clipboard as clipboard
|
||||
from .colors import color as color
|
||||
from .cond import Cond as Cond
|
||||
from .cond import color_mode_cond as color_mode_cond
|
||||
from .cond import cond as cond
|
||||
from .debounce import DebounceInput as DebounceInput
|
||||
from .debounce import debounce_input as debounce_input
|
||||
from .foreach import Foreach as Foreach
|
||||
from .foreach import foreach as foreach
|
||||
from .html import Html as Html
|
||||
from .html import html as html
|
||||
from .match import Match as Match
|
||||
from .match import match as match
|
||||
from .responsive import desktop_only as desktop_only
|
||||
from .responsive import mobile_and_tablet as mobile_and_tablet
|
||||
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from typing import Optional
|
||||
|
||||
from reflex import constants
|
||||
from reflex.components.base.fragment import Fragment
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.core.cond import cond
|
||||
from reflex.components.el.elements.typography import Div
|
||||
@ -163,7 +164,7 @@ class ConnectionToaster(Toaster):
|
||||
return super().create(*children, **props)
|
||||
|
||||
|
||||
class ConnectionBanner(Component):
|
||||
class ConnectionBanner(Fragment):
|
||||
"""A connection banner component."""
|
||||
|
||||
@classmethod
|
||||
@ -190,10 +191,10 @@ class ConnectionBanner(Component):
|
||||
position="fixed",
|
||||
)
|
||||
|
||||
return cond(has_connection_errors, comp)
|
||||
return super().create(cond(has_connection_errors, comp))
|
||||
|
||||
|
||||
class ConnectionModal(Component):
|
||||
class ConnectionModal(Fragment):
|
||||
"""A connection status modal window."""
|
||||
|
||||
@classmethod
|
||||
@ -208,16 +209,18 @@ class ConnectionModal(Component):
|
||||
"""
|
||||
if not comp:
|
||||
comp = Text.create(*default_connection_error())
|
||||
return cond(
|
||||
has_too_many_connection_errors,
|
||||
DialogRoot.create(
|
||||
DialogContent.create(
|
||||
DialogTitle.create("Connection Error"),
|
||||
comp,
|
||||
return super().create(
|
||||
cond(
|
||||
has_too_many_connection_errors,
|
||||
DialogRoot.create(
|
||||
DialogContent.create(
|
||||
DialogTitle.create("Connection Error"),
|
||||
comp,
|
||||
),
|
||||
open=has_too_many_connection_errors,
|
||||
z_index=9999,
|
||||
),
|
||||
open=has_too_many_connection_errors,
|
||||
z_index=9999,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
# ------------------------------------------------------
|
||||
from typing import Any, Dict, Literal, Optional, Union, overload
|
||||
|
||||
from reflex.components.base.fragment import Fragment
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.el.elements.typography import Div
|
||||
from reflex.components.lucide.icon import Icon
|
||||
@ -137,7 +138,7 @@ class ConnectionToaster(Toaster):
|
||||
"""
|
||||
...
|
||||
|
||||
class ConnectionBanner(Component):
|
||||
class ConnectionBanner(Fragment):
|
||||
@overload
|
||||
@classmethod
|
||||
def create( # type: ignore
|
||||
@ -176,7 +177,7 @@ class ConnectionBanner(Component):
|
||||
"""
|
||||
...
|
||||
|
||||
class ConnectionModal(Component):
|
||||
class ConnectionModal(Fragment):
|
||||
@overload
|
||||
@classmethod
|
||||
def create( # type: ignore
|
||||
|
@ -41,7 +41,7 @@ class ClientSideRouting(Component):
|
||||
return ""
|
||||
|
||||
|
||||
def wait_for_client_redirect(component: Component) -> Component:
|
||||
def wait_for_client_redirect(component: Component) -> Var[Component]:
|
||||
"""Wait for a redirect to occur before rendering a component.
|
||||
|
||||
This prevents the 404 page from flashing while the redirect is happening.
|
||||
@ -53,9 +53,9 @@ def wait_for_client_redirect(component: Component) -> Component:
|
||||
The conditionally rendered component.
|
||||
"""
|
||||
return cond(
|
||||
condition=route_not_found,
|
||||
c1=component,
|
||||
c2=ClientSideRouting.create(),
|
||||
route_not_found,
|
||||
component,
|
||||
ClientSideRouting.create(),
|
||||
)
|
||||
|
||||
|
||||
|
@ -2,126 +2,31 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, overload
|
||||
from typing import Any, TypeVar, Union, overload
|
||||
|
||||
from reflex.components.base.fragment import Fragment
|
||||
from reflex.components.component import BaseComponent, Component, MemoizationLeaf
|
||||
from reflex.components.tags import CondTag, Tag
|
||||
from reflex.constants import Dirs
|
||||
from reflex.components.component import BaseComponent, Component
|
||||
from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode
|
||||
from reflex.utils.imports import ImportDict, ImportVar
|
||||
from reflex.vars import VarData
|
||||
from reflex.utils import types
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
from reflex.vars.number import ternary_operation
|
||||
|
||||
_IS_TRUE_IMPORT: ImportDict = {
|
||||
f"$/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
|
||||
}
|
||||
|
||||
@overload
|
||||
def cond(
|
||||
condition: Any, c1: BaseComponent | Var[BaseComponent], c2: Any = None, /
|
||||
) -> Var[Component]: ...
|
||||
|
||||
|
||||
class Cond(MemoizationLeaf):
|
||||
"""Render one of two components based on a condition."""
|
||||
|
||||
# The cond to determine which component to render.
|
||||
cond: Var[Any]
|
||||
|
||||
# The component to render if the cond is true.
|
||||
comp1: BaseComponent | None = None
|
||||
# The component to render if the cond is false.
|
||||
comp2: BaseComponent | None = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
cond: Var,
|
||||
comp1: BaseComponent,
|
||||
comp2: Optional[BaseComponent] = None,
|
||||
) -> Component:
|
||||
"""Create a conditional component.
|
||||
|
||||
Args:
|
||||
cond: The cond to determine which component to render.
|
||||
comp1: The component to render if the cond is true.
|
||||
comp2: The component to render if the cond is false.
|
||||
|
||||
Returns:
|
||||
The conditional component.
|
||||
"""
|
||||
# Wrap everything in fragments.
|
||||
if type(comp1).__name__ != "Fragment":
|
||||
comp1 = Fragment.create(comp1)
|
||||
if comp2 is None or type(comp2).__name__ != "Fragment":
|
||||
comp2 = Fragment.create(comp2) if comp2 else Fragment.create()
|
||||
return Fragment.create(
|
||||
cls(
|
||||
cond=cond,
|
||||
comp1=comp1,
|
||||
comp2=comp2,
|
||||
children=[comp1, comp2],
|
||||
)
|
||||
)
|
||||
|
||||
def _get_props_imports(self):
|
||||
"""Get the imports needed for component's props.
|
||||
|
||||
Returns:
|
||||
The imports for the component's props of the component.
|
||||
"""
|
||||
return []
|
||||
|
||||
def _render(self) -> Tag:
|
||||
return CondTag(
|
||||
cond=self.cond,
|
||||
true_value=self.comp1.render(), # pyright: ignore [reportOptionalMemberAccess]
|
||||
false_value=self.comp2.render(), # pyright: ignore [reportOptionalMemberAccess]
|
||||
)
|
||||
|
||||
def render(self) -> Dict:
|
||||
"""Render the component.
|
||||
|
||||
Returns:
|
||||
The dictionary for template of component.
|
||||
"""
|
||||
tag = self._render()
|
||||
return dict(
|
||||
tag.add_props(
|
||||
**self.event_triggers,
|
||||
key=self.key,
|
||||
sx=self.style,
|
||||
id=self.id,
|
||||
class_name=self.class_name,
|
||||
).set(
|
||||
props=tag.format_props(),
|
||||
),
|
||||
cond_state=f"isTrue({self.cond!s})",
|
||||
)
|
||||
|
||||
def add_imports(self) -> ImportDict:
|
||||
"""Add imports for the Cond component.
|
||||
|
||||
Returns:
|
||||
The import dict for the component.
|
||||
"""
|
||||
var_data = VarData.merge(self.cond._get_all_var_data())
|
||||
|
||||
imports = var_data.old_school_imports() if var_data else {}
|
||||
|
||||
return {**imports, **_IS_TRUE_IMPORT}
|
||||
T = TypeVar("T")
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
@overload
|
||||
def cond(condition: Any, c1: Component, c2: Any) -> Component: ... # pyright: ignore [reportOverlappingOverload]
|
||||
def cond(condition: Any, c1: T | Var[T], c2: V | Var[V], /) -> Var[T | V]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def cond(condition: Any, c1: Component) -> Component: ...
|
||||
|
||||
|
||||
@overload
|
||||
def cond(condition: Any, c1: Any, c2: Any) -> Var: ...
|
||||
|
||||
|
||||
def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
|
||||
def cond(condition: Any, c1: Any, c2: Any = None, /) -> Var:
|
||||
"""Create a conditional component or Prop.
|
||||
|
||||
Args:
|
||||
@ -137,48 +42,40 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
|
||||
"""
|
||||
# Convert the condition to a Var.
|
||||
cond_var = LiteralVar.create(condition)
|
||||
if cond_var is None:
|
||||
raise ValueError("The condition must be set.")
|
||||
|
||||
# If the first component is a component, create a Cond component.
|
||||
if isinstance(c1, BaseComponent):
|
||||
if c2 is not None and not isinstance(c2, BaseComponent):
|
||||
raise ValueError("Both arguments must be components.")
|
||||
return Cond.create(cond_var, c1, c2)
|
||||
# If the first component is a component, create a Fragment if the second component is not set.
|
||||
if isinstance(c1, BaseComponent) or (
|
||||
isinstance(c1, Var)
|
||||
and types.safe_typehint_issubclass(
|
||||
c1._var_type, Union[BaseComponent, list[BaseComponent]]
|
||||
)
|
||||
):
|
||||
c2 = c2 if c2 is not None else Fragment.create()
|
||||
|
||||
# Otherwise, create a conditional Var.
|
||||
# Check that the second argument is valid.
|
||||
if isinstance(c2, BaseComponent):
|
||||
raise ValueError("Both arguments must be props.")
|
||||
if c2 is None:
|
||||
raise ValueError("For conditional vars, the second argument must be set.")
|
||||
|
||||
def create_var(cond_part: Any) -> Var[Any]:
|
||||
return LiteralVar.create(cond_part)
|
||||
|
||||
# convert the truth and false cond parts into vars so the _var_data can be obtained.
|
||||
c1 = create_var(c1)
|
||||
c2 = create_var(c2)
|
||||
|
||||
# Create the conditional var.
|
||||
return ternary_operation(
|
||||
cond_var.bool()._replace(
|
||||
merge_var_data=VarData(imports=_IS_TRUE_IMPORT),
|
||||
),
|
||||
cond_var.bool(),
|
||||
c1,
|
||||
c2,
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def color_mode_cond(light: Component, dark: Component | None = None) -> Component: ... # pyright: ignore [reportOverlappingOverload]
|
||||
def color_mode_cond(
|
||||
light: BaseComponent | Var[BaseComponent],
|
||||
dark: BaseComponent | Var[BaseComponent] | None = ...,
|
||||
) -> Var[Component]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def color_mode_cond(light: Any, dark: Any = None) -> Var: ...
|
||||
def color_mode_cond(light: T | Var[T], dark: V | Var[V]) -> Var[T | V]: ...
|
||||
|
||||
|
||||
def color_mode_cond(light: Any, dark: Any = None) -> Var | Component:
|
||||
def color_mode_cond(light: Any, dark: Any = None) -> Var:
|
||||
"""Create a component or Prop based on color_mode.
|
||||
|
||||
Args:
|
||||
@ -193,3 +90,9 @@ def color_mode_cond(light: Any, dark: Any = None) -> Var | Component:
|
||||
light,
|
||||
dark,
|
||||
)
|
||||
|
||||
|
||||
class Cond:
|
||||
"""Create a conditional component or Prop."""
|
||||
|
||||
create = staticmethod(cond)
|
||||
|
@ -2,16 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Any, Callable, Iterable
|
||||
from typing import Callable, Iterable
|
||||
|
||||
from reflex.components.base.fragment import Fragment
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.tags import IterTag
|
||||
from reflex.constants import MemoizationMode
|
||||
from reflex.state import ComponentState
|
||||
from reflex.utils.exceptions import UntypedVarError
|
||||
from reflex.vars import ArrayVar, ObjectVar, StringVar
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
|
||||
|
||||
@ -23,149 +16,42 @@ class ForeachRenderError(TypeError):
|
||||
"""Raised when there is an error with the foreach render function."""
|
||||
|
||||
|
||||
class Foreach(Component):
|
||||
"""A component that takes in an iterable and a render function and renders a list of components."""
|
||||
def foreach(
|
||||
iterable: Var[Iterable] | Iterable,
|
||||
render_fn: Callable,
|
||||
) -> Var:
|
||||
"""Create a foreach component.
|
||||
|
||||
_memoization_mode = MemoizationMode(recursive=False)
|
||||
Args:
|
||||
iterable: The iterable to create components from.
|
||||
render_fn: A function from the render args to the component.
|
||||
|
||||
# The iterable to create components from.
|
||||
iterable: Var[Iterable]
|
||||
Returns:
|
||||
The foreach component.
|
||||
|
||||
# A function from the render args to the component.
|
||||
render_fn: Callable = Fragment.create
|
||||
Raises:
|
||||
ForeachVarError: If the iterable is of type Any.
|
||||
TypeError: If the render function is a ComponentState.
|
||||
UntypedVarError: If the iterable is of type Any without a type annotation.
|
||||
"""
|
||||
iterable = LiteralVar.create(iterable).guess_type()
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
iterable: Var[Iterable] | Iterable,
|
||||
render_fn: Callable,
|
||||
) -> Foreach:
|
||||
"""Create a foreach component.
|
||||
if isinstance(iterable, ObjectVar):
|
||||
iterable = iterable.entries()
|
||||
|
||||
Args:
|
||||
iterable: The iterable to create components from.
|
||||
render_fn: A function from the render args to the component.
|
||||
if isinstance(iterable, StringVar):
|
||||
iterable = iterable.split()
|
||||
|
||||
Returns:
|
||||
The foreach component.
|
||||
|
||||
Raises:
|
||||
ForeachVarError: If the iterable is of type Any.
|
||||
TypeError: If the render function is a ComponentState.
|
||||
UntypedVarError: If the iterable is of type Any without a type annotation.
|
||||
"""
|
||||
from reflex.vars import ArrayVar, ObjectVar, StringVar
|
||||
|
||||
iterable = LiteralVar.create(iterable).guess_type()
|
||||
|
||||
if iterable._var_type == Any:
|
||||
raise ForeachVarError(
|
||||
f"Could not foreach over var `{iterable!s}` of type Any. "
|
||||
"(If you are trying to foreach over a state var, add a type annotation to the var). "
|
||||
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(render_fn, "__qualname__")
|
||||
and render_fn.__qualname__ == ComponentState.create.__qualname__
|
||||
):
|
||||
raise TypeError(
|
||||
"Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet."
|
||||
)
|
||||
|
||||
if isinstance(iterable, ObjectVar):
|
||||
iterable = iterable.entries()
|
||||
|
||||
if isinstance(iterable, StringVar):
|
||||
iterable = iterable.split()
|
||||
|
||||
if not isinstance(iterable, ArrayVar):
|
||||
raise ForeachVarError(
|
||||
f"Could not foreach over var `{iterable!s}` of type {iterable._var_type}. "
|
||||
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
|
||||
)
|
||||
|
||||
component = cls(
|
||||
iterable=iterable,
|
||||
render_fn=render_fn,
|
||||
)
|
||||
try:
|
||||
# Keep a ref to a rendered component to determine correct imports/hooks/styles.
|
||||
component.children = [component._render().render_component()]
|
||||
except UntypedVarError as e:
|
||||
raise UntypedVarError(
|
||||
f"Could not foreach over var `{iterable!s}` without a type annotation. "
|
||||
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
|
||||
) from e
|
||||
return component
|
||||
|
||||
def _render(self) -> IterTag:
|
||||
props = {}
|
||||
|
||||
render_sig = inspect.signature(self.render_fn)
|
||||
params = list(render_sig.parameters.values())
|
||||
|
||||
# Validate the render function signature.
|
||||
if len(params) == 0 or len(params) > 2:
|
||||
raise ForeachRenderError(
|
||||
"Expected 1 or 2 parameters in foreach render function, got "
|
||||
f"{[p.name for p in params]}. See "
|
||||
"https://reflex.dev/docs/library/dynamic-rendering/foreach/"
|
||||
)
|
||||
|
||||
if len(params) >= 1:
|
||||
# Determine the arg var name based on the params accepted by render_fn.
|
||||
props["arg_var_name"] = params[0].name
|
||||
|
||||
if len(params) == 2:
|
||||
# Determine the index var name based on the params accepted by render_fn.
|
||||
props["index_var_name"] = params[1].name
|
||||
else:
|
||||
render_fn = self.render_fn
|
||||
# Otherwise, use a deterministic index, based on the render function bytecode.
|
||||
code_hash = (
|
||||
hash(
|
||||
getattr(
|
||||
render_fn,
|
||||
"__code__",
|
||||
(
|
||||
repr(self.render_fn)
|
||||
if not isinstance(render_fn, functools.partial)
|
||||
else render_fn.func.__code__
|
||||
),
|
||||
)
|
||||
)
|
||||
.to_bytes(
|
||||
length=8,
|
||||
byteorder="big",
|
||||
signed=True,
|
||||
)
|
||||
.hex()
|
||||
)
|
||||
props["index_var_name"] = f"index_{code_hash}"
|
||||
|
||||
return IterTag(
|
||||
iterable=self.iterable,
|
||||
render_fn=self.render_fn,
|
||||
children=self.children,
|
||||
**props,
|
||||
if not isinstance(iterable, ArrayVar):
|
||||
raise ForeachVarError(
|
||||
f"Could not foreach over var `{iterable!s}` of type {iterable._var_type}. "
|
||||
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
|
||||
)
|
||||
|
||||
def render(self):
|
||||
"""Render the component.
|
||||
|
||||
Returns:
|
||||
The dictionary for template of component.
|
||||
"""
|
||||
tag = self._render()
|
||||
|
||||
return dict(
|
||||
tag,
|
||||
iterable_state=str(tag.iterable),
|
||||
arg_name=tag.arg_var_name,
|
||||
arg_index=tag.get_index_var_arg(),
|
||||
iterable_type=tag.iterable._var_type.mro()[0].__name__,
|
||||
)
|
||||
return iterable.foreach(render_fn)
|
||||
|
||||
|
||||
foreach = Foreach.create
|
||||
class Foreach:
|
||||
"""Create a foreach component."""
|
||||
|
||||
create = staticmethod(foreach)
|
||||
|
@ -1,274 +1,161 @@
|
||||
"""rx.match."""
|
||||
|
||||
import textwrap
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from reflex.components.base import Fragment
|
||||
from reflex.components.component import BaseComponent, Component, MemoizationLeaf
|
||||
from reflex.components.tags import MatchTag, Tag
|
||||
from reflex.style import Style
|
||||
from reflex.utils import format, types
|
||||
from reflex.components.component import BaseComponent
|
||||
from reflex.utils import types
|
||||
from reflex.utils.exceptions import MatchTypeError
|
||||
from reflex.utils.imports import ImportDict
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
from reflex.vars.base import VAR_TYPE, Var
|
||||
from reflex.vars.number import MatchOperation
|
||||
|
||||
CASE_TYPE = tuple[Unpack[tuple[Any, ...]], Var[VAR_TYPE] | VAR_TYPE]
|
||||
|
||||
|
||||
class Match(MemoizationLeaf):
|
||||
"""Match cases based on a condition."""
|
||||
def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
|
||||
"""Process the individual match cases.
|
||||
|
||||
# The condition to determine which case to match.
|
||||
cond: Var[Any]
|
||||
Args:
|
||||
cases: The match cases.
|
||||
|
||||
# The list of match cases to be matched.
|
||||
match_cases: List[Any] = []
|
||||
|
||||
# The catchall case to match.
|
||||
default: Any
|
||||
|
||||
@classmethod
|
||||
def create(cls, cond: Any, *cases) -> Union[Component, Var]:
|
||||
"""Create a Match Component.
|
||||
|
||||
Args:
|
||||
cond: The condition to determine which case to match.
|
||||
cases: This list of cases to match.
|
||||
|
||||
Returns:
|
||||
The match component.
|
||||
|
||||
Raises:
|
||||
ValueError: When a default case is not provided for cases with Var return types.
|
||||
"""
|
||||
match_cond_var = cls._create_condition_var(cond)
|
||||
cases, default = cls._process_cases(list(cases))
|
||||
match_cases = cls._process_match_cases(cases)
|
||||
|
||||
cls._validate_return_types(match_cases)
|
||||
|
||||
if default is None and types._issubclass(type(match_cases[0][-1]), Var):
|
||||
Raises:
|
||||
ValueError: If the default case is not the last case or the tuple elements are less than 2.
|
||||
"""
|
||||
for case in cases:
|
||||
if not isinstance(case, tuple):
|
||||
raise ValueError(
|
||||
"For cases with return types as Vars, a default case must be provided"
|
||||
"rx.match should have tuples of cases and a default case as the last argument."
|
||||
)
|
||||
|
||||
return cls._create_match_cond_var_or_component(
|
||||
match_cond_var, match_cases, default
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _create_condition_var(cls, cond: Any) -> Var:
|
||||
"""Convert the condition to a Var.
|
||||
|
||||
Args:
|
||||
cond: The condition.
|
||||
|
||||
Returns:
|
||||
The condition as a base var
|
||||
|
||||
Raises:
|
||||
ValueError: If the condition is not provided.
|
||||
"""
|
||||
match_cond_var = LiteralVar.create(cond)
|
||||
|
||||
if match_cond_var is None:
|
||||
raise ValueError("The condition must be set")
|
||||
return match_cond_var
|
||||
|
||||
@classmethod
|
||||
def _process_cases(
|
||||
cls, cases: List
|
||||
) -> Tuple[List, Optional[Union[Var, BaseComponent]]]:
|
||||
"""Process the list of match cases and the catchall default case.
|
||||
|
||||
Args:
|
||||
cases: The list of match cases.
|
||||
|
||||
Returns:
|
||||
The default case and the list of match case tuples.
|
||||
|
||||
Raises:
|
||||
ValueError: If there are multiple default cases.
|
||||
"""
|
||||
default = None
|
||||
|
||||
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
|
||||
raise ValueError("rx.match can only have one default case.")
|
||||
|
||||
if not cases:
|
||||
raise ValueError("rx.match should have at least one case.")
|
||||
|
||||
# Get the default case which should be the last non-tuple arg
|
||||
if not isinstance(cases[-1], tuple):
|
||||
default = cases.pop()
|
||||
default = (
|
||||
cls._create_case_var_with_var_data(default)
|
||||
if not isinstance(default, BaseComponent)
|
||||
else default
|
||||
# There should be at least two elements in a case tuple(a condition and return value)
|
||||
if len(case) < 2:
|
||||
raise ValueError(
|
||||
"A case tuple should have at least a match case element and a return value."
|
||||
)
|
||||
|
||||
return cases, default
|
||||
|
||||
@classmethod
|
||||
def _create_case_var_with_var_data(cls, case_element: Any) -> Var:
|
||||
"""Convert a case element into a Var.If the case
|
||||
is a Style type, we extract the var data and merge it with the
|
||||
newly created Var.
|
||||
def _validate_return_types(*return_values: Any) -> bool:
|
||||
"""Validate that match cases have the same return types.
|
||||
|
||||
Args:
|
||||
case_element: The case element.
|
||||
Args:
|
||||
return_values: The return values of the match cases.
|
||||
|
||||
Returns:
|
||||
The case element Var.
|
||||
"""
|
||||
_var_data = case_element._var_data if isinstance(case_element, Style) else None
|
||||
case_element = LiteralVar.create(case_element, _var_data=_var_data)
|
||||
return case_element
|
||||
Returns:
|
||||
True if all cases have the same return types.
|
||||
|
||||
@classmethod
|
||||
def _process_match_cases(cls, cases: List) -> List[List[Var]]:
|
||||
"""Process the individual match cases.
|
||||
Raises:
|
||||
MatchTypeError: If the return types of cases are different.
|
||||
"""
|
||||
|
||||
Args:
|
||||
cases: The match cases.
|
||||
|
||||
Returns:
|
||||
The processed match cases.
|
||||
|
||||
Raises:
|
||||
ValueError: If the default case is not the last case or the tuple elements are less than 2.
|
||||
"""
|
||||
match_cases = []
|
||||
for case in cases:
|
||||
if not isinstance(case, tuple):
|
||||
raise ValueError(
|
||||
"rx.match should have tuples of cases and a default case as the last argument."
|
||||
)
|
||||
# There should be at least two elements in a case tuple(a condition and return value)
|
||||
if len(case) < 2:
|
||||
raise ValueError(
|
||||
"A case tuple should have at least a match case element and a return value."
|
||||
)
|
||||
|
||||
case_list = []
|
||||
for element in case:
|
||||
# convert all non component element to vars.
|
||||
el = (
|
||||
cls._create_case_var_with_var_data(element)
|
||||
if not isinstance(element, BaseComponent)
|
||||
else element
|
||||
)
|
||||
if not isinstance(el, (Var, BaseComponent)):
|
||||
raise ValueError("Case element must be a var or component")
|
||||
case_list.append(el)
|
||||
|
||||
match_cases.append(case_list)
|
||||
|
||||
return match_cases
|
||||
|
||||
@classmethod
|
||||
def _validate_return_types(cls, match_cases: List[List[Var]]) -> None:
|
||||
"""Validate that match cases have the same return types.
|
||||
|
||||
Args:
|
||||
match_cases: The match cases.
|
||||
|
||||
Raises:
|
||||
MatchTypeError: If the return types of cases are different.
|
||||
"""
|
||||
first_case_return = match_cases[0][-1]
|
||||
return_type = type(first_case_return)
|
||||
|
||||
if isinstance(first_case_return, BaseComponent):
|
||||
return_type = BaseComponent
|
||||
elif isinstance(first_case_return, Var):
|
||||
return_type = Var
|
||||
|
||||
for index, case in enumerate(match_cases):
|
||||
if not types._issubclass(type(case[-1]), return_type):
|
||||
raise MatchTypeError(
|
||||
f"Match cases should have the same return types. Case {index} with return "
|
||||
f"value `{case[-1]._js_expr if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`"
|
||||
f" of type {type(case[-1])!r} is not {return_type}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _create_match_cond_var_or_component(
|
||||
cls,
|
||||
match_cond_var: Var,
|
||||
match_cases: List[List[Var]],
|
||||
default: Optional[Union[Var, BaseComponent]],
|
||||
) -> Union[Component, Var]:
|
||||
"""Create and return the match condition var or component.
|
||||
|
||||
Args:
|
||||
match_cond_var: The match condition.
|
||||
match_cases: The list of match cases.
|
||||
default: The default case.
|
||||
|
||||
Returns:
|
||||
The match component wrapped in a fragment or the match var.
|
||||
|
||||
Raises:
|
||||
ValueError: If the return types are not vars when creating a match var for Var types.
|
||||
"""
|
||||
if default is None and types._issubclass(
|
||||
type(match_cases[0][-1]), BaseComponent
|
||||
):
|
||||
default = Fragment.create()
|
||||
|
||||
if types._issubclass(type(match_cases[0][-1]), BaseComponent):
|
||||
return Fragment.create(
|
||||
cls(
|
||||
cond=match_cond_var,
|
||||
match_cases=match_cases,
|
||||
default=default,
|
||||
children=[case[-1] for case in match_cases] + [default], # pyright: ignore [reportArgumentType]
|
||||
)
|
||||
def is_component_or_component_var(obj: Any) -> bool:
|
||||
return types._isinstance(obj, BaseComponent) or (
|
||||
isinstance(obj, Var)
|
||||
and types.safe_typehint_issubclass(
|
||||
obj._var_type, Union[list[BaseComponent], BaseComponent]
|
||||
)
|
||||
|
||||
# Validate the match cases (as well as the default case) to have Var return types.
|
||||
if any(
|
||||
case for case in match_cases if not isinstance(case[-1], Var)
|
||||
) or not isinstance(default, Var):
|
||||
raise ValueError("Return types of match cases should be Vars.")
|
||||
|
||||
return Var(
|
||||
_js_expr=format.format_match(
|
||||
cond=str(match_cond_var),
|
||||
match_cases=match_cases,
|
||||
default=default, # pyright: ignore [reportArgumentType]
|
||||
),
|
||||
_var_type=default._var_type, # pyright: ignore [reportAttributeAccessIssue,reportOptionalMemberAccess]
|
||||
_var_data=VarData.merge(
|
||||
match_cond_var._get_all_var_data(),
|
||||
*[el._get_all_var_data() for case in match_cases for el in case],
|
||||
default._get_all_var_data(), # pyright: ignore [reportAttributeAccessIssue, reportOptionalMemberAccess]
|
||||
),
|
||||
)
|
||||
|
||||
def _render(self) -> Tag:
|
||||
return MatchTag(
|
||||
cond=self.cond, match_cases=self.match_cases, default=self.default
|
||||
def type_of_return_value(obj: Any) -> Any:
|
||||
if isinstance(obj, Var):
|
||||
return obj._var_type
|
||||
return type(obj)
|
||||
|
||||
is_return_type_component = [
|
||||
is_component_or_component_var(return_type) for return_type in return_values
|
||||
]
|
||||
|
||||
if any(is_return_type_component) and not all(is_return_type_component):
|
||||
non_component_return_types = [
|
||||
(type_of_return_value(return_value), i)
|
||||
for i, return_value in enumerate(return_values)
|
||||
if not is_return_type_component[i]
|
||||
]
|
||||
raise MatchTypeError(
|
||||
"Match cases should have the same return types. "
|
||||
+ "Expected return types to be of type Component or Var[Component]. "
|
||||
+ ". ".join(
|
||||
[
|
||||
f"Return type of case {i} is {return_type}"
|
||||
for return_type, i in non_component_return_types
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def render(self) -> Dict:
|
||||
"""Render the component.
|
||||
|
||||
Returns:
|
||||
The dictionary for template of component.
|
||||
"""
|
||||
tag = self._render()
|
||||
tag.name = "match"
|
||||
return dict(tag)
|
||||
|
||||
def add_imports(self) -> ImportDict:
|
||||
"""Add imports for the Match component.
|
||||
|
||||
Returns:
|
||||
The import dict.
|
||||
"""
|
||||
var_data = VarData.merge(self.cond._get_all_var_data())
|
||||
return var_data.old_school_imports() if var_data else {}
|
||||
return all(is_return_type_component)
|
||||
|
||||
|
||||
match = Match.create
|
||||
def _create_match_var(
|
||||
match_cond_var: Var,
|
||||
match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
|
||||
default: VAR_TYPE | Var[VAR_TYPE],
|
||||
) -> Var[VAR_TYPE]:
|
||||
"""Create the match var.
|
||||
|
||||
Args:
|
||||
match_cond_var: The match condition var.
|
||||
match_cases: The match cases.
|
||||
default: The default case.
|
||||
|
||||
Returns:
|
||||
The match var.
|
||||
"""
|
||||
return MatchOperation.create(match_cond_var, match_cases, default)
|
||||
|
||||
|
||||
def match(
|
||||
cond: Any,
|
||||
*cases: Unpack[
|
||||
tuple[Unpack[tuple[CASE_TYPE[VAR_TYPE], ...]], Var[VAR_TYPE] | VAR_TYPE]
|
||||
],
|
||||
) -> Var[VAR_TYPE]:
|
||||
"""Create a match var.
|
||||
|
||||
Args:
|
||||
cond: The condition to match.
|
||||
cases: The match cases. Each case should be a tuple with the first elements as the match case and the last element as the return value. The last argument should be the default case.
|
||||
|
||||
Returns:
|
||||
The match var.
|
||||
|
||||
Raises:
|
||||
ValueError: If the default case is not the last case or the tuple elements are less than 2.
|
||||
"""
|
||||
default = types.Unset()
|
||||
|
||||
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
|
||||
raise ValueError("rx.match can only have one default case.")
|
||||
|
||||
if not cases:
|
||||
raise ValueError("rx.match should have at least one case.")
|
||||
|
||||
# Get the default case which should be the last non-tuple arg
|
||||
if not isinstance(cases[-1], tuple):
|
||||
default = cases[-1]
|
||||
actual_cases = cases[:-1]
|
||||
else:
|
||||
actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
|
||||
|
||||
_process_match_cases(actual_cases)
|
||||
|
||||
is_component_match = _validate_return_types(
|
||||
*[case[-1] for case in actual_cases],
|
||||
*([default] if not isinstance(default, types.Unset) else []),
|
||||
)
|
||||
|
||||
if isinstance(default, types.Unset) and not is_component_match:
|
||||
raise ValueError(
|
||||
"For cases with return types as Vars, a default case must be provided"
|
||||
)
|
||||
|
||||
if isinstance(default, types.Unset):
|
||||
default = Fragment.create()
|
||||
|
||||
default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
|
||||
|
||||
return _create_match_var(
|
||||
cond,
|
||||
actual_cases,
|
||||
default,
|
||||
)
|
||||
|
@ -29,7 +29,7 @@ from reflex.event import (
|
||||
from reflex.utils import format
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import CallableVar, Var, get_unique_variable_name
|
||||
from reflex.vars.base import Var, get_unique_variable_name
|
||||
from reflex.vars.sequence import LiteralStringVar
|
||||
|
||||
DEFAULT_UPLOAD_ID: str = "default"
|
||||
@ -45,7 +45,6 @@ upload_files_context_var_data: VarData = VarData(
|
||||
)
|
||||
|
||||
|
||||
@CallableVar
|
||||
def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> Var:
|
||||
"""Get the file upload drop trigger.
|
||||
|
||||
@ -75,7 +74,6 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> Var:
|
||||
)
|
||||
|
||||
|
||||
@CallableVar
|
||||
def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> Var:
|
||||
"""Get the list of selected files.
|
||||
|
||||
|
@ -13,14 +13,12 @@ from reflex.event import CallableEventSpec, EventSpec, EventType
|
||||
from reflex.style import Style
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import CallableVar, Var
|
||||
from reflex.vars.base import Var
|
||||
|
||||
DEFAULT_UPLOAD_ID: str
|
||||
upload_files_context_var_data: VarData
|
||||
|
||||
@CallableVar
|
||||
def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> Var: ...
|
||||
@CallableVar
|
||||
def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> Var: ...
|
||||
@CallableEventSpec
|
||||
def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ...
|
||||
|
@ -303,7 +303,7 @@ class DataEditor(NoSSRComponent):
|
||||
|
||||
# Fired when editing is finished.
|
||||
on_finished_editing: EventHandler[
|
||||
passthrough_event_spec(Union[GridCell, None], tuple[int, int])
|
||||
passthrough_event_spec(Union[GridCell, None], tuple[int, int]) # pyright: ignore[reportArgumentType]
|
||||
]
|
||||
|
||||
# Fired when a row is appended.
|
||||
|
@ -828,7 +828,7 @@ class ShikiHighLevelCodeBlock(ShikiCodeBlock):
|
||||
if isinstance(code, Var):
|
||||
return string_replace_operation(
|
||||
code, StringVar(_js_expr=f"/{regex_pattern}/g", _var_type=str), ""
|
||||
)
|
||||
).guess_type()
|
||||
if isinstance(code, str):
|
||||
return re.sub(regex_pattern, "", code)
|
||||
|
||||
|
@ -114,8 +114,8 @@ class MarkdownComponentMap:
|
||||
explicit_return = explicit_return or cls._explicit_return
|
||||
|
||||
return ArgsFunctionOperation.create(
|
||||
args_names=(DestructuredArg(fields=tuple(fn_args)),),
|
||||
return_expr=fn_body,
|
||||
(DestructuredArg(fields=tuple(fn_args)),),
|
||||
fn_body,
|
||||
explicit_return=explicit_return,
|
||||
_var_data=var_data,
|
||||
)
|
||||
|
@ -188,7 +188,7 @@ class Slider(ComponentNamespace):
|
||||
else:
|
||||
children = [
|
||||
track,
|
||||
# Foreach.create(props.get("value"), lambda e: SliderThumb.create()), # foreach doesn't render Thumbs properly # noqa: ERA001
|
||||
# foreach(props.get("value"), lambda e: SliderThumb.create()), # foreach doesn't render Thumbs properly # noqa: ERA001
|
||||
]
|
||||
|
||||
return SliderRoot.create(*children, **props)
|
||||
|
@ -20,7 +20,7 @@ from __future__ import annotations
|
||||
from typing import Any, Dict, List, Literal, Optional, Union, get_args
|
||||
|
||||
from reflex.components.component import BaseComponent
|
||||
from reflex.components.core.cond import Cond, color_mode_cond, cond
|
||||
from reflex.components.core.cond import color_mode_cond, cond
|
||||
from reflex.components.lucide.icon import Icon
|
||||
from reflex.components.radix.themes.components.dropdown_menu import dropdown_menu
|
||||
from reflex.components.radix.themes.components.switch import Switch
|
||||
@ -40,28 +40,23 @@ DEFAULT_LIGHT_ICON: Icon = Icon.create(tag="sun")
|
||||
DEFAULT_DARK_ICON: Icon = Icon.create(tag="moon")
|
||||
|
||||
|
||||
class ColorModeIcon(Cond):
|
||||
"""Displays the current color mode as an icon."""
|
||||
def color_mode_icon(
|
||||
light_component: BaseComponent | None = None,
|
||||
dark_component: BaseComponent | None = None,
|
||||
):
|
||||
"""Create a color mode icon component.
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
light_component: BaseComponent | None = None,
|
||||
dark_component: BaseComponent | None = None,
|
||||
):
|
||||
"""Create an icon component based on color_mode.
|
||||
Args:
|
||||
light_component: The component to render in light mode.
|
||||
dark_component: The component to render in dark mode.
|
||||
|
||||
Args:
|
||||
light_component: the component to display when color mode is default
|
||||
dark_component: the component to display when color mode is dark (non-default)
|
||||
|
||||
Returns:
|
||||
The conditionally rendered component
|
||||
"""
|
||||
return color_mode_cond(
|
||||
light=light_component or DEFAULT_LIGHT_ICON,
|
||||
dark=dark_component or DEFAULT_DARK_ICON,
|
||||
)
|
||||
Returns:
|
||||
The color mode icon component.
|
||||
"""
|
||||
return color_mode_cond(
|
||||
light=light_component or DEFAULT_LIGHT_ICON,
|
||||
dark=dark_component or DEFAULT_DARK_ICON,
|
||||
)
|
||||
|
||||
|
||||
LiteralPosition = Literal["top-left", "top-right", "bottom-left", "bottom-right"]
|
||||
@ -144,7 +139,7 @@ class ColorModeIconButton(IconButton):
|
||||
|
||||
if allow_system:
|
||||
|
||||
def color_mode_item(_color_mode: str):
|
||||
def color_mode_item(_color_mode: Literal["light", "dark", "system"]):
|
||||
return dropdown_menu.item(
|
||||
_color_mode.title(), on_click=set_color_mode(_color_mode)
|
||||
)
|
||||
@ -152,7 +147,7 @@ class ColorModeIconButton(IconButton):
|
||||
return dropdown_menu.root(
|
||||
dropdown_menu.trigger(
|
||||
super().create(
|
||||
ColorModeIcon.create(),
|
||||
color_mode_icon(),
|
||||
),
|
||||
**props,
|
||||
),
|
||||
@ -163,7 +158,7 @@ class ColorModeIconButton(IconButton):
|
||||
),
|
||||
)
|
||||
return IconButton.create(
|
||||
ColorModeIcon.create(),
|
||||
color_mode_icon(),
|
||||
on_click=toggle_color_mode,
|
||||
**props,
|
||||
)
|
||||
@ -197,7 +192,7 @@ class ColorModeSwitch(Switch):
|
||||
class ColorModeNamespace(Var):
|
||||
"""Namespace for color mode components."""
|
||||
|
||||
icon = staticmethod(ColorModeIcon.create)
|
||||
icon = staticmethod(color_mode_icon)
|
||||
button = staticmethod(ColorModeIconButton.create)
|
||||
switch = staticmethod(ColorModeSwitch.create)
|
||||
|
||||
|
@ -7,7 +7,6 @@ from typing import Any, Dict, List, Literal, Optional, Union, overload
|
||||
|
||||
from reflex.components.component import BaseComponent
|
||||
from reflex.components.core.breakpoints import Breakpoints
|
||||
from reflex.components.core.cond import Cond
|
||||
from reflex.components.lucide.icon import Icon
|
||||
from reflex.components.radix.themes.components.switch import Switch
|
||||
from reflex.event import EventType
|
||||
@ -19,48 +18,10 @@ from .components.icon_button import IconButton
|
||||
DEFAULT_LIGHT_ICON: Icon
|
||||
DEFAULT_DARK_ICON: Icon
|
||||
|
||||
class ColorModeIcon(Cond):
|
||||
@overload
|
||||
@classmethod
|
||||
def create( # type: ignore
|
||||
cls,
|
||||
*children,
|
||||
cond: Optional[Union[Any, Var[Any]]] = None,
|
||||
comp1: Optional[BaseComponent] = None,
|
||||
comp2: Optional[BaseComponent] = None,
|
||||
style: Optional[Style] = None,
|
||||
key: Optional[Any] = None,
|
||||
id: Optional[Any] = None,
|
||||
class_name: Optional[Any] = None,
|
||||
autofocus: Optional[bool] = None,
|
||||
custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
|
||||
on_blur: Optional[EventType[()]] = None,
|
||||
on_click: Optional[EventType[()]] = None,
|
||||
on_context_menu: Optional[EventType[()]] = None,
|
||||
on_double_click: Optional[EventType[()]] = None,
|
||||
on_focus: Optional[EventType[()]] = None,
|
||||
on_mount: Optional[EventType[()]] = None,
|
||||
on_mouse_down: Optional[EventType[()]] = None,
|
||||
on_mouse_enter: Optional[EventType[()]] = None,
|
||||
on_mouse_leave: Optional[EventType[()]] = None,
|
||||
on_mouse_move: Optional[EventType[()]] = None,
|
||||
on_mouse_out: Optional[EventType[()]] = None,
|
||||
on_mouse_over: Optional[EventType[()]] = None,
|
||||
on_mouse_up: Optional[EventType[()]] = None,
|
||||
on_scroll: Optional[EventType[()]] = None,
|
||||
on_unmount: Optional[EventType[()]] = None,
|
||||
**props,
|
||||
) -> "ColorModeIcon":
|
||||
"""Create an icon component based on color_mode.
|
||||
|
||||
Args:
|
||||
light_component: the component to display when color mode is default
|
||||
dark_component: the component to display when color mode is dark (non-default)
|
||||
|
||||
Returns:
|
||||
The conditionally rendered component
|
||||
"""
|
||||
...
|
||||
def color_mode_icon(
|
||||
light_component: BaseComponent | None = None,
|
||||
dark_component: BaseComponent | None = None,
|
||||
): ...
|
||||
|
||||
LiteralPosition = Literal["top-left", "top-right", "bottom-left", "bottom-right"]
|
||||
position_values: List[str]
|
||||
@ -440,7 +401,7 @@ class ColorModeSwitch(Switch):
|
||||
...
|
||||
|
||||
class ColorModeNamespace(Var):
|
||||
icon = staticmethod(ColorModeIcon.create)
|
||||
icon = staticmethod(color_mode_icon)
|
||||
button = staticmethod(ColorModeIconButton.create)
|
||||
switch = staticmethod(ColorModeSwitch.create)
|
||||
|
||||
|
@ -6,7 +6,7 @@ from typing import Literal
|
||||
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.core.breakpoints import Responsive
|
||||
from reflex.components.core.match import Match
|
||||
from reflex.components.core.match import match
|
||||
from reflex.components.el import elements
|
||||
from reflex.components.lucide import Icon
|
||||
from reflex.style import Style
|
||||
@ -77,7 +77,7 @@ class IconButton(elements.Button, RadixLoadingProp, RadixThemesComponent):
|
||||
if isinstance(props["size"], str):
|
||||
children[0].size = RADIX_TO_LUCIDE_SIZE[props["size"]]
|
||||
else:
|
||||
size_map_var = Match.create(
|
||||
size_map_var = match(
|
||||
props["size"],
|
||||
*list(RADIX_TO_LUCIDE_SIZE.items()),
|
||||
12,
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from typing import Any, Iterable, Literal, Union
|
||||
|
||||
from reflex.components.component import Component, ComponentNamespace
|
||||
from reflex.components.core.foreach import Foreach
|
||||
from reflex.components.core.foreach import foreach
|
||||
from reflex.components.el.elements.typography import Li, Ol, Ul
|
||||
from reflex.components.lucide.icon import Icon
|
||||
from reflex.components.markdown.markdown import MarkdownComponentMap
|
||||
@ -70,7 +70,7 @@ class BaseList(Component, MarkdownComponentMap):
|
||||
|
||||
if not children and items is not None:
|
||||
if isinstance(items, Var):
|
||||
children = [Foreach.create(items, ListItem.create)]
|
||||
children = [foreach(items, ListItem.create)]
|
||||
else:
|
||||
children = [ListItem.create(item) for item in items]
|
||||
props["direction"] = "column"
|
||||
|
@ -1,6 +1,3 @@
|
||||
"""Representations for React tags."""
|
||||
|
||||
from .cond_tag import CondTag
|
||||
from .iter_tag import IterTag
|
||||
from .match_tag import MatchTag
|
||||
from .tag import Tag
|
||||
|
@ -1,21 +0,0 @@
|
||||
"""Tag to conditionally render components."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from reflex.components.tags.tag import Tag
|
||||
from reflex.vars.base import Var
|
||||
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class CondTag(Tag):
|
||||
"""A conditional tag."""
|
||||
|
||||
# The condition to determine which component to render.
|
||||
cond: Var[Any] = dataclasses.field(default_factory=lambda: Var.create(True))
|
||||
|
||||
# The code to render if the condition is true.
|
||||
true_value: Dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
# The code to render if the condition is false.
|
||||
false_value: Optional[Dict] = None
|
@ -1,145 +0,0 @@
|
||||
"""Tag to loop through a list of components."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, Union, get_args
|
||||
|
||||
from reflex.components.tags.tag import Tag
|
||||
from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflex.components.component import Component
|
||||
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class IterTag(Tag):
|
||||
"""An iterator tag."""
|
||||
|
||||
# The var to iterate over.
|
||||
iterable: Var[Iterable] = dataclasses.field(
|
||||
default_factory=lambda: LiteralArrayVar.create([])
|
||||
)
|
||||
|
||||
# The component render function for each item in the iterable.
|
||||
render_fn: Callable = dataclasses.field(default_factory=lambda: lambda x: x)
|
||||
|
||||
# The name of the arg var.
|
||||
arg_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
|
||||
|
||||
# The name of the index var.
|
||||
index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
|
||||
|
||||
def get_iterable_var_type(self) -> Type:
|
||||
"""Get the type of the iterable var.
|
||||
|
||||
Returns:
|
||||
The type of the iterable var.
|
||||
"""
|
||||
iterable = self.iterable
|
||||
try:
|
||||
if iterable._var_type.mro()[0] is dict:
|
||||
# Arg is a tuple of (key, value).
|
||||
return Tuple[get_args(iterable._var_type)] # pyright: ignore [reportReturnType]
|
||||
elif iterable._var_type.mro()[0] is tuple:
|
||||
# Arg is a union of any possible values in the tuple.
|
||||
return Union[get_args(iterable._var_type)] # pyright: ignore [reportReturnType]
|
||||
else:
|
||||
return get_args(iterable._var_type)[0]
|
||||
except Exception:
|
||||
return Any # pyright: ignore [reportReturnType]
|
||||
|
||||
def get_index_var(self) -> Var:
|
||||
"""Get the index var for the tag (with curly braces).
|
||||
|
||||
This is used to reference the index var within the tag.
|
||||
|
||||
Returns:
|
||||
The index var.
|
||||
"""
|
||||
return Var(
|
||||
_js_expr=self.index_var_name,
|
||||
_var_type=int,
|
||||
).guess_type()
|
||||
|
||||
def get_arg_var(self) -> Var:
|
||||
"""Get the arg var for the tag (with curly braces).
|
||||
|
||||
This is used to reference the arg var within the tag.
|
||||
|
||||
Returns:
|
||||
The arg var.
|
||||
"""
|
||||
return Var(
|
||||
_js_expr=self.arg_var_name,
|
||||
_var_type=self.get_iterable_var_type(),
|
||||
).guess_type()
|
||||
|
||||
def get_index_var_arg(self) -> Var:
|
||||
"""Get the index var for the tag (without curly braces).
|
||||
|
||||
This is used to render the index var in the .map() function.
|
||||
|
||||
Returns:
|
||||
The index var.
|
||||
"""
|
||||
return Var(
|
||||
_js_expr=self.index_var_name,
|
||||
_var_type=int,
|
||||
).guess_type()
|
||||
|
||||
def get_arg_var_arg(self) -> Var:
|
||||
"""Get the arg var for the tag (without curly braces).
|
||||
|
||||
This is used to render the arg var in the .map() function.
|
||||
|
||||
Returns:
|
||||
The arg var.
|
||||
"""
|
||||
return Var(
|
||||
_js_expr=self.arg_var_name,
|
||||
_var_type=self.get_iterable_var_type(),
|
||||
).guess_type()
|
||||
|
||||
def render_component(self) -> Component:
|
||||
"""Render the component.
|
||||
|
||||
Raises:
|
||||
ValueError: If the render function takes more than 2 arguments.
|
||||
|
||||
Returns:
|
||||
The rendered component.
|
||||
"""
|
||||
# Import here to avoid circular imports.
|
||||
from reflex.components.base.fragment import Fragment
|
||||
from reflex.components.core.cond import Cond
|
||||
from reflex.components.core.foreach import Foreach
|
||||
|
||||
# Get the render function arguments.
|
||||
args = inspect.getfullargspec(self.render_fn).args
|
||||
arg = self.get_arg_var()
|
||||
index = self.get_index_var()
|
||||
|
||||
if len(args) == 1:
|
||||
# If the render function doesn't take the index as an argument.
|
||||
component = self.render_fn(arg)
|
||||
else:
|
||||
# If the render function takes the index as an argument.
|
||||
if len(args) != 2:
|
||||
raise ValueError("The render function must take 2 arguments.")
|
||||
component = self.render_fn(arg, index)
|
||||
|
||||
# Nested foreach components or cond must be wrapped in fragments.
|
||||
if isinstance(component, (Foreach, Cond)):
|
||||
component = Fragment.create(component)
|
||||
|
||||
# If the component is a tuple, unpack and wrap it in a fragment.
|
||||
if isinstance(component, tuple):
|
||||
component = Fragment.create(*component)
|
||||
|
||||
# Set the component key.
|
||||
if component.key is None:
|
||||
component.key = index
|
||||
|
||||
return component
|
@ -1,21 +0,0 @@
|
||||
"""Tag to conditionally match cases."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, List
|
||||
|
||||
from reflex.components.tags.tag import Tag
|
||||
from reflex.vars.base import Var
|
||||
|
||||
|
||||
@dataclasses.dataclass()
|
||||
class MatchTag(Tag):
|
||||
"""A match tag."""
|
||||
|
||||
# The condition to determine which case to match.
|
||||
cond: Var[Any] = dataclasses.field(default_factory=lambda: Var.create(True))
|
||||
|
||||
# The list of match cases to be matched.
|
||||
match_cases: List[Any] = dataclasses.field(default_factory=list)
|
||||
|
||||
# The catchall case to match.
|
||||
default: Any = dataclasses.field(default=Var.create(None))
|
@ -600,14 +600,16 @@ stop_propagation = EventChain(events=[], args_spec=no_args_event_spec).stop_prop
|
||||
prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
EVENT_T = TypeVar("EVENT_T")
|
||||
EVENT_U = TypeVar("EVENT_U")
|
||||
|
||||
Ts = TypeVarTuple("Ts")
|
||||
|
||||
|
||||
class IdentityEventReturn(Generic[T], Protocol):
|
||||
class IdentityEventReturn(Generic[Unpack[Ts]], Protocol):
|
||||
"""Protocol for an identity event return."""
|
||||
|
||||
def __call__(self, *values: Var[T]) -> Tuple[Var[T], ...]:
|
||||
def __call__(self, *values: Unpack[Ts]) -> tuple[Unpack[Ts]]:
|
||||
"""Return the input values.
|
||||
|
||||
Args:
|
||||
@ -620,22 +622,26 @@ class IdentityEventReturn(Generic[T], Protocol):
|
||||
|
||||
|
||||
@overload
|
||||
def passthrough_event_spec( # pyright: ignore [reportOverlappingOverload]
|
||||
event_type: Type[T], /
|
||||
) -> Callable[[Var[T]], Tuple[Var[T]]]: ...
|
||||
def passthrough_event_spec(
|
||||
event_type: Type[EVENT_T], /
|
||||
) -> IdentityEventReturn[Var[EVENT_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def passthrough_event_spec(
|
||||
event_type_1: Type[T], event_type2: Type[U], /
|
||||
) -> Callable[[Var[T], Var[U]], Tuple[Var[T], Var[U]]]: ...
|
||||
event_type_1: Type[EVENT_T], event_type2: Type[EVENT_U], /
|
||||
) -> IdentityEventReturn[Var[EVENT_T], Var[EVENT_U]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: ...
|
||||
def passthrough_event_spec(
|
||||
*event_types: Unpack[tuple[Type[EVENT_T]]],
|
||||
) -> IdentityEventReturn[Unpack[tuple[Var[EVENT_T], ...]]]: ...
|
||||
|
||||
|
||||
def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: # pyright: ignore [reportInconsistentOverload]
|
||||
def passthrough_event_spec( # pyright: ignore[reportInconsistentOverload]
|
||||
*event_types: Type[EVENT_T],
|
||||
) -> IdentityEventReturn[Unpack[tuple[Var[EVENT_T], ...]]]:
|
||||
"""A helper function that returns the input event as output.
|
||||
|
||||
Args:
|
||||
@ -645,7 +651,7 @@ def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: #
|
||||
A function that returns the input event as output.
|
||||
"""
|
||||
|
||||
def inner(*values: Var[T]) -> Tuple[Var[T], ...]:
|
||||
def inner(*values: Var[EVENT_T]) -> Tuple[Var[EVENT_T], ...]:
|
||||
return values
|
||||
|
||||
inner_type = tuple(Var[event_type] for event_type in event_types)
|
||||
@ -780,7 +786,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
|
||||
return None
|
||||
|
||||
fn.__qualname__ = name
|
||||
fn.__signature__ = sig # pyright: ignore [reportFunctionMemberAccess]
|
||||
fn.__signature__ = sig # pyright: ignore[reportFunctionMemberAccess]
|
||||
return EventSpec(
|
||||
handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
|
||||
args=tuple(
|
||||
|
@ -607,8 +607,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if cls._item_is_event_handler(name, fn)
|
||||
}
|
||||
|
||||
for mixin in cls._mixins(): # pyright: ignore [reportAssignmentType]
|
||||
for name, value in mixin.__dict__.items():
|
||||
for mixin_class in cls._mixins():
|
||||
for name, value in mixin_class.__dict__.items():
|
||||
if name in cls.inherited_vars:
|
||||
continue
|
||||
if is_computed_var(value):
|
||||
@ -619,7 +619,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
cls.computed_vars[newcv._js_expr] = newcv
|
||||
cls.vars[newcv._js_expr] = newcv
|
||||
continue
|
||||
if types.is_backend_base_variable(name, mixin): # pyright: ignore [reportArgumentType]
|
||||
if types.is_backend_base_variable(name, mixin_class):
|
||||
cls.backend_vars[name] = copy.deepcopy(value)
|
||||
continue
|
||||
if events.get(name) is not None:
|
||||
@ -3653,6 +3653,9 @@ def get_state_manager() -> StateManager:
|
||||
return prerequisites.get_and_validate_app().app.state_manager
|
||||
|
||||
|
||||
DATACLASS_FIELDS = getattr(dataclasses, "_FIELDS", "__dataclass_fields__")
|
||||
|
||||
|
||||
class MutableProxy(wrapt.ObjectProxy):
|
||||
"""A proxy for a mutable object that tracks changes."""
|
||||
|
||||
@ -3724,12 +3727,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
cls.__dataclass_proxies__[wrapper_cls_name] = type(
|
||||
wrapper_cls_name,
|
||||
(cls,),
|
||||
{
|
||||
dataclasses._FIELDS: getattr( # pyright: ignore [reportAttributeAccessIssue]
|
||||
wrapped_cls,
|
||||
dataclasses._FIELDS, # pyright: ignore [reportAttributeAccessIssue]
|
||||
),
|
||||
},
|
||||
{DATACLASS_FIELDS: getattr(wrapped_cls, DATACLASS_FIELDS)},
|
||||
)
|
||||
cls = cls.__dataclass_proxies__[wrapper_cls_name]
|
||||
return super().__new__(cls)
|
||||
@ -3878,11 +3876,11 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
if (
|
||||
isinstance(self.__wrapped__, Base)
|
||||
and __name not in self.__never_wrap_base_attrs__
|
||||
and hasattr(value, "__func__")
|
||||
and (value_func := getattr(value, "__func__", None))
|
||||
):
|
||||
# Wrap methods called on Base subclasses, which might do _anything_
|
||||
return wrapt.FunctionWrapper(
|
||||
functools.partial(value.__func__, self), # pyright: ignore [reportFunctionMemberAccess]
|
||||
functools.partial(value_func, self),
|
||||
self._wrap_recursive_decorator,
|
||||
)
|
||||
|
||||
|
@ -12,7 +12,7 @@ from reflex.utils.exceptions import ReflexError
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.utils.types import get_origin
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import CallableVar, LiteralVar, Var
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
from reflex.vars.function import FunctionVar
|
||||
from reflex.vars.object import ObjectVar
|
||||
|
||||
@ -48,7 +48,6 @@ def _color_mode_var(_js_expr: str, _var_type: Type = str) -> Var:
|
||||
).guess_type()
|
||||
|
||||
|
||||
@CallableVar
|
||||
def set_color_mode(
|
||||
new_color_mode: LiteralColorMode | Var[LiteralColorMode] | None = None,
|
||||
) -> Var[EventChain]:
|
||||
|
@ -68,10 +68,8 @@ try:
|
||||
from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports]
|
||||
WebElement,
|
||||
)
|
||||
|
||||
has_selenium = True
|
||||
except ImportError:
|
||||
has_selenium = False
|
||||
webdriver = None
|
||||
|
||||
# The timeout (minutes) to check for the port.
|
||||
DEFAULT_TIMEOUT = 15
|
||||
@ -296,11 +294,13 @@ class AppHarness:
|
||||
if p not in before_decorated_pages
|
||||
]
|
||||
self.app_instance = self.app_module.app
|
||||
if self.app_instance and isinstance(
|
||||
self.app_instance._state_manager, StateManagerRedis
|
||||
):
|
||||
if self.app_instance is None:
|
||||
raise RuntimeError("App was not initialized.")
|
||||
if isinstance(self.app_instance._state_manager, StateManagerRedis):
|
||||
# Create our own redis connection for testing.
|
||||
self.state_manager = StateManagerRedis.create(self.app_instance._state) # pyright: ignore [reportArgumentType]
|
||||
if self.app_instance._state is None:
|
||||
raise RuntimeError("App state is not initialized.")
|
||||
self.state_manager = StateManagerRedis.create(self.app_instance._state)
|
||||
else:
|
||||
self.state_manager = (
|
||||
self.app_instance._state_manager if self.app_instance else None
|
||||
@ -615,7 +615,7 @@ class AppHarness:
|
||||
Raises:
|
||||
RuntimeError: when selenium is not importable or frontend is not running
|
||||
"""
|
||||
if not has_selenium:
|
||||
if webdriver is None:
|
||||
raise RuntimeError(
|
||||
"Frontend functionality requires `selenium` to be installed, "
|
||||
"and it could not be imported."
|
||||
|
@ -201,10 +201,13 @@ def _get_first_non_framework_frame() -> FrameType | None:
|
||||
# Exclude utility modules that should never be the source of deprecated reflex usage.
|
||||
exclude_modules = [click, rx, typer, typing_extensions]
|
||||
exclude_roots = [
|
||||
p.parent.resolve()
|
||||
if (p := Path(m.__file__)).name == "__init__.py" # pyright: ignore [reportArgumentType]
|
||||
else p.resolve()
|
||||
(
|
||||
p.parent.resolve()
|
||||
if (p := Path(m.__file__)).name == "__init__.py"
|
||||
else p.resolve()
|
||||
)
|
||||
for m in exclude_modules
|
||||
if m.__file__
|
||||
]
|
||||
# Specifically exclude the reflex cli module.
|
||||
if reflex_bin := shutil.which(b"reflex"):
|
||||
|
@ -4,9 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from reflex import constants
|
||||
from reflex.constants.state import FRONTEND_EVENT_STATE
|
||||
@ -136,22 +135,6 @@ def wrap(
|
||||
return f"{open * num}{text}{close * num}"
|
||||
|
||||
|
||||
def indent(text: str, indent_level: int = 2) -> str:
|
||||
"""Indent the given text by the given indent level.
|
||||
|
||||
Args:
|
||||
text: The text to indent.
|
||||
indent_level: The indent level.
|
||||
|
||||
Returns:
|
||||
The indented text.
|
||||
"""
|
||||
lines = text.splitlines()
|
||||
if len(lines) < 2:
|
||||
return text
|
||||
return os.linesep.join(f"{' ' * indent_level}{line}" for line in lines) + os.linesep
|
||||
|
||||
|
||||
def to_snake_case(text: str) -> str:
|
||||
"""Convert a string to snake case.
|
||||
|
||||
@ -239,80 +222,6 @@ def make_default_page_title(app_name: str, route: str) -> str:
|
||||
return to_title_case(title)
|
||||
|
||||
|
||||
def _escape_js_string(string: str) -> str:
|
||||
"""Escape the string for use as a JS string literal.
|
||||
|
||||
Args:
|
||||
string: The string to escape.
|
||||
|
||||
Returns:
|
||||
The escaped string.
|
||||
"""
|
||||
|
||||
# TODO: we may need to re-vist this logic after new Var API is implemented.
|
||||
def escape_outside_segments(segment: str):
|
||||
"""Escape backticks in segments outside of `${}`.
|
||||
|
||||
Args:
|
||||
segment: The part of the string to escape.
|
||||
|
||||
Returns:
|
||||
The escaped or unescaped segment.
|
||||
"""
|
||||
if segment.startswith("${") and segment.endswith("}"):
|
||||
# Return the `${}` segment unchanged
|
||||
return segment
|
||||
else:
|
||||
# Escape backticks in the segment
|
||||
segment = segment.replace(r"\`", "`")
|
||||
segment = segment.replace("`", r"\`")
|
||||
return segment
|
||||
|
||||
# Split the string into parts, keeping the `${}` segments
|
||||
parts = re.split(r"(\$\{.*?\})", string)
|
||||
escaped_parts = [escape_outside_segments(part) for part in parts]
|
||||
escaped_string = "".join(escaped_parts)
|
||||
return escaped_string
|
||||
|
||||
|
||||
def _wrap_js_string(string: str) -> str:
|
||||
"""Wrap string so it looks like {`string`}.
|
||||
|
||||
Args:
|
||||
string: The string to wrap.
|
||||
|
||||
Returns:
|
||||
The wrapped string.
|
||||
"""
|
||||
string = wrap(string, "`")
|
||||
string = wrap(string, "{")
|
||||
return string
|
||||
|
||||
|
||||
def format_string(string: str) -> str:
|
||||
"""Format the given string as a JS string literal..
|
||||
|
||||
Args:
|
||||
string: The string to format.
|
||||
|
||||
Returns:
|
||||
The formatted string.
|
||||
"""
|
||||
return _wrap_js_string(_escape_js_string(string))
|
||||
|
||||
|
||||
def format_var(var: Var) -> str:
|
||||
"""Format the given Var as a javascript value.
|
||||
|
||||
Args:
|
||||
var: The Var to format.
|
||||
|
||||
Returns:
|
||||
The formatted Var.
|
||||
"""
|
||||
return str(var)
|
||||
|
||||
|
||||
def format_route(route: str, format_case: bool = True) -> str:
|
||||
"""Format the given route.
|
||||
|
||||
@ -335,40 +244,6 @@ def format_route(route: str, format_case: bool = True) -> str:
|
||||
return route
|
||||
|
||||
|
||||
def format_match(
|
||||
cond: str | Var,
|
||||
match_cases: List[List[Var]],
|
||||
default: Var,
|
||||
) -> str:
|
||||
"""Format a match expression whose return type is a Var.
|
||||
|
||||
Args:
|
||||
cond: The condition.
|
||||
match_cases: The list of cases to match.
|
||||
default: The default case.
|
||||
|
||||
Returns:
|
||||
The formatted match expression
|
||||
|
||||
"""
|
||||
switch_code = f"(() => {{ switch (JSON.stringify({cond})) {{"
|
||||
|
||||
for case in match_cases:
|
||||
conditions = case[:-1]
|
||||
return_value = case[-1]
|
||||
|
||||
case_conditions = " ".join(
|
||||
[f"case JSON.stringify({condition!s}):" for condition in conditions]
|
||||
)
|
||||
case_code = f"{case_conditions} return ({return_value!s}); break;"
|
||||
switch_code += case_code
|
||||
|
||||
switch_code += f"default: return ({default!s}); break;"
|
||||
switch_code += "};})()"
|
||||
|
||||
return switch_code
|
||||
|
||||
|
||||
def format_prop(
|
||||
prop: Union[Var, EventChain, ComponentStyle, str],
|
||||
) -> Union[int, float, str]:
|
||||
|
@ -16,7 +16,7 @@ from itertools import chain
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin
|
||||
from typing import Any, Callable, Iterable, Sequence, Type, cast, get_args, get_origin
|
||||
|
||||
from reflex.components.component import Component
|
||||
from reflex.utils import types as rx_types
|
||||
@ -230,7 +230,9 @@ def _generate_imports(
|
||||
"""
|
||||
return [
|
||||
*[
|
||||
ast.ImportFrom(module=name, names=[ast.alias(name=val) for val in values]) # pyright: ignore [reportCallIssue]
|
||||
ast.ImportFrom(
|
||||
module=name, names=[ast.alias(name=val) for val in values], level=0
|
||||
)
|
||||
for name, values in DEFAULT_IMPORTS.items()
|
||||
],
|
||||
ast.Import([ast.alias("reflex")]),
|
||||
@ -429,18 +431,15 @@ def type_to_ast(typ: Any, cls: type) -> ast.AST:
|
||||
return ast.Name(id=base_name)
|
||||
|
||||
# Convert all type arguments recursively
|
||||
arg_nodes = [type_to_ast(arg, cls) for arg in args]
|
||||
arg_nodes = cast(list[ast.expr], [type_to_ast(arg, cls) for arg in args])
|
||||
|
||||
# Special case for single-argument types (like List[T] or Optional[T])
|
||||
if len(arg_nodes) == 1:
|
||||
slice_value = arg_nodes[0]
|
||||
else:
|
||||
slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load()) # pyright: ignore [reportArgumentType]
|
||||
|
||||
slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load())
|
||||
return ast.Subscript(
|
||||
value=ast.Name(id=base_name),
|
||||
slice=ast.Index(value=slice_value), # pyright: ignore [reportArgumentType]
|
||||
ctx=ast.Load(),
|
||||
value=ast.Name(id=base_name), slice=slice_value, ctx=ast.Load()
|
||||
)
|
||||
|
||||
|
||||
@ -635,7 +634,7 @@ def _generate_component_create_functiondef(
|
||||
),
|
||||
),
|
||||
ast.Expr(
|
||||
value=ast.Constant(value=Ellipsis),
|
||||
value=ast.Constant(...),
|
||||
),
|
||||
],
|
||||
decorator_list=[
|
||||
@ -646,8 +645,8 @@ def _generate_component_create_functiondef(
|
||||
else [ast.Name(id="classmethod")]
|
||||
),
|
||||
],
|
||||
lineno=node.lineno if node is not None else None, # pyright: ignore [reportArgumentType]
|
||||
returns=ast.Constant(value=clz.__name__),
|
||||
lineno=node.lineno if node is not None else None, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
return definition
|
||||
|
||||
@ -695,7 +694,6 @@ def _generate_staticmethod_call_functiondef(
|
||||
),
|
||||
],
|
||||
decorator_list=[ast.Name(id="staticmethod")],
|
||||
lineno=node.lineno if node is not None else None, # pyright: ignore [reportArgumentType]
|
||||
returns=ast.Constant(
|
||||
value=_get_type_hint(
|
||||
typing.get_type_hints(clz.__call__).get("return", None),
|
||||
@ -703,6 +701,7 @@ def _generate_staticmethod_call_functiondef(
|
||||
is_optional=False,
|
||||
)
|
||||
),
|
||||
lineno=node.lineno if node is not None else None, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
return definition
|
||||
|
||||
@ -723,6 +722,9 @@ def _generate_namespace_call_functiondef(
|
||||
|
||||
Returns:
|
||||
The create functiondef node for the ast.
|
||||
|
||||
Raises:
|
||||
TypeError: If the __call__ method does not have a __func__.
|
||||
"""
|
||||
# add the imports needed by get_type_hint later
|
||||
type_hint_globals.update(
|
||||
@ -737,7 +739,12 @@ def _generate_namespace_call_functiondef(
|
||||
# Determine which class is wrapped by the namespace __call__ method
|
||||
component_clz = clz.__call__.__self__
|
||||
|
||||
if clz.__call__.__func__.__name__ != "create": # pyright: ignore [reportFunctionMemberAccess]
|
||||
func = getattr(clz.__call__, "__func__", None)
|
||||
|
||||
if func is None:
|
||||
raise TypeError(f"__call__ method on {clz_name} does not have a __func__")
|
||||
|
||||
if func.__name__ != "create":
|
||||
return None
|
||||
|
||||
definition = _generate_component_create_functiondef(
|
||||
@ -920,7 +927,7 @@ class StubGenerator(ast.NodeTransformer):
|
||||
node.body.append(call_definition)
|
||||
if not node.body:
|
||||
# We should never return an empty body.
|
||||
node.body.append(ast.Expr(value=ast.Constant(value=Ellipsis)))
|
||||
node.body.append(ast.Expr(value=ast.Constant(...)))
|
||||
self.current_class = None
|
||||
return node
|
||||
|
||||
@ -947,9 +954,9 @@ class StubGenerator(ast.NodeTransformer):
|
||||
if node.name.startswith("_") and node.name != "__call__":
|
||||
return None # remove private methods
|
||||
|
||||
if node.body[-1] != ast.Expr(value=ast.Constant(value=Ellipsis)):
|
||||
if node.body[-1] != ast.Expr(value=ast.Constant(...)):
|
||||
# Blank out the function body for public functions.
|
||||
node.body = [ast.Expr(value=ast.Constant(value=Ellipsis))]
|
||||
node.body = [ast.Expr(value=ast.Constant(...))]
|
||||
return node
|
||||
|
||||
def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:
|
||||
|
@ -7,6 +7,7 @@ import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from collections import abc
|
||||
from functools import cached_property, lru_cache, wraps
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -23,6 +24,7 @@ from typing import (
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
_GenericAlias, # pyright: ignore [reportAttributeAccessIssue]
|
||||
get_args,
|
||||
@ -31,6 +33,7 @@ from typing import (
|
||||
from typing import get_origin as get_origin_og
|
||||
|
||||
import sqlalchemy
|
||||
import typing_extensions
|
||||
from typing_extensions import is_typeddict
|
||||
|
||||
import reflex
|
||||
@ -68,13 +71,13 @@ else:
|
||||
|
||||
|
||||
# Potential GenericAlias types for isinstance checks.
|
||||
GenericAliasTypes = [_GenericAlias]
|
||||
_GenericAliasTypes: list[type] = [_GenericAlias]
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
# For newer versions of Python.
|
||||
from types import GenericAlias
|
||||
|
||||
GenericAliasTypes.append(GenericAlias)
|
||||
_GenericAliasTypes.append(GenericAlias)
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
# For older versions of Python.
|
||||
@ -82,9 +85,9 @@ with contextlib.suppress(ImportError):
|
||||
_SpecialGenericAlias, # pyright: ignore [reportAttributeAccessIssue]
|
||||
)
|
||||
|
||||
GenericAliasTypes.append(_SpecialGenericAlias)
|
||||
_GenericAliasTypes.append(_SpecialGenericAlias)
|
||||
|
||||
GenericAliasTypes = tuple(GenericAliasTypes)
|
||||
GenericAliasTypes = tuple(_GenericAliasTypes)
|
||||
|
||||
# Potential Union types for isinstance checks (UnionType added in py3.10).
|
||||
UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,)
|
||||
@ -183,7 +186,7 @@ def is_generic_alias(cls: GenericType) -> bool:
|
||||
return isinstance(cls, GenericAliasTypes) # pyright: ignore [reportArgumentType]
|
||||
|
||||
|
||||
def unionize(*args: GenericType) -> Type:
|
||||
def unionize(*args: GenericType) -> GenericType:
|
||||
"""Unionize the types.
|
||||
|
||||
Args:
|
||||
@ -417,7 +420,7 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_base_class(cls: GenericType) -> Type:
|
||||
def get_base_class(cls: GenericType) -> Type | tuple[Type, ...]:
|
||||
"""Get the base class of a class.
|
||||
|
||||
Args:
|
||||
@ -437,7 +440,14 @@ def get_base_class(cls: GenericType) -> Type:
|
||||
return type(get_args(cls)[0])
|
||||
|
||||
if is_union(cls):
|
||||
return tuple(get_base_class(arg) for arg in get_args(cls)) # pyright: ignore [reportReturnType]
|
||||
base_classes = []
|
||||
for arg in get_args(cls):
|
||||
sub_base_classes = get_base_class(arg)
|
||||
if isinstance(sub_base_classes, tuple):
|
||||
base_classes.extend(sub_base_classes)
|
||||
else:
|
||||
base_classes.append(sub_base_classes)
|
||||
return tuple(base_classes)
|
||||
|
||||
return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls
|
||||
|
||||
@ -847,18 +857,22 @@ StateBases = get_base_class(StateVar)
|
||||
StateIterBases = get_base_class(StateIterVar)
|
||||
|
||||
|
||||
def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]):
|
||||
"""Check if a class is a subclass of another class. Returns False if internal error occurs.
|
||||
def safe_issubclass(cls: Any, class_or_tuple: Any, /) -> bool:
|
||||
"""Check if a class is a subclass of another class or a tuple of classes.
|
||||
|
||||
Args:
|
||||
cls: The class to check.
|
||||
cls_check: The class to check against.
|
||||
class_or_tuple: The class or tuple of classes to check against.
|
||||
|
||||
Returns:
|
||||
Whether the class is a subclass of the other class.
|
||||
Whether the class is a subclass of the other class or tuple of classes.
|
||||
"""
|
||||
if cls is class_or_tuple or (
|
||||
isinstance(class_or_tuple, tuple) and cls in class_or_tuple
|
||||
):
|
||||
return True
|
||||
try:
|
||||
return issubclass(cls, cls_check)
|
||||
return issubclass(cls, class_or_tuple)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
@ -873,17 +887,32 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
|
||||
Returns:
|
||||
Whether the type hint is a subclass of the other type hint.
|
||||
"""
|
||||
if isinstance(possible_subclass, Sequence) and isinstance(
|
||||
possible_superclass, Sequence
|
||||
):
|
||||
return all(
|
||||
typehint_issubclass(subclass, superclass)
|
||||
for subclass, superclass in zip(
|
||||
possible_subclass, possible_superclass, strict=False
|
||||
)
|
||||
)
|
||||
if possible_subclass is possible_superclass:
|
||||
return True
|
||||
if possible_superclass is Any:
|
||||
return True
|
||||
if possible_subclass is Any:
|
||||
return False
|
||||
if isinstance(
|
||||
possible_subclass, (TypeVar, typing_extensions.TypeVar)
|
||||
) or isinstance(possible_superclass, (TypeVar, typing_extensions.TypeVar)):
|
||||
return True
|
||||
|
||||
provided_type_origin = get_origin(possible_subclass)
|
||||
accepted_type_origin = get_origin(possible_superclass)
|
||||
|
||||
if provided_type_origin is None and accepted_type_origin is None:
|
||||
# In this case, we are dealing with a non-generic type, so we can use issubclass
|
||||
return issubclass(possible_subclass, possible_superclass)
|
||||
return safe_issubclass(possible_subclass, possible_superclass)
|
||||
|
||||
# Remove this check when Python 3.10 is the minimum supported version
|
||||
if hasattr(types, "UnionType"):
|
||||
@ -898,24 +927,64 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
|
||||
provided_args = get_args(possible_subclass)
|
||||
accepted_args = get_args(possible_superclass)
|
||||
|
||||
if accepted_type_origin is Union:
|
||||
if provided_type_origin is not Union:
|
||||
return any(
|
||||
typehint_issubclass(possible_subclass, accepted_arg)
|
||||
for accepted_arg in accepted_args
|
||||
)
|
||||
if provided_type_origin is Union:
|
||||
return all(
|
||||
any(
|
||||
typehint_issubclass(provided_arg, accepted_arg)
|
||||
for accepted_arg in accepted_args
|
||||
)
|
||||
typehint_issubclass(provided_arg, possible_superclass)
|
||||
for provided_arg in provided_args
|
||||
)
|
||||
|
||||
if accepted_type_origin is Union:
|
||||
return any(
|
||||
typehint_issubclass(possible_subclass, accepted_arg)
|
||||
for accepted_arg in accepted_args
|
||||
)
|
||||
|
||||
# Check specifically for Sequence and Iterable
|
||||
if (accepted_type_origin or possible_superclass) in (
|
||||
Sequence,
|
||||
abc.Sequence,
|
||||
Iterable,
|
||||
abc.Iterable,
|
||||
):
|
||||
iterable_type = accepted_args[0] if accepted_args else Any
|
||||
|
||||
if provided_type_origin is None:
|
||||
if not safe_issubclass(
|
||||
possible_subclass, (accepted_type_origin or possible_superclass)
|
||||
):
|
||||
return False
|
||||
|
||||
if safe_issubclass(possible_subclass, str) and not isinstance(
|
||||
iterable_type, TypeVar
|
||||
):
|
||||
return typehint_issubclass(str, iterable_type)
|
||||
return True
|
||||
|
||||
if not safe_issubclass(
|
||||
provided_type_origin, (accepted_type_origin or possible_superclass)
|
||||
):
|
||||
return False
|
||||
|
||||
if not isinstance(iterable_type, (TypeVar, typing_extensions.TypeVar)):
|
||||
if provided_type_origin in (list, tuple, set):
|
||||
# Ensure all specific types are compatible with accepted types
|
||||
return all(
|
||||
typehint_issubclass(provided_arg, iterable_type)
|
||||
for provided_arg in provided_args
|
||||
if provided_arg is not ... # Ellipsis in Tuples
|
||||
)
|
||||
if possible_subclass is dict:
|
||||
# Ensure all specific types are compatible with accepted types
|
||||
return all(
|
||||
typehint_issubclass(provided_arg, iterable_type)
|
||||
for provided_arg in provided_args[:1]
|
||||
)
|
||||
return True
|
||||
|
||||
# Check if the origin of both types is the same (e.g., list for List[int])
|
||||
# This probably should be issubclass instead of ==
|
||||
if (provided_type_origin or possible_subclass) != (
|
||||
accepted_type_origin or possible_superclass
|
||||
if not safe_issubclass(
|
||||
provided_type_origin or possible_subclass,
|
||||
accepted_type_origin or possible_superclass,
|
||||
):
|
||||
return False
|
||||
|
||||
@ -927,5 +996,21 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
|
||||
for provided_arg, accepted_arg in zip(
|
||||
provided_args, accepted_args, strict=False
|
||||
)
|
||||
if accepted_arg is not Any
|
||||
if accepted_arg is not Any and not isinstance(accepted_arg, TypeVar)
|
||||
)
|
||||
|
||||
|
||||
def safe_typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
|
||||
"""Check if a type hint is a subclass of another type hint.
|
||||
|
||||
Args:
|
||||
possible_subclass: The type hint to check.
|
||||
possible_superclass: The type hint to check against.
|
||||
|
||||
Returns:
|
||||
Whether the type hint is a subclass of the other type hint.
|
||||
"""
|
||||
try:
|
||||
return typehint_issubclass(possible_subclass, possible_superclass)
|
||||
except Exception:
|
||||
return False
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -4,10 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from datetime import date, datetime
|
||||
from typing import Any, NoReturn, TypeVar, Union, overload
|
||||
|
||||
from reflex.utils.exceptions import VarTypeError
|
||||
from reflex.vars.number import BooleanVar
|
||||
from typing import TypeVar, Union
|
||||
|
||||
from .base import (
|
||||
CustomVarOperationReturn,
|
||||
@ -23,156 +20,11 @@ DATETIME_T = TypeVar("DATETIME_T", datetime, date)
|
||||
datetime_types = Union[datetime, date]
|
||||
|
||||
|
||||
def raise_var_type_error():
|
||||
"""Raise a VarTypeError.
|
||||
|
||||
Raises:
|
||||
VarTypeError: Cannot compare a datetime object with a non-datetime object.
|
||||
"""
|
||||
raise VarTypeError("Cannot compare a datetime object with a non-datetime object.")
|
||||
|
||||
|
||||
class DateTimeVar(Var[DATETIME_T], python_types=(datetime, date)):
|
||||
"""A variable that holds a datetime or date object."""
|
||||
|
||||
@overload
|
||||
def __lt__(self, other: datetime_types) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __lt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __lt__(self, other: Any):
|
||||
"""Less than comparison.
|
||||
|
||||
Args:
|
||||
other: The other datetime to compare.
|
||||
|
||||
Returns:
|
||||
The result of the comparison.
|
||||
"""
|
||||
if not isinstance(other, DATETIME_TYPES):
|
||||
raise_var_type_error()
|
||||
return date_lt_operation(self, other)
|
||||
|
||||
@overload
|
||||
def __le__(self, other: datetime_types) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __le__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __le__(self, other: Any):
|
||||
"""Less than or equal comparison.
|
||||
|
||||
Args:
|
||||
other: The other datetime to compare.
|
||||
|
||||
Returns:
|
||||
The result of the comparison.
|
||||
"""
|
||||
if not isinstance(other, DATETIME_TYPES):
|
||||
raise_var_type_error()
|
||||
return date_le_operation(self, other)
|
||||
|
||||
@overload
|
||||
def __gt__(self, other: datetime_types) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __gt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __gt__(self, other: Any):
|
||||
"""Greater than comparison.
|
||||
|
||||
Args:
|
||||
other: The other datetime to compare.
|
||||
|
||||
Returns:
|
||||
The result of the comparison.
|
||||
"""
|
||||
if not isinstance(other, DATETIME_TYPES):
|
||||
raise_var_type_error()
|
||||
return date_gt_operation(self, other)
|
||||
|
||||
@overload
|
||||
def __ge__(self, other: datetime_types) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __ge__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __ge__(self, other: Any):
|
||||
"""Greater than or equal comparison.
|
||||
|
||||
Args:
|
||||
other: The other datetime to compare.
|
||||
|
||||
Returns:
|
||||
The result of the comparison.
|
||||
"""
|
||||
if not isinstance(other, DATETIME_TYPES):
|
||||
raise_var_type_error()
|
||||
return date_ge_operation(self, other)
|
||||
|
||||
|
||||
@var_operation
|
||||
def date_gt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
|
||||
"""Greater than comparison.
|
||||
|
||||
Args:
|
||||
lhs: The left-hand side of the operation.
|
||||
rhs: The right-hand side of the operation.
|
||||
|
||||
Returns:
|
||||
The result of the operation.
|
||||
"""
|
||||
return date_compare_operation(rhs, lhs, strict=True)
|
||||
|
||||
|
||||
@var_operation
|
||||
def date_lt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
|
||||
"""Less than comparison.
|
||||
|
||||
Args:
|
||||
lhs: The left-hand side of the operation.
|
||||
rhs: The right-hand side of the operation.
|
||||
|
||||
Returns:
|
||||
The result of the operation.
|
||||
"""
|
||||
return date_compare_operation(lhs, rhs, strict=True)
|
||||
|
||||
|
||||
@var_operation
|
||||
def date_le_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
|
||||
"""Less than or equal comparison.
|
||||
|
||||
Args:
|
||||
lhs: The left-hand side of the operation.
|
||||
rhs: The right-hand side of the operation.
|
||||
|
||||
Returns:
|
||||
The result of the operation.
|
||||
"""
|
||||
return date_compare_operation(lhs, rhs)
|
||||
|
||||
|
||||
@var_operation
|
||||
def date_ge_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
|
||||
"""Greater than or equal comparison.
|
||||
|
||||
Args:
|
||||
lhs: The left-hand side of the operation.
|
||||
rhs: The right-hand side of the operation.
|
||||
|
||||
Returns:
|
||||
The result of the operation.
|
||||
"""
|
||||
return date_compare_operation(rhs, lhs)
|
||||
|
||||
|
||||
def date_compare_operation(
|
||||
lhs: DateTimeVar[DATETIME_T] | Any,
|
||||
rhs: DateTimeVar[DATETIME_T] | Any,
|
||||
lhs: Var[datetime_types],
|
||||
rhs: Var[datetime_types],
|
||||
strict: bool = False,
|
||||
) -> CustomVarOperationReturn:
|
||||
) -> CustomVarOperationReturn[bool]:
|
||||
"""Check if the value is less than the other value.
|
||||
|
||||
Args:
|
||||
@ -189,6 +41,84 @@ def date_compare_operation(
|
||||
)
|
||||
|
||||
|
||||
@var_operation
|
||||
def date_gt_operation(
|
||||
lhs: Var[datetime_types],
|
||||
rhs: Var[datetime_types],
|
||||
) -> CustomVarOperationReturn:
|
||||
"""Greater than comparison.
|
||||
|
||||
Args:
|
||||
lhs: The left-hand side of the operation.
|
||||
rhs: The right-hand side of the operation.
|
||||
|
||||
Returns:
|
||||
The result of the operation.
|
||||
"""
|
||||
return date_compare_operation(rhs, lhs, strict=True)
|
||||
|
||||
|
||||
@var_operation
|
||||
def date_lt_operation(
|
||||
lhs: Var[datetime_types],
|
||||
rhs: Var[datetime_types],
|
||||
) -> CustomVarOperationReturn:
|
||||
"""Less than comparison.
|
||||
|
||||
Args:
|
||||
lhs: The left-hand side of the operation.
|
||||
rhs: The right-hand side of the operation.
|
||||
|
||||
Returns:
|
||||
The result of the operation.
|
||||
"""
|
||||
return date_compare_operation(lhs, rhs, strict=True)
|
||||
|
||||
|
||||
@var_operation
|
||||
def date_le_operation(
|
||||
lhs: Var[datetime_types], rhs: Var[datetime_types]
|
||||
) -> CustomVarOperationReturn:
|
||||
"""Less than or equal comparison.
|
||||
|
||||
Args:
|
||||
lhs: The left-hand side of the operation.
|
||||
rhs: The right-hand side of the operation.
|
||||
|
||||
Returns:
|
||||
The result of the operation.
|
||||
"""
|
||||
return date_compare_operation(lhs, rhs)
|
||||
|
||||
|
||||
@var_operation
|
||||
def date_ge_operation(
|
||||
lhs: Var[datetime_types], rhs: Var[datetime_types]
|
||||
) -> CustomVarOperationReturn:
|
||||
"""Greater than or equal comparison.
|
||||
|
||||
Args:
|
||||
lhs: The left-hand side of the operation.
|
||||
rhs: The right-hand side of the operation.
|
||||
|
||||
Returns:
|
||||
The result of the operation.
|
||||
"""
|
||||
return date_compare_operation(rhs, lhs)
|
||||
|
||||
|
||||
class DateTimeVar(Var[DATETIME_T], python_types=(datetime, date)):
|
||||
"""A variable that holds a datetime or date object."""
|
||||
|
||||
__lt__ = date_lt_operation
|
||||
|
||||
__le__ = date_le_operation
|
||||
|
||||
__gt__ = date_gt_operation
|
||||
|
||||
__ge__ = date_ge_operation
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import json
|
||||
import math
|
||||
from typing import (
|
||||
@ -10,21 +11,30 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
NoReturn,
|
||||
Type,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from reflex.constants.base import Dirs
|
||||
from reflex.utils.exceptions import PrimitiveUnserializableToJSONError, VarTypeError
|
||||
from reflex.utils.imports import ImportDict, ImportVar
|
||||
|
||||
from .base import (
|
||||
VAR_TYPE,
|
||||
CachedVarOperation,
|
||||
CustomVarOperationReturn,
|
||||
LiteralVar,
|
||||
ReflexCallable,
|
||||
Var,
|
||||
VarData,
|
||||
cached_property_no_lock,
|
||||
nary_type_computer,
|
||||
passthrough_unary_type_computer,
|
||||
unionize,
|
||||
var_operation,
|
||||
var_operation_return,
|
||||
@ -33,6 +43,7 @@ from .base import (
|
||||
NUMBER_T = TypeVar("NUMBER_T", int, float, bool)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .function import FunctionVar
|
||||
from .sequence import ArrayVar
|
||||
|
||||
|
||||
@ -56,13 +67,7 @@ def raise_unsupported_operand_types(
|
||||
class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
"""Base class for immutable number vars."""
|
||||
|
||||
@overload
|
||||
def __add__(self, other: number_types) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __add__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __add__(self, other: Any):
|
||||
def __add__(self, other: number_types) -> NumberVar:
|
||||
"""Add two numbers.
|
||||
|
||||
Args:
|
||||
@ -73,15 +78,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
"""
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("+", (type(self), type(other)))
|
||||
return number_add_operation(self, +other)
|
||||
return number_add_operation(self, +other).guess_type()
|
||||
|
||||
@overload
|
||||
def __radd__(self, other: number_types) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __radd__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __radd__(self, other: Any):
|
||||
def __radd__(self, other: number_types) -> NumberVar:
|
||||
"""Add two numbers.
|
||||
|
||||
Args:
|
||||
@ -92,15 +91,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
"""
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("+", (type(other), type(self)))
|
||||
return number_add_operation(+other, self)
|
||||
return number_add_operation(+other, self).guess_type()
|
||||
|
||||
@overload
|
||||
def __sub__(self, other: number_types) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __sub__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __sub__(self, other: Any):
|
||||
def __sub__(self, other: number_types) -> NumberVar:
|
||||
"""Subtract two numbers.
|
||||
|
||||
Args:
|
||||
@ -112,15 +105,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("-", (type(self), type(other)))
|
||||
|
||||
return number_subtract_operation(self, +other)
|
||||
return number_subtract_operation(self, +other).guess_type()
|
||||
|
||||
@overload
|
||||
def __rsub__(self, other: number_types) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __rsub__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __rsub__(self, other: Any):
|
||||
def __rsub__(self, other: number_types) -> NumberVar:
|
||||
"""Subtract two numbers.
|
||||
|
||||
Args:
|
||||
@ -132,7 +119,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("-", (type(other), type(self)))
|
||||
|
||||
return number_subtract_operation(+other, self)
|
||||
return number_subtract_operation(+other, self).guess_type()
|
||||
|
||||
def __abs__(self):
|
||||
"""Get the absolute value of the number.
|
||||
@ -167,7 +154,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("*", (type(self), type(other)))
|
||||
|
||||
return number_multiply_operation(self, +other)
|
||||
return number_multiply_operation(self, +other).guess_type()
|
||||
|
||||
@overload
|
||||
def __rmul__(self, other: number_types | boolean_types) -> NumberVar: ...
|
||||
@ -194,15 +181,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("*", (type(other), type(self)))
|
||||
|
||||
return number_multiply_operation(+other, self)
|
||||
return number_multiply_operation(+other, self).guess_type()
|
||||
|
||||
@overload
|
||||
def __truediv__(self, other: number_types) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __truediv__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __truediv__(self, other: Any):
|
||||
def __truediv__(self, other: number_types) -> NumberVar:
|
||||
"""Divide two numbers.
|
||||
|
||||
Args:
|
||||
@ -214,15 +195,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("/", (type(self), type(other)))
|
||||
|
||||
return number_true_division_operation(self, +other)
|
||||
return number_true_division_operation(self, +other).guess_type()
|
||||
|
||||
@overload
|
||||
def __rtruediv__(self, other: number_types) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __rtruediv__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __rtruediv__(self, other: Any):
|
||||
def __rtruediv__(self, other: number_types) -> NumberVar:
|
||||
"""Divide two numbers.
|
||||
|
||||
Args:
|
||||
@ -234,15 +209,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("/", (type(other), type(self)))
|
||||
|
||||
return number_true_division_operation(+other, self)
|
||||
return number_true_division_operation(+other, self).guess_type()
|
||||
|
||||
@overload
|
||||
def __floordiv__(self, other: number_types) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __floordiv__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __floordiv__(self, other: Any):
|
||||
def __floordiv__(self, other: number_types) -> NumberVar:
|
||||
"""Floor divide two numbers.
|
||||
|
||||
Args:
|
||||
@ -254,15 +223,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("//", (type(self), type(other)))
|
||||
|
||||
return number_floor_division_operation(self, +other)
|
||||
return number_floor_division_operation(self, +other).guess_type()
|
||||
|
||||
@overload
|
||||
def __rfloordiv__(self, other: number_types) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __rfloordiv__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __rfloordiv__(self, other: Any):
|
||||
def __rfloordiv__(self, other: number_types) -> NumberVar:
|
||||
"""Floor divide two numbers.
|
||||
|
||||
Args:
|
||||
@ -274,15 +237,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("//", (type(other), type(self)))
|
||||
|
||||
return number_floor_division_operation(+other, self)
|
||||
return number_floor_division_operation(+other, self).guess_type()
|
||||
|
||||
@overload
|
||||
def __mod__(self, other: number_types) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __mod__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __mod__(self, other: Any):
|
||||
def __mod__(self, other: number_types) -> NumberVar:
|
||||
"""Modulo two numbers.
|
||||
|
||||
Args:
|
||||
@ -294,15 +251,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("%", (type(self), type(other)))
|
||||
|
||||
return number_modulo_operation(self, +other)
|
||||
return number_modulo_operation(self, +other).guess_type()
|
||||
|
||||
@overload
|
||||
def __rmod__(self, other: number_types) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __rmod__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __rmod__(self, other: Any):
|
||||
def __rmod__(self, other: number_types) -> NumberVar:
|
||||
"""Modulo two numbers.
|
||||
|
||||
Args:
|
||||
@ -314,15 +265,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("%", (type(other), type(self)))
|
||||
|
||||
return number_modulo_operation(+other, self)
|
||||
return number_modulo_operation(+other, self).guess_type()
|
||||
|
||||
@overload
|
||||
def __pow__(self, other: number_types) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __pow__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __pow__(self, other: Any):
|
||||
def __pow__(self, other: number_types) -> NumberVar:
|
||||
"""Exponentiate two numbers.
|
||||
|
||||
Args:
|
||||
@ -334,15 +279,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("**", (type(self), type(other)))
|
||||
|
||||
return number_exponent_operation(self, +other)
|
||||
return number_exponent_operation(self, +other).guess_type()
|
||||
|
||||
@overload
|
||||
def __rpow__(self, other: number_types) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __rpow__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __rpow__(self, other: Any):
|
||||
def __rpow__(self, other: number_types) -> NumberVar:
|
||||
"""Exponentiate two numbers.
|
||||
|
||||
Args:
|
||||
@ -354,7 +293,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("**", (type(other), type(self)))
|
||||
|
||||
return number_exponent_operation(+other, self)
|
||||
return number_exponent_operation(+other, self).guess_type()
|
||||
|
||||
def __neg__(self):
|
||||
"""Negate the number.
|
||||
@ -362,7 +301,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
Returns:
|
||||
The number negation operation.
|
||||
"""
|
||||
return number_negate_operation(self)
|
||||
return number_negate_operation(self).guess_type()
|
||||
|
||||
def __invert__(self):
|
||||
"""Boolean NOT the number.
|
||||
@ -370,7 +309,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
Returns:
|
||||
The boolean NOT operation.
|
||||
"""
|
||||
return boolean_not_operation(self.bool())
|
||||
return boolean_not_operation(self.bool()).guess_type()
|
||||
|
||||
def __pos__(self) -> NumberVar:
|
||||
"""Positive the number.
|
||||
@ -386,7 +325,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
Returns:
|
||||
The number round operation.
|
||||
"""
|
||||
return number_round_operation(self)
|
||||
return number_round_operation(self).guess_type()
|
||||
|
||||
def __ceil__(self):
|
||||
"""Ceil the number.
|
||||
@ -394,7 +333,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
Returns:
|
||||
The number ceil operation.
|
||||
"""
|
||||
return number_ceil_operation(self)
|
||||
return number_ceil_operation(self).guess_type()
|
||||
|
||||
def __floor__(self):
|
||||
"""Floor the number.
|
||||
@ -402,7 +341,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
Returns:
|
||||
The number floor operation.
|
||||
"""
|
||||
return number_floor_operation(self)
|
||||
return number_floor_operation(self).guess_type()
|
||||
|
||||
def __trunc__(self):
|
||||
"""Trunc the number.
|
||||
@ -410,15 +349,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
Returns:
|
||||
The number trunc operation.
|
||||
"""
|
||||
return number_trunc_operation(self)
|
||||
return number_trunc_operation(self).guess_type()
|
||||
|
||||
@overload
|
||||
def __lt__(self, other: number_types) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __lt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __lt__(self, other: Any):
|
||||
def __lt__(self, other: number_types) -> BooleanVar:
|
||||
"""Less than comparison.
|
||||
|
||||
Args:
|
||||
@ -429,15 +362,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
"""
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("<", (type(self), type(other)))
|
||||
return less_than_operation(+self, +other)
|
||||
return less_than_operation(+self, +other).guess_type()
|
||||
|
||||
@overload
|
||||
def __le__(self, other: number_types) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __le__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __le__(self, other: Any):
|
||||
def __le__(self, other: number_types) -> BooleanVar:
|
||||
"""Less than or equal comparison.
|
||||
|
||||
Args:
|
||||
@ -448,9 +375,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
"""
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types("<=", (type(self), type(other)))
|
||||
return less_than_or_equal_operation(+self, +other)
|
||||
return less_than_or_equal_operation(+self, +other).guess_type()
|
||||
|
||||
def __eq__(self, other: Any):
|
||||
def __eq__(self, other: Any) -> BooleanVar:
|
||||
"""Equal comparison.
|
||||
|
||||
Args:
|
||||
@ -460,10 +387,10 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
The result of the comparison.
|
||||
"""
|
||||
if isinstance(other, NUMBER_TYPES):
|
||||
return equal_operation(+self, +other)
|
||||
return equal_operation(self, other)
|
||||
return equal_operation(+self, +other).guess_type()
|
||||
return equal_operation(self, other).guess_type()
|
||||
|
||||
def __ne__(self, other: Any):
|
||||
def __ne__(self, other: Any) -> BooleanVar:
|
||||
"""Not equal comparison.
|
||||
|
||||
Args:
|
||||
@ -473,16 +400,10 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
The result of the comparison.
|
||||
"""
|
||||
if isinstance(other, NUMBER_TYPES):
|
||||
return not_equal_operation(+self, +other)
|
||||
return not_equal_operation(self, other)
|
||||
return not_equal_operation(+self, +other).guess_type()
|
||||
return not_equal_operation(self, other).guess_type()
|
||||
|
||||
@overload
|
||||
def __gt__(self, other: number_types) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __gt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __gt__(self, other: Any):
|
||||
def __gt__(self, other: number_types) -> BooleanVar:
|
||||
"""Greater than comparison.
|
||||
|
||||
Args:
|
||||
@ -493,15 +414,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
"""
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types(">", (type(self), type(other)))
|
||||
return greater_than_operation(+self, +other)
|
||||
return greater_than_operation(+self, +other).guess_type()
|
||||
|
||||
@overload
|
||||
def __ge__(self, other: number_types) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __ge__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
def __ge__(self, other: Any):
|
||||
def __ge__(self, other: number_types) -> BooleanVar:
|
||||
"""Greater than or equal comparison.
|
||||
|
||||
Args:
|
||||
@ -512,7 +427,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
"""
|
||||
if not isinstance(other, NUMBER_TYPES):
|
||||
raise_unsupported_operand_types(">=", (type(self), type(other)))
|
||||
return greater_than_or_equal_operation(+self, +other)
|
||||
return greater_than_or_equal_operation(+self, +other).guess_type()
|
||||
|
||||
def _is_strict_float(self) -> bool:
|
||||
"""Check if the number is a float.
|
||||
@ -532,8 +447,8 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
|
||||
|
||||
def binary_number_operation(
|
||||
func: Callable[[NumberVar, NumberVar], str],
|
||||
) -> Callable[[number_types, number_types], NumberVar]:
|
||||
func: Callable[[Var[int | float], Var[int | float]], str],
|
||||
):
|
||||
"""Decorator to create a binary number operation.
|
||||
|
||||
Args:
|
||||
@ -543,30 +458,37 @@ def binary_number_operation(
|
||||
The binary number operation.
|
||||
"""
|
||||
|
||||
@var_operation
|
||||
def operation(lhs: NumberVar, rhs: NumberVar):
|
||||
def operation(
|
||||
lhs: Var[int | float], rhs: Var[int | float]
|
||||
) -> CustomVarOperationReturn[int | float]:
|
||||
def type_computer(*args: Var):
|
||||
if not args:
|
||||
return (
|
||||
ReflexCallable[[int | float, int | float], int | float],
|
||||
type_computer,
|
||||
)
|
||||
if len(args) == 1:
|
||||
return (
|
||||
ReflexCallable[[int | float], int | float],
|
||||
functools.partial(type_computer, args[0]),
|
||||
)
|
||||
return (
|
||||
ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
|
||||
None,
|
||||
)
|
||||
|
||||
return var_operation_return(
|
||||
js_expression=func(lhs, rhs),
|
||||
var_type=unionize(lhs._var_type, rhs._var_type),
|
||||
type_computer=type_computer,
|
||||
)
|
||||
|
||||
def wrapper(lhs: number_types, rhs: number_types) -> NumberVar:
|
||||
"""Create the binary number operation.
|
||||
operation.__name__ = func.__name__
|
||||
|
||||
Args:
|
||||
lhs: The first number.
|
||||
rhs: The second number.
|
||||
|
||||
Returns:
|
||||
The binary number operation.
|
||||
"""
|
||||
return operation(lhs, rhs) # pyright: ignore [reportReturnType, reportArgumentType]
|
||||
|
||||
return wrapper
|
||||
return var_operation(operation)
|
||||
|
||||
|
||||
@binary_number_operation
|
||||
def number_add_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
def number_add_operation(lhs: Var[int | float], rhs: Var[int | float]):
|
||||
"""Add two numbers.
|
||||
|
||||
Args:
|
||||
@ -580,7 +502,7 @@ def number_add_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
|
||||
|
||||
@binary_number_operation
|
||||
def number_subtract_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
def number_subtract_operation(lhs: Var[int | float], rhs: Var[int | float]):
|
||||
"""Subtract two numbers.
|
||||
|
||||
Args:
|
||||
@ -593,8 +515,15 @@ def number_subtract_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
return f"({lhs} - {rhs})"
|
||||
|
||||
|
||||
unary_operation_type_computer = passthrough_unary_type_computer(
|
||||
ReflexCallable[[int | float], int | float]
|
||||
)
|
||||
|
||||
|
||||
@var_operation
|
||||
def number_abs_operation(value: NumberVar):
|
||||
def number_abs_operation(
|
||||
value: Var[int | float],
|
||||
) -> CustomVarOperationReturn[int | float]:
|
||||
"""Get the absolute value of the number.
|
||||
|
||||
Args:
|
||||
@ -604,12 +533,14 @@ def number_abs_operation(value: NumberVar):
|
||||
The number absolute operation.
|
||||
"""
|
||||
return var_operation_return(
|
||||
js_expression=f"Math.abs({value})", var_type=value._var_type
|
||||
js_expression=f"Math.abs({value})",
|
||||
type_computer=unary_operation_type_computer,
|
||||
_raw_js_function="Math.abs",
|
||||
)
|
||||
|
||||
|
||||
@binary_number_operation
|
||||
def number_multiply_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
def number_multiply_operation(lhs: Var[int | float], rhs: Var[int | float]):
|
||||
"""Multiply two numbers.
|
||||
|
||||
Args:
|
||||
@ -624,7 +555,7 @@ def number_multiply_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
|
||||
@var_operation
|
||||
def number_negate_operation(
|
||||
value: NumberVar[NUMBER_T],
|
||||
value: Var[NUMBER_T],
|
||||
) -> CustomVarOperationReturn[NUMBER_T]:
|
||||
"""Negate the number.
|
||||
|
||||
@ -634,11 +565,13 @@ def number_negate_operation(
|
||||
Returns:
|
||||
The number negation operation.
|
||||
"""
|
||||
return var_operation_return(js_expression=f"-({value})", var_type=value._var_type)
|
||||
return var_operation_return(
|
||||
js_expression=f"-({value})", type_computer=unary_operation_type_computer
|
||||
)
|
||||
|
||||
|
||||
@binary_number_operation
|
||||
def number_true_division_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
def number_true_division_operation(lhs: Var[int | float], rhs: Var[int | float]):
|
||||
"""Divide two numbers.
|
||||
|
||||
Args:
|
||||
@ -652,7 +585,7 @@ def number_true_division_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
|
||||
|
||||
@binary_number_operation
|
||||
def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
def number_floor_division_operation(lhs: Var[int | float], rhs: Var[int | float]):
|
||||
"""Floor divide two numbers.
|
||||
|
||||
Args:
|
||||
@ -666,7 +599,7 @@ def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
|
||||
|
||||
@binary_number_operation
|
||||
def number_modulo_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
def number_modulo_operation(lhs: Var[int | float], rhs: Var[int | float]):
|
||||
"""Modulo two numbers.
|
||||
|
||||
Args:
|
||||
@ -680,7 +613,7 @@ def number_modulo_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
|
||||
|
||||
@binary_number_operation
|
||||
def number_exponent_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
def number_exponent_operation(lhs: Var[int | float], rhs: Var[int | float]):
|
||||
"""Exponentiate two numbers.
|
||||
|
||||
Args:
|
||||
@ -694,7 +627,7 @@ def number_exponent_operation(lhs: NumberVar, rhs: NumberVar):
|
||||
|
||||
|
||||
@var_operation
|
||||
def number_round_operation(value: NumberVar):
|
||||
def number_round_operation(value: Var[int | float]):
|
||||
"""Round the number.
|
||||
|
||||
Args:
|
||||
@ -707,7 +640,7 @@ def number_round_operation(value: NumberVar):
|
||||
|
||||
|
||||
@var_operation
|
||||
def number_ceil_operation(value: NumberVar):
|
||||
def number_ceil_operation(value: Var[int | float]):
|
||||
"""Ceil the number.
|
||||
|
||||
Args:
|
||||
@ -720,7 +653,7 @@ def number_ceil_operation(value: NumberVar):
|
||||
|
||||
|
||||
@var_operation
|
||||
def number_floor_operation(value: NumberVar):
|
||||
def number_floor_operation(value: Var[int | float]):
|
||||
"""Floor the number.
|
||||
|
||||
Args:
|
||||
@ -729,11 +662,15 @@ def number_floor_operation(value: NumberVar):
|
||||
Returns:
|
||||
The number floor operation.
|
||||
"""
|
||||
return var_operation_return(js_expression=f"Math.floor({value})", var_type=int)
|
||||
return var_operation_return(
|
||||
js_expression=f"Math.floor({value})",
|
||||
var_type=int,
|
||||
_raw_js_function="Math.floor",
|
||||
)
|
||||
|
||||
|
||||
@var_operation
|
||||
def number_trunc_operation(value: NumberVar):
|
||||
def number_trunc_operation(value: Var[int | float]):
|
||||
"""Trunc the number.
|
||||
|
||||
Args:
|
||||
@ -754,7 +691,7 @@ class BooleanVar(NumberVar[bool], python_types=bool):
|
||||
Returns:
|
||||
The boolean NOT operation.
|
||||
"""
|
||||
return boolean_not_operation(self)
|
||||
return boolean_not_operation(self).guess_type()
|
||||
|
||||
def __int__(self):
|
||||
"""Convert the boolean to an int.
|
||||
@ -762,7 +699,7 @@ class BooleanVar(NumberVar[bool], python_types=bool):
|
||||
Returns:
|
||||
The boolean to int operation.
|
||||
"""
|
||||
return boolean_to_number_operation(self)
|
||||
return boolean_to_number_operation(self).guess_type()
|
||||
|
||||
def __pos__(self):
|
||||
"""Convert the boolean to an int.
|
||||
@ -770,7 +707,7 @@ class BooleanVar(NumberVar[bool], python_types=bool):
|
||||
Returns:
|
||||
The boolean to int operation.
|
||||
"""
|
||||
return boolean_to_number_operation(self)
|
||||
return boolean_to_number_operation(self).guess_type()
|
||||
|
||||
def bool(self) -> BooleanVar:
|
||||
"""Boolean conversion.
|
||||
@ -826,7 +763,7 @@ class BooleanVar(NumberVar[bool], python_types=bool):
|
||||
|
||||
|
||||
@var_operation
|
||||
def boolean_to_number_operation(value: BooleanVar):
|
||||
def boolean_to_number_operation(value: Var[bool]):
|
||||
"""Convert the boolean to a number.
|
||||
|
||||
Args:
|
||||
@ -835,12 +772,14 @@ def boolean_to_number_operation(value: BooleanVar):
|
||||
Returns:
|
||||
The boolean to number operation.
|
||||
"""
|
||||
return var_operation_return(js_expression=f"Number({value})", var_type=int)
|
||||
return var_operation_return(
|
||||
js_expression=f"Number({value})", var_type=int, _raw_js_function="Number"
|
||||
)
|
||||
|
||||
|
||||
def comparison_operator(
|
||||
func: Callable[[Var, Var], str],
|
||||
) -> Callable[[Var | Any, Var | Any], BooleanVar]:
|
||||
) -> FunctionVar[ReflexCallable[[Any, Any], bool]]:
|
||||
"""Decorator to create a comparison operation.
|
||||
|
||||
Args:
|
||||
@ -850,26 +789,15 @@ def comparison_operator(
|
||||
The comparison operation.
|
||||
"""
|
||||
|
||||
@var_operation
|
||||
def operation(lhs: Var, rhs: Var):
|
||||
def operation(lhs: Var[Any], rhs: Var[Any]):
|
||||
return var_operation_return(
|
||||
js_expression=func(lhs, rhs),
|
||||
var_type=bool,
|
||||
)
|
||||
|
||||
def wrapper(lhs: Var | Any, rhs: Var | Any) -> BooleanVar:
|
||||
"""Create the comparison operation.
|
||||
operation.__name__ = func.__name__
|
||||
|
||||
Args:
|
||||
lhs: The first value.
|
||||
rhs: The second value.
|
||||
|
||||
Returns:
|
||||
The comparison operation.
|
||||
"""
|
||||
return operation(lhs, rhs)
|
||||
|
||||
return wrapper
|
||||
return var_operation(operation)
|
||||
|
||||
|
||||
@comparison_operator
|
||||
@ -957,7 +885,7 @@ def not_equal_operation(lhs: Var, rhs: Var):
|
||||
|
||||
|
||||
@var_operation
|
||||
def boolean_not_operation(value: BooleanVar):
|
||||
def boolean_not_operation(value: Var[bool]):
|
||||
"""Boolean NOT the boolean.
|
||||
|
||||
Args:
|
||||
@ -1081,6 +1009,18 @@ _IS_TRUE_IMPORT: ImportDict = {
|
||||
f"$/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
|
||||
}
|
||||
|
||||
_AT_SLICE_IMPORT: ImportDict = {
|
||||
f"$/{Dirs.STATE_PATH}": [ImportVar(tag="atSlice")],
|
||||
}
|
||||
|
||||
_AT_SLICE_OR_INDEX: ImportDict = {
|
||||
f"$/{Dirs.STATE_PATH}": [ImportVar(tag="atSliceOrIndex")],
|
||||
}
|
||||
|
||||
_RANGE_IMPORT: ImportDict = {
|
||||
f"$/{Dirs.UTILS}/helpers/range": [ImportVar(tag="range", is_default=True)],
|
||||
}
|
||||
|
||||
|
||||
@var_operation
|
||||
def boolify(value: Var):
|
||||
@ -1096,16 +1036,17 @@ def boolify(value: Var):
|
||||
js_expression=f"isTrue({value})",
|
||||
var_type=bool,
|
||||
var_data=VarData(imports=_IS_TRUE_IMPORT),
|
||||
_raw_js_function="isTrue",
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
T = TypeVar("T", bound=Any)
|
||||
U = TypeVar("U", bound=Any)
|
||||
|
||||
|
||||
@var_operation
|
||||
def ternary_operation(
|
||||
condition: BooleanVar, if_true: Var[T], if_false: Var[U]
|
||||
condition: Var[bool], if_true: Var[T], if_false: Var[U]
|
||||
) -> CustomVarOperationReturn[Union[T, U]]:
|
||||
"""Create a ternary operation.
|
||||
|
||||
@ -1117,14 +1058,125 @@ def ternary_operation(
|
||||
Returns:
|
||||
The ternary operation.
|
||||
"""
|
||||
type_value: Union[Type[T], Type[U]] = unionize(
|
||||
if_true._var_type, if_false._var_type
|
||||
)
|
||||
value: CustomVarOperationReturn[Union[T, U]] = var_operation_return(
|
||||
js_expression=f"({condition} ? {if_true} : {if_false})",
|
||||
var_type=type_value,
|
||||
type_computer=nary_type_computer(
|
||||
ReflexCallable[[bool, Any, Any], Any],
|
||||
ReflexCallable[[Any, Any], Any],
|
||||
ReflexCallable[[Any], Any],
|
||||
computer=lambda args: unionize(args[1]._var_type, args[2]._var_type),
|
||||
),
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
TUPLE_ENDS_IN_VAR = tuple[Unpack[tuple[Var[Any], ...]], Var[VAR_TYPE]]
|
||||
|
||||
TUPLE_ENDS_IN_VAR_RELAXED = tuple[
|
||||
Unpack[tuple[Var[Any] | Any, ...]], Var[VAR_TYPE] | VAR_TYPE
|
||||
]
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
slots=True,
|
||||
)
|
||||
class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
|
||||
"""Base class for immutable match operations."""
|
||||
|
||||
_cond: Var[bool] = dataclasses.field(
|
||||
default_factory=lambda: LiteralBooleanVar.create(True)
|
||||
)
|
||||
_cases: tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...] = dataclasses.field(
|
||||
default_factory=tuple
|
||||
)
|
||||
_default: Var[VAR_TYPE] = dataclasses.field( # pyright: ignore[reportAssignmentType]
|
||||
default_factory=lambda: Var.create(None)
|
||||
)
|
||||
|
||||
@cached_property_no_lock
|
||||
def _cached_var_name(self) -> str:
|
||||
"""Get the name of the var.
|
||||
|
||||
Returns:
|
||||
The name of the var.
|
||||
"""
|
||||
switch_code = f"(() => {{ switch (JSON.stringify({self._cond!s})) {{"
|
||||
|
||||
for case in self._cases:
|
||||
conditions = case[:-1]
|
||||
return_value = case[-1]
|
||||
|
||||
case_conditions = " ".join(
|
||||
[f"case JSON.stringify({condition!s}):" for condition in conditions]
|
||||
)
|
||||
case_code = f"{case_conditions} return ({return_value!s}); break;"
|
||||
switch_code += case_code
|
||||
|
||||
switch_code += f"default: return ({self._default!s}); break;"
|
||||
switch_code += "};})()"
|
||||
|
||||
return switch_code
|
||||
|
||||
@cached_property_no_lock
|
||||
def _cached_get_all_var_data(self) -> VarData | None:
|
||||
"""Get the VarData for the var.
|
||||
|
||||
Returns:
|
||||
The VarData for the var.
|
||||
"""
|
||||
return VarData.merge(
|
||||
self._cond._get_all_var_data(),
|
||||
*(
|
||||
cond_or_return._get_all_var_data()
|
||||
for case in self._cases
|
||||
for cond_or_return in case
|
||||
),
|
||||
self._default._get_all_var_data(),
|
||||
self._var_data,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
cond: Any,
|
||||
cases: Sequence[TUPLE_ENDS_IN_VAR_RELAXED[VAR_TYPE]],
|
||||
default: Var[VAR_TYPE] | VAR_TYPE,
|
||||
_var_data: VarData | None = None,
|
||||
_var_type: type[VAR_TYPE] | None = None,
|
||||
):
|
||||
"""Create the match operation.
|
||||
|
||||
Args:
|
||||
cond: The condition.
|
||||
cases: The cases.
|
||||
default: The default case.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
_var_type: The type of the Var.
|
||||
|
||||
Returns:
|
||||
The match operation.
|
||||
"""
|
||||
cond = Var.create(cond)
|
||||
cases = cast(
|
||||
tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
|
||||
tuple(tuple(Var.create(c) for c in case) for case in cases),
|
||||
)
|
||||
|
||||
_default = cast(Var[VAR_TYPE], Var.create(default))
|
||||
var_type = _var_type or unionize(
|
||||
*(case[-1]._var_type for case in cases),
|
||||
_default._var_type,
|
||||
)
|
||||
return cls(
|
||||
_js_expr="",
|
||||
_var_data=_var_data,
|
||||
_var_type=var_type,
|
||||
_cond=cond,
|
||||
_cases=cases,
|
||||
_default=_default,
|
||||
)
|
||||
|
||||
|
||||
NUMBER_TYPES = (int, float, NumberVar)
|
||||
|
@ -10,6 +10,7 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
@ -27,15 +28,19 @@ from reflex.utils.types import (
|
||||
get_attribute_access_type,
|
||||
get_origin,
|
||||
safe_issubclass,
|
||||
unionize,
|
||||
)
|
||||
|
||||
from .base import (
|
||||
CachedVarOperation,
|
||||
LiteralVar,
|
||||
ReflexCallable,
|
||||
Var,
|
||||
VarData,
|
||||
cached_property_no_lock,
|
||||
figure_out_type,
|
||||
nary_type_computer,
|
||||
unary_type_computer,
|
||||
var_operation,
|
||||
var_operation_return,
|
||||
)
|
||||
@ -69,9 +74,9 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
|
||||
) -> Type[VALUE_TYPE]: ...
|
||||
|
||||
@overload
|
||||
def _value_type(self) -> Type: ...
|
||||
def _value_type(self) -> GenericType: ...
|
||||
|
||||
def _value_type(self) -> Type:
|
||||
def _value_type(self) -> GenericType:
|
||||
"""Get the type of the values of the object.
|
||||
|
||||
Returns:
|
||||
@ -83,18 +88,18 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
|
||||
args = get_args(self._var_type) if issubclass(fixed_type, Mapping) else ()
|
||||
return args[1] if args else Any # pyright: ignore [reportReturnType]
|
||||
|
||||
def keys(self) -> ArrayVar[List[str]]:
|
||||
def keys(self) -> ArrayVar[Sequence[str]]:
|
||||
"""Get the keys of the object.
|
||||
|
||||
Returns:
|
||||
The keys of the object.
|
||||
"""
|
||||
return object_keys_operation(self)
|
||||
return object_keys_operation(self).guess_type()
|
||||
|
||||
@overload
|
||||
def values(
|
||||
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
|
||||
) -> ArrayVar[List[VALUE_TYPE]]: ...
|
||||
) -> ArrayVar[Sequence[VALUE_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def values(self) -> ArrayVar: ...
|
||||
@ -105,12 +110,12 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
|
||||
Returns:
|
||||
The values of the object.
|
||||
"""
|
||||
return object_values_operation(self)
|
||||
return object_values_operation(self).guess_type()
|
||||
|
||||
@overload
|
||||
def entries(
|
||||
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
|
||||
) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ...
|
||||
) -> ArrayVar[Sequence[Tuple[str, VALUE_TYPE]]]: ...
|
||||
|
||||
@overload
|
||||
def entries(self) -> ArrayVar: ...
|
||||
@ -121,7 +126,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
|
||||
Returns:
|
||||
The entries of the object.
|
||||
"""
|
||||
return object_entries_operation(self)
|
||||
return object_entries_operation(self).guess_type()
|
||||
|
||||
items = entries
|
||||
|
||||
@ -167,15 +172,10 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
|
||||
self: ObjectVar[Mapping[Any, Sequence[ARRAY_INNER_TYPE]]]
|
||||
| ObjectVar[Mapping[Any, List[ARRAY_INNER_TYPE]]],
|
||||
key: Var | Any,
|
||||
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
|
||||
key: Var | Any,
|
||||
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
|
||||
) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
@ -227,15 +227,9 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
|
||||
|
||||
@overload
|
||||
def __getattr__(
|
||||
self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
|
||||
self: ObjectVar[Mapping[Any, Sequence[ARRAY_INNER_TYPE]]],
|
||||
name: str,
|
||||
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getattr__(
|
||||
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
|
||||
name: str,
|
||||
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
|
||||
) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getattr__(
|
||||
@ -295,7 +289,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
|
||||
Returns:
|
||||
The result of the check.
|
||||
"""
|
||||
return object_has_own_property_operation(self, key)
|
||||
return object_has_own_property_operation(self, key).guess_type()
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
@ -310,7 +304,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
def _key_type(self) -> Type:
|
||||
def _key_type(self) -> GenericType:
|
||||
"""Get the type of the keys of the object.
|
||||
|
||||
Returns:
|
||||
@ -319,7 +313,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
|
||||
args_list = typing.get_args(self._var_type)
|
||||
return args_list[0] if args_list else Any # pyright: ignore [reportReturnType]
|
||||
|
||||
def _value_type(self) -> Type:
|
||||
def _value_type(self) -> GenericType:
|
||||
"""Get the type of the values of the object.
|
||||
|
||||
Returns:
|
||||
@ -416,7 +410,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
|
||||
|
||||
|
||||
@var_operation
|
||||
def object_keys_operation(value: ObjectVar):
|
||||
def object_keys_operation(value: Var):
|
||||
"""Get the keys of an object.
|
||||
|
||||
Args:
|
||||
@ -428,11 +422,12 @@ def object_keys_operation(value: ObjectVar):
|
||||
return var_operation_return(
|
||||
js_expression=f"Object.keys({value})",
|
||||
var_type=List[str],
|
||||
_raw_js_function="Object.keys",
|
||||
)
|
||||
|
||||
|
||||
@var_operation
|
||||
def object_values_operation(value: ObjectVar):
|
||||
def object_values_operation(value: Var):
|
||||
"""Get the values of an object.
|
||||
|
||||
Args:
|
||||
@ -443,12 +438,17 @@ def object_values_operation(value: ObjectVar):
|
||||
"""
|
||||
return var_operation_return(
|
||||
js_expression=f"Object.values({value})",
|
||||
var_type=List[value._value_type()],
|
||||
type_computer=unary_type_computer(
|
||||
ReflexCallable[[Any], List[Any]],
|
||||
lambda x: List[x.to(ObjectVar)._value_type()],
|
||||
),
|
||||
var_type=List[Any],
|
||||
_raw_js_function="Object.values",
|
||||
)
|
||||
|
||||
|
||||
@var_operation
|
||||
def object_entries_operation(value: ObjectVar):
|
||||
def object_entries_operation(value: Var):
|
||||
"""Get the entries of an object.
|
||||
|
||||
Args:
|
||||
@ -457,14 +457,20 @@ def object_entries_operation(value: ObjectVar):
|
||||
Returns:
|
||||
The entries of the object.
|
||||
"""
|
||||
value = value.to(ObjectVar)
|
||||
return var_operation_return(
|
||||
js_expression=f"Object.entries({value})",
|
||||
var_type=List[Tuple[str, value._value_type()]],
|
||||
type_computer=unary_type_computer(
|
||||
ReflexCallable[[Any], List[Tuple[str, Any]]],
|
||||
lambda x: List[Tuple[str, x.to(ObjectVar)._value_type()]],
|
||||
),
|
||||
var_type=List[Tuple[str, Any]],
|
||||
_raw_js_function="Object.entries",
|
||||
)
|
||||
|
||||
|
||||
@var_operation
|
||||
def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
|
||||
def object_merge_operation(lhs: Var, rhs: Var):
|
||||
"""Merge two objects.
|
||||
|
||||
Args:
|
||||
@ -476,10 +482,15 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
|
||||
"""
|
||||
return var_operation_return(
|
||||
js_expression=f"({{...{lhs}, ...{rhs}}})",
|
||||
var_type=Mapping[
|
||||
Union[lhs._key_type(), rhs._key_type()],
|
||||
Union[lhs._value_type(), rhs._value_type()],
|
||||
],
|
||||
type_computer=nary_type_computer(
|
||||
ReflexCallable[[Any, Any], Mapping[Any, Any]],
|
||||
ReflexCallable[[Any], Mapping[Any, Any]],
|
||||
computer=lambda args: Mapping[
|
||||
unionize(*[arg.to(ObjectVar)._key_type() for arg in args]),
|
||||
unionize(*[arg.to(ObjectVar)._value_type() for arg in args]),
|
||||
],
|
||||
),
|
||||
var_type=Mapping[Any, Any],
|
||||
)
|
||||
|
||||
|
||||
@ -536,7 +547,7 @@ class ObjectItemOperation(CachedVarOperation, Var):
|
||||
|
||||
|
||||
@var_operation
|
||||
def object_has_own_property_operation(object: ObjectVar, key: Var):
|
||||
def object_has_own_property_operation(object: Var, key: Var):
|
||||
"""Check if an object has a key.
|
||||
|
||||
Args:
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -28,9 +28,11 @@ def TestEventAction():
|
||||
def on_click2(self):
|
||||
self.order.append("on_click2")
|
||||
|
||||
@rx.event
|
||||
def on_click_throttle(self):
|
||||
self.order.append("on_click_throttle")
|
||||
|
||||
@rx.event
|
||||
def on_click_debounce(self):
|
||||
self.order.append("on_click_debounce")
|
||||
|
||||
|
@ -17,45 +17,68 @@ def LifespanApp():
|
||||
|
||||
import reflex as rx
|
||||
|
||||
lifespan_task_global = 0
|
||||
lifespan_context_global = 0
|
||||
def create_tasks():
|
||||
lifespan_task_global = 0
|
||||
lifespan_context_global = 0
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan_context(app, inc: int = 1):
|
||||
global lifespan_context_global
|
||||
print(f"Lifespan context entered: {app}.")
|
||||
lifespan_context_global += inc # pyright: ignore[reportUnboundVariable]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
print("Lifespan context exited.")
|
||||
lifespan_context_global += inc
|
||||
|
||||
async def lifespan_task(inc: int = 1):
|
||||
global lifespan_task_global
|
||||
print("Lifespan global started.")
|
||||
try:
|
||||
while True:
|
||||
lifespan_task_global += inc # pyright: ignore[reportUnboundVariable, reportPossiblyUnboundVariable]
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError as ce:
|
||||
print(f"Lifespan global cancelled: {ce}.")
|
||||
lifespan_task_global = 0
|
||||
|
||||
class LifespanState(rx.State):
|
||||
interval: int = 100
|
||||
|
||||
@rx.var(cache=False)
|
||||
def task_global(self) -> int:
|
||||
return lifespan_task_global
|
||||
|
||||
@rx.var(cache=False)
|
||||
def context_global(self) -> int:
|
||||
def lifespan_context_global_getter():
|
||||
return lifespan_context_global
|
||||
|
||||
@rx.event
|
||||
def tick(self, date):
|
||||
pass
|
||||
def lifespan_task_global_getter():
|
||||
return lifespan_task_global
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan_context(app, inc: int = 1):
|
||||
nonlocal lifespan_context_global
|
||||
print(f"Lifespan context entered: {app}.")
|
||||
lifespan_context_global += inc
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
print("Lifespan context exited.")
|
||||
lifespan_context_global += inc
|
||||
|
||||
async def lifespan_task(inc: int = 1):
|
||||
nonlocal lifespan_task_global
|
||||
print("Lifespan global started.")
|
||||
try:
|
||||
while True:
|
||||
lifespan_task_global += inc
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError as ce:
|
||||
print(f"Lifespan global cancelled: {ce}.")
|
||||
lifespan_task_global = 0
|
||||
|
||||
class LifespanState(rx.State):
|
||||
interval: int = 100
|
||||
|
||||
@rx.var(cache=False)
|
||||
def task_global(self) -> int:
|
||||
return lifespan_task_global
|
||||
|
||||
@rx.var(cache=False)
|
||||
def context_global(self) -> int:
|
||||
return lifespan_context_global
|
||||
|
||||
@rx.event
|
||||
def tick(self, date):
|
||||
pass
|
||||
|
||||
return (
|
||||
lifespan_task,
|
||||
lifespan_context,
|
||||
LifespanState,
|
||||
lifespan_task_global_getter,
|
||||
lifespan_context_global_getter,
|
||||
)
|
||||
|
||||
(
|
||||
lifespan_task,
|
||||
lifespan_context,
|
||||
LifespanState,
|
||||
lifespan_task_global_getter,
|
||||
lifespan_context_global_getter,
|
||||
) = create_tasks()
|
||||
|
||||
def index():
|
||||
return rx.vstack(
|
||||
@ -113,13 +136,16 @@ async def test_lifespan(lifespan_app: AppHarness):
|
||||
task_global = driver.find_element(By.ID, "task_global")
|
||||
|
||||
assert context_global.text == "2"
|
||||
assert lifespan_app.app_module.lifespan_context_global == 2
|
||||
assert lifespan_app.app_module.lifespan_context_global_getter() == 2
|
||||
|
||||
original_task_global_text = task_global.text
|
||||
original_task_global_value = int(original_task_global_text)
|
||||
lifespan_app.poll_for_content(task_global, exp_not_equal=original_task_global_text)
|
||||
driver.find_element(By.ID, "toggle-tick").click() # avoid teardown errors
|
||||
assert lifespan_app.app_module.lifespan_task_global > original_task_global_value
|
||||
assert (
|
||||
lifespan_app.app_module.lifespan_task_global_getter()
|
||||
> original_task_global_value
|
||||
)
|
||||
assert int(task_global.text) > original_task_global_value
|
||||
|
||||
# Kill the backend
|
||||
@ -129,5 +155,5 @@ async def test_lifespan(lifespan_app: AppHarness):
|
||||
lifespan_app.backend_thread.join()
|
||||
|
||||
# Check that the lifespan tasks have been cancelled
|
||||
assert lifespan_app.app_module.lifespan_task_global == 0
|
||||
assert lifespan_app.app_module.lifespan_context_global == 4
|
||||
assert lifespan_app.app_module.lifespan_task_global_getter() == 0
|
||||
assert lifespan_app.app_module.lifespan_context_global_getter() == 4
|
||||
|
@ -87,7 +87,7 @@ def UploadFile():
|
||||
),
|
||||
rx.box(
|
||||
rx.foreach(
|
||||
rx.selected_files,
|
||||
rx.selected_files(),
|
||||
lambda f: rx.text(f, as_="p"),
|
||||
),
|
||||
id="selected_files",
|
||||
|
@ -61,7 +61,7 @@ def ColorToggleApp():
|
||||
rx.icon(tag="moon", size=20),
|
||||
value="dark",
|
||||
),
|
||||
on_change=set_color_mode,
|
||||
on_change=set_color_mode(),
|
||||
variant="classic",
|
||||
radius="large",
|
||||
value=color_mode,
|
||||
|
@ -25,6 +25,7 @@ def test_connection_banner():
|
||||
"react",
|
||||
"$/utils/context",
|
||||
"$/utils/state",
|
||||
"@emotion/react",
|
||||
RadixThemesComponent().library or "",
|
||||
"$/env.json",
|
||||
)
|
||||
@ -43,6 +44,7 @@ def test_connection_modal():
|
||||
"react",
|
||||
"$/utils/context",
|
||||
"$/utils/state",
|
||||
"@emotion/react",
|
||||
RadixThemesComponent().library or "",
|
||||
"$/env.json",
|
||||
)
|
||||
|
@ -3,8 +3,7 @@ from typing import Any, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from reflex.components.base.fragment import Fragment
|
||||
from reflex.components.core.cond import Cond, cond
|
||||
from reflex.components.core.cond import cond
|
||||
from reflex.components.radix.themes.typography.text import Text
|
||||
from reflex.state import BaseState
|
||||
from reflex.utils.format import format_state_name
|
||||
@ -40,32 +39,23 @@ def test_validate_cond(cond_state: BaseState):
|
||||
Args:
|
||||
cond_state: A fixture.
|
||||
"""
|
||||
cond_component = cond(
|
||||
first_component = Text.create("cond is True")
|
||||
second_component = Text.create("cond is False")
|
||||
cond_var = cond(
|
||||
cond_state.value,
|
||||
Text.create("cond is True"),
|
||||
Text.create("cond is False"),
|
||||
first_component,
|
||||
second_component,
|
||||
)
|
||||
cond_dict = cond_component.render() if type(cond_component) is Fragment else {}
|
||||
assert cond_dict["name"] == "Fragment"
|
||||
|
||||
[condition] = cond_dict["children"]
|
||||
assert condition["cond_state"] == f"isTrue({cond_state.get_full_name()}.value)"
|
||||
assert isinstance(cond_var, Var)
|
||||
assert (
|
||||
str(cond_var)
|
||||
== f'({cond_state.value.bool()!s} ? (jsx(RadixThemesText, ({{ ["as"] : "p" }}), (jsx(Fragment, ({{ }}), "cond is True")))) : (jsx(RadixThemesText, ({{ ["as"] : "p" }}), (jsx(Fragment, ({{ }}), "cond is False")))))'
|
||||
)
|
||||
|
||||
# true value
|
||||
true_value = condition["true_value"]
|
||||
assert true_value["name"] == "Fragment"
|
||||
|
||||
[true_value_text] = true_value["children"]
|
||||
assert true_value_text["name"] == "RadixThemesText"
|
||||
assert true_value_text["children"][0]["contents"] == '{"cond is True"}'
|
||||
|
||||
# false value
|
||||
false_value = condition["false_value"]
|
||||
assert false_value["name"] == "Fragment"
|
||||
|
||||
[false_value_text] = false_value["children"]
|
||||
assert false_value_text["name"] == "RadixThemesText"
|
||||
assert false_value_text["children"][0]["contents"] == '{"cond is False"}'
|
||||
var_data = cond_var._get_all_var_data()
|
||||
assert var_data is not None
|
||||
assert var_data.components == (first_component, second_component)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -99,22 +89,25 @@ def test_prop_cond(c1: Any, c2: Any):
|
||||
assert str(prop_cond) == f"(true ? {c1!s} : {c2!s})"
|
||||
|
||||
|
||||
def test_cond_no_mix():
|
||||
"""Test if cond can't mix components and props."""
|
||||
with pytest.raises(ValueError):
|
||||
cond(True, LiteralVar.create("hello"), Text.create("world"))
|
||||
def test_cond_mix():
|
||||
"""Test if cond can mix components and props."""
|
||||
x = cond(True, LiteralVar.create("hello"), Text.create("world"))
|
||||
assert isinstance(x, Var)
|
||||
assert (
|
||||
str(x)
|
||||
== '(true ? "hello" : (jsx(RadixThemesText, ({ ["as"] : "p" }), (jsx(Fragment, ({ }), "world")))))'
|
||||
)
|
||||
|
||||
|
||||
def test_cond_no_else():
|
||||
"""Test if cond can be used without else."""
|
||||
# Components should support the use of cond without else
|
||||
comp = cond(True, Text.create("hello"))
|
||||
assert isinstance(comp, Fragment)
|
||||
comp = comp.children[0]
|
||||
assert isinstance(comp, Cond)
|
||||
assert comp.cond._decode() is True
|
||||
assert comp.comp1.render() == Fragment.create(Text.create("hello")).render() # pyright: ignore [reportOptionalMemberAccess]
|
||||
assert comp.comp2 == Fragment.create()
|
||||
assert isinstance(comp, Var)
|
||||
assert (
|
||||
str(comp)
|
||||
== '(true ? (jsx(RadixThemesText, ({ ["as"] : "p" }), (jsx(Fragment, ({ }), "hello")))) : (jsx(Fragment, ({ }))))'
|
||||
)
|
||||
|
||||
# Props do not support the use of cond without else
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Set, Tuple, Union
|
||||
from typing import Dict, List, Sequence, Set, Tuple, Union
|
||||
|
||||
import pydantic.v1
|
||||
import pytest
|
||||
@ -6,16 +6,11 @@ import pytest
|
||||
from reflex import el
|
||||
from reflex.base import Base
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.core.foreach import (
|
||||
Foreach,
|
||||
ForeachRenderError,
|
||||
ForeachVarError,
|
||||
foreach,
|
||||
)
|
||||
from reflex.components.core.foreach import ForeachVarError, foreach
|
||||
from reflex.components.radix.themes.layout.box import box
|
||||
from reflex.components.radix.themes.typography.text import text
|
||||
from reflex.state import BaseState, ComponentState
|
||||
from reflex.vars.base import Var
|
||||
from reflex.utils.exceptions import VarTypeError
|
||||
from reflex.vars.number import NumberVar
|
||||
from reflex.vars.sequence import ArrayVar
|
||||
|
||||
@ -125,7 +120,9 @@ def display_colors_set(color):
|
||||
return box(text(color))
|
||||
|
||||
|
||||
def display_nested_list_element(element: ArrayVar[List[str]], index: NumberVar[int]):
|
||||
def display_nested_list_element(
|
||||
element: ArrayVar[Sequence[str]], index: NumberVar[int]
|
||||
):
|
||||
assert element._var_type == List[str]
|
||||
assert index._var_type is int
|
||||
return box(text(element[index]))
|
||||
@ -139,143 +136,35 @@ def display_color_index_tuple(color):
|
||||
seen_index_vars = set()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state_var, render_fn, render_dict",
|
||||
[
|
||||
(
|
||||
ForEachState.colors_list,
|
||||
display_color,
|
||||
{
|
||||
"iterable_state": f"{ForEachState.get_full_name()}.colors_list",
|
||||
"iterable_type": "list",
|
||||
},
|
||||
),
|
||||
(
|
||||
ForEachState.colors_dict_list,
|
||||
display_color_name,
|
||||
{
|
||||
"iterable_state": f"{ForEachState.get_full_name()}.colors_dict_list",
|
||||
"iterable_type": "list",
|
||||
},
|
||||
),
|
||||
(
|
||||
ForEachState.colors_nested_dict_list,
|
||||
display_shade,
|
||||
{
|
||||
"iterable_state": f"{ForEachState.get_full_name()}.colors_nested_dict_list",
|
||||
"iterable_type": "list",
|
||||
},
|
||||
),
|
||||
(
|
||||
ForEachState.primary_color,
|
||||
display_primary_colors,
|
||||
{
|
||||
"iterable_state": f"Object.entries({ForEachState.get_full_name()}.primary_color)",
|
||||
"iterable_type": "list",
|
||||
},
|
||||
),
|
||||
(
|
||||
ForEachState.color_with_shades,
|
||||
display_color_with_shades,
|
||||
{
|
||||
"iterable_state": f"Object.entries({ForEachState.get_full_name()}.color_with_shades)",
|
||||
"iterable_type": "list",
|
||||
},
|
||||
),
|
||||
(
|
||||
ForEachState.nested_colors_with_shades,
|
||||
display_nested_color_with_shades,
|
||||
{
|
||||
"iterable_state": f"Object.entries({ForEachState.get_full_name()}.nested_colors_with_shades)",
|
||||
"iterable_type": "list",
|
||||
},
|
||||
),
|
||||
(
|
||||
ForEachState.nested_colors_with_shades,
|
||||
display_nested_color_with_shades_v2,
|
||||
{
|
||||
"iterable_state": f"Object.entries({ForEachState.get_full_name()}.nested_colors_with_shades)",
|
||||
"iterable_type": "list",
|
||||
},
|
||||
),
|
||||
(
|
||||
ForEachState.color_tuple,
|
||||
display_color_tuple,
|
||||
{
|
||||
"iterable_state": f"{ForEachState.get_full_name()}.color_tuple",
|
||||
"iterable_type": "tuple",
|
||||
},
|
||||
),
|
||||
(
|
||||
ForEachState.colors_set,
|
||||
display_colors_set,
|
||||
{
|
||||
"iterable_state": f"{ForEachState.get_full_name()}.colors_set",
|
||||
"iterable_type": "set",
|
||||
},
|
||||
),
|
||||
(
|
||||
ForEachState.nested_colors_list,
|
||||
lambda el, i: display_nested_list_element(el, i),
|
||||
{
|
||||
"iterable_state": f"{ForEachState.get_full_name()}.nested_colors_list",
|
||||
"iterable_type": "list",
|
||||
},
|
||||
),
|
||||
(
|
||||
ForEachState.color_index_tuple,
|
||||
display_color_index_tuple,
|
||||
{
|
||||
"iterable_state": f"{ForEachState.get_full_name()}.color_index_tuple",
|
||||
"iterable_type": "tuple",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_foreach_render(state_var, render_fn, render_dict):
|
||||
"""Test that the foreach component renders without error.
|
||||
|
||||
Args:
|
||||
state_var: the state var.
|
||||
render_fn: The render callable
|
||||
render_dict: return dict on calling `component.render`
|
||||
"""
|
||||
component = Foreach.create(state_var, render_fn)
|
||||
|
||||
rend = component.render()
|
||||
assert rend["iterable_state"] == render_dict["iterable_state"]
|
||||
assert rend["iterable_type"] == render_dict["iterable_type"]
|
||||
|
||||
# Make sure the index vars are unique.
|
||||
arg_index = rend["arg_index"]
|
||||
assert isinstance(arg_index, Var)
|
||||
assert arg_index._js_expr not in seen_index_vars
|
||||
assert arg_index._var_type is int
|
||||
seen_index_vars.add(arg_index._js_expr)
|
||||
|
||||
|
||||
def test_foreach_bad_annotations():
|
||||
"""Test that the foreach component raises a ForeachVarError if the iterable is of type Any."""
|
||||
with pytest.raises(ForeachVarError):
|
||||
Foreach.create(
|
||||
foreach(
|
||||
ForEachState.bad_annotation_list,
|
||||
lambda sublist: Foreach.create(sublist, lambda color: text(color)),
|
||||
lambda sublist: foreach(sublist, lambda color: text(color)),
|
||||
)
|
||||
|
||||
|
||||
def test_foreach_no_param_in_signature():
|
||||
"""Test that the foreach component raises a ForeachRenderError if no parameters are passed."""
|
||||
with pytest.raises(ForeachRenderError):
|
||||
Foreach.create(
|
||||
ForEachState.colors_list,
|
||||
lambda: text("color"),
|
||||
)
|
||||
"""Test that the foreach component DOES NOT raise an error if no parameters are passed."""
|
||||
foreach(
|
||||
ForEachState.colors_list,
|
||||
lambda: text("color"),
|
||||
)
|
||||
|
||||
|
||||
def test_foreach_with_index():
|
||||
"""Test that the foreach component works with an index."""
|
||||
foreach(
|
||||
ForEachState.colors_list,
|
||||
lambda color, index: text(color, index),
|
||||
)
|
||||
|
||||
|
||||
def test_foreach_too_many_params_in_signature():
|
||||
"""Test that the foreach component raises a ForeachRenderError if too many parameters are passed."""
|
||||
with pytest.raises(ForeachRenderError):
|
||||
Foreach.create(
|
||||
with pytest.raises(VarTypeError):
|
||||
foreach(
|
||||
ForEachState.colors_list,
|
||||
lambda color, index, extra: text(color),
|
||||
)
|
||||
@ -290,13 +179,13 @@ def test_foreach_component_styles():
|
||||
)
|
||||
)
|
||||
component._add_style_recursive({box: {"color": "red"}})
|
||||
assert 'css={({ ["color"] : "red" })}' in str(component)
|
||||
assert '{ ["css"] : ({ ["color"] : "red" }) }' in str(component)
|
||||
|
||||
|
||||
def test_foreach_component_state():
|
||||
"""Test that using a component state to render in the foreach raises an error."""
|
||||
with pytest.raises(TypeError):
|
||||
Foreach.create(
|
||||
foreach(
|
||||
ForEachState.colors_list,
|
||||
ComponentStateTest.create,
|
||||
)
|
||||
@ -304,7 +193,7 @@ def test_foreach_component_state():
|
||||
|
||||
def test_foreach_default_factory():
|
||||
"""Test that the default factory is called."""
|
||||
_ = Foreach.create(
|
||||
_ = foreach(
|
||||
ForEachState.default_factory_list,
|
||||
lambda tag: text(tag.name),
|
||||
)
|
||||
|
@ -1,10 +1,10 @@
|
||||
from typing import List, Mapping, Tuple
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
import reflex as rx
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.core.match import Match
|
||||
from reflex.components.core.match import match
|
||||
from reflex.state import BaseState
|
||||
from reflex.utils.exceptions import MatchTypeError
|
||||
from reflex.vars.base import Var
|
||||
@ -18,75 +18,6 @@ class MatchState(BaseState):
|
||||
string: str = "random string"
|
||||
|
||||
|
||||
def test_match_components():
|
||||
"""Test matching cases with return values as components."""
|
||||
match_case_tuples = (
|
||||
(1, rx.text("first value")),
|
||||
(2, 3, rx.text("second value")),
|
||||
([1, 2], rx.text("third value")),
|
||||
("random", rx.text("fourth value")),
|
||||
({"foo": "bar"}, rx.text("fifth value")),
|
||||
(MatchState.num + 1, rx.text("sixth value")),
|
||||
rx.text("default value"),
|
||||
)
|
||||
match_comp = Match.create(MatchState.value, *match_case_tuples)
|
||||
|
||||
assert isinstance(match_comp, Component)
|
||||
match_dict = match_comp.render()
|
||||
assert match_dict["name"] == "Fragment"
|
||||
|
||||
[match_child] = match_dict["children"]
|
||||
|
||||
assert match_child["name"] == "match"
|
||||
assert str(match_child["cond"]) == f"{MatchState.get_name()}.value"
|
||||
|
||||
match_cases = match_child["match_cases"]
|
||||
assert len(match_cases) == 6
|
||||
|
||||
assert match_cases[0][0]._js_expr == "1"
|
||||
assert match_cases[0][0]._var_type is int
|
||||
first_return_value_render = match_cases[0][1]
|
||||
assert first_return_value_render["name"] == "RadixThemesText"
|
||||
assert first_return_value_render["children"][0]["contents"] == '{"first value"}'
|
||||
|
||||
assert match_cases[1][0]._js_expr == "2"
|
||||
assert match_cases[1][0]._var_type is int
|
||||
assert match_cases[1][1]._js_expr == "3"
|
||||
assert match_cases[1][1]._var_type is int
|
||||
second_return_value_render = match_cases[1][2]
|
||||
assert second_return_value_render["name"] == "RadixThemesText"
|
||||
assert second_return_value_render["children"][0]["contents"] == '{"second value"}'
|
||||
|
||||
assert match_cases[2][0]._js_expr == "[1, 2]"
|
||||
assert match_cases[2][0]._var_type == List[int]
|
||||
third_return_value_render = match_cases[2][1]
|
||||
assert third_return_value_render["name"] == "RadixThemesText"
|
||||
assert third_return_value_render["children"][0]["contents"] == '{"third value"}'
|
||||
|
||||
assert match_cases[3][0]._js_expr == '"random"'
|
||||
assert match_cases[3][0]._var_type is str
|
||||
fourth_return_value_render = match_cases[3][1]
|
||||
assert fourth_return_value_render["name"] == "RadixThemesText"
|
||||
assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}'
|
||||
|
||||
assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })'
|
||||
assert match_cases[4][0]._var_type == Mapping[str, str]
|
||||
fifth_return_value_render = match_cases[4][1]
|
||||
assert fifth_return_value_render["name"] == "RadixThemesText"
|
||||
assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}'
|
||||
|
||||
assert match_cases[5][0]._js_expr == f"({MatchState.get_name()}.num + 1)"
|
||||
assert match_cases[5][0]._var_type is int
|
||||
fifth_return_value_render = match_cases[5][1]
|
||||
assert fifth_return_value_render["name"] == "RadixThemesText"
|
||||
assert fifth_return_value_render["children"][0]["contents"] == '{"sixth value"}'
|
||||
|
||||
default = match_child["default"]
|
||||
|
||||
assert default["name"] == "RadixThemesText"
|
||||
assert default["children"][0]["contents"] == '{"default value"}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cases, expected",
|
||||
[
|
||||
@ -137,7 +68,7 @@ def test_match_vars(cases, expected):
|
||||
cases: The match cases.
|
||||
expected: The expected var full name.
|
||||
"""
|
||||
match_comp = Match.create(MatchState.value, *cases)
|
||||
match_comp = match(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
|
||||
assert isinstance(match_comp, Var)
|
||||
assert str(match_comp) == expected
|
||||
|
||||
@ -146,18 +77,14 @@ def test_match_on_component_without_default():
|
||||
"""Test that matching cases with return values as components returns a Fragment
|
||||
as the default case if not provided.
|
||||
"""
|
||||
from reflex.components.base.fragment import Fragment
|
||||
|
||||
match_case_tuples = (
|
||||
(1, rx.text("first value")),
|
||||
(2, 3, rx.text("second value")),
|
||||
)
|
||||
|
||||
match_comp = Match.create(MatchState.value, *match_case_tuples)
|
||||
assert isinstance(match_comp, Component)
|
||||
default = match_comp.render()["children"][0]["default"]
|
||||
match_comp = match(MatchState.value, *match_case_tuples)
|
||||
|
||||
assert isinstance(default, dict) and default["name"] == Fragment.__name__
|
||||
assert isinstance(match_comp, Var)
|
||||
|
||||
|
||||
def test_match_on_var_no_default():
|
||||
@ -172,7 +99,7 @@ def test_match_on_var_no_default():
|
||||
ValueError,
|
||||
match="For cases with return types as Vars, a default case must be provided",
|
||||
):
|
||||
Match.create(MatchState.value, *match_case_tuples)
|
||||
match(MatchState.value, *match_case_tuples)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -205,7 +132,7 @@ def test_match_default_not_last_arg(match_case):
|
||||
ValueError,
|
||||
match="rx.match should have tuples of cases and a default case as the last argument.",
|
||||
):
|
||||
Match.create(MatchState.value, *match_case)
|
||||
match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -235,7 +162,7 @@ def test_match_case_tuple_elements(match_case):
|
||||
ValueError,
|
||||
match="A case tuple should have at least a match case element and a return value.",
|
||||
):
|
||||
Match.create(MatchState.value, *match_case)
|
||||
match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -251,8 +178,7 @@ def test_match_case_tuple_elements(match_case):
|
||||
(MatchState.num + 1, "black"),
|
||||
rx.text("default value"),
|
||||
),
|
||||
'Match cases should have the same return types. Case 3 with return value `"red"` of type '
|
||||
"<class 'reflex.vars.sequence.LiteralStringVar'> is not <class 'reflex.components.component.BaseComponent'>",
|
||||
"Match cases should have the same return types. Expected return types to be of type Component or Var[Component]. Return type of case 3 is <class 'str'>. Return type of case 4 is <class 'str'>. Return type of case 5 is <class 'str'>",
|
||||
),
|
||||
(
|
||||
(
|
||||
@ -264,8 +190,7 @@ def test_match_case_tuple_elements(match_case):
|
||||
([1, 2], rx.text("third value")),
|
||||
rx.text("default value"),
|
||||
),
|
||||
'Match cases should have the same return types. Case 3 with return value `<RadixThemesText as={"p"}> {"first value"} </RadixThemesText>` '
|
||||
"of type <class 'reflex.components.radix.themes.typography.text.Text'> is not <class 'reflex.vars.base.Var'>",
|
||||
"Match cases should have the same return types. Expected return types to be of type Component or Var[Component]. Return type of case 0 is <class 'str'>. Return type of case 1 is <class 'str'>. Return type of case 2 is <class 'str'>",
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -276,8 +201,8 @@ def test_match_different_return_types(cases: Tuple, error_msg: str):
|
||||
cases: The match cases.
|
||||
error_msg: Expected error message.
|
||||
"""
|
||||
with pytest.raises(MatchTypeError, match=error_msg):
|
||||
Match.create(MatchState.value, *cases)
|
||||
with pytest.raises(MatchTypeError, match=re.escape(error_msg)):
|
||||
match(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -309,9 +234,9 @@ def test_match_multiple_default_cases(match_case):
|
||||
match_case: the cases to match.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="rx.match can only have one default case."):
|
||||
Match.create(MatchState.value, *match_case)
|
||||
match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
|
||||
|
||||
|
||||
def test_match_no_cond():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Match.create(None)
|
||||
_ = match(None) # pyright: ignore[reportCallIssue]
|
||||
|
@ -1457,13 +1457,10 @@ def test_instantiate_all_components():
|
||||
# These components all have required arguments and cannot be trivially instantiated.
|
||||
untested_components = {
|
||||
"Card",
|
||||
"Cond",
|
||||
"DebounceInput",
|
||||
"Foreach",
|
||||
"FormControl",
|
||||
"Html",
|
||||
"Icon",
|
||||
"Match",
|
||||
"Markdown",
|
||||
"MultiSelect",
|
||||
"Option",
|
||||
@ -2156,14 +2153,11 @@ def test_add_style_foreach():
|
||||
page = rx.vstack(rx.foreach(Var.range(3), lambda i: StyledComponent.create(i)))
|
||||
page._add_style_recursive(Style())
|
||||
|
||||
# Expect only a single child of the foreach on the python side
|
||||
assert len(page.children[0].children) == 1
|
||||
|
||||
# Expect the style to be added to the child of the foreach
|
||||
assert 'css={({ ["color"] : "red" })}' in str(page.children[0].children[0])
|
||||
assert '({ ["css"] : ({ ["color"] : "red" }) }),' in str(page.children[0])
|
||||
|
||||
# Expect only one instance of this CSS dict in the rendered page
|
||||
assert str(page).count('css={({ ["color"] : "red" })}') == 1
|
||||
assert str(page).count('({ ["css"] : ({ ["color"] : "red" }) }),') == 1
|
||||
|
||||
|
||||
class TriggerState(rx.State):
|
||||
|
@ -2,7 +2,7 @@ from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from reflex.components.tags import CondTag, Tag, tagless
|
||||
from reflex.components.tags import Tag, tagless
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
|
||||
|
||||
@ -105,29 +105,6 @@ def test_format_tag(tag: Tag, expected: Dict):
|
||||
assert prop_value.equals(LiteralVar.create(expected["props"][prop]))
|
||||
|
||||
|
||||
def test_format_cond_tag():
|
||||
"""Test that the cond tag dict is correct."""
|
||||
tag = CondTag(
|
||||
true_value=dict(Tag(name="h1", contents="True content")),
|
||||
false_value=dict(Tag(name="h2", contents="False content")),
|
||||
cond=Var(_js_expr="logged_in", _var_type=bool),
|
||||
)
|
||||
tag_dict = dict(tag)
|
||||
cond, true_value, false_value = (
|
||||
tag_dict["cond"],
|
||||
tag_dict["true_value"],
|
||||
tag_dict["false_value"],
|
||||
)
|
||||
assert cond._js_expr == "logged_in"
|
||||
assert cond._var_type is bool
|
||||
|
||||
assert true_value["name"] == "h1"
|
||||
assert true_value["contents"] == "True content"
|
||||
|
||||
assert false_value["name"] == "h2"
|
||||
assert false_value["contents"] == "False content"
|
||||
|
||||
|
||||
def test_tagless_string_representation():
|
||||
"""Test that the string representation of a tagless is correct."""
|
||||
tag = tagless.Tagless(contents="Hello world")
|
||||
|
@ -9,7 +9,7 @@ import unittest.mock
|
||||
import uuid
|
||||
from contextlib import nullcontext as does_not_raise
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Tuple, Type
|
||||
from typing import Generator, List, Tuple, Type, cast
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
@ -31,9 +31,8 @@ from reflex.app import (
|
||||
)
|
||||
from reflex.components import Component
|
||||
from reflex.components.base.fragment import Fragment
|
||||
from reflex.components.core.cond import Cond
|
||||
from reflex.components.radix.themes.typography.text import Text
|
||||
from reflex.event import Event
|
||||
from reflex.event import Event, EventHandler
|
||||
from reflex.middleware import HydrateMiddleware
|
||||
from reflex.model import Model
|
||||
from reflex.state import (
|
||||
@ -916,7 +915,7 @@ class DynamicState(BaseState):
|
||||
"""
|
||||
return self.dynamic
|
||||
|
||||
on_load_internal = OnLoadInternalState.on_load_internal.fn # pyright: ignore [reportFunctionMemberAccess]
|
||||
on_load_internal = cast(EventHandler, OnLoadInternalState.on_load_internal).fn
|
||||
|
||||
|
||||
def test_dynamic_arg_shadow(
|
||||
@ -1189,7 +1188,7 @@ async def test_process_events(mocker, token: str):
|
||||
pass
|
||||
|
||||
assert (await app.state_manager.get_state(event.substate_token)).value == 5
|
||||
assert app._postprocess.call_count == 6 # pyright: ignore [reportFunctionMemberAccess]
|
||||
assert getattr(app._postprocess, "call_count", None) == 6
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
await app.state_manager.close()
|
||||
@ -1227,10 +1226,6 @@ def test_overlay_component(
|
||||
assert app.overlay_component is not None
|
||||
generated_component = app._generate_component(app.overlay_component)
|
||||
assert isinstance(generated_component, OverlayFragment)
|
||||
assert isinstance(
|
||||
generated_component.children[0],
|
||||
Cond, # ConnectionModal is a Cond under the hood
|
||||
)
|
||||
else:
|
||||
assert app.overlay_component is not None
|
||||
assert isinstance(
|
||||
@ -1246,8 +1241,8 @@ def test_overlay_component(
|
||||
|
||||
if exp_page_child is not None:
|
||||
assert len(page.children) == 3
|
||||
children_types = (type(child) for child in page.children)
|
||||
assert exp_page_child in children_types # pyright: ignore [reportOperatorIssue]
|
||||
children_types = [type(child) for child in page.children]
|
||||
assert exp_page_child in children_types
|
||||
else:
|
||||
assert len(page.children) == 2
|
||||
|
||||
|
@ -6,6 +6,7 @@ import reflex as rx
|
||||
from reflex.constants.compiler import Hooks, Imports
|
||||
from reflex.event import (
|
||||
Event,
|
||||
EventActionsMixin,
|
||||
EventChain,
|
||||
EventHandler,
|
||||
EventSpec,
|
||||
@ -410,6 +411,7 @@ def test_event_actions():
|
||||
|
||||
def test_event_actions_on_state():
|
||||
class EventActionState(BaseState):
|
||||
@rx.event
|
||||
def handler(self):
|
||||
pass
|
||||
|
||||
@ -417,7 +419,8 @@ def test_event_actions_on_state():
|
||||
assert isinstance(handler, EventHandler)
|
||||
assert not handler.event_actions
|
||||
|
||||
sp_handler = EventActionState.handler.stop_propagation # pyright: ignore [reportFunctionMemberAccess]
|
||||
sp_handler = EventActionState.handler.stop_propagation
|
||||
assert isinstance(sp_handler, EventActionsMixin)
|
||||
assert sp_handler.event_actions == {"stopPropagation": True}
|
||||
# should NOT affect other references to the handler
|
||||
assert not handler.event_actions
|
||||
|
@ -122,9 +122,12 @@ async def test_health(
|
||||
# Call the async health function
|
||||
response = await health()
|
||||
|
||||
print(json.loads(response.body)) # pyright: ignore [reportArgumentType]
|
||||
body = response.body
|
||||
assert isinstance(body, bytes)
|
||||
|
||||
print(json.loads(body))
|
||||
print(expected_status)
|
||||
|
||||
# Verify the response content and status code
|
||||
assert response.status_code == expected_code
|
||||
assert json.loads(response.body) == expected_status # pyright: ignore [reportArgumentType]
|
||||
assert json.loads(body) == expected_status
|
||||
|
@ -18,6 +18,7 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
@ -121,8 +122,8 @@ class TestState(BaseState):
|
||||
num2: float = 3.14
|
||||
key: str
|
||||
map_key: str = "a"
|
||||
array: List[float] = [1, 2, 3.14]
|
||||
mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]}
|
||||
array: rx.Field[List[float]] = rx.field([1, 2, 3.14])
|
||||
mapping: rx.Field[Dict[str, List[int]]] = rx.field({"a": [1, 2, 3], "b": [4, 5, 6]})
|
||||
obj: Object = Object()
|
||||
complex: Dict[int, Object] = {1: Object(), 2: Object()}
|
||||
fig: Figure = Figure()
|
||||
@ -432,13 +433,16 @@ def test_default_setters(test_state):
|
||||
|
||||
def test_class_indexing_with_vars():
|
||||
"""Test that we can index into a state var with another var."""
|
||||
prop = TestState.array[TestState.num1] # pyright: ignore [reportCallIssue, reportArgumentType]
|
||||
assert str(prop) == f"{TestState.get_name()}.array.at({TestState.get_name()}.num1)"
|
||||
prop = TestState.array[TestState.num1]
|
||||
assert (
|
||||
str(prop)
|
||||
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({TestState.get_name()}.array, ...args)))({TestState.get_name()}.num1))"
|
||||
)
|
||||
|
||||
prop = TestState.mapping["a"][TestState.num1] # pyright: ignore [reportCallIssue, reportArgumentType]
|
||||
assert (
|
||||
str(prop)
|
||||
== f'{TestState.get_name()}.mapping["a"].at({TestState.get_name()}.num1)'
|
||||
== f'(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({TestState.get_name()}.mapping["a"], ...args)))({TestState.get_name()}.num1))'
|
||||
)
|
||||
|
||||
prop = TestState.mapping[TestState.map_key]
|
||||
@ -1358,6 +1362,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
||||
class HandlerState(BaseState):
|
||||
x: int = 42
|
||||
|
||||
@rx.event
|
||||
def handler(self):
|
||||
self.x = self.x + 1
|
||||
|
||||
@ -1368,11 +1373,11 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
||||
counter += 1
|
||||
return counter
|
||||
|
||||
assert isinstance(HandlerState.handler, EventHandler)
|
||||
if use_partial:
|
||||
HandlerState.handler = functools.partial(HandlerState.handler.fn) # pyright: ignore [reportFunctionMemberAccess]
|
||||
partial_guy = functools.partial(HandlerState.handler.fn)
|
||||
HandlerState.handler = partial_guy # pyright: ignore[reportAttributeAccessIssue]
|
||||
assert isinstance(HandlerState.handler, functools.partial)
|
||||
else:
|
||||
assert isinstance(HandlerState.handler, EventHandler)
|
||||
|
||||
s = HandlerState()
|
||||
assert (
|
||||
@ -2029,8 +2034,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||
|
||||
# ensure state update was emitted
|
||||
assert mock_app.event_namespace is not None
|
||||
mock_app.event_namespace.emit.assert_called_once() # pyright: ignore [reportFunctionMemberAccess]
|
||||
mcall = mock_app.event_namespace.emit.mock_calls[0] # pyright: ignore [reportFunctionMemberAccess]
|
||||
mock_app.event_namespace.emit.assert_called_once() # pyright: ignore[reportFunctionMemberAccess]
|
||||
mock_calls = getattr(mock_app.event_namespace.emit, "mock_calls", None)
|
||||
assert mock_calls is not None
|
||||
assert isinstance(mock_calls, Sequence)
|
||||
mcall = mock_calls[0]
|
||||
assert mcall.args[0] == str(SocketEvent.EVENT)
|
||||
assert mcall.args[1] == StateUpdate(
|
||||
delta={
|
||||
@ -2231,7 +2239,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
||||
assert mock_app.event_namespace is not None
|
||||
emit_mock = mock_app.event_namespace.emit
|
||||
|
||||
first_ws_message = emit_mock.mock_calls[0].args[1] # pyright: ignore [reportFunctionMemberAccess]
|
||||
mock_calls = getattr(emit_mock, "mock_calls", None)
|
||||
assert mock_calls is not None
|
||||
assert isinstance(mock_calls, Sequence)
|
||||
|
||||
first_ws_message = mock_calls[0].args[1]
|
||||
assert (
|
||||
first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
|
||||
is not None
|
||||
@ -2246,7 +2258,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
||||
events=[],
|
||||
final=True,
|
||||
)
|
||||
for call in emit_mock.mock_calls[1:5]: # pyright: ignore [reportFunctionMemberAccess]
|
||||
for call in mock_calls[1:5]:
|
||||
assert call.args[1] == StateUpdate(
|
||||
delta={
|
||||
BackgroundTaskState.get_full_name(): {
|
||||
@ -2256,7 +2268,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
||||
events=[],
|
||||
final=True,
|
||||
)
|
||||
assert emit_mock.mock_calls[-2].args[1] == StateUpdate( # pyright: ignore [reportFunctionMemberAccess]
|
||||
assert mock_calls[-2].args[1] == StateUpdate(
|
||||
delta={
|
||||
BackgroundTaskState.get_full_name(): {
|
||||
"order": exp_order,
|
||||
@ -2267,7 +2279,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
||||
events=[],
|
||||
final=True,
|
||||
)
|
||||
assert emit_mock.mock_calls[-1].args[1] == StateUpdate( # pyright: ignore [reportFunctionMemberAccess]
|
||||
assert mock_calls[-1].args[1] == StateUpdate(
|
||||
delta={
|
||||
BackgroundTaskState.get_full_name(): {
|
||||
"computed_order": exp_order,
|
||||
|
@ -328,7 +328,10 @@ def test_basic_operations(TestObj):
|
||||
assert str(LiteralNumberVar.create(1) ** 2) == "(1 ** 2)"
|
||||
assert str(LiteralNumberVar.create(1) & v(2)) == "(1 && 2)"
|
||||
assert str(LiteralNumberVar.create(1) | v(2)) == "(1 || 2)"
|
||||
assert str(LiteralArrayVar.create([1, 2, 3])[0]) == "[1, 2, 3].at(0)"
|
||||
assert (
|
||||
str(LiteralArrayVar.create([1, 2, 3])[0])
|
||||
== "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3], ...args)))(0))"
|
||||
)
|
||||
assert (
|
||||
str(LiteralObjectVar.create({"a": 1, "b": 2})["a"])
|
||||
== '({ ["a"] : 1, ["b"] : 2 })["a"]'
|
||||
@ -350,27 +353,33 @@ def test_basic_operations(TestObj):
|
||||
str(Var(_js_expr="foo").to(ObjectVar, TestObj)._var_set_state("state").bar)
|
||||
== 'state.foo["bar"]'
|
||||
)
|
||||
assert str(abs(LiteralNumberVar.create(1))) == "Math.abs(1)"
|
||||
assert str(LiteralArrayVar.create([1, 2, 3]).length()) == "[1, 2, 3].length"
|
||||
assert str(abs(LiteralNumberVar.create(1))) == "(Math.abs(1))"
|
||||
assert (
|
||||
str(LiteralArrayVar.create([1, 2, 3]).length())
|
||||
== "(((...args) => (((_array) => _array.length)([1, 2, 3], ...args)))())"
|
||||
)
|
||||
assert (
|
||||
str(LiteralArrayVar.create([1, 2]) + LiteralArrayVar.create([3, 4]))
|
||||
== "[...[1, 2], ...[3, 4]]"
|
||||
== "(((...args) => (((_lhs, _rhs) => [..._lhs, ..._rhs])([1, 2], ...args)))([3, 4]))"
|
||||
)
|
||||
|
||||
# Tests for reverse operation
|
||||
assert (
|
||||
str(LiteralArrayVar.create([1, 2, 3]).reverse())
|
||||
== "[1, 2, 3].slice().reverse()"
|
||||
== "(((...args) => (((_array) => _array.slice().reverse())([1, 2, 3], ...args)))())"
|
||||
)
|
||||
assert (
|
||||
str(LiteralArrayVar.create(["1", "2", "3"]).reverse())
|
||||
== '["1", "2", "3"].slice().reverse()'
|
||||
== '(((...args) => (((_array) => _array.slice().reverse())(["1", "2", "3"], ...args)))())'
|
||||
)
|
||||
assert (
|
||||
str(Var(_js_expr="foo")._var_set_state("state").to(list).reverse())
|
||||
== "state.foo.slice().reverse()"
|
||||
== "(((...args) => (((_array) => _array.slice().reverse())(state.foo, ...args)))())"
|
||||
)
|
||||
assert (
|
||||
str(Var(_js_expr="foo").to(list).reverse())
|
||||
== "(((...args) => (((_array) => _array.slice().reverse())(foo, ...args)))())"
|
||||
)
|
||||
assert str(Var(_js_expr="foo").to(list).reverse()) == "foo.slice().reverse()"
|
||||
assert str(Var(_js_expr="foo", _var_type=str).js_type()) == "(typeof(foo))"
|
||||
|
||||
|
||||
@ -395,14 +404,32 @@ def test_basic_operations(TestObj):
|
||||
],
|
||||
)
|
||||
def test_list_tuple_contains(var, expected):
|
||||
assert str(var.contains(1)) == f"{expected}.includes(1)"
|
||||
assert str(var.contains("1")) == f'{expected}.includes("1")'
|
||||
assert str(var.contains(v(1))) == f"{expected}.includes(1)"
|
||||
assert str(var.contains(v("1"))) == f'{expected}.includes("1")'
|
||||
assert (
|
||||
str(var.contains(1))
|
||||
== f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))(1))'
|
||||
)
|
||||
assert (
|
||||
str(var.contains("1"))
|
||||
== f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))("1"))'
|
||||
)
|
||||
assert (
|
||||
str(var.contains(v(1)))
|
||||
== f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))(1))'
|
||||
)
|
||||
assert (
|
||||
str(var.contains(v("1")))
|
||||
== f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))("1"))'
|
||||
)
|
||||
other_state_var = Var(_js_expr="other", _var_type=str)._var_set_state("state")
|
||||
other_var = Var(_js_expr="other", _var_type=str)
|
||||
assert str(var.contains(other_state_var)) == f"{expected}.includes(state.other)"
|
||||
assert str(var.contains(other_var)) == f"{expected}.includes(other)"
|
||||
assert (
|
||||
str(var.contains(other_state_var))
|
||||
== f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))(state.other))'
|
||||
)
|
||||
assert (
|
||||
str(var.contains(other_var))
|
||||
== f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))(other))'
|
||||
)
|
||||
|
||||
|
||||
class Foo(rx.Base):
|
||||
@ -446,15 +473,23 @@ def test_var_types(var, var_type):
|
||||
],
|
||||
)
|
||||
def test_str_contains(var, expected):
|
||||
assert str(var.contains("1")) == f'{expected}.includes("1")'
|
||||
assert str(var.contains(v("1"))) == f'{expected}.includes("1")'
|
||||
assert (
|
||||
str(var.contains("1"))
|
||||
== f'(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))({expected!s}, ...args)))("1"))'
|
||||
)
|
||||
assert (
|
||||
str(var.contains(v("1")))
|
||||
== f'(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))({expected!s}, ...args)))("1"))'
|
||||
)
|
||||
other_state_var = Var(_js_expr="other")._var_set_state("state").to(str)
|
||||
other_var = Var(_js_expr="other").to(str)
|
||||
assert str(var.contains(other_state_var)) == f"{expected}.includes(state.other)"
|
||||
assert str(var.contains(other_var)) == f"{expected}.includes(other)"
|
||||
assert (
|
||||
str(var.contains("1", "hello"))
|
||||
== f'{expected}.some(obj => obj["hello"] === "1")'
|
||||
str(var.contains(other_state_var))
|
||||
== f"(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))({expected!s}, ...args)))(state.other))"
|
||||
)
|
||||
assert (
|
||||
str(var.contains(other_var))
|
||||
== f"(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))({expected!s}, ...args)))(other))"
|
||||
)
|
||||
|
||||
|
||||
@ -467,16 +502,17 @@ def test_str_contains(var, expected):
|
||||
],
|
||||
)
|
||||
def test_dict_contains(var, expected):
|
||||
assert str(var.contains(1)) == f"{expected}.hasOwnProperty(1)"
|
||||
assert str(var.contains("1")) == f'{expected}.hasOwnProperty("1")'
|
||||
assert str(var.contains(v(1))) == f"{expected}.hasOwnProperty(1)"
|
||||
assert str(var.contains(v("1"))) == f'{expected}.hasOwnProperty("1")'
|
||||
assert str(var.contains(1)) == f"{expected!s}.hasOwnProperty(1)"
|
||||
assert str(var.contains("1")) == f'{expected!s}.hasOwnProperty("1")'
|
||||
assert str(var.contains(v(1))) == f"{expected!s}.hasOwnProperty(1)"
|
||||
assert str(var.contains(v("1"))) == f'{expected!s}.hasOwnProperty("1")'
|
||||
other_state_var = Var(_js_expr="other")._var_set_state("state").to(str)
|
||||
other_var = Var(_js_expr="other").to(str)
|
||||
assert (
|
||||
str(var.contains(other_state_var)) == f"{expected}.hasOwnProperty(state.other)"
|
||||
str(var.contains(other_state_var))
|
||||
== f"{expected!s}.hasOwnProperty(state.other)"
|
||||
)
|
||||
assert str(var.contains(other_var)) == f"{expected}.hasOwnProperty(other)"
|
||||
assert str(var.contains(other_var)) == f"{expected!s}.hasOwnProperty(other)"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -484,7 +520,6 @@ def test_dict_contains(var, expected):
|
||||
[
|
||||
Var(_js_expr="list", _var_type=List[int]).guess_type(),
|
||||
Var(_js_expr="tuple", _var_type=Tuple[int, int]).guess_type(),
|
||||
Var(_js_expr="str", _var_type=str).guess_type(),
|
||||
],
|
||||
)
|
||||
def test_var_indexing_lists(var):
|
||||
@ -494,11 +529,20 @@ def test_var_indexing_lists(var):
|
||||
var : The str, list or tuple base var.
|
||||
"""
|
||||
# Test basic indexing.
|
||||
assert str(var[0]) == f"{var._js_expr}.at(0)"
|
||||
assert str(var[1]) == f"{var._js_expr}.at(1)"
|
||||
assert (
|
||||
str(var[0])
|
||||
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))(0))"
|
||||
)
|
||||
assert (
|
||||
str(var[1])
|
||||
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))(1))"
|
||||
)
|
||||
|
||||
# Test negative indexing.
|
||||
assert str(var[-1]) == f"{var._js_expr}.at(-1)"
|
||||
assert (
|
||||
str(var[-1])
|
||||
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))(-1))"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -532,11 +576,20 @@ def test_var_indexing_str():
|
||||
assert str_var[0]._var_type is str
|
||||
|
||||
# Test basic indexing.
|
||||
assert str(str_var[0]) == "str.at(0)"
|
||||
assert str(str_var[1]) == "str.at(1)"
|
||||
assert (
|
||||
str(str_var[0])
|
||||
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))(0))"
|
||||
)
|
||||
assert (
|
||||
str(str_var[1])
|
||||
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))(1))"
|
||||
)
|
||||
|
||||
# Test negative indexing.
|
||||
assert str(str_var[-1]) == "str.at(-1)"
|
||||
assert (
|
||||
str(str_var[-1])
|
||||
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))(-1))"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -651,9 +704,18 @@ def test_var_list_slicing(var):
|
||||
Args:
|
||||
var : The str, list or tuple base var.
|
||||
"""
|
||||
assert str(var[:1]) == f"{var._js_expr}.slice(undefined, 1)"
|
||||
assert str(var[1:]) == f"{var._js_expr}.slice(1, undefined)"
|
||||
assert str(var[:]) == f"{var._js_expr}.slice(undefined, undefined)"
|
||||
assert (
|
||||
str(var[:1])
|
||||
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))([null, 1, null]))"
|
||||
)
|
||||
assert (
|
||||
str(var[1:])
|
||||
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))([1, null, null]))"
|
||||
)
|
||||
assert (
|
||||
str(var[:])
|
||||
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))([null, null, null]))"
|
||||
)
|
||||
|
||||
|
||||
def test_str_var_slicing():
|
||||
@ -665,16 +727,40 @@ def test_str_var_slicing():
|
||||
assert str_var[:1]._var_type is str
|
||||
|
||||
# Test basic slicing.
|
||||
assert str(str_var[:1]) == 'str.split("").slice(undefined, 1).join("")'
|
||||
assert str(str_var[1:]) == 'str.split("").slice(1, undefined).join("")'
|
||||
assert str(str_var[:]) == 'str.split("").slice(undefined, undefined).join("")'
|
||||
assert str(str_var[1:2]) == 'str.split("").slice(1, 2).join("")'
|
||||
assert (
|
||||
str(str_var[:1])
|
||||
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([null, 1, null]))"
|
||||
)
|
||||
assert (
|
||||
str(str_var[1:])
|
||||
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([1, null, null]))"
|
||||
)
|
||||
assert (
|
||||
str(str_var[:])
|
||||
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([null, null, null]))"
|
||||
)
|
||||
assert (
|
||||
str(str_var[1:2])
|
||||
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([1, 2, null]))"
|
||||
)
|
||||
|
||||
# Test negative slicing.
|
||||
assert str(str_var[:-1]) == 'str.split("").slice(undefined, -1).join("")'
|
||||
assert str(str_var[-1:]) == 'str.split("").slice(-1, undefined).join("")'
|
||||
assert str(str_var[:-2]) == 'str.split("").slice(undefined, -2).join("")'
|
||||
assert str(str_var[-2:]) == 'str.split("").slice(-2, undefined).join("")'
|
||||
assert (
|
||||
str(str_var[:-1])
|
||||
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([null, -1, null]))"
|
||||
)
|
||||
assert (
|
||||
str(str_var[-1:])
|
||||
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([-1, null, null]))"
|
||||
)
|
||||
assert (
|
||||
str(str_var[:-2])
|
||||
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([null, -2, null]))"
|
||||
)
|
||||
assert (
|
||||
str(str_var[-2:])
|
||||
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([-2, null, null]))"
|
||||
)
|
||||
|
||||
|
||||
def test_dict_indexing():
|
||||
@ -963,11 +1049,11 @@ def test_function_var():
|
||||
|
||||
def test_var_operation():
|
||||
@var_operation
|
||||
def add(a: Union[NumberVar, int], b: Union[NumberVar, int]):
|
||||
def add(a: Var[int], b: Var[int]):
|
||||
return var_operation_return(js_expression=f"({a} + {b})", var_type=int)
|
||||
|
||||
assert str(add(1, 2)) == "(1 + 2)"
|
||||
assert str(add(a=4, b=-9)) == "(4 + -9)"
|
||||
assert str(add(4, -9)) == "(4 + -9)"
|
||||
|
||||
five = LiteralNumberVar.create(5)
|
||||
seven = add(2, five)
|
||||
@ -978,13 +1064,29 @@ def test_var_operation():
|
||||
def test_string_operations():
|
||||
basic_string = LiteralStringVar.create("Hello, World!")
|
||||
|
||||
assert str(basic_string.length()) == '"Hello, World!".split("").length'
|
||||
assert str(basic_string.lower()) == '"Hello, World!".toLowerCase()'
|
||||
assert str(basic_string.upper()) == '"Hello, World!".toUpperCase()'
|
||||
assert str(basic_string.strip()) == '"Hello, World!".trim()'
|
||||
assert str(basic_string.contains("World")) == '"Hello, World!".includes("World")'
|
||||
assert (
|
||||
str(basic_string.split(" ").join(",")) == '"Hello, World!".split(" ").join(",")'
|
||||
str(basic_string.length())
|
||||
== '(((...args) => (((...arg) => (((_array) => _array.length)((((_string, _sep = "") => isTrue(_sep) ? _string.split(_sep) : [..._string])(...args)))))("Hello, World!", ...args)))())'
|
||||
)
|
||||
assert (
|
||||
str(basic_string.lower())
|
||||
== '(((...args) => (((_string) => String.prototype.toLowerCase.apply(_string))("Hello, World!", ...args)))())'
|
||||
)
|
||||
assert (
|
||||
str(basic_string.upper())
|
||||
== '(((...args) => (((_string) => String.prototype.toUpperCase.apply(_string))("Hello, World!", ...args)))())'
|
||||
)
|
||||
assert (
|
||||
str(basic_string.strip())
|
||||
== '(((...args) => (((_string) => String.prototype.trim.apply(_string))("Hello, World!", ...args)))())'
|
||||
)
|
||||
assert (
|
||||
str(basic_string.contains("World"))
|
||||
== '(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))("Hello, World!", ...args)))("World"))'
|
||||
)
|
||||
assert (
|
||||
str(basic_string.split(" ").join(","))
|
||||
== '(((...args) => (((_array, _sep = "") => Array.prototype.join.apply(_array,[_sep]))((((...args) => (((_string, _sep = "") => isTrue(_sep) ? _string.split(_sep) : [..._string])("Hello, World!", ...args)))(" ")), ...args)))(","))'
|
||||
)
|
||||
|
||||
|
||||
@ -1004,14 +1106,14 @@ def test_all_number_operations():
|
||||
|
||||
assert (
|
||||
str(even_more_complicated_number)
|
||||
== "!(isTrue((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))))))"
|
||||
== "!((isTrue(((Math.abs((Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)))))))"
|
||||
)
|
||||
|
||||
assert str(LiteralNumberVar.create(5) > False) == "(5 > 0)"
|
||||
assert str(LiteralBooleanVar.create(False) < 5) == "(Number(false) < 5)"
|
||||
assert str(LiteralBooleanVar.create(False) < 5) == "((Number(false)) < 5)"
|
||||
assert (
|
||||
str(LiteralBooleanVar.create(False) < LiteralBooleanVar.create(True))
|
||||
== "(Number(false) < Number(true))"
|
||||
== "((Number(false)) < (Number(true)))"
|
||||
)
|
||||
|
||||
|
||||
@ -1020,10 +1122,10 @@ def test_all_number_operations():
|
||||
[
|
||||
(Var.create(False), "false"),
|
||||
(Var.create(True), "true"),
|
||||
(Var.create("false"), 'isTrue("false")'),
|
||||
(Var.create([1, 2, 3]), "isTrue([1, 2, 3])"),
|
||||
(Var.create({"a": 1, "b": 2}), 'isTrue(({ ["a"] : 1, ["b"] : 2 }))'),
|
||||
(Var("mysterious_var"), "isTrue(mysterious_var)"),
|
||||
(Var.create("false"), '(isTrue("false"))'),
|
||||
(Var.create([1, 2, 3]), "(isTrue([1, 2, 3]))"),
|
||||
(Var.create({"a": 1, "b": 2}), '(isTrue(({ ["a"] : 1, ["b"] : 2 })))'),
|
||||
(Var("mysterious_var"), "(isTrue(mysterious_var))"),
|
||||
],
|
||||
)
|
||||
def test_boolify_operations(var, expected):
|
||||
@ -1032,18 +1134,30 @@ def test_boolify_operations(var, expected):
|
||||
|
||||
def test_index_operation():
|
||||
array_var = LiteralArrayVar.create([1, 2, 3, 4, 5])
|
||||
assert str(array_var[0]) == "[1, 2, 3, 4, 5].at(0)"
|
||||
assert str(array_var[1:2]) == "[1, 2, 3, 4, 5].slice(1, 2)"
|
||||
assert (
|
||||
str(array_var[0])
|
||||
== "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))(0))"
|
||||
)
|
||||
assert (
|
||||
str(array_var[1:2])
|
||||
== "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))([1, 2, null]))"
|
||||
)
|
||||
assert (
|
||||
str(array_var[1:4:2])
|
||||
== "[1, 2, 3, 4, 5].slice(1, 4).filter((_, i) => i % 2 === 0)"
|
||||
== "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))([1, 4, 2]))"
|
||||
)
|
||||
assert (
|
||||
str(array_var[::-1])
|
||||
== "[1, 2, 3, 4, 5].slice(0, [1, 2, 3, 4, 5].length).slice().reverse().slice(undefined, undefined).filter((_, i) => i % 1 === 0)"
|
||||
== "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))([null, null, -1]))"
|
||||
)
|
||||
assert (
|
||||
str(array_var.reverse())
|
||||
== "(((...args) => (((_array) => _array.slice().reverse())([1, 2, 3, 4, 5], ...args)))())"
|
||||
)
|
||||
assert (
|
||||
str(array_var[0].to(NumberVar) + 9)
|
||||
== "((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))(0)) + 9)"
|
||||
)
|
||||
assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].slice().reverse()"
|
||||
assert str(array_var[0].to(NumberVar) + 9) == "([1, 2, 3, 4, 5].at(0) + 9)"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -1065,40 +1179,37 @@ def test_inf_and_nan(var, expected_js):
|
||||
def test_array_operations():
|
||||
array_var = LiteralArrayVar.create([1, 2, 3, 4, 5])
|
||||
|
||||
assert str(array_var.length()) == "[1, 2, 3, 4, 5].length"
|
||||
assert str(array_var.contains(3)) == "[1, 2, 3, 4, 5].includes(3)"
|
||||
assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].slice().reverse()"
|
||||
assert (
|
||||
str(ArrayVar.range(10))
|
||||
== "Array.from({ length: Math.ceil((10 - 0) / 1) }, (_, i) => 0 + i * 1)"
|
||||
str(array_var.length())
|
||||
== "(((...args) => (((_array) => _array.length)([1, 2, 3, 4, 5], ...args)))())"
|
||||
)
|
||||
assert (
|
||||
str(ArrayVar.range(1, 10))
|
||||
== "Array.from({ length: Math.ceil((10 - 1) / 1) }, (_, i) => 1 + i * 1)"
|
||||
str(array_var.contains(3))
|
||||
== '(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))([1, 2, 3, 4, 5], ...args)))(3))'
|
||||
)
|
||||
assert (
|
||||
str(ArrayVar.range(1, 10, 2))
|
||||
== "Array.from({ length: Math.ceil((10 - 1) / 2) }, (_, i) => 1 + i * 2)"
|
||||
)
|
||||
assert (
|
||||
str(ArrayVar.range(1, 10, -1))
|
||||
== "Array.from({ length: Math.ceil((10 - 1) / -1) }, (_, i) => 1 + i * -1)"
|
||||
str(array_var.reverse())
|
||||
== "(((...args) => (((_array) => _array.slice().reverse())([1, 2, 3, 4, 5], ...args)))())"
|
||||
)
|
||||
assert str(ArrayVar.range(10)) == "[...range(10, null, 1)]"
|
||||
assert str(ArrayVar.range(1, 10)) == "[...range(1, 10, 1)]"
|
||||
assert str(ArrayVar.range(1, 10, 2)) == "[...range(1, 10, 2)]"
|
||||
assert str(ArrayVar.range(1, 10, -1)) == "[...range(1, 10, -1)]"
|
||||
|
||||
|
||||
def test_object_operations():
|
||||
object_var = LiteralObjectVar.create({"a": 1, "b": 2, "c": 3})
|
||||
|
||||
assert (
|
||||
str(object_var.keys()) == 'Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))'
|
||||
str(object_var.keys()) == '(Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })))'
|
||||
)
|
||||
assert (
|
||||
str(object_var.values())
|
||||
== 'Object.values(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))'
|
||||
== '(Object.values(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })))'
|
||||
)
|
||||
assert (
|
||||
str(object_var.entries())
|
||||
== 'Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))'
|
||||
== '(Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })))'
|
||||
)
|
||||
assert str(object_var.a) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]'
|
||||
assert str(object_var["a"]) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]'
|
||||
@ -1139,12 +1250,12 @@ def test_type_chains():
|
||||
List[int],
|
||||
)
|
||||
assert (
|
||||
str(object_var.keys()[0].upper())
|
||||
== 'Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })).at(0).toUpperCase()'
|
||||
str(object_var.keys()[0].upper()) # pyright: ignore [reportAttributeAccessIssue]
|
||||
== '(((...args) => (((_string) => String.prototype.toUpperCase.apply(_string))((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))), ...args)))(0)), ...args)))())'
|
||||
)
|
||||
assert (
|
||||
str(object_var.entries()[1][1] - 1)
|
||||
== '(Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })).at(1).at(1) - 1)'
|
||||
str(object_var.entries()[1][1] - 1) # pyright: ignore [reportCallIssue, reportOperatorIssue]
|
||||
== '((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))), ...args)))(1)), ...args)))(1)) - 1)'
|
||||
)
|
||||
assert (
|
||||
str(object_var["c"] + object_var["b"]) # pyright: ignore [reportCallIssue, reportOperatorIssue]
|
||||
@ -1153,10 +1264,14 @@ def test_type_chains():
|
||||
|
||||
|
||||
def test_nested_dict():
|
||||
arr = LiteralArrayVar.create([{"bar": ["foo", "bar"]}], List[Dict[str, List[str]]])
|
||||
arr = Var.create([{"bar": ["foo", "bar"]}]).to(List[Dict[str, List[str]]])
|
||||
first_dict = arr.at(0)
|
||||
bar_element = first_dict["bar"]
|
||||
first_bar_element = bar_element[0]
|
||||
|
||||
assert (
|
||||
str(arr[0]["bar"][0]) == '[({ ["bar"] : ["foo", "bar"] })].at(0)["bar"].at(0)' # pyright: ignore [reportIndexIssue]
|
||||
str(first_bar_element)
|
||||
== '(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((((...args) => (((_array, _index) => _array.at(_index))([({ ["bar"] : ["foo", "bar"] })], ...args)))(0))["bar"], ...args)))(0))' # pyright: ignore [reportIndexIssue]
|
||||
)
|
||||
|
||||
|
||||
@ -1331,9 +1446,9 @@ def test_unsupported_types_for_reverse(var):
|
||||
Args:
|
||||
var: The base var.
|
||||
"""
|
||||
with pytest.raises(TypeError) as err:
|
||||
with pytest.raises(AttributeError) as err:
|
||||
var.reverse()
|
||||
assert err.value.args[0] == "Cannot reverse non-list var."
|
||||
assert err.value.args[0] == "'Var' object has no attribute 'reverse'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -1351,12 +1466,9 @@ def test_unsupported_types_for_contains(var: Var):
|
||||
Args:
|
||||
var: The base var.
|
||||
"""
|
||||
with pytest.raises(TypeError) as err:
|
||||
with pytest.raises(AttributeError) as err:
|
||||
assert var.contains(1) # pyright: ignore [reportAttributeAccessIssue]
|
||||
assert (
|
||||
err.value.args[0]
|
||||
== f"Var of type {var._var_type} does not support contains check."
|
||||
)
|
||||
assert err.value.args[0] == "'Var' object has no attribute 'contains'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -1376,7 +1488,7 @@ def test_unsupported_types_for_string_contains(other):
|
||||
assert Var(_js_expr="var").to(str).contains(other)
|
||||
assert (
|
||||
err.value.args[0]
|
||||
== f"Unsupported Operand type(s) for contains: ToStringOperation, {type(other).__name__}"
|
||||
== f"Invalid argument other provided to argument 0 in var operation. Expected <class 'str'> but got {other._var_type}."
|
||||
)
|
||||
|
||||
|
||||
@ -1608,17 +1720,12 @@ def test_valid_var_operations(operand1_var: Var, operand2_var, operators: List[s
|
||||
LiteralVar.create([10, 20]),
|
||||
LiteralVar.create("5"),
|
||||
[
|
||||
"+",
|
||||
"-",
|
||||
"/",
|
||||
"//",
|
||||
"*",
|
||||
"%",
|
||||
"**",
|
||||
">",
|
||||
"<",
|
||||
"<=",
|
||||
">=",
|
||||
"^",
|
||||
"<<",
|
||||
">>",
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
import plotly.graph_objects as go
|
||||
import pytest
|
||||
@ -98,60 +98,6 @@ def test_wrap(text: str, open: str, expected: str, check_first: bool, num: int):
|
||||
assert format.wrap(text, open, check_first=check_first, num=num) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string,expected_output",
|
||||
[
|
||||
("This is a random string", "This is a random string"),
|
||||
(
|
||||
"This is a random string with `backticks`",
|
||||
"This is a random string with \\`backticks\\`",
|
||||
),
|
||||
(
|
||||
"This is a random string with `backticks`",
|
||||
"This is a random string with \\`backticks\\`",
|
||||
),
|
||||
(
|
||||
"This is a string with ${someValue[`string interpolation`]} unescaped",
|
||||
"This is a string with ${someValue[`string interpolation`]} unescaped",
|
||||
),
|
||||
(
|
||||
"This is a string with `backticks` and ${someValue[`string interpolation`]} unescaped",
|
||||
"This is a string with \\`backticks\\` and ${someValue[`string interpolation`]} unescaped",
|
||||
),
|
||||
(
|
||||
"This is a string with `backticks`, ${someValue[`the first string interpolation`]} and ${someValue[`the second`]}",
|
||||
"This is a string with \\`backticks\\`, ${someValue[`the first string interpolation`]} and ${someValue[`the second`]}",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_escape_js_string(string, expected_output):
|
||||
assert format._escape_js_string(string) == expected_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,indent_level,expected",
|
||||
[
|
||||
("", 2, ""),
|
||||
("hello", 2, "hello"),
|
||||
("hello\nworld", 2, " hello\n world\n"),
|
||||
("hello\nworld", 4, " hello\n world\n"),
|
||||
(" hello\n world", 2, " hello\n world\n"),
|
||||
],
|
||||
)
|
||||
def test_indent(text: str, indent_level: int, expected: str, windows_platform: bool):
|
||||
"""Test indenting a string.
|
||||
|
||||
Args:
|
||||
text: The text to indent.
|
||||
indent_level: The number of spaces to indent by.
|
||||
expected: The expected output string.
|
||||
windows_platform: Whether the system is windows.
|
||||
"""
|
||||
assert format.indent(text, indent_level) == (
|
||||
expected.replace("\n", "\r\n") if windows_platform else expected
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,output",
|
||||
[
|
||||
@ -252,25 +198,6 @@ def test_to_kebab_case(input: str, output: str):
|
||||
assert format.to_kebab_case(input) == output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,output",
|
||||
[
|
||||
("", "{``}"),
|
||||
("hello", "{`hello`}"),
|
||||
("hello world", "{`hello world`}"),
|
||||
("hello=`world`", "{`hello=\\`world\\``}"),
|
||||
],
|
||||
)
|
||||
def test_format_string(input: str, output: str):
|
||||
"""Test formatting the input as JS string literal.
|
||||
|
||||
Args:
|
||||
input: the input string.
|
||||
output: the output string.
|
||||
"""
|
||||
assert format.format_string(input) == output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input,output",
|
||||
[
|
||||
@ -310,45 +237,6 @@ def test_format_route(route: str, format_case: bool, expected: bool):
|
||||
assert format.format_route(route, format_case=format_case) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"condition, match_cases, default,expected",
|
||||
[
|
||||
(
|
||||
"state__state.value",
|
||||
[
|
||||
[LiteralVar.create(1), LiteralVar.create("red")],
|
||||
[LiteralVar.create(2), LiteralVar.create(3), LiteralVar.create("blue")],
|
||||
[TestState.mapping, TestState.num1],
|
||||
[
|
||||
LiteralVar.create(f"{TestState.map_key}-key"),
|
||||
LiteralVar.create("return-key"),
|
||||
],
|
||||
],
|
||||
LiteralVar.create("yellow"),
|
||||
'(() => { switch (JSON.stringify(state__state.value)) {case JSON.stringify(1): return ("red"); break;case JSON.stringify(2): case JSON.stringify(3): '
|
||||
f'return ("blue"); break;case JSON.stringify({TestState.get_full_name()}.mapping): return '
|
||||
f'({TestState.get_full_name()}.num1); break;case JSON.stringify(({TestState.get_full_name()}.map_key+"-key")): return ("return-key");'
|
||||
' break;default: return ("yellow"); break;};})()',
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_format_match(
|
||||
condition: str,
|
||||
match_cases: List[List[Var]],
|
||||
default: Var,
|
||||
expected: str,
|
||||
):
|
||||
"""Test formatting a match statement.
|
||||
|
||||
Args:
|
||||
condition: The condition to match.
|
||||
match_cases: List of match cases to be matched.
|
||||
default: Catchall case for the match statement.
|
||||
expected: The expected string output.
|
||||
"""
|
||||
assert format.format_match(condition, match_cases, default) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prop,formatted",
|
||||
[
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Type, Union
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Sequence, Tuple, Type, Union
|
||||
|
||||
import pytest
|
||||
import typer
|
||||
@ -109,10 +109,20 @@ def test_is_generic_alias(cls: type, expected: bool):
|
||||
(Dict[str, str], dict[str, str], True),
|
||||
(Dict[str, str], dict[str, Any], True),
|
||||
(Dict[str, Any], dict[str, Any], True),
|
||||
(List[int], Sequence[int], True),
|
||||
(List[str], Sequence[int], False),
|
||||
(Tuple[int], Sequence[int], True),
|
||||
(Tuple[int, str], Sequence[int], False),
|
||||
(Tuple[int, ...], Sequence[int], True),
|
||||
(str, Sequence[int], False),
|
||||
(str, Sequence[str], True),
|
||||
],
|
||||
)
|
||||
def test_typehint_issubclass(subclass, superclass, expected):
|
||||
assert types.typehint_issubclass(subclass, superclass) == expected
|
||||
if expected:
|
||||
assert types.typehint_issubclass(subclass, superclass)
|
||||
else:
|
||||
assert not types.typehint_issubclass(subclass, superclass)
|
||||
|
||||
|
||||
def test_validate_none_bun_path(mocker):
|
||||
|
Loading…
Reference in New Issue
Block a user