' # type: ignore
+ == f'
' # type: ignore
)
diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py
index 583bfa1e2..f09e800e5 100644
--- a/tests/units/components/core/test_match.py
+++ b/tests/units/components/core/test_match.py
@@ -41,15 +41,15 @@ def test_match_components():
assert len(match_cases) == 6
assert match_cases[0][0]._js_expr == "1"
- assert match_cases[0][0]._var_type == int
+ assert match_cases[0][0]._var_type is int
first_return_value_render = match_cases[0][1].render()
assert first_return_value_render["name"] == "RadixThemesText"
assert first_return_value_render["children"][0]["contents"] == '{"first value"}'
assert match_cases[1][0]._js_expr == "2"
- assert match_cases[1][0]._var_type == int
+ assert match_cases[1][0]._var_type is int
assert match_cases[1][1]._js_expr == "3"
- assert match_cases[1][1]._var_type == int
+ assert match_cases[1][1]._var_type is int
second_return_value_render = match_cases[1][2].render()
assert second_return_value_render["name"] == "RadixThemesText"
assert second_return_value_render["children"][0]["contents"] == '{"second value"}'
@@ -61,7 +61,7 @@ def test_match_components():
assert third_return_value_render["children"][0]["contents"] == '{"third value"}'
assert match_cases[3][0]._js_expr == '"random"'
- assert match_cases[3][0]._var_type == str
+ assert match_cases[3][0]._var_type is str
fourth_return_value_render = match_cases[3][1].render()
assert fourth_return_value_render["name"] == "RadixThemesText"
assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}'
@@ -73,7 +73,7 @@ def test_match_components():
assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}'
assert match_cases[5][0]._js_expr == f"({MatchState.get_name()}.num + 1)"
- assert match_cases[5][0]._var_type == int
+ assert match_cases[5][0]._var_type is int
fifth_return_value_render = match_cases[5][1].render()
assert fifth_return_value_render["name"] == "RadixThemesText"
assert fifth_return_value_render["children"][0]["contents"] == '{"sixth value"}'
diff --git a/tests/units/components/core/test_upload.py b/tests/units/components/core/test_upload.py
index 83f04b3e6..710baa161 100644
--- a/tests/units/components/core/test_upload.py
+++ b/tests/units/components/core/test_upload.py
@@ -1,3 +1,6 @@
+from typing import Any
+
+from reflex import event
from reflex.components.core.upload import (
StyledUpload,
Upload,
@@ -11,10 +14,11 @@ from reflex.state import State
from reflex.vars.base import LiteralVar, Var
-class TestUploadState(State):
+class UploadStateTest(State):
"""Test upload state."""
- def drop_handler(self, files):
+ @event
+ def drop_handler(self, files: Any):
"""Handle the drop event.
Args:
@@ -22,7 +26,8 @@ class TestUploadState(State):
"""
pass
- def not_drop_handler(self, not_files):
+ @event
+ def not_drop_handler(self, not_files: Any):
"""Handle the drop event without defining the files argument.
Args:
@@ -42,7 +47,7 @@ def test_get_upload_url():
def test__on_drop_spec():
- assert isinstance(_on_drop_spec(LiteralVar.create([])), list)
+ assert isinstance(_on_drop_spec(LiteralVar.create([])), tuple)
def test_upload_create():
@@ -55,7 +60,7 @@ def test_upload_create():
up_comp_2 = Upload.create(
id="foo_id",
- on_drop=TestUploadState.drop_handler([]), # type: ignore
+ on_drop=UploadStateTest.drop_handler([]), # type: ignore
)
assert isinstance(up_comp_2, Upload)
assert up_comp_2.is_used
@@ -65,7 +70,7 @@ def test_upload_create():
up_comp_3 = Upload.create(
id="foo_id",
- on_drop=TestUploadState.drop_handler,
+ on_drop=UploadStateTest.drop_handler,
)
assert isinstance(up_comp_3, Upload)
assert up_comp_3.is_used
@@ -75,7 +80,7 @@ def test_upload_create():
up_comp_4 = Upload.create(
id="foo_id",
- on_drop=TestUploadState.not_drop_handler([]), # type: ignore
+ on_drop=UploadStateTest.not_drop_handler([]), # type: ignore
)
assert isinstance(up_comp_4, Upload)
assert up_comp_4.is_used
@@ -91,7 +96,7 @@ def test_styled_upload_create():
styled_up_comp_2 = StyledUpload.create(
id="foo_id",
- on_drop=TestUploadState.drop_handler([]), # type: ignore
+ on_drop=UploadStateTest.drop_handler([]), # type: ignore
)
assert isinstance(styled_up_comp_2, StyledUpload)
assert styled_up_comp_2.is_used
@@ -101,7 +106,7 @@ def test_styled_upload_create():
styled_up_comp_3 = StyledUpload.create(
id="foo_id",
- on_drop=TestUploadState.drop_handler,
+ on_drop=UploadStateTest.drop_handler,
)
assert isinstance(styled_up_comp_3, StyledUpload)
assert styled_up_comp_3.is_used
@@ -111,7 +116,7 @@ def test_styled_upload_create():
styled_up_comp_4 = StyledUpload.create(
id="foo_id",
- on_drop=TestUploadState.not_drop_handler([]), # type: ignore
+ on_drop=UploadStateTest.not_drop_handler([]), # type: ignore
)
assert isinstance(styled_up_comp_4, StyledUpload)
assert styled_up_comp_4.is_used
diff --git a/tests/units/components/datadisplay/test_code.py b/tests/units/components/datadisplay/test_code.py
index 809c68fe5..6b7168756 100644
--- a/tests/units/components/datadisplay/test_code.py
+++ b/tests/units/components/datadisplay/test_code.py
@@ -11,22 +11,3 @@ def test_code_light_dark_theme(theme, expected):
code_block = CodeBlock.create(theme=theme)
assert code_block.theme._js_expr == expected # type: ignore
-
-
-def generate_custom_code(language, expected_case):
- return f"SyntaxHighlighter.registerLanguage('{language}', {expected_case})"
-
-
-@pytest.mark.parametrize(
- "language, expected_case",
- [
- ("python", "python"),
- ("firestore-security-rules", "firestoreSecurityRules"),
- ("typescript", "typescript"),
- ],
-)
-def test_get_custom_code(language, expected_case):
- code_block = CodeBlock.create(language=language)
- assert code_block._get_custom_code() == generate_custom_code(
- language, expected_case
- )
diff --git a/tests/units/components/datadisplay/test_shiki_code.py b/tests/units/components/datadisplay/test_shiki_code.py
new file mode 100644
index 000000000..eb473ba06
--- /dev/null
+++ b/tests/units/components/datadisplay/test_shiki_code.py
@@ -0,0 +1,172 @@
+import pytest
+
+from reflex.components.datadisplay.shiki_code_block import (
+ ShikiBaseTransformers,
+ ShikiCodeBlock,
+ ShikiHighLevelCodeBlock,
+ ShikiJsTransformer,
+)
+from reflex.components.el.elements.forms import Button
+from reflex.components.lucide.icon import Icon
+from reflex.components.radix.themes.layout.box import Box
+from reflex.style import Style
+from reflex.vars import Var
+
+
+@pytest.mark.parametrize(
+ "library, fns, expected_output, raises_exception",
+ [
+ ("some_library", ["function_one"], ["function_one"], False),
+ ("some_library", [123], None, True),
+ ("some_library", [], [], False),
+ (
+ "some_library",
+ ["function_one", "function_two"],
+ ["function_one", "function_two"],
+ False,
+ ),
+ ("", ["function_one"], ["function_one"], False),
+ ("some_library", ["function_one", 789], None, True),
+ ("", [], [], False),
+ ],
+)
+def test_create_transformer(library, fns, expected_output, raises_exception):
+ if raises_exception:
+ # Ensure ValueError is raised for invalid cases
+ with pytest.raises(ValueError):
+ ShikiCodeBlock.create_transformer(library, fns)
+ else:
+ transformer = ShikiCodeBlock.create_transformer(library, fns)
+ assert isinstance(transformer, ShikiBaseTransformers)
+ assert transformer.library == library
+
+ # Verify that the functions are correctly wrapped in FunctionStringVar
+ function_names = [str(fn) for fn in transformer.fns]
+ assert function_names == expected_output
+
+
+@pytest.mark.parametrize(
+ "code_block, children, props, expected_first_child, expected_styles",
+ [
+ ("print('Hello')", ["print('Hello')"], {}, "print('Hello')", {}),
+ (
+ "print('Hello')",
+ ["print('Hello')", "More content"],
+ {},
+ "print('Hello')",
+ {},
+ ),
+ (
+ "print('Hello')",
+ ["print('Hello')"],
+ {
+ "transformers": [
+ ShikiBaseTransformers(
+ library="lib", fns=[], style=Style({"color": "red"})
+ )
+ ]
+ },
+ "print('Hello')",
+ {"color": "red"},
+ ),
+ (
+ "print('Hello')",
+ ["print('Hello')"],
+ {
+ "transformers": [
+ ShikiBaseTransformers(
+ library="lib", fns=[], style=Style({"color": "red"})
+ )
+ ],
+ "style": {"background": "blue"},
+ },
+ "print('Hello')",
+ {"color": "red", "background": "blue"},
+ ),
+ ],
+)
+def test_create_shiki_code_block(
+ code_block, children, props, expected_first_child, expected_styles
+):
+ component = ShikiCodeBlock.create(code_block, *children, **props)
+
+ # Test that the created component is a Box
+ assert isinstance(component, Box)
+
+ # Test that the first child is the code
+ code_block_component = component.children[0]
+ assert code_block_component.code._var_value == expected_first_child # type: ignore
+
+ applied_styles = component.style
+ for key, value in expected_styles.items():
+ assert Var.create(applied_styles[key])._var_value == value
+
+
+@pytest.mark.parametrize(
+ "children, props, expected_transformers, expected_button_type",
+ [
+ (["print('Hello')"], {"use_transformers": True}, [ShikiJsTransformer], None),
+ (["print('Hello')"], {"can_copy": True}, None, Button),
+ (
+ ["print('Hello')"],
+ {
+ "can_copy": True,
+ "copy_button": Button.create(Icon.create(tag="a_arrow_down")),
+ },
+ None,
+ Button,
+ ),
+ ],
+)
+def test_create_shiki_high_level_code_block(
+ children, props, expected_transformers, expected_button_type
+):
+ component = ShikiHighLevelCodeBlock.create(*children, **props)
+
+ # Test that the created component is a Box
+ assert isinstance(component, Box)
+
+ # Test that the first child is the code block component
+ code_block_component = component.children[0]
+ assert code_block_component.code._var_value == children[0] # type: ignore
+
+ # Check if the transformer is set correctly if expected
+ if expected_transformers:
+ exp_trans_names = [t.__name__ for t in expected_transformers]
+ for transformer in code_block_component.transformers._var_value: # type: ignore
+ assert type(transformer).__name__ in exp_trans_names
+
+ # Check if the second child is the copy button if can_copy is True
+ if props.get("can_copy", False):
+ if props.get("copy_button"):
+ assert isinstance(component.children[1], expected_button_type)
+ assert component.children[1] == props["copy_button"]
+ else:
+ assert isinstance(component.children[1], expected_button_type)
+ else:
+ assert len(component.children) == 1
+
+
+@pytest.mark.parametrize(
+ "children, props",
+ [
+ (["print('Hello')"], {"theme": "dark"}),
+ (["print('Hello')"], {"language": "javascript"}),
+ ],
+)
+def test_shiki_high_level_code_block_theme_language_mapping(children, props):
+ component = ShikiHighLevelCodeBlock.create(*children, **props)
+
+ # Test that the theme is mapped correctly
+ if "theme" in props:
+ assert component.children[
+ 0
+ ].theme._var_value == ShikiHighLevelCodeBlock._map_themes(props["theme"]) # type: ignore
+
+ # Test that the language is mapped correctly
+ if "language" in props:
+ assert component.children[
+ 0
+ ].language._var_value == ShikiHighLevelCodeBlock._map_languages( # type: ignore
+ props["language"]
+ )
diff --git a/tests/units/components/el/test_svg.py b/tests/units/components/el/test_svg.py
new file mode 100644
index 000000000..29aaa96dd
--- /dev/null
+++ b/tests/units/components/el/test_svg.py
@@ -0,0 +1,74 @@
+from reflex.components.el.elements.media import (
+ Circle,
+ Defs,
+ Ellipse,
+ Line,
+ LinearGradient,
+ Path,
+ Polygon,
+ RadialGradient,
+ Rect,
+ Stop,
+ Svg,
+ Text,
+)
+
+
+def test_circle():
+ circle = Circle.create().render()
+ assert circle["name"] == "circle"
+
+
+def test_defs():
+ defs = Defs.create().render()
+ assert defs["name"] == "defs"
+
+
+def test_ellipse():
+ ellipse = Ellipse.create().render()
+ assert ellipse["name"] == "ellipse"
+
+
+def test_line():
+ line = Line.create().render()
+ assert line["name"] == "line"
+
+
+def test_linear_gradient():
+ linear_gradient = LinearGradient.create().render()
+ assert linear_gradient["name"] == "linearGradient"
+
+
+def test_path():
+ path = Path.create().render()
+ assert path["name"] == "path"
+
+
+def test_polygon():
+ polygon = Polygon.create().render()
+ assert polygon["name"] == "polygon"
+
+
+def test_radial_gradient():
+ radial_gradient = RadialGradient.create().render()
+ assert radial_gradient["name"] == "radialGradient"
+
+
+def test_rect():
+ rect = Rect.create().render()
+ assert rect["name"] == "rect"
+
+
+def test_svg():
+ svg = Svg.create().render()
+ assert svg["name"] == "svg"
+
+
+def test_text():
+ text = Text.create().render()
+ assert text["name"] == "text"
+
+
+def test_stop():
+ stop = Stop.create().render()
+ assert stop["name"] == "stop"
diff --git a/tests/units/components/forms/test_uploads.py b/tests/units/components/forms/test_uploads.py
deleted file mode 100644
index 3b2ee014f..000000000
--- a/tests/units/components/forms/test_uploads.py
+++ /dev/null
@@ -1,205 +0,0 @@
-import pytest
-
-import reflex as rx
-
-
-@pytest.fixture
-def upload_root_component():
- """A test upload component function.
-
- Returns:
- A test upload component function.
- """
-
- def upload_root_component():
- return rx.upload.root(
- rx.button("select file"),
- rx.text("Drag and drop files here or click to select files"),
- border="1px dotted black",
- )
-
- return upload_root_component()
-
-
-@pytest.fixture
-def upload_component():
- """A test upload component function.
-
- Returns:
- A test upload component function.
- """
-
- def upload_component():
- return rx.upload(
- rx.button("select file"),
- rx.text("Drag and drop files here or click to select files"),
- border="1px dotted black",
- )
-
- return upload_component()
-
-
-@pytest.fixture
-def upload_component_id_special():
- def upload_component():
- return rx.upload(
- rx.button("select file"),
- rx.text("Drag and drop files here or click to select files"),
- border="1px dotted black",
- id="#spec!`al-_98ID",
- )
-
- return upload_component()
-
-
-@pytest.fixture
-def upload_component_with_props():
- """A test upload component with props function.
-
- Returns:
- A test upload component with props function.
- """
-
- def upload_component_with_props():
- return rx.upload(
- rx.button("select file"),
- rx.text("Drag and drop files here or click to select files"),
- border="1px dotted black",
- no_drag=True,
- max_files=2,
- )
-
- return upload_component_with_props()
-
-
-def test_upload_root_component_render(upload_root_component):
- """Test that the render function is set correctly.
-
- Args:
- upload_root_component: component fixture
- """
- upload = upload_root_component.render()
-
- # upload
- assert upload["name"] == "ReactDropzone"
- assert upload["props"] == [
- 'id={"default"}',
- "multiple={true}",
- "onDrop={e => setFilesById(filesById => {\n"
- " const updatedFilesById = Object.assign({}, filesById);\n"
- ' updatedFilesById["default"] = e;\n'
- " return updatedFilesById;\n"
- " })\n"
- " }",
- "ref={ref_default}",
- ]
- assert upload["args"] == ("getRootProps", "getInputProps")
-
- # box inside of upload
- [box] = upload["children"]
- assert box["name"] == "RadixThemesBox"
- assert box["props"] == [
- 'className={"rx-Upload"}',
- 'css={({ ["border"] : "1px dotted black" })}',
- "{...getRootProps()}",
- ]
-
- # input, button and text inside of box
- [input, button, text] = box["children"]
- assert input["name"] == "input"
- assert input["props"] == ['type={"file"}', "{...getInputProps()}"]
-
- assert button["name"] == "RadixThemesButton"
- assert button["children"][0]["contents"] == '{"select file"}'
-
- assert text["name"] == "RadixThemesText"
- assert (
- text["children"][0]["contents"]
- == '{"Drag and drop files here or click to select files"}'
- )
-
-
-def test_upload_component_render(upload_component):
- """Test that the render function is set correctly.
-
- Args:
- upload_component: component fixture
- """
- upload = upload_component.render()
-
- # upload
- assert upload["name"] == "ReactDropzone"
- assert upload["props"] == [
- 'id={"default"}',
- "multiple={true}",
- "onDrop={e => setFilesById(filesById => {\n"
- " const updatedFilesById = Object.assign({}, filesById);\n"
- ' updatedFilesById["default"] = e;\n'
- " return updatedFilesById;\n"
- " })\n"
- " }",
- "ref={ref_default}",
- ]
- assert upload["args"] == ("getRootProps", "getInputProps")
-
- # box inside of upload
- [box] = upload["children"]
- assert box["name"] == "RadixThemesBox"
- assert box["props"] == [
- 'className={"rx-Upload"}',
- 'css={({ ["border"] : "1px dotted black", ["padding"] : "5em", ["textAlign"] : "center" })}',
- "{...getRootProps()}",
- ]
-
- # input, button and text inside of box
- [input, button, text] = box["children"]
- assert input["name"] == "input"
- assert input["props"] == ['type={"file"}', "{...getInputProps()}"]
-
- assert button["name"] == "RadixThemesButton"
- assert button["children"][0]["contents"] == '{"select file"}'
-
- assert text["name"] == "RadixThemesText"
- assert (
- text["children"][0]["contents"]
- == '{"Drag and drop files here or click to select files"}'
- )
-
-
-def test_upload_component_with_props_render(upload_component_with_props):
- """Test that the render function is set correctly.
-
- Args:
- upload_component_with_props: component fixture
- """
- upload = upload_component_with_props.render()
-
- assert upload["props"] == [
- 'id={"default"}',
- "maxFiles={2}",
- "multiple={true}",
- "noDrag={true}",
- "onDrop={e => setFilesById(filesById => {\n"
- " const updatedFilesById = Object.assign({}, filesById);\n"
- ' updatedFilesById["default"] = e;\n'
- " return updatedFilesById;\n"
- " })\n"
- " }",
- "ref={ref_default}",
- ]
-
-
-def test_upload_component_id_with_special_chars(upload_component_id_special):
- upload = upload_component_id_special.render()
-
- assert upload["props"] == [
- r'id={"#spec!`al-_98ID"}',
- "multiple={true}",
- "onDrop={e => setFilesById(filesById => {\n"
- " const updatedFilesById = Object.assign({}, filesById);\n"
- ' updatedFilesById["#spec!`al-_98ID"] = e;\n'
- " return updatedFilesById;\n"
- " })\n"
- " }",
- "ref={ref__spec_al__98ID}",
- ]
diff --git a/tests/units/components/markdown/__init__.py b/tests/units/components/markdown/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/units/components/markdown/test_markdown.py b/tests/units/components/markdown/test_markdown.py
new file mode 100644
index 000000000..866f32ae1
--- /dev/null
+++ b/tests/units/components/markdown/test_markdown.py
@@ -0,0 +1,190 @@
+from typing import Type
+
+import pytest
+
+from reflex.components.component import Component, memo
+from reflex.components.datadisplay.code import CodeBlock
+from reflex.components.datadisplay.shiki_code_block import ShikiHighLevelCodeBlock
+from reflex.components.markdown.markdown import Markdown, MarkdownComponentMap
+from reflex.components.radix.themes.layout.box import Box
+from reflex.components.radix.themes.typography.heading import Heading
+from reflex.vars.base import Var
+
+
+class CustomMarkdownComponent(Component, MarkdownComponentMap):
+ """A custom markdown component."""
+
+ tag = "CustomMarkdownComponent"
+ library = "custom"
+
+ @classmethod
+ def get_fn_args(cls) -> tuple[str, ...]:
+ """Return the function arguments.
+
+ Returns:
+ The function arguments.
+ """
+ return ("custom_node", "custom_children", "custom_props")
+
+ @classmethod
+ def get_fn_body(cls) -> Var:
+ """Return the function body.
+
+ Returns:
+ The function body.
+ """
+ return Var(_js_expr="{return custom_node + custom_children + custom_props}")
+
+
+def syntax_highlighter_memoized_component(codeblock: Type[Component]):
+ @memo
+ def code_block(code: str, language: str):
+ return Box.create(
+ codeblock.create(
+ code,
+ language=language,
+ class_name="code-block",
+ can_copy=True,
+ ),
+ class_name="relative mb-4",
+ )
+
+ def code_block_markdown(*children, **props):
+ return code_block(
+ code=children[0], language=props.pop("language", "plain"), **props
+ )
+
+ return code_block_markdown
+
+
+@pytest.mark.parametrize(
+ "fn_body, fn_args, explicit_return, expected",
+ [
+ (
+ None,
+ None,
+ False,
+ Var(_js_expr="(({node, children, ...props}) => undefined)"),
+ ),
+ ("return node", ("node",), True, Var(_js_expr="(({node}) => {return node})")),
+ (
+ "return node + children",
+ ("node", "children"),
+ True,
+ Var(_js_expr="(({node, children}) => {return node + children})"),
+ ),
+ (
+ "return node + props",
+ ("node", "...props"),
+ True,
+ Var(_js_expr="(({node, ...props}) => {return node + props})"),
+ ),
+ (
+ "return node + children + props",
+ ("node", "children", "...props"),
+ True,
+ Var(
+ _js_expr="(({node, children, ...props}) => {return node + children + props})"
+ ),
+ ),
+ ],
+)
+def test_create_map_fn_var(fn_body, fn_args, explicit_return, expected):
+ result = MarkdownComponentMap.create_map_fn_var(
+ fn_body=Var(_js_expr=fn_body, _var_type=str) if fn_body else None,
+ fn_args=fn_args,
+ explicit_return=explicit_return,
+ )
+ assert result._js_expr == expected._js_expr
+
+
+@pytest.mark.parametrize(
+ ("cls", "fn_body", "fn_args", "explicit_return", "expected"),
+ [
+ (
+ MarkdownComponentMap,
+ None,
+ None,
+ False,
+ Var(_js_expr="(({node, children, ...props}) => undefined)"),
+ ),
+ (
+ MarkdownComponentMap,
+ "return node",
+ ("node",),
+ True,
+ Var(_js_expr="(({node}) => {return node})"),
+ ),
+ (
+ CustomMarkdownComponent,
+ None,
+ None,
+ True,
+ Var(
+ _js_expr="(({custom_node, custom_children, custom_props}) => {return custom_node + custom_children + custom_props})"
+ ),
+ ),
+ (
+ CustomMarkdownComponent,
+ "return custom_node",
+ ("custom_node",),
+ True,
+ Var(_js_expr="(({custom_node}) => {return custom_node})"),
+ ),
+ ],
+)
+def test_create_map_fn_var_subclass(cls, fn_body, fn_args, explicit_return, expected):
+ result = cls.create_map_fn_var(
+ fn_body=Var(_js_expr=fn_body, _var_type=int) if fn_body else None,
+ fn_args=fn_args,
+ explicit_return=explicit_return,
+ )
+ assert result._js_expr == expected._js_expr
+
+
+@pytest.mark.parametrize(
+ "key,component_map, expected",
+ [
+ (
+ "code",
+ {},
+ """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; if (_language) { (async () => { try { const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`); SyntaxHighlighter.registerLanguage(_language, module.default); } catch (error) { console.error(`Error importing language module for ${_language}:`, error); } })(); } ; return inline ? ( {children} ) : ( ); })""",
+ ),
+ (
+ "code",
+ {
+ "codeblock": lambda value, **props: ShikiHighLevelCodeBlock.create(
+ value, **props
+ )
+ },
+ """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; ; return inline ? ( {children} ) : ( ); })""",
+ ),
+ (
+ "h1",
+ {
+ "h1": lambda value: CustomMarkdownComponent.create(
+ Heading.create(value, as_="h1", size="6", margin_y="0.5em")
+ )
+ },
+ """(({custom_node, custom_children, custom_props}) => ({children}))""",
+ ),
+ (
+ "code",
+ {"codeblock": syntax_highlighter_memoized_component(CodeBlock)},
+ """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; if (_language) { (async () => { try { const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`); SyntaxHighlighter.registerLanguage(_language, module.default); } catch (error) { console.error(`Error importing language module for ${_language}:`, error); } })(); } ; return inline ? ( {children} ) : ( ); })""",
+ ),
+ (
+ "code",
+ {
+ "codeblock": syntax_highlighter_memoized_component(
+ ShikiHighLevelCodeBlock
+ )
+ },
+ """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; ; return inline ? ( {children} ) : ( ); })""",
+ ),
+ ],
+)
+def test_markdown_format_component(key, component_map, expected):
+ markdown = Markdown.create("# header", component_map=component_map)
+ result = markdown.format_component_map()
+ assert str(result[key]) == expected
diff --git a/tests/units/components/media/test_image.py b/tests/units/components/media/test_image.py
index f8618347c..742bd8c38 100644
--- a/tests/units/components/media/test_image.py
+++ b/tests/units/components/media/test_image.py
@@ -42,7 +42,7 @@ def test_set_src_str():
"`pic2.jpeg`",
)
# For plain rx.el.img, an explicit var is not created, so the quoting happens later
- # assert str(image.src) == "pic2.jpeg" # type: ignore
+ # assert str(image.src) == "pic2.jpeg" # type: ignore #noqa: ERA001
def test_set_src_img(pil_image: Img):
diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py
index 73d3f611b..674873b69 100644
--- a/tests/units/components/test_component.py
+++ b/tests/units/components/test_component.py
@@ -16,11 +16,18 @@ from reflex.components.component import (
)
from reflex.components.radix.themes.layout.box import Box
from reflex.constants import EventTriggers
-from reflex.event import EventChain, EventHandler, parse_args_spec
+from reflex.event import (
+ EventChain,
+ EventHandler,
+ input_event,
+ no_args_event_spec,
+ parse_args_spec,
+ passthrough_event_spec,
+)
from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
-from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
+from reflex.utils.exceptions import EventFnArgMismatch
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
@@ -37,6 +44,18 @@ def test_state():
def do_something_arg(self, arg):
pass
+ def do_something_with_bool(self, arg: bool):
+ pass
+
+ def do_something_with_int(self, arg: int):
+ pass
+
+ def do_something_with_list_int(self, arg: list[int]):
+ pass
+
+ def do_something_with_list_str(self, arg: list[str]):
+ pass
+
return TestState
@@ -89,8 +108,10 @@ def component2() -> Type[Component]:
"""
return {
**super().get_event_triggers(),
- "on_open": lambda e0: [e0],
- "on_close": lambda e0: [e0],
+ "on_open": passthrough_event_spec(bool),
+ "on_close": passthrough_event_spec(bool),
+ "on_user_visited_count_changed": passthrough_event_spec(int),
+ "on_user_list_changed": passthrough_event_spec(List[str]),
}
def _get_imports(self) -> ParsedImportDict:
@@ -576,7 +597,14 @@ def test_get_event_triggers(component1, component2):
assert component1().get_event_triggers().keys() == default_triggers
assert (
component2().get_event_triggers().keys()
- == {"on_open", "on_close", "on_prop_event"} | default_triggers
+ == {
+ "on_open",
+ "on_close",
+ "on_prop_event",
+ "on_user_visited_count_changed",
+ "on_user_list_changed",
+ }
+ | default_triggers
)
@@ -636,21 +664,18 @@ def test_component_create_unallowed_types(children, test_component):
"name": "Fragment",
"props": [],
"contents": "",
- "args": None,
"special_props": [],
"children": [
{
"name": "RadixThemesText",
"props": ['as={"p"}'],
"contents": "",
- "args": None,
"special_props": [],
"children": [
{
"name": "",
"props": [],
"contents": '{"first_text"}',
- "args": None,
"special_props": [],
"children": [],
"autofocus": False,
@@ -665,15 +690,12 @@ def test_component_create_unallowed_types(children, test_component):
(
(rx.text("first_text"), rx.text("second_text")),
{
- "args": None,
"autofocus": False,
"children": [
{
- "args": None,
"autofocus": False,
"children": [
{
- "args": None,
"autofocus": False,
"children": [],
"contents": '{"first_text"}',
@@ -688,11 +710,9 @@ def test_component_create_unallowed_types(children, test_component):
"special_props": [],
},
{
- "args": None,
"autofocus": False,
"children": [
{
- "args": None,
"autofocus": False,
"children": [],
"contents": '{"second_text"}',
@@ -716,15 +736,12 @@ def test_component_create_unallowed_types(children, test_component):
(
(rx.text("first_text"), rx.box((rx.text("second_text"),))),
{
- "args": None,
"autofocus": False,
"children": [
{
- "args": None,
"autofocus": False,
"children": [
{
- "args": None,
"autofocus": False,
"children": [],
"contents": '{"first_text"}',
@@ -739,19 +756,15 @@ def test_component_create_unallowed_types(children, test_component):
"special_props": [],
},
{
- "args": None,
"autofocus": False,
"children": [
{
- "args": None,
"autofocus": False,
"children": [
{
- "args": None,
"autofocus": False,
"children": [
{
- "args": None,
"autofocus": False,
"children": [],
"contents": '{"second_text"}',
@@ -797,7 +810,8 @@ def test_component_create_unpack_tuple_child(test_component, element, expected):
comp = test_component.create(element)
assert len(comp.children) == 1
- assert isinstance((fragment_wrapper := comp.children[0]), Fragment)
+ fragment_wrapper = comp.children[0]
+ assert isinstance(fragment_wrapper, Fragment)
assert fragment_wrapper.render() == expected
@@ -831,9 +845,9 @@ def test_component_event_trigger_arbitrary_args():
comp = C1.create(on_foo=C1State.mock_handler)
assert comp.render()["props"][0] == (
- "onFoo={((__e, _alpha, _bravo, _charlie) => ((addEvents("
- f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }})))], '
- "[__e, _alpha, _bravo, _charlie], ({ })))))}"
+ "onFoo={((__e, _alpha, _bravo, _charlie) => (addEvents("
+ f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }}), ({{ }})))], '
+ "[__e, _alpha, _bravo, _charlie], ({ }))))}"
)
@@ -891,26 +905,30 @@ def test_invalid_event_handler_args(component2, test_state):
test_state: A test state.
"""
# EventHandler args must match
- with pytest.raises(EventHandlerArgMismatch):
+ with pytest.raises(EventFnArgMismatch):
component2.create(on_click=test_state.do_something_arg)
- with pytest.raises(EventHandlerArgMismatch):
- component2.create(on_open=test_state.do_something)
- with pytest.raises(EventHandlerArgMismatch):
- component2.create(on_prop_event=test_state.do_something)
# Multiple EventHandler args: all must match
- with pytest.raises(EventHandlerArgMismatch):
+ with pytest.raises(EventFnArgMismatch):
component2.create(
on_click=[test_state.do_something_arg, test_state.do_something]
)
- with pytest.raises(EventHandlerArgMismatch):
- component2.create(
- on_open=[test_state.do_something_arg, test_state.do_something]
- )
- with pytest.raises(EventHandlerArgMismatch):
- component2.create(
- on_prop_event=[test_state.do_something_arg, test_state.do_something]
- )
+
+ # Enable when 0.7.0 happens
+ # # Event Handler types must match
+ # with pytest.raises(EventHandlerArgTypeMismatch):
+ # component2.create(
+ # on_user_visited_count_changed=test_state.do_something_with_bool # noqa: ERA001 RUF100
+ # ) # noqa: ERA001 RUF100
+ # with pytest.raises(EventHandlerArgTypeMismatch):
+ # component2.create(on_user_list_changed=test_state.do_something_with_int) #noqa: ERA001
+ # with pytest.raises(EventHandlerArgTypeMismatch):
+ # component2.create(on_user_list_changed=test_state.do_something_with_list_int) #noqa: ERA001
+
+ # component2.create(on_open=test_state.do_something_with_int) #noqa: ERA001
+ # component2.create(on_open=test_state.do_something_with_bool) #noqa: ERA001
+ # component2.create(on_user_visited_count_changed=test_state.do_something_with_int) #noqa: ERA001
+ # component2.create(on_user_list_changed=test_state.do_something_with_list_str) #noqa: ERA001
# lambda cannot return weird values.
with pytest.raises(ValueError):
@@ -925,38 +943,19 @@ def test_invalid_event_handler_args(component2, test_state):
# lambda signature must match event trigger.
with pytest.raises(EventFnArgMismatch):
component2.create(on_click=lambda _: test_state.do_something_arg(1))
- with pytest.raises(EventFnArgMismatch):
- component2.create(on_open=lambda: test_state.do_something)
- with pytest.raises(EventFnArgMismatch):
- component2.create(on_prop_event=lambda: test_state.do_something)
# lambda returning EventHandler must match spec
- with pytest.raises(EventHandlerArgMismatch):
+ with pytest.raises(EventFnArgMismatch):
component2.create(on_click=lambda: test_state.do_something_arg)
- with pytest.raises(EventHandlerArgMismatch):
- component2.create(on_open=lambda _: test_state.do_something)
- with pytest.raises(EventHandlerArgMismatch):
- component2.create(on_prop_event=lambda _: test_state.do_something)
# Mixed EventSpec and EventHandler must match spec.
- with pytest.raises(EventHandlerArgMismatch):
+ with pytest.raises(EventFnArgMismatch):
component2.create(
on_click=lambda: [
test_state.do_something_arg(1),
test_state.do_something_arg,
]
)
- with pytest.raises(EventHandlerArgMismatch):
- component2.create(
- on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something]
- )
- with pytest.raises(EventHandlerArgMismatch):
- component2.create(
- on_prop_event=lambda _: [
- test_state.do_something_arg(1),
- test_state.do_something,
- ]
- )
def test_valid_event_handler_args(component2, test_state):
@@ -970,6 +969,10 @@ def test_valid_event_handler_args(component2, test_state):
component2.create(on_click=test_state.do_something)
component2.create(on_click=test_state.do_something_arg(1))
+ # Does not raise because event handlers are allowed to have less args than the spec.
+ component2.create(on_open=test_state.do_something)
+ component2.create(on_prop_event=test_state.do_something)
+
# Controlled event handlers should take args.
component2.create(on_open=test_state.do_something_arg)
component2.create(on_prop_event=test_state.do_something_arg)
@@ -978,10 +981,20 @@ def test_valid_event_handler_args(component2, test_state):
component2.create(on_open=test_state.do_something())
component2.create(on_prop_event=test_state.do_something())
+ # Multiple EventHandler args: all must match
+ component2.create(on_open=[test_state.do_something_arg, test_state.do_something])
+ component2.create(
+ on_prop_event=[test_state.do_something_arg, test_state.do_something]
+ )
+
# lambda returning EventHandler is okay if the spec matches.
component2.create(on_click=lambda: test_state.do_something)
component2.create(on_open=lambda _: test_state.do_something_arg)
component2.create(on_prop_event=lambda _: test_state.do_something_arg)
+ component2.create(on_open=lambda: test_state.do_something)
+ component2.create(on_prop_event=lambda: test_state.do_something)
+ component2.create(on_open=lambda _: test_state.do_something)
+ component2.create(on_prop_event=lambda _: test_state.do_something)
# lambda can always return an EventSpec.
component2.create(on_click=lambda: test_state.do_something_arg(1))
@@ -1014,6 +1027,15 @@ def test_valid_event_handler_args(component2, test_state):
component2.create(
on_prop_event=lambda _: [test_state.do_something_arg, test_state.do_something()]
)
+ component2.create(
+ on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something]
+ )
+ component2.create(
+ on_prop_event=lambda _: [
+ test_state.do_something_arg(1),
+ test_state.do_something,
+ ]
+ )
def test_get_hooks_nested(component1, component2, component3):
@@ -1111,10 +1133,10 @@ def test_component_with_only_valid_children(fixture, request):
@pytest.mark.parametrize(
"component,rendered",
[
- (rx.text("hi"), '\n {"hi"}\n'),
+ (rx.text("hi"), '\n\n{"hi"}\n'),
(
rx.box(rx.heading("test", size="3")),
- '\n \n {"test"}\n\n',
+ '\n\n\n\n{"test"}\n\n',
),
],
)
@@ -1178,14 +1200,14 @@ TEST_VAR = LiteralVar.create("test")._replace(
)
FORMATTED_TEST_VAR = LiteralVar.create(f"foo{TEST_VAR}bar")
STYLE_VAR = TEST_VAR._replace(_js_expr="style")
-EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain)
+EVENT_CHAIN_VAR = TEST_VAR.to(EventChain)
ARG_VAR = Var(_js_expr="arg")
TEST_VAR_DICT_OF_DICT = LiteralVar.create({"a": {"b": "test"}})._replace(
merge_var_data=TEST_VAR._var_data
)
FORMATTED_TEST_VAR_DICT_OF_DICT = LiteralVar.create(
- {"a": {"b": f"footestbar"}}
+ {"a": {"b": "footestbar"}}
)._replace(merge_var_data=TEST_VAR._var_data)
TEST_VAR_LIST_OF_LIST = LiteralVar.create([["test"]])._replace(
@@ -1224,6 +1246,7 @@ class EventState(rx.State):
v: int = 42
+ @rx.event
def handler(self):
"""A handler that does nothing."""
@@ -1414,8 +1437,6 @@ def test_get_vars(component, exp_vars):
comp_vars,
sorted(exp_vars, key=lambda v: v._js_expr),
):
- # print(str(comp_var), str(exp_var))
- # print(comp_var._get_all_var_data(), exp_var._get_all_var_data())
assert comp_var.equals(exp_var)
@@ -1778,7 +1799,7 @@ def test_custom_component_declare_event_handlers_in_fields():
return {
**super().get_event_triggers(),
"on_a": lambda e0: [e0],
- "on_b": lambda e0: [e0.target.value],
+ "on_b": input_event,
"on_c": lambda e0: [],
"on_d": lambda: [],
"on_e": lambda: [],
@@ -1787,9 +1808,9 @@ def test_custom_component_declare_event_handlers_in_fields():
class TestComponent(Component):
on_a: EventHandler[lambda e0: [e0]]
- on_b: EventHandler[lambda e0: [e0.target.value]]
- on_c: EventHandler[lambda e0: []]
- on_d: EventHandler[lambda: []]
+ on_b: EventHandler[input_event]
+ on_c: EventHandler[no_args_event_spec]
+ on_d: EventHandler[no_args_event_spec]
on_e: EventHandler
on_f: EventHandler[lambda a, b, c: [c, b, a]]
@@ -2141,6 +2162,7 @@ def test_add_style_foreach():
class TriggerState(rx.State):
"""Test state with event handlers."""
+ @rx.event
def do_something(self):
"""Sample event handler."""
pass
@@ -2159,7 +2181,7 @@ class TriggerState(rx.State):
rx.text("random text", on_click=TriggerState.do_something),
rx.text(
"random text",
- on_click=Var(_js_expr="toggleColorMode", _var_type=EventChain),
+ on_click=Var(_js_expr="toggleColorMode").to(EventChain),
),
),
True,
@@ -2169,7 +2191,7 @@ class TriggerState(rx.State):
rx.text("random text", on_click=rx.console_log("log")),
rx.text(
"random text",
- on_click=Var(_js_expr="toggleColorMode", _var_type=EventChain),
+ on_click=Var(_js_expr="toggleColorMode").to(EventChain),
),
),
False,
@@ -2209,3 +2231,56 @@ class TriggerState(rx.State):
)
def test_has_state_event_triggers(component, output):
assert component._has_stateful_event_triggers() == output
+
+
+class SpecialComponent(Box):
+ """A special component with custom attributes."""
+
+ data_prop: Var[str]
+ aria_prop: Var[str]
+
+
+@pytest.mark.parametrize(
+ ("component_kwargs", "exp_custom_attrs", "exp_style"),
+ [
+ (
+ {"data_test": "test", "aria_test": "test"},
+ {"data-test": "test", "aria-test": "test"},
+ {},
+ ),
+ (
+ {"data-test": "test", "aria-test": "test"},
+ {"data-test": "test", "aria-test": "test"},
+ {},
+ ),
+ (
+ {"custom_attrs": {"data-existing": "test"}, "data_new": "test"},
+ {"data-existing": "test", "data-new": "test"},
+ {},
+ ),
+ (
+ {"data_test": "test", "data_prop": "prop"},
+ {"data-test": "test"},
+ {},
+ ),
+ (
+ {"aria_test": "test", "aria_prop": "prop"},
+ {"aria-test": "test"},
+ {},
+ ),
+ ],
+)
+def test_special_props(component_kwargs, exp_custom_attrs, exp_style):
+ """Test that data_ and aria_ special props are correctly added to the component.
+
+ Args:
+ component_kwargs: The component kwargs.
+ exp_custom_attrs: The expected custom attributes.
+ exp_style: The expected style.
+ """
+ component = SpecialComponent.create(**component_kwargs)
+ assert component.custom_attrs == exp_custom_attrs
+ assert component.style == exp_style
+ for prop in SpecialComponent.get_props():
+ if prop in component_kwargs:
+ assert getattr(component, prop)._var_value == component_kwargs[prop]
diff --git a/tests/units/components/test_component_future_annotations.py b/tests/units/components/test_component_future_annotations.py
index 37aeb813a..0867a2d37 100644
--- a/tests/units/components/test_component_future_annotations.py
+++ b/tests/units/components/test_component_future_annotations.py
@@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any
from reflex.components.component import Component
-from reflex.event import EventHandler
+from reflex.event import EventHandler, input_event, no_args_event_spec
# This is a repeat of its namesake in test_component.py.
@@ -25,9 +25,9 @@ def test_custom_component_declare_event_handlers_in_fields():
class TestComponent(Component):
on_a: EventHandler[lambda e0: [e0]]
- on_b: EventHandler[lambda e0: [e0.target.value]]
- on_c: EventHandler[lambda e0: []]
- on_d: EventHandler[lambda: []]
+ on_b: EventHandler[input_event]
+ on_c: EventHandler[no_args_event_spec]
+ on_d: EventHandler[no_args_event_spec]
custom_component = ReferenceComponent.create()
test_component = TestComponent.create()
diff --git a/tests/units/components/test_component_state.py b/tests/units/components/test_component_state.py
index 574997ba5..1b62e35c8 100644
--- a/tests/units/components/test_component_state.py
+++ b/tests/units/components/test_component_state.py
@@ -1,7 +1,10 @@
"""Ensure that Components returned by ComponentState.create have independent State classes."""
+import pytest
+
import reflex as rx
from reflex.components.base.bare import Bare
+from reflex.utils.exceptions import ReflexRuntimeError
def test_component_state():
@@ -40,3 +43,21 @@ def test_component_state():
assert len(cs2.children) == 1
assert cs2.children[0].render() == Bare.create("b").render()
assert cs2.id == "b"
+
+
+def test_init_component_state() -> None:
+ """Ensure that ComponentState subclasses cannot be instantiated directly."""
+
+ class CS(rx.ComponentState):
+ @classmethod
+ def get_component(cls, *children, **props):
+ return rx.el.div()
+
+ with pytest.raises(ReflexRuntimeError):
+ CS()
+
+ class SubCS(CS):
+ pass
+
+ with pytest.raises(ReflexRuntimeError):
+ SubCS()
diff --git a/tests/units/components/test_props.py b/tests/units/components/test_props.py
new file mode 100644
index 000000000..8ab07f135
--- /dev/null
+++ b/tests/units/components/test_props.py
@@ -0,0 +1,63 @@
+import pytest
+
+from reflex.components.props import NoExtrasAllowedProps
+from reflex.utils.exceptions import InvalidPropValueError
+
+try:
+ from pydantic.v1 import ValidationError
+except ModuleNotFoundError:
+ from pydantic import ValidationError
+
+
+class PropA(NoExtrasAllowedProps):
+ """Base prop class."""
+
+ foo: str
+ bar: str
+
+
+class PropB(NoExtrasAllowedProps):
+ """Prop class with nested props."""
+
+ foobar: str
+ foobaz: PropA
+
+
+@pytest.mark.parametrize(
+ "props_class, kwargs, should_raise",
+ [
+ (PropA, {"foo": "value", "bar": "another_value"}, False),
+ (PropA, {"fooz": "value", "bar": "another_value"}, True),
+ (
+ PropB,
+ {
+ "foobaz": {"foo": "value", "bar": "another_value"},
+ "foobar": "foo_bar_value",
+ },
+ False,
+ ),
+ (
+ PropB,
+ {
+ "fooba": {"foo": "value", "bar": "another_value"},
+ "foobar": "foo_bar_value",
+ },
+ True,
+ ),
+ (
+ PropB,
+ {
+ "foobaz": {"foobar": "value", "bar": "another_value"},
+ "foobar": "foo_bar_value",
+ },
+ True,
+ ),
+ ],
+)
+def test_no_extras_allowed_props(props_class, kwargs, should_raise):
+ if should_raise:
+ with pytest.raises((ValidationError, InvalidPropValueError)):
+ props_class(**kwargs)
+ else:
+ props_instance = props_class(**kwargs)
+ assert isinstance(props_instance, props_class)
diff --git a/tests/units/components/test_tag.py b/tests/units/components/test_tag.py
index c41246e3f..a69e40b8b 100644
--- a/tests/units/components/test_tag.py
+++ b/tests/units/components/test_tag.py
@@ -119,7 +119,7 @@ def test_format_cond_tag():
tag_dict["false_value"],
)
assert cond._js_expr == "logged_in"
- assert cond._var_type == bool
+ assert cond._var_type is bool
assert true_value["name"] == "h1"
assert true_value["contents"] == "True content"
diff --git a/tests/units/conftest.py b/tests/units/conftest.py
index 589d35cd7..fb6229aca 100644
--- a/tests/units/conftest.py
+++ b/tests/units/conftest.py
@@ -1,5 +1,6 @@
"""Test fixtures."""
+import asyncio
import contextlib
import os
import platform
@@ -24,6 +25,11 @@ from .states import (
)
+def pytest_configure(config):
+ if config.getoption("asyncio_mode") == "auto":
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
+
+
@pytest.fixture
def app() -> App:
"""A base app.
@@ -200,7 +206,7 @@ class chdir(contextlib.AbstractContextManager):
def __enter__(self):
"""Save current directory and perform chdir."""
- self._old_cwd.append(Path(".").resolve())
+ self._old_cwd.append(Path.cwd())
os.chdir(self.path)
def __exit__(self, *excinfo):
diff --git a/tests/units/experimental/test_assets.py b/tests/units/experimental/test_assets.py
deleted file mode 100644
index 8037bcc75..000000000
--- a/tests/units/experimental/test_assets.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import shutil
-from pathlib import Path
-
-import pytest
-
-import reflex as rx
-
-
-def test_asset():
- # Test the asset function.
-
- # The asset function copies a file to the app's external assets directory.
- asset = rx._x.asset("custom_script.js", "subfolder")
- assert asset == "/external/test_assets/subfolder/custom_script.js"
- result_file = Path(
- Path.cwd(), "assets/external/test_assets/subfolder/custom_script.js"
- )
- assert result_file.exists()
-
- # Running a second time should not raise an error.
- asset = rx._x.asset("custom_script.js", "subfolder")
-
- # Test the asset function without a subfolder.
- asset = rx._x.asset("custom_script.js")
- assert asset == "/external/test_assets/custom_script.js"
- result_file = Path(Path.cwd(), "assets/external/test_assets/custom_script.js")
- assert result_file.exists()
-
- # clean up
- shutil.rmtree(Path.cwd() / "assets/external")
-
- with pytest.raises(FileNotFoundError):
- asset = rx._x.asset("non_existent_file.js")
-
- # Nothing is done to assets when file does not exist.
- assert not Path(Path.cwd() / "assets/external").exists()
diff --git a/tests/units/states/upload.py b/tests/units/states/upload.py
index f81e9f235..338025bcd 100644
--- a/tests/units/states/upload.py
+++ b/tests/units/states/upload.py
@@ -71,7 +71,7 @@ class FileUploadState(State):
assert file.filename is not None
self.img_list.append(file.filename)
- @rx.background
+ @rx.event(background=True)
async def bg_upload(self, files: List[rx.UploadFile]):
"""Background task cannot be upload handler.
@@ -119,7 +119,7 @@ class ChildFileUploadState(FileStateBase1):
assert file.filename is not None
self.img_list.append(file.filename)
- @rx.background
+ @rx.event(background=True)
async def bg_upload(self, files: List[rx.UploadFile]):
"""Background task cannot be upload handler.
@@ -167,7 +167,7 @@ class GrandChildFileUploadState(FileStateBase2):
assert file.filename is not None
self.img_list.append(file.filename)
- @rx.background
+ @rx.event(background=True)
async def bg_upload(self, files: List[rx.UploadFile]):
"""Background task cannot be upload handler.
diff --git a/tests/units/test_app.py b/tests/units/test_app.py
index 0c22c38e3..48a4bdda1 100644
--- a/tests/units/test_app.py
+++ b/tests/units/test_app.py
@@ -237,9 +237,12 @@ def test_add_page_default_route(app: App, index_page, about_page):
about_page: The about page.
"""
assert app.pages == {}
+ assert app.unevaluated_pages == {}
app.add_page(index_page)
+ app._compile_page("index")
assert app.pages.keys() == {"index"}
app.add_page(about_page)
+ app._compile_page("about")
assert app.pages.keys() == {"index", "about"}
@@ -252,8 +255,9 @@ def test_add_page_set_route(app: App, index_page, windows_platform: bool):
windows_platform: Whether the system is windows.
"""
route = "test" if windows_platform else "/test"
- assert app.pages == {}
+ assert app.unevaluated_pages == {}
app.add_page(index_page, route=route)
+ app._compile_page("test")
assert app.pages.keys() == {"test"}
@@ -267,8 +271,9 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
app = App(state=EmptyState)
assert app.state is not None
route = "/test/[dynamic]"
- assert app.pages == {}
+ assert app.unevaluated_pages == {}
app.add_page(index_page, route=route)
+ app._compile_page("test/[dynamic]")
assert app.pages.keys() == {"test/[dynamic]"}
assert "dynamic" in app.state.computed_vars
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
@@ -286,9 +291,9 @@ def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool)
windows_platform: Whether the system is windows.
"""
route = "test\\nested" if windows_platform else "/test/nested"
- assert app.pages == {}
+ assert app.unevaluated_pages == {}
app.add_page(index_page, route=route)
- assert app.pages.keys() == {route.strip(os.path.sep)}
+ assert app.unevaluated_pages.keys() == {route.strip(os.path.sep)}
def test_add_page_invalid_api_route(app: App, index_page):
@@ -765,7 +770,8 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
)
state._tmp_path = tmp_path
# The App state must be the "root" of the state tree
- app = App(state=State)
+ app = App()
+ app._enable_state()
app.event_namespace.emit = AsyncMock() # type: ignore
current_state = await app.state_manager.get_state(_substate_key(token, state))
data = b"This is binary data"
@@ -781,11 +787,11 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
}
file1 = UploadFile(
- filename=f"image1.jpg",
+ filename="image1.jpg",
file=bio,
)
file2 = UploadFile(
- filename=f"image2.jpg",
+ filename="image2.jpg",
file=bio,
)
upload_fn = upload(app)
@@ -868,7 +874,7 @@ async def test_upload_file_background(state, tmp_path, token):
await fn(request_mock, [file_mock])
assert (
err.value.args[0]
- == f"@rx.background is not supported for upload handler `{state.get_full_name()}.bg_upload`."
+ == f"@rx.event(background=True) is not supported for upload handler `{state.get_full_name()}.bg_upload`."
)
if isinstance(app.state_manager, StateManagerRedis):
@@ -893,12 +899,11 @@ class DynamicState(BaseState):
loaded: int = 0
counter: int = 0
- # side_effect_counter: int = 0
-
def on_load(self):
"""Event handler for page on_load, should trigger for all navigation events."""
self.loaded = self.loaded + 1
+ @rx.event
def on_counter(self):
"""Increment the counter var."""
self.counter = self.counter + 1
@@ -910,7 +915,6 @@ class DynamicState(BaseState):
Returns:
same as self.dynamic
"""
- # self.side_effect_counter = self.side_effect_counter + 1
return self.dynamic
on_load_internal = OnLoadInternalState.on_load_internal.fn
@@ -1000,8 +1004,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
substate_token = _substate_key(token, DynamicState)
sid = "mock_sid"
client_ip = "127.0.0.1"
- state = await app.state_manager.get_state(substate_token)
- assert state.dynamic == ""
+ async with app.state_manager.modify_state(substate_token) as state:
+ state.router_data = {"simulate": "hydrated"}
+ assert state.dynamic == ""
exp_vals = ["foo", "foobar", "baz"]
def _event(name, val, **kwargs):
@@ -1051,7 +1056,6 @@ async def test_dynamic_route_var_route_change_completed_on_load(
arg_name: exp_val,
f"comp_{arg_name}": exp_val,
constants.CompileVars.IS_HYDRATED: False,
- # "side_effect_counter": exp_index,
"router": exp_router,
}
},
@@ -1147,8 +1151,6 @@ async def test_dynamic_route_var_route_change_completed_on_load(
state = await app.state_manager.get_state(substate_token)
assert state.loaded == len(exp_vals)
assert state.counter == len(exp_vals)
- # print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}")
- # assert state.side_effect_counter == len(exp_vals)
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@@ -1173,6 +1175,7 @@ async def test_process_events(mocker, token: str):
"ip": "127.0.0.1",
}
app = App(state=GenState)
+
mocker.patch.object(app, "_postprocess", AsyncMock())
event = Event(
token=token,
@@ -1180,6 +1183,8 @@ async def test_process_events(mocker, token: str):
payload={"c": 5},
router_data=router_data,
)
+ async with app.state_manager.modify_state(event.substate_token) as state:
+ state.router_data = {"simulate": "hydrated"}
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
pass
@@ -1204,7 +1209,7 @@ async def test_process_events(mocker, token: str):
],
)
def test_overlay_component(
- state: State | None,
+ state: Type[State] | None,
overlay_component: Component | ComponentCallable | None,
exp_page_child: Type[Component] | None,
):
@@ -1236,6 +1241,7 @@ def test_overlay_component(
app.add_page(rx.box("Index"), route="/test")
# overlay components are wrapped during compile only
+ app._compile_page("test")
app._setup_overlay_component()
page = app.pages["test"]
@@ -1363,6 +1369,7 @@ def test_app_state_determination():
# Add a page with `on_load` enables state.
a1.add_page(rx.box("About"), route="/about", on_load=rx.console_log(""))
+ a1._compile_page("about")
assert a1.state is not None
a2 = App()
@@ -1370,6 +1377,7 @@ def test_app_state_determination():
# Referencing a state Var enables state.
a2.add_page(rx.box(rx.text(GenState.value)), route="/")
+ a2._compile_page("index")
assert a2.state is not None
a3 = App()
@@ -1377,6 +1385,7 @@ def test_app_state_determination():
# Referencing router enables state.
a3.add_page(rx.box(rx.text(State.router.page.full_path)), route="/")
+ a3._compile_page("index")
assert a3.state is not None
a4 = App()
@@ -1388,16 +1397,10 @@ def test_app_state_determination():
a4.add_page(
rx.box(rx.button("Click", on_click=DynamicState.on_counter)), route="/page2"
)
+ a4._compile_page("page2")
assert a4.state is not None
-# for coverage
-def test_raise_on_connect_error():
- """Test that the connect_error function is called."""
- with pytest.raises(ValueError):
- App(connect_error_component="Foo")
-
-
def test_raise_on_state():
"""Test that the state is set."""
# state kwargs is deprecated, we just make sure the app is created anyway.
@@ -1467,17 +1470,23 @@ def test_add_page_component_returning_tuple():
app.add_page(index) # type: ignore
app.add_page(page2) # type: ignore
- assert isinstance((fragment_wrapper := app.pages["index"].children[0]), Fragment)
- assert isinstance((first_text := fragment_wrapper.children[0]), Text)
+ app._compile_page("index")
+ app._compile_page("page2")
+
+ fragment_wrapper = app.pages["index"].children[0]
+ assert isinstance(fragment_wrapper, Fragment)
+ first_text = fragment_wrapper.children[0]
+ assert isinstance(first_text, Text)
assert str(first_text.children[0].contents) == '"first"' # type: ignore
- assert isinstance((second_text := fragment_wrapper.children[1]), Text)
+ second_text = fragment_wrapper.children[1]
+ assert isinstance(second_text, Text)
assert str(second_text.children[0].contents) == '"second"' # type: ignore
# Test page with trailing comma.
- assert isinstance(
- (page2_fragment_wrapper := app.pages["page2"].children[0]), Fragment
- )
- assert isinstance((third_text := page2_fragment_wrapper.children[0]), Text)
+ page2_fragment_wrapper = app.pages["page2"].children[0]
+ assert isinstance(page2_fragment_wrapper, Fragment)
+ third_text = page2_fragment_wrapper.children[0]
+ assert isinstance(third_text, Text)
assert str(third_text.children[0].contents) == '"third"' # type: ignore
diff --git a/tests/units/test_attribute_access_type.py b/tests/units/test_attribute_access_type.py
index 0d490ec1e..d08c17c8c 100644
--- a/tests/units/test_attribute_access_type.py
+++ b/tests/units/test_attribute_access_type.py
@@ -3,11 +3,19 @@ from __future__ import annotations
from typing import Dict, List, Optional, Type, Union
import attrs
+import pydantic.v1
import pytest
import sqlalchemy
+import sqlmodel
from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
+from sqlalchemy.orm import (
+ DeclarativeBase,
+ Mapped,
+ MappedAsDataclass,
+ mapped_column,
+ relationship,
+)
import reflex as rx
from reflex.utils.types import GenericType, get_attribute_access_type
@@ -53,6 +61,10 @@ class SQLALabel(SQLABase):
id: Mapped[int] = mapped_column(primary_key=True)
test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
test: Mapped[SQLAClass] = relationship(back_populates="labels")
+ test_dataclass_id: Mapped[int] = mapped_column(
+ sqlalchemy.ForeignKey("test_dataclass.id")
+ )
+ test_dataclass: Mapped[SQLAClassDataclass] = relationship(back_populates="labels")
class SQLAClass(SQLABase):
@@ -104,9 +116,64 @@ class SQLAClass(SQLABase):
return self.labels[0] if self.labels else None
+class SQLAClassDataclass(MappedAsDataclass, SQLABase):
+ """Test sqlalchemy model."""
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ no_default: Mapped[int] = mapped_column(nullable=True)
+ count: Mapped[int] = mapped_column()
+ name: Mapped[str] = mapped_column()
+ int_list: Mapped[List[int]] = mapped_column(
+ sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER)
+ )
+ str_list: Mapped[List[str]] = mapped_column(
+ sqlalchemy.types.ARRAY(item_type=sqlalchemy.String)
+ )
+ optional_int: Mapped[Optional[int]] = mapped_column(nullable=True)
+ sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
+ sqla_tag: Mapped[Optional[SQLATag]] = relationship()
+ labels: Mapped[List[SQLALabel]] = relationship(back_populates="test_dataclass")
+ # do not use lower case dict here!
+ # https://github.com/sqlalchemy/sqlalchemy/issues/9902
+ dict_str_str: Mapped[Dict[str, str]] = mapped_column()
+ default_factory: Mapped[List[int]] = mapped_column(
+ sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER),
+ default_factory=list,
+ )
+ __tablename__: str = "test_dataclass"
+
+ @property
+ def str_property(self) -> str:
+ """String property.
+
+ Returns:
+ Name attribute
+ """
+ return self.name
+
+ @hybrid_property
+ def str_or_int_property(self) -> Union[str, int]:
+ """String or int property.
+
+ Returns:
+ Name attribute
+ """
+ return self.name
+
+ @hybrid_property
+ def first_label(self) -> Optional[SQLALabel]:
+ """First label property.
+
+ Returns:
+ First label
+ """
+ return self.labels[0] if self.labels else None
+
+
class ModelClass(rx.Model):
"""Test reflex model."""
+ no_default: Optional[int] = sqlmodel.Field(nullable=True)
count: int = 0
name: str = "test"
int_list: List[int] = []
@@ -115,6 +182,7 @@ class ModelClass(rx.Model):
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
+ default_factory: List[int] = sqlmodel.Field(default_factory=list)
@property
def str_property(self) -> str:
@@ -147,6 +215,7 @@ class ModelClass(rx.Model):
class BaseClass(rx.Base):
"""Test rx.Base class."""
+ no_default: Optional[int] = pydantic.v1.Field(required=False)
count: int = 0
name: str = "test"
int_list: List[int] = []
@@ -155,6 +224,7 @@ class BaseClass(rx.Base):
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
+ default_factory: List[int] = pydantic.v1.Field(default_factory=list)
@property
def str_property(self) -> str:
@@ -236,6 +306,7 @@ class AttrClass:
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
+ default_factory: List[int] = attrs.field(factory=list)
@property
def str_property(self) -> str:
@@ -265,27 +336,17 @@ class AttrClass:
return self.labels[0] if self.labels else None
-@pytest.fixture(
- params=[
+@pytest.mark.parametrize(
+ "cls",
+ [
SQLAClass,
+ SQLAClassDataclass,
BaseClass,
BareClass,
ModelClass,
AttrClass,
- ]
+ ],
)
-def cls(request: pytest.FixtureRequest) -> type:
- """Fixture for the class to test.
-
- Args:
- request: pytest request object.
-
- Returns:
- Class to test.
- """
- return request.param
-
-
@pytest.mark.parametrize(
"attr, expected",
[
@@ -311,3 +372,38 @@ def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType)
expected: Expected type.
"""
assert get_attribute_access_type(cls, attr) == expected
+
+
+@pytest.mark.parametrize(
+ "cls",
+ [
+ SQLAClassDataclass,
+ BaseClass,
+ ModelClass,
+ AttrClass,
+ ],
+)
+def test_get_attribute_access_type_default_factory(cls: type) -> None:
+ """Test get_attribute_access_type returns the correct type for default factory fields.
+
+ Args:
+ cls: Class to test.
+ """
+ assert get_attribute_access_type(cls, "default_factory") == List[int]
+
+
+@pytest.mark.parametrize(
+ "cls",
+ [
+ SQLAClassDataclass,
+ BaseClass,
+ ModelClass,
+ ],
+)
+def test_get_attribute_access_type_no_default(cls: type) -> None:
+ """Test get_attribute_access_type returns the correct type for fields with no default which are not required.
+
+ Args:
+ cls: Class to test.
+ """
+ assert get_attribute_access_type(cls, "no_default") == Optional[int]
diff --git a/tests/units/test_config.py b/tests/units/test_config.py
index 31dd77649..e5d4622bd 100644
--- a/tests/units/test_config.py
+++ b/tests/units/test_config.py
@@ -1,11 +1,21 @@
import multiprocessing
import os
+from pathlib import Path
+from typing import Any, Dict
import pytest
import reflex as rx
import reflex.config
-from reflex.constants import Endpoint
+from reflex.config import (
+ EnvVar,
+ env_var,
+ environment,
+ interpret_boolean_env,
+ interpret_enum_env,
+ interpret_int_env,
+)
+from reflex.constants import Endpoint, Env
def test_requires_app_name():
@@ -41,7 +51,12 @@ def test_set_app_name(base_config_values):
("TELEMETRY_ENABLED", True),
],
)
-def test_update_from_env(base_config_values, monkeypatch, env_var, value):
+def test_update_from_env(
+ base_config_values: Dict[str, Any],
+ monkeypatch: pytest.MonkeyPatch,
+ env_var: str,
+ value: Any,
+):
"""Test that environment variables override config values.
Args:
@@ -56,6 +71,29 @@ def test_update_from_env(base_config_values, monkeypatch, env_var, value):
assert getattr(config, env_var.lower()) == value
+def test_update_from_env_path(
+ base_config_values: Dict[str, Any],
+ monkeypatch: pytest.MonkeyPatch,
+ tmp_path: Path,
+):
+ """Test that environment variables override config values.
+
+ Args:
+ base_config_values: Config values.
+ monkeypatch: The pytest monkeypatch object.
+ tmp_path: The pytest tmp_path fixture object.
+ """
+ monkeypatch.setenv("BUN_PATH", "/test")
+ assert os.environ.get("BUN_PATH") == "/test"
+ with pytest.raises(ValueError):
+ rx.Config(**base_config_values)
+
+ monkeypatch.setenv("BUN_PATH", str(tmp_path))
+ assert os.environ.get("BUN_PATH") == str(tmp_path)
+ config = rx.Config(**base_config_values)
+ assert config.bun_path == tmp_path
+
+
@pytest.mark.parametrize(
"kwargs, expected",
[
@@ -177,11 +215,11 @@ def test_replace_defaults(
assert getattr(c, key) == value
-def reflex_dir_constant():
- return rx.constants.Reflex.DIR
+def reflex_dir_constant() -> Path:
+ return environment.REFLEX_DIR.get()
-def test_reflex_dir_env_var(monkeypatch, tmp_path):
+def test_reflex_dir_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Test that the REFLEX_DIR environment variable is used to set the Reflex.DIR constant.
Args:
@@ -191,5 +229,54 @@ def test_reflex_dir_env_var(monkeypatch, tmp_path):
monkeypatch.setenv("REFLEX_DIR", str(tmp_path))
mp_ctx = multiprocessing.get_context(method="spawn")
+ assert reflex_dir_constant() == tmp_path
with mp_ctx.Pool(processes=1) as pool:
- assert pool.apply(reflex_dir_constant) == str(tmp_path)
+ assert pool.apply(reflex_dir_constant) == tmp_path
+
+
+def test_interpret_enum_env() -> None:
+ assert interpret_enum_env(Env.PROD.value, Env, "REFLEX_ENV") == Env.PROD
+
+
+def test_interpret_int_env() -> None:
+ assert interpret_int_env("3001", "FRONTEND_PORT") == 3001
+
+
+@pytest.mark.parametrize("value, expected", [("true", True), ("false", False)])
+def test_interpret_bool_env(value: str, expected: bool) -> None:
+ assert interpret_boolean_env(value, "TELEMETRY_ENABLED") == expected
+
+
+def test_env_var():
+ class TestEnv:
+ BLUBB: EnvVar[str] = env_var("default")
+ INTERNAL: EnvVar[str] = env_var("default", internal=True)
+ BOOLEAN: EnvVar[bool] = env_var(False)
+
+ assert TestEnv.BLUBB.get() == "default"
+ assert TestEnv.BLUBB.name == "BLUBB"
+ TestEnv.BLUBB.set("new")
+ assert os.environ.get("BLUBB") == "new"
+ assert TestEnv.BLUBB.get() == "new"
+ TestEnv.BLUBB.set(None)
+ assert "BLUBB" not in os.environ
+
+ assert TestEnv.INTERNAL.get() == "default"
+ assert TestEnv.INTERNAL.name == "__INTERNAL"
+ TestEnv.INTERNAL.set("new")
+ assert os.environ.get("__INTERNAL") == "new"
+ assert TestEnv.INTERNAL.get() == "new"
+ assert TestEnv.INTERNAL.getenv() == "new"
+ TestEnv.INTERNAL.set(None)
+ assert "__INTERNAL" not in os.environ
+
+ assert TestEnv.BOOLEAN.get() is False
+ assert TestEnv.BOOLEAN.name == "BOOLEAN"
+ TestEnv.BOOLEAN.set(True)
+ assert os.environ.get("BOOLEAN") == "True"
+ assert TestEnv.BOOLEAN.get() is True
+ TestEnv.BOOLEAN.set(False)
+ assert os.environ.get("BOOLEAN") == "False"
+ assert TestEnv.BOOLEAN.get() is False
+ TestEnv.BOOLEAN.set(None)
+ assert "BOOLEAN" not in os.environ
diff --git a/tests/units/test_db_config.py b/tests/units/test_db_config.py
index b8d7c07cb..5b716e6bb 100644
--- a/tests/units/test_db_config.py
+++ b/tests/units/test_db_config.py
@@ -164,7 +164,7 @@ def test_constructor_postgresql(username, password, host, port, database, expect
"localhost",
5432,
"db",
- "postgresql+psycopg2://user:pass@localhost:5432/db",
+ "postgresql+psycopg://user:pass@localhost:5432/db",
),
(
"user",
@@ -172,17 +172,17 @@ def test_constructor_postgresql(username, password, host, port, database, expect
"localhost",
None,
"db",
- "postgresql+psycopg2://user@localhost/db",
+ "postgresql+psycopg://user@localhost/db",
),
- ("user", "", "", None, "db", "postgresql+psycopg2://user@/db"),
- ("", "", "localhost", 5432, "db", "postgresql+psycopg2://localhost:5432/db"),
- ("", "", "", None, "db", "postgresql+psycopg2:///db"),
+ ("user", "", "", None, "db", "postgresql+psycopg://user@/db"),
+ ("", "", "localhost", 5432, "db", "postgresql+psycopg://localhost:5432/db"),
+ ("", "", "", None, "db", "postgresql+psycopg:///db"),
],
)
-def test_constructor_postgresql_psycopg2(
+def test_constructor_postgresql_psycopg(
username, password, host, port, database, expected_url
):
- """Test DBConfig.postgresql_psycopg2 constructor creates the instance correctly.
+ """Test DBConfig.postgresql_psycopg constructor creates the instance correctly.
Args:
username: Database username.
@@ -192,10 +192,10 @@ def test_constructor_postgresql_psycopg2(
database: Database name.
expected_url: Expected database URL generated.
"""
- db_config = DBConfig.postgresql_psycopg2(
+ db_config = DBConfig.postgresql_psycopg(
username=username, password=password, host=host, port=port, database=database
)
- assert db_config.engine == "postgresql+psycopg2"
+ assert db_config.engine == "postgresql+psycopg"
assert db_config.username == username
assert db_config.password == password
assert db_config.host == host
diff --git a/tests/units/test_event.py b/tests/units/test_event.py
index 3996a6101..c5198a571 100644
--- a/tests/units/test_event.py
+++ b/tests/units/test_event.py
@@ -1,12 +1,20 @@
-from typing import List
+from typing import Callable, List
import pytest
-from reflex import event
-from reflex.event import Event, EventHandler, EventSpec, call_event_handler, fix_events
+import reflex as rx
+from reflex.event import (
+ Event,
+ EventChain,
+ EventHandler,
+ EventSpec,
+ call_event_handler,
+ event,
+ fix_events,
+)
from reflex.state import BaseState
from reflex.utils import format
-from reflex.vars.base import LiteralVar, Var
+from reflex.vars.base import Field, LiteralVar, Var, field
def make_var(value) -> Var:
@@ -100,7 +108,7 @@ def test_call_event_handler_partial():
def spec(a2: Var[str]) -> List[Var[str]]:
return [a2]
- handler = EventHandler(fn=test_fn_with_args)
+ handler = EventHandler(fn=test_fn_with_args, state_full_name="BigState")
event_spec = handler(make_var("first"))
event_spec2 = call_event_handler(event_spec, spec)
@@ -108,7 +116,10 @@ def test_call_event_handler_partial():
assert len(event_spec.args) == 1
assert event_spec.args[0][0].equals(Var(_js_expr="arg1"))
assert event_spec.args[0][1].equals(Var(_js_expr="first"))
- assert format.format_event(event_spec) == 'Event("test_fn_with_args", {arg1:first})'
+ assert (
+ format.format_event(event_spec)
+ == 'Event("BigState.test_fn_with_args", {arg1:first})'
+ )
assert event_spec2 is not event_spec
assert event_spec2.handler == handler
@@ -119,7 +130,7 @@ def test_call_event_handler_partial():
assert event_spec2.args[1][1].equals(Var(_js_expr="_a2", _var_type=str))
assert (
format.format_event(event_spec2)
- == 'Event("test_fn_with_args", {arg1:first,arg2:_a2})'
+ == 'Event("BigState.test_fn_with_args", {arg1:first,arg2:_a2})'
)
@@ -198,10 +209,6 @@ def test_event_redirect(input, output):
assert isinstance(spec, EventSpec)
assert spec.handler.fn.__qualname__ == "_redirect"
- # this asserts need comment about what it's testing (they fail with Var as input)
- # assert spec.args[0][0].equals(Var(_js_expr="path"))
- # assert spec.args[0][1].equals(Var(_js_expr="/path"))
-
assert format.format_event(spec) == output
@@ -209,24 +216,40 @@ def test_event_console_log():
"""Test the event console log function."""
spec = event.console_log("message")
assert isinstance(spec, EventSpec)
- assert spec.handler.fn.__qualname__ == "_console"
- assert spec.args[0][0].equals(Var(_js_expr="message"))
- assert spec.args[0][1].equals(LiteralVar.create("message"))
- assert format.format_event(spec) == 'Event("_console", {message:"message"})'
+ assert spec.handler.fn.__qualname__ == "_call_function"
+ assert spec.args[0][0].equals(Var(_js_expr="function"))
+ assert spec.args[0][1].equals(
+ Var('(() => (console["log"]("message")))', _var_type=Callable)
+ )
+ assert (
+ format.format_event(spec)
+ == 'Event("_call_function", {function:(() => (console["log"]("message")))})'
+ )
spec = event.console_log(Var(_js_expr="message"))
- assert format.format_event(spec) == 'Event("_console", {message:message})'
+ assert (
+ format.format_event(spec)
+ == 'Event("_call_function", {function:(() => (console["log"](message)))})'
+ )
def test_event_window_alert():
"""Test the event window alert function."""
spec = event.window_alert("message")
assert isinstance(spec, EventSpec)
- assert spec.handler.fn.__qualname__ == "_alert"
- assert spec.args[0][0].equals(Var(_js_expr="message"))
- assert spec.args[0][1].equals(LiteralVar.create("message"))
- assert format.format_event(spec) == 'Event("_alert", {message:"message"})'
+ assert spec.handler.fn.__qualname__ == "_call_function"
+ assert spec.args[0][0].equals(Var(_js_expr="function"))
+ assert spec.args[0][1].equals(
+ Var('(() => (window["alert"]("message")))', _var_type=Callable)
+ )
+ assert (
+ format.format_event(spec)
+ == 'Event("_call_function", {function:(() => (window["alert"]("message")))})'
+ )
spec = event.window_alert(Var(_js_expr="message"))
- assert format.format_event(spec) == 'Event("_alert", {message:message})'
+ assert (
+ format.format_event(spec)
+ == 'Event("_call_function", {function:(() => (window["alert"](message)))})'
+ )
def test_set_focus():
@@ -292,7 +315,7 @@ def test_remove_cookie_with_options():
assert spec.args[1][1].equals(LiteralVar.create(options))
assert (
format.format_event(spec)
- == f'Event("_remove_cookie", {{key:"testkey",options:{str(LiteralVar.create(options))}}})'
+ == f'Event("_remove_cookie", {{key:"testkey",options:{LiteralVar.create(options)!s}}})'
)
@@ -388,3 +411,42 @@ def test_event_actions_on_state():
assert sp_handler.event_actions == {"stopPropagation": True}
# should NOT affect other references to the handler
assert not handler.event_actions
+
+
+def test_event_var_data():
+ class S(BaseState):
+ x: Field[int] = field(0)
+
+ @event
+ def s(self, value: int):
+ pass
+
+ # Handler doesn't have any _var_data because it's just a str
+ handler_var = Var.create(S.s)
+ assert handler_var._get_all_var_data() is None
+
+ # Ensure spec carries _var_data
+ spec_var = Var.create(S.s(S.x))
+ assert spec_var._get_all_var_data() == S.x._get_all_var_data()
+
+ # Needed to instantiate the EventChain
+ def _args_spec(value: Var[int]) -> tuple[Var[int]]:
+ return (value,)
+
+ # Ensure chain carries _var_data
+ chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec))
+ assert chain_var._get_all_var_data() == S.x._get_all_var_data()
+
+
+def test_event_bound_method() -> None:
+ class S(BaseState):
+ @event
+ def e(self, arg: str):
+ print(arg)
+
+ class Wrapper:
+ def get_handler(self, arg: str):
+ return S.e(arg)
+
+ w = Wrapper()
+ _ = rx.input(on_change=w.get_handler)
diff --git a/tests/units/test_model.py b/tests/units/test_model.py
index ac8187e03..0a83f39ec 100644
--- a/tests/units/test_model.py
+++ b/tests/units/test_model.py
@@ -46,7 +46,7 @@ def test_default_primary_key(model_default_primary: Model):
Args:
model_default_primary: Fixture.
"""
- assert "id" in model_default_primary.__class__.__fields__
+ assert "id" in type(model_default_primary).__fields__
def test_custom_primary_key(model_custom_primary: Model):
@@ -55,7 +55,7 @@ def test_custom_primary_key(model_custom_primary: Model):
Args:
model_custom_primary: Fixture.
"""
- assert "id" not in model_custom_primary.__class__.__fields__
+ assert "id" not in type(model_custom_primary).__fields__
@pytest.mark.filterwarnings(
diff --git a/tests/units/test_prerequisites.py b/tests/units/test_prerequisites.py
index c4f57a998..2497318e7 100644
--- a/tests/units/test_prerequisites.py
+++ b/tests/units/test_prerequisites.py
@@ -24,7 +24,15 @@ from reflex.utils.prerequisites import (
app_name="test",
),
False,
- 'module.exports = {basePath: "", compress: true, reactStrictMode: true, trailingSlash: true};',
+ 'module.exports = {basePath: "", compress: true, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
+ ),
+ (
+ Config(
+ app_name="test",
+ static_page_generation_timeout=30,
+ ),
+ False,
+ 'module.exports = {basePath: "", compress: true, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 30};',
),
(
Config(
@@ -32,7 +40,7 @@ from reflex.utils.prerequisites import (
next_compression=False,
),
False,
- 'module.exports = {basePath: "", compress: false, reactStrictMode: true, trailingSlash: true};',
+ 'module.exports = {basePath: "", compress: false, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
),
(
Config(
@@ -40,7 +48,7 @@ from reflex.utils.prerequisites import (
frontend_path="/test",
),
False,
- 'module.exports = {basePath: "/test", compress: true, reactStrictMode: true, trailingSlash: true};',
+ 'module.exports = {basePath: "/test", compress: true, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
),
(
Config(
@@ -49,14 +57,14 @@ from reflex.utils.prerequisites import (
next_compression=False,
),
False,
- 'module.exports = {basePath: "/test", compress: false, reactStrictMode: true, trailingSlash: true};',
+ 'module.exports = {basePath: "/test", compress: false, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
),
(
Config(
app_name="test",
),
True,
- 'module.exports = {basePath: "", compress: true, reactStrictMode: true, trailingSlash: true, output: "export", distDir: "_static"};',
+ 'module.exports = {basePath: "", compress: true, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60, output: "export", distDir: "_static"};',
),
],
)
diff --git a/tests/units/test_sqlalchemy.py b/tests/units/test_sqlalchemy.py
index b18799e0c..23e315785 100644
--- a/tests/units/test_sqlalchemy.py
+++ b/tests/units/test_sqlalchemy.py
@@ -127,8 +127,8 @@ def test_automigration(
assert result[0].b == 4.2
# No-op
- # assert Model.migrate(autogenerate=True)
- # assert len(list(versions.glob("*.py"))) == 4
+ # assert Model.migrate(autogenerate=True) #noqa: ERA001
+ # assert len(list(versions.glob("*.py"))) == 4 #noqa: ERA001
# drop table (AlembicSecond)
model_registry.get_metadata().clear()
diff --git a/tests/units/test_state.py b/tests/units/test_state.py
index 205162b9f..912d72f4f 100644
--- a/tests/units/test_state.py
+++ b/tests/units/test_state.py
@@ -8,12 +8,26 @@ import functools
import json
import os
import sys
+import threading
from textwrap import dedent
-from typing import Any, Callable, Dict, Generator, List, Optional, Union
+from typing import (
+ Any,
+ AsyncGenerator,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
from unittest.mock import AsyncMock, Mock
import pytest
+import pytest_asyncio
from plotly.graph_objects import Figure
+from pydantic import BaseModel as BaseModelV2
+from pydantic.v1 import BaseModel as BaseModelV1
import reflex as rx
import reflex.config
@@ -41,14 +55,22 @@ from reflex.state import (
)
from reflex.testing import chdir
from reflex.utils import format, prerequisites, types
+from reflex.utils.exceptions import (
+ InvalidLockWarningThresholdError,
+ ReflexRuntimeError,
+ SetUndefinedStateVarError,
+ StateSerializationError,
+)
from reflex.utils.format import json_dumps
-from reflex.vars.base import ComputedVar, Var
+from reflex.vars.base import Var, computed_var
from tests.units.states.mutation import MutableSQLAModel, MutableTestState
from .states import GenState
CI = bool(os.environ.get("CI", False))
-LOCK_EXPIRATION = 2000 if CI else 300
+LOCK_EXPIRATION = 2500 if CI else 300
+LOCK_WARNING_THRESHOLD = 1000 if CI else 100
+LOCK_WARN_SLEEP = 1.5 if CI else 0.15
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
@@ -103,8 +125,10 @@ class TestState(BaseState):
complex: Dict[int, Object] = {1: Object(), 2: Object()}
fig: Figure = Figure()
dt: datetime.datetime = datetime.datetime.fromisoformat("1989-11-09T18:53:00+01:00")
+ _backend: int = 0
+ asynctest: int = 0
- @ComputedVar
+ @computed_var
def sum(self) -> float:
"""Dynamically sum the numbers.
@@ -113,7 +137,7 @@ class TestState(BaseState):
"""
return self.num1 + self.num2
- @ComputedVar
+ @computed_var
def upper(self) -> str:
"""Uppercase the key.
@@ -126,6 +150,14 @@ class TestState(BaseState):
"""Do something."""
pass
+ async def set_asynctest(self, value: int):
+ """Set the asynctest value. Intentionally overwrite the default setter with an async one.
+
+ Args:
+ value: The new value.
+ """
+ self.asynctest = value
+
class ChildState(TestState):
"""A child state fixture."""
@@ -272,9 +304,9 @@ def test_base_class_vars(test_state):
assert isinstance(prop, Var)
assert prop._js_expr.split(".")[-1] == field
- assert cls.num1._var_type == int
- assert cls.num2._var_type == float
- assert cls.key._var_type == str
+ assert cls.num1._var_type is int
+ assert cls.num2._var_type is float
+ assert cls.key._var_type is str
def test_computed_class_var(test_state):
@@ -310,6 +342,7 @@ def test_class_vars(test_state):
"upper",
"fig",
"dt",
+ "asynctest",
}
@@ -524,7 +557,7 @@ def test_set_class_var():
TestState._set_var(Var(_js_expr="num3", _var_type=int)._var_set_state(TestState))
var = TestState.num3 # type: ignore
assert var._js_expr == TestState.get_full_name() + ".num3"
- assert var._var_type == int
+ assert var._var_type is int
assert var._var_state == TestState.get_full_name()
@@ -704,6 +737,7 @@ def test_reset(test_state, child_state):
# Set some values.
test_state.num1 = 1
test_state.num2 = 2
+ test_state._backend = 3
child_state.value = "test"
# Reset the state.
@@ -712,6 +746,7 @@ def test_reset(test_state, child_state):
# The values should be reset.
assert test_state.num1 == 0
assert test_state.num2 == 3.14
+ assert test_state._backend == 0
assert child_state.value == ""
expected_dirty_vars = {
@@ -727,6 +762,8 @@ def test_reset(test_state, child_state):
"map_key",
"mapping",
"dt",
+ "_backend",
+ "asynctest",
}
# The dirty vars should be reset.
@@ -757,7 +794,6 @@ async def test_process_event_simple(test_state):
assert test_state.num1 == 69
# The delta should contain the changes, including computed vars.
- # assert update.delta == {"test_state": {"num1": 69, "sum": 72.14}}
assert update.delta == {
TestState.get_full_name(): {"num1": 69, "sum": 72.14, "upper": ""},
GrandchildState3.get_full_name(): {"computed": ""},
@@ -1106,7 +1142,7 @@ def test_child_state():
v: int = 2
class ChildState(MainState):
- @ComputedVar
+ @computed_var
def rendered_var(self):
return self.v
@@ -1125,7 +1161,7 @@ def test_conditional_computed_vars():
t1: str = "a"
t2: str = "b"
- @ComputedVar
+ @computed_var
def rendered_var(self) -> str:
if self.flag:
return self.t1
@@ -1284,19 +1320,19 @@ def test_computed_var_depends_on_parent_non_cached():
assert ps.dirty_vars == set()
assert cs.dirty_vars == set()
- dict1 = ps.dict()
+ dict1 = json.loads(json_dumps(ps.dict()))
assert dict1[ps.get_full_name()] == {
"no_cache_v": 1,
"router": formatted_router,
}
assert dict1[cs.get_full_name()] == {"dep_v": 2}
- dict2 = ps.dict()
+ dict2 = json.loads(json_dumps(ps.dict()))
assert dict2[ps.get_full_name()] == {
"no_cache_v": 3,
"router": formatted_router,
}
assert dict2[cs.get_full_name()] == {"dep_v": 4}
- dict3 = ps.dict()
+ dict3 = json.loads(json_dumps(ps.dict()))
assert dict3[ps.get_full_name()] == {
"no_cache_v": 5,
"router": formatted_router,
@@ -1540,7 +1576,7 @@ def test_error_on_state_method_shadow():
assert (
err.value.args[0]
- == f"The event handler name `reset` shadows a builtin State method; use a different name instead"
+ == "The event handler name `reset` shadows a builtin State method; use a different name instead"
)
@@ -1596,8 +1632,10 @@ async def test_state_with_invalid_yield(capsys, mock_app):
assert "must only return/yield: None, Events or other EventHandlers" in captured.out
-@pytest.fixture(scope="function", params=["in_process", "disk", "redis"])
-def state_manager(request) -> Generator[StateManager, None, None]:
+@pytest_asyncio.fixture(
+ loop_scope="function", scope="function", params=["in_process", "disk", "redis"]
+)
+async def state_manager(request) -> AsyncGenerator[StateManager, None]:
"""Instance of state manager parametrized for redis and in-process.
Args:
@@ -1621,7 +1659,7 @@ def state_manager(request) -> Generator[StateManager, None, None]:
yield state_manager
if isinstance(state_manager, StateManagerRedis):
- asyncio.get_event_loop().run_until_complete(state_manager.close())
+ await state_manager.close()
@pytest.fixture()
@@ -1666,7 +1704,7 @@ async def test_state_manager_modify_state(
assert not state_manager._states_locks[token].locked()
# separate instances should NOT share locks
- sm2 = state_manager.__class__(state=TestState)
+ sm2 = type(state_manager)(state=TestState)
assert sm2._state_manager_lock is state_manager._state_manager_lock
assert not sm2._states_locks
if state_manager._states_locks:
@@ -1709,8 +1747,8 @@ async def test_state_manager_contend(
assert not state_manager._states_locks[token].locked()
-@pytest.fixture(scope="function")
-def state_manager_redis() -> Generator[StateManager, None, None]:
+@pytest_asyncio.fixture(loop_scope="function", scope="function")
+async def state_manager_redis() -> AsyncGenerator[StateManager, None]:
"""Instance of state manager for redis only.
Yields:
@@ -1723,7 +1761,7 @@ def state_manager_redis() -> Generator[StateManager, None, None]:
yield state_manager
- asyncio.get_event_loop().run_until_complete(state_manager.close())
+ await state_manager.close()
@pytest.fixture()
@@ -1752,6 +1790,7 @@ async def test_state_manager_lock_expire(
substate_token_redis: A token + substate name for looking up in state manager.
"""
state_manager_redis.lock_expiration = LOCK_EXPIRATION
+ state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
async with state_manager_redis.modify_state(substate_token_redis):
await asyncio.sleep(0.01)
@@ -1776,6 +1815,7 @@ async def test_state_manager_lock_expire_contend(
unexp_num1 = 666
state_manager_redis.lock_expiration = LOCK_EXPIRATION
+ state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
order = []
@@ -1805,13 +1845,63 @@ async def test_state_manager_lock_expire_contend(
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
+@pytest.mark.asyncio
+async def test_state_manager_lock_warning_threshold_contend(
+ state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker
+):
+ """Test that the state manager triggers a warning when lock contention exceeds the warning threshold.
+
+ Args:
+ state_manager_redis: A state manager instance.
+ token: A token.
+ substate_token_redis: A token + substate name for looking up in state manager.
+ mocker: Pytest mocker object.
+ """
+ console_warn = mocker.patch("reflex.utils.console.warn")
+
+ state_manager_redis.lock_expiration = LOCK_EXPIRATION
+ state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
+
+ order = []
+
+ async def _coro_blocker():
+ async with state_manager_redis.modify_state(substate_token_redis):
+ order.append("blocker")
+ await asyncio.sleep(LOCK_WARN_SLEEP)
+
+ tasks = [
+ asyncio.create_task(_coro_blocker()),
+ ]
+
+ await tasks[0]
+ console_warn.assert_called()
+ assert console_warn.call_count == 7
+
+
+class CopyingAsyncMock(AsyncMock):
+ """An AsyncMock, but deepcopy the args and kwargs first."""
+
+ def __call__(self, *args, **kwargs):
+ """Call the mock.
+
+ Args:
+ args: the arguments passed to the mock
+ kwargs: the keyword arguments passed to the mock
+
+ Returns:
+ The result of the mock call
+ """
+ args = copy.deepcopy(args)
+ kwargs = copy.deepcopy(kwargs)
+ return super().__call__(*args, **kwargs)
+
+
@pytest.fixture(scope="function")
-def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
- """Mock app fixture.
+def mock_app_simple(monkeypatch) -> rx.App:
+ """Simple Mock app fixture.
Args:
monkeypatch: Pytest monkeypatch object.
- state_manager: A state manager.
Returns:
The app, after mocking out prerequisites.get_app()
@@ -1822,8 +1912,7 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
setattr(app_module, CompileVars.APP, app)
app.state = TestState
- app._state_manager = state_manager
- app.event_namespace.emit = AsyncMock() # type: ignore
+ app.event_namespace.emit = CopyingAsyncMock() # type: ignore
def _mock_get_app(*args, **kwargs):
return app_module
@@ -1832,6 +1921,21 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
return app
+@pytest.fixture(scope="function")
+def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App:
+ """Mock app fixture.
+
+ Args:
+ mock_app_simple: A simple mock app.
+ state_manager: A state manager.
+
+ Returns:
+ The app, after mocking out prerequisites.get_app()
+ """
+ mock_app_simple._state_manager = state_manager
+ return mock_app_simple
+
+
@pytest.mark.asyncio
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
"""Test that the state proxy works.
@@ -1883,11 +1987,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
async with sp:
assert sp._self_actx is not None
assert sp._self_mutable # proxy is mutable inside context
- if isinstance(mock_app.state_manager, StateManagerMemory):
+ if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
# For in-process store, only one instance of the state exists
assert sp.__wrapped__ is grandchild_state
else:
- # When redis or disk is used, a new+updated instance is assigned to the proxy
+ # When redis is used, a new+updated instance is assigned to the proxy
assert sp.__wrapped__ is not grandchild_state
sp.value2 = "42"
assert not sp._self_mutable # proxy is not mutable after exiting context
@@ -1898,7 +2002,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
gotten_state = await mock_app.state_manager.get_state(
_substate_key(grandchild_state.router.session.client_token, grandchild_state)
)
- if isinstance(mock_app.state_manager, StateManagerMemory):
+ if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
# For in-process store, only one instance of the state exists
assert gotten_state is parent_state
else:
@@ -1912,21 +2016,19 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
mock_app.event_namespace.emit.assert_called_once()
mcall = mock_app.event_namespace.emit.mock_calls[0]
assert mcall.args[0] == str(SocketEvent.EVENT)
- assert json.loads(mcall.args[1]) == dataclasses.asdict(
- StateUpdate(
- delta={
- parent_state.get_full_name(): {
- "upper": "",
- "sum": 3.14,
- },
- grandchild_state.get_full_name(): {
- "value2": "42",
- },
- GrandchildState3.get_full_name(): {
- "computed": "",
- },
- }
- )
+ assert mcall.args[1] == StateUpdate(
+ delta={
+ parent_state.get_full_name(): {
+ "upper": "",
+ "sum": 3.14,
+ },
+ grandchild_state.get_full_name(): {
+ "value2": "42",
+ },
+ GrandchildState3.get_full_name(): {
+ "computed": "",
+ },
+ }
)
assert mcall.kwargs["to"] == grandchild_state.router.session.session_id
@@ -1937,6 +2039,10 @@ class BackgroundTaskState(BaseState):
order: List[str] = []
dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]}
+ def __init__(self, **kwargs): # noqa: D107
+ super().__init__(**kwargs)
+ self.router_data = {"simulate": "hydrate"}
+
@rx.var
def computed_order(self) -> List[str]:
"""Get the order as a computed var.
@@ -1946,7 +2052,7 @@ class BackgroundTaskState(BaseState):
"""
return self.order
- @rx.background
+ @rx.event(background=True)
async def background_task(self):
"""A background task that updates the state."""
async with self:
@@ -1983,7 +2089,7 @@ class BackgroundTaskState(BaseState):
self.other() # direct calling event handlers works in context
self._private_method()
- @rx.background
+ @rx.event(background=True)
async def background_task_reset(self):
"""A background task that resets the state."""
with pytest.raises(ImmutableStateError):
@@ -1997,7 +2103,7 @@ class BackgroundTaskState(BaseState):
async with self:
self.order.append("reset")
- @rx.background
+ @rx.event(background=True)
async def background_task_generator(self):
"""A background task generator that does nothing.
@@ -2104,51 +2210,51 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit
- first_ws_message = json.loads(emit_mock.mock_calls[0].args[1])
+ first_ws_message = emit_mock.mock_calls[0].args[1]
assert (
- first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router")
+ first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
is not None
)
- assert first_ws_message == {
- "delta": {
+ assert first_ws_message == StateUpdate(
+ delta={
BackgroundTaskState.get_full_name(): {
"order": ["background_task:start"],
"computed_order": ["background_task:start"],
}
},
- "events": [],
- "final": True,
- }
+ events=[],
+ final=True,
+ )
for call in emit_mock.mock_calls[1:5]:
- assert json.loads(call.args[1]) == {
- "delta": {
+ assert call.args[1] == StateUpdate(
+ delta={
BackgroundTaskState.get_full_name(): {
"computed_order": ["background_task:start"],
}
},
- "events": [],
- "final": True,
- }
- assert json.loads(emit_mock.mock_calls[-2].args[1]) == {
- "delta": {
+ events=[],
+ final=True,
+ )
+ assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
+ delta={
BackgroundTaskState.get_full_name(): {
"order": exp_order,
"computed_order": exp_order,
"dict_list": {},
}
},
- "events": [],
- "final": True,
- }
- assert json.loads(emit_mock.mock_calls[-1].args[1]) == {
- "delta": {
+ events=[],
+ final=True,
+ )
+ assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
+ delta={
BackgroundTaskState.get_full_name(): {
"computed_order": exp_order,
},
},
- "events": [],
- "final": True,
- }
+ events=[],
+ final=True,
+ )
@pytest.mark.asyncio
@@ -2494,13 +2600,16 @@ def test_mutable_copy_vars(mutable_state: MutableTestState, copy_func: Callable)
def test_duplicate_substate_class(mocker):
+ # Neuter pytest escape hatch, because we want to test duplicate detection.
mocker.patch("reflex.state.is_testing_env", lambda: False)
+ # Neuter state handling since these _are_ defined inside a function.
+ mocker.patch("reflex.state.BaseState._handle_local_def", lambda: None)
with pytest.raises(ValueError):
class TestState(BaseState):
pass
- class ChildTestState(TestState): # type: ignore # noqa
+ class ChildTestState(TestState): # type: ignore
pass
class ChildTestState(TestState): # type: ignore # noqa
@@ -2617,23 +2726,23 @@ def test_state_union_optional():
c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo="")))
custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="")
- assert str(UnionState.c3.c2) == f'{str(UnionState.c3)}?.["c2"]' # type: ignore
- assert str(UnionState.c3.c2.c1) == f'{str(UnionState.c3)}?.["c2"]?.["c1"]' # type: ignore
+ assert str(UnionState.c3.c2) == f'{UnionState.c3!s}?.["c2"]' # type: ignore
+ assert str(UnionState.c3.c2.c1) == f'{UnionState.c3!s}?.["c2"]?.["c1"]' # type: ignore
assert (
- str(UnionState.c3.c2.c1.foo) == f'{str(UnionState.c3)}?.["c2"]?.["c1"]?.["foo"]' # type: ignore
+ str(UnionState.c3.c2.c1.foo) == f'{UnionState.c3!s}?.["c2"]?.["c1"]?.["foo"]' # type: ignore
)
assert (
- str(UnionState.c3.c2.c1r.foo) == f'{str(UnionState.c3)}?.["c2"]?.["c1r"]["foo"]' # type: ignore
+ str(UnionState.c3.c2.c1r.foo) == f'{UnionState.c3!s}?.["c2"]?.["c1r"]["foo"]' # type: ignore
)
- assert str(UnionState.c3.c2r.c1) == f'{str(UnionState.c3)}?.["c2r"]["c1"]' # type: ignore
+ assert str(UnionState.c3.c2r.c1) == f'{UnionState.c3!s}?.["c2r"]["c1"]' # type: ignore
assert (
- str(UnionState.c3.c2r.c1.foo) == f'{str(UnionState.c3)}?.["c2r"]["c1"]?.["foo"]' # type: ignore
+ str(UnionState.c3.c2r.c1.foo) == f'{UnionState.c3!s}?.["c2r"]["c1"]?.["foo"]' # type: ignore
)
assert (
- str(UnionState.c3.c2r.c1r.foo) == f'{str(UnionState.c3)}?.["c2r"]["c1r"]["foo"]' # type: ignore
+ str(UnionState.c3.c2r.c1r.foo) == f'{UnionState.c3!s}?.["c2r"]["c1r"]["foo"]' # type: ignore
)
- assert str(UnionState.c3i.c2) == f'{str(UnionState.c3i)}["c2"]' # type: ignore
- assert str(UnionState.c3r.c2) == f'{str(UnionState.c3r)}["c2"]' # type: ignore
+ assert str(UnionState.c3i.c2) == f'{UnionState.c3i!s}["c2"]' # type: ignore
+ assert str(UnionState.c3r.c2) == f'{UnionState.c3r!s}["c2"]' # type: ignore
assert UnionState.custom_union.foo is not None # type: ignore
assert UnionState.custom_union.c1 is not None # type: ignore
assert UnionState.custom_union.c1r is not None # type: ignore
@@ -2684,7 +2793,7 @@ def test_set_base_field_via_setter():
assert "c2" in bfss.dirty_vars
-def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
+def exp_is_hydrated(state: BaseState, is_hydrated: bool = True) -> Dict[str, Any]:
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
Args:
@@ -2702,6 +2811,7 @@ class OnLoadState(State):
num: int = 0
+ @rx.event
def test_handler(self):
"""Test handler."""
self.num += 1
@@ -2762,7 +2872,8 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
app = app_module_mock.app = App(
state=State, load_events={"index": [test_state.test_handler]}
)
- state = State()
+ async with app.state_manager.modify_state(_substate_key(token, State)) as state:
+ state.router_data = {"simulate": "hydrate"}
updates = []
async for update in rx.app.process(
@@ -2789,6 +2900,9 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
}
assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state)
+ if isinstance(app.state_manager, StateManagerRedis):
+ await app.state_manager.close()
+
@pytest.mark.asyncio
async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
@@ -2806,7 +2920,8 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
state=State,
load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
)
- state = State()
+ async with app.state_manager.modify_state(_substate_key(token, State)) as state:
+ state.router_data = {"simulate": "hydrate"}
updates = []
async for update in rx.app.process(
@@ -2836,6 +2951,9 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
}
assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)
+ if isinstance(app.state_manager, StateManagerRedis):
+ await app.state_manager.close()
+
@pytest.mark.asyncio
async def test_get_state(mock_app: rx.App, token: str):
@@ -2921,7 +3039,7 @@ async def test_get_state(mock_app: rx.App, token: str):
_substate_key(token, ChildState2)
)
assert isinstance(new_test_state, TestState)
- if isinstance(mock_app.state_manager, StateManagerMemory):
+ if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
# In memory, it's the same instance
assert new_test_state is test_state
test_state._clean()
@@ -2931,15 +3049,6 @@ async def test_get_state(mock_app: rx.App, token: str):
ChildState2.get_name(),
ChildState3.get_name(),
)
- elif isinstance(mock_app.state_manager, StateManagerDisk):
- # On disk, it's a new instance
- assert new_test_state is not test_state
- # All substates are available
- assert tuple(sorted(new_test_state.substates)) == (
- ChildState.get_name(),
- ChildState2.get_name(),
- ChildState3.get_name(),
- )
else:
# With redis, we get a whole new instance
assert new_test_state is not test_state
@@ -3074,12 +3183,12 @@ def test_potentially_dirty_substates():
"""
class State(RxState):
- @ComputedVar
+ @computed_var
def foo(self) -> str:
return ""
class C1(State):
- @ComputedVar
+ @computed_var
def bar(self) -> str:
return ""
@@ -3171,16 +3280,53 @@ async def test_setvar(mock_app: rx.App, token: str):
TestState.setvar(42, 42)
+@pytest.mark.asyncio
+async def test_setvar_async_setter():
+ """Test that overridden async setters raise Exception when used with setvar."""
+ with pytest.raises(NotImplementedError):
+ TestState.setvar("asynctest", 42)
+
+
@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
@pytest.mark.parametrize(
"expiration_kwargs, expected_values",
[
- ({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)),
+ (
+ {"redis_lock_expiration": 20000},
+ (
+ 20000,
+ constants.Expiration.TOKEN,
+ constants.Expiration.LOCK_WARNING_THRESHOLD,
+ ),
+ ),
(
{"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
- (50000, 5600),
+ (50000, 5600, constants.Expiration.LOCK_WARNING_THRESHOLD),
+ ),
+ (
+ {"redis_token_expiration": 7600},
+ (
+ constants.Expiration.LOCK,
+ 7600,
+ constants.Expiration.LOCK_WARNING_THRESHOLD,
+ ),
+ ),
+ (
+ {"redis_lock_expiration": 50000, "redis_lock_warning_threshold": 1500},
+ (50000, constants.Expiration.TOKEN, 1500),
+ ),
+ (
+ {"redis_token_expiration": 5600, "redis_lock_warning_threshold": 3000},
+ (constants.Expiration.LOCK, 5600, 3000),
+ ),
+ (
+ {
+ "redis_lock_expiration": 50000,
+ "redis_token_expiration": 5600,
+ "redis_lock_warning_threshold": 2000,
+ },
+ (50000, 5600, 2000),
),
- ({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)),
],
)
def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
@@ -3196,6 +3342,7 @@ import reflex as rx
config = rx.Config(
app_name="project1",
redis_url="redis://localhost:6379",
+ state_manager_mode="redis",
{config_items}
)
"""
@@ -3209,6 +3356,44 @@ config = rx.Config(
state_manager = StateManager.create(state=State)
assert state_manager.lock_expiration == expected_values[0] # type: ignore
assert state_manager.token_expiration == expected_values[1] # type: ignore
+ assert state_manager.lock_warning_threshold == expected_values[2] # type: ignore
+
+
+@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
+@pytest.mark.parametrize(
+ "redis_lock_expiration, redis_lock_warning_threshold",
+ [
+ (10000, 10000),
+ (20000, 30000),
+ ],
+)
+def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold(
+ tmp_path, redis_lock_expiration, redis_lock_warning_threshold
+):
+ proj_root = tmp_path / "project1"
+ proj_root.mkdir()
+
+ config_string = f"""
+import reflex as rx
+config = rx.Config(
+ app_name="project1",
+ redis_url="redis://localhost:6379",
+ state_manager_mode="redis",
+ redis_lock_expiration = {redis_lock_expiration},
+ redis_lock_warning_threshold = {redis_lock_warning_threshold},
+)
+ """
+
+ (proj_root / "rxconfig.py").write_text(dedent(config_string))
+
+ with chdir(proj_root):
+ # reload config for each parameter to avoid stale values
+ reflex.config.get_config(reload=True)
+ from reflex.state import State, StateManager
+
+ with pytest.raises(InvalidLockWarningThresholdError):
+ StateManager.create(state=State)
+ del sys.modules[constants.Config.MODULE]
class MixinState(State, mixin=True):
@@ -3262,3 +3447,320 @@ def test_child_mixin_state() -> None:
assert "computed" in ChildUsesMixinState.inherited_vars
assert "computed" not in ChildUsesMixinState.computed_vars
+
+
+def test_assignment_to_undeclared_vars():
+ """Test that an attribute error is thrown when undeclared vars are set."""
+
+ class State(BaseState):
+ val: str
+ _val: str
+ __val: str # type: ignore
+
+ def handle_supported_regular_vars(self):
+ self.val = "no underscore"
+ self._val = "single leading underscore"
+ self.__val = "double leading undercore"
+
+ def handle_regular_var(self):
+ self.num = 5
+
+ def handle_backend_var(self):
+ self._num = 5
+
+ def handle_non_var(self):
+ self.__num = 5
+
+ class Substate(State):
+ def handle_var(self):
+ self.value = 20
+
+ state = State() # type: ignore
+ sub_state = Substate() # type: ignore
+
+ with pytest.raises(SetUndefinedStateVarError):
+ state.handle_regular_var()
+
+ with pytest.raises(SetUndefinedStateVarError):
+ sub_state.handle_var()
+
+ with pytest.raises(SetUndefinedStateVarError):
+ state.handle_backend_var()
+
+ state.handle_supported_regular_vars()
+ state.handle_non_var()
+
+
+@pytest.mark.asyncio
+async def test_deserialize_gc_state_disk(token):
+ """Test that a state can be deserialized from disk with a grandchild state.
+
+ Args:
+ token: A token.
+ """
+
+ class Root(BaseState):
+ pass
+
+ class State(Root):
+ num: int = 42
+
+ class Child(State):
+ foo: str = "bar"
+
+ dsm = StateManagerDisk(state=Root)
+ async with dsm.modify_state(token) as root:
+ s = await root.get_state(State)
+ s.num += 1
+ c = await root.get_state(Child)
+ assert s._get_was_touched()
+ assert not c._get_was_touched()
+
+ dsm2 = StateManagerDisk(state=Root)
+ root = await dsm2.get_state(token)
+ s = await root.get_state(State)
+ assert s.num == 43
+ c = await root.get_state(Child)
+ assert c.foo == "bar"
+
+
+class Obj(Base):
+ """A object containing a callable for testing fallback pickle."""
+
+ _f: Callable
+
+
+def test_fallback_pickle():
+ """Test that state serialization will fall back to dill."""
+
+ class DillState(BaseState):
+ _o: Optional[Obj] = None
+ _f: Optional[Callable] = None
+ _g: Any = None
+
+ state = DillState(_reflex_internal_init=True) # type: ignore
+ state._o = Obj(_f=lambda: 42)
+ state._f = lambda: 420
+
+ pk = state._serialize()
+
+ unpickled_state = BaseState._deserialize(pk)
+ assert unpickled_state._f() == 420
+ assert unpickled_state._o._f() == 42
+
+ # Threading locks are unpicklable normally, and raise TypeError instead of PicklingError.
+ state2 = DillState(_reflex_internal_init=True) # type: ignore
+ state2._g = threading.Lock()
+ pk2 = state2._serialize()
+ unpickled_state2 = BaseState._deserialize(pk2)
+ assert isinstance(unpickled_state2._g, type(threading.Lock()))
+
+ # Some object, like generator, are still unpicklable with dill.
+ state3 = DillState(_reflex_internal_init=True) # type: ignore
+ state3._g = (i for i in range(10))
+
+ with pytest.raises(StateSerializationError):
+ _ = state3._serialize()
+
+
+def test_typed_state() -> None:
+ class TypedState(rx.State):
+ field: rx.Field[str] = rx.field("")
+
+ _ = TypedState(field="str")
+
+
+class ModelV1(BaseModelV1):
+ """A pydantic BaseModel v1."""
+
+ foo: str = "bar"
+
+
+class ModelV2(BaseModelV2):
+ """A pydantic BaseModel v2."""
+
+ foo: str = "bar"
+
+
+@dataclasses.dataclass
+class ModelDC:
+ """A dataclass."""
+
+ foo: str = "bar"
+
+
+class PydanticState(rx.State):
+ """A state with pydantic BaseModel vars."""
+
+ v1: ModelV1 = ModelV1()
+ v2: ModelV2 = ModelV2()
+ dc: ModelDC = ModelDC()
+
+
+def test_mutable_models():
+ """Test that dataclass and pydantic BaseModel v1 and v2 use dep tracking."""
+ state = PydanticState()
+ assert isinstance(state.v1, MutableProxy)
+ state.v1.foo = "baz"
+ assert state.dirty_vars == {"v1"}
+ state.dirty_vars.clear()
+
+ assert isinstance(state.v2, MutableProxy)
+ state.v2.foo = "baz"
+ assert state.dirty_vars == {"v2"}
+ state.dirty_vars.clear()
+
+ # Not yet supported ENG-4083
+ # assert isinstance(state.dc, MutableProxy) #noqa: ERA001
+ # state.dc.foo = "baz" #noqa: ERA001
+ # assert state.dirty_vars == {"dc"} #noqa: ERA001
+ # state.dirty_vars.clear() #noqa: ERA001
+
+
+def test_get_value():
+ class GetValueState(rx.State):
+ foo: str = "FOO"
+ bar: str = "BAR"
+
+ state = GetValueState()
+
+ assert state.dict() == {
+ state.get_full_name(): {
+ "foo": "FOO",
+ "bar": "BAR",
+ }
+ }
+ assert state.get_delta() == {}
+
+ state.bar = "foo"
+
+ assert state.dict() == {
+ state.get_full_name(): {
+ "foo": "FOO",
+ "bar": "foo",
+ }
+ }
+ assert state.get_delta() == {
+ state.get_full_name(): {
+ "bar": "foo",
+ }
+ }
+
+
+def test_init_mixin() -> None:
+ """Ensure that State mixins can not be instantiated directly."""
+
+ class Mixin(BaseState, mixin=True):
+ pass
+
+ with pytest.raises(ReflexRuntimeError):
+ Mixin()
+
+ class SubMixin(Mixin, mixin=True):
+ pass
+
+ with pytest.raises(ReflexRuntimeError):
+ SubMixin()
+
+
+class ReflexModel(rx.Model):
+ """A model for testing."""
+
+ foo: str
+
+
+class UpcastState(rx.State):
+ """A state for testing upcasting."""
+
+ passed: bool = False
+
+ def rx_model(self, m: ReflexModel): # noqa: D102
+ assert isinstance(m, ReflexModel)
+ self.passed = True
+
+ def rx_base(self, o: Object): # noqa: D102
+ assert isinstance(o, Object)
+ self.passed = True
+
+ def rx_base_or_none(self, o: Optional[Object]): # noqa: D102
+ if o is not None:
+ assert isinstance(o, Object)
+ self.passed = True
+
+ def rx_basemodelv1(self, m: ModelV1): # noqa: D102
+ assert isinstance(m, ModelV1)
+ self.passed = True
+
+ def rx_basemodelv2(self, m: ModelV2): # noqa: D102
+ assert isinstance(m, ModelV2)
+ self.passed = True
+
+ def rx_dataclass(self, dc: ModelDC): # noqa: D102
+ assert isinstance(dc, ModelDC)
+ self.passed = True
+
+ def py_set(self, s: set): # noqa: D102
+ assert isinstance(s, set)
+ self.passed = True
+
+ def py_Set(self, s: Set): # noqa: D102
+ assert isinstance(s, Set)
+ self.passed = True
+
+ def py_tuple(self, t: tuple): # noqa: D102
+ assert isinstance(t, tuple)
+ self.passed = True
+
+ def py_Tuple(self, t: Tuple): # noqa: D102
+ assert isinstance(t, tuple)
+ self.passed = True
+
+ def py_dict(self, d: dict[str, str]): # noqa: D102
+ assert isinstance(d, dict)
+ self.passed = True
+
+ def py_list(self, ls: list[str]): # noqa: D102
+ assert isinstance(ls, list)
+ self.passed = True
+
+ def py_Any(self, a: Any): # noqa: D102
+ assert isinstance(a, list)
+ self.passed = True
+
+ def py_unresolvable(self, u: "Unresolvable"): # noqa: D102, F821 # type: ignore
+ assert isinstance(u, list)
+ self.passed = True
+
+
+@pytest.mark.asyncio
+@pytest.mark.usefixtures("mock_app_simple")
+@pytest.mark.parametrize(
+ ("handler", "payload"),
+ [
+ (UpcastState.rx_model, {"m": {"foo": "bar"}}),
+ (UpcastState.rx_base, {"o": {"foo": "bar"}}),
+ (UpcastState.rx_base_or_none, {"o": {"foo": "bar"}}),
+ (UpcastState.rx_base_or_none, {"o": None}),
+ (UpcastState.rx_basemodelv1, {"m": {"foo": "bar"}}),
+ (UpcastState.rx_basemodelv2, {"m": {"foo": "bar"}}),
+ (UpcastState.rx_dataclass, {"dc": {"foo": "bar"}}),
+ (UpcastState.py_set, {"s": ["foo", "foo"]}),
+ (UpcastState.py_Set, {"s": ["foo", "foo"]}),
+ (UpcastState.py_tuple, {"t": ["foo", "foo"]}),
+ (UpcastState.py_Tuple, {"t": ["foo", "foo"]}),
+ (UpcastState.py_dict, {"d": {"foo": "bar"}}),
+ (UpcastState.py_list, {"ls": ["foo", "foo"]}),
+ (UpcastState.py_Any, {"a": ["foo"]}),
+ (UpcastState.py_unresolvable, {"u": ["foo"]}),
+ ],
+)
+async def test_upcast_event_handler_arg(handler, payload):
+ """Test that upcast event handler args work correctly.
+
+ Args:
+ handler: The handler to test.
+ payload: The payload to test.
+ """
+ state = UpcastState()
+ async for update in state._process_event(handler, state, payload):
+ assert update.delta == {UpcastState.get_full_name(): {"passed": True}}
diff --git a/tests/units/test_state_tree.py b/tests/units/test_state_tree.py
index 7c1e13a91..ebdd877de 100644
--- a/tests/units/test_state_tree.py
+++ b/tests/units/test_state_tree.py
@@ -1,9 +1,9 @@
"""Specialized test for a larger state tree."""
-import asyncio
-from typing import Generator
+from typing import AsyncGenerator
import pytest
+import pytest_asyncio
import reflex as rx
from reflex.state import BaseState, StateManager, StateManagerRedis, _substate_key
@@ -210,8 +210,10 @@ ALWAYS_COMPUTED_DICT_KEYS = [
]
-@pytest.fixture(scope="function")
-def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]:
+@pytest_asyncio.fixture(loop_scope="function", scope="function")
+async def state_manager_redis(
+ app_module_mock,
+) -> AsyncGenerator[StateManager, None]:
"""Instance of state manager for redis only.
Args:
@@ -228,7 +230,7 @@ def state_manager_redis(app_module_mock) -> Generator[StateManager, None, None]:
yield state_manager
- asyncio.get_event_loop().run_until_complete(state_manager.close())
+ await state_manager.close()
@pytest.mark.asyncio
diff --git a/tests/units/test_telemetry.py b/tests/units/test_telemetry.py
index a434779d4..d8a77dfd6 100644
--- a/tests/units/test_telemetry.py
+++ b/tests/units/test_telemetry.py
@@ -34,12 +34,6 @@ def test_disable():
@pytest.mark.parametrize("event", ["init", "reinit", "run-dev", "run-prod", "export"])
def test_send(mocker, event):
httpx_post_mock = mocker.patch("httpx.post")
- # mocker.patch(
- # "builtins.open",
- # mocker.mock_open(
- # read_data='{"project_hash": "78285505863498957834586115958872998605"}'
- # ),
- # )
# Mock the read_text method of Path
pathlib_path_read_text_mock = mocker.patch(
@@ -52,4 +46,4 @@ def test_send(mocker, event):
telemetry._send(event, telemetry_enabled=True)
httpx_post_mock.assert_called_once()
- pathlib_path_read_text_mock.assert_called_once()
+ assert pathlib_path_read_text_mock.call_count == 2
diff --git a/tests/units/test_testing.py b/tests/units/test_testing.py
index b01709202..83a03ad83 100644
--- a/tests/units/test_testing.py
+++ b/tests/units/test_testing.py
@@ -29,7 +29,7 @@ def test_app_harness(tmp_path):
with AppHarness.create(
root=tmp_path,
- app_source=BasicApp, # type: ignore
+ app_source=BasicApp,
) as harness:
assert harness.app_instance is not None
assert harness.backend is not None
diff --git a/tests/units/test_var.py b/tests/units/test_var.py
index 227b01d85..048752d11 100644
--- a/tests/units/test_var.py
+++ b/tests/units/test_var.py
@@ -1,5 +1,6 @@
import json
import math
+import sys
import typing
from typing import Dict, List, Optional, Set, Tuple, Union, cast
@@ -21,12 +22,12 @@ from reflex.vars.base import (
var_operation,
var_operation_return,
)
-from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar
-from reflex.vars.number import (
- LiteralBooleanVar,
- LiteralNumberVar,
- NumberVar,
+from reflex.vars.function import (
+ ArgsFunctionOperation,
+ DestructuredArg,
+ FunctionStringVar,
)
+from reflex.vars.number import LiteralBooleanVar, LiteralNumberVar, NumberVar
from reflex.vars.object import LiteralObjectVar, ObjectVar
from reflex.vars.sequence import (
ArrayVar,
@@ -214,7 +215,7 @@ def test_str(prop, expected):
@pytest.mark.parametrize(
- "prop,expected",
+ ("prop", "expected"),
[
(Var(_js_expr="p", _var_type=int), 0),
(Var(_js_expr="p", _var_type=float), 0.0),
@@ -226,14 +227,14 @@ def test_str(prop, expected):
(Var(_js_expr="p", _var_type=set), set()),
],
)
-def test_default_value(prop, expected):
+def test_default_value(prop: Var, expected):
"""Test that the default value of a var is correct.
Args:
prop: The var to test.
expected: The expected default value.
"""
- assert prop.get_default_value() == expected
+ assert prop._get_default_value() == expected
@pytest.mark.parametrize(
@@ -249,14 +250,14 @@ def test_default_value(prop, expected):
],
),
)
-def test_get_setter(prop, expected):
+def test_get_setter(prop: Var, expected):
"""Test that the name of the setter function of a var is correct.
Args:
prop: The var to test.
expected: The expected name of the setter function.
"""
- assert prop.get_setter_name() == expected
+ assert prop._get_setter_name() == expected
@pytest.mark.parametrize(
@@ -398,6 +399,44 @@ def test_list_tuple_contains(var, expected):
assert str(var.contains(other_var)) == f"{expected}.includes(other)"
+class Foo(rx.Base):
+ """Foo class."""
+
+ bar: int
+ baz: str
+
+
+class Bar(rx.Base):
+ """Bar class."""
+
+ bar: str
+ baz: str
+ foo: int
+
+
+@pytest.mark.parametrize(
+ ("var", "var_type"),
+ (
+ [
+ (Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar),
+ (Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]),
+ ]
+ if sys.version_info >= (3, 10)
+ else []
+ )
+ + [
+ (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]),
+ (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str),
+ (
+ Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo,
+ Union[int, None],
+ ),
+ ],
+)
+def test_var_types(var, var_type):
+ assert var._var_type == var_type
+
+
@pytest.mark.parametrize(
"var, expected",
[
@@ -480,8 +519,8 @@ def test_var_indexing_types(var, type_):
type_ : The type on indexed object.
"""
- assert var[2]._var_type == type_[0]
- assert var[3]._var_type == type_[1]
+ assert var[0]._var_type == type_[0]
+ assert var[1]._var_type == type_[1]
def test_var_indexing_str():
@@ -490,7 +529,7 @@ def test_var_indexing_str():
# Test that indexing gives a type of Var[str].
assert isinstance(str_var[0], Var)
- assert str_var[0]._var_type == str
+ assert str_var[0]._var_type is str
# Test basic indexing.
assert str(str_var[0]) == "str.at(0)"
@@ -623,7 +662,7 @@ def test_str_var_slicing():
# Test that slicing gives a type of Var[str].
assert isinstance(str_var[:1], Var)
- assert str_var[:1]._var_type == str
+ assert str_var[:1]._var_type is str
# Test basic slicing.
assert str(str_var[:1]) == 'str.split("").slice(undefined, 1).join("")'
@@ -886,13 +925,13 @@ def test_function_var():
)
assert (
str(manual_addition_func.call(1, 2))
- == '(((a, b) => (({ ["args"] : [a, b], ["result"] : a + b })))(1, 2))'
+ == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))'
)
- increment_func = addition_func(1)
+ increment_func = addition_func.partial(1)
assert (
str(increment_func.call(2))
- == "(((...args) => ((((a, b) => a + b)(1, ...args))))(2))"
+ == "(((...args) => (((a, b) => a + b)(1, ...args)))(2))"
)
create_hello_statement = ArgsFunctionOperation.create(
@@ -902,9 +941,25 @@ def test_function_var():
last_name = LiteralStringVar.create("Universe")
assert (
str(create_hello_statement.call(f"{first_name} {last_name}"))
- == '(((name) => (("Hello, "+name+"!")))("Steven Universe"))'
+ == '(((name) => ("Hello, "+name+"!"))("Steven Universe"))'
)
+ # Test with destructured arguments
+ destructured_func = ArgsFunctionOperation.create(
+ (DestructuredArg(fields=("a", "b")),),
+ Var(_js_expr="a + b"),
+ )
+ assert (
+ str(destructured_func.call({"a": 1, "b": 2}))
+ == '((({a, b}) => a + b)(({ ["a"] : 1, ["b"] : 2 })))'
+ )
+
+ # Test with explicit return
+ explicit_return_func = ArgsFunctionOperation.create(
+ ("a", "b"), Var(_js_expr="return a + b"), explicit_return=True
+ )
+ assert str(explicit_return_func.call(1, 2)) == "(((a, b) => {return a + b})(1, 2))"
+
def test_var_operation():
@var_operation
@@ -1267,7 +1322,6 @@ def test_fstring_roundtrip(value):
Var(_js_expr="var", _var_type=float).guess_type(),
Var(_js_expr="var", _var_type=str).guess_type(),
Var(_js_expr="var", _var_type=bool).guess_type(),
- Var(_js_expr="var", _var_type=dict).guess_type(),
Var(_js_expr="var", _var_type=None).guess_type(),
],
)
@@ -1279,7 +1333,7 @@ def test_unsupported_types_for_reverse(var):
"""
with pytest.raises(TypeError) as err:
var.reverse()
- assert err.value.args[0] == f"Cannot reverse non-list var."
+ assert err.value.args[0] == "Cannot reverse non-list var."
@pytest.mark.parametrize(
@@ -1288,10 +1342,10 @@ def test_unsupported_types_for_reverse(var):
Var(_js_expr="var", _var_type=int).guess_type(),
Var(_js_expr="var", _var_type=float).guess_type(),
Var(_js_expr="var", _var_type=bool).guess_type(),
- Var(_js_expr="var", _var_type=None).guess_type(),
+ Var(_js_expr="var", _var_type=type(None)).guess_type(),
],
)
-def test_unsupported_types_for_contains(var):
+def test_unsupported_types_for_contains(var: Var):
"""Test that unsupported types for contains throw a type error.
Args:
@@ -1441,8 +1495,6 @@ def test_valid_var_operations(operand1_var: Var, operand2_var, operators: List[s
)
eval(f"operand1_var {operator} operand2_var")
eval(f"operand2_var {operator} operand1_var")
- # operand1_var.operation(op=operator, other=operand2_var)
- # operand1_var.operation(op=operator, other=operand2_var, flip=True)
@pytest.mark.parametrize(
@@ -1716,14 +1768,12 @@ def test_valid_var_operations(operand1_var: Var, operand2_var, operators: List[s
)
def test_invalid_var_operations(operand1_var: Var, operand2_var, operators: List[str]):
for operator in operators:
- print(f"testing {operator} on {str(operand1_var)} and {str(operand2_var)}")
+ print(f"testing {operator} on {operand1_var!s} and {operand2_var!s}")
with pytest.raises(TypeError):
print(eval(f"operand1_var {operator} operand2_var"))
- # operand1_var.operation(op=operator, other=operand2_var)
with pytest.raises(TypeError):
print(eval(f"operand2_var {operator} operand1_var"))
- # operand1_var.operation(op=operator, other=operand2_var, flip=True)
@pytest.mark.parametrize(
@@ -1809,3 +1859,6 @@ def test_to_string_operation():
assert cast(Var, TestState.email)._var_type == Email
assert cast(Var, TestState.optional_email)._var_type == Optional[Email]
+
+ single_var = Var.create(Email())
+ assert single_var._var_type == Email
diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py
index 042c3f323..cd1d0179d 100644
--- a/tests/units/utils/test_format.py
+++ b/tests/units/utils/test_format.py
@@ -8,7 +8,7 @@ import plotly.graph_objects as go
import pytest
from reflex.components.tags.tag import Tag
-from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent
+from reflex.event import EventChain, EventHandler, EventSpec, JavascriptInputEvent
from reflex.style import Style
from reflex.utils import format
from reflex.utils.serializers import serialize_figure
@@ -374,7 +374,7 @@ def test_format_match(
events=[EventSpec(handler=EventHandler(fn=mock_event))],
args_spec=lambda: [],
),
- '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ })))))',
+ '((...args) => (addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ }))))',
),
(
EventChain(
@@ -387,7 +387,7 @@ def test_format_match(
Var(
_js_expr="_e",
)
- .to(ObjectVar, FrontendEvent)
+ .to(ObjectVar, JavascriptInputEvent)
.target.value,
),
),
@@ -395,7 +395,7 @@ def test_format_match(
],
args_spec=lambda e: [e.target.value],
),
- '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] })))], [_e], ({ })))))',
+ '((_e) => (addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] }), ({ })))], [_e], ({ }))))',
),
(
EventChain(
@@ -403,7 +403,19 @@ def test_format_match(
args_spec=lambda: [],
event_actions={"stopPropagation": True},
),
- '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["stopPropagation"] : true })))))',
+ '((...args) => (addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["stopPropagation"] : true }))))',
+ ),
+ (
+ EventChain(
+ events=[
+ EventSpec(
+ handler=EventHandler(fn=mock_event),
+ event_actions={"stopPropagation": True},
+ )
+ ],
+ args_spec=lambda: [],
+ ),
+ '((...args) => (addEvents([(Event("mock_event", ({ }), ({ ["stopPropagation"] : true })))], args, ({ }))))',
),
(
EventChain(
@@ -411,7 +423,7 @@ def test_format_match(
args_spec=lambda: [],
event_actions={"preventDefault": True},
),
- '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["preventDefault"] : true })))))',
+ '((...args) => (addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["preventDefault"] : true }))))',
),
({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'),
(Var(_js_expr="var", _var_type=int).guess_type(), "var"),
@@ -519,7 +531,7 @@ def test_format_event_handler(input, output):
[
(
EventSpec(handler=EventHandler(fn=mock_event)),
- '(Event("mock_event", ({ })))',
+ '(Event("mock_event", ({ }), ({ })))',
),
],
)
@@ -589,6 +601,7 @@ formatted_router = {
"sum": 3.14,
"upper": "",
"router": formatted_router,
+ "asynctest": 0,
},
ChildState.get_full_name(): {
"count": 23,
diff --git a/tests/units/utils/test_serializers.py b/tests/units/utils/test_serializers.py
index 630187309..355f40d3f 100644
--- a/tests/units/utils/test_serializers.py
+++ b/tests/units/utils/test_serializers.py
@@ -20,7 +20,6 @@ from reflex.vars.base import LiteralVar
def test_has_serializer(type_: Type, expected: bool):
"""Test that has_serializer returns the correct value.
-
Args:
type_: The type to check.
expected: The expected result.
@@ -41,7 +40,6 @@ def test_has_serializer(type_: Type, expected: bool):
def test_get_serializer(type_: Type, expected: serializers.Serializer):
"""Test that get_serializer returns the correct value.
-
Args:
type_: The type to check.
expected: The expected result.
@@ -96,7 +94,7 @@ class StrEnum(str, Enum):
BAR = "bar"
-class TestEnum(Enum):
+class FooBarEnum(Enum):
"""A lone enum class."""
FOO = "foo"
@@ -151,10 +149,10 @@ class BaseSubclass(Base):
"key2": "prefix_bar",
},
),
- (TestEnum.FOO, "foo"),
- ([TestEnum.FOO, TestEnum.BAR], ["foo", "bar"]),
+ (FooBarEnum.FOO, "foo"),
+ ([FooBarEnum.FOO, FooBarEnum.BAR], ["foo", "bar"]),
(
- {"key1": TestEnum.FOO, "key2": TestEnum.BAR},
+ {"key1": FooBarEnum.FOO, "key2": FooBarEnum.BAR},
{
"key1": "foo",
"key2": "bar",
@@ -195,7 +193,6 @@ class BaseSubclass(Base):
def test_serialize(value: Any, expected: str):
"""Test that serialize returns the correct value.
-
Args:
value: The value to serialize.
expected: The expected result.
diff --git a/tests/units/utils/test_types.py b/tests/units/utils/test_types.py
index fc9261e04..87790e979 100644
--- a/tests/units/utils/test_types.py
+++ b/tests/units/utils/test_types.py
@@ -1,4 +1,4 @@
-from typing import Any, List, Literal, Tuple, Union
+from typing import Any, Dict, List, Literal, Tuple, Union
import pytest
@@ -17,7 +17,7 @@ def test_validate_literal_error_msg(params, allowed_value_str, value_str):
types.validate_literal(*params)
assert (
- err.value.args[0] == f"prop value for {str(params[0])} of the `{params[-1]}` "
+ err.value.args[0] == f"prop value for {params[0]!s} of the `{params[-1]}` "
f"component should be one of the following: {allowed_value_str}. Got {value_str} instead"
)
@@ -45,3 +45,48 @@ def test_issubclass(
cls: types.GenericType, cls_check: types.GenericType, expected: bool
) -> None:
assert types._issubclass(cls, cls_check) == expected
+
+
+class CustomDict(dict[str, str]):
+ """A custom dict with generic arguments."""
+
+ pass
+
+
+class ChildCustomDict(CustomDict):
+ """A child of CustomDict."""
+
+ pass
+
+
+class GenericDict(dict):
+ """A generic dict with no generic arguments."""
+
+ pass
+
+
+class ChildGenericDict(GenericDict):
+ """A child of GenericDict."""
+
+ pass
+
+
+@pytest.mark.parametrize(
+ "cls,expected",
+ [
+ (int, False),
+ (str, False),
+ (float, False),
+ (Tuple[int], True),
+ (List[int], True),
+ (Union[int, str], True),
+ (Union[str, int], True),
+ (Dict[str, int], True),
+ (CustomDict, True),
+ (ChildCustomDict, True),
+ (GenericDict, False),
+ (ChildGenericDict, False),
+ ],
+)
+def test_has_args(cls, expected: bool) -> None:
+ assert types.has_args(cls) == expected
diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py
index 5cdd846fe..20bad4146 100644
--- a/tests/units/utils/test_utils.py
+++ b/tests/units/utils/test_utils.py
@@ -2,7 +2,7 @@ import os
import typing
from functools import cached_property
from pathlib import Path
-from typing import Any, ClassVar, List, Literal, Type, Union
+from typing import Any, ClassVar, Dict, List, Literal, Type, Union
import pytest
import typer
@@ -10,15 +10,12 @@ from packaging import version
from reflex import constants
from reflex.base import Base
+from reflex.config import environment
from reflex.event import EventHandler
from reflex.state import BaseState
-from reflex.utils import (
- build,
- prerequisites,
- types,
-)
+from reflex.utils import build, prerequisites, types
from reflex.utils import exec as utils_exec
-from reflex.utils.exceptions import ReflexError
+from reflex.utils.exceptions import ReflexError, SystemPackageMissingError
from reflex.vars.base import Var
@@ -77,6 +74,47 @@ def test_is_generic_alias(cls: type, expected: bool):
assert types.is_generic_alias(cls) == expected
+@pytest.mark.parametrize(
+ ("subclass", "superclass", "expected"),
+ [
+ *[
+ (base_type, base_type, True)
+ for base_type in [int, float, str, bool, list, dict]
+ ],
+ *[
+ (one_type, another_type, False)
+ for one_type in [int, float, str, list, dict]
+ for another_type in [int, float, str, list, dict]
+ if one_type != another_type
+ ],
+ (bool, int, True),
+ (int, bool, False),
+ (list, List, True),
+ (list, List[str], True), # this is wrong, but it's a limitation of the function
+ (List, list, True),
+ (List[int], list, True),
+ (List[int], List, True),
+ (List[int], List[str], False),
+ (List[int], List[int], True),
+ (List[int], List[float], False),
+ (List[int], List[Union[int, float]], True),
+ (List[int], List[Union[float, str]], False),
+ (Union[int, float], List[Union[int, float]], False),
+ (Union[int, float], Union[int, float, str], True),
+ (Union[int, float], Union[str, float], False),
+ (Dict[str, int], Dict[str, int], True),
+ (Dict[str, bool], Dict[str, int], True),
+ (Dict[str, int], Dict[str, bool], False),
+ (Dict[str, Any], dict[str, str], False),
+ (Dict[str, str], dict[str, str], True),
+ (Dict[str, str], dict[str, Any], True),
+ (Dict[str, Any], dict[str, Any], True),
+ ],
+)
+def test_typehint_issubclass(subclass, superclass, expected):
+ assert types.typehint_issubclass(subclass, superclass) == expected
+
+
def test_validate_invalid_bun_path(mocker):
"""Test that an error is thrown when a custom specified bun path is not valid
or does not exist.
@@ -117,7 +155,7 @@ def test_remove_existing_bun_installation(mocker):
Args:
mocker: Pytest mocker.
"""
- mocker.patch("reflex.utils.prerequisites.os.path.exists", return_value=True)
+ mocker.patch("reflex.utils.prerequisites.Path.exists", return_value=True)
rm = mocker.patch("reflex.utils.prerequisites.path_ops.rm", mocker.Mock())
prerequisites.remove_existing_bun_installation()
@@ -260,7 +298,7 @@ def tmp_working_dir(tmp_path):
Yields:
subdirectory of tmp_path which is now the current working directory.
"""
- old_pwd = Path(".").resolve()
+ old_pwd = Path.cwd()
working_dir = tmp_path / "working_dir"
working_dir.mkdir()
os.chdir(working_dir)
@@ -458,10 +496,10 @@ def test_bun_install_without_unzip(mocker):
mocker: Pytest mocker object.
"""
mocker.patch("reflex.utils.path_ops.which", return_value=None)
- mocker.patch("os.path.exists", return_value=False)
+ mocker.patch("pathlib.Path.exists", return_value=False)
mocker.patch("reflex.utils.prerequisites.constants.IS_WINDOWS", False)
- with pytest.raises(FileNotFoundError):
+ with pytest.raises(SystemPackageMissingError):
prerequisites.install_bun()
@@ -476,7 +514,7 @@ def test_bun_install_version(mocker, bun_version):
"""
mocker.patch("reflex.utils.prerequisites.constants.IS_WINDOWS", False)
- mocker.patch("os.path.exists", return_value=True)
+ mocker.patch("pathlib.Path.exists", return_value=True)
mocker.patch(
"reflex.utils.prerequisites.get_bun_version",
return_value=version.parse(bun_version),
@@ -542,7 +580,9 @@ def test_style_prop_with_event_handler_value(callable):
style = {
"color": (
- EventHandler(fn=callable) if type(callable) != EventHandler else callable
+ EventHandler(fn=callable)
+ if type(callable) is not EventHandler
+ else callable
)
}
@@ -550,3 +590,11 @@ def test_style_prop_with_event_handler_value(callable):
rx.box(
style=style, # type: ignore
)
+
+
+def test_is_prod_mode() -> None:
+ """Test that the prod mode is correctly determined."""
+ environment.REFLEX_ENV_MODE.set(constants.Env.PROD)
+ assert utils_exec.is_prod_mode()
+ environment.REFLEX_ENV_MODE.set(None)
+ assert not utils_exec.is_prod_mode()
diff --git a/tests/units/vars/test_base.py b/tests/units/vars/test_base.py
index f83d79373..68bc0c38e 100644
--- a/tests/units/vars/test_base.py
+++ b/tests/units/vars/test_base.py
@@ -5,6 +5,30 @@ import pytest
from reflex.vars.base import figure_out_type
+class CustomDict(dict[str, str]):
+ """A custom dict with generic arguments."""
+
+ pass
+
+
+class ChildCustomDict(CustomDict):
+ """A child of CustomDict."""
+
+ pass
+
+
+class GenericDict(dict):
+ """A generic dict with no generic arguments."""
+
+ pass
+
+
+class ChildGenericDict(GenericDict):
+ """A child of GenericDict."""
+
+ pass
+
+
@pytest.mark.parametrize(
("value", "expected"),
[
@@ -15,6 +39,10 @@ from reflex.vars.base import figure_out_type
([1, 2.0, "a"], List[Union[int, float, str]]),
({"a": 1, "b": 2}, Dict[str, int]),
({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]),
+ (CustomDict(), CustomDict),
+ (ChildCustomDict(), ChildCustomDict),
+ (GenericDict({1: 1}), Dict[int, int]),
+ (ChildGenericDict({1: 1}), Dict[int, int]),
],
)
def test_figure_out_type(value, expected):
diff --git a/tests/units/vars/test_object.py b/tests/units/vars/test_object.py
new file mode 100644
index 000000000..efcb21166
--- /dev/null
+++ b/tests/units/vars/test_object.py
@@ -0,0 +1,102 @@
+import pytest
+from typing_extensions import assert_type
+
+import reflex as rx
+from reflex.utils.types import GenericType
+from reflex.vars.base import Var
+from reflex.vars.object import LiteralObjectVar, ObjectVar
+
+
+class Bare:
+ """A bare class with a single attribute."""
+
+ quantity: int = 0
+
+
+@rx.serializer
+def serialize_bare(obj: Bare) -> dict:
+ """A serializer for the bare class.
+
+ Args:
+ obj: The object to serialize.
+
+ Returns:
+ A dictionary with the quantity attribute.
+ """
+ return {"quantity": obj.quantity}
+
+
+class Base(rx.Base):
+ """A reflex base class with a single attribute."""
+
+ quantity: int = 0
+
+
+class ObjectState(rx.State):
+ """A reflex state with bare and base objects."""
+
+ bare: rx.Field[Bare] = rx.field(Bare())
+ base: rx.Field[Base] = rx.field(Base())
+
+
+@pytest.mark.parametrize("type_", [Base, Bare])
+def test_var_create(type_: GenericType) -> None:
+ my_object = type_()
+ var = Var.create(my_object)
+ assert var._var_type is type_
+
+ quantity = var.quantity
+ assert quantity._var_type is int
+
+
+@pytest.mark.parametrize("type_", [Base, Bare])
+def test_literal_create(type_: GenericType) -> None:
+ my_object = type_()
+ var = LiteralObjectVar.create(my_object)
+ assert var._var_type is type_
+
+ quantity = var.quantity
+ assert quantity._var_type is int
+
+
+@pytest.mark.parametrize("type_", [Base, Bare])
+def test_guess(type_: GenericType) -> None:
+ my_object = type_()
+ var = Var.create(my_object)
+ var = var.guess_type()
+ assert var._var_type is type_
+
+ quantity = var.quantity
+ assert quantity._var_type is int
+
+
+@pytest.mark.parametrize("type_", [Base, Bare])
+def test_state(type_: GenericType) -> None:
+ attr_name = type_.__name__.lower()
+ var = getattr(ObjectState, attr_name)
+ assert var._var_type is type_
+
+ quantity = var.quantity
+ assert quantity._var_type is int
+
+
+@pytest.mark.parametrize("type_", [Base, Bare])
+def test_state_to_operation(type_: GenericType) -> None:
+ attr_name = type_.__name__.lower()
+ original_var = getattr(ObjectState, attr_name)
+
+ var = original_var.to(ObjectVar, type_)
+ assert var._var_type is type_
+
+ var = original_var.to(ObjectVar)
+ assert var._var_type is type_
+
+
+def test_typing() -> None:
+ # Bare
+ var = ObjectState.bare.to(ObjectVar)
+ _ = assert_type(var, ObjectVar[Bare])
+
+ # Base
+ var = ObjectState.base
+ _ = assert_type(var, ObjectVar[Base])