From abfc099779c510133f704598df14b3e75012396c Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Thu, 4 Jan 2024 17:48:18 +0000 Subject: [PATCH] rx.match component (#2318) * initial commit * add more tests * refactor match jinja template * add docstrings * cleanup * possible fix for pyright * fix conflicts * fix conflicts again * comments * fixed bug from review * fix tests * address PR comment * fix tests * type error improvement * formatting * darglint fix * more tests * stringify switch condition and cases as js doesnt support complex types(lists and dicts) in switch cases. * Update reflex/vars.py Co-authored-by: Masen Furer * change usages * Precommit fix --------- Co-authored-by: Alek Petuskey Co-authored-by: Masen Furer Co-authored-by: Alek Petuskey --- .../jinja/web/pages/utils.js.jinja2 | 24 ++ reflex/__init__.py | 1 + reflex/__init__.pyi | 2 + reflex/app.py | 5 +- reflex/components/core/__init__.py | 2 + reflex/components/core/match.py | 257 +++++++++++++++ reflex/components/tags/__init__.py | 1 + reflex/components/tags/match_tag.py | 19 ++ reflex/utils/exceptions.py | 6 + reflex/utils/format.py | 37 ++- reflex/vars.py | 20 ++ scripts/pyi_generator.py | 1 + tests/components/layout/test_match.py | 306 ++++++++++++++++++ tests/components/test_component.py | 1 + tests/test_var.py | 26 ++ tests/utils/test_format.py | 38 ++- 16 files changed, 743 insertions(+), 3 deletions(-) create mode 100644 reflex/components/core/match.py create mode 100644 reflex/components/tags/match_tag.py create mode 100644 tests/components/layout/test_match.py diff --git a/reflex/.templates/jinja/web/pages/utils.js.jinja2 b/reflex/.templates/jinja/web/pages/utils.js.jinja2 index 144e228f5..03890e9b6 100644 --- a/reflex/.templates/jinja/web/pages/utils.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/utils.js.jinja2 @@ -8,6 +8,8 @@ {{- component }} {%- elif "iterable" in component %} {{- render_iterable_tag(component) }} + {%- elif component.name == "match"%} + {{- render_match_tag(component) }} {%- elif "cond" in component %} {{- render_condition_tag(component) }} {%- elif component.children|length %} @@ -77,6 +79,28 @@ {% if props|length %} {{ props|join(" ") }}{% endif %} {% endmacro %} +{# Rendering Match component. #} +{# Args: #} +{# component: component dictionary #} +{% macro render_match_tag(component) %} +{ + (() => { + switch (JSON.stringify({{ component.cond._var_full_name }})) { + {% for case in component.match_cases %} + {% for condition in case[:-1] %} + case JSON.stringify({{ condition._var_name_unwrapped }}): + {% endfor %} + return {{ case[-1] }}; + break; + {% endfor %} + default: + return {{ component.default }}; + break; + } + })() + } +{%- endmacro %} + {# Rendering content with args. #} {# Args: #} diff --git a/reflex/__init__.py b/reflex/__init__.py index 012bdf196..0d60550b5 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -114,6 +114,7 @@ _ALL_COMPONENTS = [ "List", "ListItem", "Markdown", + "Match", "Menu", "MenuButton", "MenuDivider", diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index 482de65ac..1f00ca652 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -107,6 +107,7 @@ from reflex.components import LinkOverlay as LinkOverlay from reflex.components import List as List from reflex.components import ListItem as ListItem from reflex.components import Markdown as Markdown +from reflex.components import Match as Match from reflex.components import Menu as Menu from reflex.components import MenuButton as MenuButton from reflex.components import MenuDivider as MenuDivider @@ -317,6 +318,7 @@ from reflex.components import link_overlay as link_overlay from reflex.components import list as list from reflex.components import list_item as list_item from reflex.components import markdown as markdown +from reflex.components import match as match from reflex.components import menu as menu from reflex.components import menu_button as menu_button from reflex.components import menu_divider as menu_divider diff --git a/reflex/app.py b/reflex/app.py index 6dcd3760e..781c50df3 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -64,7 +64,7 @@ from reflex.state import ( StateManager, StateUpdate, ) -from reflex.utils import console, format, prerequisites, types +from reflex.utils import console, exceptions, format, prerequisites, types from reflex.utils.imports import ImportVar # Define custom types. @@ -344,9 +344,12 @@ class App(Base): Raises: TypeError: When an invalid component function is passed. + exceptions.MatchTypeError: If the return types of match cases in rx.match are different. """ try: return component if isinstance(component, Component) else component() + except exceptions.MatchTypeError: + raise except TypeError as e: message = str(e) if "BaseVar" in message or "ComputedVar" in message: diff --git a/reflex/components/core/__init__.py b/reflex/components/core/__init__.py index b83e94c99..17aecc228 100644 --- a/reflex/components/core/__init__.py +++ b/reflex/components/core/__init__.py @@ -4,6 +4,7 @@ from .banner import ConnectionBanner, ConnectionModal from .cond import Cond, cond from .debounce import DebounceInput from .foreach import Foreach +from .match import Match from .responsive import ( desktop_only, mobile_and_tablet, @@ -17,4 +18,5 @@ connection_banner = ConnectionBanner.create connection_modal = ConnectionModal.create debounce_input = DebounceInput.create foreach = Foreach.create +match = Match.create upload = Upload.create diff --git a/reflex/components/core/match.py b/reflex/components/core/match.py new file mode 100644 index 000000000..efcbe97c0 --- /dev/null +++ b/reflex/components/core/match.py @@ -0,0 +1,257 @@ +"""rx.match.""" +import textwrap +from typing import Any, Dict, List, Optional, Tuple, Union + +from reflex.components.base import Fragment +from reflex.components.component import BaseComponent, Component, MemoizationLeaf +from reflex.components.tags import MatchTag, Tag +from reflex.utils import format, imports, types +from reflex.utils.exceptions import MatchTypeError +from reflex.vars import BaseVar, Var, VarData + + +class Match(MemoizationLeaf): + """Match cases based on a condition.""" + + # The condition to determine which case to match. + cond: Var[Any] + + # The list of match cases to be matched. + match_cases: List[Any] = [] + + # The catchall case to match. + default: Any + + @classmethod + def create(cls, cond: Any, *cases) -> Union[Component, BaseVar]: + """Create a Match Component. + + Args: + cond: The condition to determine which case to match. + cases: This list of cases to match. + + Returns: + The match component. + + Raises: + ValueError: When a default case is not provided for cases with Var return types. + """ + match_cond_var = cls._create_condition_var(cond) + cases, default = cls._process_cases(list(cases)) + match_cases = cls._process_match_cases(cases) + + cls._validate_return_types(match_cases) + + if default is None and types._issubclass(type(match_cases[0][-1]), BaseVar): + raise ValueError( + "For cases with return types as Vars, a default case must be provided" + ) + + return cls._create_match_cond_var_or_component( + match_cond_var, match_cases, default + ) + + @classmethod + def _create_condition_var(cls, cond: Any) -> BaseVar: + """Convert the condition to a Var. + + Args: + cond: The condition. + + Returns: + The condition as a base var + + Raises: + ValueError: If the condition is not provided. + """ + match_cond_var = Var.create(cond) + if match_cond_var is None: + raise ValueError("The condition must be set") + return match_cond_var # type: ignore + + @classmethod + def _process_cases( + cls, cases: List + ) -> Tuple[List, Optional[Union[BaseVar, BaseComponent]]]: + """Process the list of match cases and the catchall default case. + + Args: + cases: The list of match cases. + + Returns: + The default case and the list of match case tuples. + + Raises: + ValueError: If there are multiple default cases. + """ + default = None + + if len([case for case in cases if not isinstance(case, tuple)]) > 1: + raise ValueError("rx.match can only have one default case.") + + # Get the default case which should be the last non-tuple arg + if not isinstance(cases[-1], tuple): + default = cases.pop() + default = ( + Var.create(default, _var_is_string=type(default) is str) + if not isinstance(default, BaseComponent) + else default + ) + + return cases, default # type: ignore + + @classmethod + def _process_match_cases(cls, cases: List) -> List[List[BaseVar]]: + """Process the individual match cases. + + Args: + cases: The match cases. + + Returns: + The processed match cases. + + Raises: + ValueError: If the default case is not the last case or the tuple elements are less than 2. + """ + match_cases = [] + for case in cases: + if not isinstance(case, tuple): + raise ValueError( + "rx.match should have tuples of cases and a default case as the last argument." + ) + # There should be at least two elements in a case tuple(a condition and return value) + if len(case) < 2: + raise ValueError( + "A case tuple should have at least a match case element and a return value." + ) + + case_list = [] + for element in case: + # convert all non component element to vars. + el = ( + Var.create(element, _var_is_string=type(element) is str) + if not isinstance(element, BaseComponent) + else element + ) + if not isinstance(el, (BaseVar, BaseComponent)): + raise ValueError("Case element must be a var or component") + case_list.append(el) + + match_cases.append(case_list) + + return match_cases + + @classmethod + def _validate_return_types(cls, match_cases: List[List[BaseVar]]) -> None: + """Validate that match cases have the same return types. + + Args: + match_cases: The match cases. + + Raises: + MatchTypeError: If the return types of cases are different. + """ + first_case_return = match_cases[0][-1] + return_type = type(first_case_return) + + if types._isinstance(first_case_return, BaseComponent): + return_type = BaseComponent + elif types._isinstance(first_case_return, BaseVar): + return_type = BaseVar + + for index, case in enumerate(match_cases): + if not types._issubclass(type(case[-1]), return_type): + raise MatchTypeError( + f"Match cases should have the same return types. Case {index} with return " + f"value `{case[-1]._var_name if isinstance(case[-1], BaseVar) else textwrap.shorten(str(case[-1]), width=250)}`" + f" of type {type(case[-1])!r} is not {return_type}" + ) + + @classmethod + def _create_match_cond_var_or_component( + cls, + match_cond_var: Var, + match_cases: List[List[BaseVar]], + default: Optional[Union[BaseVar, BaseComponent]], + ) -> Union[Component, BaseVar]: + """Create and return the match condition var or component. + + Args: + match_cond_var: The match condition. + match_cases: The list of match cases. + default: The default case. + + Returns: + The match component wrapped in a fragment or the match var. + + Raises: + ValueError: If the return types are not vars when creating a match var for Var types. + """ + if default is None and types._issubclass( + type(match_cases[0][-1]), BaseComponent + ): + default = Fragment.create() + + if types._issubclass(type(match_cases[0][-1]), BaseComponent): + return Fragment.create( + cls( + cond=match_cond_var, + match_cases=match_cases, + default=default, + ) + ) + + # Validate the match cases (as well as the default case) to have Var return types. + if any( + case for case in match_cases if not types._isinstance(case[-1], BaseVar) + ) or not types._isinstance(default, BaseVar): + raise ValueError("Return types of match cases should be Vars.") + + # match cases and default should all be Vars at this point. + # Retrieve var data of every var in the match cases and default. + var_data = [ + *[el._var_data for case in match_cases for el in case], + default._var_data, # type: ignore + ] + + return match_cond_var._replace( + _var_name=format.format_match( + cond=match_cond_var._var_full_name, + match_cases=match_cases, # type: ignore + default=default, # type: ignore + ), + _var_type=default._var_type, # type: ignore + _var_is_local=False, + _var_full_name_needs_state_prefix=False, + merge_var_data=VarData.merge(*var_data), + ) + + def _render(self) -> Tag: + return MatchTag( + cond=self.cond, match_cases=self.match_cases, default=self.default + ) + + def render(self) -> Dict: + """Render the component. + + Returns: + The dictionary for template of component. + """ + tag = self._render() + tag.name = "match" + return dict(tag) + + def _get_imports(self): + merged_imports = super()._get_imports() + # Obtain the imports of all components the in match case. + for case in self.match_cases: + if isinstance(case[-1], BaseComponent): + merged_imports = imports.merge_imports( + merged_imports, case[-1].get_imports() + ) + # Get the import of the default case component. + if isinstance(self.default, BaseComponent): + merged_imports = imports.merge_imports( + merged_imports, self.default.get_imports() + ) + return merged_imports diff --git a/reflex/components/tags/__init__.py b/reflex/components/tags/__init__.py index b7eb8c4b5..993da11fe 100644 --- a/reflex/components/tags/__init__.py +++ b/reflex/components/tags/__init__.py @@ -2,4 +2,5 @@ from .cond_tag import CondTag from .iter_tag import IterTag +from .match_tag import MatchTag from .tag import Tag diff --git a/reflex/components/tags/match_tag.py b/reflex/components/tags/match_tag.py new file mode 100644 index 000000000..c2f6649d5 --- /dev/null +++ b/reflex/components/tags/match_tag.py @@ -0,0 +1,19 @@ +"""Tag to conditionally match cases.""" + +from typing import Any, List + +from reflex.components.tags.tag import Tag +from reflex.vars import Var + + +class MatchTag(Tag): + """A match tag.""" + + # The condition to determine which case to match. + cond: Var[Any] + + # The list of match cases to be matched. + match_cases: List[Any] + + # The catchall case to match. + default: Any diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 878f6fb16..e4a9a6e6c 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -13,3 +13,9 @@ class ImmutableStateError(AttributeError): class LockExpiredError(Exception): """Raised when the state lock expires while an event is being processed.""" + + +class MatchTypeError(TypeError): + """Raised when the return types of match cases are different.""" + + pass diff --git a/reflex/utils/format.py b/reflex/utils/format.py index b42ed2b10..46fa41436 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -7,7 +7,7 @@ import json import os import re import sys -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, List, Union from reflex import constants from reflex.utils import exceptions, serializers, types @@ -272,6 +272,41 @@ def format_cond( return wrap(f"{cond} ? {true_value} : {false_value}", "{") +def format_match(cond: str | Var, match_cases: List[BaseVar], default: Var) -> str: + """Format a match expression whose return type is a Var. + + Args: + cond: The condition. + match_cases: The list of cases to match. + default: The default case. + + Returns: + The formatted match expression + + """ + switch_code = f"(() => {{ switch (JSON.stringify({cond})) {{" + + for case in match_cases: + conditions = case[:-1] + return_value = case[-1] + + case_conditions = " ".join( + [ + f"case JSON.stringify({condition._var_name_unwrapped}):" + for condition in conditions + ] + ) + case_code = ( + f"{case_conditions} return ({return_value._var_name_unwrapped}); break;" + ) + switch_code += case_code + + switch_code += f"default: return ({default._var_name_unwrapped}); break;" + switch_code += "};})()" + + return switch_code + + def format_prop( prop: Union[Var, EventChain, ComponentStyle, str], ) -> Union[int, float, str]: diff --git a/reflex/vars.py b/reflex/vars.py index 2f31f1190..04351609f 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1534,6 +1534,26 @@ class Var: """ return self._var_data.state if self._var_data else "" + @property + def _var_name_unwrapped(self) -> str: + """Get the var str without wrapping in curly braces. + + Returns: + The str var without the wrapped curly braces + """ + type_ = ( + get_origin(self._var_type) + if types.is_generic_alias(self._var_type) + else self._var_type + ) + + wrapped_var = str(self) + return ( + wrapped_var + if not self._var_state and issubclass(type_, dict) + else wrapped_var.strip("{}") + ) + # Allow automatic serialization of Var within JSON structures serializers.serializer(_encode_var) diff --git a/scripts/pyi_generator.py b/scripts/pyi_generator.py index b7143806b..9dc4141a8 100644 --- a/scripts/pyi_generator.py +++ b/scripts/pyi_generator.py @@ -30,6 +30,7 @@ EXCLUDED_FILES = [ "bare.py", "foreach.py", "cond.py", + "match.py", "multiselect.py", "literals.py", ] diff --git a/tests/components/layout/test_match.py b/tests/components/layout/test_match.py new file mode 100644 index 000000000..233e9c643 --- /dev/null +++ b/tests/components/layout/test_match.py @@ -0,0 +1,306 @@ +from typing import Tuple + +import pytest + +import reflex as rx +from reflex.components.core.match import Match +from reflex.state import BaseState +from reflex.utils.exceptions import MatchTypeError +from reflex.vars import BaseVar + + +class MatchState(BaseState): + """A test state.""" + + value: int = 0 + num: int = 5 + string: str = "random string" + + +def test_match_components(): + """Test matching cases with return values as components.""" + match_case_tuples = ( + (1, rx.text("first value")), + (2, 3, rx.text("second value")), + ([1, 2], rx.text("third value")), + ("random", rx.text("fourth value")), + ({"foo": "bar"}, rx.text("fifth value")), + (MatchState.num + 1, rx.text("sixth value")), + rx.text("default value"), + ) + match_comp = Match.create(MatchState.value, *match_case_tuples) + match_dict = match_comp.render() # type: ignore + assert match_dict["name"] == "Fragment" + + [match_child] = match_dict["children"] + + assert match_child["name"] == "match" + assert str(match_child["cond"]) == "{match_state.value}" + + match_cases = match_child["match_cases"] + assert len(match_cases) == 6 + + assert match_cases[0][0]._var_name == "1" + assert match_cases[0][0]._var_type == int + first_return_value_render = match_cases[0][1].render() + assert first_return_value_render["name"] == "Text" + assert first_return_value_render["children"][0]["contents"] == "{`first value`}" + + assert match_cases[1][0]._var_name == "2" + assert match_cases[1][0]._var_type == int + assert match_cases[1][1]._var_name == "3" + assert match_cases[1][1]._var_type == int + second_return_value_render = match_cases[1][2].render() + assert second_return_value_render["name"] == "Text" + assert second_return_value_render["children"][0]["contents"] == "{`second value`}" + + assert match_cases[2][0]._var_name == "[1, 2]" + assert match_cases[2][0]._var_type == list + third_return_value_render = match_cases[2][1].render() + assert third_return_value_render["name"] == "Text" + assert third_return_value_render["children"][0]["contents"] == "{`third value`}" + + assert match_cases[3][0]._var_name == "random" + assert match_cases[3][0]._var_type == str + fourth_return_value_render = match_cases[3][1].render() + assert fourth_return_value_render["name"] == "Text" + assert fourth_return_value_render["children"][0]["contents"] == "{`fourth value`}" + + assert match_cases[4][0]._var_name == '{"foo": "bar"}' + assert match_cases[4][0]._var_type == dict + fifth_return_value_render = match_cases[4][1].render() + assert fifth_return_value_render["name"] == "Text" + assert fifth_return_value_render["children"][0]["contents"] == "{`fifth value`}" + + assert match_cases[5][0]._var_name == "(match_state.num + 1)" + assert match_cases[5][0]._var_type == int + fifth_return_value_render = match_cases[5][1].render() + assert fifth_return_value_render["name"] == "Text" + assert fifth_return_value_render["children"][0]["contents"] == "{`sixth value`}" + + default = match_child["default"].render() + + assert default["name"] == "Text" + assert default["children"][0]["contents"] == "{`default value`}" + + +@pytest.mark.parametrize( + "cases, expected", + [ + ( + ( + (1, "first"), + (2, 3, "second value"), + ([1, 2], "third-value"), + ("random", "fourth_value"), + ({"foo": "bar"}, "fifth value"), + (MatchState.num + 1, "sixth value"), + (f"{MatchState.value} - string", MatchState.string), + (MatchState.string, f"{MatchState.value} - string"), + "default value", + ), + "(() => { switch (JSON.stringify(match_state.value)) {case JSON.stringify(1): return (`first`); break;case JSON.stringify(2): case JSON.stringify(3): return " + "(`second value`); break;case JSON.stringify([1, 2]): return (`third-value`); break;case JSON.stringify(`random`): " + 'return (`fourth_value`); break;case JSON.stringify({"foo": "bar"}): return (`fifth value`); ' + "break;case JSON.stringify((match_state.num + 1)): return (`sixth value`); break;case JSON.stringify(`${match_state.value} - string`): " + "return (match_state.string); break;case JSON.stringify(match_state.string): return (`${match_state.value} - string`); break;default: " + "return (`default value`); break;};})()", + ), + ( + ( + (1, "first"), + (2, 3, "second value"), + ([1, 2], "third-value"), + ("random", "fourth_value"), + ({"foo": "bar"}, "fifth value"), + (MatchState.num + 1, "sixth value"), + (f"{MatchState.value} - string", MatchState.string), + (MatchState.string, f"{MatchState.value} - string"), + MatchState.string, + ), + "(() => { switch (JSON.stringify(match_state.value)) {case JSON.stringify(1): return (`first`); break;case JSON.stringify(2): case JSON.stringify(3): return " + "(`second value`); break;case JSON.stringify([1, 2]): return (`third-value`); break;case JSON.stringify(`random`): " + 'return (`fourth_value`); break;case JSON.stringify({"foo": "bar"}): return (`fifth value`); ' + "break;case JSON.stringify((match_state.num + 1)): return (`sixth value`); break;case JSON.stringify(`${match_state.value} - string`): " + "return (match_state.string); break;case JSON.stringify(match_state.string): return (`${match_state.value} - string`); break;default: " + "return (match_state.string); break;};})()", + ), + ], +) +def test_match_vars(cases, expected): + """Test matching cases with return values as Vars. + + Args: + cases: The match cases. + expected: The expected var full name. + """ + match_comp = Match.create(MatchState.value, *cases) + assert isinstance(match_comp, BaseVar) + assert match_comp._var_full_name == expected + + +def test_match_on_component_without_default(): + """Test that matching cases with return values as components returns a Fragment + as the default case if not provided. + """ + match_case_tuples = ( + (1, rx.text("first value")), + (2, 3, rx.text("second value")), + ) + + match_comp = Match.create(MatchState.value, *match_case_tuples) + default = match_comp.render()["children"][0]["default"] # type: ignore + + assert isinstance(default, rx.Fragment) + + +def test_match_on_var_no_default(): + """Test that an error is thrown when cases with return Values as Var do not have a default case.""" + match_case_tuples = ( + (1, "red"), + (2, 3, "blue"), + ([1, 2], "green"), + ) + + with pytest.raises( + ValueError, + match="For cases with return types as Vars, a default case must be provided", + ): + Match.create(MatchState.value, *match_case_tuples) + + +@pytest.mark.parametrize( + "match_case", + [ + ( + (1, "red"), + (2, 3, "blue"), + "black", + ([1, 2], "green"), + ), + ( + (1, rx.text("first value")), + (2, 3, rx.text("second value")), + ([1, 2], rx.text("third value")), + rx.text("default value"), + ("random", rx.text("fourth value")), + ({"foo": "bar"}, rx.text("fifth value")), + (MatchState.num + 1, rx.text("sixth value")), + ), + ], +) +def test_match_default_not_last_arg(match_case): + """Test that an error is thrown when the default case is not the last arg. + + Args: + match_case: The cases to match. + """ + with pytest.raises( + ValueError, + match="rx.match should have tuples of cases and a default case as the last argument.", + ): + Match.create(MatchState.value, *match_case) + + +@pytest.mark.parametrize( + "match_case", + [ + ( + (1, "red"), + (2, 3, "blue"), + ("green",), + "black", + ), + ( + (1, rx.text("first value")), + (2, 3, rx.text("second value")), + ([1, 2],), + rx.text("default value"), + ), + ], +) +def test_match_case_tuple_elements(match_case): + """Test that a match has at least 2 elements(a condition and a return value). + + Args: + match_case: The cases to match. + """ + with pytest.raises( + ValueError, + match="A case tuple should have at least a match case element and a return value.", + ): + Match.create(MatchState.value, *match_case) + + +@pytest.mark.parametrize( + "cases, error_msg", + [ + ( + ( + (1, rx.text("first value")), + (2, 3, rx.text("second value")), + ([1, 2], rx.text("third value")), + ("random", "red"), + ({"foo": "bar"}, "green"), + (MatchState.num + 1, "black"), + rx.text("default value"), + ), + "Match cases should have the same return types. Case 3 with return value `red` of type " + " is not ", + ), + ( + ( + ("random", "red"), + ({"foo": "bar"}, "green"), + (MatchState.num + 1, "black"), + (1, rx.text("first value")), + (2, 3, rx.text("second value")), + ([1, 2], rx.text("third value")), + rx.text("default value"), + ), + "Match cases should have the same return types. Case 3 with return value ` {`first value`} ` " + "of type is not ", + ), + ], +) +def test_match_different_return_types(cases: Tuple, error_msg: str): + """Test that an error is thrown when the return values are of different types. + + Args: + cases: The match cases. + error_msg: Expected error message. + """ + with pytest.raises(MatchTypeError, match=error_msg): + Match.create(MatchState.value, *cases) + + +@pytest.mark.parametrize( + "match_case", + [ + ( + (1, "red"), + (2, 3, "blue"), + ([1, 2], "green"), + "black", + "white", + ), + ( + (1, rx.text("first value")), + (2, 3, rx.text("second value")), + ([1, 2], rx.text("third value")), + ("random", rx.text("fourth value")), + ({"foo": "bar"}, rx.text("fifth value")), + (MatchState.num + 1, rx.text("sixth value")), + rx.text("default value"), + rx.text("another default value"), + ), + ], +) +def test_match_multiple_default_cases(match_case): + """Test that there is only one default case. + + Args: + match_case: the cases to match. + """ + with pytest.raises(ValueError, match="rx.match can only have one default case."): + Match.create(MatchState.value, *match_case) diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 679fc8d80..7bdd6cc9f 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -896,6 +896,7 @@ def test_instantiate_all_components(): "FormControl", "Html", "Icon", + "Match", "Markdown", "MultiSelect", "Option", diff --git a/tests/test_var.py b/tests/test_var.py index a6b9deeae..c5ba3dc3c 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -24,6 +24,13 @@ test_vars = [ ] +class ATestState(BaseState): + """Test state.""" + + value: str + dict_val: Dict[str, List] = {} + + @pytest.fixture def TestObj(): class TestObj(Base): @@ -1137,3 +1144,22 @@ def test_invalid_var_operations(operand1_var: Var, operand2_var, operators: List with pytest.raises(TypeError): operand1_var.operation(op=operator, other=operand2_var, flip=True) + + +@pytest.mark.parametrize( + "var, expected", + [ + (Var.create("string_value", _var_is_string=True), "`string_value`"), + (Var.create(1), "1"), + (Var.create([1, 2, 3]), "[1, 2, 3]"), + (Var.create({"foo": "bar"}), '{"foo": "bar"}'), + (Var.create(ATestState.value, _var_is_string=True), "a_test_state.value"), + ( + Var.create(f"{ATestState.value} string", _var_is_string=True), + "`${a_test_state.value} string`", + ), + (Var.create(ATestState.dict_val), "a_test_state.dict_val"), + ], +) +def test_var_name_unwrapped(var, expected): + assert var._var_name_unwrapped == expected diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index f56b4d22b..8c62f734a 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -1,5 +1,5 @@ import datetime -from typing import Any +from typing import Any, List import pytest @@ -294,6 +294,42 @@ def test_format_cond(condition: str, true_value: str, false_value: str, expected assert format.format_cond(condition, true_value, false_value) == expected +@pytest.mark.parametrize( + "condition, match_cases, default,expected", + [ + ( + "state__state.value", + [ + [Var.create(1), Var.create("red", _var_is_string=True)], + [Var.create(2), Var.create(3), Var.create("blue", _var_is_string=True)], + [TestState.mapping, TestState.num1], + [ + Var.create(f"{TestState.map_key}-key", _var_is_string=True), + Var.create("return-key", _var_is_string=True), + ], + ], + Var.create("yellow", _var_is_string=True), + "(() => { switch (JSON.stringify(state__state.value)) {case JSON.stringify(1): return (`red`); break;case JSON.stringify(2): case JSON.stringify(3): " + "return (`blue`); break;case JSON.stringify(test_state.mapping): return " + "(test_state.num1); break;case JSON.stringify(`${test_state.map_key}-key`): return (`return-key`);" + " break;default: return (`yellow`); break;};})()", + ) + ], +) +def test_format_match( + condition: str, match_cases: List[BaseVar], default: BaseVar, expected: str +): + """Test formatting a match statement. + + Args: + condition: The condition to match. + match_cases: List of match cases to be matched. + default: Catchall case for the match statement. + expected: The expected string output. + """ + assert format.format_match(condition, match_cases, default) == expected + + @pytest.mark.parametrize( "prop,formatted", [