From c74313992f06e78be937f51d051529f29af57652 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 13 Feb 2025 14:42:11 -0800 Subject: [PATCH] refactor match code --- reflex/components/core/match.py | 50 ++++++++++++++++----------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/reflex/components/core/match.py b/reflex/components/core/match.py index 1b32c3033..4db9bf5bd 100644 --- a/reflex/components/core/match.py +++ b/reflex/components/core/match.py @@ -5,7 +5,7 @@ from typing import Any, Union, cast from typing_extensions import Unpack from reflex.components.base import Fragment -from reflex.components.component import BaseComponent, Component +from reflex.components.component import BaseComponent from reflex.utils import types from reflex.utils.exceptions import MatchTypeError from reflex.vars.base import VAR_TYPE, Var @@ -36,11 +36,14 @@ def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]): ) -def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None: +def _validate_return_types(*return_values: Any) -> bool: """Validate that match cases have the same return types. Args: - match_cases: The match cases. + return_values: The return values of the match cases. + + Returns: + True if all cases have the same return types. Raises: MatchTypeError: If the return types of cases are different. @@ -54,22 +57,20 @@ def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None ) ) - def type_of_return_type(obj: Any) -> Any: + def type_of_return_value(obj: Any) -> Any: if isinstance(obj, Var): return obj._var_type return type(obj) - return_types = [case[-1] for case in match_cases] + is_return_type_component = [ + is_component_or_component_var(return_type) for return_type in return_values + ] - if any( - is_component_or_component_var(return_type) for return_type in return_types - ) and not all( - is_component_or_component_var(return_type) for return_type in return_types - ): + if any(is_return_type_component) and not all(is_return_type_component): non_component_return_types = [ - (type_of_return_type(return_type), i) - for i, return_type in enumerate(return_types) - if not is_component_or_component_var(return_type) + (type_of_return_value(return_value), i) + for i, return_value in enumerate(return_values) + if not is_return_type_component[i] ] raise MatchTypeError( "Match cases should have the same return types. " @@ -82,6 +83,8 @@ def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None ) ) + return all(is_return_type_component) + def _create_match_var( match_cond_var: Var, @@ -119,7 +122,7 @@ def match( Raises: ValueError: If the default case is not the last case or the tuple elements are less than 2. """ - default = None + default = types.Unset() if len([case for case in cases if not isinstance(case, tuple)]) > 1: raise ValueError("rx.match can only have one default case.") @@ -136,22 +139,17 @@ def match( _process_match_cases(actual_cases) - _validate_return_types(actual_cases) + is_component_match = _validate_return_types( + *[case[-1] for case in actual_cases], + *([default] if not isinstance(default, types.Unset) else []), + ) - if default is None and any( - not ( - isinstance((return_type := case[-1]), Component) - or ( - isinstance(return_type, Var) - and types.typehint_issubclass(return_type._var_type, Component) - ) - ) - for case in actual_cases - ): + if isinstance(default, types.Unset) and not is_component_match: raise ValueError( "For cases with return types as Vars, a default case must be provided" ) - elif default is None: + + if isinstance(default, types.Unset): default = Fragment.create() default = cast(Var[VAR_TYPE] | VAR_TYPE, default)