From 00019daa27292a098135d6438d8c5df579d66913 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 17 Jan 2025 16:17:24 -0800 Subject: [PATCH] get rid of match class --- .../jinja/web/pages/utils.js.jinja2 | 25 -- reflex/components/component.py | 13 +- reflex/components/core/__init__.py | 1 - reflex/components/core/match.py | 270 ++++++++---------- .../radix/themes/components/icon_button.py | 4 +- tests/units/components/core/test_match.py | 18 +- tests/units/components/test_component.py | 1 - 7 files changed, 137 insertions(+), 195 deletions(-) diff --git a/reflex/.templates/jinja/web/pages/utils.js.jinja2 b/reflex/.templates/jinja/web/pages/utils.js.jinja2 index 624e3bee8..d161e846d 100644 --- a/reflex/.templates/jinja/web/pages/utils.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/utils.js.jinja2 @@ -8,8 +8,6 @@ {{- component }} {%- elif "iterable" in component %} {{- render_iterable_tag(component) }} - {%- elif component.name == "match"%} - {{- render_match_tag(component) }} {%- elif "cond" in component %} {{- render_condition_tag(component) }} {%- elif component.children|length %} @@ -75,29 +73,6 @@ {% if props|length %} {{ props|join(" ") }}{% endif %} {% endmacro %} -{# Rendering Match component. #} -{# Args: #} -{# component: component dictionary #} -{% macro render_match_tag(component) %} -{ - (() => { - switch (JSON.stringify({{ component.cond._js_expr }})) { - {% for case in component.match_cases %} - {% for condition in case[:-1] %} - case JSON.stringify({{ condition._js_expr }}): - {% endfor %} - return {{ case[-1] }}; - break; - {% endfor %} - default: - return {{ component.default }}; - break; - } - })() - } -{%- endmacro %} - - {# Rendering content with args. #} {# Args: #} {# component: component dictionary #} diff --git a/reflex/components/component.py b/reflex/components/component.py index 015406d92..f3f69bbb2 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -934,7 +934,6 @@ class Component(BaseComponent, ABC): from reflex.components.base.fragment import Fragment from reflex.components.core.cond import Cond from reflex.components.core.foreach import Foreach - from reflex.components.core.match import Match no_valid_parents_defined = all(child._valid_parents == [] for child in children) if ( @@ -945,9 +944,7 @@ class Component(BaseComponent, ABC): return comp_name = type(self).__name__ - allowed_components = [ - comp.__name__ for comp in (Fragment, Foreach, Cond, Match) - ] + allowed_components = [comp.__name__ for comp in (Fragment, Foreach, Cond)] def validate_child(child): child_name = type(child).__name__ @@ -971,11 +968,6 @@ class Component(BaseComponent, ABC): for c in var_data.components: validate_child(c) - if isinstance(child, Match): - for cases in child.match_cases: - validate_child(cases[-1]) - validate_child(child.default) - if self._invalid_children and child_name in self._invalid_children: raise ValueError( f"The component `{comp_name}` cannot have `{child_name}` as a child component" @@ -2073,7 +2065,6 @@ class StatefulComponent(BaseComponent): from reflex.components.base.bare import Bare from reflex.components.core.cond import Cond from reflex.components.core.foreach import Foreach - from reflex.components.core.match import Match if isinstance(child, Bare): return child.contents @@ -2081,8 +2072,6 @@ class StatefulComponent(BaseComponent): return child.cond if isinstance(child, Foreach): return child.iterable - if isinstance(child, Match): - return child.cond return child @classmethod diff --git a/reflex/components/core/__init__.py b/reflex/components/core/__init__.py index fbe0bdc84..c61bf90e3 100644 --- a/reflex/components/core/__init__.py +++ b/reflex/components/core/__init__.py @@ -30,7 +30,6 @@ _SUBMOD_ATTRS: dict[str, list[str]] = { "html": ["html", "Html"], "match": [ "match", - "Match", ], "breakpoints": ["breakpoints", "set_breakpoints"], "responsive": [ diff --git a/reflex/components/core/match.py b/reflex/components/core/match.py index e697be7b2..d1b359e49 100644 --- a/reflex/components/core/match.py +++ b/reflex/components/core/match.py @@ -1,12 +1,12 @@ """rx.match.""" import textwrap -from typing import Any, List, cast +from typing import Any, cast from typing_extensions import Unpack from reflex.components.base import Fragment -from reflex.components.component import BaseComponent, Component, MemoizationLeaf +from reflex.components.component import BaseComponent, Component from reflex.utils import types from reflex.utils.exceptions import MatchTypeError from reflex.vars.base import VAR_TYPE, Var @@ -15,155 +15,135 @@ from reflex.vars.number import MatchOperation CASE_TYPE = tuple[Unpack[tuple[Any, ...]], Var[VAR_TYPE] | VAR_TYPE] -class Match(MemoizationLeaf): - """Match cases based on a condition.""" +def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]): + """Process the individual match cases. - # The condition to determine which case to match. - cond: Var[Any] + Args: + cases: The match cases. - # The list of match cases to be matched. - match_cases: List[Any] = [] - - # The catchall case to match. - default: Any - - @classmethod - def create( - cls, - cond: Any, - *cases: Unpack[ - tuple[Unpack[tuple[CASE_TYPE[VAR_TYPE], ...]], Var[VAR_TYPE] | VAR_TYPE] - ], - ) -> Var[VAR_TYPE]: - """Create a Match Component. - - Args: - cond: The condition to determine which case to match. - cases: This list of cases to match. - - Returns: - The match component. - - Raises: - ValueError: When a default case is not provided for cases with Var return types. - """ - default = None - - if len([case for case in cases if not isinstance(case, tuple)]) > 1: - raise ValueError("rx.match can only have one default case.") - - if not cases: - raise ValueError("rx.match should have at least one case.") - - # Get the default case which should be the last non-tuple arg - if not isinstance(cases[-1], tuple): - default = cases[-1] - actual_cases = cases[:-1] - else: - actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases) - - cls._process_match_cases(actual_cases) - - cls._validate_return_types(actual_cases) - - if default is None and any( - not ( - isinstance((return_type := case[-1]), Component) - or ( - isinstance(return_type, Var) - and types.typehint_issubclass(return_type._var_type, Component) - ) - ) - for case in actual_cases - ): + Raises: + ValueError: If the default case is not the last case or the tuple elements are less than 2. + """ + for case in cases: + if not isinstance(case, tuple): raise ValueError( - "For cases with return types as Vars, a default case must be provided" + "rx.match should have tuples of cases and a default case as the last argument." ) - elif default is None: - default = Fragment.create() - default = cast(Var[VAR_TYPE] | VAR_TYPE, default) + # There should be at least two elements in a case tuple(a condition and return value) + if len(case) < 2: + raise ValueError( + "A case tuple should have at least a match case element and a return value." + ) - return cls._create_match_cond_var_or_component( - cond, - actual_cases, - default, + +def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None: + """Validate that match cases have the same return types. + + Args: + match_cases: The match cases. + + Raises: + MatchTypeError: If the return types of cases are different. + """ + first_case_return = match_cases[0][-1] + return_type = type(first_case_return) + + if types._isinstance(first_case_return, BaseComponent): + return_type = BaseComponent + elif types._isinstance(first_case_return, Var): + return_type = Var + + for index, case in enumerate(match_cases): + if not ( + types._issubclass(type(case[-1]), return_type) + or ( + isinstance(case[-1], Var) + and types.typehint_issubclass(case[-1]._var_type, return_type) + ) + ): + raise MatchTypeError( + f"Match cases should have the same return types. Case {index} with return " + f"value `{case[-1]._js_expr if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`" + f" of type {(type(case[-1]) if not isinstance(case[-1], Var) else case[-1]._var_type)!r} is not {return_type}" + ) + + +def _create_match_var( + match_cond_var: Var, + match_cases: tuple[CASE_TYPE[VAR_TYPE], ...], + default: VAR_TYPE | Var[VAR_TYPE], +) -> Var[VAR_TYPE]: + """Create the match var. + + Args: + match_cond_var: The match condition var. + match_cases: The match cases. + default: The default case. + + Returns: + The match var. + """ + return MatchOperation.create(match_cond_var, match_cases, default) + + +def match( + cond: Any, + *cases: Unpack[ + tuple[Unpack[tuple[CASE_TYPE[VAR_TYPE], ...]], Var[VAR_TYPE] | VAR_TYPE] + ], +) -> Var[VAR_TYPE]: + """Create a match var. + + Args: + cond: The condition to match. + cases: The match cases. Each case should be a tuple with the first elements as the match case and the last element as the return value. The last argument should be the default case. + + Returns: + The match var. + + Raises: + ValueError: If the default case is not the last case or the tuple elements are less than 2. + """ + default = None + + if len([case for case in cases if not isinstance(case, tuple)]) > 1: + raise ValueError("rx.match can only have one default case.") + + if not cases: + raise ValueError("rx.match should have at least one case.") + + # Get the default case which should be the last non-tuple arg + if not isinstance(cases[-1], tuple): + default = cases[-1] + actual_cases = cases[:-1] + else: + actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases) + + _process_match_cases(actual_cases) + + _validate_return_types(actual_cases) + + if default is None and any( + not ( + isinstance((return_type := case[-1]), Component) + or ( + isinstance(return_type, Var) + and types.typehint_issubclass(return_type._var_type, Component) + ) ) + for case in actual_cases + ): + raise ValueError( + "For cases with return types as Vars, a default case must be provided" + ) + elif default is None: + default = Fragment.create() - @classmethod - def _process_match_cases(cls, cases: tuple[CASE_TYPE[VAR_TYPE], ...]): - """Process the individual match cases. + default = cast(Var[VAR_TYPE] | VAR_TYPE, default) - Args: - cases: The match cases. - - Raises: - ValueError: If the default case is not the last case or the tuple elements are less than 2. - """ - for case in cases: - if not isinstance(case, tuple): - raise ValueError( - "rx.match should have tuples of cases and a default case as the last argument." - ) - - # There should be at least two elements in a case tuple(a condition and return value) - if len(case) < 2: - raise ValueError( - "A case tuple should have at least a match case element and a return value." - ) - - @classmethod - def _validate_return_types( - cls, match_cases: tuple[CASE_TYPE[VAR_TYPE], ...] - ) -> None: - """Validate that match cases have the same return types. - - Args: - match_cases: The match cases. - - Raises: - MatchTypeError: If the return types of cases are different. - """ - first_case_return = match_cases[0][-1] - return_type = type(first_case_return) - - if types._isinstance(first_case_return, BaseComponent): - return_type = BaseComponent - elif types._isinstance(first_case_return, Var): - return_type = Var - - for index, case in enumerate(match_cases): - if not ( - types._issubclass(type(case[-1]), return_type) - or ( - isinstance(case[-1], Var) - and types.typehint_issubclass(case[-1]._var_type, return_type) - ) - ): - raise MatchTypeError( - f"Match cases should have the same return types. Case {index} with return " - f"value `{case[-1]._js_expr if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`" - f" of type {(type(case[-1]) if not isinstance(case[-1], Var) else case[-1]._var_type)!r} is not {return_type}" - ) - - @classmethod - def _create_match_cond_var_or_component( - cls, - match_cond_var: Var, - match_cases: tuple[CASE_TYPE[VAR_TYPE], ...], - default: VAR_TYPE | Var[VAR_TYPE], - ) -> Var[VAR_TYPE]: - """Create and return the match condition var or component. - - Args: - match_cond_var: The match condition. - match_cases: The list of match cases. - default: The default case. - - Returns: - The match component wrapped in a fragment or the match var. - """ - return MatchOperation.create(match_cond_var, match_cases, default) - - -match = Match.create + return _create_match_var( + cond, + actual_cases, + default, + ) diff --git a/reflex/components/radix/themes/components/icon_button.py b/reflex/components/radix/themes/components/icon_button.py index 68c67485a..9daf86691 100644 --- a/reflex/components/radix/themes/components/icon_button.py +++ b/reflex/components/radix/themes/components/icon_button.py @@ -5,8 +5,8 @@ from __future__ import annotations from typing import Literal from reflex.components.component import Component +from reflex.components.core import match from reflex.components.core.breakpoints import Responsive -from reflex.components.core.match import Match from reflex.components.el import elements from reflex.components.lucide import Icon from reflex.style import Style @@ -77,7 +77,7 @@ class IconButton(elements.Button, RadixLoadingProp, RadixThemesComponent): if isinstance(props["size"], str): children[0].size = RADIX_TO_LUCIDE_SIZE[props["size"]] else: - size_map_var = Match.create( + size_map_var = match( props["size"], *list(RADIX_TO_LUCIDE_SIZE.items()), 12, diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py index 83581a415..862bd5ad3 100644 --- a/tests/units/components/core/test_match.py +++ b/tests/units/components/core/test_match.py @@ -3,7 +3,7 @@ from typing import Tuple import pytest import reflex as rx -from reflex.components.core.match import Match +from reflex.components.core.match import match from reflex.state import BaseState from reflex.utils.exceptions import MatchTypeError from reflex.vars.base import Var @@ -67,7 +67,7 @@ def test_match_vars(cases, expected): cases: The match cases. expected: The expected var full name. """ - match_comp = Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue] + match_comp = match(MatchState.value, *cases) # pyright: ignore[reportCallIssue] assert isinstance(match_comp, Var) assert str(match_comp) == expected @@ -81,7 +81,7 @@ def test_match_on_component_without_default(): (2, 3, rx.text("second value")), ) - match_comp = Match.create(MatchState.value, *match_case_tuples) + match_comp = match(MatchState.value, *match_case_tuples) assert isinstance(match_comp, Var) @@ -98,7 +98,7 @@ def test_match_on_var_no_default(): ValueError, match="For cases with return types as Vars, a default case must be provided", ): - Match.create(MatchState.value, *match_case_tuples) + match(MatchState.value, *match_case_tuples) @pytest.mark.parametrize( @@ -131,7 +131,7 @@ def test_match_default_not_last_arg(match_case): ValueError, match="rx.match should have tuples of cases and a default case as the last argument.", ): - Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] + match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] @pytest.mark.parametrize( @@ -161,7 +161,7 @@ def test_match_case_tuple_elements(match_case): ValueError, match="A case tuple should have at least a match case element and a return value.", ): - Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] + match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] @pytest.mark.parametrize( @@ -203,7 +203,7 @@ def test_match_different_return_types(cases: Tuple, error_msg: str): error_msg: Expected error message. """ with pytest.raises(MatchTypeError, match=error_msg): - Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue] + match(MatchState.value, *cases) # pyright: ignore[reportCallIssue] @pytest.mark.parametrize( @@ -235,9 +235,9 @@ def test_match_multiple_default_cases(match_case): match_case: the cases to match. """ with pytest.raises(ValueError, match="rx.match can only have one default case."): - Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] + match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] def test_match_no_cond(): with pytest.raises(ValueError): - _ = Match.create(None) # pyright: ignore[reportCallIssue] + _ = match(None) # pyright: ignore[reportCallIssue] diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 766f96e61..bf2e18140 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -1451,7 +1451,6 @@ def test_instantiate_all_components(): "FormControl", "Html", "Icon", - "Match", "Markdown", "MultiSelect", "Option",