rx.match component (#2318)
* initial commit * add more tests * refactor match jinja template * add docstrings * cleanup * possible fix for pyright * fix conflicts * fix conflicts again * comments * fixed bug from review * fix tests * address PR comment * fix tests * type error improvement * formatting * darglint fix * more tests * stringify switch condition and cases as js doesnt support complex types(lists and dicts) in switch cases. * Update reflex/vars.py Co-authored-by: Masen Furer <m_github@0x26.net> * change usages * Precommit fix --------- Co-authored-by: Alek Petuskey <alek@pynecone.io> Co-authored-by: Masen Furer <m_github@0x26.net> Co-authored-by: Alek Petuskey <alekpetuskey@aleks-mbp.lan>
This commit is contained in:
parent
32b9b00e05
commit
abfc099779
@ -8,6 +8,8 @@
|
||||
{{- component }}
|
||||
{%- elif "iterable" in component %}
|
||||
{{- render_iterable_tag(component) }}
|
||||
{%- elif component.name == "match"%}
|
||||
{{- render_match_tag(component) }}
|
||||
{%- elif "cond" in component %}
|
||||
{{- render_condition_tag(component) }}
|
||||
{%- elif component.children|length %}
|
||||
@ -77,6 +79,28 @@
|
||||
{% if props|length %} {{ props|join(" ") }}{% endif %}
|
||||
{% endmacro %}
|
||||
|
||||
{# Rendering Match component. #}
|
||||
{# Args: #}
|
||||
{# component: component dictionary #}
|
||||
{% macro render_match_tag(component) %}
|
||||
{
|
||||
(() => {
|
||||
switch (JSON.stringify({{ component.cond._var_full_name }})) {
|
||||
{% for case in component.match_cases %}
|
||||
{% for condition in case[:-1] %}
|
||||
case JSON.stringify({{ condition._var_name_unwrapped }}):
|
||||
{% endfor %}
|
||||
return {{ case[-1] }};
|
||||
break;
|
||||
{% endfor %}
|
||||
default:
|
||||
return {{ component.default }};
|
||||
break;
|
||||
}
|
||||
})()
|
||||
}
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
{# Rendering content with args. #}
|
||||
{# Args: #}
|
||||
|
@ -114,6 +114,7 @@ _ALL_COMPONENTS = [
|
||||
"List",
|
||||
"ListItem",
|
||||
"Markdown",
|
||||
"Match",
|
||||
"Menu",
|
||||
"MenuButton",
|
||||
"MenuDivider",
|
||||
|
@ -107,6 +107,7 @@ from reflex.components import LinkOverlay as LinkOverlay
|
||||
from reflex.components import List as List
|
||||
from reflex.components import ListItem as ListItem
|
||||
from reflex.components import Markdown as Markdown
|
||||
from reflex.components import Match as Match
|
||||
from reflex.components import Menu as Menu
|
||||
from reflex.components import MenuButton as MenuButton
|
||||
from reflex.components import MenuDivider as MenuDivider
|
||||
@ -317,6 +318,7 @@ from reflex.components import link_overlay as link_overlay
|
||||
from reflex.components import list as list
|
||||
from reflex.components import list_item as list_item
|
||||
from reflex.components import markdown as markdown
|
||||
from reflex.components import match as match
|
||||
from reflex.components import menu as menu
|
||||
from reflex.components import menu_button as menu_button
|
||||
from reflex.components import menu_divider as menu_divider
|
||||
|
@ -64,7 +64,7 @@ from reflex.state import (
|
||||
StateManager,
|
||||
StateUpdate,
|
||||
)
|
||||
from reflex.utils import console, format, prerequisites, types
|
||||
from reflex.utils import console, exceptions, format, prerequisites, types
|
||||
from reflex.utils.imports import ImportVar
|
||||
|
||||
# Define custom types.
|
||||
@ -344,9 +344,12 @@ class App(Base):
|
||||
|
||||
Raises:
|
||||
TypeError: When an invalid component function is passed.
|
||||
exceptions.MatchTypeError: If the return types of match cases in rx.match are different.
|
||||
"""
|
||||
try:
|
||||
return component if isinstance(component, Component) else component()
|
||||
except exceptions.MatchTypeError:
|
||||
raise
|
||||
except TypeError as e:
|
||||
message = str(e)
|
||||
if "BaseVar" in message or "ComputedVar" in message:
|
||||
|
@ -4,6 +4,7 @@ from .banner import ConnectionBanner, ConnectionModal
|
||||
from .cond import Cond, cond
|
||||
from .debounce import DebounceInput
|
||||
from .foreach import Foreach
|
||||
from .match import Match
|
||||
from .responsive import (
|
||||
desktop_only,
|
||||
mobile_and_tablet,
|
||||
@ -17,4 +18,5 @@ connection_banner = ConnectionBanner.create
|
||||
connection_modal = ConnectionModal.create
|
||||
debounce_input = DebounceInput.create
|
||||
foreach = Foreach.create
|
||||
match = Match.create
|
||||
upload = Upload.create
|
||||
|
257
reflex/components/core/match.py
Normal file
257
reflex/components/core/match.py
Normal file
@ -0,0 +1,257 @@
|
||||
"""rx.match."""
|
||||
import textwrap
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from reflex.components.base import Fragment
|
||||
from reflex.components.component import BaseComponent, Component, MemoizationLeaf
|
||||
from reflex.components.tags import MatchTag, Tag
|
||||
from reflex.utils import format, imports, types
|
||||
from reflex.utils.exceptions import MatchTypeError
|
||||
from reflex.vars import BaseVar, Var, VarData
|
||||
|
||||
|
||||
class Match(MemoizationLeaf):
|
||||
"""Match cases based on a condition."""
|
||||
|
||||
# The condition to determine which case to match.
|
||||
cond: Var[Any]
|
||||
|
||||
# The list of match cases to be matched.
|
||||
match_cases: List[Any] = []
|
||||
|
||||
# The catchall case to match.
|
||||
default: Any
|
||||
|
||||
@classmethod
|
||||
def create(cls, cond: Any, *cases) -> Union[Component, BaseVar]:
|
||||
"""Create a Match Component.
|
||||
|
||||
Args:
|
||||
cond: The condition to determine which case to match.
|
||||
cases: This list of cases to match.
|
||||
|
||||
Returns:
|
||||
The match component.
|
||||
|
||||
Raises:
|
||||
ValueError: When a default case is not provided for cases with Var return types.
|
||||
"""
|
||||
match_cond_var = cls._create_condition_var(cond)
|
||||
cases, default = cls._process_cases(list(cases))
|
||||
match_cases = cls._process_match_cases(cases)
|
||||
|
||||
cls._validate_return_types(match_cases)
|
||||
|
||||
if default is None and types._issubclass(type(match_cases[0][-1]), BaseVar):
|
||||
raise ValueError(
|
||||
"For cases with return types as Vars, a default case must be provided"
|
||||
)
|
||||
|
||||
return cls._create_match_cond_var_or_component(
|
||||
match_cond_var, match_cases, default
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _create_condition_var(cls, cond: Any) -> BaseVar:
|
||||
"""Convert the condition to a Var.
|
||||
|
||||
Args:
|
||||
cond: The condition.
|
||||
|
||||
Returns:
|
||||
The condition as a base var
|
||||
|
||||
Raises:
|
||||
ValueError: If the condition is not provided.
|
||||
"""
|
||||
match_cond_var = Var.create(cond)
|
||||
if match_cond_var is None:
|
||||
raise ValueError("The condition must be set")
|
||||
return match_cond_var # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _process_cases(
|
||||
cls, cases: List
|
||||
) -> Tuple[List, Optional[Union[BaseVar, BaseComponent]]]:
|
||||
"""Process the list of match cases and the catchall default case.
|
||||
|
||||
Args:
|
||||
cases: The list of match cases.
|
||||
|
||||
Returns:
|
||||
The default case and the list of match case tuples.
|
||||
|
||||
Raises:
|
||||
ValueError: If there are multiple default cases.
|
||||
"""
|
||||
default = None
|
||||
|
||||
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
|
||||
raise ValueError("rx.match can only have one default case.")
|
||||
|
||||
# Get the default case which should be the last non-tuple arg
|
||||
if not isinstance(cases[-1], tuple):
|
||||
default = cases.pop()
|
||||
default = (
|
||||
Var.create(default, _var_is_string=type(default) is str)
|
||||
if not isinstance(default, BaseComponent)
|
||||
else default
|
||||
)
|
||||
|
||||
return cases, default # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _process_match_cases(cls, cases: List) -> List[List[BaseVar]]:
|
||||
"""Process the individual match cases.
|
||||
|
||||
Args:
|
||||
cases: The match cases.
|
||||
|
||||
Returns:
|
||||
The processed match cases.
|
||||
|
||||
Raises:
|
||||
ValueError: If the default case is not the last case or the tuple elements are less than 2.
|
||||
"""
|
||||
match_cases = []
|
||||
for case in cases:
|
||||
if not isinstance(case, tuple):
|
||||
raise ValueError(
|
||||
"rx.match should have tuples of cases and a default case as the last argument."
|
||||
)
|
||||
# There should be at least two elements in a case tuple(a condition and return value)
|
||||
if len(case) < 2:
|
||||
raise ValueError(
|
||||
"A case tuple should have at least a match case element and a return value."
|
||||
)
|
||||
|
||||
case_list = []
|
||||
for element in case:
|
||||
# convert all non component element to vars.
|
||||
el = (
|
||||
Var.create(element, _var_is_string=type(element) is str)
|
||||
if not isinstance(element, BaseComponent)
|
||||
else element
|
||||
)
|
||||
if not isinstance(el, (BaseVar, BaseComponent)):
|
||||
raise ValueError("Case element must be a var or component")
|
||||
case_list.append(el)
|
||||
|
||||
match_cases.append(case_list)
|
||||
|
||||
return match_cases
|
||||
|
||||
@classmethod
|
||||
def _validate_return_types(cls, match_cases: List[List[BaseVar]]) -> None:
|
||||
"""Validate that match cases have the same return types.
|
||||
|
||||
Args:
|
||||
match_cases: The match cases.
|
||||
|
||||
Raises:
|
||||
MatchTypeError: If the return types of cases are different.
|
||||
"""
|
||||
first_case_return = match_cases[0][-1]
|
||||
return_type = type(first_case_return)
|
||||
|
||||
if types._isinstance(first_case_return, BaseComponent):
|
||||
return_type = BaseComponent
|
||||
elif types._isinstance(first_case_return, BaseVar):
|
||||
return_type = BaseVar
|
||||
|
||||
for index, case in enumerate(match_cases):
|
||||
if not types._issubclass(type(case[-1]), return_type):
|
||||
raise MatchTypeError(
|
||||
f"Match cases should have the same return types. Case {index} with return "
|
||||
f"value `{case[-1]._var_name if isinstance(case[-1], BaseVar) else textwrap.shorten(str(case[-1]), width=250)}`"
|
||||
f" of type {type(case[-1])!r} is not {return_type}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _create_match_cond_var_or_component(
|
||||
cls,
|
||||
match_cond_var: Var,
|
||||
match_cases: List[List[BaseVar]],
|
||||
default: Optional[Union[BaseVar, BaseComponent]],
|
||||
) -> Union[Component, BaseVar]:
|
||||
"""Create and return the match condition var or component.
|
||||
|
||||
Args:
|
||||
match_cond_var: The match condition.
|
||||
match_cases: The list of match cases.
|
||||
default: The default case.
|
||||
|
||||
Returns:
|
||||
The match component wrapped in a fragment or the match var.
|
||||
|
||||
Raises:
|
||||
ValueError: If the return types are not vars when creating a match var for Var types.
|
||||
"""
|
||||
if default is None and types._issubclass(
|
||||
type(match_cases[0][-1]), BaseComponent
|
||||
):
|
||||
default = Fragment.create()
|
||||
|
||||
if types._issubclass(type(match_cases[0][-1]), BaseComponent):
|
||||
return Fragment.create(
|
||||
cls(
|
||||
cond=match_cond_var,
|
||||
match_cases=match_cases,
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
|
||||
# Validate the match cases (as well as the default case) to have Var return types.
|
||||
if any(
|
||||
case for case in match_cases if not types._isinstance(case[-1], BaseVar)
|
||||
) or not types._isinstance(default, BaseVar):
|
||||
raise ValueError("Return types of match cases should be Vars.")
|
||||
|
||||
# match cases and default should all be Vars at this point.
|
||||
# Retrieve var data of every var in the match cases and default.
|
||||
var_data = [
|
||||
*[el._var_data for case in match_cases for el in case],
|
||||
default._var_data, # type: ignore
|
||||
]
|
||||
|
||||
return match_cond_var._replace(
|
||||
_var_name=format.format_match(
|
||||
cond=match_cond_var._var_full_name,
|
||||
match_cases=match_cases, # type: ignore
|
||||
default=default, # type: ignore
|
||||
),
|
||||
_var_type=default._var_type, # type: ignore
|
||||
_var_is_local=False,
|
||||
_var_full_name_needs_state_prefix=False,
|
||||
merge_var_data=VarData.merge(*var_data),
|
||||
)
|
||||
|
||||
def _render(self) -> Tag:
|
||||
return MatchTag(
|
||||
cond=self.cond, match_cases=self.match_cases, default=self.default
|
||||
)
|
||||
|
||||
def render(self) -> Dict:
|
||||
"""Render the component.
|
||||
|
||||
Returns:
|
||||
The dictionary for template of component.
|
||||
"""
|
||||
tag = self._render()
|
||||
tag.name = "match"
|
||||
return dict(tag)
|
||||
|
||||
def _get_imports(self):
|
||||
merged_imports = super()._get_imports()
|
||||
# Obtain the imports of all components the in match case.
|
||||
for case in self.match_cases:
|
||||
if isinstance(case[-1], BaseComponent):
|
||||
merged_imports = imports.merge_imports(
|
||||
merged_imports, case[-1].get_imports()
|
||||
)
|
||||
# Get the import of the default case component.
|
||||
if isinstance(self.default, BaseComponent):
|
||||
merged_imports = imports.merge_imports(
|
||||
merged_imports, self.default.get_imports()
|
||||
)
|
||||
return merged_imports
|
@ -2,4 +2,5 @@
|
||||
|
||||
from .cond_tag import CondTag
|
||||
from .iter_tag import IterTag
|
||||
from .match_tag import MatchTag
|
||||
from .tag import Tag
|
||||
|
19
reflex/components/tags/match_tag.py
Normal file
19
reflex/components/tags/match_tag.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""Tag to conditionally match cases."""
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
from reflex.components.tags.tag import Tag
|
||||
from reflex.vars import Var
|
||||
|
||||
|
||||
class MatchTag(Tag):
|
||||
"""A match tag."""
|
||||
|
||||
# The condition to determine which case to match.
|
||||
cond: Var[Any]
|
||||
|
||||
# The list of match cases to be matched.
|
||||
match_cases: List[Any]
|
||||
|
||||
# The catchall case to match.
|
||||
default: Any
|
@ -13,3 +13,9 @@ class ImmutableStateError(AttributeError):
|
||||
|
||||
class LockExpiredError(Exception):
|
||||
"""Raised when the state lock expires while an event is being processed."""
|
||||
|
||||
|
||||
class MatchTypeError(TypeError):
|
||||
"""Raised when the return types of match cases are different."""
|
||||
|
||||
pass
|
||||
|
@ -7,7 +7,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Union
|
||||
|
||||
from reflex import constants
|
||||
from reflex.utils import exceptions, serializers, types
|
||||
@ -272,6 +272,41 @@ def format_cond(
|
||||
return wrap(f"{cond} ? {true_value} : {false_value}", "{")
|
||||
|
||||
|
||||
def format_match(cond: str | Var, match_cases: List[BaseVar], default: Var) -> str:
|
||||
"""Format a match expression whose return type is a Var.
|
||||
|
||||
Args:
|
||||
cond: The condition.
|
||||
match_cases: The list of cases to match.
|
||||
default: The default case.
|
||||
|
||||
Returns:
|
||||
The formatted match expression
|
||||
|
||||
"""
|
||||
switch_code = f"(() => {{ switch (JSON.stringify({cond})) {{"
|
||||
|
||||
for case in match_cases:
|
||||
conditions = case[:-1]
|
||||
return_value = case[-1]
|
||||
|
||||
case_conditions = " ".join(
|
||||
[
|
||||
f"case JSON.stringify({condition._var_name_unwrapped}):"
|
||||
for condition in conditions
|
||||
]
|
||||
)
|
||||
case_code = (
|
||||
f"{case_conditions} return ({return_value._var_name_unwrapped}); break;"
|
||||
)
|
||||
switch_code += case_code
|
||||
|
||||
switch_code += f"default: return ({default._var_name_unwrapped}); break;"
|
||||
switch_code += "};})()"
|
||||
|
||||
return switch_code
|
||||
|
||||
|
||||
def format_prop(
|
||||
prop: Union[Var, EventChain, ComponentStyle, str],
|
||||
) -> Union[int, float, str]:
|
||||
|
@ -1534,6 +1534,26 @@ class Var:
|
||||
"""
|
||||
return self._var_data.state if self._var_data else ""
|
||||
|
||||
@property
|
||||
def _var_name_unwrapped(self) -> str:
|
||||
"""Get the var str without wrapping in curly braces.
|
||||
|
||||
Returns:
|
||||
The str var without the wrapped curly braces
|
||||
"""
|
||||
type_ = (
|
||||
get_origin(self._var_type)
|
||||
if types.is_generic_alias(self._var_type)
|
||||
else self._var_type
|
||||
)
|
||||
|
||||
wrapped_var = str(self)
|
||||
return (
|
||||
wrapped_var
|
||||
if not self._var_state and issubclass(type_, dict)
|
||||
else wrapped_var.strip("{}")
|
||||
)
|
||||
|
||||
|
||||
# Allow automatic serialization of Var within JSON structures
|
||||
serializers.serializer(_encode_var)
|
||||
|
@ -30,6 +30,7 @@ EXCLUDED_FILES = [
|
||||
"bare.py",
|
||||
"foreach.py",
|
||||
"cond.py",
|
||||
"match.py",
|
||||
"multiselect.py",
|
||||
"literals.py",
|
||||
]
|
||||
|
306
tests/components/layout/test_match.py
Normal file
306
tests/components/layout/test_match.py
Normal file
@ -0,0 +1,306 @@
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
import reflex as rx
|
||||
from reflex.components.core.match import Match
|
||||
from reflex.state import BaseState
|
||||
from reflex.utils.exceptions import MatchTypeError
|
||||
from reflex.vars import BaseVar
|
||||
|
||||
|
||||
class MatchState(BaseState):
|
||||
"""A test state."""
|
||||
|
||||
value: int = 0
|
||||
num: int = 5
|
||||
string: str = "random string"
|
||||
|
||||
|
||||
def test_match_components():
|
||||
"""Test matching cases with return values as components."""
|
||||
match_case_tuples = (
|
||||
(1, rx.text("first value")),
|
||||
(2, 3, rx.text("second value")),
|
||||
([1, 2], rx.text("third value")),
|
||||
("random", rx.text("fourth value")),
|
||||
({"foo": "bar"}, rx.text("fifth value")),
|
||||
(MatchState.num + 1, rx.text("sixth value")),
|
||||
rx.text("default value"),
|
||||
)
|
||||
match_comp = Match.create(MatchState.value, *match_case_tuples)
|
||||
match_dict = match_comp.render() # type: ignore
|
||||
assert match_dict["name"] == "Fragment"
|
||||
|
||||
[match_child] = match_dict["children"]
|
||||
|
||||
assert match_child["name"] == "match"
|
||||
assert str(match_child["cond"]) == "{match_state.value}"
|
||||
|
||||
match_cases = match_child["match_cases"]
|
||||
assert len(match_cases) == 6
|
||||
|
||||
assert match_cases[0][0]._var_name == "1"
|
||||
assert match_cases[0][0]._var_type == int
|
||||
first_return_value_render = match_cases[0][1].render()
|
||||
assert first_return_value_render["name"] == "Text"
|
||||
assert first_return_value_render["children"][0]["contents"] == "{`first value`}"
|
||||
|
||||
assert match_cases[1][0]._var_name == "2"
|
||||
assert match_cases[1][0]._var_type == int
|
||||
assert match_cases[1][1]._var_name == "3"
|
||||
assert match_cases[1][1]._var_type == int
|
||||
second_return_value_render = match_cases[1][2].render()
|
||||
assert second_return_value_render["name"] == "Text"
|
||||
assert second_return_value_render["children"][0]["contents"] == "{`second value`}"
|
||||
|
||||
assert match_cases[2][0]._var_name == "[1, 2]"
|
||||
assert match_cases[2][0]._var_type == list
|
||||
third_return_value_render = match_cases[2][1].render()
|
||||
assert third_return_value_render["name"] == "Text"
|
||||
assert third_return_value_render["children"][0]["contents"] == "{`third value`}"
|
||||
|
||||
assert match_cases[3][0]._var_name == "random"
|
||||
assert match_cases[3][0]._var_type == str
|
||||
fourth_return_value_render = match_cases[3][1].render()
|
||||
assert fourth_return_value_render["name"] == "Text"
|
||||
assert fourth_return_value_render["children"][0]["contents"] == "{`fourth value`}"
|
||||
|
||||
assert match_cases[4][0]._var_name == '{"foo": "bar"}'
|
||||
assert match_cases[4][0]._var_type == dict
|
||||
fifth_return_value_render = match_cases[4][1].render()
|
||||
assert fifth_return_value_render["name"] == "Text"
|
||||
assert fifth_return_value_render["children"][0]["contents"] == "{`fifth value`}"
|
||||
|
||||
assert match_cases[5][0]._var_name == "(match_state.num + 1)"
|
||||
assert match_cases[5][0]._var_type == int
|
||||
fifth_return_value_render = match_cases[5][1].render()
|
||||
assert fifth_return_value_render["name"] == "Text"
|
||||
assert fifth_return_value_render["children"][0]["contents"] == "{`sixth value`}"
|
||||
|
||||
default = match_child["default"].render()
|
||||
|
||||
assert default["name"] == "Text"
|
||||
assert default["children"][0]["contents"] == "{`default value`}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cases, expected",
|
||||
[
|
||||
(
|
||||
(
|
||||
(1, "first"),
|
||||
(2, 3, "second value"),
|
||||
([1, 2], "third-value"),
|
||||
("random", "fourth_value"),
|
||||
({"foo": "bar"}, "fifth value"),
|
||||
(MatchState.num + 1, "sixth value"),
|
||||
(f"{MatchState.value} - string", MatchState.string),
|
||||
(MatchState.string, f"{MatchState.value} - string"),
|
||||
"default value",
|
||||
),
|
||||
"(() => { switch (JSON.stringify(match_state.value)) {case JSON.stringify(1): return (`first`); break;case JSON.stringify(2): case JSON.stringify(3): return "
|
||||
"(`second value`); break;case JSON.stringify([1, 2]): return (`third-value`); break;case JSON.stringify(`random`): "
|
||||
'return (`fourth_value`); break;case JSON.stringify({"foo": "bar"}): return (`fifth value`); '
|
||||
"break;case JSON.stringify((match_state.num + 1)): return (`sixth value`); break;case JSON.stringify(`${match_state.value} - string`): "
|
||||
"return (match_state.string); break;case JSON.stringify(match_state.string): return (`${match_state.value} - string`); break;default: "
|
||||
"return (`default value`); break;};})()",
|
||||
),
|
||||
(
|
||||
(
|
||||
(1, "first"),
|
||||
(2, 3, "second value"),
|
||||
([1, 2], "third-value"),
|
||||
("random", "fourth_value"),
|
||||
({"foo": "bar"}, "fifth value"),
|
||||
(MatchState.num + 1, "sixth value"),
|
||||
(f"{MatchState.value} - string", MatchState.string),
|
||||
(MatchState.string, f"{MatchState.value} - string"),
|
||||
MatchState.string,
|
||||
),
|
||||
"(() => { switch (JSON.stringify(match_state.value)) {case JSON.stringify(1): return (`first`); break;case JSON.stringify(2): case JSON.stringify(3): return "
|
||||
"(`second value`); break;case JSON.stringify([1, 2]): return (`third-value`); break;case JSON.stringify(`random`): "
|
||||
'return (`fourth_value`); break;case JSON.stringify({"foo": "bar"}): return (`fifth value`); '
|
||||
"break;case JSON.stringify((match_state.num + 1)): return (`sixth value`); break;case JSON.stringify(`${match_state.value} - string`): "
|
||||
"return (match_state.string); break;case JSON.stringify(match_state.string): return (`${match_state.value} - string`); break;default: "
|
||||
"return (match_state.string); break;};})()",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_match_vars(cases, expected):
|
||||
"""Test matching cases with return values as Vars.
|
||||
|
||||
Args:
|
||||
cases: The match cases.
|
||||
expected: The expected var full name.
|
||||
"""
|
||||
match_comp = Match.create(MatchState.value, *cases)
|
||||
assert isinstance(match_comp, BaseVar)
|
||||
assert match_comp._var_full_name == expected
|
||||
|
||||
|
||||
def test_match_on_component_without_default():
|
||||
"""Test that matching cases with return values as components returns a Fragment
|
||||
as the default case if not provided.
|
||||
"""
|
||||
match_case_tuples = (
|
||||
(1, rx.text("first value")),
|
||||
(2, 3, rx.text("second value")),
|
||||
)
|
||||
|
||||
match_comp = Match.create(MatchState.value, *match_case_tuples)
|
||||
default = match_comp.render()["children"][0]["default"] # type: ignore
|
||||
|
||||
assert isinstance(default, rx.Fragment)
|
||||
|
||||
|
||||
def test_match_on_var_no_default():
|
||||
"""Test that an error is thrown when cases with return Values as Var do not have a default case."""
|
||||
match_case_tuples = (
|
||||
(1, "red"),
|
||||
(2, 3, "blue"),
|
||||
([1, 2], "green"),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="For cases with return types as Vars, a default case must be provided",
|
||||
):
|
||||
Match.create(MatchState.value, *match_case_tuples)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"match_case",
|
||||
[
|
||||
(
|
||||
(1, "red"),
|
||||
(2, 3, "blue"),
|
||||
"black",
|
||||
([1, 2], "green"),
|
||||
),
|
||||
(
|
||||
(1, rx.text("first value")),
|
||||
(2, 3, rx.text("second value")),
|
||||
([1, 2], rx.text("third value")),
|
||||
rx.text("default value"),
|
||||
("random", rx.text("fourth value")),
|
||||
({"foo": "bar"}, rx.text("fifth value")),
|
||||
(MatchState.num + 1, rx.text("sixth value")),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_match_default_not_last_arg(match_case):
|
||||
"""Test that an error is thrown when the default case is not the last arg.
|
||||
|
||||
Args:
|
||||
match_case: The cases to match.
|
||||
"""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="rx.match should have tuples of cases and a default case as the last argument.",
|
||||
):
|
||||
Match.create(MatchState.value, *match_case)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"match_case",
|
||||
[
|
||||
(
|
||||
(1, "red"),
|
||||
(2, 3, "blue"),
|
||||
("green",),
|
||||
"black",
|
||||
),
|
||||
(
|
||||
(1, rx.text("first value")),
|
||||
(2, 3, rx.text("second value")),
|
||||
([1, 2],),
|
||||
rx.text("default value"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_match_case_tuple_elements(match_case):
|
||||
"""Test that a match has at least 2 elements(a condition and a return value).
|
||||
|
||||
Args:
|
||||
match_case: The cases to match.
|
||||
"""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="A case tuple should have at least a match case element and a return value.",
|
||||
):
|
||||
Match.create(MatchState.value, *match_case)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cases, error_msg",
|
||||
[
|
||||
(
|
||||
(
|
||||
(1, rx.text("first value")),
|
||||
(2, 3, rx.text("second value")),
|
||||
([1, 2], rx.text("third value")),
|
||||
("random", "red"),
|
||||
({"foo": "bar"}, "green"),
|
||||
(MatchState.num + 1, "black"),
|
||||
rx.text("default value"),
|
||||
),
|
||||
"Match cases should have the same return types. Case 3 with return value `red` of type "
|
||||
"<class 'reflex.vars.BaseVar'> is not <class 'reflex.components.component.BaseComponent'>",
|
||||
),
|
||||
(
|
||||
(
|
||||
("random", "red"),
|
||||
({"foo": "bar"}, "green"),
|
||||
(MatchState.num + 1, "black"),
|
||||
(1, rx.text("first value")),
|
||||
(2, 3, rx.text("second value")),
|
||||
([1, 2], rx.text("third value")),
|
||||
rx.text("default value"),
|
||||
),
|
||||
"Match cases should have the same return types. Case 3 with return value `<Text> {`first value`} </Text>` "
|
||||
"of type <class 'reflex.components.chakra.typography.text.Text'> is not <class 'reflex.vars.BaseVar'>",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_match_different_return_types(cases: Tuple, error_msg: str):
|
||||
"""Test that an error is thrown when the return values are of different types.
|
||||
|
||||
Args:
|
||||
cases: The match cases.
|
||||
error_msg: Expected error message.
|
||||
"""
|
||||
with pytest.raises(MatchTypeError, match=error_msg):
|
||||
Match.create(MatchState.value, *cases)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"match_case",
|
||||
[
|
||||
(
|
||||
(1, "red"),
|
||||
(2, 3, "blue"),
|
||||
([1, 2], "green"),
|
||||
"black",
|
||||
"white",
|
||||
),
|
||||
(
|
||||
(1, rx.text("first value")),
|
||||
(2, 3, rx.text("second value")),
|
||||
([1, 2], rx.text("third value")),
|
||||
("random", rx.text("fourth value")),
|
||||
({"foo": "bar"}, rx.text("fifth value")),
|
||||
(MatchState.num + 1, rx.text("sixth value")),
|
||||
rx.text("default value"),
|
||||
rx.text("another default value"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_match_multiple_default_cases(match_case):
|
||||
"""Test that there is only one default case.
|
||||
|
||||
Args:
|
||||
match_case: the cases to match.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="rx.match can only have one default case."):
|
||||
Match.create(MatchState.value, *match_case)
|
@ -896,6 +896,7 @@ def test_instantiate_all_components():
|
||||
"FormControl",
|
||||
"Html",
|
||||
"Icon",
|
||||
"Match",
|
||||
"Markdown",
|
||||
"MultiSelect",
|
||||
"Option",
|
||||
|
@ -24,6 +24,13 @@ test_vars = [
|
||||
]
|
||||
|
||||
|
||||
class ATestState(BaseState):
|
||||
"""Test state."""
|
||||
|
||||
value: str
|
||||
dict_val: Dict[str, List] = {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def TestObj():
|
||||
class TestObj(Base):
|
||||
@ -1137,3 +1144,22 @@ def test_invalid_var_operations(operand1_var: Var, operand2_var, operators: List
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
operand1_var.operation(op=operator, other=operand2_var, flip=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"var, expected",
|
||||
[
|
||||
(Var.create("string_value", _var_is_string=True), "`string_value`"),
|
||||
(Var.create(1), "1"),
|
||||
(Var.create([1, 2, 3]), "[1, 2, 3]"),
|
||||
(Var.create({"foo": "bar"}), '{"foo": "bar"}'),
|
||||
(Var.create(ATestState.value, _var_is_string=True), "a_test_state.value"),
|
||||
(
|
||||
Var.create(f"{ATestState.value} string", _var_is_string=True),
|
||||
"`${a_test_state.value} string`",
|
||||
),
|
||||
(Var.create(ATestState.dict_val), "a_test_state.dict_val"),
|
||||
],
|
||||
)
|
||||
def test_var_name_unwrapped(var, expected):
|
||||
assert var._var_name_unwrapped == expected
|
||||
|
@ -1,5 +1,5 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
import pytest
|
||||
|
||||
@ -294,6 +294,42 @@ def test_format_cond(condition: str, true_value: str, false_value: str, expected
|
||||
assert format.format_cond(condition, true_value, false_value) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"condition, match_cases, default,expected",
|
||||
[
|
||||
(
|
||||
"state__state.value",
|
||||
[
|
||||
[Var.create(1), Var.create("red", _var_is_string=True)],
|
||||
[Var.create(2), Var.create(3), Var.create("blue", _var_is_string=True)],
|
||||
[TestState.mapping, TestState.num1],
|
||||
[
|
||||
Var.create(f"{TestState.map_key}-key", _var_is_string=True),
|
||||
Var.create("return-key", _var_is_string=True),
|
||||
],
|
||||
],
|
||||
Var.create("yellow", _var_is_string=True),
|
||||
"(() => { switch (JSON.stringify(state__state.value)) {case JSON.stringify(1): return (`red`); break;case JSON.stringify(2): case JSON.stringify(3): "
|
||||
"return (`blue`); break;case JSON.stringify(test_state.mapping): return "
|
||||
"(test_state.num1); break;case JSON.stringify(`${test_state.map_key}-key`): return (`return-key`);"
|
||||
" break;default: return (`yellow`); break;};})()",
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_format_match(
|
||||
condition: str, match_cases: List[BaseVar], default: BaseVar, expected: str
|
||||
):
|
||||
"""Test formatting a match statement.
|
||||
|
||||
Args:
|
||||
condition: The condition to match.
|
||||
match_cases: List of match cases to be matched.
|
||||
default: Catchall case for the match statement.
|
||||
expected: The expected string output.
|
||||
"""
|
||||
assert format.format_match(condition, match_cases, default) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prop,formatted",
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user