162 lines
4.8 KiB
Python
162 lines
4.8 KiB
Python
"""rx.match."""
|
|
|
|
from typing import Any, Union, cast
|
|
|
|
from typing_extensions import Unpack
|
|
|
|
from reflex.components.base import Fragment
|
|
from reflex.components.component import BaseComponent
|
|
from reflex.utils import types
|
|
from reflex.utils.exceptions import MatchTypeError
|
|
from reflex.vars.base import VAR_TYPE, Var
|
|
from reflex.vars.number import MatchOperation
|
|
|
|
CASE_TYPE = tuple[Unpack[tuple[Any, ...]], Var[VAR_TYPE] | VAR_TYPE]
|
|
|
|
|
|
def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
|
|
"""Process the individual match cases.
|
|
|
|
Args:
|
|
cases: The match cases.
|
|
|
|
Raises:
|
|
ValueError: If the default case is not the last case or the tuple elements are less than 2.
|
|
"""
|
|
for case in cases:
|
|
if not isinstance(case, tuple):
|
|
raise ValueError(
|
|
"rx.match should have tuples of cases and a default case as the last argument."
|
|
)
|
|
|
|
# There should be at least two elements in a case tuple(a condition and return value)
|
|
if len(case) < 2:
|
|
raise ValueError(
|
|
"A case tuple should have at least a match case element and a return value."
|
|
)
|
|
|
|
|
|
def _validate_return_types(*return_values: Any) -> bool:
|
|
"""Validate that match cases have the same return types.
|
|
|
|
Args:
|
|
return_values: The return values of the match cases.
|
|
|
|
Returns:
|
|
True if all cases have the same return types.
|
|
|
|
Raises:
|
|
MatchTypeError: If the return types of cases are different.
|
|
"""
|
|
|
|
def is_component_or_component_var(obj: Any) -> bool:
|
|
return types._isinstance(obj, BaseComponent) or (
|
|
isinstance(obj, Var)
|
|
and types.safe_typehint_issubclass(
|
|
obj._var_type, Union[list[BaseComponent], BaseComponent]
|
|
)
|
|
)
|
|
|
|
def type_of_return_value(obj: Any) -> Any:
|
|
if isinstance(obj, Var):
|
|
return obj._var_type
|
|
return type(obj)
|
|
|
|
is_return_type_component = [
|
|
is_component_or_component_var(return_type) for return_type in return_values
|
|
]
|
|
|
|
if any(is_return_type_component) and not all(is_return_type_component):
|
|
non_component_return_types = [
|
|
(type_of_return_value(return_value), i)
|
|
for i, return_value in enumerate(return_values)
|
|
if not is_return_type_component[i]
|
|
]
|
|
raise MatchTypeError(
|
|
"Match cases should have the same return types. "
|
|
+ "Expected return types to be of type Component or Var[Component]. "
|
|
+ ". ".join(
|
|
[
|
|
f"Return type of case {i} is {return_type}"
|
|
for return_type, i in non_component_return_types
|
|
]
|
|
)
|
|
)
|
|
|
|
return all(is_return_type_component)
|
|
|
|
|
|
def _create_match_var(
|
|
match_cond_var: Var,
|
|
match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
|
|
default: VAR_TYPE | Var[VAR_TYPE],
|
|
) -> Var[VAR_TYPE]:
|
|
"""Create the match var.
|
|
|
|
Args:
|
|
match_cond_var: The match condition var.
|
|
match_cases: The match cases.
|
|
default: The default case.
|
|
|
|
Returns:
|
|
The match var.
|
|
"""
|
|
return MatchOperation.create(match_cond_var, match_cases, default)
|
|
|
|
|
|
def match(
|
|
cond: Any,
|
|
*cases: Unpack[
|
|
tuple[Unpack[tuple[CASE_TYPE[VAR_TYPE], ...]], Var[VAR_TYPE] | VAR_TYPE]
|
|
],
|
|
) -> Var[VAR_TYPE]:
|
|
"""Create a match var.
|
|
|
|
Args:
|
|
cond: The condition to match.
|
|
cases: The match cases. Each case should be a tuple with the first elements as the match case and the last element as the return value. The last argument should be the default case.
|
|
|
|
Returns:
|
|
The match var.
|
|
|
|
Raises:
|
|
ValueError: If the default case is not the last case or the tuple elements are less than 2.
|
|
"""
|
|
default = types.Unset()
|
|
|
|
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
|
|
raise ValueError("rx.match can only have one default case.")
|
|
|
|
if not cases:
|
|
raise ValueError("rx.match should have at least one case.")
|
|
|
|
# Get the default case which should be the last non-tuple arg
|
|
if not isinstance(cases[-1], tuple):
|
|
default = cases[-1]
|
|
actual_cases = cases[:-1]
|
|
else:
|
|
actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
|
|
|
|
_process_match_cases(actual_cases)
|
|
|
|
is_component_match = _validate_return_types(
|
|
*[case[-1] for case in actual_cases],
|
|
*([default] if not isinstance(default, types.Unset) else []),
|
|
)
|
|
|
|
if isinstance(default, types.Unset) and not is_component_match:
|
|
raise ValueError(
|
|
"For cases with return types as Vars, a default case must be provided"
|
|
)
|
|
|
|
if isinstance(default, types.Unset):
|
|
default = Fragment.create()
|
|
|
|
default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
|
|
|
|
return _create_match_var(
|
|
cond,
|
|
actual_cases,
|
|
default,
|
|
)
|