refactor match code

This commit is contained in:
Khaleel Al-Adhami 2025-02-13 14:42:11 -08:00
parent d1ff6d51a2
commit c74313992f

View File

@ -5,7 +5,7 @@ from typing import Any, Union, cast
from typing_extensions import Unpack from typing_extensions import Unpack
from reflex.components.base import Fragment from reflex.components.base import Fragment
from reflex.components.component import BaseComponent, Component from reflex.components.component import BaseComponent
from reflex.utils import types from reflex.utils import types
from reflex.utils.exceptions import MatchTypeError from reflex.utils.exceptions import MatchTypeError
from reflex.vars.base import VAR_TYPE, Var from reflex.vars.base import VAR_TYPE, Var
@ -36,11 +36,14 @@ def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
) )
def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None: def _validate_return_types(*return_values: Any) -> bool:
"""Validate that match cases have the same return types. """Validate that match cases have the same return types.
Args: Args:
match_cases: The match cases. return_values: The return values of the match cases.
Returns:
True if all cases have the same return types.
Raises: Raises:
MatchTypeError: If the return types of cases are different. MatchTypeError: If the return types of cases are different.
@ -54,22 +57,20 @@ def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None
) )
) )
def type_of_return_type(obj: Any) -> Any: def type_of_return_value(obj: Any) -> Any:
if isinstance(obj, Var): if isinstance(obj, Var):
return obj._var_type return obj._var_type
return type(obj) return type(obj)
return_types = [case[-1] for case in match_cases] is_return_type_component = [
is_component_or_component_var(return_type) for return_type in return_values
]
if any( if any(is_return_type_component) and not all(is_return_type_component):
is_component_or_component_var(return_type) for return_type in return_types
) and not all(
is_component_or_component_var(return_type) for return_type in return_types
):
non_component_return_types = [ non_component_return_types = [
(type_of_return_type(return_type), i) (type_of_return_value(return_value), i)
for i, return_type in enumerate(return_types) for i, return_value in enumerate(return_values)
if not is_component_or_component_var(return_type) if not is_return_type_component[i]
] ]
raise MatchTypeError( raise MatchTypeError(
"Match cases should have the same return types. " "Match cases should have the same return types. "
@ -82,6 +83,8 @@ def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None
) )
) )
return all(is_return_type_component)
def _create_match_var( def _create_match_var(
match_cond_var: Var, match_cond_var: Var,
@ -119,7 +122,7 @@ def match(
Raises: Raises:
ValueError: If the default case is not the last case or the tuple elements are less than 2. ValueError: If the default case is not the last case or the tuple elements are less than 2.
""" """
default = None default = types.Unset()
if len([case for case in cases if not isinstance(case, tuple)]) > 1: if len([case for case in cases if not isinstance(case, tuple)]) > 1:
raise ValueError("rx.match can only have one default case.") raise ValueError("rx.match can only have one default case.")
@ -136,22 +139,17 @@ def match(
_process_match_cases(actual_cases) _process_match_cases(actual_cases)
_validate_return_types(actual_cases) is_component_match = _validate_return_types(
*[case[-1] for case in actual_cases],
*([default] if not isinstance(default, types.Unset) else []),
)
if default is None and any( if isinstance(default, types.Unset) and not is_component_match:
not (
isinstance((return_type := case[-1]), Component)
or (
isinstance(return_type, Var)
and types.typehint_issubclass(return_type._var_type, Component)
)
)
for case in actual_cases
):
raise ValueError( raise ValueError(
"For cases with return types as Vars, a default case must be provided" "For cases with return types as Vars, a default case must be provided"
) )
elif default is None:
if isinstance(default, types.Unset):
default = Fragment.create() default = Fragment.create()
default = cast(Var[VAR_TYPE] | VAR_TYPE, default) default = cast(Var[VAR_TYPE] | VAR_TYPE, default)