get rid of match class

This commit is contained in:
Khaleel Al-Adhami 2025-01-17 16:17:24 -08:00
parent 36af8255d3
commit 00019daa27
7 changed files with 137 additions and 195 deletions

View File

@ -8,8 +8,6 @@
{{- component }} {{- component }}
{%- elif "iterable" in component %} {%- elif "iterable" in component %}
{{- render_iterable_tag(component) }} {{- render_iterable_tag(component) }}
{%- elif component.name == "match"%}
{{- render_match_tag(component) }}
{%- elif "cond" in component %} {%- elif "cond" in component %}
{{- render_condition_tag(component) }} {{- render_condition_tag(component) }}
{%- elif component.children|length %} {%- elif component.children|length %}
@ -75,29 +73,6 @@
{% if props|length %} {{ props|join(" ") }}{% endif %} {% if props|length %} {{ props|join(" ") }}{% endif %}
{% endmacro %} {% endmacro %}
{# Rendering Match component. #}
{# Args: #}
{# component: component dictionary #}
{% macro render_match_tag(component) %}
{
(() => {
switch (JSON.stringify({{ component.cond._js_expr }})) {
{% for case in component.match_cases %}
{% for condition in case[:-1] %}
case JSON.stringify({{ condition._js_expr }}):
{% endfor %}
return {{ case[-1] }};
break;
{% endfor %}
default:
return {{ component.default }};
break;
}
})()
}
{%- endmacro %}
{# Rendering content with args. #} {# Rendering content with args. #}
{# Args: #} {# Args: #}
{# component: component dictionary #} {# component: component dictionary #}

View File

@ -934,7 +934,6 @@ class Component(BaseComponent, ABC):
from reflex.components.base.fragment import Fragment from reflex.components.base.fragment import Fragment
from reflex.components.core.cond import Cond from reflex.components.core.cond import Cond
from reflex.components.core.foreach import Foreach from reflex.components.core.foreach import Foreach
from reflex.components.core.match import Match
no_valid_parents_defined = all(child._valid_parents == [] for child in children) no_valid_parents_defined = all(child._valid_parents == [] for child in children)
if ( if (
@ -945,9 +944,7 @@ class Component(BaseComponent, ABC):
return return
comp_name = type(self).__name__ comp_name = type(self).__name__
allowed_components = [ allowed_components = [comp.__name__ for comp in (Fragment, Foreach, Cond)]
comp.__name__ for comp in (Fragment, Foreach, Cond, Match)
]
def validate_child(child): def validate_child(child):
child_name = type(child).__name__ child_name = type(child).__name__
@ -971,11 +968,6 @@ class Component(BaseComponent, ABC):
for c in var_data.components: for c in var_data.components:
validate_child(c) validate_child(c)
if isinstance(child, Match):
for cases in child.match_cases:
validate_child(cases[-1])
validate_child(child.default)
if self._invalid_children and child_name in self._invalid_children: if self._invalid_children and child_name in self._invalid_children:
raise ValueError( raise ValueError(
f"The component `{comp_name}` cannot have `{child_name}` as a child component" f"The component `{comp_name}` cannot have `{child_name}` as a child component"
@ -2073,7 +2065,6 @@ class StatefulComponent(BaseComponent):
from reflex.components.base.bare import Bare from reflex.components.base.bare import Bare
from reflex.components.core.cond import Cond from reflex.components.core.cond import Cond
from reflex.components.core.foreach import Foreach from reflex.components.core.foreach import Foreach
from reflex.components.core.match import Match
if isinstance(child, Bare): if isinstance(child, Bare):
return child.contents return child.contents
@ -2081,8 +2072,6 @@ class StatefulComponent(BaseComponent):
return child.cond return child.cond
if isinstance(child, Foreach): if isinstance(child, Foreach):
return child.iterable return child.iterable
if isinstance(child, Match):
return child.cond
return child return child
@classmethod @classmethod

View File

@ -30,7 +30,6 @@ _SUBMOD_ATTRS: dict[str, list[str]] = {
"html": ["html", "Html"], "html": ["html", "Html"],
"match": [ "match": [
"match", "match",
"Match",
], ],
"breakpoints": ["breakpoints", "set_breakpoints"], "breakpoints": ["breakpoints", "set_breakpoints"],
"responsive": [ "responsive": [

View File

@ -1,12 +1,12 @@
"""rx.match.""" """rx.match."""
import textwrap import textwrap
from typing import Any, List, cast from typing import Any, cast
from typing_extensions import Unpack from typing_extensions import Unpack
from reflex.components.base import Fragment from reflex.components.base import Fragment
from reflex.components.component import BaseComponent, Component, MemoizationLeaf from reflex.components.component import BaseComponent, Component
from reflex.utils import types from reflex.utils import types
from reflex.utils.exceptions import MatchTypeError from reflex.utils.exceptions import MatchTypeError
from reflex.vars.base import VAR_TYPE, Var from reflex.vars.base import VAR_TYPE, Var
@ -15,155 +15,135 @@ from reflex.vars.number import MatchOperation
CASE_TYPE = tuple[Unpack[tuple[Any, ...]], Var[VAR_TYPE] | VAR_TYPE] CASE_TYPE = tuple[Unpack[tuple[Any, ...]], Var[VAR_TYPE] | VAR_TYPE]
class Match(MemoizationLeaf): def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
"""Match cases based on a condition.""" """Process the individual match cases.
# The condition to determine which case to match. Args:
cond: Var[Any] cases: The match cases.
# The list of match cases to be matched. Raises:
match_cases: List[Any] = [] ValueError: If the default case is not the last case or the tuple elements are less than 2.
"""
# The catchall case to match. for case in cases:
default: Any if not isinstance(case, tuple):
@classmethod
def create(
cls,
cond: Any,
*cases: Unpack[
tuple[Unpack[tuple[CASE_TYPE[VAR_TYPE], ...]], Var[VAR_TYPE] | VAR_TYPE]
],
) -> Var[VAR_TYPE]:
"""Create a Match Component.
Args:
cond: The condition to determine which case to match.
cases: This list of cases to match.
Returns:
The match component.
Raises:
ValueError: When a default case is not provided for cases with Var return types.
"""
default = None
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
raise ValueError("rx.match can only have one default case.")
if not cases:
raise ValueError("rx.match should have at least one case.")
# Get the default case which should be the last non-tuple arg
if not isinstance(cases[-1], tuple):
default = cases[-1]
actual_cases = cases[:-1]
else:
actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
cls._process_match_cases(actual_cases)
cls._validate_return_types(actual_cases)
if default is None and any(
not (
isinstance((return_type := case[-1]), Component)
or (
isinstance(return_type, Var)
and types.typehint_issubclass(return_type._var_type, Component)
)
)
for case in actual_cases
):
raise ValueError( raise ValueError(
"For cases with return types as Vars, a default case must be provided" "rx.match should have tuples of cases and a default case as the last argument."
) )
elif default is None:
default = Fragment.create()
default = cast(Var[VAR_TYPE] | VAR_TYPE, default) # There should be at least two elements in a case tuple(a condition and return value)
if len(case) < 2:
raise ValueError(
"A case tuple should have at least a match case element and a return value."
)
return cls._create_match_cond_var_or_component(
cond, def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None:
actual_cases, """Validate that match cases have the same return types.
default,
Args:
match_cases: The match cases.
Raises:
MatchTypeError: If the return types of cases are different.
"""
first_case_return = match_cases[0][-1]
return_type = type(first_case_return)
if types._isinstance(first_case_return, BaseComponent):
return_type = BaseComponent
elif types._isinstance(first_case_return, Var):
return_type = Var
for index, case in enumerate(match_cases):
if not (
types._issubclass(type(case[-1]), return_type)
or (
isinstance(case[-1], Var)
and types.typehint_issubclass(case[-1]._var_type, return_type)
)
):
raise MatchTypeError(
f"Match cases should have the same return types. Case {index} with return "
f"value `{case[-1]._js_expr if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`"
f" of type {(type(case[-1]) if not isinstance(case[-1], Var) else case[-1]._var_type)!r} is not {return_type}"
)
def _create_match_var(
match_cond_var: Var,
match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
default: VAR_TYPE | Var[VAR_TYPE],
) -> Var[VAR_TYPE]:
"""Create the match var.
Args:
match_cond_var: The match condition var.
match_cases: The match cases.
default: The default case.
Returns:
The match var.
"""
return MatchOperation.create(match_cond_var, match_cases, default)
def match(
cond: Any,
*cases: Unpack[
tuple[Unpack[tuple[CASE_TYPE[VAR_TYPE], ...]], Var[VAR_TYPE] | VAR_TYPE]
],
) -> Var[VAR_TYPE]:
"""Create a match var.
Args:
cond: The condition to match.
cases: The match cases. Each case should be a tuple with the first elements as the match case and the last element as the return value. The last argument should be the default case.
Returns:
The match var.
Raises:
ValueError: If the default case is not the last case or the tuple elements are less than 2.
"""
default = None
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
raise ValueError("rx.match can only have one default case.")
if not cases:
raise ValueError("rx.match should have at least one case.")
# Get the default case which should be the last non-tuple arg
if not isinstance(cases[-1], tuple):
default = cases[-1]
actual_cases = cases[:-1]
else:
actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
_process_match_cases(actual_cases)
_validate_return_types(actual_cases)
if default is None and any(
not (
isinstance((return_type := case[-1]), Component)
or (
isinstance(return_type, Var)
and types.typehint_issubclass(return_type._var_type, Component)
)
) )
for case in actual_cases
):
raise ValueError(
"For cases with return types as Vars, a default case must be provided"
)
elif default is None:
default = Fragment.create()
@classmethod default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
def _process_match_cases(cls, cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
"""Process the individual match cases.
Args: return _create_match_var(
cases: The match cases. cond,
actual_cases,
Raises: default,
ValueError: If the default case is not the last case or the tuple elements are less than 2. )
"""
for case in cases:
if not isinstance(case, tuple):
raise ValueError(
"rx.match should have tuples of cases and a default case as the last argument."
)
# There should be at least two elements in a case tuple(a condition and return value)
if len(case) < 2:
raise ValueError(
"A case tuple should have at least a match case element and a return value."
)
@classmethod
def _validate_return_types(
cls, match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]
) -> None:
"""Validate that match cases have the same return types.
Args:
match_cases: The match cases.
Raises:
MatchTypeError: If the return types of cases are different.
"""
first_case_return = match_cases[0][-1]
return_type = type(first_case_return)
if types._isinstance(first_case_return, BaseComponent):
return_type = BaseComponent
elif types._isinstance(first_case_return, Var):
return_type = Var
for index, case in enumerate(match_cases):
if not (
types._issubclass(type(case[-1]), return_type)
or (
isinstance(case[-1], Var)
and types.typehint_issubclass(case[-1]._var_type, return_type)
)
):
raise MatchTypeError(
f"Match cases should have the same return types. Case {index} with return "
f"value `{case[-1]._js_expr if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`"
f" of type {(type(case[-1]) if not isinstance(case[-1], Var) else case[-1]._var_type)!r} is not {return_type}"
)
@classmethod
def _create_match_cond_var_or_component(
cls,
match_cond_var: Var,
match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
default: VAR_TYPE | Var[VAR_TYPE],
) -> Var[VAR_TYPE]:
"""Create and return the match condition var or component.
Args:
match_cond_var: The match condition.
match_cases: The list of match cases.
default: The default case.
Returns:
The match component wrapped in a fragment or the match var.
"""
return MatchOperation.create(match_cond_var, match_cases, default)
match = Match.create

View File

@ -5,8 +5,8 @@ from __future__ import annotations
from typing import Literal from typing import Literal
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.core import match
from reflex.components.core.breakpoints import Responsive from reflex.components.core.breakpoints import Responsive
from reflex.components.core.match import Match
from reflex.components.el import elements from reflex.components.el import elements
from reflex.components.lucide import Icon from reflex.components.lucide import Icon
from reflex.style import Style from reflex.style import Style
@ -77,7 +77,7 @@ class IconButton(elements.Button, RadixLoadingProp, RadixThemesComponent):
if isinstance(props["size"], str): if isinstance(props["size"], str):
children[0].size = RADIX_TO_LUCIDE_SIZE[props["size"]] children[0].size = RADIX_TO_LUCIDE_SIZE[props["size"]]
else: else:
size_map_var = Match.create( size_map_var = match(
props["size"], props["size"],
*list(RADIX_TO_LUCIDE_SIZE.items()), *list(RADIX_TO_LUCIDE_SIZE.items()),
12, 12,

View File

@ -3,7 +3,7 @@ from typing import Tuple
import pytest import pytest
import reflex as rx import reflex as rx
from reflex.components.core.match import Match from reflex.components.core.match import match
from reflex.state import BaseState from reflex.state import BaseState
from reflex.utils.exceptions import MatchTypeError from reflex.utils.exceptions import MatchTypeError
from reflex.vars.base import Var from reflex.vars.base import Var
@ -67,7 +67,7 @@ def test_match_vars(cases, expected):
cases: The match cases. cases: The match cases.
expected: The expected var full name. expected: The expected var full name.
""" """
match_comp = Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue] match_comp = match(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
assert isinstance(match_comp, Var) assert isinstance(match_comp, Var)
assert str(match_comp) == expected assert str(match_comp) == expected
@ -81,7 +81,7 @@ def test_match_on_component_without_default():
(2, 3, rx.text("second value")), (2, 3, rx.text("second value")),
) )
match_comp = Match.create(MatchState.value, *match_case_tuples) match_comp = match(MatchState.value, *match_case_tuples)
assert isinstance(match_comp, Var) assert isinstance(match_comp, Var)
@ -98,7 +98,7 @@ def test_match_on_var_no_default():
ValueError, ValueError,
match="For cases with return types as Vars, a default case must be provided", match="For cases with return types as Vars, a default case must be provided",
): ):
Match.create(MatchState.value, *match_case_tuples) match(MatchState.value, *match_case_tuples)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -131,7 +131,7 @@ def test_match_default_not_last_arg(match_case):
ValueError, ValueError,
match="rx.match should have tuples of cases and a default case as the last argument.", match="rx.match should have tuples of cases and a default case as the last argument.",
): ):
Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -161,7 +161,7 @@ def test_match_case_tuple_elements(match_case):
ValueError, ValueError,
match="A case tuple should have at least a match case element and a return value.", match="A case tuple should have at least a match case element and a return value.",
): ):
Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -203,7 +203,7 @@ def test_match_different_return_types(cases: Tuple, error_msg: str):
error_msg: Expected error message. error_msg: Expected error message.
""" """
with pytest.raises(MatchTypeError, match=error_msg): with pytest.raises(MatchTypeError, match=error_msg):
Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue] match(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -235,9 +235,9 @@ def test_match_multiple_default_cases(match_case):
match_case: the cases to match. match_case: the cases to match.
""" """
with pytest.raises(ValueError, match="rx.match can only have one default case."): with pytest.raises(ValueError, match="rx.match can only have one default case."):
Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
def test_match_no_cond(): def test_match_no_cond():
with pytest.raises(ValueError): with pytest.raises(ValueError):
_ = Match.create(None) # pyright: ignore[reportCallIssue] _ = match(None) # pyright: ignore[reportCallIssue]

View File

@ -1451,7 +1451,6 @@ def test_instantiate_all_components():
"FormControl", "FormControl",
"Html", "Html",
"Icon", "Icon",
"Match",
"Markdown", "Markdown",
"MultiSelect", "MultiSelect",
"Option", "Option",