From 29fc4b020a54137c7e953946cf94eff1c9b885a7 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Wed, 22 Jan 2025 16:39:43 -0800 Subject: [PATCH] make the match logic better --- reflex/components/core/match.py | 53 +++++++++++++++-------- tests/units/components/core/test_match.py | 9 ++-- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/reflex/components/core/match.py b/reflex/components/core/match.py index d1b359e49..a9e7e10c5 100644 --- a/reflex/components/core/match.py +++ b/reflex/components/core/match.py @@ -1,6 +1,5 @@ """rx.match.""" -import textwrap from typing import Any, cast from typing_extensions import Unpack @@ -46,27 +45,43 @@ def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None Raises: MatchTypeError: If the return types of cases are different. """ - first_case_return = match_cases[0][-1] - return_type = type(first_case_return) - if types._isinstance(first_case_return, BaseComponent): - return_type = BaseComponent - elif types._isinstance(first_case_return, Var): - return_type = Var + def is_component_or_component_var(obj: Any) -> bool: + return types._isinstance(obj, BaseComponent) or ( + isinstance(obj, Var) + and ( + types.safe_typehint_issubclass(obj._var_type, BaseComponent) + or types.safe_typehint_issubclass(obj._var_type, list[BaseComponent]) + ) + ) - for index, case in enumerate(match_cases): - if not ( - types._issubclass(type(case[-1]), return_type) - or ( - isinstance(case[-1], Var) - and types.typehint_issubclass(case[-1]._var_type, return_type) - ) - ): - raise MatchTypeError( - f"Match cases should have the same return types. Case {index} with return " - f"value `{case[-1]._js_expr if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`" - f" of type {(type(case[-1]) if not isinstance(case[-1], Var) else case[-1]._var_type)!r} is not {return_type}" + def type_of_return_type(obj: Any) -> Any: + if isinstance(obj, Var): + return obj._var_type + return type(obj) + + return_types = [case[-1] for case in match_cases] + + if any( + is_component_or_component_var(return_type) for return_type in return_types + ) and not all( + is_component_or_component_var(return_type) for return_type in return_types + ): + non_component_return_types = [ + (type_of_return_type(return_type), i) + for i, return_type in enumerate(return_types) + if not is_component_or_component_var(return_type) + ] + raise MatchTypeError( + "Match cases should have the same return types. " + + "Expected return types to be of type Component or Var[Component]. " + + ". ".join( + [ + f"Return type of case {i} is {return_type}" + for return_type, i in non_component_return_types + ] ) + ) def _create_match_var( diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py index 862bd5ad3..129234c16 100644 --- a/tests/units/components/core/test_match.py +++ b/tests/units/components/core/test_match.py @@ -1,3 +1,4 @@ +import re from typing import Tuple import pytest @@ -177,8 +178,7 @@ def test_match_case_tuple_elements(match_case): (MatchState.num + 1, "black"), rx.text("default value"), ), - "Match cases should have the same return types. Case 3 with return value `red` of type " - " is not ", + "Match cases should have the same return types. Expected return types to be of type Component or Var[Component]. Return type of case 3 is . Return type of case 4 is . Return type of case 5 is ", ), ( ( @@ -190,8 +190,7 @@ def test_match_case_tuple_elements(match_case): ([1, 2], rx.text("third value")), rx.text("default value"), ), - 'Match cases should have the same return types. Case 3 with return value ` {"first value"} ` ' - "of type is not ", + "Match cases should have the same return types. Expected return types to be of type Component or Var[Component]. Return type of case 0 is . Return type of case 1 is . Return type of case 2 is ", ), ], ) @@ -202,7 +201,7 @@ def test_match_different_return_types(cases: Tuple, error_msg: str): cases: The match cases. error_msg: Expected error message. """ - with pytest.raises(MatchTypeError, match=error_msg): + with pytest.raises(MatchTypeError, match=re.escape(error_msg)): match(MatchState.value, *cases) # pyright: ignore[reportCallIssue]