Compare commits

...

101 Commits

Author SHA1 Message Date
Khaleel Al-Adhami
c74313992f refactor match code 2025-02-13 14:42:11 -08:00
Khaleel Al-Adhami
d1ff6d51a2 bring back UntypedVarError 2025-02-13 13:41:20 -08:00
Khaleel Al-Adhami
3080cadd31 some random fixes after merge 2025-02-13 13:36:45 -08:00
Khaleel Al-Adhami
a2074b9081 Merge branch 'main' into add-validation-to-function-vars 2025-02-13 12:52:00 -08:00
Khaleel Al-Adhami
cabb9b43ad Merge branch 'main' into add-validation-to-function-vars 2025-02-13 11:43:05 -08:00
Khaleel Al-Adhami
aef6aced03 clear cache on add style recursive 2025-01-31 14:38:23 -08:00
Khaleel Al-Adhami
1bb8ba5585 Merge branch 'main' into add-validation-to-function-vars 2025-01-31 14:31:11 -08:00
Khaleel Al-Adhami
06b751f679 remove caching things 2025-01-31 14:28:40 -08:00
Khaleel Al-Adhami
749577f0bc precommit 2025-01-31 14:19:13 -08:00
Khaleel Al-Adhami
46c66b2adf lock poetry 2025-01-31 14:05:11 -08:00
Khaleel Al-Adhami
8fab30eb69 Merge branch 'main' into add-validation-to-function-vars 2025-01-31 14:03:42 -08:00
Khaleel Al-Adhami
2ffa698c6b improve match and cond checking 2025-01-23 15:03:45 -08:00
Khaleel Al-Adhami
0d746bf762 handle default args 2025-01-23 13:09:47 -08:00
Khaleel Al-Adhami
d3b12a84fa optimize var operation output 2025-01-23 11:54:45 -08:00
Khaleel Al-Adhami
6604784ca1 update pydantic requirement 2025-01-23 10:27:34 -08:00
Khaleel Al-Adhami
24f341d125 Merge branch 'main' into add-validation-to-function-vars 2025-01-23 10:26:04 -08:00
Khaleel Al-Adhami
29fc4b020a make the match logic better 2025-01-22 16:39:43 -08:00
Masen Furer
19dd15bd44
Banner components that return Fragment inherit from Fragment 2025-01-22 16:25:44 -08:00
Khaleel Al-Adhami
6a50b3a29e i was a bit silly 2025-01-22 16:24:03 -08:00
Khaleel Al-Adhami
b2a27cb8c1 fix list of component is a component 2025-01-22 16:17:08 -08:00
Khaleel Al-Adhami
540382dd3e kill callable var 2025-01-22 15:04:31 -08:00
Khaleel Al-Adhami
c6e2368c95 callable vars are not good 2025-01-22 13:35:04 -08:00
Khaleel Al-Adhami
e10cf07506 make icon to static methods 2025-01-22 13:14:55 -08:00
Khaleel Al-Adhami
84d3a2bb97 unbreak cond why not 2025-01-22 12:53:16 -08:00
Khaleel Al-Adhami
8173e10698 ok we can unbreak foreach just for you 2025-01-22 12:51:01 -08:00
Khaleel Al-Adhami
1f9fbd88de resolve merge mistakes 2025-01-22 12:47:21 -08:00
Khaleel Al-Adhami
68b8c12127 Merge branch 'main' into add-validation-to-function-vars 2025-01-22 12:11:48 -08:00
Khaleel Al-Adhami
5e45ca509b remove iterable from jinja 2025-01-17 18:23:38 -08:00
Khaleel Al-Adhami
be92063421 dang it darglint 2025-01-17 18:21:28 -08:00
Khaleel Al-Adhami
5d6b51c561 what if i deleted rx.foreach 2025-01-17 18:16:43 -08:00
Khaleel Al-Adhami
b10f6e836d remove even more cond and match 2025-01-17 16:34:48 -08:00
Khaleel Al-Adhami
db89a712e9 reformat color_mode pyi 2025-01-17 16:30:05 -08:00
Khaleel Al-Adhami
9e7eeb2a6e fix imports 2025-01-17 16:29:10 -08:00
Khaleel Al-Adhami
b7579f4d8d why not, remove cond 2025-01-17 16:27:30 -08:00
Khaleel Al-Adhami
392c5b5a69 who likes cond 2025-01-17 16:21:53 -08:00
Khaleel Al-Adhami
b11fc5a8ef update pyi 2025-01-17 16:17:47 -08:00
Khaleel Al-Adhami
00019daa27 get rid of match class 2025-01-17 16:17:24 -08:00
Khaleel Al-Adhami
36af8255d3 remove unused format functions 2025-01-17 16:06:58 -08:00
Khaleel Al-Adhami
aadd8b56bf do silly things 2025-01-17 15:51:11 -08:00
Khaleel Al-Adhami
fa6c12e8b3 remove unnecessary comment 2025-01-17 15:32:34 -08:00
Khaleel Al-Adhami
f4aa122950 don't delete thomas code that's rude 2025-01-17 15:24:32 -08:00
Khaleel Al-Adhami
9a987caf76 dang it darglint 2025-01-17 15:21:33 -08:00
Khaleel Al-Adhami
c6f05bb320 fix lineno 2025-01-17 15:15:09 -08:00
Khaleel Al-Adhami
0e539a208c fix syntax for soy 3.10 2025-01-17 15:08:58 -08:00
Khaleel Al-Adhami
a7230f1f45 resolve pyright issues 2025-01-17 15:07:32 -08:00
Khaleel Al-Adhami
0798cb8f60 fix version python 3.10 2025-01-17 14:31:49 -08:00
Khaleel Al-Adhami
270fcb996d Merge branch 'main' into add-validation-to-function-vars 2025-01-17 14:24:48 -08:00
Khaleel Al-Adhami
57d8ea02e9 down to only two pyright error 2025-01-17 14:16:03 -08:00
Khaleel Al-Adhami
112b2ed948 solve some but not all pyright issues 2025-01-16 20:06:02 -08:00
Khaleel Al-Adhami
3d73f561b7 can't have ohio 2025-01-16 19:16:53 -08:00
Khaleel Al-Adhami
72f1fa7cb4 make the var type actually work 2025-01-16 18:17:15 -08:00
Khaleel Al-Adhami
94b4443afc what if we simply didn't have match 2025-01-16 18:11:48 -08:00
Khaleel Al-Adhami
f9d45d5562 fix tests for cond and var 2025-01-16 16:02:41 -08:00
Khaleel Al-Adhami
94c9e52474 fix convert to component logic 2025-01-16 15:16:48 -08:00
Khaleel Al-Adhami
4300f338d8 add wrap components override 2025-01-16 14:50:37 -08:00
Khaleel Al-Adhami
f42d1f4b0f fix component state 2025-01-16 14:34:35 -08:00
Khaleel Al-Adhami
a488fe0c49 i missed up that 2025-01-16 14:05:38 -08:00
Khaleel Al-Adhami
076cfea6ae fix and and or 2025-01-16 13:48:42 -08:00
Khaleel Al-Adhami
d0208e678c change range a bit 2025-01-15 19:16:36 -08:00
Khaleel Al-Adhami
19b6fe9efc handle component at largest scope 2025-01-15 18:02:44 -08:00
Khaleel Al-Adhami
1aa728ee4c use infallible guy 2025-01-15 17:28:46 -08:00
Khaleel Al-Adhami
713f907bf0 silly me 2025-01-15 17:08:33 -08:00
Khaleel Al-Adhami
990bf131c6 fix that test 2025-01-15 17:04:36 -08:00
Khaleel Al-Adhami
f257122934 Merge branch 'main' into add-validation-to-function-vars 2025-01-15 17:04:28 -08:00
Khaleel Al-Adhami
f0f84d5410 fix some tests 2025-01-15 17:00:06 -08:00
Khaleel Al-Adhami
45dde0072e handle vars with var_type being BaseComponent 2025-01-15 10:32:15 -08:00
Khaleel Al-Adhami
2a02e96d87 make the thing compile again 2025-01-15 10:25:25 -08:00
Khaleel Al-Adhami
056de9e277 poetry update 2025-01-15 10:15:55 -08:00
Khaleel Al-Adhami
d31510c655 Merge branch 'main' into add-validation-to-function-vars 2025-01-14 18:40:59 -08:00
Khaleel Al-Adhami
a5526afaeb update pyright once again 2025-01-02 15:20:03 -08:00
Khaleel Al-Adhami
99a3090784 Merge branch 'main' into add-validation-to-function-vars 2025-01-02 11:44:45 -08:00
Khaleel Al-Adhami
bd2ea5b417 update poetry version 2024-12-12 07:35:18 +03:00
Khaleel Al-Adhami
8830d5ab77 update poetry 2024-12-12 07:33:36 +03:00
Khaleel Al-Adhami
06eb04f005 Merge branch 'main' into add-validation-to-function-vars 2024-12-12 07:33:24 +03:00
Khaleel Al-Adhami
2e1bc057a4 aaaaa 2024-11-15 17:39:14 -08:00
Khaleel Al-Adhami
2b05ee98ed make it handle slice 2024-11-15 17:30:45 -08:00
Khaleel Al-Adhami
ed1ae0d3a2 add missing return 2024-11-15 17:15:43 -08:00
Khaleel Al-Adhami
7d0a4f7133 make safe issubclass 2024-11-15 17:12:42 -08:00
Khaleel Al-Adhami
079cc56f59 more types 2024-11-15 16:59:28 -08:00
Khaleel Al-Adhami
92b1232806 would this fix it? no clue 2024-11-15 16:16:13 -08:00
Khaleel Al-Adhami
7f1dc7c841 remove .bool 2024-11-15 15:58:03 -08:00
Khaleel Al-Adhami
eac54d60d2 call guess type 2024-11-15 15:54:04 -08:00
Khaleel Al-Adhami
702670ff26 add components to var data 2024-11-15 15:47:46 -08:00
Khaleel Al-Adhami
88cfb3b7e2 handle var at top level 2024-11-15 15:11:44 -08:00
Khaleel Al-Adhami
53b98543cc default factory 2024-11-15 14:58:19 -08:00
Khaleel Al-Adhami
5f0546f32e what am i doing anymore 2024-11-15 14:53:51 -08:00
Khaleel Al-Adhami
2c04153013 special case ellipsis types 2024-11-14 09:45:44 -08:00
Khaleel Al-Adhami
a9db61b371 get it right pyright 2024-11-13 19:03:42 -08:00
Khaleel Al-Adhami
3cdd2097b6 fix pyright issues outside of vars 2024-11-13 18:43:20 -08:00
Khaleel Al-Adhami
9d7e353ed3 fix pyright issues 2024-11-13 18:40:06 -08:00
Khaleel Al-Adhami
f4aa1f58c3 implement type computers 2024-11-13 18:17:53 -08:00
Khaleel Al-Adhami
ebc81811c0 fix silly mistakes 2024-11-13 13:51:47 -08:00
Khaleel Al-Adhami
1e9743dcd6 Merge branch 'main' into add-validation-to-function-vars 2024-11-13 13:23:38 -08:00
Khaleel Al-Adhami
05bd41c040 add validation 2024-11-13 13:22:01 -08:00
Khaleel Al-Adhami
f9b24fe5bd get typevar from extensions 2024-11-12 16:38:30 -08:00
Khaleel Al-Adhami
6745d6cb9d don't use Any from extensions 2024-11-12 16:35:04 -08:00
Khaleel Al-Adhami
7ada0ea5b9 special case 3.9 2024-11-12 16:23:06 -08:00
Khaleel Al-Adhami
48951dbabd try importing everything from extensions 2024-11-12 16:11:30 -08:00
Khaleel Al-Adhami
7aa9245514 remove ellipsis as they are not supported in 3.9 2024-11-12 16:07:17 -08:00
Khaleel Al-Adhami
9b06d684cd import ParamSpec from typing_extensions 2024-11-12 16:03:40 -08:00
Khaleel Al-Adhami
56f0d6375b add typing to function vars 2024-11-12 16:01:32 -08:00
60 changed files with 4543 additions and 3878 deletions

View File

@ -23,7 +23,7 @@ fastapi = ">=0.96.0,!=0.111.0,!=0.111.1"
gunicorn = ">=20.1.0,<24.0"
jinja2 = ">=3.1.2,<4.0"
psutil = ">=5.9.4,<7.0"
pydantic = ">=1.10.2,<3.0"
pydantic = ">=1.10.17,<3.0"
python-multipart = ">=0.0.5,<0.1"
python-socketio = ">=5.7.0,<6.0"
redis = ">=4.3.5,<6.0"
@ -55,7 +55,7 @@ typing_extensions = ">=4.6.0"
[tool.poetry.group.dev.dependencies]
pytest = ">=7.1.2,<9.0"
pytest-mock = ">=3.10.0,<4.0"
pyright = ">=1.1.392, <1.2"
pyright = ">=1.1.392.post0,<1.2"
darglint = ">=1.8.1,<2.0"
dill = ">=0.3.8"
toml = ">=0.10.2,<1.0"

View File

@ -6,12 +6,6 @@
{% filter indent(width=indent_width) %}
{%- if component is not mapping %}
{{- component }}
{%- elif "iterable" in component %}
{{- render_iterable_tag(component) }}
{%- elif component.name == "match"%}
{{- render_match_tag(component) }}
{%- elif "cond" in component %}
{{- render_condition_tag(component) }}
{%- elif component.children|length %}
{{- render_tag(component) }}
{%- else %}
@ -44,30 +38,6 @@
{%- endmacro %}
{# Rendering condition component. #}
{# Args: #}
{# component: component dictionary #}
{% macro render_condition_tag(component) %}
{ {{- component.cond_state }} ? (
{{ render(component.true_value) }}
) : (
{{ render(component.false_value) }}
)}
{%- endmacro %}
{# Rendering iterable component. #}
{# Args: #}
{# component: component dictionary #}
{% macro render_iterable_tag(component) %}
<>{ {{ component.iterable_state }}.map(({{ component.arg_name }}, {{ component.arg_index }}) => (
{% for child in component.children %}
{{ render(child) }}
{% endfor %}
))}</>
{%- endmacro %}
{# Rendering props of a component. #}
{# Args: #}
{# component: component dictionary #}
@ -75,29 +45,6 @@
{% if props|length %} {{ props|join(" ") }}{% endif %}
{% endmacro %}
{# Rendering Match component. #}
{# Args: #}
{# component: component dictionary #}
{% macro render_match_tag(component) %}
{
(() => {
switch (JSON.stringify({{ component.cond._js_expr }})) {
{% for case in component.match_cases %}
{% for condition in case[:-1] %}
case JSON.stringify({{ condition._js_expr }}):
{% endfor %}
return {{ render(case[-1]) }};
break;
{% endfor %}
default:
return {{ render(component.default) }};
break;
}
})()
}
{%- endmacro %}
{# Rendering content with args. #}
{# Args: #}
{# component: component dictionary #}

View File

@ -1,43 +1,43 @@
/**
* Simulate the python range() builtin function.
* inspired by https://dev.to/guyariely/using-python-range-in-javascript-337p
*
*
* If needed outside of an iterator context, use `Array.from(range(10))` or
* spread syntax `[...range(10)]` to get an array.
*
*
* @param {number} start: the start or end of the range.
* @param {number} stop: the end of the range.
* @param {number} step: the step of the range.
* @returns {object} an object with a Symbol.iterator method over the range
*/
export default function range(start, stop, step) {
return {
[Symbol.iterator]() {
if (stop === undefined) {
stop = start;
start = 0;
}
if (step === undefined) {
step = 1;
}
let i = start - step;
return {
next() {
i += step;
if ((step > 0 && i < stop) || (step < 0 && i > stop)) {
return {
value: i,
done: false,
};
}
return {
[Symbol.iterator]() {
if ((stop ?? undefined) === undefined) {
stop = start;
start = 0;
}
if ((step ?? undefined) === undefined) {
step = 1;
}
let i = start - step;
return {
next() {
i += step;
if ((step > 0 && i < stop) || (step < 0 && i > stop)) {
return {
value: undefined,
done: true,
value: i,
done: false,
};
},
};
},
};
}
}
return {
value: undefined,
done: true,
};
},
};
},
};
}

View File

@ -926,6 +926,45 @@ export const isTrue = (val) => {
return Boolean(val);
};
/**
* Returns a copy of a section of an array.
* @param {Array | string} arrayLike The array to slice.
* @param {[number, number, number]} slice The slice to apply.
* @returns The sliced array.
*/
export const atSlice = (arrayLike, slice) => {
const array = [...arrayLike];
const [startSlice, endSlice, stepSlice] = slice;
if (stepSlice ?? null === null) {
return array.slice(startSlice ?? undefined, endSlice ?? undefined);
}
const step = stepSlice ?? 1;
if (step > 0) {
return array
.slice(startSlice ?? undefined, endSlice ?? undefined)
.filter((_, i) => i % step === 0);
}
const actualStart = (endSlice ?? null) === null ? 0 : endSlice + 1;
const actualEnd =
(startSlice ?? null) === null ? array.length : startSlice + 1;
return array
.slice(actualStart, actualEnd)
.reverse()
.filter((_, i) => i % step === 0);
};
/**
* Get the value at a slice or index.
* @param {Array | string} arrayLike The array to get the value from.
* @param {number | [number, number, number]} sliceOrIndex The slice or index to get the value at.
* @returns The value at the slice or index.
*/
export const atSliceOrIndex = (arrayLike, sliceOrIndex) => {
return Array.isArray(sliceOrIndex)
? atSlice(arrayLike, sliceOrIndex)
: arrayLike.at(sliceOrIndex);
};
/**
* Get the value from a ref.
* @param ref The ref to get the value from.

View File

@ -5,15 +5,9 @@ from __future__ import annotations
import os
from typing import TYPE_CHECKING, Any, List, Type
try:
import pydantic.v1.main as pydantic_main
from pydantic.v1 import BaseModel
from pydantic.v1.fields import ModelField
except ModuleNotFoundError:
if not TYPE_CHECKING:
import pydantic.main as pydantic_main
from pydantic import BaseModel
from pydantic.fields import ModelField
import pydantic.v1.main as pydantic_main
from pydantic.v1 import BaseModel
from pydantic.v1.fields import ModelField
def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
@ -50,7 +44,7 @@ if TYPE_CHECKING:
from reflex.vars import Var
class Base(BaseModel): # pyright: ignore [reportPossiblyUnboundVariable]
class Base(BaseModel):
"""The base class subclassed by all Reflex classes.
This class wraps Pydantic and provides common methods such as

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Any, Iterator
from reflex.components.component import Component, LiteralComponentVar
from reflex.components.component import Component, ComponentStyle
from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless
from reflex.config import PerformanceMode, environment
@ -12,7 +12,7 @@ from reflex.utils import console
from reflex.utils.decorator import once
from reflex.utils.imports import ParsedImportDict
from reflex.vars import BooleanVar, ObjectVar, Var
from reflex.vars.base import VarData
from reflex.vars.base import GLOBAL_CACHE, VarData
from reflex.vars.sequence import LiteralStringVar
@ -80,8 +80,11 @@ class Bare(Component):
The hooks for the component.
"""
hooks = super()._get_all_hooks_internal()
if isinstance(self.contents, LiteralComponentVar):
hooks |= self.contents._var_value._get_all_hooks_internal()
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
hooks |= component._get_all_hooks_internal()
return hooks
def _get_all_hooks(self) -> dict[str, VarData | None]:
@ -91,18 +94,24 @@ class Bare(Component):
The hooks for the component.
"""
hooks = super()._get_all_hooks()
if isinstance(self.contents, LiteralComponentVar):
hooks |= self.contents._var_value._get_all_hooks()
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
hooks |= component._get_all_hooks()
return hooks
def _get_all_imports(self) -> ParsedImportDict:
def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict:
"""Include the imports for the component.
Args:
collapse: Whether to collapse the imports.
Returns:
The imports for the component.
"""
imports = super()._get_all_imports()
if isinstance(self.contents, LiteralComponentVar):
imports = super()._get_all_imports(collapse=collapse)
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
imports |= {k: list(v) for k, v in var_data.imports}
@ -115,8 +124,11 @@ class Bare(Component):
The dynamic imports.
"""
dynamic_imports = super()._get_all_dynamic_imports()
if isinstance(self.contents, LiteralComponentVar):
dynamic_imports |= self.contents._var_value._get_all_dynamic_imports()
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
dynamic_imports |= component._get_all_dynamic_imports()
return dynamic_imports
def _get_all_custom_code(self) -> set[str]:
@ -126,10 +138,28 @@ class Bare(Component):
The custom code.
"""
custom_code = super()._get_all_custom_code()
if isinstance(self.contents, LiteralComponentVar):
custom_code |= self.contents._var_value._get_all_custom_code()
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
custom_code |= component._get_all_custom_code()
return custom_code
def _get_all_app_wrap_components(self) -> dict[tuple[int, str], Component]:
"""Get the components that should be wrapped in the app.
Returns:
The components that should be wrapped in the app.
"""
app_wrap_components = super()._get_all_app_wrap_components()
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
if isinstance(component, Component):
app_wrap_components |= component._get_all_app_wrap_components()
return app_wrap_components
def _get_all_refs(self) -> set[str]:
"""Get the refs for the children of the component.
@ -137,8 +167,11 @@ class Bare(Component):
The refs for the children.
"""
refs = super()._get_all_refs()
if isinstance(self.contents, LiteralComponentVar):
refs |= self.contents._var_value._get_all_refs()
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
refs |= component._get_all_refs()
return refs
def _render(self) -> Tag:
@ -148,6 +181,30 @@ class Bare(Component):
return Tagless(contents=f"{{{self.contents!s}}}")
return Tagless(contents=str(self.contents))
def _add_style_recursive(
self, style: ComponentStyle, theme: Component | None = None
) -> Component:
"""Add style to the component and its children.
Args:
style: The style to add.
theme: The theme to add.
Returns:
The component with the style added.
"""
new_self = super()._add_style_recursive(style, theme)
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
if isinstance(component, Component):
component._add_style_recursive(style, theme)
GLOBAL_CACHE.clear()
return new_self
def _get_vars(
self, include_children: bool = False, ignore_ids: set[int] | None = None
) -> Iterator[Var]:

View File

@ -65,8 +65,7 @@ from reflex.vars.base import (
Var,
cached_property_no_lock,
)
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar
from reflex.vars.number import ternary_operation
from reflex.vars.function import FunctionStringVar
from reflex.vars.object import ObjectVar
from reflex.vars.sequence import LiteralArrayVar
@ -889,10 +888,8 @@ class Component(BaseComponent, ABC):
children: The children of the component.
"""
from reflex.components.base.bare import Bare
from reflex.components.base.fragment import Fragment
from reflex.components.core.cond import Cond
from reflex.components.core.foreach import Foreach
from reflex.components.core.match import Match
no_valid_parents_defined = all(child._valid_parents == [] for child in children)
if (
@ -903,9 +900,7 @@ class Component(BaseComponent, ABC):
return
comp_name = type(self).__name__
allowed_components = [
comp.__name__ for comp in (Fragment, Foreach, Cond, Match)
]
allowed_components = [comp.__name__ for comp in (Fragment,)]
def validate_child(child: Any):
child_name = type(child).__name__
@ -915,24 +910,39 @@ class Component(BaseComponent, ABC):
for c in child.children:
validate_child(c)
if isinstance(child, Cond):
validate_child(child.comp1)
validate_child(child.comp2)
if isinstance(child, Match):
for cases in child.match_cases:
validate_child(cases[-1])
validate_child(child.default)
if (
isinstance(child, Bare)
and child.contents is not None
and isinstance(child.contents, Var)
):
var_data = child.contents._get_all_var_data()
if var_data is not None:
for c in var_data.components:
validate_child(c)
if self._invalid_children and child_name in self._invalid_children:
raise ValueError(
f"The component `{comp_name}` cannot have `{child_name}` as a child component"
)
if self._valid_children and child_name not in [
*self._valid_children,
*allowed_components,
]:
valid_children = self._valid_children + allowed_components
def child_is_in_valid(child_component: Any):
if type(child_component).__name__ in valid_children:
return True
if (
not isinstance(child_component, Bare)
or child_component.contents is None
or not isinstance(child_component.contents, Var)
or (var_data := child_component.contents._get_all_var_data())
is None
):
return False
return all(child_is_in_valid(c) for c in var_data.components)
if self._valid_children and not child_is_in_valid(child):
valid_child_list = ", ".join(
[f"`{v_child}`" for v_child in self._valid_children]
)
@ -1918,8 +1928,6 @@ class StatefulComponent(BaseComponent):
Returns:
The stateful component or None if the component should not be memoized.
"""
from reflex.components.core.foreach import Foreach
if component._memoization_mode.disposition == MemoizationDisposition.NEVER:
# Never memoize this component.
return None
@ -1948,10 +1956,6 @@ class StatefulComponent(BaseComponent):
# Skip BaseComponent and StatefulComponent children.
if not isinstance(child, Component):
continue
# Always consider Foreach something that must be memoized by the parent.
if isinstance(child, Foreach):
should_memoize = True
break
child = cls._child_var(child)
if isinstance(child, Var) and child._get_all_var_data():
should_memoize = True
@ -2001,18 +2005,9 @@ class StatefulComponent(BaseComponent):
The Var from the child component or the child itself (for regular cases).
"""
from reflex.components.base.bare import Bare
from reflex.components.core.cond import Cond
from reflex.components.core.foreach import Foreach
from reflex.components.core.match import Match
if isinstance(child, Bare):
return child.contents
if isinstance(child, Cond):
return child.cond
if isinstance(child, Foreach):
return child.iterable
if isinstance(child, Match):
return child.cond
return child
@classmethod
@ -2359,53 +2354,6 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
return render_dict_to_var(tag.render(), imported_names)
return Var.create(tag)
if "iterable" in tag:
function_return = Var.create(
[
render_dict_to_var(child.render(), imported_names)
for child in tag["children"]
]
)
func = ArgsFunctionOperation.create(
(tag["arg_var_name"], tag["index_var_name"]),
function_return,
)
return FunctionStringVar.create("Array.prototype.map.call").call(
tag["iterable"]
if not isinstance(tag["iterable"], ObjectVar)
else tag["iterable"].items(),
func,
)
if tag["name"] == "match":
element = tag["cond"]
conditionals = render_dict_to_var(tag["default"], imported_names)
for case in tag["match_cases"][::-1]:
condition = case[0].to_string() == element.to_string()
for pattern in case[1:-1]:
condition = condition | (pattern.to_string() == element.to_string())
conditionals = ternary_operation(
condition,
render_dict_to_var(case[-1], imported_names),
conditionals,
)
return conditionals
if "cond" in tag:
return ternary_operation(
tag["cond"],
render_dict_to_var(tag["true_value"], imported_names),
render_dict_to_var(tag["false_value"], imported_names)
if tag["false_value"] is not None
else Var.create(None),
)
props = {}
special_props = []
@ -2485,17 +2433,14 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
"@emotion/react": [
ImportVar(tag="jsx"),
],
}
),
VarData(
imports=self._var_value._get_all_imports(),
),
VarData(
imports={
"react": [
ImportVar(tag="Fragment"),
],
}
},
components=(self._var_value,),
),
VarData(
imports=self._var_value._get_all_imports(),
),
)

View File

@ -21,16 +21,14 @@ _SUBMOD_ATTRS: dict[str, list[str]] = {
"colors": [
"color",
],
"cond": ["Cond", "color_mode_cond", "cond"],
"cond": ["color_mode_cond", "cond"],
"debounce": ["DebounceInput", "debounce_input"],
"foreach": [
"foreach",
"Foreach",
],
"html": ["html", "Html"],
"match": [
"match",
"Match",
],
"breakpoints": ["breakpoints", "set_breakpoints"],
"responsive": [

View File

@ -17,16 +17,13 @@ from .breakpoints import set_breakpoints as set_breakpoints
from .clipboard import Clipboard as Clipboard
from .clipboard import clipboard as clipboard
from .colors import color as color
from .cond import Cond as Cond
from .cond import color_mode_cond as color_mode_cond
from .cond import cond as cond
from .debounce import DebounceInput as DebounceInput
from .debounce import debounce_input as debounce_input
from .foreach import Foreach as Foreach
from .foreach import foreach as foreach
from .html import Html as Html
from .html import html as html
from .match import Match as Match
from .match import match as match
from .responsive import desktop_only as desktop_only
from .responsive import mobile_and_tablet as mobile_and_tablet

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from typing import Optional
from reflex import constants
from reflex.components.base.fragment import Fragment
from reflex.components.component import Component
from reflex.components.core.cond import cond
from reflex.components.el.elements.typography import Div
@ -163,7 +164,7 @@ class ConnectionToaster(Toaster):
return super().create(*children, **props)
class ConnectionBanner(Component):
class ConnectionBanner(Fragment):
"""A connection banner component."""
@classmethod
@ -190,10 +191,10 @@ class ConnectionBanner(Component):
position="fixed",
)
return cond(has_connection_errors, comp)
return super().create(cond(has_connection_errors, comp))
class ConnectionModal(Component):
class ConnectionModal(Fragment):
"""A connection status modal window."""
@classmethod
@ -208,16 +209,18 @@ class ConnectionModal(Component):
"""
if not comp:
comp = Text.create(*default_connection_error())
return cond(
has_too_many_connection_errors,
DialogRoot.create(
DialogContent.create(
DialogTitle.create("Connection Error"),
comp,
return super().create(
cond(
has_too_many_connection_errors,
DialogRoot.create(
DialogContent.create(
DialogTitle.create("Connection Error"),
comp,
),
open=has_too_many_connection_errors,
z_index=9999,
),
open=has_too_many_connection_errors,
z_index=9999,
),
)
)

View File

@ -5,6 +5,7 @@
# ------------------------------------------------------
from typing import Any, Dict, Literal, Optional, Union, overload
from reflex.components.base.fragment import Fragment
from reflex.components.component import Component
from reflex.components.el.elements.typography import Div
from reflex.components.lucide.icon import Icon
@ -137,7 +138,7 @@ class ConnectionToaster(Toaster):
"""
...
class ConnectionBanner(Component):
class ConnectionBanner(Fragment):
@overload
@classmethod
def create( # type: ignore
@ -176,7 +177,7 @@ class ConnectionBanner(Component):
"""
...
class ConnectionModal(Component):
class ConnectionModal(Fragment):
@overload
@classmethod
def create( # type: ignore

View File

@ -41,7 +41,7 @@ class ClientSideRouting(Component):
return ""
def wait_for_client_redirect(component: Component) -> Component:
def wait_for_client_redirect(component: Component) -> Var[Component]:
"""Wait for a redirect to occur before rendering a component.
This prevents the 404 page from flashing while the redirect is happening.
@ -53,9 +53,9 @@ def wait_for_client_redirect(component: Component) -> Component:
The conditionally rendered component.
"""
return cond(
condition=route_not_found,
c1=component,
c2=ClientSideRouting.create(),
route_not_found,
component,
ClientSideRouting.create(),
)

View File

@ -2,126 +2,31 @@
from __future__ import annotations
from typing import Any, Dict, Optional, overload
from typing import Any, TypeVar, Union, overload
from reflex.components.base.fragment import Fragment
from reflex.components.component import BaseComponent, Component, MemoizationLeaf
from reflex.components.tags import CondTag, Tag
from reflex.constants import Dirs
from reflex.components.component import BaseComponent, Component
from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import VarData
from reflex.utils import types
from reflex.vars.base import LiteralVar, Var
from reflex.vars.number import ternary_operation
_IS_TRUE_IMPORT: ImportDict = {
f"$/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
}
@overload
def cond(
condition: Any, c1: BaseComponent | Var[BaseComponent], c2: Any = None, /
) -> Var[Component]: ...
class Cond(MemoizationLeaf):
"""Render one of two components based on a condition."""
# The cond to determine which component to render.
cond: Var[Any]
# The component to render if the cond is true.
comp1: BaseComponent | None = None
# The component to render if the cond is false.
comp2: BaseComponent | None = None
@classmethod
def create(
cls,
cond: Var,
comp1: BaseComponent,
comp2: Optional[BaseComponent] = None,
) -> Component:
"""Create a conditional component.
Args:
cond: The cond to determine which component to render.
comp1: The component to render if the cond is true.
comp2: The component to render if the cond is false.
Returns:
The conditional component.
"""
# Wrap everything in fragments.
if type(comp1).__name__ != "Fragment":
comp1 = Fragment.create(comp1)
if comp2 is None or type(comp2).__name__ != "Fragment":
comp2 = Fragment.create(comp2) if comp2 else Fragment.create()
return Fragment.create(
cls(
cond=cond,
comp1=comp1,
comp2=comp2,
children=[comp1, comp2],
)
)
def _get_props_imports(self):
"""Get the imports needed for component's props.
Returns:
The imports for the component's props of the component.
"""
return []
def _render(self) -> Tag:
return CondTag(
cond=self.cond,
true_value=self.comp1.render(), # pyright: ignore [reportOptionalMemberAccess]
false_value=self.comp2.render(), # pyright: ignore [reportOptionalMemberAccess]
)
def render(self) -> Dict:
"""Render the component.
Returns:
The dictionary for template of component.
"""
tag = self._render()
return dict(
tag.add_props(
**self.event_triggers,
key=self.key,
sx=self.style,
id=self.id,
class_name=self.class_name,
).set(
props=tag.format_props(),
),
cond_state=f"isTrue({self.cond!s})",
)
def add_imports(self) -> ImportDict:
"""Add imports for the Cond component.
Returns:
The import dict for the component.
"""
var_data = VarData.merge(self.cond._get_all_var_data())
imports = var_data.old_school_imports() if var_data else {}
return {**imports, **_IS_TRUE_IMPORT}
T = TypeVar("T")
V = TypeVar("V")
@overload
def cond(condition: Any, c1: Component, c2: Any) -> Component: ... # pyright: ignore [reportOverlappingOverload]
def cond(condition: Any, c1: T | Var[T], c2: V | Var[V], /) -> Var[T | V]: ...
@overload
def cond(condition: Any, c1: Component) -> Component: ...
@overload
def cond(condition: Any, c1: Any, c2: Any) -> Var: ...
def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
def cond(condition: Any, c1: Any, c2: Any = None, /) -> Var:
"""Create a conditional component or Prop.
Args:
@ -137,48 +42,40 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
"""
# Convert the condition to a Var.
cond_var = LiteralVar.create(condition)
if cond_var is None:
raise ValueError("The condition must be set.")
# If the first component is a component, create a Cond component.
if isinstance(c1, BaseComponent):
if c2 is not None and not isinstance(c2, BaseComponent):
raise ValueError("Both arguments must be components.")
return Cond.create(cond_var, c1, c2)
# If the first component is a component, create a Fragment if the second component is not set.
if isinstance(c1, BaseComponent) or (
isinstance(c1, Var)
and types.safe_typehint_issubclass(
c1._var_type, Union[BaseComponent, list[BaseComponent]]
)
):
c2 = c2 if c2 is not None else Fragment.create()
# Otherwise, create a conditional Var.
# Check that the second argument is valid.
if isinstance(c2, BaseComponent):
raise ValueError("Both arguments must be props.")
if c2 is None:
raise ValueError("For conditional vars, the second argument must be set.")
def create_var(cond_part: Any) -> Var[Any]:
return LiteralVar.create(cond_part)
# convert the truth and false cond parts into vars so the _var_data can be obtained.
c1 = create_var(c1)
c2 = create_var(c2)
# Create the conditional var.
return ternary_operation(
cond_var.bool()._replace(
merge_var_data=VarData(imports=_IS_TRUE_IMPORT),
),
cond_var.bool(),
c1,
c2,
)
@overload
def color_mode_cond(light: Component, dark: Component | None = None) -> Component: ... # pyright: ignore [reportOverlappingOverload]
def color_mode_cond(
light: BaseComponent | Var[BaseComponent],
dark: BaseComponent | Var[BaseComponent] | None = ...,
) -> Var[Component]: ...
@overload
def color_mode_cond(light: Any, dark: Any = None) -> Var: ...
def color_mode_cond(light: T | Var[T], dark: V | Var[V]) -> Var[T | V]: ...
def color_mode_cond(light: Any, dark: Any = None) -> Var | Component:
def color_mode_cond(light: Any, dark: Any = None) -> Var:
"""Create a component or Prop based on color_mode.
Args:
@ -193,3 +90,9 @@ def color_mode_cond(light: Any, dark: Any = None) -> Var | Component:
light,
dark,
)
class Cond:
"""Create a conditional component or Prop."""
create = staticmethod(cond)

View File

@ -2,16 +2,9 @@
from __future__ import annotations
import functools
import inspect
from typing import Any, Callable, Iterable
from typing import Callable, Iterable
from reflex.components.base.fragment import Fragment
from reflex.components.component import Component
from reflex.components.tags import IterTag
from reflex.constants import MemoizationMode
from reflex.state import ComponentState
from reflex.utils.exceptions import UntypedVarError
from reflex.vars import ArrayVar, ObjectVar, StringVar
from reflex.vars.base import LiteralVar, Var
@ -23,149 +16,42 @@ class ForeachRenderError(TypeError):
"""Raised when there is an error with the foreach render function."""
class Foreach(Component):
"""A component that takes in an iterable and a render function and renders a list of components."""
def foreach(
iterable: Var[Iterable] | Iterable,
render_fn: Callable,
) -> Var:
"""Create a foreach component.
_memoization_mode = MemoizationMode(recursive=False)
Args:
iterable: The iterable to create components from.
render_fn: A function from the render args to the component.
# The iterable to create components from.
iterable: Var[Iterable]
Returns:
The foreach component.
# A function from the render args to the component.
render_fn: Callable = Fragment.create
Raises:
ForeachVarError: If the iterable is of type Any.
TypeError: If the render function is a ComponentState.
UntypedVarError: If the iterable is of type Any without a type annotation.
"""
iterable = LiteralVar.create(iterable).guess_type()
@classmethod
def create(
cls,
iterable: Var[Iterable] | Iterable,
render_fn: Callable,
) -> Foreach:
"""Create a foreach component.
if isinstance(iterable, ObjectVar):
iterable = iterable.entries()
Args:
iterable: The iterable to create components from.
render_fn: A function from the render args to the component.
if isinstance(iterable, StringVar):
iterable = iterable.split()
Returns:
The foreach component.
Raises:
ForeachVarError: If the iterable is of type Any.
TypeError: If the render function is a ComponentState.
UntypedVarError: If the iterable is of type Any without a type annotation.
"""
from reflex.vars import ArrayVar, ObjectVar, StringVar
iterable = LiteralVar.create(iterable).guess_type()
if iterable._var_type == Any:
raise ForeachVarError(
f"Could not foreach over var `{iterable!s}` of type Any. "
"(If you are trying to foreach over a state var, add a type annotation to the var). "
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
)
if (
hasattr(render_fn, "__qualname__")
and render_fn.__qualname__ == ComponentState.create.__qualname__
):
raise TypeError(
"Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet."
)
if isinstance(iterable, ObjectVar):
iterable = iterable.entries()
if isinstance(iterable, StringVar):
iterable = iterable.split()
if not isinstance(iterable, ArrayVar):
raise ForeachVarError(
f"Could not foreach over var `{iterable!s}` of type {iterable._var_type}. "
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
)
component = cls(
iterable=iterable,
render_fn=render_fn,
)
try:
# Keep a ref to a rendered component to determine correct imports/hooks/styles.
component.children = [component._render().render_component()]
except UntypedVarError as e:
raise UntypedVarError(
f"Could not foreach over var `{iterable!s}` without a type annotation. "
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
) from e
return component
def _render(self) -> IterTag:
props = {}
render_sig = inspect.signature(self.render_fn)
params = list(render_sig.parameters.values())
# Validate the render function signature.
if len(params) == 0 or len(params) > 2:
raise ForeachRenderError(
"Expected 1 or 2 parameters in foreach render function, got "
f"{[p.name for p in params]}. See "
"https://reflex.dev/docs/library/dynamic-rendering/foreach/"
)
if len(params) >= 1:
# Determine the arg var name based on the params accepted by render_fn.
props["arg_var_name"] = params[0].name
if len(params) == 2:
# Determine the index var name based on the params accepted by render_fn.
props["index_var_name"] = params[1].name
else:
render_fn = self.render_fn
# Otherwise, use a deterministic index, based on the render function bytecode.
code_hash = (
hash(
getattr(
render_fn,
"__code__",
(
repr(self.render_fn)
if not isinstance(render_fn, functools.partial)
else render_fn.func.__code__
),
)
)
.to_bytes(
length=8,
byteorder="big",
signed=True,
)
.hex()
)
props["index_var_name"] = f"index_{code_hash}"
return IterTag(
iterable=self.iterable,
render_fn=self.render_fn,
children=self.children,
**props,
if not isinstance(iterable, ArrayVar):
raise ForeachVarError(
f"Could not foreach over var `{iterable!s}` of type {iterable._var_type}. "
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
)
def render(self):
"""Render the component.
Returns:
The dictionary for template of component.
"""
tag = self._render()
return dict(
tag,
iterable_state=str(tag.iterable),
arg_name=tag.arg_var_name,
arg_index=tag.get_index_var_arg(),
iterable_type=tag.iterable._var_type.mro()[0].__name__,
)
return iterable.foreach(render_fn)
foreach = Foreach.create
class Foreach:
"""Create a foreach component."""
create = staticmethod(foreach)

View File

@ -1,274 +1,161 @@
"""rx.match."""
import textwrap
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Union, cast
from typing_extensions import Unpack
from reflex.components.base import Fragment
from reflex.components.component import BaseComponent, Component, MemoizationLeaf
from reflex.components.tags import MatchTag, Tag
from reflex.style import Style
from reflex.utils import format, types
from reflex.components.component import BaseComponent
from reflex.utils import types
from reflex.utils.exceptions import MatchTypeError
from reflex.utils.imports import ImportDict
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
from reflex.vars.base import VAR_TYPE, Var
from reflex.vars.number import MatchOperation
CASE_TYPE = tuple[Unpack[tuple[Any, ...]], Var[VAR_TYPE] | VAR_TYPE]
class Match(MemoizationLeaf):
"""Match cases based on a condition."""
def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
"""Process the individual match cases.
# The condition to determine which case to match.
cond: Var[Any]
Args:
cases: The match cases.
# The list of match cases to be matched.
match_cases: List[Any] = []
# The catchall case to match.
default: Any
@classmethod
def create(cls, cond: Any, *cases) -> Union[Component, Var]:
"""Create a Match Component.
Args:
cond: The condition to determine which case to match.
cases: This list of cases to match.
Returns:
The match component.
Raises:
ValueError: When a default case is not provided for cases with Var return types.
"""
match_cond_var = cls._create_condition_var(cond)
cases, default = cls._process_cases(list(cases))
match_cases = cls._process_match_cases(cases)
cls._validate_return_types(match_cases)
if default is None and types._issubclass(type(match_cases[0][-1]), Var):
Raises:
ValueError: If the default case is not the last case or the tuple elements are less than 2.
"""
for case in cases:
if not isinstance(case, tuple):
raise ValueError(
"For cases with return types as Vars, a default case must be provided"
"rx.match should have tuples of cases and a default case as the last argument."
)
return cls._create_match_cond_var_or_component(
match_cond_var, match_cases, default
)
@classmethod
def _create_condition_var(cls, cond: Any) -> Var:
"""Convert the condition to a Var.
Args:
cond: The condition.
Returns:
The condition as a base var
Raises:
ValueError: If the condition is not provided.
"""
match_cond_var = LiteralVar.create(cond)
if match_cond_var is None:
raise ValueError("The condition must be set")
return match_cond_var
@classmethod
def _process_cases(
cls, cases: List
) -> Tuple[List, Optional[Union[Var, BaseComponent]]]:
"""Process the list of match cases and the catchall default case.
Args:
cases: The list of match cases.
Returns:
The default case and the list of match case tuples.
Raises:
ValueError: If there are multiple default cases.
"""
default = None
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
raise ValueError("rx.match can only have one default case.")
if not cases:
raise ValueError("rx.match should have at least one case.")
# Get the default case which should be the last non-tuple arg
if not isinstance(cases[-1], tuple):
default = cases.pop()
default = (
cls._create_case_var_with_var_data(default)
if not isinstance(default, BaseComponent)
else default
# There should be at least two elements in a case tuple(a condition and return value)
if len(case) < 2:
raise ValueError(
"A case tuple should have at least a match case element and a return value."
)
return cases, default
@classmethod
def _create_case_var_with_var_data(cls, case_element: Any) -> Var:
"""Convert a case element into a Var.If the case
is a Style type, we extract the var data and merge it with the
newly created Var.
def _validate_return_types(*return_values: Any) -> bool:
"""Validate that match cases have the same return types.
Args:
case_element: The case element.
Args:
return_values: The return values of the match cases.
Returns:
The case element Var.
"""
_var_data = case_element._var_data if isinstance(case_element, Style) else None
case_element = LiteralVar.create(case_element, _var_data=_var_data)
return case_element
Returns:
True if all cases have the same return types.
@classmethod
def _process_match_cases(cls, cases: List) -> List[List[Var]]:
"""Process the individual match cases.
Raises:
MatchTypeError: If the return types of cases are different.
"""
Args:
cases: The match cases.
Returns:
The processed match cases.
Raises:
ValueError: If the default case is not the last case or the tuple elements are less than 2.
"""
match_cases = []
for case in cases:
if not isinstance(case, tuple):
raise ValueError(
"rx.match should have tuples of cases and a default case as the last argument."
)
# There should be at least two elements in a case tuple(a condition and return value)
if len(case) < 2:
raise ValueError(
"A case tuple should have at least a match case element and a return value."
)
case_list = []
for element in case:
# convert all non component element to vars.
el = (
cls._create_case_var_with_var_data(element)
if not isinstance(element, BaseComponent)
else element
)
if not isinstance(el, (Var, BaseComponent)):
raise ValueError("Case element must be a var or component")
case_list.append(el)
match_cases.append(case_list)
return match_cases
@classmethod
def _validate_return_types(cls, match_cases: List[List[Var]]) -> None:
"""Validate that match cases have the same return types.
Args:
match_cases: The match cases.
Raises:
MatchTypeError: If the return types of cases are different.
"""
first_case_return = match_cases[0][-1]
return_type = type(first_case_return)
if isinstance(first_case_return, BaseComponent):
return_type = BaseComponent
elif isinstance(first_case_return, Var):
return_type = Var
for index, case in enumerate(match_cases):
if not types._issubclass(type(case[-1]), return_type):
raise MatchTypeError(
f"Match cases should have the same return types. Case {index} with return "
f"value `{case[-1]._js_expr if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`"
f" of type {type(case[-1])!r} is not {return_type}"
)
@classmethod
def _create_match_cond_var_or_component(
cls,
match_cond_var: Var,
match_cases: List[List[Var]],
default: Optional[Union[Var, BaseComponent]],
) -> Union[Component, Var]:
"""Create and return the match condition var or component.
Args:
match_cond_var: The match condition.
match_cases: The list of match cases.
default: The default case.
Returns:
The match component wrapped in a fragment or the match var.
Raises:
ValueError: If the return types are not vars when creating a match var for Var types.
"""
if default is None and types._issubclass(
type(match_cases[0][-1]), BaseComponent
):
default = Fragment.create()
if types._issubclass(type(match_cases[0][-1]), BaseComponent):
return Fragment.create(
cls(
cond=match_cond_var,
match_cases=match_cases,
default=default,
children=[case[-1] for case in match_cases] + [default], # pyright: ignore [reportArgumentType]
)
def is_component_or_component_var(obj: Any) -> bool:
return types._isinstance(obj, BaseComponent) or (
isinstance(obj, Var)
and types.safe_typehint_issubclass(
obj._var_type, Union[list[BaseComponent], BaseComponent]
)
# Validate the match cases (as well as the default case) to have Var return types.
if any(
case for case in match_cases if not isinstance(case[-1], Var)
) or not isinstance(default, Var):
raise ValueError("Return types of match cases should be Vars.")
return Var(
_js_expr=format.format_match(
cond=str(match_cond_var),
match_cases=match_cases,
default=default, # pyright: ignore [reportArgumentType]
),
_var_type=default._var_type, # pyright: ignore [reportAttributeAccessIssue,reportOptionalMemberAccess]
_var_data=VarData.merge(
match_cond_var._get_all_var_data(),
*[el._get_all_var_data() for case in match_cases for el in case],
default._get_all_var_data(), # pyright: ignore [reportAttributeAccessIssue, reportOptionalMemberAccess]
),
)
def _render(self) -> Tag:
return MatchTag(
cond=self.cond, match_cases=self.match_cases, default=self.default
def type_of_return_value(obj: Any) -> Any:
if isinstance(obj, Var):
return obj._var_type
return type(obj)
is_return_type_component = [
is_component_or_component_var(return_type) for return_type in return_values
]
if any(is_return_type_component) and not all(is_return_type_component):
non_component_return_types = [
(type_of_return_value(return_value), i)
for i, return_value in enumerate(return_values)
if not is_return_type_component[i]
]
raise MatchTypeError(
"Match cases should have the same return types. "
+ "Expected return types to be of type Component or Var[Component]. "
+ ". ".join(
[
f"Return type of case {i} is {return_type}"
for return_type, i in non_component_return_types
]
)
)
def render(self) -> Dict:
"""Render the component.
Returns:
The dictionary for template of component.
"""
tag = self._render()
tag.name = "match"
return dict(tag)
def add_imports(self) -> ImportDict:
"""Add imports for the Match component.
Returns:
The import dict.
"""
var_data = VarData.merge(self.cond._get_all_var_data())
return var_data.old_school_imports() if var_data else {}
return all(is_return_type_component)
match = Match.create
def _create_match_var(
match_cond_var: Var,
match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
default: VAR_TYPE | Var[VAR_TYPE],
) -> Var[VAR_TYPE]:
"""Create the match var.
Args:
match_cond_var: The match condition var.
match_cases: The match cases.
default: The default case.
Returns:
The match var.
"""
return MatchOperation.create(match_cond_var, match_cases, default)
def match(
cond: Any,
*cases: Unpack[
tuple[Unpack[tuple[CASE_TYPE[VAR_TYPE], ...]], Var[VAR_TYPE] | VAR_TYPE]
],
) -> Var[VAR_TYPE]:
"""Create a match var.
Args:
cond: The condition to match.
cases: The match cases. Each case should be a tuple with the first elements as the match case and the last element as the return value. The last argument should be the default case.
Returns:
The match var.
Raises:
ValueError: If the default case is not the last case or the tuple elements are less than 2.
"""
default = types.Unset()
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
raise ValueError("rx.match can only have one default case.")
if not cases:
raise ValueError("rx.match should have at least one case.")
# Get the default case which should be the last non-tuple arg
if not isinstance(cases[-1], tuple):
default = cases[-1]
actual_cases = cases[:-1]
else:
actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
_process_match_cases(actual_cases)
is_component_match = _validate_return_types(
*[case[-1] for case in actual_cases],
*([default] if not isinstance(default, types.Unset) else []),
)
if isinstance(default, types.Unset) and not is_component_match:
raise ValueError(
"For cases with return types as Vars, a default case must be provided"
)
if isinstance(default, types.Unset):
default = Fragment.create()
default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
return _create_match_var(
cond,
actual_cases,
default,
)

View File

@ -29,7 +29,7 @@ from reflex.event import (
from reflex.utils import format
from reflex.utils.imports import ImportVar
from reflex.vars import VarData
from reflex.vars.base import CallableVar, Var, get_unique_variable_name
from reflex.vars.base import Var, get_unique_variable_name
from reflex.vars.sequence import LiteralStringVar
DEFAULT_UPLOAD_ID: str = "default"
@ -45,7 +45,6 @@ upload_files_context_var_data: VarData = VarData(
)
@CallableVar
def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> Var:
"""Get the file upload drop trigger.
@ -75,7 +74,6 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> Var:
)
@CallableVar
def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> Var:
"""Get the list of selected files.

View File

@ -13,14 +13,12 @@ from reflex.event import CallableEventSpec, EventSpec, EventType
from reflex.style import Style
from reflex.utils.imports import ImportVar
from reflex.vars import VarData
from reflex.vars.base import CallableVar, Var
from reflex.vars.base import Var
DEFAULT_UPLOAD_ID: str
upload_files_context_var_data: VarData
@CallableVar
def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> Var: ...
@CallableVar
def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> Var: ...
@CallableEventSpec
def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ...

View File

@ -303,7 +303,7 @@ class DataEditor(NoSSRComponent):
# Fired when editing is finished.
on_finished_editing: EventHandler[
passthrough_event_spec(Union[GridCell, None], tuple[int, int])
passthrough_event_spec(Union[GridCell, None], tuple[int, int]) # pyright: ignore[reportArgumentType]
]
# Fired when a row is appended.

View File

@ -828,7 +828,7 @@ class ShikiHighLevelCodeBlock(ShikiCodeBlock):
if isinstance(code, Var):
return string_replace_operation(
code, StringVar(_js_expr=f"/{regex_pattern}/g", _var_type=str), ""
)
).guess_type()
if isinstance(code, str):
return re.sub(regex_pattern, "", code)

View File

@ -114,8 +114,8 @@ class MarkdownComponentMap:
explicit_return = explicit_return or cls._explicit_return
return ArgsFunctionOperation.create(
args_names=(DestructuredArg(fields=tuple(fn_args)),),
return_expr=fn_body,
(DestructuredArg(fields=tuple(fn_args)),),
fn_body,
explicit_return=explicit_return,
_var_data=var_data,
)

View File

@ -188,7 +188,7 @@ class Slider(ComponentNamespace):
else:
children = [
track,
# Foreach.create(props.get("value"), lambda e: SliderThumb.create()), # foreach doesn't render Thumbs properly # noqa: ERA001
# foreach(props.get("value"), lambda e: SliderThumb.create()), # foreach doesn't render Thumbs properly # noqa: ERA001
]
return SliderRoot.create(*children, **props)

View File

@ -20,7 +20,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional, Union, get_args
from reflex.components.component import BaseComponent
from reflex.components.core.cond import Cond, color_mode_cond, cond
from reflex.components.core.cond import color_mode_cond, cond
from reflex.components.lucide.icon import Icon
from reflex.components.radix.themes.components.dropdown_menu import dropdown_menu
from reflex.components.radix.themes.components.switch import Switch
@ -40,28 +40,23 @@ DEFAULT_LIGHT_ICON: Icon = Icon.create(tag="sun")
DEFAULT_DARK_ICON: Icon = Icon.create(tag="moon")
class ColorModeIcon(Cond):
"""Displays the current color mode as an icon."""
def color_mode_icon(
light_component: BaseComponent | None = None,
dark_component: BaseComponent | None = None,
):
"""Create a color mode icon component.
@classmethod
def create(
cls,
light_component: BaseComponent | None = None,
dark_component: BaseComponent | None = None,
):
"""Create an icon component based on color_mode.
Args:
light_component: The component to render in light mode.
dark_component: The component to render in dark mode.
Args:
light_component: the component to display when color mode is default
dark_component: the component to display when color mode is dark (non-default)
Returns:
The conditionally rendered component
"""
return color_mode_cond(
light=light_component or DEFAULT_LIGHT_ICON,
dark=dark_component or DEFAULT_DARK_ICON,
)
Returns:
The color mode icon component.
"""
return color_mode_cond(
light=light_component or DEFAULT_LIGHT_ICON,
dark=dark_component or DEFAULT_DARK_ICON,
)
LiteralPosition = Literal["top-left", "top-right", "bottom-left", "bottom-right"]
@ -144,7 +139,7 @@ class ColorModeIconButton(IconButton):
if allow_system:
def color_mode_item(_color_mode: str):
def color_mode_item(_color_mode: Literal["light", "dark", "system"]):
return dropdown_menu.item(
_color_mode.title(), on_click=set_color_mode(_color_mode)
)
@ -152,7 +147,7 @@ class ColorModeIconButton(IconButton):
return dropdown_menu.root(
dropdown_menu.trigger(
super().create(
ColorModeIcon.create(),
color_mode_icon(),
),
**props,
),
@ -163,7 +158,7 @@ class ColorModeIconButton(IconButton):
),
)
return IconButton.create(
ColorModeIcon.create(),
color_mode_icon(),
on_click=toggle_color_mode,
**props,
)
@ -197,7 +192,7 @@ class ColorModeSwitch(Switch):
class ColorModeNamespace(Var):
"""Namespace for color mode components."""
icon = staticmethod(ColorModeIcon.create)
icon = staticmethod(color_mode_icon)
button = staticmethod(ColorModeIconButton.create)
switch = staticmethod(ColorModeSwitch.create)

View File

@ -7,7 +7,6 @@ from typing import Any, Dict, List, Literal, Optional, Union, overload
from reflex.components.component import BaseComponent
from reflex.components.core.breakpoints import Breakpoints
from reflex.components.core.cond import Cond
from reflex.components.lucide.icon import Icon
from reflex.components.radix.themes.components.switch import Switch
from reflex.event import EventType
@ -19,48 +18,10 @@ from .components.icon_button import IconButton
DEFAULT_LIGHT_ICON: Icon
DEFAULT_DARK_ICON: Icon
class ColorModeIcon(Cond):
@overload
@classmethod
def create( # type: ignore
cls,
*children,
cond: Optional[Union[Any, Var[Any]]] = None,
comp1: Optional[BaseComponent] = None,
comp2: Optional[BaseComponent] = None,
style: Optional[Style] = None,
key: Optional[Any] = None,
id: Optional[Any] = None,
class_name: Optional[Any] = None,
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
on_blur: Optional[EventType[()]] = None,
on_click: Optional[EventType[()]] = None,
on_context_menu: Optional[EventType[()]] = None,
on_double_click: Optional[EventType[()]] = None,
on_focus: Optional[EventType[()]] = None,
on_mount: Optional[EventType[()]] = None,
on_mouse_down: Optional[EventType[()]] = None,
on_mouse_enter: Optional[EventType[()]] = None,
on_mouse_leave: Optional[EventType[()]] = None,
on_mouse_move: Optional[EventType[()]] = None,
on_mouse_out: Optional[EventType[()]] = None,
on_mouse_over: Optional[EventType[()]] = None,
on_mouse_up: Optional[EventType[()]] = None,
on_scroll: Optional[EventType[()]] = None,
on_unmount: Optional[EventType[()]] = None,
**props,
) -> "ColorModeIcon":
"""Create an icon component based on color_mode.
Args:
light_component: the component to display when color mode is default
dark_component: the component to display when color mode is dark (non-default)
Returns:
The conditionally rendered component
"""
...
def color_mode_icon(
light_component: BaseComponent | None = None,
dark_component: BaseComponent | None = None,
): ...
LiteralPosition = Literal["top-left", "top-right", "bottom-left", "bottom-right"]
position_values: List[str]
@ -440,7 +401,7 @@ class ColorModeSwitch(Switch):
...
class ColorModeNamespace(Var):
icon = staticmethod(ColorModeIcon.create)
icon = staticmethod(color_mode_icon)
button = staticmethod(ColorModeIconButton.create)
switch = staticmethod(ColorModeSwitch.create)

View File

@ -6,7 +6,7 @@ from typing import Literal
from reflex.components.component import Component
from reflex.components.core.breakpoints import Responsive
from reflex.components.core.match import Match
from reflex.components.core.match import match
from reflex.components.el import elements
from reflex.components.lucide import Icon
from reflex.style import Style
@ -77,7 +77,7 @@ class IconButton(elements.Button, RadixLoadingProp, RadixThemesComponent):
if isinstance(props["size"], str):
children[0].size = RADIX_TO_LUCIDE_SIZE[props["size"]]
else:
size_map_var = Match.create(
size_map_var = match(
props["size"],
*list(RADIX_TO_LUCIDE_SIZE.items()),
12,

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from typing import Any, Iterable, Literal, Union
from reflex.components.component import Component, ComponentNamespace
from reflex.components.core.foreach import Foreach
from reflex.components.core.foreach import foreach
from reflex.components.el.elements.typography import Li, Ol, Ul
from reflex.components.lucide.icon import Icon
from reflex.components.markdown.markdown import MarkdownComponentMap
@ -70,7 +70,7 @@ class BaseList(Component, MarkdownComponentMap):
if not children and items is not None:
if isinstance(items, Var):
children = [Foreach.create(items, ListItem.create)]
children = [foreach(items, ListItem.create)]
else:
children = [ListItem.create(item) for item in items]
props["direction"] = "column"

View File

@ -1,6 +1,3 @@
"""Representations for React tags."""
from .cond_tag import CondTag
from .iter_tag import IterTag
from .match_tag import MatchTag
from .tag import Tag

View File

@ -1,21 +0,0 @@
"""Tag to conditionally render components."""
import dataclasses
from typing import Any, Dict, Optional
from reflex.components.tags.tag import Tag
from reflex.vars.base import Var
@dataclasses.dataclass()
class CondTag(Tag):
"""A conditional tag."""
# The condition to determine which component to render.
cond: Var[Any] = dataclasses.field(default_factory=lambda: Var.create(True))
# The code to render if the condition is true.
true_value: Dict = dataclasses.field(default_factory=dict)
# The code to render if the condition is false.
false_value: Optional[Dict] = None

View File

@ -1,145 +0,0 @@
"""Tag to loop through a list of components."""
from __future__ import annotations
import dataclasses
import inspect
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, Union, get_args
from reflex.components.tags.tag import Tag
from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
if TYPE_CHECKING:
from reflex.components.component import Component
@dataclasses.dataclass()
class IterTag(Tag):
"""An iterator tag."""
# The var to iterate over.
iterable: Var[Iterable] = dataclasses.field(
default_factory=lambda: LiteralArrayVar.create([])
)
# The component render function for each item in the iterable.
render_fn: Callable = dataclasses.field(default_factory=lambda: lambda x: x)
# The name of the arg var.
arg_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
# The name of the index var.
index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
def get_iterable_var_type(self) -> Type:
"""Get the type of the iterable var.
Returns:
The type of the iterable var.
"""
iterable = self.iterable
try:
if iterable._var_type.mro()[0] is dict:
# Arg is a tuple of (key, value).
return Tuple[get_args(iterable._var_type)] # pyright: ignore [reportReturnType]
elif iterable._var_type.mro()[0] is tuple:
# Arg is a union of any possible values in the tuple.
return Union[get_args(iterable._var_type)] # pyright: ignore [reportReturnType]
else:
return get_args(iterable._var_type)[0]
except Exception:
return Any # pyright: ignore [reportReturnType]
def get_index_var(self) -> Var:
"""Get the index var for the tag (with curly braces).
This is used to reference the index var within the tag.
Returns:
The index var.
"""
return Var(
_js_expr=self.index_var_name,
_var_type=int,
).guess_type()
def get_arg_var(self) -> Var:
"""Get the arg var for the tag (with curly braces).
This is used to reference the arg var within the tag.
Returns:
The arg var.
"""
return Var(
_js_expr=self.arg_var_name,
_var_type=self.get_iterable_var_type(),
).guess_type()
def get_index_var_arg(self) -> Var:
"""Get the index var for the tag (without curly braces).
This is used to render the index var in the .map() function.
Returns:
The index var.
"""
return Var(
_js_expr=self.index_var_name,
_var_type=int,
).guess_type()
def get_arg_var_arg(self) -> Var:
"""Get the arg var for the tag (without curly braces).
This is used to render the arg var in the .map() function.
Returns:
The arg var.
"""
return Var(
_js_expr=self.arg_var_name,
_var_type=self.get_iterable_var_type(),
).guess_type()
def render_component(self) -> Component:
"""Render the component.
Raises:
ValueError: If the render function takes more than 2 arguments.
Returns:
The rendered component.
"""
# Import here to avoid circular imports.
from reflex.components.base.fragment import Fragment
from reflex.components.core.cond import Cond
from reflex.components.core.foreach import Foreach
# Get the render function arguments.
args = inspect.getfullargspec(self.render_fn).args
arg = self.get_arg_var()
index = self.get_index_var()
if len(args) == 1:
# If the render function doesn't take the index as an argument.
component = self.render_fn(arg)
else:
# If the render function takes the index as an argument.
if len(args) != 2:
raise ValueError("The render function must take 2 arguments.")
component = self.render_fn(arg, index)
# Nested foreach components or cond must be wrapped in fragments.
if isinstance(component, (Foreach, Cond)):
component = Fragment.create(component)
# If the component is a tuple, unpack and wrap it in a fragment.
if isinstance(component, tuple):
component = Fragment.create(*component)
# Set the component key.
if component.key is None:
component.key = index
return component

View File

@ -1,21 +0,0 @@
"""Tag to conditionally match cases."""
import dataclasses
from typing import Any, List
from reflex.components.tags.tag import Tag
from reflex.vars.base import Var
@dataclasses.dataclass()
class MatchTag(Tag):
"""A match tag."""
# The condition to determine which case to match.
cond: Var[Any] = dataclasses.field(default_factory=lambda: Var.create(True))
# The list of match cases to be matched.
match_cases: List[Any] = dataclasses.field(default_factory=list)
# The catchall case to match.
default: Any = dataclasses.field(default=Var.create(None))

View File

@ -600,14 +600,16 @@ stop_propagation = EventChain(events=[], args_spec=no_args_event_spec).stop_prop
prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default
T = TypeVar("T")
U = TypeVar("U")
EVENT_T = TypeVar("EVENT_T")
EVENT_U = TypeVar("EVENT_U")
Ts = TypeVarTuple("Ts")
class IdentityEventReturn(Generic[T], Protocol):
class IdentityEventReturn(Generic[Unpack[Ts]], Protocol):
"""Protocol for an identity event return."""
def __call__(self, *values: Var[T]) -> Tuple[Var[T], ...]:
def __call__(self, *values: Unpack[Ts]) -> tuple[Unpack[Ts]]:
"""Return the input values.
Args:
@ -620,22 +622,26 @@ class IdentityEventReturn(Generic[T], Protocol):
@overload
def passthrough_event_spec( # pyright: ignore [reportOverlappingOverload]
event_type: Type[T], /
) -> Callable[[Var[T]], Tuple[Var[T]]]: ...
def passthrough_event_spec(
event_type: Type[EVENT_T], /
) -> IdentityEventReturn[Var[EVENT_T]]: ...
@overload
def passthrough_event_spec(
event_type_1: Type[T], event_type2: Type[U], /
) -> Callable[[Var[T], Var[U]], Tuple[Var[T], Var[U]]]: ...
event_type_1: Type[EVENT_T], event_type2: Type[EVENT_U], /
) -> IdentityEventReturn[Var[EVENT_T], Var[EVENT_U]]: ...
@overload
def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: ...
def passthrough_event_spec(
*event_types: Unpack[tuple[Type[EVENT_T]]],
) -> IdentityEventReturn[Unpack[tuple[Var[EVENT_T], ...]]]: ...
def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: # pyright: ignore [reportInconsistentOverload]
def passthrough_event_spec( # pyright: ignore[reportInconsistentOverload]
*event_types: Type[EVENT_T],
) -> IdentityEventReturn[Unpack[tuple[Var[EVENT_T], ...]]]:
"""A helper function that returns the input event as output.
Args:
@ -645,7 +651,7 @@ def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: #
A function that returns the input event as output.
"""
def inner(*values: Var[T]) -> Tuple[Var[T], ...]:
def inner(*values: Var[EVENT_T]) -> Tuple[Var[EVENT_T], ...]:
return values
inner_type = tuple(Var[event_type] for event_type in event_types)
@ -780,7 +786,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
return None
fn.__qualname__ = name
fn.__signature__ = sig # pyright: ignore [reportFunctionMemberAccess]
fn.__signature__ = sig # pyright: ignore[reportFunctionMemberAccess]
return EventSpec(
handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
args=tuple(

View File

@ -607,8 +607,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if cls._item_is_event_handler(name, fn)
}
for mixin in cls._mixins(): # pyright: ignore [reportAssignmentType]
for name, value in mixin.__dict__.items():
for mixin_class in cls._mixins():
for name, value in mixin_class.__dict__.items():
if name in cls.inherited_vars:
continue
if is_computed_var(value):
@ -619,7 +619,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cls.computed_vars[newcv._js_expr] = newcv
cls.vars[newcv._js_expr] = newcv
continue
if types.is_backend_base_variable(name, mixin): # pyright: ignore [reportArgumentType]
if types.is_backend_base_variable(name, mixin_class):
cls.backend_vars[name] = copy.deepcopy(value)
continue
if events.get(name) is not None:
@ -3653,6 +3653,9 @@ def get_state_manager() -> StateManager:
return prerequisites.get_and_validate_app().app.state_manager
DATACLASS_FIELDS = getattr(dataclasses, "_FIELDS", "__dataclass_fields__")
class MutableProxy(wrapt.ObjectProxy):
"""A proxy for a mutable object that tracks changes."""
@ -3724,12 +3727,7 @@ class MutableProxy(wrapt.ObjectProxy):
cls.__dataclass_proxies__[wrapper_cls_name] = type(
wrapper_cls_name,
(cls,),
{
dataclasses._FIELDS: getattr( # pyright: ignore [reportAttributeAccessIssue]
wrapped_cls,
dataclasses._FIELDS, # pyright: ignore [reportAttributeAccessIssue]
),
},
{DATACLASS_FIELDS: getattr(wrapped_cls, DATACLASS_FIELDS)},
)
cls = cls.__dataclass_proxies__[wrapper_cls_name]
return super().__new__(cls)
@ -3878,11 +3876,11 @@ class MutableProxy(wrapt.ObjectProxy):
if (
isinstance(self.__wrapped__, Base)
and __name not in self.__never_wrap_base_attrs__
and hasattr(value, "__func__")
and (value_func := getattr(value, "__func__", None))
):
# Wrap methods called on Base subclasses, which might do _anything_
return wrapt.FunctionWrapper(
functools.partial(value.__func__, self), # pyright: ignore [reportFunctionMemberAccess]
functools.partial(value_func, self),
self._wrap_recursive_decorator,
)

View File

@ -12,7 +12,7 @@ from reflex.utils.exceptions import ReflexError
from reflex.utils.imports import ImportVar
from reflex.utils.types import get_origin
from reflex.vars import VarData
from reflex.vars.base import CallableVar, LiteralVar, Var
from reflex.vars.base import LiteralVar, Var
from reflex.vars.function import FunctionVar
from reflex.vars.object import ObjectVar
@ -48,7 +48,6 @@ def _color_mode_var(_js_expr: str, _var_type: Type = str) -> Var:
).guess_type()
@CallableVar
def set_color_mode(
new_color_mode: LiteralColorMode | Var[LiteralColorMode] | None = None,
) -> Var[EventChain]:

View File

@ -68,10 +68,8 @@ try:
from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports]
WebElement,
)
has_selenium = True
except ImportError:
has_selenium = False
webdriver = None
# The timeout (minutes) to check for the port.
DEFAULT_TIMEOUT = 15
@ -296,11 +294,13 @@ class AppHarness:
if p not in before_decorated_pages
]
self.app_instance = self.app_module.app
if self.app_instance and isinstance(
self.app_instance._state_manager, StateManagerRedis
):
if self.app_instance is None:
raise RuntimeError("App was not initialized.")
if isinstance(self.app_instance._state_manager, StateManagerRedis):
# Create our own redis connection for testing.
self.state_manager = StateManagerRedis.create(self.app_instance._state) # pyright: ignore [reportArgumentType]
if self.app_instance._state is None:
raise RuntimeError("App state is not initialized.")
self.state_manager = StateManagerRedis.create(self.app_instance._state)
else:
self.state_manager = (
self.app_instance._state_manager if self.app_instance else None
@ -615,7 +615,7 @@ class AppHarness:
Raises:
RuntimeError: when selenium is not importable or frontend is not running
"""
if not has_selenium:
if webdriver is None:
raise RuntimeError(
"Frontend functionality requires `selenium` to be installed, "
"and it could not be imported."

View File

@ -201,10 +201,13 @@ def _get_first_non_framework_frame() -> FrameType | None:
# Exclude utility modules that should never be the source of deprecated reflex usage.
exclude_modules = [click, rx, typer, typing_extensions]
exclude_roots = [
p.parent.resolve()
if (p := Path(m.__file__)).name == "__init__.py" # pyright: ignore [reportArgumentType]
else p.resolve()
(
p.parent.resolve()
if (p := Path(m.__file__)).name == "__init__.py"
else p.resolve()
)
for m in exclude_modules
if m.__file__
]
# Specifically exclude the reflex cli module.
if reflex_bin := shutil.which(b"reflex"):

View File

@ -4,9 +4,8 @@ from __future__ import annotations
import inspect
import json
import os
import re
from typing import TYPE_CHECKING, Any, List, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from reflex import constants
from reflex.constants.state import FRONTEND_EVENT_STATE
@ -136,22 +135,6 @@ def wrap(
return f"{open * num}{text}{close * num}"
def indent(text: str, indent_level: int = 2) -> str:
"""Indent the given text by the given indent level.
Args:
text: The text to indent.
indent_level: The indent level.
Returns:
The indented text.
"""
lines = text.splitlines()
if len(lines) < 2:
return text
return os.linesep.join(f"{' ' * indent_level}{line}" for line in lines) + os.linesep
def to_snake_case(text: str) -> str:
"""Convert a string to snake case.
@ -239,80 +222,6 @@ def make_default_page_title(app_name: str, route: str) -> str:
return to_title_case(title)
def _escape_js_string(string: str) -> str:
"""Escape the string for use as a JS string literal.
Args:
string: The string to escape.
Returns:
The escaped string.
"""
# TODO: we may need to re-vist this logic after new Var API is implemented.
def escape_outside_segments(segment: str):
"""Escape backticks in segments outside of `${}`.
Args:
segment: The part of the string to escape.
Returns:
The escaped or unescaped segment.
"""
if segment.startswith("${") and segment.endswith("}"):
# Return the `${}` segment unchanged
return segment
else:
# Escape backticks in the segment
segment = segment.replace(r"\`", "`")
segment = segment.replace("`", r"\`")
return segment
# Split the string into parts, keeping the `${}` segments
parts = re.split(r"(\$\{.*?\})", string)
escaped_parts = [escape_outside_segments(part) for part in parts]
escaped_string = "".join(escaped_parts)
return escaped_string
def _wrap_js_string(string: str) -> str:
"""Wrap string so it looks like {`string`}.
Args:
string: The string to wrap.
Returns:
The wrapped string.
"""
string = wrap(string, "`")
string = wrap(string, "{")
return string
def format_string(string: str) -> str:
"""Format the given string as a JS string literal..
Args:
string: The string to format.
Returns:
The formatted string.
"""
return _wrap_js_string(_escape_js_string(string))
def format_var(var: Var) -> str:
"""Format the given Var as a javascript value.
Args:
var: The Var to format.
Returns:
The formatted Var.
"""
return str(var)
def format_route(route: str, format_case: bool = True) -> str:
"""Format the given route.
@ -335,40 +244,6 @@ def format_route(route: str, format_case: bool = True) -> str:
return route
def format_match(
cond: str | Var,
match_cases: List[List[Var]],
default: Var,
) -> str:
"""Format a match expression whose return type is a Var.
Args:
cond: The condition.
match_cases: The list of cases to match.
default: The default case.
Returns:
The formatted match expression
"""
switch_code = f"(() => {{ switch (JSON.stringify({cond})) {{"
for case in match_cases:
conditions = case[:-1]
return_value = case[-1]
case_conditions = " ".join(
[f"case JSON.stringify({condition!s}):" for condition in conditions]
)
case_code = f"{case_conditions} return ({return_value!s}); break;"
switch_code += case_code
switch_code += f"default: return ({default!s}); break;"
switch_code += "};})()"
return switch_code
def format_prop(
prop: Union[Var, EventChain, ComponentStyle, str],
) -> Union[int, float, str]:

View File

@ -16,7 +16,7 @@ from itertools import chain
from multiprocessing import Pool, cpu_count
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin
from typing import Any, Callable, Iterable, Sequence, Type, cast, get_args, get_origin
from reflex.components.component import Component
from reflex.utils import types as rx_types
@ -230,7 +230,9 @@ def _generate_imports(
"""
return [
*[
ast.ImportFrom(module=name, names=[ast.alias(name=val) for val in values]) # pyright: ignore [reportCallIssue]
ast.ImportFrom(
module=name, names=[ast.alias(name=val) for val in values], level=0
)
for name, values in DEFAULT_IMPORTS.items()
],
ast.Import([ast.alias("reflex")]),
@ -429,18 +431,15 @@ def type_to_ast(typ: Any, cls: type) -> ast.AST:
return ast.Name(id=base_name)
# Convert all type arguments recursively
arg_nodes = [type_to_ast(arg, cls) for arg in args]
arg_nodes = cast(list[ast.expr], [type_to_ast(arg, cls) for arg in args])
# Special case for single-argument types (like List[T] or Optional[T])
if len(arg_nodes) == 1:
slice_value = arg_nodes[0]
else:
slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load()) # pyright: ignore [reportArgumentType]
slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load())
return ast.Subscript(
value=ast.Name(id=base_name),
slice=ast.Index(value=slice_value), # pyright: ignore [reportArgumentType]
ctx=ast.Load(),
value=ast.Name(id=base_name), slice=slice_value, ctx=ast.Load()
)
@ -635,7 +634,7 @@ def _generate_component_create_functiondef(
),
),
ast.Expr(
value=ast.Constant(value=Ellipsis),
value=ast.Constant(...),
),
],
decorator_list=[
@ -646,8 +645,8 @@ def _generate_component_create_functiondef(
else [ast.Name(id="classmethod")]
),
],
lineno=node.lineno if node is not None else None, # pyright: ignore [reportArgumentType]
returns=ast.Constant(value=clz.__name__),
lineno=node.lineno if node is not None else None, # pyright: ignore[reportArgumentType]
)
return definition
@ -695,7 +694,6 @@ def _generate_staticmethod_call_functiondef(
),
],
decorator_list=[ast.Name(id="staticmethod")],
lineno=node.lineno if node is not None else None, # pyright: ignore [reportArgumentType]
returns=ast.Constant(
value=_get_type_hint(
typing.get_type_hints(clz.__call__).get("return", None),
@ -703,6 +701,7 @@ def _generate_staticmethod_call_functiondef(
is_optional=False,
)
),
lineno=node.lineno if node is not None else None, # pyright: ignore[reportArgumentType]
)
return definition
@ -723,6 +722,9 @@ def _generate_namespace_call_functiondef(
Returns:
The create functiondef node for the ast.
Raises:
TypeError: If the __call__ method does not have a __func__.
"""
# add the imports needed by get_type_hint later
type_hint_globals.update(
@ -737,7 +739,12 @@ def _generate_namespace_call_functiondef(
# Determine which class is wrapped by the namespace __call__ method
component_clz = clz.__call__.__self__
if clz.__call__.__func__.__name__ != "create": # pyright: ignore [reportFunctionMemberAccess]
func = getattr(clz.__call__, "__func__", None)
if func is None:
raise TypeError(f"__call__ method on {clz_name} does not have a __func__")
if func.__name__ != "create":
return None
definition = _generate_component_create_functiondef(
@ -920,7 +927,7 @@ class StubGenerator(ast.NodeTransformer):
node.body.append(call_definition)
if not node.body:
# We should never return an empty body.
node.body.append(ast.Expr(value=ast.Constant(value=Ellipsis)))
node.body.append(ast.Expr(value=ast.Constant(...)))
self.current_class = None
return node
@ -947,9 +954,9 @@ class StubGenerator(ast.NodeTransformer):
if node.name.startswith("_") and node.name != "__call__":
return None # remove private methods
if node.body[-1] != ast.Expr(value=ast.Constant(value=Ellipsis)):
if node.body[-1] != ast.Expr(value=ast.Constant(...)):
# Blank out the function body for public functions.
node.body = [ast.Expr(value=ast.Constant(value=Ellipsis))]
node.body = [ast.Expr(value=ast.Constant(...))]
return node
def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:

View File

@ -7,6 +7,7 @@ import dataclasses
import inspect
import sys
import types
from collections import abc
from functools import cached_property, lru_cache, wraps
from typing import (
TYPE_CHECKING,
@ -23,6 +24,7 @@ from typing import (
Sequence,
Tuple,
Type,
TypeVar,
Union,
_GenericAlias, # pyright: ignore [reportAttributeAccessIssue]
get_args,
@ -31,6 +33,7 @@ from typing import (
from typing import get_origin as get_origin_og
import sqlalchemy
import typing_extensions
from typing_extensions import is_typeddict
import reflex
@ -68,13 +71,13 @@ else:
# Potential GenericAlias types for isinstance checks.
GenericAliasTypes = [_GenericAlias]
_GenericAliasTypes: list[type] = [_GenericAlias]
with contextlib.suppress(ImportError):
# For newer versions of Python.
from types import GenericAlias
GenericAliasTypes.append(GenericAlias)
_GenericAliasTypes.append(GenericAlias)
with contextlib.suppress(ImportError):
# For older versions of Python.
@ -82,9 +85,9 @@ with contextlib.suppress(ImportError):
_SpecialGenericAlias, # pyright: ignore [reportAttributeAccessIssue]
)
GenericAliasTypes.append(_SpecialGenericAlias)
_GenericAliasTypes.append(_SpecialGenericAlias)
GenericAliasTypes = tuple(GenericAliasTypes)
GenericAliasTypes = tuple(_GenericAliasTypes)
# Potential Union types for isinstance checks (UnionType added in py3.10).
UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,)
@ -183,7 +186,7 @@ def is_generic_alias(cls: GenericType) -> bool:
return isinstance(cls, GenericAliasTypes) # pyright: ignore [reportArgumentType]
def unionize(*args: GenericType) -> Type:
def unionize(*args: GenericType) -> GenericType:
"""Unionize the types.
Args:
@ -417,7 +420,7 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
@lru_cache()
def get_base_class(cls: GenericType) -> Type:
def get_base_class(cls: GenericType) -> Type | tuple[Type, ...]:
"""Get the base class of a class.
Args:
@ -437,7 +440,14 @@ def get_base_class(cls: GenericType) -> Type:
return type(get_args(cls)[0])
if is_union(cls):
return tuple(get_base_class(arg) for arg in get_args(cls)) # pyright: ignore [reportReturnType]
base_classes = []
for arg in get_args(cls):
sub_base_classes = get_base_class(arg)
if isinstance(sub_base_classes, tuple):
base_classes.extend(sub_base_classes)
else:
base_classes.append(sub_base_classes)
return tuple(base_classes)
return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls
@ -847,18 +857,22 @@ StateBases = get_base_class(StateVar)
StateIterBases = get_base_class(StateIterVar)
def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]):
"""Check if a class is a subclass of another class. Returns False if internal error occurs.
def safe_issubclass(cls: Any, class_or_tuple: Any, /) -> bool:
"""Check if a class is a subclass of another class or a tuple of classes.
Args:
cls: The class to check.
cls_check: The class to check against.
class_or_tuple: The class or tuple of classes to check against.
Returns:
Whether the class is a subclass of the other class.
Whether the class is a subclass of the other class or tuple of classes.
"""
if cls is class_or_tuple or (
isinstance(class_or_tuple, tuple) and cls in class_or_tuple
):
return True
try:
return issubclass(cls, cls_check)
return issubclass(cls, class_or_tuple)
except TypeError:
return False
@ -873,17 +887,32 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
Returns:
Whether the type hint is a subclass of the other type hint.
"""
if isinstance(possible_subclass, Sequence) and isinstance(
possible_superclass, Sequence
):
return all(
typehint_issubclass(subclass, superclass)
for subclass, superclass in zip(
possible_subclass, possible_superclass, strict=False
)
)
if possible_subclass is possible_superclass:
return True
if possible_superclass is Any:
return True
if possible_subclass is Any:
return False
if isinstance(
possible_subclass, (TypeVar, typing_extensions.TypeVar)
) or isinstance(possible_superclass, (TypeVar, typing_extensions.TypeVar)):
return True
provided_type_origin = get_origin(possible_subclass)
accepted_type_origin = get_origin(possible_superclass)
if provided_type_origin is None and accepted_type_origin is None:
# In this case, we are dealing with a non-generic type, so we can use issubclass
return issubclass(possible_subclass, possible_superclass)
return safe_issubclass(possible_subclass, possible_superclass)
# Remove this check when Python 3.10 is the minimum supported version
if hasattr(types, "UnionType"):
@ -898,24 +927,64 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
provided_args = get_args(possible_subclass)
accepted_args = get_args(possible_superclass)
if accepted_type_origin is Union:
if provided_type_origin is not Union:
return any(
typehint_issubclass(possible_subclass, accepted_arg)
for accepted_arg in accepted_args
)
if provided_type_origin is Union:
return all(
any(
typehint_issubclass(provided_arg, accepted_arg)
for accepted_arg in accepted_args
)
typehint_issubclass(provided_arg, possible_superclass)
for provided_arg in provided_args
)
if accepted_type_origin is Union:
return any(
typehint_issubclass(possible_subclass, accepted_arg)
for accepted_arg in accepted_args
)
# Check specifically for Sequence and Iterable
if (accepted_type_origin or possible_superclass) in (
Sequence,
abc.Sequence,
Iterable,
abc.Iterable,
):
iterable_type = accepted_args[0] if accepted_args else Any
if provided_type_origin is None:
if not safe_issubclass(
possible_subclass, (accepted_type_origin or possible_superclass)
):
return False
if safe_issubclass(possible_subclass, str) and not isinstance(
iterable_type, TypeVar
):
return typehint_issubclass(str, iterable_type)
return True
if not safe_issubclass(
provided_type_origin, (accepted_type_origin or possible_superclass)
):
return False
if not isinstance(iterable_type, (TypeVar, typing_extensions.TypeVar)):
if provided_type_origin in (list, tuple, set):
# Ensure all specific types are compatible with accepted types
return all(
typehint_issubclass(provided_arg, iterable_type)
for provided_arg in provided_args
if provided_arg is not ... # Ellipsis in Tuples
)
if possible_subclass is dict:
# Ensure all specific types are compatible with accepted types
return all(
typehint_issubclass(provided_arg, iterable_type)
for provided_arg in provided_args[:1]
)
return True
# Check if the origin of both types is the same (e.g., list for List[int])
# This probably should be issubclass instead of ==
if (provided_type_origin or possible_subclass) != (
accepted_type_origin or possible_superclass
if not safe_issubclass(
provided_type_origin or possible_subclass,
accepted_type_origin or possible_superclass,
):
return False
@ -927,5 +996,21 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
for provided_arg, accepted_arg in zip(
provided_args, accepted_args, strict=False
)
if accepted_arg is not Any
if accepted_arg is not Any and not isinstance(accepted_arg, TypeVar)
)
def safe_typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
"""Check if a type hint is a subclass of another type hint.
Args:
possible_subclass: The type hint to check.
possible_superclass: The type hint to check against.
Returns:
Whether the type hint is a subclass of the other type hint.
"""
try:
return typehint_issubclass(possible_subclass, possible_superclass)
except Exception:
return False

File diff suppressed because it is too large Load Diff

View File

@ -4,10 +4,7 @@ from __future__ import annotations
import dataclasses
from datetime import date, datetime
from typing import Any, NoReturn, TypeVar, Union, overload
from reflex.utils.exceptions import VarTypeError
from reflex.vars.number import BooleanVar
from typing import TypeVar, Union
from .base import (
CustomVarOperationReturn,
@ -23,156 +20,11 @@ DATETIME_T = TypeVar("DATETIME_T", datetime, date)
datetime_types = Union[datetime, date]
def raise_var_type_error():
"""Raise a VarTypeError.
Raises:
VarTypeError: Cannot compare a datetime object with a non-datetime object.
"""
raise VarTypeError("Cannot compare a datetime object with a non-datetime object.")
class DateTimeVar(Var[DATETIME_T], python_types=(datetime, date)):
"""A variable that holds a datetime or date object."""
@overload
def __lt__(self, other: datetime_types) -> BooleanVar: ...
@overload
def __lt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __lt__(self, other: Any):
"""Less than comparison.
Args:
other: The other datetime to compare.
Returns:
The result of the comparison.
"""
if not isinstance(other, DATETIME_TYPES):
raise_var_type_error()
return date_lt_operation(self, other)
@overload
def __le__(self, other: datetime_types) -> BooleanVar: ...
@overload
def __le__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __le__(self, other: Any):
"""Less than or equal comparison.
Args:
other: The other datetime to compare.
Returns:
The result of the comparison.
"""
if not isinstance(other, DATETIME_TYPES):
raise_var_type_error()
return date_le_operation(self, other)
@overload
def __gt__(self, other: datetime_types) -> BooleanVar: ...
@overload
def __gt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __gt__(self, other: Any):
"""Greater than comparison.
Args:
other: The other datetime to compare.
Returns:
The result of the comparison.
"""
if not isinstance(other, DATETIME_TYPES):
raise_var_type_error()
return date_gt_operation(self, other)
@overload
def __ge__(self, other: datetime_types) -> BooleanVar: ...
@overload
def __ge__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __ge__(self, other: Any):
"""Greater than or equal comparison.
Args:
other: The other datetime to compare.
Returns:
The result of the comparison.
"""
if not isinstance(other, DATETIME_TYPES):
raise_var_type_error()
return date_ge_operation(self, other)
@var_operation
def date_gt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
"""Greater than comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(rhs, lhs, strict=True)
@var_operation
def date_lt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
"""Less than comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(lhs, rhs, strict=True)
@var_operation
def date_le_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
"""Less than or equal comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(lhs, rhs)
@var_operation
def date_ge_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
"""Greater than or equal comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(rhs, lhs)
def date_compare_operation(
lhs: DateTimeVar[DATETIME_T] | Any,
rhs: DateTimeVar[DATETIME_T] | Any,
lhs: Var[datetime_types],
rhs: Var[datetime_types],
strict: bool = False,
) -> CustomVarOperationReturn:
) -> CustomVarOperationReturn[bool]:
"""Check if the value is less than the other value.
Args:
@ -189,6 +41,84 @@ def date_compare_operation(
)
@var_operation
def date_gt_operation(
lhs: Var[datetime_types],
rhs: Var[datetime_types],
) -> CustomVarOperationReturn:
"""Greater than comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(rhs, lhs, strict=True)
@var_operation
def date_lt_operation(
lhs: Var[datetime_types],
rhs: Var[datetime_types],
) -> CustomVarOperationReturn:
"""Less than comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(lhs, rhs, strict=True)
@var_operation
def date_le_operation(
lhs: Var[datetime_types], rhs: Var[datetime_types]
) -> CustomVarOperationReturn:
"""Less than or equal comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(lhs, rhs)
@var_operation
def date_ge_operation(
lhs: Var[datetime_types], rhs: Var[datetime_types]
) -> CustomVarOperationReturn:
"""Greater than or equal comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(rhs, lhs)
class DateTimeVar(Var[DATETIME_T], python_types=(datetime, date)):
"""A variable that holds a datetime or date object."""
__lt__ = date_lt_operation
__le__ = date_le_operation
__gt__ = date_gt_operation
__ge__ = date_ge_operation
@dataclasses.dataclass(
eq=False,
frozen=True,

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import dataclasses
import functools
import json
import math
from typing import (
@ -10,21 +11,30 @@ from typing import (
Any,
Callable,
NoReturn,
Type,
Sequence,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import Unpack
from reflex.constants.base import Dirs
from reflex.utils.exceptions import PrimitiveUnserializableToJSONError, VarTypeError
from reflex.utils.imports import ImportDict, ImportVar
from .base import (
VAR_TYPE,
CachedVarOperation,
CustomVarOperationReturn,
LiteralVar,
ReflexCallable,
Var,
VarData,
cached_property_no_lock,
nary_type_computer,
passthrough_unary_type_computer,
unionize,
var_operation,
var_operation_return,
@ -33,6 +43,7 @@ from .base import (
NUMBER_T = TypeVar("NUMBER_T", int, float, bool)
if TYPE_CHECKING:
from .function import FunctionVar
from .sequence import ArrayVar
@ -56,13 +67,7 @@ def raise_unsupported_operand_types(
class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""Base class for immutable number vars."""
@overload
def __add__(self, other: number_types) -> NumberVar: ...
@overload
def __add__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __add__(self, other: Any):
def __add__(self, other: number_types) -> NumberVar:
"""Add two numbers.
Args:
@ -73,15 +78,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("+", (type(self), type(other)))
return number_add_operation(self, +other)
return number_add_operation(self, +other).guess_type()
@overload
def __radd__(self, other: number_types) -> NumberVar: ...
@overload
def __radd__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __radd__(self, other: Any):
def __radd__(self, other: number_types) -> NumberVar:
"""Add two numbers.
Args:
@ -92,15 +91,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("+", (type(other), type(self)))
return number_add_operation(+other, self)
return number_add_operation(+other, self).guess_type()
@overload
def __sub__(self, other: number_types) -> NumberVar: ...
@overload
def __sub__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __sub__(self, other: Any):
def __sub__(self, other: number_types) -> NumberVar:
"""Subtract two numbers.
Args:
@ -112,15 +105,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("-", (type(self), type(other)))
return number_subtract_operation(self, +other)
return number_subtract_operation(self, +other).guess_type()
@overload
def __rsub__(self, other: number_types) -> NumberVar: ...
@overload
def __rsub__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __rsub__(self, other: Any):
def __rsub__(self, other: number_types) -> NumberVar:
"""Subtract two numbers.
Args:
@ -132,7 +119,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("-", (type(other), type(self)))
return number_subtract_operation(+other, self)
return number_subtract_operation(+other, self).guess_type()
def __abs__(self):
"""Get the absolute value of the number.
@ -167,7 +154,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("*", (type(self), type(other)))
return number_multiply_operation(self, +other)
return number_multiply_operation(self, +other).guess_type()
@overload
def __rmul__(self, other: number_types | boolean_types) -> NumberVar: ...
@ -194,15 +181,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("*", (type(other), type(self)))
return number_multiply_operation(+other, self)
return number_multiply_operation(+other, self).guess_type()
@overload
def __truediv__(self, other: number_types) -> NumberVar: ...
@overload
def __truediv__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __truediv__(self, other: Any):
def __truediv__(self, other: number_types) -> NumberVar:
"""Divide two numbers.
Args:
@ -214,15 +195,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("/", (type(self), type(other)))
return number_true_division_operation(self, +other)
return number_true_division_operation(self, +other).guess_type()
@overload
def __rtruediv__(self, other: number_types) -> NumberVar: ...
@overload
def __rtruediv__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __rtruediv__(self, other: Any):
def __rtruediv__(self, other: number_types) -> NumberVar:
"""Divide two numbers.
Args:
@ -234,15 +209,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("/", (type(other), type(self)))
return number_true_division_operation(+other, self)
return number_true_division_operation(+other, self).guess_type()
@overload
def __floordiv__(self, other: number_types) -> NumberVar: ...
@overload
def __floordiv__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __floordiv__(self, other: Any):
def __floordiv__(self, other: number_types) -> NumberVar:
"""Floor divide two numbers.
Args:
@ -254,15 +223,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("//", (type(self), type(other)))
return number_floor_division_operation(self, +other)
return number_floor_division_operation(self, +other).guess_type()
@overload
def __rfloordiv__(self, other: number_types) -> NumberVar: ...
@overload
def __rfloordiv__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __rfloordiv__(self, other: Any):
def __rfloordiv__(self, other: number_types) -> NumberVar:
"""Floor divide two numbers.
Args:
@ -274,15 +237,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("//", (type(other), type(self)))
return number_floor_division_operation(+other, self)
return number_floor_division_operation(+other, self).guess_type()
@overload
def __mod__(self, other: number_types) -> NumberVar: ...
@overload
def __mod__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __mod__(self, other: Any):
def __mod__(self, other: number_types) -> NumberVar:
"""Modulo two numbers.
Args:
@ -294,15 +251,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("%", (type(self), type(other)))
return number_modulo_operation(self, +other)
return number_modulo_operation(self, +other).guess_type()
@overload
def __rmod__(self, other: number_types) -> NumberVar: ...
@overload
def __rmod__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __rmod__(self, other: Any):
def __rmod__(self, other: number_types) -> NumberVar:
"""Modulo two numbers.
Args:
@ -314,15 +265,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("%", (type(other), type(self)))
return number_modulo_operation(+other, self)
return number_modulo_operation(+other, self).guess_type()
@overload
def __pow__(self, other: number_types) -> NumberVar: ...
@overload
def __pow__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __pow__(self, other: Any):
def __pow__(self, other: number_types) -> NumberVar:
"""Exponentiate two numbers.
Args:
@ -334,15 +279,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("**", (type(self), type(other)))
return number_exponent_operation(self, +other)
return number_exponent_operation(self, +other).guess_type()
@overload
def __rpow__(self, other: number_types) -> NumberVar: ...
@overload
def __rpow__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __rpow__(self, other: Any):
def __rpow__(self, other: number_types) -> NumberVar:
"""Exponentiate two numbers.
Args:
@ -354,7 +293,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("**", (type(other), type(self)))
return number_exponent_operation(+other, self)
return number_exponent_operation(+other, self).guess_type()
def __neg__(self):
"""Negate the number.
@ -362,7 +301,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
Returns:
The number negation operation.
"""
return number_negate_operation(self)
return number_negate_operation(self).guess_type()
def __invert__(self):
"""Boolean NOT the number.
@ -370,7 +309,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
Returns:
The boolean NOT operation.
"""
return boolean_not_operation(self.bool())
return boolean_not_operation(self.bool()).guess_type()
def __pos__(self) -> NumberVar:
"""Positive the number.
@ -386,7 +325,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
Returns:
The number round operation.
"""
return number_round_operation(self)
return number_round_operation(self).guess_type()
def __ceil__(self):
"""Ceil the number.
@ -394,7 +333,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
Returns:
The number ceil operation.
"""
return number_ceil_operation(self)
return number_ceil_operation(self).guess_type()
def __floor__(self):
"""Floor the number.
@ -402,7 +341,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
Returns:
The number floor operation.
"""
return number_floor_operation(self)
return number_floor_operation(self).guess_type()
def __trunc__(self):
"""Trunc the number.
@ -410,15 +349,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
Returns:
The number trunc operation.
"""
return number_trunc_operation(self)
return number_trunc_operation(self).guess_type()
@overload
def __lt__(self, other: number_types) -> BooleanVar: ...
@overload
def __lt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __lt__(self, other: Any):
def __lt__(self, other: number_types) -> BooleanVar:
"""Less than comparison.
Args:
@ -429,15 +362,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("<", (type(self), type(other)))
return less_than_operation(+self, +other)
return less_than_operation(+self, +other).guess_type()
@overload
def __le__(self, other: number_types) -> BooleanVar: ...
@overload
def __le__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __le__(self, other: Any):
def __le__(self, other: number_types) -> BooleanVar:
"""Less than or equal comparison.
Args:
@ -448,9 +375,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types("<=", (type(self), type(other)))
return less_than_or_equal_operation(+self, +other)
return less_than_or_equal_operation(+self, +other).guess_type()
def __eq__(self, other: Any):
def __eq__(self, other: Any) -> BooleanVar:
"""Equal comparison.
Args:
@ -460,10 +387,10 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
The result of the comparison.
"""
if isinstance(other, NUMBER_TYPES):
return equal_operation(+self, +other)
return equal_operation(self, other)
return equal_operation(+self, +other).guess_type()
return equal_operation(self, other).guess_type()
def __ne__(self, other: Any):
def __ne__(self, other: Any) -> BooleanVar:
"""Not equal comparison.
Args:
@ -473,16 +400,10 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
The result of the comparison.
"""
if isinstance(other, NUMBER_TYPES):
return not_equal_operation(+self, +other)
return not_equal_operation(self, other)
return not_equal_operation(+self, +other).guess_type()
return not_equal_operation(self, other).guess_type()
@overload
def __gt__(self, other: number_types) -> BooleanVar: ...
@overload
def __gt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __gt__(self, other: Any):
def __gt__(self, other: number_types) -> BooleanVar:
"""Greater than comparison.
Args:
@ -493,15 +414,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types(">", (type(self), type(other)))
return greater_than_operation(+self, +other)
return greater_than_operation(+self, +other).guess_type()
@overload
def __ge__(self, other: number_types) -> BooleanVar: ...
@overload
def __ge__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload]
def __ge__(self, other: Any):
def __ge__(self, other: number_types) -> BooleanVar:
"""Greater than or equal comparison.
Args:
@ -512,7 +427,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
if not isinstance(other, NUMBER_TYPES):
raise_unsupported_operand_types(">=", (type(self), type(other)))
return greater_than_or_equal_operation(+self, +other)
return greater_than_or_equal_operation(+self, +other).guess_type()
def _is_strict_float(self) -> bool:
"""Check if the number is a float.
@ -532,8 +447,8 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
def binary_number_operation(
func: Callable[[NumberVar, NumberVar], str],
) -> Callable[[number_types, number_types], NumberVar]:
func: Callable[[Var[int | float], Var[int | float]], str],
):
"""Decorator to create a binary number operation.
Args:
@ -543,30 +458,37 @@ def binary_number_operation(
The binary number operation.
"""
@var_operation
def operation(lhs: NumberVar, rhs: NumberVar):
def operation(
lhs: Var[int | float], rhs: Var[int | float]
) -> CustomVarOperationReturn[int | float]:
def type_computer(*args: Var):
if not args:
return (
ReflexCallable[[int | float, int | float], int | float],
type_computer,
)
if len(args) == 1:
return (
ReflexCallable[[int | float], int | float],
functools.partial(type_computer, args[0]),
)
return (
ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
None,
)
return var_operation_return(
js_expression=func(lhs, rhs),
var_type=unionize(lhs._var_type, rhs._var_type),
type_computer=type_computer,
)
def wrapper(lhs: number_types, rhs: number_types) -> NumberVar:
"""Create the binary number operation.
operation.__name__ = func.__name__
Args:
lhs: The first number.
rhs: The second number.
Returns:
The binary number operation.
"""
return operation(lhs, rhs) # pyright: ignore [reportReturnType, reportArgumentType]
return wrapper
return var_operation(operation)
@binary_number_operation
def number_add_operation(lhs: NumberVar, rhs: NumberVar):
def number_add_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Add two numbers.
Args:
@ -580,7 +502,7 @@ def number_add_operation(lhs: NumberVar, rhs: NumberVar):
@binary_number_operation
def number_subtract_operation(lhs: NumberVar, rhs: NumberVar):
def number_subtract_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Subtract two numbers.
Args:
@ -593,8 +515,15 @@ def number_subtract_operation(lhs: NumberVar, rhs: NumberVar):
return f"({lhs} - {rhs})"
unary_operation_type_computer = passthrough_unary_type_computer(
ReflexCallable[[int | float], int | float]
)
@var_operation
def number_abs_operation(value: NumberVar):
def number_abs_operation(
value: Var[int | float],
) -> CustomVarOperationReturn[int | float]:
"""Get the absolute value of the number.
Args:
@ -604,12 +533,14 @@ def number_abs_operation(value: NumberVar):
The number absolute operation.
"""
return var_operation_return(
js_expression=f"Math.abs({value})", var_type=value._var_type
js_expression=f"Math.abs({value})",
type_computer=unary_operation_type_computer,
_raw_js_function="Math.abs",
)
@binary_number_operation
def number_multiply_operation(lhs: NumberVar, rhs: NumberVar):
def number_multiply_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Multiply two numbers.
Args:
@ -624,7 +555,7 @@ def number_multiply_operation(lhs: NumberVar, rhs: NumberVar):
@var_operation
def number_negate_operation(
value: NumberVar[NUMBER_T],
value: Var[NUMBER_T],
) -> CustomVarOperationReturn[NUMBER_T]:
"""Negate the number.
@ -634,11 +565,13 @@ def number_negate_operation(
Returns:
The number negation operation.
"""
return var_operation_return(js_expression=f"-({value})", var_type=value._var_type)
return var_operation_return(
js_expression=f"-({value})", type_computer=unary_operation_type_computer
)
@binary_number_operation
def number_true_division_operation(lhs: NumberVar, rhs: NumberVar):
def number_true_division_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Divide two numbers.
Args:
@ -652,7 +585,7 @@ def number_true_division_operation(lhs: NumberVar, rhs: NumberVar):
@binary_number_operation
def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar):
def number_floor_division_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Floor divide two numbers.
Args:
@ -666,7 +599,7 @@ def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar):
@binary_number_operation
def number_modulo_operation(lhs: NumberVar, rhs: NumberVar):
def number_modulo_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Modulo two numbers.
Args:
@ -680,7 +613,7 @@ def number_modulo_operation(lhs: NumberVar, rhs: NumberVar):
@binary_number_operation
def number_exponent_operation(lhs: NumberVar, rhs: NumberVar):
def number_exponent_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Exponentiate two numbers.
Args:
@ -694,7 +627,7 @@ def number_exponent_operation(lhs: NumberVar, rhs: NumberVar):
@var_operation
def number_round_operation(value: NumberVar):
def number_round_operation(value: Var[int | float]):
"""Round the number.
Args:
@ -707,7 +640,7 @@ def number_round_operation(value: NumberVar):
@var_operation
def number_ceil_operation(value: NumberVar):
def number_ceil_operation(value: Var[int | float]):
"""Ceil the number.
Args:
@ -720,7 +653,7 @@ def number_ceil_operation(value: NumberVar):
@var_operation
def number_floor_operation(value: NumberVar):
def number_floor_operation(value: Var[int | float]):
"""Floor the number.
Args:
@ -729,11 +662,15 @@ def number_floor_operation(value: NumberVar):
Returns:
The number floor operation.
"""
return var_operation_return(js_expression=f"Math.floor({value})", var_type=int)
return var_operation_return(
js_expression=f"Math.floor({value})",
var_type=int,
_raw_js_function="Math.floor",
)
@var_operation
def number_trunc_operation(value: NumberVar):
def number_trunc_operation(value: Var[int | float]):
"""Trunc the number.
Args:
@ -754,7 +691,7 @@ class BooleanVar(NumberVar[bool], python_types=bool):
Returns:
The boolean NOT operation.
"""
return boolean_not_operation(self)
return boolean_not_operation(self).guess_type()
def __int__(self):
"""Convert the boolean to an int.
@ -762,7 +699,7 @@ class BooleanVar(NumberVar[bool], python_types=bool):
Returns:
The boolean to int operation.
"""
return boolean_to_number_operation(self)
return boolean_to_number_operation(self).guess_type()
def __pos__(self):
"""Convert the boolean to an int.
@ -770,7 +707,7 @@ class BooleanVar(NumberVar[bool], python_types=bool):
Returns:
The boolean to int operation.
"""
return boolean_to_number_operation(self)
return boolean_to_number_operation(self).guess_type()
def bool(self) -> BooleanVar:
"""Boolean conversion.
@ -826,7 +763,7 @@ class BooleanVar(NumberVar[bool], python_types=bool):
@var_operation
def boolean_to_number_operation(value: BooleanVar):
def boolean_to_number_operation(value: Var[bool]):
"""Convert the boolean to a number.
Args:
@ -835,12 +772,14 @@ def boolean_to_number_operation(value: BooleanVar):
Returns:
The boolean to number operation.
"""
return var_operation_return(js_expression=f"Number({value})", var_type=int)
return var_operation_return(
js_expression=f"Number({value})", var_type=int, _raw_js_function="Number"
)
def comparison_operator(
func: Callable[[Var, Var], str],
) -> Callable[[Var | Any, Var | Any], BooleanVar]:
) -> FunctionVar[ReflexCallable[[Any, Any], bool]]:
"""Decorator to create a comparison operation.
Args:
@ -850,26 +789,15 @@ def comparison_operator(
The comparison operation.
"""
@var_operation
def operation(lhs: Var, rhs: Var):
def operation(lhs: Var[Any], rhs: Var[Any]):
return var_operation_return(
js_expression=func(lhs, rhs),
var_type=bool,
)
def wrapper(lhs: Var | Any, rhs: Var | Any) -> BooleanVar:
"""Create the comparison operation.
operation.__name__ = func.__name__
Args:
lhs: The first value.
rhs: The second value.
Returns:
The comparison operation.
"""
return operation(lhs, rhs)
return wrapper
return var_operation(operation)
@comparison_operator
@ -957,7 +885,7 @@ def not_equal_operation(lhs: Var, rhs: Var):
@var_operation
def boolean_not_operation(value: BooleanVar):
def boolean_not_operation(value: Var[bool]):
"""Boolean NOT the boolean.
Args:
@ -1081,6 +1009,18 @@ _IS_TRUE_IMPORT: ImportDict = {
f"$/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
}
_AT_SLICE_IMPORT: ImportDict = {
f"$/{Dirs.STATE_PATH}": [ImportVar(tag="atSlice")],
}
_AT_SLICE_OR_INDEX: ImportDict = {
f"$/{Dirs.STATE_PATH}": [ImportVar(tag="atSliceOrIndex")],
}
_RANGE_IMPORT: ImportDict = {
f"$/{Dirs.UTILS}/helpers/range": [ImportVar(tag="range", is_default=True)],
}
@var_operation
def boolify(value: Var):
@ -1096,16 +1036,17 @@ def boolify(value: Var):
js_expression=f"isTrue({value})",
var_type=bool,
var_data=VarData(imports=_IS_TRUE_IMPORT),
_raw_js_function="isTrue",
)
T = TypeVar("T")
U = TypeVar("U")
T = TypeVar("T", bound=Any)
U = TypeVar("U", bound=Any)
@var_operation
def ternary_operation(
condition: BooleanVar, if_true: Var[T], if_false: Var[U]
condition: Var[bool], if_true: Var[T], if_false: Var[U]
) -> CustomVarOperationReturn[Union[T, U]]:
"""Create a ternary operation.
@ -1117,14 +1058,125 @@ def ternary_operation(
Returns:
The ternary operation.
"""
type_value: Union[Type[T], Type[U]] = unionize(
if_true._var_type, if_false._var_type
)
value: CustomVarOperationReturn[Union[T, U]] = var_operation_return(
js_expression=f"({condition} ? {if_true} : {if_false})",
var_type=type_value,
type_computer=nary_type_computer(
ReflexCallable[[bool, Any, Any], Any],
ReflexCallable[[Any, Any], Any],
ReflexCallable[[Any], Any],
computer=lambda args: unionize(args[1]._var_type, args[2]._var_type),
),
)
return value
TUPLE_ENDS_IN_VAR = tuple[Unpack[tuple[Var[Any], ...]], Var[VAR_TYPE]]
TUPLE_ENDS_IN_VAR_RELAXED = tuple[
Unpack[tuple[Var[Any] | Any, ...]], Var[VAR_TYPE] | VAR_TYPE
]
@dataclasses.dataclass(
eq=False,
frozen=True,
slots=True,
)
class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
"""Base class for immutable match operations."""
_cond: Var[bool] = dataclasses.field(
default_factory=lambda: LiteralBooleanVar.create(True)
)
_cases: tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...] = dataclasses.field(
default_factory=tuple
)
_default: Var[VAR_TYPE] = dataclasses.field( # pyright: ignore[reportAssignmentType]
default_factory=lambda: Var.create(None)
)
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""Get the name of the var.
Returns:
The name of the var.
"""
switch_code = f"(() => {{ switch (JSON.stringify({self._cond!s})) {{"
for case in self._cases:
conditions = case[:-1]
return_value = case[-1]
case_conditions = " ".join(
[f"case JSON.stringify({condition!s}):" for condition in conditions]
)
case_code = f"{case_conditions} return ({return_value!s}); break;"
switch_code += case_code
switch_code += f"default: return ({self._default!s}); break;"
switch_code += "};})()"
return switch_code
@cached_property_no_lock
def _cached_get_all_var_data(self) -> VarData | None:
"""Get the VarData for the var.
Returns:
The VarData for the var.
"""
return VarData.merge(
self._cond._get_all_var_data(),
*(
cond_or_return._get_all_var_data()
for case in self._cases
for cond_or_return in case
),
self._default._get_all_var_data(),
self._var_data,
)
@classmethod
def create(
cls,
cond: Any,
cases: Sequence[TUPLE_ENDS_IN_VAR_RELAXED[VAR_TYPE]],
default: Var[VAR_TYPE] | VAR_TYPE,
_var_data: VarData | None = None,
_var_type: type[VAR_TYPE] | None = None,
):
"""Create the match operation.
Args:
cond: The condition.
cases: The cases.
default: The default case.
_var_data: Additional hooks and imports associated with the Var.
_var_type: The type of the Var.
Returns:
The match operation.
"""
cond = Var.create(cond)
cases = cast(
tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
tuple(tuple(Var.create(c) for c in case) for case in cases),
)
_default = cast(Var[VAR_TYPE], Var.create(default))
var_type = _var_type or unionize(
*(case[-1]._var_type for case in cases),
_default._var_type,
)
return cls(
_js_expr="",
_var_data=_var_data,
_var_type=var_type,
_cond=cond,
_cases=cases,
_default=_default,
)
NUMBER_TYPES = (int, float, NumberVar)

View File

@ -10,6 +10,7 @@ from typing import (
List,
Mapping,
NoReturn,
Sequence,
Tuple,
Type,
TypeVar,
@ -27,15 +28,19 @@ from reflex.utils.types import (
get_attribute_access_type,
get_origin,
safe_issubclass,
unionize,
)
from .base import (
CachedVarOperation,
LiteralVar,
ReflexCallable,
Var,
VarData,
cached_property_no_lock,
figure_out_type,
nary_type_computer,
unary_type_computer,
var_operation,
var_operation_return,
)
@ -69,9 +74,9 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
) -> Type[VALUE_TYPE]: ...
@overload
def _value_type(self) -> Type: ...
def _value_type(self) -> GenericType: ...
def _value_type(self) -> Type:
def _value_type(self) -> GenericType:
"""Get the type of the values of the object.
Returns:
@ -83,18 +88,18 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
args = get_args(self._var_type) if issubclass(fixed_type, Mapping) else ()
return args[1] if args else Any # pyright: ignore [reportReturnType]
def keys(self) -> ArrayVar[List[str]]:
def keys(self) -> ArrayVar[Sequence[str]]:
"""Get the keys of the object.
Returns:
The keys of the object.
"""
return object_keys_operation(self)
return object_keys_operation(self).guess_type()
@overload
def values(
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[VALUE_TYPE]]: ...
) -> ArrayVar[Sequence[VALUE_TYPE]]: ...
@overload
def values(self) -> ArrayVar: ...
@ -105,12 +110,12 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
Returns:
The values of the object.
"""
return object_values_operation(self)
return object_values_operation(self).guess_type()
@overload
def entries(
self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ...
) -> ArrayVar[Sequence[Tuple[str, VALUE_TYPE]]]: ...
@overload
def entries(self) -> ArrayVar: ...
@ -121,7 +126,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
Returns:
The entries of the object.
"""
return object_entries_operation(self)
return object_entries_operation(self).guess_type()
items = entries
@ -167,15 +172,10 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
@overload
def __getitem__(
self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, Sequence[ARRAY_INNER_TYPE]]]
| ObjectVar[Mapping[Any, List[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
key: Var | Any,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
@ -227,15 +227,9 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
@overload
def __getattr__(
self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
self: ObjectVar[Mapping[Any, Sequence[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
name: str,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
@ -295,7 +289,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
Returns:
The result of the check.
"""
return object_has_own_property_operation(self, key)
return object_has_own_property_operation(self, key).guess_type()
@dataclasses.dataclass(
@ -310,7 +304,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
default_factory=dict
)
def _key_type(self) -> Type:
def _key_type(self) -> GenericType:
"""Get the type of the keys of the object.
Returns:
@ -319,7 +313,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
args_list = typing.get_args(self._var_type)
return args_list[0] if args_list else Any # pyright: ignore [reportReturnType]
def _value_type(self) -> Type:
def _value_type(self) -> GenericType:
"""Get the type of the values of the object.
Returns:
@ -416,7 +410,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
@var_operation
def object_keys_operation(value: ObjectVar):
def object_keys_operation(value: Var):
"""Get the keys of an object.
Args:
@ -428,11 +422,12 @@ def object_keys_operation(value: ObjectVar):
return var_operation_return(
js_expression=f"Object.keys({value})",
var_type=List[str],
_raw_js_function="Object.keys",
)
@var_operation
def object_values_operation(value: ObjectVar):
def object_values_operation(value: Var):
"""Get the values of an object.
Args:
@ -443,12 +438,17 @@ def object_values_operation(value: ObjectVar):
"""
return var_operation_return(
js_expression=f"Object.values({value})",
var_type=List[value._value_type()],
type_computer=unary_type_computer(
ReflexCallable[[Any], List[Any]],
lambda x: List[x.to(ObjectVar)._value_type()],
),
var_type=List[Any],
_raw_js_function="Object.values",
)
@var_operation
def object_entries_operation(value: ObjectVar):
def object_entries_operation(value: Var):
"""Get the entries of an object.
Args:
@ -457,14 +457,20 @@ def object_entries_operation(value: ObjectVar):
Returns:
The entries of the object.
"""
value = value.to(ObjectVar)
return var_operation_return(
js_expression=f"Object.entries({value})",
var_type=List[Tuple[str, value._value_type()]],
type_computer=unary_type_computer(
ReflexCallable[[Any], List[Tuple[str, Any]]],
lambda x: List[Tuple[str, x.to(ObjectVar)._value_type()]],
),
var_type=List[Tuple[str, Any]],
_raw_js_function="Object.entries",
)
@var_operation
def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
def object_merge_operation(lhs: Var, rhs: Var):
"""Merge two objects.
Args:
@ -476,10 +482,15 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
"""
return var_operation_return(
js_expression=f"({{...{lhs}, ...{rhs}}})",
var_type=Mapping[
Union[lhs._key_type(), rhs._key_type()],
Union[lhs._value_type(), rhs._value_type()],
],
type_computer=nary_type_computer(
ReflexCallable[[Any, Any], Mapping[Any, Any]],
ReflexCallable[[Any], Mapping[Any, Any]],
computer=lambda args: Mapping[
unionize(*[arg.to(ObjectVar)._key_type() for arg in args]),
unionize(*[arg.to(ObjectVar)._value_type() for arg in args]),
],
),
var_type=Mapping[Any, Any],
)
@ -536,7 +547,7 @@ class ObjectItemOperation(CachedVarOperation, Var):
@var_operation
def object_has_own_property_operation(object: ObjectVar, key: Var):
def object_has_own_property_operation(object: Var, key: Var):
"""Check if an object has a key.
Args:

File diff suppressed because it is too large Load Diff

View File

@ -28,9 +28,11 @@ def TestEventAction():
def on_click2(self):
self.order.append("on_click2")
@rx.event
def on_click_throttle(self):
self.order.append("on_click_throttle")
@rx.event
def on_click_debounce(self):
self.order.append("on_click_debounce")

View File

@ -17,45 +17,68 @@ def LifespanApp():
import reflex as rx
lifespan_task_global = 0
lifespan_context_global = 0
def create_tasks():
lifespan_task_global = 0
lifespan_context_global = 0
@asynccontextmanager
async def lifespan_context(app, inc: int = 1):
global lifespan_context_global
print(f"Lifespan context entered: {app}.")
lifespan_context_global += inc # pyright: ignore[reportUnboundVariable]
try:
yield
finally:
print("Lifespan context exited.")
lifespan_context_global += inc
async def lifespan_task(inc: int = 1):
global lifespan_task_global
print("Lifespan global started.")
try:
while True:
lifespan_task_global += inc # pyright: ignore[reportUnboundVariable, reportPossiblyUnboundVariable]
await asyncio.sleep(0.1)
except asyncio.CancelledError as ce:
print(f"Lifespan global cancelled: {ce}.")
lifespan_task_global = 0
class LifespanState(rx.State):
interval: int = 100
@rx.var(cache=False)
def task_global(self) -> int:
return lifespan_task_global
@rx.var(cache=False)
def context_global(self) -> int:
def lifespan_context_global_getter():
return lifespan_context_global
@rx.event
def tick(self, date):
pass
def lifespan_task_global_getter():
return lifespan_task_global
@asynccontextmanager
async def lifespan_context(app, inc: int = 1):
nonlocal lifespan_context_global
print(f"Lifespan context entered: {app}.")
lifespan_context_global += inc
try:
yield
finally:
print("Lifespan context exited.")
lifespan_context_global += inc
async def lifespan_task(inc: int = 1):
nonlocal lifespan_task_global
print("Lifespan global started.")
try:
while True:
lifespan_task_global += inc
await asyncio.sleep(0.1)
except asyncio.CancelledError as ce:
print(f"Lifespan global cancelled: {ce}.")
lifespan_task_global = 0
class LifespanState(rx.State):
interval: int = 100
@rx.var(cache=False)
def task_global(self) -> int:
return lifespan_task_global
@rx.var(cache=False)
def context_global(self) -> int:
return lifespan_context_global
@rx.event
def tick(self, date):
pass
return (
lifespan_task,
lifespan_context,
LifespanState,
lifespan_task_global_getter,
lifespan_context_global_getter,
)
(
lifespan_task,
lifespan_context,
LifespanState,
lifespan_task_global_getter,
lifespan_context_global_getter,
) = create_tasks()
def index():
return rx.vstack(
@ -113,13 +136,16 @@ async def test_lifespan(lifespan_app: AppHarness):
task_global = driver.find_element(By.ID, "task_global")
assert context_global.text == "2"
assert lifespan_app.app_module.lifespan_context_global == 2
assert lifespan_app.app_module.lifespan_context_global_getter() == 2
original_task_global_text = task_global.text
original_task_global_value = int(original_task_global_text)
lifespan_app.poll_for_content(task_global, exp_not_equal=original_task_global_text)
driver.find_element(By.ID, "toggle-tick").click() # avoid teardown errors
assert lifespan_app.app_module.lifespan_task_global > original_task_global_value
assert (
lifespan_app.app_module.lifespan_task_global_getter()
> original_task_global_value
)
assert int(task_global.text) > original_task_global_value
# Kill the backend
@ -129,5 +155,5 @@ async def test_lifespan(lifespan_app: AppHarness):
lifespan_app.backend_thread.join()
# Check that the lifespan tasks have been cancelled
assert lifespan_app.app_module.lifespan_task_global == 0
assert lifespan_app.app_module.lifespan_context_global == 4
assert lifespan_app.app_module.lifespan_task_global_getter() == 0
assert lifespan_app.app_module.lifespan_context_global_getter() == 4

View File

@ -87,7 +87,7 @@ def UploadFile():
),
rx.box(
rx.foreach(
rx.selected_files,
rx.selected_files(),
lambda f: rx.text(f, as_="p"),
),
id="selected_files",

View File

@ -61,7 +61,7 @@ def ColorToggleApp():
rx.icon(tag="moon", size=20),
value="dark",
),
on_change=set_color_mode,
on_change=set_color_mode(),
variant="classic",
radius="large",
value=color_mode,

View File

@ -25,6 +25,7 @@ def test_connection_banner():
"react",
"$/utils/context",
"$/utils/state",
"@emotion/react",
RadixThemesComponent().library or "",
"$/env.json",
)
@ -43,6 +44,7 @@ def test_connection_modal():
"react",
"$/utils/context",
"$/utils/state",
"@emotion/react",
RadixThemesComponent().library or "",
"$/env.json",
)

View File

@ -3,8 +3,7 @@ from typing import Any, Union
import pytest
from reflex.components.base.fragment import Fragment
from reflex.components.core.cond import Cond, cond
from reflex.components.core.cond import cond
from reflex.components.radix.themes.typography.text import Text
from reflex.state import BaseState
from reflex.utils.format import format_state_name
@ -40,32 +39,23 @@ def test_validate_cond(cond_state: BaseState):
Args:
cond_state: A fixture.
"""
cond_component = cond(
first_component = Text.create("cond is True")
second_component = Text.create("cond is False")
cond_var = cond(
cond_state.value,
Text.create("cond is True"),
Text.create("cond is False"),
first_component,
second_component,
)
cond_dict = cond_component.render() if type(cond_component) is Fragment else {}
assert cond_dict["name"] == "Fragment"
[condition] = cond_dict["children"]
assert condition["cond_state"] == f"isTrue({cond_state.get_full_name()}.value)"
assert isinstance(cond_var, Var)
assert (
str(cond_var)
== f'({cond_state.value.bool()!s} ? (jsx(RadixThemesText, ({{ ["as"] : "p" }}), (jsx(Fragment, ({{ }}), "cond is True")))) : (jsx(RadixThemesText, ({{ ["as"] : "p" }}), (jsx(Fragment, ({{ }}), "cond is False")))))'
)
# true value
true_value = condition["true_value"]
assert true_value["name"] == "Fragment"
[true_value_text] = true_value["children"]
assert true_value_text["name"] == "RadixThemesText"
assert true_value_text["children"][0]["contents"] == '{"cond is True"}'
# false value
false_value = condition["false_value"]
assert false_value["name"] == "Fragment"
[false_value_text] = false_value["children"]
assert false_value_text["name"] == "RadixThemesText"
assert false_value_text["children"][0]["contents"] == '{"cond is False"}'
var_data = cond_var._get_all_var_data()
assert var_data is not None
assert var_data.components == (first_component, second_component)
@pytest.mark.parametrize(
@ -99,22 +89,25 @@ def test_prop_cond(c1: Any, c2: Any):
assert str(prop_cond) == f"(true ? {c1!s} : {c2!s})"
def test_cond_no_mix():
"""Test if cond can't mix components and props."""
with pytest.raises(ValueError):
cond(True, LiteralVar.create("hello"), Text.create("world"))
def test_cond_mix():
"""Test if cond can mix components and props."""
x = cond(True, LiteralVar.create("hello"), Text.create("world"))
assert isinstance(x, Var)
assert (
str(x)
== '(true ? "hello" : (jsx(RadixThemesText, ({ ["as"] : "p" }), (jsx(Fragment, ({ }), "world")))))'
)
def test_cond_no_else():
"""Test if cond can be used without else."""
# Components should support the use of cond without else
comp = cond(True, Text.create("hello"))
assert isinstance(comp, Fragment)
comp = comp.children[0]
assert isinstance(comp, Cond)
assert comp.cond._decode() is True
assert comp.comp1.render() == Fragment.create(Text.create("hello")).render() # pyright: ignore [reportOptionalMemberAccess]
assert comp.comp2 == Fragment.create()
assert isinstance(comp, Var)
assert (
str(comp)
== '(true ? (jsx(RadixThemesText, ({ ["as"] : "p" }), (jsx(Fragment, ({ }), "hello")))) : (jsx(Fragment, ({ }))))'
)
# Props do not support the use of cond without else
with pytest.raises(ValueError):

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Set, Tuple, Union
from typing import Dict, List, Sequence, Set, Tuple, Union
import pydantic.v1
import pytest
@ -6,16 +6,11 @@ import pytest
from reflex import el
from reflex.base import Base
from reflex.components.component import Component
from reflex.components.core.foreach import (
Foreach,
ForeachRenderError,
ForeachVarError,
foreach,
)
from reflex.components.core.foreach import ForeachVarError, foreach
from reflex.components.radix.themes.layout.box import box
from reflex.components.radix.themes.typography.text import text
from reflex.state import BaseState, ComponentState
from reflex.vars.base import Var
from reflex.utils.exceptions import VarTypeError
from reflex.vars.number import NumberVar
from reflex.vars.sequence import ArrayVar
@ -125,7 +120,9 @@ def display_colors_set(color):
return box(text(color))
def display_nested_list_element(element: ArrayVar[List[str]], index: NumberVar[int]):
def display_nested_list_element(
element: ArrayVar[Sequence[str]], index: NumberVar[int]
):
assert element._var_type == List[str]
assert index._var_type is int
return box(text(element[index]))
@ -139,143 +136,35 @@ def display_color_index_tuple(color):
seen_index_vars = set()
@pytest.mark.parametrize(
"state_var, render_fn, render_dict",
[
(
ForEachState.colors_list,
display_color,
{
"iterable_state": f"{ForEachState.get_full_name()}.colors_list",
"iterable_type": "list",
},
),
(
ForEachState.colors_dict_list,
display_color_name,
{
"iterable_state": f"{ForEachState.get_full_name()}.colors_dict_list",
"iterable_type": "list",
},
),
(
ForEachState.colors_nested_dict_list,
display_shade,
{
"iterable_state": f"{ForEachState.get_full_name()}.colors_nested_dict_list",
"iterable_type": "list",
},
),
(
ForEachState.primary_color,
display_primary_colors,
{
"iterable_state": f"Object.entries({ForEachState.get_full_name()}.primary_color)",
"iterable_type": "list",
},
),
(
ForEachState.color_with_shades,
display_color_with_shades,
{
"iterable_state": f"Object.entries({ForEachState.get_full_name()}.color_with_shades)",
"iterable_type": "list",
},
),
(
ForEachState.nested_colors_with_shades,
display_nested_color_with_shades,
{
"iterable_state": f"Object.entries({ForEachState.get_full_name()}.nested_colors_with_shades)",
"iterable_type": "list",
},
),
(
ForEachState.nested_colors_with_shades,
display_nested_color_with_shades_v2,
{
"iterable_state": f"Object.entries({ForEachState.get_full_name()}.nested_colors_with_shades)",
"iterable_type": "list",
},
),
(
ForEachState.color_tuple,
display_color_tuple,
{
"iterable_state": f"{ForEachState.get_full_name()}.color_tuple",
"iterable_type": "tuple",
},
),
(
ForEachState.colors_set,
display_colors_set,
{
"iterable_state": f"{ForEachState.get_full_name()}.colors_set",
"iterable_type": "set",
},
),
(
ForEachState.nested_colors_list,
lambda el, i: display_nested_list_element(el, i),
{
"iterable_state": f"{ForEachState.get_full_name()}.nested_colors_list",
"iterable_type": "list",
},
),
(
ForEachState.color_index_tuple,
display_color_index_tuple,
{
"iterable_state": f"{ForEachState.get_full_name()}.color_index_tuple",
"iterable_type": "tuple",
},
),
],
)
def test_foreach_render(state_var, render_fn, render_dict):
"""Test that the foreach component renders without error.
Args:
state_var: the state var.
render_fn: The render callable
render_dict: return dict on calling `component.render`
"""
component = Foreach.create(state_var, render_fn)
rend = component.render()
assert rend["iterable_state"] == render_dict["iterable_state"]
assert rend["iterable_type"] == render_dict["iterable_type"]
# Make sure the index vars are unique.
arg_index = rend["arg_index"]
assert isinstance(arg_index, Var)
assert arg_index._js_expr not in seen_index_vars
assert arg_index._var_type is int
seen_index_vars.add(arg_index._js_expr)
def test_foreach_bad_annotations():
"""Test that the foreach component raises a ForeachVarError if the iterable is of type Any."""
with pytest.raises(ForeachVarError):
Foreach.create(
foreach(
ForEachState.bad_annotation_list,
lambda sublist: Foreach.create(sublist, lambda color: text(color)),
lambda sublist: foreach(sublist, lambda color: text(color)),
)
def test_foreach_no_param_in_signature():
"""Test that the foreach component raises a ForeachRenderError if no parameters are passed."""
with pytest.raises(ForeachRenderError):
Foreach.create(
ForEachState.colors_list,
lambda: text("color"),
)
"""Test that the foreach component DOES NOT raise an error if no parameters are passed."""
foreach(
ForEachState.colors_list,
lambda: text("color"),
)
def test_foreach_with_index():
"""Test that the foreach component works with an index."""
foreach(
ForEachState.colors_list,
lambda color, index: text(color, index),
)
def test_foreach_too_many_params_in_signature():
"""Test that the foreach component raises a ForeachRenderError if too many parameters are passed."""
with pytest.raises(ForeachRenderError):
Foreach.create(
with pytest.raises(VarTypeError):
foreach(
ForEachState.colors_list,
lambda color, index, extra: text(color),
)
@ -290,13 +179,13 @@ def test_foreach_component_styles():
)
)
component._add_style_recursive({box: {"color": "red"}})
assert 'css={({ ["color"] : "red" })}' in str(component)
assert '{ ["css"] : ({ ["color"] : "red" }) }' in str(component)
def test_foreach_component_state():
"""Test that using a component state to render in the foreach raises an error."""
with pytest.raises(TypeError):
Foreach.create(
foreach(
ForEachState.colors_list,
ComponentStateTest.create,
)
@ -304,7 +193,7 @@ def test_foreach_component_state():
def test_foreach_default_factory():
"""Test that the default factory is called."""
_ = Foreach.create(
_ = foreach(
ForEachState.default_factory_list,
lambda tag: text(tag.name),
)

View File

@ -1,10 +1,10 @@
from typing import List, Mapping, Tuple
import re
from typing import Tuple
import pytest
import reflex as rx
from reflex.components.component import Component
from reflex.components.core.match import Match
from reflex.components.core.match import match
from reflex.state import BaseState
from reflex.utils.exceptions import MatchTypeError
from reflex.vars.base import Var
@ -18,75 +18,6 @@ class MatchState(BaseState):
string: str = "random string"
def test_match_components():
"""Test matching cases with return values as components."""
match_case_tuples = (
(1, rx.text("first value")),
(2, 3, rx.text("second value")),
([1, 2], rx.text("third value")),
("random", rx.text("fourth value")),
({"foo": "bar"}, rx.text("fifth value")),
(MatchState.num + 1, rx.text("sixth value")),
rx.text("default value"),
)
match_comp = Match.create(MatchState.value, *match_case_tuples)
assert isinstance(match_comp, Component)
match_dict = match_comp.render()
assert match_dict["name"] == "Fragment"
[match_child] = match_dict["children"]
assert match_child["name"] == "match"
assert str(match_child["cond"]) == f"{MatchState.get_name()}.value"
match_cases = match_child["match_cases"]
assert len(match_cases) == 6
assert match_cases[0][0]._js_expr == "1"
assert match_cases[0][0]._var_type is int
first_return_value_render = match_cases[0][1]
assert first_return_value_render["name"] == "RadixThemesText"
assert first_return_value_render["children"][0]["contents"] == '{"first value"}'
assert match_cases[1][0]._js_expr == "2"
assert match_cases[1][0]._var_type is int
assert match_cases[1][1]._js_expr == "3"
assert match_cases[1][1]._var_type is int
second_return_value_render = match_cases[1][2]
assert second_return_value_render["name"] == "RadixThemesText"
assert second_return_value_render["children"][0]["contents"] == '{"second value"}'
assert match_cases[2][0]._js_expr == "[1, 2]"
assert match_cases[2][0]._var_type == List[int]
third_return_value_render = match_cases[2][1]
assert third_return_value_render["name"] == "RadixThemesText"
assert third_return_value_render["children"][0]["contents"] == '{"third value"}'
assert match_cases[3][0]._js_expr == '"random"'
assert match_cases[3][0]._var_type is str
fourth_return_value_render = match_cases[3][1]
assert fourth_return_value_render["name"] == "RadixThemesText"
assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}'
assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })'
assert match_cases[4][0]._var_type == Mapping[str, str]
fifth_return_value_render = match_cases[4][1]
assert fifth_return_value_render["name"] == "RadixThemesText"
assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}'
assert match_cases[5][0]._js_expr == f"({MatchState.get_name()}.num + 1)"
assert match_cases[5][0]._var_type is int
fifth_return_value_render = match_cases[5][1]
assert fifth_return_value_render["name"] == "RadixThemesText"
assert fifth_return_value_render["children"][0]["contents"] == '{"sixth value"}'
default = match_child["default"]
assert default["name"] == "RadixThemesText"
assert default["children"][0]["contents"] == '{"default value"}'
@pytest.mark.parametrize(
"cases, expected",
[
@ -137,7 +68,7 @@ def test_match_vars(cases, expected):
cases: The match cases.
expected: The expected var full name.
"""
match_comp = Match.create(MatchState.value, *cases)
match_comp = match(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
assert isinstance(match_comp, Var)
assert str(match_comp) == expected
@ -146,18 +77,14 @@ def test_match_on_component_without_default():
"""Test that matching cases with return values as components returns a Fragment
as the default case if not provided.
"""
from reflex.components.base.fragment import Fragment
match_case_tuples = (
(1, rx.text("first value")),
(2, 3, rx.text("second value")),
)
match_comp = Match.create(MatchState.value, *match_case_tuples)
assert isinstance(match_comp, Component)
default = match_comp.render()["children"][0]["default"]
match_comp = match(MatchState.value, *match_case_tuples)
assert isinstance(default, dict) and default["name"] == Fragment.__name__
assert isinstance(match_comp, Var)
def test_match_on_var_no_default():
@ -172,7 +99,7 @@ def test_match_on_var_no_default():
ValueError,
match="For cases with return types as Vars, a default case must be provided",
):
Match.create(MatchState.value, *match_case_tuples)
match(MatchState.value, *match_case_tuples)
@pytest.mark.parametrize(
@ -205,7 +132,7 @@ def test_match_default_not_last_arg(match_case):
ValueError,
match="rx.match should have tuples of cases and a default case as the last argument.",
):
Match.create(MatchState.value, *match_case)
match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize(
@ -235,7 +162,7 @@ def test_match_case_tuple_elements(match_case):
ValueError,
match="A case tuple should have at least a match case element and a return value.",
):
Match.create(MatchState.value, *match_case)
match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize(
@ -251,8 +178,7 @@ def test_match_case_tuple_elements(match_case):
(MatchState.num + 1, "black"),
rx.text("default value"),
),
'Match cases should have the same return types. Case 3 with return value `"red"` of type '
"<class 'reflex.vars.sequence.LiteralStringVar'> is not <class 'reflex.components.component.BaseComponent'>",
"Match cases should have the same return types. Expected return types to be of type Component or Var[Component]. Return type of case 3 is <class 'str'>. Return type of case 4 is <class 'str'>. Return type of case 5 is <class 'str'>",
),
(
(
@ -264,8 +190,7 @@ def test_match_case_tuple_elements(match_case):
([1, 2], rx.text("third value")),
rx.text("default value"),
),
'Match cases should have the same return types. Case 3 with return value `<RadixThemesText as={"p"}> {"first value"} </RadixThemesText>` '
"of type <class 'reflex.components.radix.themes.typography.text.Text'> is not <class 'reflex.vars.base.Var'>",
"Match cases should have the same return types. Expected return types to be of type Component or Var[Component]. Return type of case 0 is <class 'str'>. Return type of case 1 is <class 'str'>. Return type of case 2 is <class 'str'>",
),
],
)
@ -276,8 +201,8 @@ def test_match_different_return_types(cases: Tuple, error_msg: str):
cases: The match cases.
error_msg: Expected error message.
"""
with pytest.raises(MatchTypeError, match=error_msg):
Match.create(MatchState.value, *cases)
with pytest.raises(MatchTypeError, match=re.escape(error_msg)):
match(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize(
@ -309,9 +234,9 @@ def test_match_multiple_default_cases(match_case):
match_case: the cases to match.
"""
with pytest.raises(ValueError, match="rx.match can only have one default case."):
Match.create(MatchState.value, *match_case)
match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
def test_match_no_cond():
with pytest.raises(ValueError):
_ = Match.create(None)
_ = match(None) # pyright: ignore[reportCallIssue]

View File

@ -1457,13 +1457,10 @@ def test_instantiate_all_components():
# These components all have required arguments and cannot be trivially instantiated.
untested_components = {
"Card",
"Cond",
"DebounceInput",
"Foreach",
"FormControl",
"Html",
"Icon",
"Match",
"Markdown",
"MultiSelect",
"Option",
@ -2156,14 +2153,11 @@ def test_add_style_foreach():
page = rx.vstack(rx.foreach(Var.range(3), lambda i: StyledComponent.create(i)))
page._add_style_recursive(Style())
# Expect only a single child of the foreach on the python side
assert len(page.children[0].children) == 1
# Expect the style to be added to the child of the foreach
assert 'css={({ ["color"] : "red" })}' in str(page.children[0].children[0])
assert '({ ["css"] : ({ ["color"] : "red" }) }),' in str(page.children[0])
# Expect only one instance of this CSS dict in the rendered page
assert str(page).count('css={({ ["color"] : "red" })}') == 1
assert str(page).count('({ ["css"] : ({ ["color"] : "red" }) }),') == 1
class TriggerState(rx.State):

View File

@ -2,7 +2,7 @@ from typing import Dict, List
import pytest
from reflex.components.tags import CondTag, Tag, tagless
from reflex.components.tags import Tag, tagless
from reflex.vars.base import LiteralVar, Var
@ -105,29 +105,6 @@ def test_format_tag(tag: Tag, expected: Dict):
assert prop_value.equals(LiteralVar.create(expected["props"][prop]))
def test_format_cond_tag():
"""Test that the cond tag dict is correct."""
tag = CondTag(
true_value=dict(Tag(name="h1", contents="True content")),
false_value=dict(Tag(name="h2", contents="False content")),
cond=Var(_js_expr="logged_in", _var_type=bool),
)
tag_dict = dict(tag)
cond, true_value, false_value = (
tag_dict["cond"],
tag_dict["true_value"],
tag_dict["false_value"],
)
assert cond._js_expr == "logged_in"
assert cond._var_type is bool
assert true_value["name"] == "h1"
assert true_value["contents"] == "True content"
assert false_value["name"] == "h2"
assert false_value["contents"] == "False content"
def test_tagless_string_representation():
"""Test that the string representation of a tagless is correct."""
tag = tagless.Tagless(contents="Hello world")

View File

@ -9,7 +9,7 @@ import unittest.mock
import uuid
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from typing import Generator, List, Tuple, Type
from typing import Generator, List, Tuple, Type, cast
from unittest.mock import AsyncMock
import pytest
@ -31,9 +31,8 @@ from reflex.app import (
)
from reflex.components import Component
from reflex.components.base.fragment import Fragment
from reflex.components.core.cond import Cond
from reflex.components.radix.themes.typography.text import Text
from reflex.event import Event
from reflex.event import Event, EventHandler
from reflex.middleware import HydrateMiddleware
from reflex.model import Model
from reflex.state import (
@ -916,7 +915,7 @@ class DynamicState(BaseState):
"""
return self.dynamic
on_load_internal = OnLoadInternalState.on_load_internal.fn # pyright: ignore [reportFunctionMemberAccess]
on_load_internal = cast(EventHandler, OnLoadInternalState.on_load_internal).fn
def test_dynamic_arg_shadow(
@ -1189,7 +1188,7 @@ async def test_process_events(mocker, token: str):
pass
assert (await app.state_manager.get_state(event.substate_token)).value == 5
assert app._postprocess.call_count == 6 # pyright: ignore [reportFunctionMemberAccess]
assert getattr(app._postprocess, "call_count", None) == 6
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@ -1227,10 +1226,6 @@ def test_overlay_component(
assert app.overlay_component is not None
generated_component = app._generate_component(app.overlay_component)
assert isinstance(generated_component, OverlayFragment)
assert isinstance(
generated_component.children[0],
Cond, # ConnectionModal is a Cond under the hood
)
else:
assert app.overlay_component is not None
assert isinstance(
@ -1246,8 +1241,8 @@ def test_overlay_component(
if exp_page_child is not None:
assert len(page.children) == 3
children_types = (type(child) for child in page.children)
assert exp_page_child in children_types # pyright: ignore [reportOperatorIssue]
children_types = [type(child) for child in page.children]
assert exp_page_child in children_types
else:
assert len(page.children) == 2

View File

@ -6,6 +6,7 @@ import reflex as rx
from reflex.constants.compiler import Hooks, Imports
from reflex.event import (
Event,
EventActionsMixin,
EventChain,
EventHandler,
EventSpec,
@ -410,6 +411,7 @@ def test_event_actions():
def test_event_actions_on_state():
class EventActionState(BaseState):
@rx.event
def handler(self):
pass
@ -417,7 +419,8 @@ def test_event_actions_on_state():
assert isinstance(handler, EventHandler)
assert not handler.event_actions
sp_handler = EventActionState.handler.stop_propagation # pyright: ignore [reportFunctionMemberAccess]
sp_handler = EventActionState.handler.stop_propagation
assert isinstance(sp_handler, EventActionsMixin)
assert sp_handler.event_actions == {"stopPropagation": True}
# should NOT affect other references to the handler
assert not handler.event_actions

View File

@ -122,9 +122,12 @@ async def test_health(
# Call the async health function
response = await health()
print(json.loads(response.body)) # pyright: ignore [reportArgumentType]
body = response.body
assert isinstance(body, bytes)
print(json.loads(body))
print(expected_status)
# Verify the response content and status code
assert response.status_code == expected_code
assert json.loads(response.body) == expected_status # pyright: ignore [reportArgumentType]
assert json.loads(body) == expected_status

View File

@ -18,6 +18,7 @@ from typing import (
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
@ -121,8 +122,8 @@ class TestState(BaseState):
num2: float = 3.14
key: str
map_key: str = "a"
array: List[float] = [1, 2, 3.14]
mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]}
array: rx.Field[List[float]] = rx.field([1, 2, 3.14])
mapping: rx.Field[Dict[str, List[int]]] = rx.field({"a": [1, 2, 3], "b": [4, 5, 6]})
obj: Object = Object()
complex: Dict[int, Object] = {1: Object(), 2: Object()}
fig: Figure = Figure()
@ -432,13 +433,16 @@ def test_default_setters(test_state):
def test_class_indexing_with_vars():
"""Test that we can index into a state var with another var."""
prop = TestState.array[TestState.num1] # pyright: ignore [reportCallIssue, reportArgumentType]
assert str(prop) == f"{TestState.get_name()}.array.at({TestState.get_name()}.num1)"
prop = TestState.array[TestState.num1]
assert (
str(prop)
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({TestState.get_name()}.array, ...args)))({TestState.get_name()}.num1))"
)
prop = TestState.mapping["a"][TestState.num1] # pyright: ignore [reportCallIssue, reportArgumentType]
assert (
str(prop)
== f'{TestState.get_name()}.mapping["a"].at({TestState.get_name()}.num1)'
== f'(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({TestState.get_name()}.mapping["a"], ...args)))({TestState.get_name()}.num1))'
)
prop = TestState.mapping[TestState.map_key]
@ -1358,6 +1362,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
class HandlerState(BaseState):
x: int = 42
@rx.event
def handler(self):
self.x = self.x + 1
@ -1368,11 +1373,11 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
counter += 1
return counter
assert isinstance(HandlerState.handler, EventHandler)
if use_partial:
HandlerState.handler = functools.partial(HandlerState.handler.fn) # pyright: ignore [reportFunctionMemberAccess]
partial_guy = functools.partial(HandlerState.handler.fn)
HandlerState.handler = partial_guy # pyright: ignore[reportAttributeAccessIssue]
assert isinstance(HandlerState.handler, functools.partial)
else:
assert isinstance(HandlerState.handler, EventHandler)
s = HandlerState()
assert (
@ -2029,8 +2034,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
# ensure state update was emitted
assert mock_app.event_namespace is not None
mock_app.event_namespace.emit.assert_called_once() # pyright: ignore [reportFunctionMemberAccess]
mcall = mock_app.event_namespace.emit.mock_calls[0] # pyright: ignore [reportFunctionMemberAccess]
mock_app.event_namespace.emit.assert_called_once() # pyright: ignore[reportFunctionMemberAccess]
mock_calls = getattr(mock_app.event_namespace.emit, "mock_calls", None)
assert mock_calls is not None
assert isinstance(mock_calls, Sequence)
mcall = mock_calls[0]
assert mcall.args[0] == str(SocketEvent.EVENT)
assert mcall.args[1] == StateUpdate(
delta={
@ -2231,7 +2239,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit
first_ws_message = emit_mock.mock_calls[0].args[1] # pyright: ignore [reportFunctionMemberAccess]
mock_calls = getattr(emit_mock, "mock_calls", None)
assert mock_calls is not None
assert isinstance(mock_calls, Sequence)
first_ws_message = mock_calls[0].args[1]
assert (
first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
is not None
@ -2246,7 +2258,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
events=[],
final=True,
)
for call in emit_mock.mock_calls[1:5]: # pyright: ignore [reportFunctionMemberAccess]
for call in mock_calls[1:5]:
assert call.args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
@ -2256,7 +2268,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
events=[],
final=True,
)
assert emit_mock.mock_calls[-2].args[1] == StateUpdate( # pyright: ignore [reportFunctionMemberAccess]
assert mock_calls[-2].args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
"order": exp_order,
@ -2267,7 +2279,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
events=[],
final=True,
)
assert emit_mock.mock_calls[-1].args[1] == StateUpdate( # pyright: ignore [reportFunctionMemberAccess]
assert mock_calls[-1].args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
"computed_order": exp_order,

View File

@ -328,7 +328,10 @@ def test_basic_operations(TestObj):
assert str(LiteralNumberVar.create(1) ** 2) == "(1 ** 2)"
assert str(LiteralNumberVar.create(1) & v(2)) == "(1 && 2)"
assert str(LiteralNumberVar.create(1) | v(2)) == "(1 || 2)"
assert str(LiteralArrayVar.create([1, 2, 3])[0]) == "[1, 2, 3].at(0)"
assert (
str(LiteralArrayVar.create([1, 2, 3])[0])
== "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3], ...args)))(0))"
)
assert (
str(LiteralObjectVar.create({"a": 1, "b": 2})["a"])
== '({ ["a"] : 1, ["b"] : 2 })["a"]'
@ -350,27 +353,33 @@ def test_basic_operations(TestObj):
str(Var(_js_expr="foo").to(ObjectVar, TestObj)._var_set_state("state").bar)
== 'state.foo["bar"]'
)
assert str(abs(LiteralNumberVar.create(1))) == "Math.abs(1)"
assert str(LiteralArrayVar.create([1, 2, 3]).length()) == "[1, 2, 3].length"
assert str(abs(LiteralNumberVar.create(1))) == "(Math.abs(1))"
assert (
str(LiteralArrayVar.create([1, 2, 3]).length())
== "(((...args) => (((_array) => _array.length)([1, 2, 3], ...args)))())"
)
assert (
str(LiteralArrayVar.create([1, 2]) + LiteralArrayVar.create([3, 4]))
== "[...[1, 2], ...[3, 4]]"
== "(((...args) => (((_lhs, _rhs) => [..._lhs, ..._rhs])([1, 2], ...args)))([3, 4]))"
)
# Tests for reverse operation
assert (
str(LiteralArrayVar.create([1, 2, 3]).reverse())
== "[1, 2, 3].slice().reverse()"
== "(((...args) => (((_array) => _array.slice().reverse())([1, 2, 3], ...args)))())"
)
assert (
str(LiteralArrayVar.create(["1", "2", "3"]).reverse())
== '["1", "2", "3"].slice().reverse()'
== '(((...args) => (((_array) => _array.slice().reverse())(["1", "2", "3"], ...args)))())'
)
assert (
str(Var(_js_expr="foo")._var_set_state("state").to(list).reverse())
== "state.foo.slice().reverse()"
== "(((...args) => (((_array) => _array.slice().reverse())(state.foo, ...args)))())"
)
assert (
str(Var(_js_expr="foo").to(list).reverse())
== "(((...args) => (((_array) => _array.slice().reverse())(foo, ...args)))())"
)
assert str(Var(_js_expr="foo").to(list).reverse()) == "foo.slice().reverse()"
assert str(Var(_js_expr="foo", _var_type=str).js_type()) == "(typeof(foo))"
@ -395,14 +404,32 @@ def test_basic_operations(TestObj):
],
)
def test_list_tuple_contains(var, expected):
assert str(var.contains(1)) == f"{expected}.includes(1)"
assert str(var.contains("1")) == f'{expected}.includes("1")'
assert str(var.contains(v(1))) == f"{expected}.includes(1)"
assert str(var.contains(v("1"))) == f'{expected}.includes("1")'
assert (
str(var.contains(1))
== f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))(1))'
)
assert (
str(var.contains("1"))
== f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))("1"))'
)
assert (
str(var.contains(v(1)))
== f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))(1))'
)
assert (
str(var.contains(v("1")))
== f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))("1"))'
)
other_state_var = Var(_js_expr="other", _var_type=str)._var_set_state("state")
other_var = Var(_js_expr="other", _var_type=str)
assert str(var.contains(other_state_var)) == f"{expected}.includes(state.other)"
assert str(var.contains(other_var)) == f"{expected}.includes(other)"
assert (
str(var.contains(other_state_var))
== f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))(state.other))'
)
assert (
str(var.contains(other_var))
== f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))(other))'
)
class Foo(rx.Base):
@ -446,15 +473,23 @@ def test_var_types(var, var_type):
],
)
def test_str_contains(var, expected):
assert str(var.contains("1")) == f'{expected}.includes("1")'
assert str(var.contains(v("1"))) == f'{expected}.includes("1")'
assert (
str(var.contains("1"))
== f'(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))({expected!s}, ...args)))("1"))'
)
assert (
str(var.contains(v("1")))
== f'(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))({expected!s}, ...args)))("1"))'
)
other_state_var = Var(_js_expr="other")._var_set_state("state").to(str)
other_var = Var(_js_expr="other").to(str)
assert str(var.contains(other_state_var)) == f"{expected}.includes(state.other)"
assert str(var.contains(other_var)) == f"{expected}.includes(other)"
assert (
str(var.contains("1", "hello"))
== f'{expected}.some(obj => obj["hello"] === "1")'
str(var.contains(other_state_var))
== f"(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))({expected!s}, ...args)))(state.other))"
)
assert (
str(var.contains(other_var))
== f"(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))({expected!s}, ...args)))(other))"
)
@ -467,16 +502,17 @@ def test_str_contains(var, expected):
],
)
def test_dict_contains(var, expected):
assert str(var.contains(1)) == f"{expected}.hasOwnProperty(1)"
assert str(var.contains("1")) == f'{expected}.hasOwnProperty("1")'
assert str(var.contains(v(1))) == f"{expected}.hasOwnProperty(1)"
assert str(var.contains(v("1"))) == f'{expected}.hasOwnProperty("1")'
assert str(var.contains(1)) == f"{expected!s}.hasOwnProperty(1)"
assert str(var.contains("1")) == f'{expected!s}.hasOwnProperty("1")'
assert str(var.contains(v(1))) == f"{expected!s}.hasOwnProperty(1)"
assert str(var.contains(v("1"))) == f'{expected!s}.hasOwnProperty("1")'
other_state_var = Var(_js_expr="other")._var_set_state("state").to(str)
other_var = Var(_js_expr="other").to(str)
assert (
str(var.contains(other_state_var)) == f"{expected}.hasOwnProperty(state.other)"
str(var.contains(other_state_var))
== f"{expected!s}.hasOwnProperty(state.other)"
)
assert str(var.contains(other_var)) == f"{expected}.hasOwnProperty(other)"
assert str(var.contains(other_var)) == f"{expected!s}.hasOwnProperty(other)"
@pytest.mark.parametrize(
@ -484,7 +520,6 @@ def test_dict_contains(var, expected):
[
Var(_js_expr="list", _var_type=List[int]).guess_type(),
Var(_js_expr="tuple", _var_type=Tuple[int, int]).guess_type(),
Var(_js_expr="str", _var_type=str).guess_type(),
],
)
def test_var_indexing_lists(var):
@ -494,11 +529,20 @@ def test_var_indexing_lists(var):
var : The str, list or tuple base var.
"""
# Test basic indexing.
assert str(var[0]) == f"{var._js_expr}.at(0)"
assert str(var[1]) == f"{var._js_expr}.at(1)"
assert (
str(var[0])
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))(0))"
)
assert (
str(var[1])
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))(1))"
)
# Test negative indexing.
assert str(var[-1]) == f"{var._js_expr}.at(-1)"
assert (
str(var[-1])
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))(-1))"
)
@pytest.mark.parametrize(
@ -532,11 +576,20 @@ def test_var_indexing_str():
assert str_var[0]._var_type is str
# Test basic indexing.
assert str(str_var[0]) == "str.at(0)"
assert str(str_var[1]) == "str.at(1)"
assert (
str(str_var[0])
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))(0))"
)
assert (
str(str_var[1])
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))(1))"
)
# Test negative indexing.
assert str(str_var[-1]) == "str.at(-1)"
assert (
str(str_var[-1])
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))(-1))"
)
@pytest.mark.parametrize(
@ -651,9 +704,18 @@ def test_var_list_slicing(var):
Args:
var : The str, list or tuple base var.
"""
assert str(var[:1]) == f"{var._js_expr}.slice(undefined, 1)"
assert str(var[1:]) == f"{var._js_expr}.slice(1, undefined)"
assert str(var[:]) == f"{var._js_expr}.slice(undefined, undefined)"
assert (
str(var[:1])
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))([null, 1, null]))"
)
assert (
str(var[1:])
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))([1, null, null]))"
)
assert (
str(var[:])
== f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))([null, null, null]))"
)
def test_str_var_slicing():
@ -665,16 +727,40 @@ def test_str_var_slicing():
assert str_var[:1]._var_type is str
# Test basic slicing.
assert str(str_var[:1]) == 'str.split("").slice(undefined, 1).join("")'
assert str(str_var[1:]) == 'str.split("").slice(1, undefined).join("")'
assert str(str_var[:]) == 'str.split("").slice(undefined, undefined).join("")'
assert str(str_var[1:2]) == 'str.split("").slice(1, 2).join("")'
assert (
str(str_var[:1])
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([null, 1, null]))"
)
assert (
str(str_var[1:])
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([1, null, null]))"
)
assert (
str(str_var[:])
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([null, null, null]))"
)
assert (
str(str_var[1:2])
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([1, 2, null]))"
)
# Test negative slicing.
assert str(str_var[:-1]) == 'str.split("").slice(undefined, -1).join("")'
assert str(str_var[-1:]) == 'str.split("").slice(-1, undefined).join("")'
assert str(str_var[:-2]) == 'str.split("").slice(undefined, -2).join("")'
assert str(str_var[-2:]) == 'str.split("").slice(-2, undefined).join("")'
assert (
str(str_var[:-1])
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([null, -1, null]))"
)
assert (
str(str_var[-1:])
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([-1, null, null]))"
)
assert (
str(str_var[:-2])
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([null, -2, null]))"
)
assert (
str(str_var[-2:])
== "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([-2, null, null]))"
)
def test_dict_indexing():
@ -963,11 +1049,11 @@ def test_function_var():
def test_var_operation():
@var_operation
def add(a: Union[NumberVar, int], b: Union[NumberVar, int]):
def add(a: Var[int], b: Var[int]):
return var_operation_return(js_expression=f"({a} + {b})", var_type=int)
assert str(add(1, 2)) == "(1 + 2)"
assert str(add(a=4, b=-9)) == "(4 + -9)"
assert str(add(4, -9)) == "(4 + -9)"
five = LiteralNumberVar.create(5)
seven = add(2, five)
@ -978,13 +1064,29 @@ def test_var_operation():
def test_string_operations():
basic_string = LiteralStringVar.create("Hello, World!")
assert str(basic_string.length()) == '"Hello, World!".split("").length'
assert str(basic_string.lower()) == '"Hello, World!".toLowerCase()'
assert str(basic_string.upper()) == '"Hello, World!".toUpperCase()'
assert str(basic_string.strip()) == '"Hello, World!".trim()'
assert str(basic_string.contains("World")) == '"Hello, World!".includes("World")'
assert (
str(basic_string.split(" ").join(",")) == '"Hello, World!".split(" ").join(",")'
str(basic_string.length())
== '(((...args) => (((...arg) => (((_array) => _array.length)((((_string, _sep = "") => isTrue(_sep) ? _string.split(_sep) : [..._string])(...args)))))("Hello, World!", ...args)))())'
)
assert (
str(basic_string.lower())
== '(((...args) => (((_string) => String.prototype.toLowerCase.apply(_string))("Hello, World!", ...args)))())'
)
assert (
str(basic_string.upper())
== '(((...args) => (((_string) => String.prototype.toUpperCase.apply(_string))("Hello, World!", ...args)))())'
)
assert (
str(basic_string.strip())
== '(((...args) => (((_string) => String.prototype.trim.apply(_string))("Hello, World!", ...args)))())'
)
assert (
str(basic_string.contains("World"))
== '(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))("Hello, World!", ...args)))("World"))'
)
assert (
str(basic_string.split(" ").join(","))
== '(((...args) => (((_array, _sep = "") => Array.prototype.join.apply(_array,[_sep]))((((...args) => (((_string, _sep = "") => isTrue(_sep) ? _string.split(_sep) : [..._string])("Hello, World!", ...args)))(" ")), ...args)))(","))'
)
@ -1004,14 +1106,14 @@ def test_all_number_operations():
assert (
str(even_more_complicated_number)
== "!(isTrue((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))))))"
== "!((isTrue(((Math.abs((Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)))))))"
)
assert str(LiteralNumberVar.create(5) > False) == "(5 > 0)"
assert str(LiteralBooleanVar.create(False) < 5) == "(Number(false) < 5)"
assert str(LiteralBooleanVar.create(False) < 5) == "((Number(false)) < 5)"
assert (
str(LiteralBooleanVar.create(False) < LiteralBooleanVar.create(True))
== "(Number(false) < Number(true))"
== "((Number(false)) < (Number(true)))"
)
@ -1020,10 +1122,10 @@ def test_all_number_operations():
[
(Var.create(False), "false"),
(Var.create(True), "true"),
(Var.create("false"), 'isTrue("false")'),
(Var.create([1, 2, 3]), "isTrue([1, 2, 3])"),
(Var.create({"a": 1, "b": 2}), 'isTrue(({ ["a"] : 1, ["b"] : 2 }))'),
(Var("mysterious_var"), "isTrue(mysterious_var)"),
(Var.create("false"), '(isTrue("false"))'),
(Var.create([1, 2, 3]), "(isTrue([1, 2, 3]))"),
(Var.create({"a": 1, "b": 2}), '(isTrue(({ ["a"] : 1, ["b"] : 2 })))'),
(Var("mysterious_var"), "(isTrue(mysterious_var))"),
],
)
def test_boolify_operations(var, expected):
@ -1032,18 +1134,30 @@ def test_boolify_operations(var, expected):
def test_index_operation():
array_var = LiteralArrayVar.create([1, 2, 3, 4, 5])
assert str(array_var[0]) == "[1, 2, 3, 4, 5].at(0)"
assert str(array_var[1:2]) == "[1, 2, 3, 4, 5].slice(1, 2)"
assert (
str(array_var[0])
== "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))(0))"
)
assert (
str(array_var[1:2])
== "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))([1, 2, null]))"
)
assert (
str(array_var[1:4:2])
== "[1, 2, 3, 4, 5].slice(1, 4).filter((_, i) => i % 2 === 0)"
== "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))([1, 4, 2]))"
)
assert (
str(array_var[::-1])
== "[1, 2, 3, 4, 5].slice(0, [1, 2, 3, 4, 5].length).slice().reverse().slice(undefined, undefined).filter((_, i) => i % 1 === 0)"
== "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))([null, null, -1]))"
)
assert (
str(array_var.reverse())
== "(((...args) => (((_array) => _array.slice().reverse())([1, 2, 3, 4, 5], ...args)))())"
)
assert (
str(array_var[0].to(NumberVar) + 9)
== "((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))(0)) + 9)"
)
assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].slice().reverse()"
assert str(array_var[0].to(NumberVar) + 9) == "([1, 2, 3, 4, 5].at(0) + 9)"
@pytest.mark.parametrize(
@ -1065,40 +1179,37 @@ def test_inf_and_nan(var, expected_js):
def test_array_operations():
array_var = LiteralArrayVar.create([1, 2, 3, 4, 5])
assert str(array_var.length()) == "[1, 2, 3, 4, 5].length"
assert str(array_var.contains(3)) == "[1, 2, 3, 4, 5].includes(3)"
assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].slice().reverse()"
assert (
str(ArrayVar.range(10))
== "Array.from({ length: Math.ceil((10 - 0) / 1) }, (_, i) => 0 + i * 1)"
str(array_var.length())
== "(((...args) => (((_array) => _array.length)([1, 2, 3, 4, 5], ...args)))())"
)
assert (
str(ArrayVar.range(1, 10))
== "Array.from({ length: Math.ceil((10 - 1) / 1) }, (_, i) => 1 + i * 1)"
str(array_var.contains(3))
== '(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))([1, 2, 3, 4, 5], ...args)))(3))'
)
assert (
str(ArrayVar.range(1, 10, 2))
== "Array.from({ length: Math.ceil((10 - 1) / 2) }, (_, i) => 1 + i * 2)"
)
assert (
str(ArrayVar.range(1, 10, -1))
== "Array.from({ length: Math.ceil((10 - 1) / -1) }, (_, i) => 1 + i * -1)"
str(array_var.reverse())
== "(((...args) => (((_array) => _array.slice().reverse())([1, 2, 3, 4, 5], ...args)))())"
)
assert str(ArrayVar.range(10)) == "[...range(10, null, 1)]"
assert str(ArrayVar.range(1, 10)) == "[...range(1, 10, 1)]"
assert str(ArrayVar.range(1, 10, 2)) == "[...range(1, 10, 2)]"
assert str(ArrayVar.range(1, 10, -1)) == "[...range(1, 10, -1)]"
def test_object_operations():
object_var = LiteralObjectVar.create({"a": 1, "b": 2, "c": 3})
assert (
str(object_var.keys()) == 'Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))'
str(object_var.keys()) == '(Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })))'
)
assert (
str(object_var.values())
== 'Object.values(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))'
== '(Object.values(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })))'
)
assert (
str(object_var.entries())
== 'Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))'
== '(Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })))'
)
assert str(object_var.a) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]'
assert str(object_var["a"]) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]'
@ -1139,12 +1250,12 @@ def test_type_chains():
List[int],
)
assert (
str(object_var.keys()[0].upper())
== 'Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })).at(0).toUpperCase()'
str(object_var.keys()[0].upper()) # pyright: ignore [reportAttributeAccessIssue]
== '(((...args) => (((_string) => String.prototype.toUpperCase.apply(_string))((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))), ...args)))(0)), ...args)))())'
)
assert (
str(object_var.entries()[1][1] - 1)
== '(Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })).at(1).at(1) - 1)'
str(object_var.entries()[1][1] - 1) # pyright: ignore [reportCallIssue, reportOperatorIssue]
== '((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))), ...args)))(1)), ...args)))(1)) - 1)'
)
assert (
str(object_var["c"] + object_var["b"]) # pyright: ignore [reportCallIssue, reportOperatorIssue]
@ -1153,10 +1264,14 @@ def test_type_chains():
def test_nested_dict():
arr = LiteralArrayVar.create([{"bar": ["foo", "bar"]}], List[Dict[str, List[str]]])
arr = Var.create([{"bar": ["foo", "bar"]}]).to(List[Dict[str, List[str]]])
first_dict = arr.at(0)
bar_element = first_dict["bar"]
first_bar_element = bar_element[0]
assert (
str(arr[0]["bar"][0]) == '[({ ["bar"] : ["foo", "bar"] })].at(0)["bar"].at(0)' # pyright: ignore [reportIndexIssue]
str(first_bar_element)
== '(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((((...args) => (((_array, _index) => _array.at(_index))([({ ["bar"] : ["foo", "bar"] })], ...args)))(0))["bar"], ...args)))(0))' # pyright: ignore [reportIndexIssue]
)
@ -1331,9 +1446,9 @@ def test_unsupported_types_for_reverse(var):
Args:
var: The base var.
"""
with pytest.raises(TypeError) as err:
with pytest.raises(AttributeError) as err:
var.reverse()
assert err.value.args[0] == "Cannot reverse non-list var."
assert err.value.args[0] == "'Var' object has no attribute 'reverse'"
@pytest.mark.parametrize(
@ -1351,12 +1466,9 @@ def test_unsupported_types_for_contains(var: Var):
Args:
var: The base var.
"""
with pytest.raises(TypeError) as err:
with pytest.raises(AttributeError) as err:
assert var.contains(1) # pyright: ignore [reportAttributeAccessIssue]
assert (
err.value.args[0]
== f"Var of type {var._var_type} does not support contains check."
)
assert err.value.args[0] == "'Var' object has no attribute 'contains'"
@pytest.mark.parametrize(
@ -1376,7 +1488,7 @@ def test_unsupported_types_for_string_contains(other):
assert Var(_js_expr="var").to(str).contains(other)
assert (
err.value.args[0]
== f"Unsupported Operand type(s) for contains: ToStringOperation, {type(other).__name__}"
== f"Invalid argument other provided to argument 0 in var operation. Expected <class 'str'> but got {other._var_type}."
)
@ -1608,17 +1720,12 @@ def test_valid_var_operations(operand1_var: Var, operand2_var, operators: List[s
LiteralVar.create([10, 20]),
LiteralVar.create("5"),
[
"+",
"-",
"/",
"//",
"*",
"%",
"**",
">",
"<",
"<=",
">=",
"^",
"<<",
">>",

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import datetime
import json
from typing import Any, List
from typing import Any
import plotly.graph_objects as go
import pytest
@ -98,60 +98,6 @@ def test_wrap(text: str, open: str, expected: str, check_first: bool, num: int):
assert format.wrap(text, open, check_first=check_first, num=num) == expected
@pytest.mark.parametrize(
"string,expected_output",
[
("This is a random string", "This is a random string"),
(
"This is a random string with `backticks`",
"This is a random string with \\`backticks\\`",
),
(
"This is a random string with `backticks`",
"This is a random string with \\`backticks\\`",
),
(
"This is a string with ${someValue[`string interpolation`]} unescaped",
"This is a string with ${someValue[`string interpolation`]} unescaped",
),
(
"This is a string with `backticks` and ${someValue[`string interpolation`]} unescaped",
"This is a string with \\`backticks\\` and ${someValue[`string interpolation`]} unescaped",
),
(
"This is a string with `backticks`, ${someValue[`the first string interpolation`]} and ${someValue[`the second`]}",
"This is a string with \\`backticks\\`, ${someValue[`the first string interpolation`]} and ${someValue[`the second`]}",
),
],
)
def test_escape_js_string(string, expected_output):
assert format._escape_js_string(string) == expected_output
@pytest.mark.parametrize(
"text,indent_level,expected",
[
("", 2, ""),
("hello", 2, "hello"),
("hello\nworld", 2, " hello\n world\n"),
("hello\nworld", 4, " hello\n world\n"),
(" hello\n world", 2, " hello\n world\n"),
],
)
def test_indent(text: str, indent_level: int, expected: str, windows_platform: bool):
"""Test indenting a string.
Args:
text: The text to indent.
indent_level: The number of spaces to indent by.
expected: The expected output string.
windows_platform: Whether the system is windows.
"""
assert format.indent(text, indent_level) == (
expected.replace("\n", "\r\n") if windows_platform else expected
)
@pytest.mark.parametrize(
"input,output",
[
@ -252,25 +198,6 @@ def test_to_kebab_case(input: str, output: str):
assert format.to_kebab_case(input) == output
@pytest.mark.parametrize(
"input,output",
[
("", "{``}"),
("hello", "{`hello`}"),
("hello world", "{`hello world`}"),
("hello=`world`", "{`hello=\\`world\\``}"),
],
)
def test_format_string(input: str, output: str):
"""Test formatting the input as JS string literal.
Args:
input: the input string.
output: the output string.
"""
assert format.format_string(input) == output
@pytest.mark.parametrize(
"input,output",
[
@ -310,45 +237,6 @@ def test_format_route(route: str, format_case: bool, expected: bool):
assert format.format_route(route, format_case=format_case) == expected
@pytest.mark.parametrize(
"condition, match_cases, default,expected",
[
(
"state__state.value",
[
[LiteralVar.create(1), LiteralVar.create("red")],
[LiteralVar.create(2), LiteralVar.create(3), LiteralVar.create("blue")],
[TestState.mapping, TestState.num1],
[
LiteralVar.create(f"{TestState.map_key}-key"),
LiteralVar.create("return-key"),
],
],
LiteralVar.create("yellow"),
'(() => { switch (JSON.stringify(state__state.value)) {case JSON.stringify(1): return ("red"); break;case JSON.stringify(2): case JSON.stringify(3): '
f'return ("blue"); break;case JSON.stringify({TestState.get_full_name()}.mapping): return '
f'({TestState.get_full_name()}.num1); break;case JSON.stringify(({TestState.get_full_name()}.map_key+"-key")): return ("return-key");'
' break;default: return ("yellow"); break;};})()',
)
],
)
def test_format_match(
condition: str,
match_cases: List[List[Var]],
default: Var,
expected: str,
):
"""Test formatting a match statement.
Args:
condition: The condition to match.
match_cases: List of match cases to be matched.
default: Catchall case for the match statement.
expected: The expected string output.
"""
assert format.format_match(condition, match_cases, default) == expected
@pytest.mark.parametrize(
"prop,formatted",
[

View File

@ -2,7 +2,7 @@ import os
import typing
from functools import cached_property
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Type, Union
from typing import Any, ClassVar, Dict, List, Literal, Sequence, Tuple, Type, Union
import pytest
import typer
@ -109,10 +109,20 @@ def test_is_generic_alias(cls: type, expected: bool):
(Dict[str, str], dict[str, str], True),
(Dict[str, str], dict[str, Any], True),
(Dict[str, Any], dict[str, Any], True),
(List[int], Sequence[int], True),
(List[str], Sequence[int], False),
(Tuple[int], Sequence[int], True),
(Tuple[int, str], Sequence[int], False),
(Tuple[int, ...], Sequence[int], True),
(str, Sequence[int], False),
(str, Sequence[str], True),
],
)
def test_typehint_issubclass(subclass, superclass, expected):
assert types.typehint_issubclass(subclass, superclass) == expected
if expected:
assert types.typehint_issubclass(subclass, superclass)
else:
assert not types.typehint_issubclass(subclass, superclass)
def test_validate_none_bun_path(mocker):