refactor match code
This commit is contained in:
parent
d1ff6d51a2
commit
c74313992f
@ -5,7 +5,7 @@ from typing import Any, Union, cast
|
|||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
from reflex.components.base import Fragment
|
from reflex.components.base import Fragment
|
||||||
from reflex.components.component import BaseComponent, Component
|
from reflex.components.component import BaseComponent
|
||||||
from reflex.utils import types
|
from reflex.utils import types
|
||||||
from reflex.utils.exceptions import MatchTypeError
|
from reflex.utils.exceptions import MatchTypeError
|
||||||
from reflex.vars.base import VAR_TYPE, Var
|
from reflex.vars.base import VAR_TYPE, Var
|
||||||
@ -36,11 +36,14 @@ def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None:
|
def _validate_return_types(*return_values: Any) -> bool:
|
||||||
"""Validate that match cases have the same return types.
|
"""Validate that match cases have the same return types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
match_cases: The match cases.
|
return_values: The return values of the match cases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if all cases have the same return types.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
MatchTypeError: If the return types of cases are different.
|
MatchTypeError: If the return types of cases are different.
|
||||||
@ -54,22 +57,20 @@ def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def type_of_return_type(obj: Any) -> Any:
|
def type_of_return_value(obj: Any) -> Any:
|
||||||
if isinstance(obj, Var):
|
if isinstance(obj, Var):
|
||||||
return obj._var_type
|
return obj._var_type
|
||||||
return type(obj)
|
return type(obj)
|
||||||
|
|
||||||
return_types = [case[-1] for case in match_cases]
|
is_return_type_component = [
|
||||||
|
is_component_or_component_var(return_type) for return_type in return_values
|
||||||
|
]
|
||||||
|
|
||||||
if any(
|
if any(is_return_type_component) and not all(is_return_type_component):
|
||||||
is_component_or_component_var(return_type) for return_type in return_types
|
|
||||||
) and not all(
|
|
||||||
is_component_or_component_var(return_type) for return_type in return_types
|
|
||||||
):
|
|
||||||
non_component_return_types = [
|
non_component_return_types = [
|
||||||
(type_of_return_type(return_type), i)
|
(type_of_return_value(return_value), i)
|
||||||
for i, return_type in enumerate(return_types)
|
for i, return_value in enumerate(return_values)
|
||||||
if not is_component_or_component_var(return_type)
|
if not is_return_type_component[i]
|
||||||
]
|
]
|
||||||
raise MatchTypeError(
|
raise MatchTypeError(
|
||||||
"Match cases should have the same return types. "
|
"Match cases should have the same return types. "
|
||||||
@ -82,6 +83,8 @@ def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return all(is_return_type_component)
|
||||||
|
|
||||||
|
|
||||||
def _create_match_var(
|
def _create_match_var(
|
||||||
match_cond_var: Var,
|
match_cond_var: Var,
|
||||||
@ -119,7 +122,7 @@ def match(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the default case is not the last case or the tuple elements are less than 2.
|
ValueError: If the default case is not the last case or the tuple elements are less than 2.
|
||||||
"""
|
"""
|
||||||
default = None
|
default = types.Unset()
|
||||||
|
|
||||||
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
|
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
|
||||||
raise ValueError("rx.match can only have one default case.")
|
raise ValueError("rx.match can only have one default case.")
|
||||||
@ -136,22 +139,17 @@ def match(
|
|||||||
|
|
||||||
_process_match_cases(actual_cases)
|
_process_match_cases(actual_cases)
|
||||||
|
|
||||||
_validate_return_types(actual_cases)
|
is_component_match = _validate_return_types(
|
||||||
|
*[case[-1] for case in actual_cases],
|
||||||
|
*([default] if not isinstance(default, types.Unset) else []),
|
||||||
|
)
|
||||||
|
|
||||||
if default is None and any(
|
if isinstance(default, types.Unset) and not is_component_match:
|
||||||
not (
|
|
||||||
isinstance((return_type := case[-1]), Component)
|
|
||||||
or (
|
|
||||||
isinstance(return_type, Var)
|
|
||||||
and types.typehint_issubclass(return_type._var_type, Component)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for case in actual_cases
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"For cases with return types as Vars, a default case must be provided"
|
"For cases with return types as Vars, a default case must be provided"
|
||||||
)
|
)
|
||||||
elif default is None:
|
|
||||||
|
if isinstance(default, types.Unset):
|
||||||
default = Fragment.create()
|
default = Fragment.create()
|
||||||
|
|
||||||
default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
|
default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
|
||||||
|
Loading…
Reference in New Issue
Block a user