improve match and cond checking

This commit is contained in:
Khaleel Al-Adhami 2025-01-23 15:03:45 -08:00
parent 0d746bf762
commit 2ffa698c6b
2 changed files with 6 additions and 9 deletions

View File

@ -2,13 +2,12 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, overload from typing import Any, Union, overload
from reflex.components.base.fragment import Fragment from reflex.components.base.fragment import Fragment
from reflex.components.component import BaseComponent, Component from reflex.components.component import BaseComponent, Component
from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode
from reflex.utils import types from reflex.utils import types
from reflex.utils.types import safe_issubclass
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var
from reflex.vars.number import ternary_operation from reflex.vars.number import ternary_operation
@ -41,9 +40,8 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
# If the first component is a component, create a Fragment if the second component is not set. # If the first component is a component, create a Fragment if the second component is not set.
if isinstance(c1, BaseComponent) or ( if isinstance(c1, BaseComponent) or (
isinstance(c1, Var) isinstance(c1, Var)
and ( and types.safe_typehint_issubclass(
safe_issubclass(c1._var_type, BaseComponent) c1._var_type, Union[BaseComponent, list[BaseComponent]]
or types.safe_typehint_issubclass(c1._var_type, list[BaseComponent])
) )
): ):
c2 = c2 if c2 is not None else Fragment.create() c2 = c2 if c2 is not None else Fragment.create()

View File

@ -1,6 +1,6 @@
"""rx.match.""" """rx.match."""
from typing import Any, cast from typing import Any, Union, cast
from typing_extensions import Unpack from typing_extensions import Unpack
@ -49,9 +49,8 @@ def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None
def is_component_or_component_var(obj: Any) -> bool: def is_component_or_component_var(obj: Any) -> bool:
return types._isinstance(obj, BaseComponent) or ( return types._isinstance(obj, BaseComponent) or (
isinstance(obj, Var) isinstance(obj, Var)
and ( and types.safe_typehint_issubclass(
types.safe_typehint_issubclass(obj._var_type, BaseComponent) obj._var_type, Union[list[BaseComponent], BaseComponent]
or types.safe_typehint_issubclass(obj._var_type, list[BaseComponent])
) )
) )