From 3080cadd3103f0462f7487fdeae859c5724fbb5c Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 13 Feb 2025 13:36:45 -0800 Subject: [PATCH] some random fixes after merge --- reflex/components/core/client_side_routing.py | 8 +-- reflex/components/core/cond.py | 23 +++++--- reflex/components/core/foreach.py | 58 ++++++++----------- reflex/vars/base.py | 20 +------ reflex/vars/function.py | 4 +- reflex/vars/object.py | 1 + reflex/vars/sequence.py | 2 +- 7 files changed, 52 insertions(+), 64 deletions(-) diff --git a/reflex/components/core/client_side_routing.py b/reflex/components/core/client_side_routing.py index 0fc40de5f..43d705ada 100644 --- a/reflex/components/core/client_side_routing.py +++ b/reflex/components/core/client_side_routing.py @@ -41,7 +41,7 @@ class ClientSideRouting(Component): return "" -def wait_for_client_redirect(component: Component) -> Component: +def wait_for_client_redirect(component: Component) -> Var[Component]: """Wait for a redirect to occur before rendering a component. This prevents the 404 page from flashing while the redirect is happening. @@ -53,9 +53,9 @@ def wait_for_client_redirect(component: Component) -> Component: The conditionally rendered component. """ return cond( - condition=route_not_found, - c1=component, - c2=ClientSideRouting.create(), + route_not_found, + component, + ClientSideRouting.create(), ) diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index c98ae5108..d2e7c26b0 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Union, overload +from typing import Any, TypeVar, Union, overload from reflex.components.base.fragment import Fragment from reflex.components.component import BaseComponent, Component @@ -13,14 +13,20 @@ from reflex.vars.number import ternary_operation @overload -def cond(condition: Any, c1: Component, c2: Any = None) -> Component: ... +def cond( + condition: Any, c1: BaseComponent | Var[BaseComponent], c2: Any = None, / +) -> Var[Component]: ... + + +T = TypeVar("T") +V = TypeVar("V") @overload -def cond(condition: Any, c1: Any, c2: Any) -> Var: ... +def cond(condition: Any, c1: T | Var[T], c2: V | Var[V], /) -> Var[T | V]: ... -def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var: +def cond(condition: Any, c1: Any, c2: Any = None, /) -> Var: """Create a conditional component or Prop. Args: @@ -59,14 +65,17 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var: @overload -def color_mode_cond(light: Component, dark: Component | None = None) -> Component: ... # pyright: ignore [reportOverlappingOverload] +def color_mode_cond( + light: BaseComponent | Var[BaseComponent], + dark: BaseComponent | Var[BaseComponent] | None = ..., +) -> Var[Component]: ... @overload -def color_mode_cond(light: Any, dark: Any = None) -> Var: ... +def color_mode_cond(light: T | Var[T], dark: V | Var[V]) -> Var[T | V]: ... -def color_mode_cond(light: Any, dark: Any = None) -> Var | Component: +def color_mode_cond(light: Any, dark: Any = None) -> Var: """Create a component or Prop based on color_mode. Args: diff --git a/reflex/components/core/foreach.py b/reflex/components/core/foreach.py index a4d03e8c4..fc59ff3fb 100644 --- a/reflex/components/core/foreach.py +++ b/reflex/components/core/foreach.py @@ -4,9 +4,8 @@ from __future__ import annotations from typing import Callable, Iterable +from reflex.vars import ArrayVar, ObjectVar, StringVar from reflex.vars.base import LiteralVar, Var -from reflex.vars.object import ObjectVar -from reflex.vars.sequence import ArrayVar class ForeachVarError(TypeError): @@ -30,44 +29,37 @@ def foreach( Returns: The foreach component. - Raises: - ForeachVarError: If the iterable is of type Any. - TypeError: If the render function is a ComponentState. - UntypedVarError: If the iterable is of type Any without a type annotation. - """ - from reflex.vars import ArrayVar, ObjectVar, StringVar + Raises: + ForeachVarError: If the iterable is of type Any. + TypeError: If the render function is a ComponentState. + UntypedVarError: If the iterable is of type Any without a type annotation. + """ + from reflex.state import ComponentState - iterable = LiteralVar.create(iterable).guess_type() + iterable = LiteralVar.create(iterable).guess_type() - if iterable._var_type == Any: - raise ForeachVarError( - f"Could not foreach over var `{iterable!s}` of type Any. " - "(If you are trying to foreach over a state var, add a type annotation to the var). " - "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" - ) + if isinstance(iterable, ObjectVar): + iterable = iterable.entries() - if ( - hasattr(render_fn, "__qualname__") - and render_fn.__qualname__ == ComponentState.create.__qualname__ - ): - raise TypeError( - "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet." - ) + if isinstance(iterable, StringVar): + iterable = iterable.split() - if isinstance(iterable, ObjectVar): - iterable = iterable.entries() + if not isinstance(iterable, ArrayVar): + raise ForeachVarError( + f"Could not foreach over var `{iterable!s}` of type {iterable._var_type}. " + "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" + ) - if isinstance(iterable, StringVar): - iterable = iterable.split() + if ( + hasattr(render_fn, "__qualname__") + and render_fn.__qualname__ == ComponentState.create.__qualname__ + ): + raise TypeError( + "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet." + ) - if not isinstance(iterable, ArrayVar): - raise ForeachVarError( - f"Could not foreach over var `{iterable!s}` of type {iterable._var_type}. " - "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" - ) + return iterable.foreach(render_fn) - return iterable.foreach(render_fn) - class Foreach: """Create a foreach component.""" diff --git a/reflex/vars/base.py b/reflex/vars/base.py index c4cd910ae..d10e0bfbf 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -2403,14 +2403,7 @@ class ComputedVar(Var[RETURN_TYPE]): self: ComputedVar[SEQUENCE_TYPE], instance: None, owner: Type, - ) -> ArrayVar[list[LIST_INSIDE]]: ... - - @overload - def __get__( - self: ComputedVar[tuple[LIST_INSIDE, ...]], - instance: None, - owner: Type, - ) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ... + ) -> ArrayVar[SEQUENCE_TYPE]: ... @overload def __get__(self, instance: None, owner: Type) -> ComputedVar[RETURN_TYPE]: ... @@ -2646,17 +2639,10 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]): @overload def __get__( - self: AsyncComputedVar[list[LIST_INSIDE]], + self: AsyncComputedVar[SEQUENCE_TYPE], instance: None, owner: Type, - ) -> ArrayVar[list[LIST_INSIDE]]: ... - - @overload - def __get__( - self: AsyncComputedVar[tuple[LIST_INSIDE, ...]], - instance: None, - owner: Type, - ) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ... + ) -> ArrayVar[SEQUENCE_TYPE]: ... @overload def __get__( diff --git a/reflex/vars/function.py b/reflex/vars/function.py index 9120370d3..56c38a007 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -1880,11 +1880,11 @@ def _generate_overloads_for_function_var_call(maximum_args: int = 4) -> str: ): for return_type, return_type_var in return_type_mapping.items(): required_args = [ - f"arg{j + 1}: Union[V" f"{j + 1}, Var[V{j + 1}]]" + f"arg{j + 1}: Union[V{j + 1}, Var[V{j + 1}]]" for j in range(number_of_required_args) ] optional_args = [ - f"arg{j + 1}: Union[V" f"{j + 1}, Var[V{j + 1}], Unset] = Unset()" + f"arg{j + 1}: Union[V{j + 1}, Var[V{j + 1}], Unset] = Unset()" for j in range( number_of_required_args, number_of_required_args + number_of_optional_args, diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 0e25e5288..c23fca1f6 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -28,6 +28,7 @@ from reflex.utils.types import ( get_attribute_access_type, get_origin, safe_issubclass, + unionize, ) from .base import ( diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index c285e06c8..e1f2b8190 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -473,7 +473,7 @@ def array_length_operation(array: Var[ARRAY_VAR_TYPE]): @var_operation def string_split_operation( - string: Var[str], sep: VarWithDefault[str] = VarWithDefault("") + string: Var[STRING_TYPE], sep: VarWithDefault[STRING_TYPE] = VarWithDefault("") ): """Split a string.