Hello !" })}/>'
@@ -32,10 +32,10 @@ def test_html_fstring_create():
html = Html.create(f"
")
assert (
- str(html.dangerouslySetInnerHTML) # type: ignore
+ str(html.dangerouslySetInnerHTML) # pyright: ignore [reportAttributeAccessIssue]
== f'({{ ["__html"] : ("
' # type: ignore
+ == f'
' # pyright: ignore [reportAttributeAccessIssue]
)
diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py
index f09e800e5..11602b77a 100644
--- a/tests/units/components/core/test_match.py
+++ b/tests/units/components/core/test_match.py
@@ -1,8 +1,9 @@
-from typing import Dict, List, Tuple
+from typing import List, Mapping, Tuple
import pytest
import reflex as rx
+from reflex.components.component import Component
from reflex.components.core.match import Match
from reflex.state import BaseState
from reflex.utils.exceptions import MatchTypeError
@@ -29,7 +30,9 @@ def test_match_components():
rx.text("default value"),
)
match_comp = Match.create(MatchState.value, *match_case_tuples)
- match_dict = match_comp.render() # type: ignore
+
+ assert isinstance(match_comp, Component)
+ match_dict = match_comp.render()
assert match_dict["name"] == "Fragment"
[match_child] = match_dict["children"]
@@ -42,7 +45,7 @@ def test_match_components():
assert match_cases[0][0]._js_expr == "1"
assert match_cases[0][0]._var_type is int
- first_return_value_render = match_cases[0][1].render()
+ first_return_value_render = match_cases[0][1]
assert first_return_value_render["name"] == "RadixThemesText"
assert first_return_value_render["children"][0]["contents"] == '{"first value"}'
@@ -50,35 +53,35 @@ def test_match_components():
assert match_cases[1][0]._var_type is int
assert match_cases[1][1]._js_expr == "3"
assert match_cases[1][1]._var_type is int
- second_return_value_render = match_cases[1][2].render()
+ second_return_value_render = match_cases[1][2]
assert second_return_value_render["name"] == "RadixThemesText"
assert second_return_value_render["children"][0]["contents"] == '{"second value"}'
assert match_cases[2][0]._js_expr == "[1, 2]"
assert match_cases[2][0]._var_type == List[int]
- third_return_value_render = match_cases[2][1].render()
+ third_return_value_render = match_cases[2][1]
assert third_return_value_render["name"] == "RadixThemesText"
assert third_return_value_render["children"][0]["contents"] == '{"third value"}'
assert match_cases[3][0]._js_expr == '"random"'
assert match_cases[3][0]._var_type is str
- fourth_return_value_render = match_cases[3][1].render()
+ fourth_return_value_render = match_cases[3][1]
assert fourth_return_value_render["name"] == "RadixThemesText"
assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}'
assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })'
- assert match_cases[4][0]._var_type == Dict[str, str]
- fifth_return_value_render = match_cases[4][1].render()
+ assert match_cases[4][0]._var_type == Mapping[str, str]
+ fifth_return_value_render = match_cases[4][1]
assert fifth_return_value_render["name"] == "RadixThemesText"
assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}'
assert match_cases[5][0]._js_expr == f"({MatchState.get_name()}.num + 1)"
assert match_cases[5][0]._var_type is int
- fifth_return_value_render = match_cases[5][1].render()
+ fifth_return_value_render = match_cases[5][1]
assert fifth_return_value_render["name"] == "RadixThemesText"
assert fifth_return_value_render["children"][0]["contents"] == '{"sixth value"}'
- default = match_child["default"].render()
+ default = match_child["default"]
assert default["name"] == "RadixThemesText"
assert default["children"][0]["contents"] == '{"default value"}'
@@ -151,9 +154,10 @@ def test_match_on_component_without_default():
)
match_comp = Match.create(MatchState.value, *match_case_tuples)
- default = match_comp.render()["children"][0]["default"] # type: ignore
+ assert isinstance(match_comp, Component)
+ default = match_comp.render()["children"][0]["default"]
- assert isinstance(default, Fragment)
+ assert isinstance(default, dict) and default["name"] == Fragment.__name__
def test_match_on_var_no_default():
diff --git a/tests/units/components/core/test_upload.py b/tests/units/components/core/test_upload.py
index 710baa161..efade7b63 100644
--- a/tests/units/components/core/test_upload.py
+++ b/tests/units/components/core/test_upload.py
@@ -5,7 +5,7 @@ from reflex.components.core.upload import (
StyledUpload,
Upload,
UploadNamespace,
- _on_drop_spec, # type: ignore
+ _on_drop_spec, # pyright: ignore [reportAttributeAccessIssue]
cancel_upload,
get_upload_url,
)
@@ -60,7 +60,7 @@ def test_upload_create():
up_comp_2 = Upload.create(
id="foo_id",
- on_drop=UploadStateTest.drop_handler([]), # type: ignore
+ on_drop=UploadStateTest.drop_handler([]),
)
assert isinstance(up_comp_2, Upload)
assert up_comp_2.is_used
@@ -80,7 +80,7 @@ def test_upload_create():
up_comp_4 = Upload.create(
id="foo_id",
- on_drop=UploadStateTest.not_drop_handler([]), # type: ignore
+ on_drop=UploadStateTest.not_drop_handler([]),
)
assert isinstance(up_comp_4, Upload)
assert up_comp_4.is_used
@@ -96,7 +96,7 @@ def test_styled_upload_create():
styled_up_comp_2 = StyledUpload.create(
id="foo_id",
- on_drop=UploadStateTest.drop_handler([]), # type: ignore
+ on_drop=UploadStateTest.drop_handler([]),
)
assert isinstance(styled_up_comp_2, StyledUpload)
assert styled_up_comp_2.is_used
@@ -116,7 +116,7 @@ def test_styled_upload_create():
styled_up_comp_4 = StyledUpload.create(
id="foo_id",
- on_drop=UploadStateTest.not_drop_handler([]), # type: ignore
+ on_drop=UploadStateTest.not_drop_handler([]),
)
assert isinstance(styled_up_comp_4, StyledUpload)
assert styled_up_comp_4.is_used
diff --git a/tests/units/components/datadisplay/conftest.py b/tests/units/components/datadisplay/conftest.py
index 13c571c8c..188e887c4 100644
--- a/tests/units/components/datadisplay/conftest.py
+++ b/tests/units/components/datadisplay/conftest.py
@@ -1,7 +1,5 @@
"""Data display component tests fixtures."""
-from typing import List
-
import pandas as pd
import pytest
@@ -54,11 +52,11 @@ def data_table_state3():
"""
class DataTableState(BaseState):
- _data: List = []
- _columns: List = ["col1", "col2"]
+ _data: list = []
+ _columns: list = ["col1", "col2"]
@rx.var
- def data(self) -> List:
+ def data(self) -> list:
return self._data
@rx.var
@@ -77,15 +75,15 @@ def data_table_state4():
"""
class DataTableState(BaseState):
- _data: List = []
- _columns: List = ["col1", "col2"]
+ _data: list = []
+ _columns: list[str] = ["col1", "col2"]
@rx.var
def data(self):
return self._data
@rx.var
- def columns(self) -> List:
+ def columns(self) -> list:
return self._columns
return DataTableState
diff --git a/tests/units/components/datadisplay/test_code.py b/tests/units/components/datadisplay/test_code.py
index 6b7168756..db0120fe1 100644
--- a/tests/units/components/datadisplay/test_code.py
+++ b/tests/units/components/datadisplay/test_code.py
@@ -10,4 +10,4 @@ from reflex.components.datadisplay.code import CodeBlock, Theme
def test_code_light_dark_theme(theme, expected):
code_block = CodeBlock.create(theme=theme)
- assert code_block.theme._js_expr == expected # type: ignore
+ assert code_block.theme._js_expr == expected # pyright: ignore [reportAttributeAccessIssue]
diff --git a/tests/units/components/datadisplay/test_datatable.py b/tests/units/components/datadisplay/test_datatable.py
index b3d31ea32..2dece464a 100644
--- a/tests/units/components/datadisplay/test_datatable.py
+++ b/tests/units/components/datadisplay/test_datatable.py
@@ -4,6 +4,7 @@ import pytest
import reflex as rx
from reflex.components.gridjs.datatable import DataTable
from reflex.utils import types
+from reflex.utils.exceptions import UntypedComputedVarError
from reflex.utils.serializers import serialize, serialize_dataframe
@@ -13,7 +14,8 @@ from reflex.utils.serializers import serialize, serialize_dataframe
pytest.param(
{
"data": pd.DataFrame(
- [["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"]
+ [["foo", "bar"], ["foo1", "bar1"]],
+ columns=["column1", "column2"], # pyright: ignore [reportArgumentType]
)
},
"data",
@@ -75,17 +77,17 @@ def test_invalid_props(props):
[
(
"data_table_state2",
- "Annotation of the computed var assigned to the data field should be provided.",
+ "Computed var 'data' must have a type annotation.",
True,
),
(
"data_table_state3",
- "Annotation of the computed var assigned to the column field should be provided.",
+ "Computed var 'columns' must have a type annotation.",
False,
),
(
"data_table_state4",
- "Annotation of the computed var assigned to the data field should be provided.",
+ "Computed var 'data' must have a type annotation.",
False,
),
],
@@ -99,7 +101,7 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram
err_msg: expected error message.
is_data_frame: whether data field is a pandas dataframe.
"""
- with pytest.raises(ValueError) as err:
+ with pytest.raises(UntypedComputedVarError) as err:
if is_data_frame:
DataTable.create(data=request.getfixturevalue(fixture).data)
else:
@@ -113,7 +115,8 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram
def test_serialize_dataframe():
"""Test if dataframe is serialized correctly."""
df = pd.DataFrame(
- [["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"]
+ [["foo", "bar"], ["foo1", "bar1"]],
+ columns=["column1", "column2"], # pyright: ignore [reportArgumentType]
)
value = serialize(df)
assert value == serialize_dataframe(df)
diff --git a/tests/units/components/datadisplay/test_shiki_code.py b/tests/units/components/datadisplay/test_shiki_code.py
index eb473ba06..e1c7984f1 100644
--- a/tests/units/components/datadisplay/test_shiki_code.py
+++ b/tests/units/components/datadisplay/test_shiki_code.py
@@ -11,6 +11,7 @@ from reflex.components.lucide.icon import Icon
from reflex.components.radix.themes.layout.box import Box
from reflex.style import Style
from reflex.vars import Var
+from reflex.vars.base import LiteralVar
@pytest.mark.parametrize(
@@ -95,11 +96,13 @@ def test_create_shiki_code_block(
# Test that the first child is the code
code_block_component = component.children[0]
- assert code_block_component.code._var_value == expected_first_child # type: ignore
+ assert code_block_component.code._var_value == expected_first_child # pyright: ignore [reportAttributeAccessIssue]
applied_styles = component.style
for key, value in expected_styles.items():
- assert Var.create(applied_styles[key])._var_value == value
+ var = Var.create(applied_styles[key])
+ assert isinstance(var, LiteralVar)
+ assert var._var_value == value
@pytest.mark.parametrize(
@@ -128,12 +131,12 @@ def test_create_shiki_high_level_code_block(
# Test that the first child is the code block component
code_block_component = component.children[0]
- assert code_block_component.code._var_value == children[0] # type: ignore
+ assert code_block_component.code._var_value == children[0] # pyright: ignore [reportAttributeAccessIssue]
# Check if the transformer is set correctly if expected
if expected_transformers:
exp_trans_names = [t.__name__ for t in expected_transformers]
- for transformer in code_block_component.transformers._var_value: # type: ignore
+ for transformer in code_block_component.transformers._var_value: # pyright: ignore [reportAttributeAccessIssue]
assert type(transformer).__name__ in exp_trans_names
# Check if the second child is the copy button if can_copy is True
@@ -161,12 +164,12 @@ def test_shiki_high_level_code_block_theme_language_mapping(children, props):
if "theme" in props:
assert component.children[
0
- ].theme._var_value == ShikiHighLevelCodeBlock._map_themes(props["theme"]) # type: ignore
+ ].theme._var_value == ShikiHighLevelCodeBlock._map_themes(props["theme"]) # pyright: ignore [reportAttributeAccessIssue]
# Test that the language is mapped correctly
if "language" in props:
assert component.children[
0
- ].language._var_value == ShikiHighLevelCodeBlock._map_languages( # type: ignore
+ ].language._var_value == ShikiHighLevelCodeBlock._map_languages( # pyright: ignore [reportAttributeAccessIssue]
props["language"]
)
diff --git a/tests/units/components/forms/test_form.py b/tests/units/components/forms/test_form.py
index 5f3ba2d37..69b5e7b63 100644
--- a/tests/units/components/forms/test_form.py
+++ b/tests/units/components/forms/test_form.py
@@ -10,7 +10,7 @@ def test_render_on_submit():
_var_type=EventChain,
)
f = Form.create(on_submit=submit_it)
- exp_submit_name = f"handleSubmit_{f.handle_submit_unique_name}" # type: ignore
+ exp_submit_name = f"handleSubmit_{f.handle_submit_unique_name}" # pyright: ignore [reportAttributeAccessIssue]
assert f"onSubmit={{{exp_submit_name}}}" in f.render()["props"]
diff --git a/tests/units/components/lucide/test_icon.py b/tests/units/components/lucide/test_icon.py
index b0a3475dd..19bea7a7f 100644
--- a/tests/units/components/lucide/test_icon.py
+++ b/tests/units/components/lucide/test_icon.py
@@ -1,13 +1,19 @@
import pytest
-from reflex.components.lucide.icon import LUCIDE_ICON_LIST, Icon
+from reflex.components.lucide.icon import (
+ LUCIDE_ICON_LIST,
+ LUCIDE_ICON_MAPPING_OVERRIDE,
+ Icon,
+)
from reflex.utils import format
@pytest.mark.parametrize("tag", LUCIDE_ICON_LIST)
def test_icon(tag):
icon = Icon.create(tag)
- assert icon.alias == f"Lucide{format.to_title_case(tag)}Icon"
+ assert icon.alias == "Lucide" + LUCIDE_ICON_MAPPING_OVERRIDE.get(
+ tag, f"{format.to_title_case(tag)}Icon"
+ )
def test_icon_missing_tag():
diff --git a/tests/units/components/markdown/test_markdown.py b/tests/units/components/markdown/test_markdown.py
index 866f32ae1..15d662ef6 100644
--- a/tests/units/components/markdown/test_markdown.py
+++ b/tests/units/components/markdown/test_markdown.py
@@ -148,7 +148,7 @@ def test_create_map_fn_var_subclass(cls, fn_body, fn_args, explicit_return, expe
(
"code",
{},
- """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; if (_language) { (async () => { try { const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`); SyntaxHighlighter.registerLanguage(_language, module.default); } catch (error) { console.error(`Error importing language module for ${_language}:`, error); } })(); } ; return inline ? ( {children} ) : ( ); })""",
+ r"""(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); let _language = match ? match[1] : ''; if (_language) { if (!["abap", "abnf", "actionscript", "ada", "agda", "al", "antlr4", "apacheconf", "apex", "apl", "applescript", "aql", "arduino", "arff", "asciidoc", "asm6502", "asmatmel", "aspnet", "autohotkey", "autoit", "avisynth", "avro-idl", "bash", "basic", "batch", "bbcode", "bicep", "birb", "bison", "bnf", "brainfuck", "brightscript", "bro", "bsl", "c", "cfscript", "chaiscript", "cil", "clike", "clojure", "cmake", "cobol", "coffeescript", "concurnas", "coq", "core", "cpp", "crystal", "csharp", "cshtml", "csp", "css", "css-extras", "csv", "cypher", "d", "dart", "dataweave", "dax", "dhall", "diff", "django", "dns-zone-file", "docker", "dot", "ebnf", "editorconfig", "eiffel", "ejs", "elixir", "elm", "erb", "erlang", "etlua", "excel-formula", "factor", "false", "firestore-security-rules", "flow", "fortran", "fsharp", "ftl", "gap", "gcode", "gdscript", "gedcom", "gherkin", "git", "glsl", "gml", "gn", "go", "go-module", "graphql", "groovy", "haml", "handlebars", "haskell", "haxe", "hcl", "hlsl", "hoon", "hpkp", "hsts", "http", "ichigojam", "icon", "icu-message-format", "idris", "iecst", "ignore", "index", "inform7", "ini", "io", "j", "java", "javadoc", "javadoclike", "javascript", "javastacktrace", "jexl", "jolie", "jq", "js-extras", "js-templates", "jsdoc", "json", "json5", "jsonp", "jsstacktrace", "jsx", "julia", "keepalived", "keyman", "kotlin", "kumir", "kusto", "latex", "latte", "less", "lilypond", "liquid", "lisp", "livescript", "llvm", "log", "lolcode", "lua", "magma", "makefile", "markdown", "markup", "markup-templating", "matlab", "maxscript", "mel", "mermaid", "mizar", "mongodb", "monkey", "moonscript", "n1ql", "n4js", "nand2tetris-hdl", "naniscript", "nasm", "neon", "nevod", "nginx", "nim", "nix", "nsis", "objectivec", "ocaml", "opencl", "openqasm", "oz", "parigp", "parser", "pascal", "pascaligo", "pcaxis", "peoplecode", "perl", "php", "php-extras", "phpdoc", "plsql", "powerquery", "powershell", "processing", "prolog", "promql", "properties", "protobuf", "psl", "pug", "puppet", "pure", "purebasic", "purescript", "python", "q", "qml", "qore", "qsharp", "r", "racket", "reason", "regex", "rego", "renpy", "rest", "rip", "roboconf", "robotframework", "ruby", "rust", "sas", "sass", "scala", "scheme", "scss", "shell-session", "smali", "smalltalk", "smarty", "sml", "solidity", "solution-file", "soy", "sparql", "splunk-spl", "sqf", "sql", "squirrel", "stan", "stylus", "swift", "systemd", "t4-cs", "t4-templating", "t4-vb", "tap", "tcl", "textile", "toml", "tremor", "tsx", "tt2", "turtle", "twig", "typescript", "typoscript", "unrealscript", "uorazor", "uri", "v", "vala", "vbnet", "velocity", "verilog", "vhdl", "vim", "visual-basic", "warpscript", "wasm", "web-idl", "wiki", "wolfram", "wren", "xeora", "xml-doc", "xojo", "xquery", "yaml", "yang", "zig"].includes(_language)) { console.warn(`Language \`${_language}\` is not supported for code blocks inside of markdown.`); _language = ''; } else { (async () => { try { const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`); SyntaxHighlighter.registerLanguage(_language, module.default); } catch (error) { console.error(`Language ${_language} is not supported for code blocks inside of markdown: `, error); } })(); } } ; return inline ? ( {children} ) : ( ); })""",
),
(
"code",
@@ -157,7 +157,7 @@ def test_create_map_fn_var_subclass(cls, fn_body, fn_args, explicit_return, expe
value, **props
)
},
- """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; ; return inline ? ( {children} ) : ( ); })""",
+ r"""(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); let _language = match ? match[1] : ''; ; return inline ? ( {children} ) : ( ); })""",
),
(
"h1",
@@ -171,7 +171,7 @@ def test_create_map_fn_var_subclass(cls, fn_body, fn_args, explicit_return, expe
(
"code",
{"codeblock": syntax_highlighter_memoized_component(CodeBlock)},
- """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; if (_language) { (async () => { try { const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`); SyntaxHighlighter.registerLanguage(_language, module.default); } catch (error) { console.error(`Error importing language module for ${_language}:`, error); } })(); } ; return inline ? ( {children} ) : ( ); })""",
+ r"""(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); let _language = match ? match[1] : ''; if (_language) { if (!["abap", "abnf", "actionscript", "ada", "agda", "al", "antlr4", "apacheconf", "apex", "apl", "applescript", "aql", "arduino", "arff", "asciidoc", "asm6502", "asmatmel", "aspnet", "autohotkey", "autoit", "avisynth", "avro-idl", "bash", "basic", "batch", "bbcode", "bicep", "birb", "bison", "bnf", "brainfuck", "brightscript", "bro", "bsl", "c", "cfscript", "chaiscript", "cil", "clike", "clojure", "cmake", "cobol", "coffeescript", "concurnas", "coq", "core", "cpp", "crystal", "csharp", "cshtml", "csp", "css", "css-extras", "csv", "cypher", "d", "dart", "dataweave", "dax", "dhall", "diff", "django", "dns-zone-file", "docker", "dot", "ebnf", "editorconfig", "eiffel", "ejs", "elixir", "elm", "erb", "erlang", "etlua", "excel-formula", "factor", "false", "firestore-security-rules", "flow", "fortran", "fsharp", "ftl", "gap", "gcode", "gdscript", "gedcom", "gherkin", "git", "glsl", "gml", "gn", "go", "go-module", "graphql", "groovy", "haml", "handlebars", "haskell", "haxe", "hcl", "hlsl", "hoon", "hpkp", "hsts", "http", "ichigojam", "icon", "icu-message-format", "idris", "iecst", "ignore", "index", "inform7", "ini", "io", "j", "java", "javadoc", "javadoclike", "javascript", "javastacktrace", "jexl", "jolie", "jq", "js-extras", "js-templates", "jsdoc", "json", "json5", "jsonp", "jsstacktrace", "jsx", "julia", "keepalived", "keyman", "kotlin", "kumir", "kusto", "latex", "latte", "less", "lilypond", "liquid", "lisp", "livescript", "llvm", "log", "lolcode", "lua", "magma", "makefile", "markdown", "markup", "markup-templating", "matlab", "maxscript", "mel", "mermaid", "mizar", "mongodb", "monkey", "moonscript", "n1ql", "n4js", "nand2tetris-hdl", "naniscript", "nasm", "neon", "nevod", "nginx", "nim", "nix", "nsis", "objectivec", "ocaml", "opencl", "openqasm", "oz", "parigp", "parser", "pascal", "pascaligo", "pcaxis", "peoplecode", "perl", "php", "php-extras", "phpdoc", "plsql", "powerquery", "powershell", "processing", "prolog", "promql", "properties", "protobuf", "psl", "pug", "puppet", "pure", "purebasic", "purescript", "python", "q", "qml", "qore", "qsharp", "r", "racket", "reason", "regex", "rego", "renpy", "rest", "rip", "roboconf", "robotframework", "ruby", "rust", "sas", "sass", "scala", "scheme", "scss", "shell-session", "smali", "smalltalk", "smarty", "sml", "solidity", "solution-file", "soy", "sparql", "splunk-spl", "sqf", "sql", "squirrel", "stan", "stylus", "swift", "systemd", "t4-cs", "t4-templating", "t4-vb", "tap", "tcl", "textile", "toml", "tremor", "tsx", "tt2", "turtle", "twig", "typescript", "typoscript", "unrealscript", "uorazor", "uri", "v", "vala", "vbnet", "velocity", "verilog", "vhdl", "vim", "visual-basic", "warpscript", "wasm", "web-idl", "wiki", "wolfram", "wren", "xeora", "xml-doc", "xojo", "xquery", "yaml", "yang", "zig"].includes(_language)) { console.warn(`Language \`${_language}\` is not supported for code blocks inside of markdown.`); _language = ''; } else { (async () => { try { const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`); SyntaxHighlighter.registerLanguage(_language, module.default); } catch (error) { console.error(`Language ${_language} is not supported for code blocks inside of markdown: `, error); } })(); } } ; return inline ? ( {children} ) : ( ); })""",
),
(
"code",
@@ -180,11 +180,12 @@ def test_create_map_fn_var_subclass(cls, fn_body, fn_args, explicit_return, expe
ShikiHighLevelCodeBlock
)
},
- """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; ; return inline ? ( {children} ) : ( ); })""",
+ r"""(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); let _language = match ? match[1] : ''; ; return inline ? ( {children} ) : ( ); })""",
),
],
)
def test_markdown_format_component(key, component_map, expected):
markdown = Markdown.create("# header", component_map=component_map)
result = markdown.format_component_map()
+ print(str(result[key]))
assert str(result[key]) == expected
diff --git a/tests/units/components/media/test_image.py b/tests/units/components/media/test_image.py
index 742bd8c38..519ca735e 100644
--- a/tests/units/components/media/test_image.py
+++ b/tests/units/components/media/test_image.py
@@ -4,7 +4,7 @@ import pytest
from PIL.Image import Image as Img
import reflex as rx
-from reflex.components.next.image import Image # type: ignore
+from reflex.components.next.image import Image
from reflex.utils.serializers import serialize, serialize_image
from reflex.vars.sequence import StringVar
@@ -17,7 +17,7 @@ def pil_image() -> Img:
A random PIL image.
"""
imarray = np.random.rand(100, 100, 3) * 255
- return PIL.Image.fromarray(imarray.astype("uint8")).convert("RGBA") # type: ignore
+ return PIL.Image.fromarray(imarray.astype("uint8")).convert("RGBA") # pyright: ignore [reportAttributeAccessIssue]
def test_serialize_image(pil_image: Img):
@@ -36,13 +36,13 @@ def test_set_src_str():
"""Test that setting the src works."""
image = rx.image(src="pic2.jpeg")
# when using next/image, we explicitly create a _var_is_str Var
- assert str(image.src) in ( # type: ignore
+ assert str(image.src) in ( # pyright: ignore [reportAttributeAccessIssue]
'"pic2.jpeg"',
"'pic2.jpeg'",
"`pic2.jpeg`",
)
# For plain rx.el.img, an explicit var is not created, so the quoting happens later
- # assert str(image.src) == "pic2.jpeg" # type: ignore #noqa: ERA001
+ # assert str(image.src) == "pic2.jpeg" #noqa: ERA001
def test_set_src_img(pil_image: Img):
@@ -52,7 +52,7 @@ def test_set_src_img(pil_image: Img):
pil_image: The image to serialize.
"""
image = Image.create(src=pil_image)
- assert str(image.src._js_expr) == '"' + serialize_image(pil_image) + '"' # type: ignore
+ assert str(image.src._js_expr) == '"' + serialize_image(pil_image) + '"' # pyright: ignore [reportAttributeAccessIssue]
def test_render(pil_image: Img):
@@ -62,4 +62,4 @@ def test_render(pil_image: Img):
pil_image: The image to serialize.
"""
image = Image.create(src=pil_image)
- assert isinstance(image.src, StringVar) # type: ignore
+ assert isinstance(image.src, StringVar) # pyright: ignore [reportAttributeAccessIssue]
diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py
index 674873b69..d333a45b4 100644
--- a/tests/units/components/test_component.py
+++ b/tests/units/components/test_component.py
@@ -19,6 +19,7 @@ from reflex.constants import EventTriggers
from reflex.event import (
EventChain,
EventHandler,
+ JavascriptInputEvent,
input_event,
no_args_event_spec,
parse_args_spec,
@@ -27,10 +28,15 @@ from reflex.event import (
from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
-from reflex.utils.exceptions import EventFnArgMismatch
+from reflex.utils.exceptions import (
+ ChildrenTypeError,
+ EventFnArgMismatchError,
+ EventHandlerArgTypeMismatchError,
+)
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
+from reflex.vars.object import ObjectVar
@pytest.fixture
@@ -94,11 +100,14 @@ def component2() -> Type[Component]:
A test component.
"""
+ def on_prop_event_spec(e0: Any):
+ return [e0]
+
class TestComponent2(Component):
# A test list prop.
arr: Var[List[str]]
- on_prop_event: EventHandler[lambda e0: [e0]]
+ on_prop_event: EventHandler[on_prop_event_spec]
def get_event_triggers(self) -> Dict[str, Any]:
"""Test controlled triggers.
@@ -444,8 +453,8 @@ def test_add_style(component1, component2):
component1: Style({"color": "white"}),
component2: Style({"color": "black"}),
}
- c1 = component1()._add_style_recursive(style) # type: ignore
- c2 = component2()._add_style_recursive(style) # type: ignore
+ c1 = component1()._add_style_recursive(style)
+ c2 = component2()._add_style_recursive(style)
assert str(c1.style["color"]) == '"white"'
assert str(c2.style["color"]) == '"black"'
@@ -461,8 +470,8 @@ def test_add_style_create(component1, component2):
component1.create: Style({"color": "white"}),
component2.create: Style({"color": "black"}),
}
- c1 = component1()._add_style_recursive(style) # type: ignore
- c2 = component2()._add_style_recursive(style) # type: ignore
+ c1 = component1()._add_style_recursive(style)
+ c2 = component2()._add_style_recursive(style)
assert str(c1.style["color"]) == '"white"'
assert str(c2.style["color"]) == '"black"'
@@ -642,17 +651,20 @@ def test_create_filters_none_props(test_component):
# Assert that the style prop is present in the component's props
assert str(component.style["color"]) == '"white"'
- assert str(component.style["text-align"]) == '"center"'
+ assert str(component.style["textAlign"]) == '"center"'
-@pytest.mark.parametrize("children", [((None,),), ("foo", ("bar", (None,)))])
+@pytest.mark.parametrize(
+ "children",
+ [
+ ((None,),),
+ ("foo", ("bar", (None,))),
+ ({"foo": "bar"},),
+ ],
+)
def test_component_create_unallowed_types(children, test_component):
- with pytest.raises(TypeError) as err:
+ with pytest.raises(ChildrenTypeError):
test_component.create(*children)
- assert (
- err.value.args[0]
- == "Children of Reflex components must be other components, state vars, or primitive Python types. Got child None of type ."
- )
@pytest.mark.parametrize(
@@ -815,10 +827,14 @@ def test_component_create_unpack_tuple_child(test_component, element, expected):
assert fragment_wrapper.render() == expected
+class _Obj(Base):
+ custom: int = 0
+
+
class C1State(BaseState):
"""State for testing C1 component."""
- def mock_handler(self, _e, _bravo, _charlie):
+ def mock_handler(self, _e: JavascriptInputEvent, _bravo: dict, _charlie: _Obj):
"""Mock handler."""
pass
@@ -826,11 +842,13 @@ class C1State(BaseState):
def test_component_event_trigger_arbitrary_args():
"""Test that we can define arbitrary types for the args of an event trigger."""
- class Obj(Base):
- custom: int = 0
-
- def on_foo_spec(_e, alpha: str, bravo: Dict[str, Any], charlie: Obj):
- return [_e.target.value, bravo["nested"], charlie.custom + 42]
+ def on_foo_spec(
+ _e: ObjectVar[JavascriptInputEvent],
+ alpha: Var[str],
+ bravo: dict[str, Any],
+ charlie: ObjectVar[_Obj],
+ ):
+ return [_e.target.value, bravo["nested"], charlie.custom.to(int) + 42]
class C1(Component):
library = "/local"
@@ -842,13 +860,7 @@ def test_component_event_trigger_arbitrary_args():
"on_foo": on_foo_spec,
}
- comp = C1.create(on_foo=C1State.mock_handler)
-
- assert comp.render()["props"][0] == (
- "onFoo={((__e, _alpha, _bravo, _charlie) => (addEvents("
- f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }}), ({{ }})))], '
- "[__e, _alpha, _bravo, _charlie], ({ }))))}"
- )
+ C1.create(on_foo=C1State.mock_handler)
def test_create_custom_component(my_component):
@@ -859,7 +871,7 @@ def test_create_custom_component(my_component):
"""
component = CustomComponent(component_fn=my_component, prop1="test", prop2=1)
assert component.tag == "MyComponent"
- assert component.get_props() == set()
+ assert component.get_props() == {"prop1", "prop2"}
assert component._get_all_custom_components() == {component}
@@ -905,30 +917,29 @@ def test_invalid_event_handler_args(component2, test_state):
test_state: A test state.
"""
# EventHandler args must match
- with pytest.raises(EventFnArgMismatch):
+ with pytest.raises(EventFnArgMismatchError):
component2.create(on_click=test_state.do_something_arg)
# Multiple EventHandler args: all must match
- with pytest.raises(EventFnArgMismatch):
+ with pytest.raises(EventFnArgMismatchError):
component2.create(
on_click=[test_state.do_something_arg, test_state.do_something]
)
- # Enable when 0.7.0 happens
# # Event Handler types must match
- # with pytest.raises(EventHandlerArgTypeMismatch):
- # component2.create(
- # on_user_visited_count_changed=test_state.do_something_with_bool # noqa: ERA001 RUF100
- # ) # noqa: ERA001 RUF100
- # with pytest.raises(EventHandlerArgTypeMismatch):
- # component2.create(on_user_list_changed=test_state.do_something_with_int) #noqa: ERA001
- # with pytest.raises(EventHandlerArgTypeMismatch):
- # component2.create(on_user_list_changed=test_state.do_something_with_list_int) #noqa: ERA001
+ with pytest.raises(EventHandlerArgTypeMismatchError):
+ component2.create(
+ on_user_visited_count_changed=test_state.do_something_with_bool
+ )
+ with pytest.raises(EventHandlerArgTypeMismatchError):
+ component2.create(on_user_list_changed=test_state.do_something_with_int)
+ with pytest.raises(EventHandlerArgTypeMismatchError):
+ component2.create(on_user_list_changed=test_state.do_something_with_list_int)
- # component2.create(on_open=test_state.do_something_with_int) #noqa: ERA001
- # component2.create(on_open=test_state.do_something_with_bool) #noqa: ERA001
- # component2.create(on_user_visited_count_changed=test_state.do_something_with_int) #noqa: ERA001
- # component2.create(on_user_list_changed=test_state.do_something_with_list_str) #noqa: ERA001
+ component2.create(on_open=test_state.do_something_with_int)
+ component2.create(on_open=test_state.do_something_with_bool)
+ component2.create(on_user_visited_count_changed=test_state.do_something_with_int)
+ component2.create(on_user_list_changed=test_state.do_something_with_list_str)
# lambda cannot return weird values.
with pytest.raises(ValueError):
@@ -941,15 +952,15 @@ def test_invalid_event_handler_args(component2, test_state):
)
# lambda signature must match event trigger.
- with pytest.raises(EventFnArgMismatch):
+ with pytest.raises(EventFnArgMismatchError):
component2.create(on_click=lambda _: test_state.do_something_arg(1))
# lambda returning EventHandler must match spec
- with pytest.raises(EventFnArgMismatch):
+ with pytest.raises(EventFnArgMismatchError):
component2.create(on_click=lambda: test_state.do_something_arg)
# Mixed EventSpec and EventHandler must match spec.
- with pytest.raises(EventFnArgMismatch):
+ with pytest.raises(EventFnArgMismatchError):
component2.create(
on_click=lambda: [
test_state.do_something_arg(1),
@@ -1318,7 +1329,7 @@ class EventState(rx.State):
),
pytest.param(
rx.fragment(class_name=[TEST_VAR, "other-class"]),
- [LiteralVar.create([TEST_VAR, "other-class"]).join(" ")],
+ [Var.create([TEST_VAR, "other-class"]).join(" ")],
id="fstring-dual-class_name",
),
pytest.param(
@@ -1353,17 +1364,17 @@ class EventState(rx.State):
id="fstring-background_color",
),
pytest.param(
- rx.fragment(style={"background_color": TEST_VAR}), # type: ignore
+ rx.fragment(style={"background_color": TEST_VAR}), # pyright: ignore [reportArgumentType]
[STYLE_VAR],
id="direct-style-background_color",
),
pytest.param(
- rx.fragment(style={"background_color": f"foo{TEST_VAR}bar"}), # type: ignore
+ rx.fragment(style={"background_color": f"foo{TEST_VAR}bar"}), # pyright: ignore [reportArgumentType]
[STYLE_VAR],
id="fstring-style-background_color",
),
pytest.param(
- rx.fragment(on_click=EVENT_CHAIN_VAR), # type: ignore
+ rx.fragment(on_click=EVENT_CHAIN_VAR),
[EVENT_CHAIN_VAR],
id="direct-event-chain",
),
@@ -1373,17 +1384,17 @@ class EventState(rx.State):
id="direct-event-handler",
),
pytest.param(
- rx.fragment(on_click=EventState.handler2(TEST_VAR)), # type: ignore
+ rx.fragment(on_click=EventState.handler2(TEST_VAR)), # pyright: ignore [reportCallIssue]
[ARG_VAR, TEST_VAR],
id="direct-event-handler-arg",
),
pytest.param(
- rx.fragment(on_click=EventState.handler2(EventState.v)), # type: ignore
+ rx.fragment(on_click=EventState.handler2(EventState.v)), # pyright: ignore [reportCallIssue]
[ARG_VAR, EventState.v],
id="direct-event-handler-arg2",
),
pytest.param(
- rx.fragment(on_click=lambda: EventState.handler2(TEST_VAR)), # type: ignore
+ rx.fragment(on_click=lambda: EventState.handler2(TEST_VAR)), # pyright: ignore [reportCallIssue]
[ARG_VAR, TEST_VAR],
id="direct-event-handler-lambda",
),
@@ -1436,6 +1447,7 @@ def test_get_vars(component, exp_vars):
for comp_var, exp_var in zip(
comp_vars,
sorted(exp_vars, key=lambda v: v._js_expr),
+ strict=True,
):
assert comp_var.equals(exp_var)
@@ -1471,7 +1483,7 @@ def test_instantiate_all_components():
comp_name
for submodule_list in component_nested_list
for comp_name in submodule_list
- ]: # type: ignore
+ ]:
if component_name in untested_components:
continue
component = getattr(
@@ -1544,11 +1556,11 @@ def test_validate_valid_children():
)
valid_component1(
- rx.cond( # type: ignore
+ rx.cond(
True,
rx.fragment(valid_component2()),
rx.fragment(
- rx.foreach(LiteralVar.create([1, 2, 3]), lambda x: valid_component2(x)) # type: ignore
+ rx.foreach(LiteralVar.create([1, 2, 3]), lambda x: valid_component2(x))
),
)
)
@@ -1603,12 +1615,12 @@ def test_validate_valid_parents():
)
valid_component2(
- rx.cond( # type: ignore
+ rx.cond(
True,
rx.fragment(valid_component3()),
rx.fragment(
rx.foreach(
- LiteralVar.create([1, 2, 3]), # type: ignore
+ LiteralVar.create([1, 2, 3]),
lambda x: valid_component2(valid_component3(x)),
)
),
@@ -1671,13 +1683,13 @@ def test_validate_invalid_children():
with pytest.raises(ValueError):
valid_component4(
- rx.cond( # type: ignore
+ rx.cond(
True,
rx.fragment(invalid_component()),
rx.fragment(
rx.foreach(
LiteralVar.create([1, 2, 3]), lambda x: invalid_component(x)
- ) # type: ignore
+ )
),
)
)
@@ -1798,21 +1810,15 @@ def test_custom_component_declare_event_handlers_in_fields():
"""
return {
**super().get_event_triggers(),
- "on_a": lambda e0: [e0],
"on_b": input_event,
- "on_c": lambda e0: [],
"on_d": lambda: [],
"on_e": lambda: [],
- "on_f": lambda a, b, c: [c, b, a],
}
class TestComponent(Component):
- on_a: EventHandler[lambda e0: [e0]]
on_b: EventHandler[input_event]
- on_c: EventHandler[no_args_event_spec]
on_d: EventHandler[no_args_event_spec]
on_e: EventHandler
- on_f: EventHandler[lambda a, b, c: [c, b, a]]
custom_component = ReferenceComponent.create()
test_component = TestComponent.create()
@@ -1823,6 +1829,7 @@ def test_custom_component_declare_event_handlers_in_fields():
for v1, v2 in zip(
parse_args_spec(test_triggers[trigger_name]),
parse_args_spec(custom_triggers[trigger_name]),
+ strict=True,
):
assert v1.equals(v2)
@@ -1864,7 +1871,7 @@ def test_invalid_event_trigger():
)
def test_component_add_imports(tags):
class BaseComponent(Component):
- def _get_imports(self) -> ImportDict:
+ def _get_imports(self) -> ImportDict: # pyright: ignore [reportIncompatibleMethodOverride]
return {}
class Reference(Component):
@@ -1876,7 +1883,7 @@ def test_component_add_imports(tags):
)
class TestBase(Component):
- def add_imports(
+ def add_imports( # pyright: ignore [reportIncompatibleMethodOverride]
self,
) -> Dict[str, Union[str, ImportVar, List[str], List[ImportVar]]]:
return {"foo": "bar"}
@@ -1908,7 +1915,7 @@ def test_component_add_hooks():
pass
class GrandchildComponent1(ChildComponent1):
- def add_hooks(self):
+ def add_hooks(self): # pyright: ignore [reportIncompatibleMethodOverride]
return [
"const hook2 = 43",
"const hook3 = 44",
@@ -1921,11 +1928,11 @@ def test_component_add_hooks():
]
class GrandchildComponent2(ChildComponent1):
- def _get_hooks(self):
+ def _get_hooks(self): # pyright: ignore [reportIncompatibleMethodOverride]
return "const hook5 = 46"
class GreatGrandchildComponent2(GrandchildComponent2):
- def add_hooks(self):
+ def add_hooks(self): # pyright: ignore [reportIncompatibleMethodOverride]
return [
"const hook2 = 43",
"const hook6 = 47",
@@ -2000,7 +2007,7 @@ def test_component_add_custom_code():
]
class GrandchildComponent2(ChildComponent1):
- def _get_custom_code(self):
+ def _get_custom_code(self): # pyright: ignore [reportIncompatibleMethodOverride]
return "const custom_code5 = 46"
class GreatGrandchildComponent2(GrandchildComponent2):
@@ -2096,11 +2103,11 @@ def test_add_style_embedded_vars(test_state: BaseState):
test_state: A test state.
"""
v0 = LiteralVar.create("parent")._replace(
- merge_var_data=VarData(hooks={"useParent": None}), # type: ignore
+ merge_var_data=VarData(hooks={"useParent": None}),
)
v1 = rx.color("plum", 10)
v2 = LiteralVar.create("text")._replace(
- merge_var_data=VarData(hooks={"useText": None}), # type: ignore
+ merge_var_data=VarData(hooks={"useText": None}),
)
class ParentComponent(Component):
@@ -2114,7 +2121,7 @@ def test_add_style_embedded_vars(test_state: BaseState):
class StyledComponent(ParentComponent):
tag = "StyledComponent"
- def add_style(self):
+ def add_style(self): # pyright: ignore [reportIncompatibleMethodOverride]
return {
"color": v1,
"fake": v2,
diff --git a/tests/units/components/test_props.py b/tests/units/components/test_props.py
index 8ab07f135..8ed49d58a 100644
--- a/tests/units/components/test_props.py
+++ b/tests/units/components/test_props.py
@@ -1,13 +1,9 @@
import pytest
+from pydantic.v1 import ValidationError
from reflex.components.props import NoExtrasAllowedProps
from reflex.utils.exceptions import InvalidPropValueError
-try:
- from pydantic.v1 import ValidationError
-except ModuleNotFoundError:
- from pydantic import ValidationError
-
class PropA(NoExtrasAllowedProps):
"""Base prop class."""
diff --git a/tests/units/components/typography/test_markdown.py b/tests/units/components/typography/test_markdown.py
index 5e9abbb1f..12f3b0dbe 100644
--- a/tests/units/components/typography/test_markdown.py
+++ b/tests/units/components/typography/test_markdown.py
@@ -29,8 +29,8 @@ def test_get_component(tag, expected):
expected: The expected component.
"""
md = Markdown.create("# Hello")
- assert tag in md.component_map # type: ignore
- assert md.get_component(tag).tag == expected # type: ignore
+ assert tag in md.component_map # pyright: ignore [reportAttributeAccessIssue]
+ assert md.get_component(tag).tag == expected
def test_set_component_map():
@@ -42,8 +42,8 @@ def test_set_component_map():
md = Markdown.create("# Hello", component_map=component_map)
# Check that the new tags have been added.
- assert md.get_component("h1").tag == "Box" # type: ignore
- assert md.get_component("p").tag == "Box" # type: ignore
+ assert md.get_component("h1").tag == "Box"
+ assert md.get_component("p").tag == "Box"
# Make sure the old tags are still there.
- assert md.get_component("h2").tag == "Heading" # type: ignore
+ assert md.get_component("h2").tag == "Heading"
diff --git a/tests/units/conftest.py b/tests/units/conftest.py
index fb6229aca..2ee290ea3 100644
--- a/tests/units/conftest.py
+++ b/tests/units/conftest.py
@@ -1,11 +1,8 @@
"""Test fixtures."""
import asyncio
-import contextlib
-import os
import platform
import uuid
-from pathlib import Path
from typing import Dict, Generator, Type
from unittest import mock
@@ -14,6 +11,7 @@ import pytest
from reflex.app import App
from reflex.event import EventSpec
from reflex.model import ModelRegistry
+from reflex.testing import chdir
from reflex.utils import prerequisites
from .states import (
@@ -97,7 +95,7 @@ def upload_sub_state_event_spec():
Returns:
Event Spec.
"""
- return EventSpec(handler=SubUploadState.handle_upload, upload=True) # type: ignore
+ return EventSpec(handler=SubUploadState.handle_upload, upload=True) # pyright: ignore [reportCallIssue]
@pytest.fixture
@@ -107,7 +105,7 @@ def upload_event_spec():
Returns:
Event Spec.
"""
- return EventSpec(handler=UploadState.handle_upload1, upload=True) # type: ignore
+ return EventSpec(handler=UploadState.handle_upload1, upload=True) # pyright: ignore [reportCallIssue]
@pytest.fixture
@@ -145,7 +143,7 @@ def sqlite_db_config_values(base_db_config_values) -> Dict:
@pytest.fixture
-def router_data_headers() -> Dict[str, str]:
+def router_data_headers() -> dict[str, str]:
"""Router data headers.
Returns:
@@ -172,7 +170,7 @@ def router_data_headers() -> Dict[str, str]:
@pytest.fixture
-def router_data(router_data_headers) -> Dict[str, str]:
+def router_data(router_data_headers: dict[str, str]) -> dict[str, str | dict]:
"""Router data.
Args:
@@ -181,7 +179,7 @@ def router_data(router_data_headers) -> Dict[str, str]:
Returns:
Dict of router data.
"""
- return { # type: ignore
+ return {
"pathname": "/",
"query": {},
"token": "b181904c-3953-4a79-dc18-ae9518c22f05",
@@ -191,33 +189,6 @@ def router_data(router_data_headers) -> Dict[str, str]:
}
-# borrowed from py3.11
-class chdir(contextlib.AbstractContextManager):
- """Non thread-safe context manager to change the current working directory."""
-
- def __init__(self, path):
- """Prepare contextmanager.
-
- Args:
- path: the path to change to
- """
- self.path = path
- self._old_cwd = []
-
- def __enter__(self):
- """Save current directory and perform chdir."""
- self._old_cwd.append(Path.cwd())
- os.chdir(self.path)
-
- def __exit__(self, *excinfo):
- """Change back to previous directory on stack.
-
- Args:
- excinfo: sys.exc_info captured in the context block
- """
- os.chdir(self._old_cwd.pop())
-
-
@pytest.fixture
def tmp_working_dir(tmp_path):
"""Create a temporary directory and chdir to it.
diff --git a/tests/units/middleware/test_hydrate_middleware.py b/tests/units/middleware/test_hydrate_middleware.py
index 9ee8d8d25..7b02f8515 100644
--- a/tests/units/middleware/test_hydrate_middleware.py
+++ b/tests/units/middleware/test_hydrate_middleware.py
@@ -41,7 +41,7 @@ async def test_preprocess_no_events(hydrate_middleware, event1, mocker):
mocker.patch("reflex.state.State.class_subclasses", {TestState})
state = State()
update = await hydrate_middleware.preprocess(
- app=App(state=State),
+ app=App(_state=State),
event=event1,
state=state,
)
diff --git a/tests/units/states/mutation.py b/tests/units/states/mutation.py
index b05f558a1..ad658bbd0 100644
--- a/tests/units/states/mutation.py
+++ b/tests/units/states/mutation.py
@@ -18,7 +18,7 @@ class DictMutationTestState(BaseState):
def add_age(self):
"""Add an age to the dict."""
- self.details.update({"age": 20}) # type: ignore
+ self.details.update({"age": 20}) # pyright: ignore [reportCallIssue, reportArgumentType]
def change_name(self):
"""Change the name in the dict."""
diff --git a/tests/units/test_app.py b/tests/units/test_app.py
index 48a4bdda1..88cb36509 100644
--- a/tests/units/test_app.py
+++ b/tests/units/test_app.py
@@ -133,7 +133,7 @@ def test_model() -> Type[Model]:
A default model.
"""
- class TestModel(Model, table=True): # type: ignore
+ class TestModel(Model, table=True):
pass
return TestModel
@@ -147,7 +147,7 @@ def test_model_auth() -> Type[Model]:
A default model.
"""
- class TestModelAuth(Model, table=True): # type: ignore
+ class TestModelAuth(Model, table=True):
"""A test model with auth."""
pass
@@ -185,19 +185,19 @@ def test_custom_auth_admin() -> Type[AuthProvider]:
login_path: str = "/login"
logout_path: str = "/logout"
- def login(self):
+ def login(self): # pyright: ignore [reportIncompatibleMethodOverride]
"""Login."""
pass
- def is_authenticated(self):
+ def is_authenticated(self): # pyright: ignore [reportIncompatibleMethodOverride]
"""Is authenticated."""
pass
- def get_admin_user(self):
+ def get_admin_user(self): # pyright: ignore [reportIncompatibleMethodOverride]
"""Get admin user."""
pass
- def logout(self):
+ def logout(self): # pyright: ignore [reportIncompatibleMethodOverride]
"""Logout."""
pass
@@ -236,14 +236,14 @@ def test_add_page_default_route(app: App, index_page, about_page):
index_page: The index page.
about_page: The about page.
"""
- assert app.pages == {}
- assert app.unevaluated_pages == {}
+ assert app._pages == {}
+ assert app._unevaluated_pages == {}
app.add_page(index_page)
app._compile_page("index")
- assert app.pages.keys() == {"index"}
+ assert app._pages.keys() == {"index"}
app.add_page(about_page)
app._compile_page("about")
- assert app.pages.keys() == {"index", "about"}
+ assert app._pages.keys() == {"index", "about"}
def test_add_page_set_route(app: App, index_page, windows_platform: bool):
@@ -255,10 +255,10 @@ def test_add_page_set_route(app: App, index_page, windows_platform: bool):
windows_platform: Whether the system is windows.
"""
route = "test" if windows_platform else "/test"
- assert app.unevaluated_pages == {}
+ assert app._unevaluated_pages == {}
app.add_page(index_page, route=route)
app._compile_page("test")
- assert app.pages.keys() == {"test"}
+ assert app._pages.keys() == {"test"}
def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
@@ -268,18 +268,18 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
index_page: The index page.
windows_platform: Whether the system is windows.
"""
- app = App(state=EmptyState)
- assert app.state is not None
+ app = App(_state=EmptyState)
+ assert app._state is not None
route = "/test/[dynamic]"
- assert app.unevaluated_pages == {}
+ assert app._unevaluated_pages == {}
app.add_page(index_page, route=route)
app._compile_page("test/[dynamic]")
- assert app.pages.keys() == {"test/[dynamic]"}
- assert "dynamic" in app.state.computed_vars
- assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
- constants.ROUTER
+ assert app._pages.keys() == {"test/[dynamic]"}
+ assert "dynamic" in app._state.computed_vars
+ assert app._state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
+ EmptyState.get_full_name(): {constants.ROUTER},
}
- assert constants.ROUTER in app.state()._computed_var_dependencies
+ assert constants.ROUTER in app._state()._var_dependencies
def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
@@ -291,9 +291,9 @@ def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool)
windows_platform: Whether the system is windows.
"""
route = "test\\nested" if windows_platform else "/test/nested"
- assert app.unevaluated_pages == {}
+ assert app._unevaluated_pages == {}
app.add_page(index_page, route=route)
- assert app.unevaluated_pages.keys() == {route.strip(os.path.sep)}
+ assert app._unevaluated_pages.keys() == {route.strip(os.path.sep)}
def test_add_page_invalid_api_route(app: App, index_page):
@@ -413,13 +413,13 @@ async def test_initialize_with_state(test_state: Type[ATestState], token: str):
test_state: The default state.
token: a Token.
"""
- app = App(state=test_state)
- assert app.state == test_state
+ app = App(_state=test_state)
+ assert app._state == test_state
# Get a state for a given token.
state = await app.state_manager.get_state(_substate_key(token, test_state))
assert isinstance(state, test_state)
- assert state.var == 0 # type: ignore
+ assert state.var == 0
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@@ -432,7 +432,7 @@ async def test_set_and_get_state(test_state):
Args:
test_state: The default state.
"""
- app = App(state=test_state)
+ app = App(_state=test_state)
# Create two tokens.
token1 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}"
@@ -441,8 +441,8 @@ async def test_set_and_get_state(test_state):
# Get the default state for each token.
state1 = await app.state_manager.get_state(token1)
state2 = await app.state_manager.get_state(token2)
- assert state1.var == 0 # type: ignore
- assert state2.var == 0 # type: ignore
+ assert state1.var == 0
+ assert state2.var == 0
# Set the vars to different values.
state1.var = 1
@@ -453,8 +453,8 @@ async def test_set_and_get_state(test_state):
# Get the states again and check the values.
state1 = await app.state_manager.get_state(token1)
state2 = await app.state_manager.get_state(token2)
- assert state1.var == 1 # type: ignore
- assert state2.var == 2 # type: ignore
+ assert state1.var == 1
+ assert state2.var == 2
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@@ -469,17 +469,17 @@ async def test_dynamic_var_event(test_state: Type[ATestState], token: str):
test_state: State Fixture.
token: a Token.
"""
- state = test_state() # type: ignore
+ state = test_state() # pyright: ignore [reportCallIssue]
state.add_var("int_val", int, 0)
- result = await state._process(
+ async for result in state._process(
Event(
token=token,
name=f"{test_state.get_name()}.set_int_val",
router_data={"pathname": "/", "query": {}},
payload={"value": 50},
)
- ).__anext__()
- assert result.delta == {test_state.get_name(): {"int_val": 50}}
+ ):
+ assert result.delta == {test_state.get_name(): {"int_val": 50}}
@pytest.mark.asyncio
@@ -583,18 +583,17 @@ async def test_list_mutation_detection__plain_list(
token: a Token.
"""
for event_name, expected_delta in event_tuples:
- result = await list_mutation_state._process(
+ async for result in list_mutation_state._process(
Event(
token=token,
name=f"{list_mutation_state.get_name()}.{event_name}",
router_data={"pathname": "/", "query": {}},
payload={},
)
- ).__anext__()
-
- # prefix keys in expected_delta with the state name
- expected_delta = {list_mutation_state.get_name(): expected_delta}
- assert result.delta == expected_delta
+ ):
+ # prefix keys in expected_delta with the state name
+ expected_delta = {list_mutation_state.get_name(): expected_delta}
+ assert result.delta == expected_delta
@pytest.mark.asyncio
@@ -709,19 +708,18 @@ async def test_dict_mutation_detection__plain_list(
token: a Token.
"""
for event_name, expected_delta in event_tuples:
- result = await dict_mutation_state._process(
+ async for result in dict_mutation_state._process(
Event(
token=token,
name=f"{dict_mutation_state.get_name()}.{event_name}",
router_data={"pathname": "/", "query": {}},
payload={},
)
- ).__anext__()
+ ):
+ # prefix keys in expected_delta with the state name
+ expected_delta = {dict_mutation_state.get_name(): expected_delta}
- # prefix keys in expected_delta with the state name
- expected_delta = {dict_mutation_state.get_name(): expected_delta}
-
- assert result.delta == expected_delta
+ assert result.delta == expected_delta
@pytest.mark.asyncio
@@ -772,7 +770,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
# The App state must be the "root" of the state tree
app = App()
app._enable_state()
- app.event_namespace.emit = AsyncMock() # type: ignore
+ app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess]
current_state = await app.state_manager.get_state(_substate_key(token, state))
data = b"This is binary data"
@@ -795,7 +793,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
file=bio,
)
upload_fn = upload(app)
- streaming_response = await upload_fn(request_mock, [file1, file2])
+ streaming_response = await upload_fn(request_mock, [file1, file2]) # pyright: ignore [reportFunctionMemberAccess]
async for state_update in streaming_response.body_iterator:
assert (
state_update
@@ -827,7 +825,7 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
token: a Token.
"""
state._tmp_path = tmp_path
- app = App(state=State)
+ app = App(_state=State)
request_mock = unittest.mock.Mock()
request_mock.headers = {
@@ -861,7 +859,7 @@ async def test_upload_file_background(state, tmp_path, token):
token: a Token.
"""
state._tmp_path = tmp_path
- app = App(state=State)
+ app = App(_state=State)
request_mock = unittest.mock.Mock()
request_mock.headers = {
@@ -899,6 +897,7 @@ class DynamicState(BaseState):
loaded: int = 0
counter: int = 0
+ @rx.event
def on_load(self):
"""Event handler for page on_load, should trigger for all navigation events."""
self.loaded = self.loaded + 1
@@ -908,7 +907,7 @@ class DynamicState(BaseState):
"""Increment the counter var."""
self.counter = self.counter + 1
- @computed_var(cache=True)
+ @computed_var
def comp_dynamic(self) -> str:
"""A computed var that depends on the dynamic var.
@@ -917,7 +916,7 @@ class DynamicState(BaseState):
"""
return self.dynamic
- on_load_internal = OnLoadInternalState.on_load_internal.fn
+ on_load_internal = OnLoadInternalState.on_load_internal.fn # pyright: ignore [reportFunctionMemberAccess]
def test_dynamic_arg_shadow(
@@ -938,10 +937,10 @@ def test_dynamic_arg_shadow(
"""
arg_name = "counter"
route = f"/test/[{arg_name}]"
- app = app_module_mock.app = App(state=DynamicState)
- assert app.state is not None
+ app = app_module_mock.app = App(_state=DynamicState)
+ assert app._state is not None
with pytest.raises(NameError):
- app.add_page(index_page, route=route, on_load=DynamicState.on_load) # type: ignore
+ app.add_page(index_page, route=route, on_load=DynamicState.on_load)
def test_multiple_dynamic_args(
@@ -963,7 +962,7 @@ def test_multiple_dynamic_args(
arg_name = "my_arg"
route = f"/test/[{arg_name}]"
route2 = f"/test2/[{arg_name}]"
- app = app_module_mock.app = App(state=EmptyState)
+ app = app_module_mock.app = App(_state=EmptyState)
app.add_page(index_page, route=route)
app.add_page(index_page, route=route2)
@@ -990,16 +989,16 @@ async def test_dynamic_route_var_route_change_completed_on_load(
"""
arg_name = "dynamic"
route = f"/test/[{arg_name}]"
- app = app_module_mock.app = App(state=DynamicState)
- assert app.state is not None
- assert arg_name not in app.state.vars
- app.add_page(index_page, route=route, on_load=DynamicState.on_load) # type: ignore
- assert arg_name in app.state.vars
- assert arg_name in app.state.computed_vars
- assert app.state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
- constants.ROUTER
+ app = app_module_mock.app = App(_state=DynamicState)
+ assert app._state is not None
+ assert arg_name not in app._state.vars
+ app.add_page(index_page, route=route, on_load=DynamicState.on_load)
+ assert arg_name in app._state.vars
+ assert arg_name in app._state.computed_vars
+ assert app._state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
+ DynamicState.get_full_name(): {constants.ROUTER},
}
- assert constants.ROUTER in app.state()._computed_var_dependencies
+ assert constants.ROUTER in app._state()._var_dependencies
substate_token = _substate_key(token, DynamicState)
sid = "mock_sid"
@@ -1022,7 +1021,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
def _dynamic_state_event(name, val, **kwargs):
return _event(
- name=format.format_event_handler(getattr(DynamicState, name)), # type: ignore
+ name=format.format_event_handler(getattr(DynamicState, name)),
val=val,
**kwargs,
)
@@ -1174,7 +1173,7 @@ async def test_process_events(mocker, token: str):
"headers": {},
"ip": "127.0.0.1",
}
- app = App(state=GenState)
+ app = App(_state=GenState)
mocker.patch.object(app, "_postprocess", AsyncMock())
event = Event(
@@ -1190,7 +1189,7 @@ async def test_process_events(mocker, token: str):
pass
assert (await app.state_manager.get_state(event.substate_token)).value == 5
- assert app._postprocess.call_count == 6
+ assert app._postprocess.call_count == 6 # pyright: ignore [reportFunctionMemberAccess]
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@@ -1220,13 +1219,13 @@ def test_overlay_component(
overlay_component: The overlay_component to pass to App.
exp_page_child: The type of the expected child in the page fragment.
"""
- app = App(state=state, overlay_component=overlay_component)
+ app = App(_state=state, overlay_component=overlay_component)
app._setup_overlay_component()
if exp_page_child is None:
assert app.overlay_component is None
elif isinstance(exp_page_child, OverlayFragment):
assert app.overlay_component is not None
- generated_component = app._generate_component(app.overlay_component) # type: ignore
+ generated_component = app._generate_component(app.overlay_component)
assert isinstance(generated_component, OverlayFragment)
assert isinstance(
generated_component.children[0],
@@ -1235,7 +1234,7 @@ def test_overlay_component(
else:
assert app.overlay_component is not None
assert isinstance(
- app._generate_component(app.overlay_component), # type: ignore
+ app._generate_component(app.overlay_component),
exp_page_child,
)
@@ -1243,12 +1242,12 @@ def test_overlay_component(
# overlay components are wrapped during compile only
app._compile_page("test")
app._setup_overlay_component()
- page = app.pages["test"]
+ page = app._pages["test"]
if exp_page_child is not None:
assert len(page.children) == 3
children_types = (type(child) for child in page.children)
- assert exp_page_child in children_types
+ assert exp_page_child in children_types # pyright: ignore [reportOperatorIssue]
else:
assert len(page.children) == 2
@@ -1276,12 +1275,23 @@ def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]:
yield app, web_dir
-def test_app_wrap_compile_theme(compilable_app: tuple[App, Path]):
+@pytest.mark.parametrize(
+ "react_strict_mode",
+ [True, False],
+)
+def test_app_wrap_compile_theme(
+ react_strict_mode: bool, compilable_app: tuple[App, Path], mocker
+):
"""Test that the radix theme component wraps the app.
Args:
+ react_strict_mode: Whether to use React Strict Mode.
compilable_app: compilable_app fixture.
+ mocker: pytest mocker object.
"""
+ conf = rx.Config(app_name="testing", react_strict_mode=react_strict_mode)
+ mocker.patch("reflex.config._get_config", return_value=conf)
+
app, web_dir = compilable_app
app.theme = rx.theme(accent_color="plum")
app._compile()
@@ -1289,45 +1299,62 @@ def test_app_wrap_compile_theme(compilable_app: tuple[App, Path]):
app_js_lines = [
line.strip() for line in app_js_contents.splitlines() if line.strip()
]
+ lines = "".join(app_js_lines)
assert (
"function AppWrap({children}) {"
"return ("
- ""
+ + ("" if react_strict_mode else "")
+ + ""
""
""
+ ""
+ ""
"{children}"
""
+ ""
""
""
- ")"
+ + ("" if react_strict_mode else "")
+ + ")"
"}"
- ) in "".join(app_js_lines)
+ ) in lines
-def test_app_wrap_priority(compilable_app: tuple[App, Path]):
+@pytest.mark.parametrize(
+ "react_strict_mode",
+ [True, False],
+)
+def test_app_wrap_priority(
+ react_strict_mode: bool, compilable_app: tuple[App, Path], mocker
+):
"""Test that the app wrap components are wrapped in the correct order.
Args:
+ react_strict_mode: Whether to use React Strict Mode.
compilable_app: compilable_app fixture.
+ mocker: pytest mocker object.
"""
+ conf = rx.Config(app_name="testing", react_strict_mode=react_strict_mode)
+ mocker.patch("reflex.config._get_config", return_value=conf)
+
app, web_dir = compilable_app
class Fragment1(Component):
tag = "Fragment1"
- def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]:
+ def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: # pyright: ignore [reportIncompatibleMethodOverride]
return {(99, "Box"): rx.box()}
class Fragment2(Component):
tag = "Fragment2"
- def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]:
+ def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: # pyright: ignore [reportIncompatibleMethodOverride]
return {(50, "Text"): rx.text()}
class Fragment3(Component):
tag = "Fragment3"
- def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]:
+ def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]: # pyright: ignore [reportIncompatibleMethodOverride]
return {(10, "Fragment2"): Fragment2.create()}
def page():
@@ -1339,74 +1366,75 @@ def test_app_wrap_priority(compilable_app: tuple[App, Path]):
app_js_lines = [
line.strip() for line in app_js_contents.splitlines() if line.strip()
]
+ lines = "".join(app_js_lines)
assert (
"function AppWrap({children}) {"
- "return ("
- ""
+ "return (" + ("" if react_strict_mode else "") + ""
''
""
""
""
+ ""
+ ""
"{children}"
""
+ ""
""
""
""
- ""
- ")"
- "}"
- ) in "".join(app_js_lines)
+ "" + ("" if react_strict_mode else "")
+ ) in lines
def test_app_state_determination():
"""Test that the stateless status of an app is determined correctly."""
a1 = App()
- assert a1.state is None
+ assert a1._state is None
# No state, no router, no event handlers.
a1.add_page(rx.box("Index"), route="/")
- assert a1.state is None
+ assert a1._state is None
# Add a page with `on_load` enables state.
a1.add_page(rx.box("About"), route="/about", on_load=rx.console_log(""))
a1._compile_page("about")
- assert a1.state is not None
+ assert a1._state is not None
a2 = App()
- assert a2.state is None
+ assert a2._state is None
# Referencing a state Var enables state.
a2.add_page(rx.box(rx.text(GenState.value)), route="/")
a2._compile_page("index")
- assert a2.state is not None
+ assert a2._state is not None
a3 = App()
- assert a3.state is None
+ assert a3._state is None
# Referencing router enables state.
a3.add_page(rx.box(rx.text(State.router.page.full_path)), route="/")
a3._compile_page("index")
- assert a3.state is not None
+ assert a3._state is not None
a4 = App()
- assert a4.state is None
+ assert a4._state is None
a4.add_page(rx.box(rx.button("Click", on_click=rx.console_log(""))), route="/")
- assert a4.state is None
+ assert a4._state is None
a4.add_page(
rx.box(rx.button("Click", on_click=DynamicState.on_counter)), route="/page2"
)
a4._compile_page("page2")
- assert a4.state is not None
+ assert a4._state is not None
def test_raise_on_state():
"""Test that the state is set."""
# state kwargs is deprecated, we just make sure the app is created anyway.
- _app = App(state=State)
- assert _app.state is not None
- assert issubclass(_app.state, State)
+ _app = App(_state=State)
+ assert _app._state is not None
+ assert issubclass(_app._state, State)
def test_call_app():
@@ -1448,11 +1476,11 @@ def test_generate_component():
"Bar",
)
- comp = App._generate_component(index) # type: ignore
+ comp = App._generate_component(index)
assert isinstance(comp, Component)
with pytest.raises(exceptions.MatchTypeError):
- App._generate_component(index_mismatch) # type: ignore
+ App._generate_component(index_mismatch) # pyright: ignore [reportArgumentType]
def test_add_page_component_returning_tuple():
@@ -1467,27 +1495,27 @@ def test_add_page_component_returning_tuple():
def page2():
return (rx.text("third"),)
- app.add_page(index) # type: ignore
- app.add_page(page2) # type: ignore
+ app.add_page(index) # pyright: ignore [reportArgumentType]
+ app.add_page(page2) # pyright: ignore [reportArgumentType]
app._compile_page("index")
app._compile_page("page2")
- fragment_wrapper = app.pages["index"].children[0]
+ fragment_wrapper = app._pages["index"].children[0]
assert isinstance(fragment_wrapper, Fragment)
first_text = fragment_wrapper.children[0]
assert isinstance(first_text, Text)
- assert str(first_text.children[0].contents) == '"first"' # type: ignore
+ assert str(first_text.children[0].contents) == '"first"' # pyright: ignore [reportAttributeAccessIssue]
second_text = fragment_wrapper.children[1]
assert isinstance(second_text, Text)
- assert str(second_text.children[0].contents) == '"second"' # type: ignore
+ assert str(second_text.children[0].contents) == '"second"' # pyright: ignore [reportAttributeAccessIssue]
# Test page with trailing comma.
- page2_fragment_wrapper = app.pages["page2"].children[0]
+ page2_fragment_wrapper = app._pages["page2"].children[0]
assert isinstance(page2_fragment_wrapper, Fragment)
third_text = page2_fragment_wrapper.children[0]
assert isinstance(third_text, Text)
- assert str(third_text.children[0].contents) == '"third"' # type: ignore
+ assert str(third_text.children[0].contents) == '"third"' # pyright: ignore [reportAttributeAccessIssue]
@pytest.mark.parametrize("export", (True, False))
@@ -1525,7 +1553,7 @@ def test_app_with_transpile_packages(compilable_app: tuple[App, Path], export: b
next_config = (web_dir / "next.config.js").read_text()
transpile_packages_match = re.search(r"transpilePackages: (\[.*?\])", next_config)
- transpile_packages_json = transpile_packages_match.group(1) # type: ignore
+ transpile_packages_json = transpile_packages_match.group(1) # pyright: ignore [reportOptionalMemberAccess]
transpile_packages = sorted(json.loads(transpile_packages_json))
assert transpile_packages == [
@@ -1549,15 +1577,25 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
base: int = 0
_backend: int = 0
- @computed_var(cache=True)
+ @computed_var()
def foo(self) -> str:
return "foo"
- @computed_var(deps=["_backend", "base", foo], cache=True)
+ @computed_var(deps=["_backend", "base", foo])
def bar(self) -> str:
return "bar"
- app.state = ValidDepState
+ class Child1(ValidDepState):
+ @computed_var(deps=["base", ValidDepState.bar])
+ def other(self) -> str:
+ return "other"
+
+ class Child2(ValidDepState):
+ @computed_var(deps=["base", Child1.other])
+ def other(self) -> str:
+ return "other"
+
+ app._state = ValidDepState
app._compile()
@@ -1565,11 +1603,11 @@ def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]):
app, _ = compilable_app
class InvalidDepState(BaseState):
- @computed_var(deps=["foolksjdf"], cache=True)
+ @computed_var(deps=["foolksjdf"])
def bar(self) -> str:
return "bar"
- app.state = InvalidDepState
+ app._state = InvalidDepState
with pytest.raises(exceptions.VarDependencyError):
app._compile()
diff --git a/tests/units/test_config.py b/tests/units/test_config.py
index e5d4622bd..18d8cd90c 100644
--- a/tests/units/test_config.py
+++ b/tests/units/test_config.py
@@ -21,7 +21,7 @@ from reflex.constants import Endpoint, Env
def test_requires_app_name():
"""Test that a config requires an app_name."""
with pytest.raises(ValueError):
- rx.Config() # type: ignore
+ rx.Config()
def test_set_app_name(base_config_values):
@@ -207,7 +207,7 @@ def test_replace_defaults(
exp_config_values: The expected config values.
"""
mock_os_env = os.environ.copy()
- monkeypatch.setattr(reflex.config.os, "environ", mock_os_env) # type: ignore
+ monkeypatch.setattr(reflex.config.os, "environ", mock_os_env)
mock_os_env.update({k: str(v) for k, v in env_vars.items()})
c = rx.Config(app_name="a", **config_kwargs)
c._set_persistent(**set_persistent_vars)
@@ -252,6 +252,7 @@ def test_env_var():
BLUBB: EnvVar[str] = env_var("default")
INTERNAL: EnvVar[str] = env_var("default", internal=True)
BOOLEAN: EnvVar[bool] = env_var(False)
+ LIST: EnvVar[list[int]] = env_var([1, 2, 3])
assert TestEnv.BLUBB.get() == "default"
assert TestEnv.BLUBB.name == "BLUBB"
@@ -280,3 +281,11 @@ def test_env_var():
assert TestEnv.BOOLEAN.get() is False
TestEnv.BOOLEAN.set(None)
assert "BOOLEAN" not in os.environ
+
+ assert TestEnv.LIST.get() == [1, 2, 3]
+ assert TestEnv.LIST.name == "LIST"
+ TestEnv.LIST.set([4, 5, 6])
+ assert os.environ.get("LIST") == "4:5:6"
+ assert TestEnv.LIST.get() == [4, 5, 6]
+ TestEnv.LIST.set(None)
+ assert "LIST" not in os.environ
diff --git a/tests/units/test_event.py b/tests/units/test_event.py
index d7e993efa..df4f282cf 100644
--- a/tests/units/test_event.py
+++ b/tests/units/test_event.py
@@ -3,6 +3,7 @@ from typing import Callable, List
import pytest
import reflex as rx
+from reflex.constants.compiler import Hooks, Imports
from reflex.event import (
Event,
EventChain,
@@ -14,7 +15,7 @@ from reflex.event import (
)
from reflex.state import BaseState
from reflex.utils import format
-from reflex.vars.base import Field, LiteralVar, Var, field
+from reflex.vars.base import Field, LiteralVar, Var, VarData, field
def make_var(value) -> Var:
@@ -72,7 +73,7 @@ def test_call_event_handler():
)
# Passing args as strings should format differently.
- event_spec = handler("first", "second") # type: ignore
+ event_spec = handler("first", "second")
assert (
format.format_event(event_spec)
== 'Event("test_fn_with_args", {arg1:"first",arg2:"second"})'
@@ -80,7 +81,7 @@ def test_call_event_handler():
first, second = 123, "456"
handler = EventHandler(fn=test_fn_with_args)
- event_spec = handler(first, second) # type: ignore
+ event_spec = handler(first, second)
assert (
format.format_event(event_spec)
== 'Event("test_fn_with_args", {arg1:123,arg2:"456"})'
@@ -94,7 +95,7 @@ def test_call_event_handler():
handler = EventHandler(fn=test_fn_with_args)
with pytest.raises(TypeError):
- handler(test_fn) # type: ignore
+ handler(test_fn)
def test_call_event_handler_partial():
@@ -199,16 +200,15 @@ def test_event_redirect(input, output):
input: The input for running the test.
output: The expected output to validate the test.
"""
- path, external, replace = input
+ path, is_external, replace = input
kwargs = {}
- if external is not None:
- kwargs["external"] = external
+ if is_external is not None:
+ kwargs["is_external"] = is_external
if replace is not None:
kwargs["replace"] = replace
spec = event.redirect(path, **kwargs)
assert isinstance(spec, EventSpec)
assert spec.handler.fn.__qualname__ == "_redirect"
-
assert format.format_event(spec) == output
@@ -417,7 +417,7 @@ def test_event_actions_on_state():
assert isinstance(handler, EventHandler)
assert not handler.event_actions
- sp_handler = EventActionState.handler.stop_propagation
+ sp_handler = EventActionState.handler.stop_propagation # pyright: ignore [reportFunctionMemberAccess]
assert sp_handler.event_actions == {"stopPropagation": True}
# should NOT affect other references to the handler
assert not handler.event_actions
@@ -444,9 +444,28 @@ def test_event_var_data():
return (value,)
# Ensure chain carries _var_data
- chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec))
+ chain_var = Var.create(
+ EventChain(
+ events=[S.s(S.x)],
+ args_spec=_args_spec,
+ invocation=rx.vars.FunctionStringVar.create(""),
+ )
+ )
assert chain_var._get_all_var_data() == S.x._get_all_var_data()
+ chain_var_data = Var.create(
+ EventChain(
+ events=[],
+ args_spec=_args_spec,
+ )
+ )._get_all_var_data()
+ assert chain_var_data is not None
+
+ assert chain_var_data == VarData(
+ imports=Imports.EVENTS,
+ hooks={Hooks.EVENTS: None},
+ )
+
def test_event_bound_method() -> None:
class S(BaseState):
@@ -455,7 +474,7 @@ def test_event_bound_method() -> None:
print(arg)
class Wrapper:
- def get_handler(self, arg: str):
+ def get_handler(self, arg: Var[str]):
return S.e(arg)
w = Wrapper()
diff --git a/tests/units/test_health_endpoint.py b/tests/units/test_health_endpoint.py
index 6d12d79d6..5b3aedc00 100644
--- a/tests/units/test_health_endpoint.py
+++ b/tests/units/test_health_endpoint.py
@@ -122,9 +122,9 @@ async def test_health(
# Call the async health function
response = await health()
- print(json.loads(response.body))
+ print(json.loads(response.body)) # pyright: ignore [reportArgumentType]
print(expected_status)
# Verify the response content and status code
assert response.status_code == expected_code
- assert json.loads(response.body) == expected_status
+ assert json.loads(response.body) == expected_status # pyright: ignore [reportArgumentType]
diff --git a/tests/units/test_model.py b/tests/units/test_model.py
index 0a83f39ec..b17538248 100644
--- a/tests/units/test_model.py
+++ b/tests/units/test_model.py
@@ -86,7 +86,7 @@ def test_automigration(
assert versions.exists()
# initial table
- class AlembicThing(Model, table=True): # type: ignore
+ class AlembicThing(Model, table=True): # pyright: ignore [reportRedeclaration]
t1: str
with Model.get_db_engine().connect() as connection:
@@ -105,7 +105,7 @@ def test_automigration(
model_registry.get_metadata().clear()
# Create column t2, mark t1 as optional with default
- class AlembicThing(Model, table=True): # type: ignore
+ class AlembicThing(Model, table=True): # pyright: ignore [reportRedeclaration]
t1: Optional[str] = "default"
t2: str = "bar"
@@ -125,7 +125,7 @@ def test_automigration(
model_registry.get_metadata().clear()
# Drop column t1
- class AlembicThing(Model, table=True): # type: ignore
+ class AlembicThing(Model, table=True): # pyright: ignore [reportRedeclaration]
t2: str = "bar"
assert Model.migrate(autogenerate=True)
@@ -138,7 +138,7 @@ def test_automigration(
assert result[1].t2 == "baz"
# Add table
- class AlembicSecond(Model, table=True): # type: ignore
+ class AlembicSecond(Model, table=True):
a: int = 42
b: float = 4.2
@@ -160,14 +160,14 @@ def test_automigration(
# drop table (AlembicSecond)
model_registry.get_metadata().clear()
- class AlembicThing(Model, table=True): # type: ignore
+ class AlembicThing(Model, table=True): # pyright: ignore [reportRedeclaration]
t2: str = "bar"
assert Model.migrate(autogenerate=True)
assert len(list(versions.glob("*.py"))) == 5
with reflex.model.session() as session:
- with pytest.raises(sqlalchemy.exc.OperationalError) as errctx: # type: ignore
+ with pytest.raises(sqlalchemy.exc.OperationalError) as errctx:
session.exec(sqlmodel.select(AlembicSecond)).all()
assert errctx.match(r"no such table: alembicsecond")
# first table should still exist
@@ -178,7 +178,7 @@ def test_automigration(
model_registry.get_metadata().clear()
- class AlembicThing(Model, table=True): # type: ignore
+ class AlembicThing(Model, table=True):
# changing column type not supported by default
t2: int = 42
diff --git a/tests/units/test_prerequisites.py b/tests/units/test_prerequisites.py
index 90afe0963..4723d8648 100644
--- a/tests/units/test_prerequisites.py
+++ b/tests/units/test_prerequisites.py
@@ -1,20 +1,28 @@
import json
import re
+import shutil
import tempfile
+from pathlib import Path
from unittest.mock import Mock, mock_open
import pytest
+from typer.testing import CliRunner
from reflex import constants
from reflex.config import Config
+from reflex.reflex import cli
+from reflex.testing import chdir
from reflex.utils.prerequisites import (
CpuInfo,
_update_next_config,
cached_procedure,
get_cpu_info,
initialize_requirements_txt,
+ rename_imports_and_app_name,
)
+runner = CliRunner()
+
@pytest.mark.parametrize(
"config, export, expected_output",
@@ -24,7 +32,7 @@ from reflex.utils.prerequisites import (
app_name="test",
),
False,
- 'module.exports = {basePath: "", compress: true, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
+ 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
),
(
Config(
@@ -32,7 +40,7 @@ from reflex.utils.prerequisites import (
static_page_generation_timeout=30,
),
False,
- 'module.exports = {basePath: "", compress: true, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 30};',
+ 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 30};',
),
(
Config(
@@ -40,7 +48,7 @@ from reflex.utils.prerequisites import (
next_compression=False,
),
False,
- 'module.exports = {basePath: "", compress: false, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
+ 'module.exports = {basePath: "", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60};',
),
(
Config(
@@ -48,7 +56,7 @@ from reflex.utils.prerequisites import (
frontend_path="/test",
),
False,
- 'module.exports = {basePath: "/test", compress: true, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
+ 'module.exports = {basePath: "/test", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
),
(
Config(
@@ -57,14 +65,14 @@ from reflex.utils.prerequisites import (
next_compression=False,
),
False,
- 'module.exports = {basePath: "/test", compress: false, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
+ 'module.exports = {basePath: "/test", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60};',
),
(
Config(
app_name="test",
),
True,
- 'module.exports = {basePath: "", compress: true, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60, output: "export", distDir: "_static"};',
+ 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, output: "export", distDir: "_static"};',
),
],
)
@@ -92,7 +100,7 @@ def test_transpile_packages(transpile_packages, expected_transpile_packages):
transpile_packages=transpile_packages,
)
transpile_packages_match = re.search(r"transpilePackages: (\[.*?\])", output)
- transpile_packages_json = transpile_packages_match.group(1) # type: ignore
+ transpile_packages_json = transpile_packages_match.group(1) # pyright: ignore [reportOptionalMemberAccess]
actual_transpile_packages = sorted(json.loads(transpile_packages_json))
assert actual_transpile_packages == expected_transpile_packages
@@ -224,3 +232,156 @@ def test_get_cpu_info():
for attr in ("manufacturer_id", "model_name", "address_width"):
value = getattr(cpu_info, attr)
assert value.strip() if attr != "address_width" else value
+
+
+@pytest.fixture
+def temp_directory():
+ temp_dir = tempfile.mkdtemp()
+ yield Path(temp_dir)
+ shutil.rmtree(temp_dir)
+
+
+@pytest.mark.parametrize(
+ "config_code,expected",
+ [
+ ("rx.Config(app_name='old_name')", 'rx.Config(app_name="new_name")'),
+ ('rx.Config(app_name="old_name")', 'rx.Config(app_name="new_name")'),
+ ("rx.Config('old_name')", 'rx.Config("new_name")'),
+ ('rx.Config("old_name")', 'rx.Config("new_name")'),
+ ],
+)
+def test_rename_imports_and_app_name(temp_directory, config_code, expected):
+ file_path = temp_directory / "rxconfig.py"
+ content = f"""
+config = {config_code}
+"""
+ file_path.write_text(content)
+
+ rename_imports_and_app_name(file_path, "old_name", "new_name")
+
+ updated_content = file_path.read_text()
+ expected_content = f"""
+config = {expected}
+"""
+ assert updated_content == expected_content
+
+
+def test_regex_edge_cases(temp_directory):
+ file_path = temp_directory / "example.py"
+ content = """
+from old_name.module import something
+import old_name
+from old_name import something_else as alias
+from old_name
+"""
+ file_path.write_text(content)
+
+ rename_imports_and_app_name(file_path, "old_name", "new_name")
+
+ updated_content = file_path.read_text()
+ expected_content = """
+from new_name.module import something
+import new_name
+from new_name import something_else as alias
+from new_name
+"""
+ assert updated_content == expected_content
+
+
+def test_cli_rename_command(temp_directory):
+ foo_dir = temp_directory / "foo"
+ foo_dir.mkdir()
+ (foo_dir / "__init__").touch()
+ (foo_dir / ".web").mkdir()
+ (foo_dir / "assets").mkdir()
+ (foo_dir / "foo").mkdir()
+ (foo_dir / "foo" / "__init__.py").touch()
+ (foo_dir / "rxconfig.py").touch()
+ (foo_dir / "rxconfig.py").write_text(
+ """
+import reflex as rx
+
+config = rx.Config(
+ app_name="foo",
+)
+"""
+ )
+ (foo_dir / "foo" / "components").mkdir()
+ (foo_dir / "foo" / "components" / "__init__.py").touch()
+ (foo_dir / "foo" / "components" / "base.py").touch()
+ (foo_dir / "foo" / "components" / "views.py").touch()
+ (foo_dir / "foo" / "components" / "base.py").write_text(
+ """
+import reflex as rx
+from foo.components import views
+from foo.components.views import *
+from .base import *
+
+def random_component():
+ return rx.fragment()
+"""
+ )
+ (foo_dir / "foo" / "foo.py").touch()
+ (foo_dir / "foo" / "foo.py").write_text(
+ """
+import reflex as rx
+import foo.components.base
+from foo.components.base import random_component
+
+class State(rx.State):
+ pass
+
+
+def index():
+ return rx.text("Hello, World!")
+
+app = rx.App()
+app.add_page(index)
+"""
+ )
+
+ with chdir(temp_directory / "foo"):
+ result = runner.invoke(cli, ["rename", "bar"])
+
+ assert result.exit_code == 0
+ assert (foo_dir / "rxconfig.py").read_text() == (
+ """
+import reflex as rx
+
+config = rx.Config(
+ app_name="bar",
+)
+"""
+ )
+ assert (foo_dir / "bar").exists()
+ assert not (foo_dir / "foo").exists()
+ assert (foo_dir / "bar" / "components" / "base.py").read_text() == (
+ """
+import reflex as rx
+from bar.components import views
+from bar.components.views import *
+from .base import *
+
+def random_component():
+ return rx.fragment()
+"""
+ )
+ assert (foo_dir / "bar" / "bar.py").exists()
+ assert not (foo_dir / "bar" / "foo.py").exists()
+ assert (foo_dir / "bar" / "bar.py").read_text() == (
+ """
+import reflex as rx
+import bar.components.base
+from bar.components.base import random_component
+
+class State(rx.State):
+ pass
+
+
+def index():
+ return rx.text("Hello, World!")
+
+app = rx.App()
+app.add_page(index)
+"""
+ )
diff --git a/tests/units/test_route.py b/tests/units/test_route.py
index 851c9cf35..62f1788d3 100644
--- a/tests/units/test_route.py
+++ b/tests/units/test_route.py
@@ -89,7 +89,7 @@ def app():
],
)
def test_check_routes_conflict_invalid(mocker, app, route1, route2):
- mocker.patch.object(app, "pages", {route1: []})
+ mocker.patch.object(app, "_pages", {route1: []})
with pytest.raises(ValueError):
app._check_routes_conflict(route2)
@@ -117,6 +117,6 @@ def test_check_routes_conflict_invalid(mocker, app, route1, route2):
],
)
def test_check_routes_conflict_valid(mocker, app, route1, route2):
- mocker.patch.object(app, "pages", {route1: []})
+ mocker.patch.object(app, "_pages", {route1: []})
# test that running this does not throw an error.
app._check_routes_conflict(route2)
diff --git a/tests/units/test_sqlalchemy.py b/tests/units/test_sqlalchemy.py
index 23e315785..4434f5ee1 100644
--- a/tests/units/test_sqlalchemy.py
+++ b/tests/units/test_sqlalchemy.py
@@ -59,7 +59,7 @@ def test_automigration(
id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
# initial table
- class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
+ class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
t1: Mapped[str] = mapped_column(default="")
with Model.get_db_engine().connect() as connection:
@@ -78,7 +78,7 @@ def test_automigration(
model_registry.get_metadata().clear()
# Create column t2, mark t1 as optional with default
- class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
+ class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
t1: Mapped[Optional[str]] = mapped_column(default="default")
t2: Mapped[str] = mapped_column(default="bar")
@@ -98,7 +98,7 @@ def test_automigration(
model_registry.get_metadata().clear()
# Drop column t1
- class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
+ class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
t2: Mapped[str] = mapped_column(default="bar")
assert Model.migrate(autogenerate=True)
@@ -133,7 +133,7 @@ def test_automigration(
# drop table (AlembicSecond)
model_registry.get_metadata().clear()
- class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
+ class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
t2: Mapped[str] = mapped_column(default="bar")
assert Model.migrate(autogenerate=True)
diff --git a/tests/units/test_state.py b/tests/units/test_state.py
index 41fac443e..e0390c5ac 100644
--- a/tests/units/test_state.py
+++ b/tests/units/test_state.py
@@ -14,6 +14,7 @@ from typing import (
Any,
AsyncGenerator,
Callable,
+ ClassVar,
Dict,
List,
Optional,
@@ -202,7 +203,7 @@ class GrandchildState(ChildState):
class GrandchildState2(ChildState2):
"""A grandchild state fixture."""
- @rx.var(cache=True)
+ @rx.var
def cached(self) -> str:
"""A cached var.
@@ -215,7 +216,7 @@ class GrandchildState2(ChildState2):
class GrandchildState3(ChildState3):
"""A great grandchild state fixture."""
- @rx.var
+ @rx.var(cache=False)
def computed(self) -> str:
"""A computed var.
@@ -241,7 +242,7 @@ def test_state() -> TestState:
Returns:
A test state.
"""
- return TestState() # type: ignore
+ return TestState() # pyright: ignore [reportCallIssue]
@pytest.fixture
@@ -431,10 +432,10 @@ def test_default_setters(test_state):
def test_class_indexing_with_vars():
"""Test that we can index into a state var with another var."""
- prop = TestState.array[TestState.num1]
+ prop = TestState.array[TestState.num1] # pyright: ignore [reportCallIssue, reportArgumentType]
assert str(prop) == f"{TestState.get_name()}.array.at({TestState.get_name()}.num1)"
- prop = TestState.mapping["a"][TestState.num1]
+ prop = TestState.mapping["a"][TestState.num1] # pyright: ignore [reportCallIssue, reportArgumentType]
assert (
str(prop)
== f'{TestState.get_name()}.mapping["a"].at({TestState.get_name()}.num1)'
@@ -554,9 +555,9 @@ def test_get_class_var():
def test_set_class_var():
"""Test setting the var of a class."""
with pytest.raises(AttributeError):
- TestState.num3 # type: ignore
+ TestState.num3 # pyright: ignore [reportAttributeAccessIssue]
TestState._set_var(Var(_js_expr="num3", _var_type=int)._var_set_state(TestState))
- var = TestState.num3 # type: ignore
+ var = TestState.num3 # pyright: ignore [reportAttributeAccessIssue]
assert var._js_expr == TestState.get_full_name() + ".num3"
assert var._var_type is int
assert var._var_state == TestState.get_full_name()
@@ -789,17 +790,16 @@ async def test_process_event_simple(test_state):
assert test_state.num1 == 0
event = Event(token="t", name="set_num1", payload={"value": 69})
- update = await test_state._process(event).__anext__()
+ async for update in test_state._process(event):
+ # The event should update the value.
+ assert test_state.num1 == 69
- # The event should update the value.
- assert test_state.num1 == 69
-
- # The delta should contain the changes, including computed vars.
- assert update.delta == {
- TestState.get_full_name(): {"num1": 69, "sum": 72.14, "upper": ""},
- GrandchildState3.get_full_name(): {"computed": ""},
- }
- assert update.events == []
+ # The delta should contain the changes, including computed vars.
+ assert update.delta == {
+ TestState.get_full_name(): {"num1": 69, "sum": 72.14},
+ GrandchildState3.get_full_name(): {"computed": ""},
+ }
+ assert update.events == []
@pytest.mark.asyncio
@@ -819,15 +819,15 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
name=f"{ChildState.get_name()}.change_both",
payload={"value": "hi", "count": 12},
)
- update = await test_state._process(event).__anext__()
- assert child_state.value == "HI"
- assert child_state.count == 24
- assert update.delta == {
- TestState.get_full_name(): {"sum": 3.14, "upper": ""},
- ChildState.get_full_name(): {"value": "HI", "count": 24},
- GrandchildState3.get_full_name(): {"computed": ""},
- }
- test_state._clean()
+ async for update in test_state._process(event):
+ assert child_state.value == "HI"
+ assert child_state.count == 24
+ assert update.delta == {
+ # TestState.get_full_name(): {"sum": 3.14, "upper": ""},
+ ChildState.get_full_name(): {"value": "HI", "count": 24},
+ GrandchildState3.get_full_name(): {"computed": ""},
+ }
+ test_state._clean()
# Test with the granchild state.
assert grandchild_state.value2 == ""
@@ -836,19 +836,19 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
name=f"{GrandchildState.get_full_name()}.set_value2",
payload={"value": "new"},
)
- update = await test_state._process(event).__anext__()
- assert grandchild_state.value2 == "new"
- assert update.delta == {
- TestState.get_full_name(): {"sum": 3.14, "upper": ""},
- GrandchildState.get_full_name(): {"value2": "new"},
- GrandchildState3.get_full_name(): {"computed": ""},
- }
+ async for update in test_state._process(event):
+ assert grandchild_state.value2 == "new"
+ assert update.delta == {
+ # TestState.get_full_name(): {"sum": 3.14, "upper": ""},
+ GrandchildState.get_full_name(): {"value2": "new"},
+ GrandchildState3.get_full_name(): {"computed": ""},
+ }
@pytest.mark.asyncio
async def test_process_event_generator():
"""Test event handlers that generate multiple updates."""
- gen_state = GenState() # type: ignore
+ gen_state = GenState() # pyright: ignore [reportCallIssue]
event = Event(
token="t",
name="go",
@@ -948,12 +948,12 @@ def test_add_var():
assert not hasattr(ds1, "dynamic_int")
ds1.add_var("dynamic_int", int, 42)
# Existing instances get the BaseVar
- assert ds1.dynamic_int.equals(DynamicState.dynamic_int) # type: ignore
+ assert ds1.dynamic_int.equals(DynamicState.dynamic_int) # pyright: ignore [reportAttributeAccessIssue]
# New instances get an actual value with the default
assert DynamicState().dynamic_int == 42
ds1.add_var("dynamic_list", List[int], [5, 10])
- assert ds1.dynamic_list.equals(DynamicState.dynamic_list) # type: ignore
+ assert ds1.dynamic_list.equals(DynamicState.dynamic_list) # pyright: ignore [reportAttributeAccessIssue]
ds2 = DynamicState()
assert ds2.dynamic_list == [5, 10]
ds2.dynamic_list.append(15)
@@ -961,8 +961,8 @@ def test_add_var():
assert DynamicState().dynamic_list == [5, 10]
ds1.add_var("dynamic_dict", Dict[str, int], {"k1": 5, "k2": 10})
- assert ds1.dynamic_dict.equals(DynamicState.dynamic_dict) # type: ignore
- assert ds2.dynamic_dict.equals(DynamicState.dynamic_dict) # type: ignore
+ assert ds1.dynamic_dict.equals(DynamicState.dynamic_dict) # pyright: ignore [reportAttributeAccessIssue]
+ assert ds2.dynamic_dict.equals(DynamicState.dynamic_dict) # pyright: ignore [reportAttributeAccessIssue]
assert DynamicState().dynamic_dict == {"k1": 5, "k2": 10}
assert DynamicState().dynamic_dict == {"k1": 5, "k2": 10}
@@ -989,7 +989,7 @@ class InterdependentState(BaseState):
v1: int = 0
_v2: int = 1
- @rx.var(cache=True)
+ @rx.var
def v1x2(self) -> int:
"""Depends on var v1.
@@ -998,7 +998,7 @@ class InterdependentState(BaseState):
"""
return self.v1 * 2
- @rx.var(cache=True)
+ @rx.var
def v2x2(self) -> int:
"""Depends on backend var _v2.
@@ -1007,7 +1007,7 @@ class InterdependentState(BaseState):
"""
return self._v2 * 2
- @rx.var(cache=True, backend=True)
+ @rx.var(backend=True)
def v2x2_backend(self) -> int:
"""Depends on backend var _v2.
@@ -1016,16 +1016,16 @@ class InterdependentState(BaseState):
"""
return self._v2 * 2
- @rx.var(cache=True)
+ @rx.var
def v1x2x2(self) -> int:
"""Depends on ComputedVar v1x2.
Returns:
ComputedVar v1x2 multiplied by 2
"""
- return self.v1x2 * 2 # type: ignore
+ return self.v1x2 * 2
- @rx.var(cache=True)
+ @rx.var
def _v3(self) -> int:
"""Depends on backend var _v2.
@@ -1034,7 +1034,7 @@ class InterdependentState(BaseState):
"""
return self._v2
- @rx.var(cache=True)
+ @rx.var
def v3x2(self) -> int:
"""Depends on ComputedVar _v3.
@@ -1144,7 +1144,7 @@ def test_child_state():
class ChildState(MainState):
@computed_var
- def rendered_var(self):
+ def rendered_var(self) -> int:
return self.v
ms = MainState()
@@ -1170,13 +1170,17 @@ def test_conditional_computed_vars():
ms = MainState()
# Initially there are no dirty computed vars.
- assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
- assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
- assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
+ assert ms._dirty_computed_vars(from_vars={"flag"}) == {
+ (MainState.get_full_name(), "rendered_var")
+ }
+ assert ms._dirty_computed_vars(from_vars={"t2"}) == {
+ (MainState.get_full_name(), "rendered_var")
+ }
+ assert ms._dirty_computed_vars(from_vars={"t1"}) == {
+ (MainState.get_full_name(), "rendered_var")
+ }
assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == {
- "flag",
- "t1",
- "t2",
+ MainState.get_full_name(): {"flag", "t1", "t2"}
}
@@ -1239,7 +1243,7 @@ def test_computed_var_cached():
class ComputedState(BaseState):
v: int = 0
- @rx.var(cache=True)
+ @rx.var
def comp_v(self) -> int:
nonlocal comp_v_calls
comp_v_calls += 1
@@ -1264,15 +1268,15 @@ def test_computed_var_cached_depends_on_non_cached():
class ComputedState(BaseState):
v: int = 0
- @rx.var
+ @rx.var(cache=False)
def no_cache_v(self) -> int:
return self.v
- @rx.var(cache=True)
+ @rx.var
def dep_v(self) -> int:
- return self.no_cache_v # type: ignore
+ return self.no_cache_v
- @rx.var(cache=True)
+ @rx.var
def comp_v(self) -> int:
return self.v
@@ -1304,16 +1308,16 @@ def test_computed_var_depends_on_parent_non_cached():
counter = 0
class ParentState(BaseState):
- @rx.var
+ @rx.var(cache=False)
def no_cache_v(self) -> int:
nonlocal counter
counter += 1
return counter
class ChildState(ParentState):
- @rx.var(cache=True)
+ @rx.var
def dep_v(self) -> int:
- return self.no_cache_v # type: ignore
+ return self.no_cache_v
ps = ParentState()
cs = ps.substates[ChildState.get_name()]
@@ -1357,7 +1361,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
def handler(self):
self.x = self.x + 1
- @rx.var(cache=True)
+ @rx.var
def cached_x_side_effect(self) -> int:
self.handler()
nonlocal counter
@@ -1365,13 +1369,16 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
return counter
if use_partial:
- HandlerState.handler = functools.partial(HandlerState.handler.fn)
+ HandlerState.handler = functools.partial(HandlerState.handler.fn) # pyright: ignore [reportFunctionMemberAccess]
assert isinstance(HandlerState.handler, functools.partial)
else:
assert isinstance(HandlerState.handler, EventHandler)
s = HandlerState()
- assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
+ assert (
+ HandlerState.get_full_name(),
+ "cached_x_side_effect",
+ ) in s._var_dependencies["x"]
assert s.cached_x_side_effect == 1
assert s.x == 43
s.handler()
@@ -1393,7 +1400,7 @@ def test_computed_var_dependencies():
def testprop(self) -> int:
return self.v
- @rx.var(cache=True)
+ @rx.var
def comp_v(self) -> int:
"""Direct access.
@@ -1402,7 +1409,7 @@ def test_computed_var_dependencies():
"""
return self.v
- @rx.var(cache=True, backend=True)
+ @rx.var(backend=True)
def comp_v_backend(self) -> int:
"""Direct access backend var.
@@ -1411,7 +1418,7 @@ def test_computed_var_dependencies():
"""
return self.v
- @rx.var(cache=True)
+ @rx.var
def comp_v_via_property(self) -> int:
"""Access v via property.
@@ -1420,8 +1427,8 @@ def test_computed_var_dependencies():
"""
return self.testprop
- @rx.var(cache=True)
- def comp_w(self):
+ @rx.var
+ def comp_w(self) -> Callable[[], int]:
"""Nested lambda.
Returns:
@@ -1429,8 +1436,8 @@ def test_computed_var_dependencies():
"""
return lambda: self.w
- @rx.var(cache=True)
- def comp_x(self):
+ @rx.var
+ def comp_x(self) -> Callable[[], int]:
"""Nested function.
Returns:
@@ -1442,8 +1449,8 @@ def test_computed_var_dependencies():
return _
- @rx.var(cache=True)
- def comp_y(self) -> List[int]:
+ @rx.var
+ def comp_y(self) -> list[int]:
"""Comprehension iterating over attribute.
Returns:
@@ -1451,7 +1458,7 @@ def test_computed_var_dependencies():
"""
return [round(y) for y in self.y]
- @rx.var(cache=True)
+ @rx.var
def comp_z(self) -> List[bool]:
"""Comprehension accesses attribute.
@@ -1461,15 +1468,15 @@ def test_computed_var_dependencies():
return [z in self._z for z in range(5)]
cs = ComputedState()
- assert cs._computed_var_dependencies["v"] == {
- "comp_v",
- "comp_v_backend",
- "comp_v_via_property",
+ assert cs._var_dependencies["v"] == {
+ (ComputedState.get_full_name(), "comp_v"),
+ (ComputedState.get_full_name(), "comp_v_backend"),
+ (ComputedState.get_full_name(), "comp_v_via_property"),
}
- assert cs._computed_var_dependencies["w"] == {"comp_w"}
- assert cs._computed_var_dependencies["x"] == {"comp_x"}
- assert cs._computed_var_dependencies["y"] == {"comp_y"}
- assert cs._computed_var_dependencies["_z"] == {"comp_z"}
+ assert cs._var_dependencies["w"] == {(ComputedState.get_full_name(), "comp_w")}
+ assert cs._var_dependencies["x"] == {(ComputedState.get_full_name(), "comp_x")}
+ assert cs._var_dependencies["y"] == {(ComputedState.get_full_name(), "comp_y")}
+ assert cs._var_dependencies["_z"] == {(ComputedState.get_full_name(), "comp_z")}
def test_backend_method():
@@ -1616,7 +1623,7 @@ async def test_state_with_invalid_yield(capsys, mock_app):
id="backend_error",
position="top-center",
style={"width": "500px"},
- ) # type: ignore
+ )
],
token="",
)
@@ -1907,13 +1914,13 @@ def mock_app_simple(monkeypatch) -> rx.App:
Returns:
The app, after mocking out prerequisites.get_app()
"""
- app = App(state=TestState)
+ app = App(_state=TestState)
app_module = Mock()
setattr(app_module, CompileVars.APP, app)
- app.state = TestState
- app.event_namespace.emit = CopyingAsyncMock() # type: ignore
+ app._state = TestState
+ app.event_namespace.emit = CopyingAsyncMock() # pyright: ignore [reportOptionalMemberAccess]
def _mock_get_app(*args, **kwargs):
return app_module
@@ -2022,15 +2029,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
# ensure state update was emitted
assert mock_app.event_namespace is not None
- mock_app.event_namespace.emit.assert_called_once()
- mcall = mock_app.event_namespace.emit.mock_calls[0]
+ mock_app.event_namespace.emit.assert_called_once() # pyright: ignore [reportFunctionMemberAccess]
+ mcall = mock_app.event_namespace.emit.mock_calls[0] # pyright: ignore [reportFunctionMemberAccess]
assert mcall.args[0] == str(SocketEvent.EVENT)
assert mcall.args[1] == StateUpdate(
delta={
- parent_state.get_full_name(): {
- "upper": "",
- "sum": 3.14,
- },
grandchild_state.get_full_name(): {
"value2": "42",
},
@@ -2053,7 +2056,7 @@ class BackgroundTaskState(BaseState):
super().__init__(**kwargs)
self.router_data = {"simulate": "hydrate"}
- @rx.var
+ @rx.var(cache=False)
def computed_order(self) -> List[str]:
"""Get the order as a computed var.
@@ -2157,8 +2160,8 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
token: A token.
"""
router_data = {"query": {}}
- mock_app.state_manager.state = mock_app.state = BackgroundTaskState
- async for update in rx.app.process( # type: ignore
+ mock_app.state_manager.state = mock_app._state = BackgroundTaskState
+ async for update in rx.app.process(
mock_app,
Event(
token=token,
@@ -2175,10 +2178,10 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
# wait for the coroutine to start
await asyncio.sleep(0.5 if CI else 0.1)
- assert len(mock_app.background_tasks) == 1
+ assert len(mock_app._background_tasks) == 1
# Process another normal event
- async for update in rx.app.process( # type: ignore
+ async for update in rx.app.process(
mock_app,
Event(
token=token,
@@ -2207,9 +2210,9 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
)
# Explicit wait for background tasks
- for task in tuple(mock_app.background_tasks):
+ for task in tuple(mock_app._background_tasks):
await task
- assert not mock_app.background_tasks
+ assert not mock_app._background_tasks
exp_order = [
"background_task:start",
@@ -2228,7 +2231,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit
- first_ws_message = emit_mock.mock_calls[0].args[1]
+ first_ws_message = emit_mock.mock_calls[0].args[1] # pyright: ignore [reportFunctionMemberAccess]
assert (
first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
is not None
@@ -2243,7 +2246,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
events=[],
final=True,
)
- for call in emit_mock.mock_calls[1:5]:
+ for call in emit_mock.mock_calls[1:5]: # pyright: ignore [reportFunctionMemberAccess]
assert call.args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
@@ -2253,7 +2256,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
events=[],
final=True,
)
- assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
+ assert emit_mock.mock_calls[-2].args[1] == StateUpdate( # pyright: ignore [reportFunctionMemberAccess]
delta={
BackgroundTaskState.get_full_name(): {
"order": exp_order,
@@ -2264,7 +2267,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
events=[],
final=True,
)
- assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
+ assert emit_mock.mock_calls[-1].args[1] == StateUpdate( # pyright: ignore [reportFunctionMemberAccess]
delta={
BackgroundTaskState.get_full_name(): {
"computed_order": exp_order,
@@ -2284,8 +2287,8 @@ async def test_background_task_reset(mock_app: rx.App, token: str):
token: A token.
"""
router_data = {"query": {}}
- mock_app.state_manager.state = mock_app.state = BackgroundTaskState
- async for update in rx.app.process( # type: ignore
+ mock_app.state_manager.state = mock_app._state = BackgroundTaskState
+ async for update in rx.app.process(
mock_app,
Event(
token=token,
@@ -2301,9 +2304,9 @@ async def test_background_task_reset(mock_app: rx.App, token: str):
assert update == StateUpdate()
# Explicit wait for background tasks
- for task in tuple(mock_app.background_tasks):
+ for task in tuple(mock_app._background_tasks):
await task
- assert not mock_app.background_tasks
+ assert not mock_app._background_tasks
assert (
await mock_app.state_manager.get_state(
@@ -2627,10 +2630,10 @@ def test_duplicate_substate_class(mocker):
class TestState(BaseState):
pass
- class ChildTestState(TestState): # type: ignore
+ class ChildTestState(TestState): # pyright: ignore [reportRedeclaration]
pass
- class ChildTestState(TestState): # type: ignore # noqa
+ class ChildTestState(TestState): # noqa: F811
pass
return TestState
@@ -2668,21 +2671,21 @@ def test_reset_with_mutables():
items: List[List[int]] = default
instance = MutableResetState()
- assert instance.items.__wrapped__ is not default # type: ignore
+ assert instance.items.__wrapped__ is not default # pyright: ignore [reportAttributeAccessIssue]
assert instance.items == default == copied_default
instance.items.append([3, 3])
assert instance.items != default
assert instance.items != copied_default
instance.reset()
- assert instance.items.__wrapped__ is not default # type: ignore
+ assert instance.items.__wrapped__ is not default # pyright: ignore [reportAttributeAccessIssue]
assert instance.items == default == copied_default
instance.items.append([3, 3])
assert instance.items != default
assert instance.items != copied_default
instance.reset()
- assert instance.items.__wrapped__ is not default # type: ignore
+ assert instance.items.__wrapped__ is not default # pyright: ignore [reportAttributeAccessIssue]
assert instance.items == default == copied_default
instance.items.append([3, 3])
assert instance.items != default
@@ -2744,30 +2747,30 @@ def test_state_union_optional():
c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo="")))
custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="")
- assert str(UnionState.c3.c2) == f'{UnionState.c3!s}?.["c2"]' # type: ignore
- assert str(UnionState.c3.c2.c1) == f'{UnionState.c3!s}?.["c2"]?.["c1"]' # type: ignore
+ assert str(UnionState.c3.c2) == f'{UnionState.c3!s}?.["c2"]' # pyright: ignore [reportOptionalMemberAccess]
+ assert str(UnionState.c3.c2.c1) == f'{UnionState.c3!s}?.["c2"]?.["c1"]' # pyright: ignore [reportOptionalMemberAccess]
assert (
- str(UnionState.c3.c2.c1.foo) == f'{UnionState.c3!s}?.["c2"]?.["c1"]?.["foo"]' # type: ignore
+ str(UnionState.c3.c2.c1.foo) == f'{UnionState.c3!s}?.["c2"]?.["c1"]?.["foo"]' # pyright: ignore [reportOptionalMemberAccess]
)
assert (
- str(UnionState.c3.c2.c1r.foo) == f'{UnionState.c3!s}?.["c2"]?.["c1r"]["foo"]' # type: ignore
+ str(UnionState.c3.c2.c1r.foo) == f'{UnionState.c3!s}?.["c2"]?.["c1r"]["foo"]' # pyright: ignore [reportOptionalMemberAccess]
)
- assert str(UnionState.c3.c2r.c1) == f'{UnionState.c3!s}?.["c2r"]["c1"]' # type: ignore
+ assert str(UnionState.c3.c2r.c1) == f'{UnionState.c3!s}?.["c2r"]["c1"]' # pyright: ignore [reportOptionalMemberAccess]
assert (
- str(UnionState.c3.c2r.c1.foo) == f'{UnionState.c3!s}?.["c2r"]["c1"]?.["foo"]' # type: ignore
+ str(UnionState.c3.c2r.c1.foo) == f'{UnionState.c3!s}?.["c2r"]["c1"]?.["foo"]' # pyright: ignore [reportOptionalMemberAccess]
)
assert (
- str(UnionState.c3.c2r.c1r.foo) == f'{UnionState.c3!s}?.["c2r"]["c1r"]["foo"]' # type: ignore
+ str(UnionState.c3.c2r.c1r.foo) == f'{UnionState.c3!s}?.["c2r"]["c1r"]["foo"]' # pyright: ignore [reportOptionalMemberAccess]
)
- assert str(UnionState.c3i.c2) == f'{UnionState.c3i!s}["c2"]' # type: ignore
- assert str(UnionState.c3r.c2) == f'{UnionState.c3r!s}["c2"]' # type: ignore
- assert UnionState.custom_union.foo is not None # type: ignore
- assert UnionState.custom_union.c1 is not None # type: ignore
- assert UnionState.custom_union.c1r is not None # type: ignore
- assert UnionState.custom_union.c2 is not None # type: ignore
- assert UnionState.custom_union.c2r is not None # type: ignore
- assert types.is_optional(UnionState.opt_int._var_type) # type: ignore
- assert types.is_union(UnionState.int_float._var_type) # type: ignore
+ assert str(UnionState.c3i.c2) == f'{UnionState.c3i!s}["c2"]'
+ assert str(UnionState.c3r.c2) == f'{UnionState.c3r!s}["c2"]'
+ assert UnionState.custom_union.foo is not None # pyright: ignore [reportAttributeAccessIssue]
+ assert UnionState.custom_union.c1 is not None # pyright: ignore [reportAttributeAccessIssue]
+ assert UnionState.custom_union.c1r is not None # pyright: ignore [reportAttributeAccessIssue]
+ assert UnionState.custom_union.c2 is not None # pyright: ignore [reportAttributeAccessIssue]
+ assert UnionState.custom_union.c2r is not None # pyright: ignore [reportAttributeAccessIssue]
+ assert types.is_optional(UnionState.opt_int._var_type) # pyright: ignore [reportAttributeAccessIssue, reportOptionalMemberAccess]
+ assert types.is_union(UnionState.int_float._var_type) # pyright: ignore [reportAttributeAccessIssue]
def test_set_base_field_via_setter():
@@ -2888,7 +2891,7 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
"reflex.state.State.class_subclasses", {test_state, OnLoadInternalState}
)
app = app_module_mock.app = App(
- state=State, load_events={"index": [test_state.test_handler]}
+ _state=State, _load_events={"index": [test_state.test_handler]}
)
async with app.state_manager.modify_state(_substate_key(token, State)) as state:
state.router_data = {"simulate": "hydrate"}
@@ -2913,10 +2916,10 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
events = updates[0].events
assert len(events) == 2
- assert (await state._process(events[0]).__anext__()).delta == {
- test_state.get_full_name(): {"num": 1}
- }
- assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state)
+ async for update in state._process(events[0]):
+ assert update.delta == {test_state.get_full_name(): {"num": 1}}
+ async for update in state._process(events[1]):
+ assert update.delta == exp_is_hydrated(state)
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@@ -2935,8 +2938,8 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
"reflex.state.State.class_subclasses", {OnLoadState, OnLoadInternalState}
)
app = app_module_mock.app = App(
- state=State,
- load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
+ _state=State,
+ _load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
)
async with app.state_manager.modify_state(_substate_key(token, State)) as state:
state.router_data = {"simulate": "hydrate"}
@@ -2961,13 +2964,12 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
events = updates[0].events
assert len(events) == 3
- assert (await state._process(events[0]).__anext__()).delta == {
- OnLoadState.get_full_name(): {"num": 1}
- }
- assert (await state._process(events[1]).__anext__()).delta == {
- OnLoadState.get_full_name(): {"num": 2}
- }
- assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)
+ async for update in state._process(events[0]):
+ assert update.delta == {OnLoadState.get_full_name(): {"num": 1}}
+ async for update in state._process(events[1]):
+ assert update.delta == {OnLoadState.get_full_name(): {"num": 2}}
+ async for update in state._process(events[2]):
+ assert update.delta == exp_is_hydrated(state)
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@@ -2981,7 +2983,7 @@ async def test_get_state(mock_app: rx.App, token: str):
mock_app: An app that will be returned by `get_app()`
token: A token.
"""
- mock_app.state_manager.state = mock_app.state = TestState
+ mock_app.state_manager.state = mock_app._state = TestState
# Get instance of ChildState2.
test_state = await mock_app.state_manager.get_state(
@@ -3040,10 +3042,6 @@ async def test_get_state(mock_app: rx.App, token: str):
grandchild_state.value2 = "set_value"
assert test_state.get_delta() == {
- TestState.get_full_name(): {
- "sum": 3.14,
- "upper": "",
- },
GrandchildState.get_full_name(): {
"value2": "set_value",
},
@@ -3081,10 +3079,6 @@ async def test_get_state(mock_app: rx.App, token: str):
child_state2.value = "set_c2_value"
assert new_test_state.get_delta() == {
- TestState.get_full_name(): {
- "sum": 3.14,
- "upper": "",
- },
ChildState2.get_full_name(): {
"value": "set_c2_value",
},
@@ -3139,8 +3133,8 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
child3_var: int = 0
- @rx.var
- def v(self):
+ @rx.var(cache=False)
+ def v(self) -> None:
pass
class Grandchild3(Child3):
@@ -3159,7 +3153,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
pass
- mock_app.state_manager.state = mock_app.state = Parent
+ mock_app.state_manager.state = mock_app._state = Parent
# Get the top level state via unconnected sibling.
root = await mock_app.state_manager.get_state(_substate_key(token, Child))
@@ -3194,7 +3188,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
RxState = State
-def test_potentially_dirty_substates():
+def test_potentially_dirty_states():
"""Test that potentially_dirty_substates returns the correct substates.
Even if the name "State" is shadowed, it should still work correctly.
@@ -3210,13 +3204,19 @@ def test_potentially_dirty_substates():
def bar(self) -> str:
return ""
- assert RxState._potentially_dirty_substates() == {State}
- assert State._potentially_dirty_substates() == {C1}
- assert C1._potentially_dirty_substates() == set()
+ assert RxState._get_potentially_dirty_states() == set()
+ assert State._get_potentially_dirty_states() == set()
+ assert C1._get_potentially_dirty_states() == set()
-def test_router_var_dep() -> None:
- """Test that router var dependencies are correctly tracked."""
+@pytest.mark.asyncio
+async def test_router_var_dep(state_manager: StateManager, token: str) -> None:
+ """Test that router var dependencies are correctly tracked.
+
+ Args:
+ state_manager: A state manager.
+ token: A token.
+ """
class RouterVarParentState(State):
"""A parent state for testing router var dependency."""
@@ -3226,37 +3226,34 @@ def test_router_var_dep() -> None:
class RouterVarDepState(RouterVarParentState):
"""A state with a router var dependency."""
- @rx.var(cache=True)
+ @rx.var
def foo(self) -> str:
return self.router.page.params.get("foo", "")
foo = RouterVarDepState.computed_vars["foo"]
State._init_var_dependency_dicts()
- assert foo._deps(objclass=RouterVarDepState) == {"router"}
- assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
- assert RouterVarParentState._substate_var_dependencies == {
- "router": {RouterVarDepState.get_name()}
- }
- assert RouterVarDepState._computed_var_dependencies == {
- "router": {"foo"},
+ assert foo._deps(objclass=RouterVarDepState) == {
+ RouterVarDepState.get_full_name(): {"router"}
}
+ assert (RouterVarDepState.get_full_name(), "foo") in State._var_dependencies[
+ "router"
+ ]
- rx_state = State()
- parent_state = RouterVarParentState()
- state = RouterVarDepState()
-
- # link states
- rx_state.substates = {RouterVarParentState.get_name(): parent_state}
- parent_state.parent_state = rx_state
- state.parent_state = parent_state
- parent_state.substates = {RouterVarDepState.get_name(): state}
+ # Get state from state manager.
+ state_manager.state = State
+ rx_state = await state_manager.get_state(_substate_key(token, State))
+ assert RouterVarParentState.get_name() in rx_state.substates
+ parent_state = rx_state.substates[RouterVarParentState.get_name()]
+ assert RouterVarDepState.get_name() in parent_state.substates
+ state = parent_state.substates[RouterVarDepState.get_name()]
assert state.dirty_vars == set()
# Reassign router var
state.router = state.router
- assert state.dirty_vars == {"foo", "router"}
+ assert rx_state.dirty_vars == {"router"}
+ assert state.dirty_vars == {"foo"}
assert parent_state.dirty_substates == {RouterVarDepState.get_name()}
@@ -3372,9 +3369,9 @@ config = rx.Config(
from reflex.state import State, StateManager
state_manager = StateManager.create(state=State)
- assert state_manager.lock_expiration == expected_values[0] # type: ignore
- assert state_manager.token_expiration == expected_values[1] # type: ignore
- assert state_manager.lock_warning_threshold == expected_values[2] # type: ignore
+ assert state_manager.lock_expiration == expected_values[0] # pyright: ignore [reportAttributeAccessIssue]
+ assert state_manager.token_expiration == expected_values[1] # pyright: ignore [reportAttributeAccessIssue]
+ assert state_manager.lock_warning_threshold == expected_values[2] # pyright: ignore [reportAttributeAccessIssue]
@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
@@ -3421,7 +3418,7 @@ class MixinState(State, mixin=True):
_backend: int = 0
_backend_no_default: dict
- @rx.var(cache=True)
+ @rx.var
def computed(self) -> str:
"""A computed var on mixin state.
@@ -3453,7 +3450,7 @@ def test_mixin_state() -> None:
assert "computed" in UsesMixinState.vars
assert (
- UsesMixinState(_reflex_internal_init=True)._backend_no_default # type: ignore
+ UsesMixinState(_reflex_internal_init=True)._backend_no_default # pyright: ignore [reportCallIssue]
is not UsesMixinState.backend_vars["_backend_no_default"]
)
@@ -3473,7 +3470,7 @@ def test_assignment_to_undeclared_vars():
class State(BaseState):
val: str
_val: str
- __val: str # type: ignore
+ __val: str # pyright: ignore [reportGeneralTypeIssues]
def handle_supported_regular_vars(self):
self.val = "no underscore"
@@ -3493,8 +3490,8 @@ def test_assignment_to_undeclared_vars():
def handle_var(self):
self.value = 20
- state = State() # type: ignore
- sub_state = Substate() # type: ignore
+ state = State() # pyright: ignore [reportCallIssue]
+ sub_state = Substate() # pyright: ignore [reportCallIssue]
with pytest.raises(SetUndefinedStateVarError):
state.handle_regular_var()
@@ -3556,7 +3553,7 @@ def test_fallback_pickle():
_f: Optional[Callable] = None
_g: Any = None
- state = DillState(_reflex_internal_init=True) # type: ignore
+ state = DillState(_reflex_internal_init=True) # pyright: ignore [reportCallIssue]
state._o = Obj(_f=lambda: 42)
state._f = lambda: 420
@@ -3567,14 +3564,14 @@ def test_fallback_pickle():
assert unpickled_state._o._f() == 42
# Threading locks are unpicklable normally, and raise TypeError instead of PicklingError.
- state2 = DillState(_reflex_internal_init=True) # type: ignore
+ state2 = DillState(_reflex_internal_init=True) # pyright: ignore [reportCallIssue]
state2._g = threading.Lock()
pk2 = state2._serialize()
unpickled_state2 = BaseState._deserialize(pk2)
assert isinstance(unpickled_state2._g, type(threading.Lock()))
# Some object, like generator, are still unpicklable with dill.
- state3 = DillState(_reflex_internal_init=True) # type: ignore
+ state3 = DillState(_reflex_internal_init=True) # pyright: ignore [reportCallIssue]
state3._g = (i for i in range(10))
with pytest.raises(StateSerializationError):
@@ -3749,7 +3746,7 @@ class UpcastState(rx.State):
assert isinstance(a, list)
self.passed = True
- def py_unresolvable(self, u: "Unresolvable"): # noqa: D102, F821 # type: ignore
+ def py_unresolvable(self, u: "Unresolvable"): # noqa: D102, F821 # pyright: ignore [reportUndefinedVariable]
assert isinstance(u, list)
self.passed = True
@@ -3815,3 +3812,128 @@ async def test_get_var_value(state_manager: StateManager, substate_token: str):
# Generic Var with no state
with pytest.raises(UnretrievableVarValueError):
await state.get_var_value(rx.Var("undefined"))
+
+
+@pytest.mark.asyncio
+async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
+ """A test where an async computed var depends on a var in another state.
+
+ Args:
+ mock_app: An app that will be returned by `get_app()`
+ token: A token.
+ """
+
+ class Parent(BaseState):
+ """A root state like rx.State."""
+
+ parent_var: int = 0
+
+ class Child2(Parent):
+ """An unconnected child state."""
+
+ pass
+
+ class Child3(Parent):
+ """A child state with a computed var causing it to be pre-fetched.
+
+ If child3_var gets set to a value, and `get_state` erroneously
+ re-fetches it from redis, the value will be lost.
+ """
+
+ child3_var: int = 0
+
+ @rx.var(cache=True)
+ def v(self) -> int:
+ return self.child3_var
+
+ class Child(Parent):
+ """A state simulating UpdateVarsInternalState."""
+
+ @rx.var(cache=True)
+ async def v(self) -> int:
+ p = await self.get_state(Parent)
+ child3 = await self.get_state(Child3)
+ return child3.child3_var + p.parent_var
+
+ mock_app.state_manager.state = mock_app._state = Parent
+
+ # Get the top level state via unconnected sibling.
+ root = await mock_app.state_manager.get_state(_substate_key(token, Child))
+ # Set value in parent_var to assert it does not get refetched later.
+ root.parent_var = 1
+
+ if isinstance(mock_app.state_manager, StateManagerRedis):
+ # When redis is used, only states with uncached computed vars are pre-fetched.
+ assert Child2.get_name() not in root.substates
+ assert Child3.get_name() not in root.substates
+
+ # Get the unconnected sibling state, which will be used to `get_state` other instances.
+ child = root.get_substate(Child.get_full_name().split("."))
+
+ # Get an uncached child state.
+ child2 = await child.get_state(Child2)
+ assert child2.parent_var == 1
+
+ # Set value on already-cached Child3 state (prefetched because it has a Computed Var).
+ child3 = await child.get_state(Child3)
+ child3.child3_var = 1
+
+ assert await child.v == 2
+ assert await child.v == 2
+ root.parent_var = 2
+ assert await child.v == 3
+
+
+class Table(rx.ComponentState):
+ """A table state."""
+
+ data: ClassVar[Var]
+
+ @rx.var(cache=True, auto_deps=False)
+ async def rows(self) -> List[Dict[str, Any]]:
+ """Computed var over the given rows.
+
+ Returns:
+ The data rows.
+ """
+ return await self.get_var_value(self.data)
+
+ @classmethod
+ def get_component(cls, data: Var) -> rx.Component:
+ """Get the component for the table.
+
+ Args:
+ data: The data var.
+
+ Returns:
+ The component.
+ """
+ cls.data = data
+ cls.computed_vars["rows"].add_dependency(cls, data)
+ return rx.foreach(data, lambda d: rx.text(d.to_string()))
+
+
+@pytest.mark.asyncio
+async def test_async_computed_var_get_var_value(mock_app: rx.App, token: str):
+ """A test where an async computed var depends on a var in another state.
+
+ Args:
+ mock_app: An app that will be returned by `get_app()`
+ token: A token.
+ """
+
+ class OtherState(rx.State):
+ """A state with a var."""
+
+ data: List[Dict[str, Any]] = [{"foo": "bar"}]
+
+ mock_app.state_manager.state = mock_app._state = rx.State
+ comp = Table.create(data=OtherState.data)
+ state = await mock_app.state_manager.get_state(_substate_key(token, OtherState))
+ other_state = await state.get_state(OtherState)
+ assert comp.State is not None
+ comp_state = await state.get_state(comp.State)
+ assert comp_state.dirty_vars == set()
+
+ other_state.data.append({"foo": "baz"})
+ assert "rows" in comp_state.dirty_vars
diff --git a/tests/units/test_state_tree.py b/tests/units/test_state_tree.py
index ebdd877de..70ef71cb8 100644
--- a/tests/units/test_state_tree.py
+++ b/tests/units/test_state_tree.py
@@ -42,7 +42,7 @@ class SubA_A_A_A(SubA_A_A):
class SubA_A_A_B(SubA_A_A):
"""SubA_A_A_B is a child of SubA_A_A."""
- @rx.var(cache=True)
+ @rx.var
def sub_a_a_a_cached(self) -> int:
"""A cached var.
@@ -117,7 +117,7 @@ class TreeD(Root):
d: int
- @rx.var
+ @rx.var(cache=False)
def d_var(self) -> int:
"""A computed var.
@@ -156,7 +156,7 @@ class SubE_A_A_A_A(SubE_A_A_A):
sub_e_a_a_a_a: int
- @rx.var
+ @rx.var(cache=False)
def sub_e_a_a_a_a_var(self) -> int:
"""A computed var.
@@ -183,7 +183,7 @@ class SubE_A_A_A_D(SubE_A_A_A):
sub_e_a_a_a_d: int
- @rx.var(cache=True)
+ @rx.var
def sub_e_a_a_a_d_var(self) -> int:
"""A computed var.
@@ -222,7 +222,7 @@ async def state_manager_redis(
Yields:
A state manager instance
"""
- app_module_mock.app = rx.App(state=Root)
+ app_module_mock.app = rx.App(_state=Root)
state_manager = app_module_mock.app.state_manager
if not isinstance(state_manager, StateManagerRedis):
diff --git a/tests/units/test_style.py b/tests/units/test_style.py
index e1d652798..e8ff5bd01 100644
--- a/tests/units/test_style.py
+++ b/tests/units/test_style.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from typing import Any, Dict
+from typing import Any, Mapping
import pytest
@@ -356,7 +356,7 @@ def test_style_via_component(
style_dict: The style_dict to pass to the component.
expected_get_style: The expected style dict.
"""
- comp = rx.el.div(style=style_dict, **kwargs) # type: ignore
+ comp = rx.el.div(style=style_dict, **kwargs) # pyright: ignore [reportArgumentType]
compare_dict_of_var(comp._get_style(), expected_get_style)
@@ -379,7 +379,7 @@ class StyleState(rx.State):
{
"css": Var(
_js_expr=f'({{ ["color"] : ("dark"+{StyleState.color}) }})'
- ).to(Dict[str, str])
+ ).to(Mapping[str, str])
},
),
(
@@ -515,17 +515,17 @@ def test_evaluate_style_namespaces():
"""Test that namespaces get converted to component create functions."""
style_dict = {rx.text: {"color": "blue"}}
assert rx.text.__call__ not in style_dict
- style_dict = evaluate_style_namespaces(style_dict) # type: ignore
+ style_dict = evaluate_style_namespaces(style_dict) # pyright: ignore [reportArgumentType]
assert rx.text.__call__ in style_dict
def test_style_update_with_var_data():
"""Test that .update with a Style containing VarData works."""
red_var = LiteralVar.create("red")._replace(
- merge_var_data=VarData(hooks={"const red = true": None}), # type: ignore
+ merge_var_data=VarData(hooks={"const red = true": None}),
)
blue_var = LiteralVar.create("blue")._replace(
- merge_var_data=VarData(hooks={"const blue = true": None}), # type: ignore
+ merge_var_data=VarData(hooks={"const blue = true": None}),
)
s1 = Style(
@@ -541,3 +541,7 @@ def test_style_update_with_var_data():
assert s2._var_data is not None
assert "const red = true" in s2._var_data.hooks
assert "const blue = true" in s2._var_data.hooks
+
+ s3 = s1 | s2
+ assert s3._var_data is not None
+ assert "_varData" not in s3
diff --git a/tests/units/test_testing.py b/tests/units/test_testing.py
index 83a03ad83..8c8f1461b 100644
--- a/tests/units/test_testing.py
+++ b/tests/units/test_testing.py
@@ -23,7 +23,7 @@ def test_app_harness(tmp_path):
class State(rx.State):
pass
- app = rx.App(state=State)
+ app = rx.App(_state=State)
app.add_page(lambda: rx.text("Basic App"), route="/", title="index")
app._compile()
diff --git a/tests/units/test_var.py b/tests/units/test_var.py
index bfa8aa35a..8fcd288e6 100644
--- a/tests/units/test_var.py
+++ b/tests/units/test_var.py
@@ -1,17 +1,20 @@
import json
import math
-import sys
import typing
-from typing import Dict, List, Optional, Set, Tuple, Union, cast
+from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast
import pytest
from pandas import DataFrame
import reflex as rx
from reflex.base import Base
+from reflex.config import PerformanceMode
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.state import BaseState
-from reflex.utils.exceptions import PrimitiveUnserializableToJSON
+from reflex.utils.exceptions import (
+ PrimitiveUnserializableToJSONError,
+ UntypedComputedVarError,
+)
from reflex.utils.imports import ImportVar
from reflex.vars import VarData
from reflex.vars.base import (
@@ -185,6 +188,7 @@ def ChildWithRuntimeOnlyVar(StateWithRuntimeOnlyVar):
"state.local",
"local2",
],
+ strict=True,
),
)
def test_full_name(prop, expected):
@@ -202,6 +206,7 @@ def test_full_name(prop, expected):
zip(
test_vars,
["prop1", "key", "state.value", "state.local", "local2"],
+ strict=True,
),
)
def test_str(prop, expected):
@@ -248,6 +253,7 @@ def test_default_value(prop: Var, expected):
"state.set_local",
"set_local2",
],
+ strict=True,
),
)
def test_get_setter(prop: Var, expected):
@@ -270,7 +276,7 @@ def test_get_setter(prop: Var, expected):
([1, 2, 3], Var(_js_expr="[1, 2, 3]", _var_type=List[int])),
(
{"a": 1, "b": 2},
- Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Dict[str, int]),
+ Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Mapping[str, int]),
),
],
)
@@ -282,7 +288,7 @@ def test_create(value, expected):
expected: The expected name of the setter function.
"""
prop = LiteralVar.create(value)
- assert prop.equals(expected) # type: ignore
+ assert prop.equals(expected)
def test_create_type_error():
@@ -416,19 +422,13 @@ class Bar(rx.Base):
@pytest.mark.parametrize(
("var", "var_type"),
- (
- [
- (Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar),
- (Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]),
- ]
- if sys.version_info >= (3, 10)
- else []
- )
- + [
- (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]),
- (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str),
+ [
+ (Var(_js_expr="").to(Foo | Bar), Foo | Bar),
+ (Var(_js_expr="").to(Foo | Bar).bar, Union[int, str]),
+ (Var(_js_expr="").to(Union[Foo, Bar]), Union[Foo, Bar]),
+ (Var(_js_expr="").to(Union[Foo, Bar]).baz, str),
(
- Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo,
+ Var(_js_expr="").to(Union[Foo, Bar]).foo,
Union[int, None],
),
],
@@ -804,7 +804,7 @@ def test_shadow_computed_var_error(request: pytest.FixtureRequest, fixture: str)
request: Fixture Request.
fixture: The state fixture.
"""
- with pytest.raises(NameError):
+ with pytest.raises(UntypedComputedVarError):
state = request.getfixturevalue(fixture)
state.var_without_annotation.foo
@@ -1004,7 +1004,7 @@ def test_all_number_operations():
assert (
str(even_more_complicated_number)
- == "!(((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)))) !== 0))"
+ == "!(isTrue((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))))))"
)
assert str(LiteralNumberVar.create(5) > False) == "(5 > 0)"
@@ -1058,7 +1058,7 @@ def test_inf_and_nan(var, expected_js):
assert str(var) == expected_js
assert isinstance(var, NumberVar)
assert isinstance(var, LiteralVar)
- with pytest.raises(PrimitiveUnserializableToJSON):
+ with pytest.raises(PrimitiveUnserializableToJSONError):
var.json()
@@ -1070,19 +1070,19 @@ def test_array_operations():
assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].slice().reverse()"
assert (
str(ArrayVar.range(10))
- == "Array.from({ length: (10 - 0) / 1 }, (_, i) => 0 + i * 1)"
+ == "Array.from({ length: Math.ceil((10 - 0) / 1) }, (_, i) => 0 + i * 1)"
)
assert (
str(ArrayVar.range(1, 10))
- == "Array.from({ length: (10 - 1) / 1 }, (_, i) => 1 + i * 1)"
+ == "Array.from({ length: Math.ceil((10 - 1) / 1) }, (_, i) => 1 + i * 1)"
)
assert (
str(ArrayVar.range(1, 10, 2))
- == "Array.from({ length: (10 - 1) / 2 }, (_, i) => 1 + i * 2)"
+ == "Array.from({ length: Math.ceil((10 - 1) / 2) }, (_, i) => 1 + i * 2)"
)
assert (
str(ArrayVar.range(1, 10, -1))
- == "Array.from({ length: (10 - 1) / -1 }, (_, i) => 1 + i * -1)"
+ == "Array.from({ length: Math.ceil((10 - 1) / -1) }, (_, i) => 1 + i * -1)"
)
@@ -1127,7 +1127,7 @@ def test_var_component():
for _, imported_objects in var_data.imports
)
- has_eval_react_component(ComponentVarState.field_var) # type: ignore
+ has_eval_react_component(ComponentVarState.field_var) # pyright: ignore [reportArgumentType]
has_eval_react_component(ComponentVarState.computed_var)
@@ -1139,15 +1139,15 @@ def test_type_chains():
List[int],
)
assert (
- str(object_var.keys()[0].upper()) # type: ignore
+ str(object_var.keys()[0].upper())
== 'Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })).at(0).toUpperCase()'
)
assert (
- str(object_var.entries()[1][1] - 1) # type: ignore
+ str(object_var.entries()[1][1] - 1)
== '(Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })).at(1).at(1) - 1)'
)
assert (
- str(object_var["c"] + object_var["b"]) # type: ignore
+ str(object_var["c"] + object_var["b"]) # pyright: ignore [reportCallIssue, reportOperatorIssue]
== '(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["c"] + ({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["b"])'
)
@@ -1156,7 +1156,7 @@ def test_nested_dict():
arr = LiteralArrayVar.create([{"bar": ["foo", "bar"]}], List[Dict[str, List[str]]])
assert (
- str(arr[0]["bar"][0]) == '[({ ["bar"] : ["foo", "bar"] })].at(0)["bar"].at(0)'
+ str(arr[0]["bar"][0]) == '[({ ["bar"] : ["foo", "bar"] })].at(0)["bar"].at(0)' # pyright: ignore [reportIndexIssue]
)
@@ -1352,7 +1352,7 @@ def test_unsupported_types_for_contains(var: Var):
var: The base var.
"""
with pytest.raises(TypeError) as err:
- assert var.contains(1)
+ assert var.contains(1) # pyright: ignore [reportAttributeAccessIssue]
assert (
err.value.args[0]
== f"Var of type {var._var_type} does not support contains check."
@@ -1382,7 +1382,7 @@ def test_unsupported_types_for_string_contains(other):
def test_unsupported_default_contains():
with pytest.raises(TypeError) as err:
- assert 1 in Var(_js_expr="var", _var_type=str).guess_type()
+ assert 1 in Var(_js_expr="var", _var_type=str).guess_type() # pyright: ignore [reportOperatorIssue]
assert (
err.value.args[0]
== "'in' operator not supported for Var types, use Var.contains() instead."
@@ -1808,16 +1808,13 @@ def cv_fget(state: BaseState) -> int:
@pytest.mark.parametrize(
"deps,expected",
[
- (["a"], {"a"}),
- (["b"], {"b"}),
- ([ComputedVar(fget=cv_fget)], {"cv_fget"}),
+ (["a"], {None: {"a"}}),
+ (["b"], {None: {"b"}}),
+ ([ComputedVar(fget=cv_fget)], {None: {"cv_fget"}}),
],
)
def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
- @computed_var(
- deps=deps,
- cache=True,
- )
+ @computed_var(deps=deps)
def test_var(state) -> int:
return 1
@@ -1835,10 +1832,7 @@ def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
def test_invalid_computed_var_deps(deps: List):
with pytest.raises(TypeError):
- @computed_var(
- deps=deps,
- cache=True,
- )
+ @computed_var(deps=deps)
def test_var(state) -> int:
return 1
@@ -1862,3 +1856,65 @@ def test_to_string_operation():
single_var = Var.create(Email())
assert single_var._var_type == Email
+
+
+@pytest.mark.asyncio
+async def test_async_computed_var():
+ side_effect_counter = 0
+
+ class AsyncComputedVarState(BaseState):
+ v: int = 1
+
+ @computed_var(cache=True)
+ async def async_computed_var(self) -> int:
+ nonlocal side_effect_counter
+ side_effect_counter += 1
+ return self.v + 1
+
+ my_state = AsyncComputedVarState()
+ assert await my_state.async_computed_var == 2
+ assert await my_state.async_computed_var == 2
+ my_state.v = 2
+ assert await my_state.async_computed_var == 3
+ assert await my_state.async_computed_var == 3
+ assert side_effect_counter == 2
+
+
+def test_var_data_hooks():
+ var_data_str = VarData(hooks="what")
+ var_data_list = VarData(hooks=["what"])
+ var_data_dict = VarData(hooks={"what": None})
+ assert var_data_str == var_data_list == var_data_dict
+
+ var_data_list_multiple = VarData(hooks=["what", "whot"])
+ var_data_dict_multiple = VarData(hooks={"what": None, "whot": None})
+ assert var_data_list_multiple == var_data_dict_multiple
+
+
+def test_var_data_with_hooks_value():
+ var_data = VarData(hooks={"what": VarData(hooks={"whot": VarData(hooks="whott")})})
+ assert var_data == VarData(hooks=["what", "whot", "whott"])
+
+
+def test_str_var_in_components(mocker):
+ class StateWithVar(rx.State):
+ field: int = 1
+
+ mocker.patch(
+ "reflex.components.base.bare.get_performance_mode",
+ return_value=PerformanceMode.RAISE,
+ )
+
+ with pytest.raises(ValueError):
+ rx.vstack(
+ str(StateWithVar.field),
+ )
+
+ mocker.patch(
+ "reflex.components.base.bare.get_performance_mode",
+ return_value=PerformanceMode.OFF,
+ )
+
+ rx.vstack(
+ str(StateWithVar.field),
+ )
diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py
index 2a2aa8259..053d5a3ae 100644
--- a/tests/units/utils/test_format.py
+++ b/tests/units/utils/test_format.py
@@ -189,11 +189,11 @@ def test_to_snake_case(input: str, output: str):
("kebab-case", "kebabCase"),
("kebab-case-two", "kebabCaseTwo"),
("snake_kebab-case", "snakeKebabCase"),
- ("_hover", "_hover"),
- ("-starts-with-hyphen", "-startsWithHyphen"),
- ("--starts-with-double-hyphen", "--startsWithDoubleHyphen"),
- ("_starts_with_underscore", "_startsWithUnderscore"),
- ("__starts_with_double_underscore", "__startsWithDoubleUnderscore"),
+ ("_hover", "Hover"),
+ ("-starts-with-hyphen", "StartsWithHyphen"),
+ ("--starts-with-double-hyphen", "StartsWithDoubleHyphen"),
+ ("_starts_with_underscore", "StartsWithUnderscore"),
+ ("__starts_with_double_underscore", "StartsWithDoubleUnderscore"),
(":start-with-colon", ":startWithColon"),
(":-start-with-colon-dash", ":StartWithColonDash"),
],
@@ -523,7 +523,7 @@ def test_format_event_handler(input, output):
input: The event handler input.
output: The expected output.
"""
- assert format.format_event_handler(input) == output # type: ignore
+ assert format.format_event_handler(input) == output
@pytest.mark.parametrize(
@@ -582,7 +582,7 @@ formatted_router = {
"input, output",
[
(
- TestState(_reflex_internal_init=True).dict(), # type: ignore
+ TestState(_reflex_internal_init=True).dict(), # pyright: ignore [reportCallIssue]
{
TestState.get_full_name(): {
"array": [1, 2, 3.14],
@@ -615,7 +615,7 @@ formatted_router = {
},
),
(
- DateTimeState(_reflex_internal_init=True).dict(), # type: ignore
+ DateTimeState(_reflex_internal_init=True).dict(), # pyright: ignore [reportCallIssue]
{
DateTimeState.get_full_name(): {
"d": "1989-11-09",
diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py
index f8573111c..74dcf79b0 100644
--- a/tests/units/utils/test_utils.py
+++ b/tests/units/utils/test_utils.py
@@ -31,7 +31,7 @@ def get_above_max_version():
"""
semantic_version_list = constants.Bun.VERSION.split(".")
- semantic_version_list[-1] = str(int(semantic_version_list[-1]) + 1) # type: ignore
+ semantic_version_list[-1] = str(int(semantic_version_list[-1]) + 1) # pyright: ignore [reportArgumentType, reportCallIssue]
return ".".join(semantic_version_list)
@@ -115,16 +115,29 @@ def test_typehint_issubclass(subclass, superclass, expected):
assert types.typehint_issubclass(subclass, superclass) == expected
-def test_validate_invalid_bun_path(mocker):
+def test_validate_none_bun_path(mocker):
+ """Test that an error is thrown when a bun path is not specified.
+
+ Args:
+ mocker: Pytest mocker object.
+ """
+ mocker.patch("reflex.utils.path_ops.get_bun_path", return_value=None)
+ # with pytest.raises(typer.Exit):
+ prerequisites.validate_bun()
+
+
+def test_validate_invalid_bun_path(
+ mocker,
+):
"""Test that an error is thrown when a custom specified bun path is not valid
or does not exist.
Args:
mocker: Pytest mocker object.
"""
- mock = mocker.Mock()
- mocker.patch.object(mock, "bun_path", return_value="/mock/path")
- mocker.patch("reflex.utils.prerequisites.get_config", mock)
+ mock_path = mocker.Mock()
+ mocker.patch("reflex.utils.path_ops.get_bun_path", return_value=mock_path)
+ mocker.patch("reflex.utils.path_ops.samefile", return_value=False)
mocker.patch("reflex.utils.prerequisites.get_bun_version", return_value=None)
with pytest.raises(typer.Exit):
@@ -137,9 +150,10 @@ def test_validate_bun_path_incompatible_version(mocker):
Args:
mocker: Pytest mocker object.
"""
- mock = mocker.Mock()
- mocker.patch.object(mock, "bun_path", return_value="/mock/path")
- mocker.patch("reflex.utils.prerequisites.get_config", mock)
+ mock_path = mocker.Mock()
+ mock_path.samefile.return_value = False
+ mocker.patch("reflex.utils.path_ops.get_bun_path", return_value=mock_path)
+ mocker.patch("reflex.utils.path_ops.samefile", return_value=False)
mocker.patch(
"reflex.utils.prerequisites.get_bun_version",
return_value=version.parse("0.6.5"),
@@ -587,9 +601,7 @@ def test_style_prop_with_event_handler_value(callable):
}
with pytest.raises(ReflexError):
- rx.box(
- style=style, # type: ignore
- )
+ rx.box(style=style) # pyright: ignore [reportArgumentType]
def test_is_prod_mode() -> None:
diff --git a/tests/units/vars/test_base.py b/tests/units/vars/test_base.py
index 68bc0c38e..8f9e99fe4 100644
--- a/tests/units/vars/test_base.py
+++ b/tests/units/vars/test_base.py
@@ -1,8 +1,9 @@
-from typing import Dict, List, Union
+from typing import List, Mapping, Union
import pytest
-from reflex.vars.base import figure_out_type
+from reflex.state import State
+from reflex.vars.base import computed_var, figure_out_type
class CustomDict(dict[str, str]):
@@ -37,13 +38,26 @@ class ChildGenericDict(GenericDict):
("a", str),
([1, 2, 3], List[int]),
([1, 2.0, "a"], List[Union[int, float, str]]),
- ({"a": 1, "b": 2}, Dict[str, int]),
- ({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]),
+ ({"a": 1, "b": 2}, Mapping[str, int]),
+ ({"a": 1, 2: "b"}, Mapping[Union[int, str], Union[str, int]]),
(CustomDict(), CustomDict),
(ChildCustomDict(), ChildCustomDict),
- (GenericDict({1: 1}), Dict[int, int]),
- (ChildGenericDict({1: 1}), Dict[int, int]),
+ (GenericDict({1: 1}), Mapping[int, int]),
+ (ChildGenericDict({1: 1}), Mapping[int, int]),
],
)
def test_figure_out_type(value, expected):
assert figure_out_type(value) == expected
+
+
+def test_computed_var_replace() -> None:
+ class StateTest(State):
+ @computed_var(cache=True)
+ def cv(self) -> int:
+ return 1
+
+ cv = StateTest.cv
+ assert cv._var_type is int
+
+ replaced = cv._replace(_var_type=float)
+ assert replaced._var_type is float
diff --git a/tests/units/vars/test_object.py b/tests/units/vars/test_object.py
index efcb21166..89ace55bb 100644
--- a/tests/units/vars/test_object.py
+++ b/tests/units/vars/test_object.py
@@ -1,10 +1,14 @@
+import dataclasses
+
import pytest
+from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
from typing_extensions import assert_type
import reflex as rx
from reflex.utils.types import GenericType
from reflex.vars.base import Var
from reflex.vars.object import LiteralObjectVar, ObjectVar
+from reflex.vars.sequence import ArrayVar
class Bare:
@@ -32,24 +36,54 @@ class Base(rx.Base):
quantity: int = 0
+class SqlaBase(DeclarativeBase, MappedAsDataclass):
+ """Sqlalchemy declarative mapping base class."""
+
+ pass
+
+
+class SqlaModel(SqlaBase):
+ """A sqlalchemy model with a single attribute."""
+
+ __tablename__: str = "sqla_model"
+
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
+ quantity: Mapped[int] = mapped_column(default=0)
+
+
+@dataclasses.dataclass
+class Dataclass:
+ """A dataclass with a single attribute."""
+
+ quantity: int = 0
+
+
class ObjectState(rx.State):
- """A reflex state with bare and base objects."""
+ """A reflex state with bare, base and sqlalchemy base vars."""
bare: rx.Field[Bare] = rx.field(Bare())
+ bare_optional: rx.Field[Bare | None] = rx.field(None)
base: rx.Field[Base] = rx.field(Base())
+ base_optional: rx.Field[Base | None] = rx.field(None)
+ sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
+ sqlamodel_optional: rx.Field[SqlaModel | None] = rx.field(None)
+ dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
+ dataclass_optional: rx.Field[Dataclass | None] = rx.field(None)
+
+ base_list: rx.Field[list[Base]] = rx.field([Base()])
-@pytest.mark.parametrize("type_", [Base, Bare])
-def test_var_create(type_: GenericType) -> None:
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
+def test_var_create(type_: type[Base | Bare | SqlaModel | Dataclass]) -> None:
my_object = type_()
var = Var.create(my_object)
assert var._var_type is type_
-
+ assert isinstance(var, ObjectVar)
quantity = var.quantity
assert quantity._var_type is int
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_literal_create(type_: GenericType) -> None:
my_object = type_()
var = LiteralObjectVar.create(my_object)
@@ -59,18 +93,18 @@ def test_literal_create(type_: GenericType) -> None:
assert quantity._var_type is int
-@pytest.mark.parametrize("type_", [Base, Bare])
-def test_guess(type_: GenericType) -> None:
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
+def test_guess(type_: type[Base | Bare | SqlaModel | Dataclass]) -> None:
my_object = type_()
var = Var.create(my_object)
var = var.guess_type()
assert var._var_type is type_
-
+ assert isinstance(var, ObjectVar)
quantity = var.quantity
assert quantity._var_type is int
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_state(type_: GenericType) -> None:
attr_name = type_.__name__.lower()
var = getattr(ObjectState, attr_name)
@@ -80,7 +114,7 @@ def test_state(type_: GenericType) -> None:
assert quantity._var_type is int
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_state_to_operation(type_: GenericType) -> None:
attr_name = type_.__name__.lower()
original_var = getattr(ObjectState, attr_name)
@@ -100,3 +134,29 @@ def test_typing() -> None:
# Base
var = ObjectState.base
_ = assert_type(var, ObjectVar[Base])
+ optional_var = ObjectState.base_optional
+ _ = assert_type(optional_var, ObjectVar[Base | None])
+ list_var = ObjectState.base_list
+ _ = assert_type(list_var, ArrayVar[list[Base]])
+ list_var_0 = list_var[0]
+ _ = assert_type(list_var_0, ObjectVar[Base])
+
+ # Sqla
+ var = ObjectState.sqlamodel
+ _ = assert_type(var, ObjectVar[SqlaModel])
+ optional_var = ObjectState.sqlamodel_optional
+ _ = assert_type(optional_var, ObjectVar[SqlaModel | None])
+ list_var = ObjectState.base_list
+ _ = assert_type(list_var, ArrayVar[list[Base]])
+ list_var_0 = list_var[0]
+ _ = assert_type(list_var_0, ObjectVar[Base])
+
+ # Dataclass
+ var = ObjectState.dataclass
+ _ = assert_type(var, ObjectVar[Dataclass])
+ optional_var = ObjectState.dataclass_optional
+ _ = assert_type(optional_var, ObjectVar[Dataclass | None])
+ list_var = ObjectState.base_list
+ _ = assert_type(list_var, ArrayVar[list[Base]])
+ list_var_0 = list_var[0]
+ _ = assert_type(list_var_0, ObjectVar[Base])