precommit

This commit is contained in:
Khaleel Al-Adhami 2025-01-31 14:19:13 -08:00
parent 46c66b2adf
commit 749577f0bc
12 changed files with 45 additions and 52 deletions

View File

@ -929,15 +929,16 @@ class Component(BaseComponent, ABC):
valid_children = self._valid_children + allowed_components valid_children = self._valid_children + allowed_components
def child_is_in_valid(child): def child_is_in_valid(child_component: Any):
if type(child).__name__ in valid_children: if type(child_component).__name__ in valid_children:
return True return True
if ( if (
not isinstance(child, Bare) not isinstance(child_component, Bare)
or child.contents is None or child_component.contents is None
or not isinstance(child.contents, Var) or not isinstance(child_component.contents, Var)
or (var_data := child.contents._get_all_var_data()) is None or (var_data := child_component.contents._get_all_var_data())
is None
): ):
return False return False

View File

@ -4,8 +4,8 @@ from __future__ import annotations
from typing import Optional from typing import Optional
from reflex.components.base.fragment import Fragment
from reflex import constants from reflex import constants
from reflex.components.base.fragment import Fragment
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.core.cond import cond from reflex.components.core.cond import cond
from reflex.components.datadisplay.logo import svg_logo from reflex.components.datadisplay.logo import svg_logo

View File

@ -2,10 +2,8 @@
from __future__ import annotations from __future__ import annotations
import functools
from typing import Callable, Iterable from typing import Callable, Iterable
from reflex.utils.exceptions import UntypedVarError
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var
from reflex.vars.object import ObjectVar from reflex.vars.object import ObjectVar
from reflex.vars.sequence import ArrayVar from reflex.vars.sequence import ArrayVar

View File

@ -139,7 +139,7 @@ class ColorModeIconButton(IconButton):
if allow_system: if allow_system:
def color_mode_item(_color_mode: str): def color_mode_item(_color_mode: Literal["light", "dark", "system"]):
return dropdown_menu.item( return dropdown_menu.item(
_color_mode.title(), on_click=set_color_mode(_color_mode) _color_mode.title(), on_click=set_color_mode(_color_mode)
) )

View File

@ -880,7 +880,9 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
): ):
return all( return all(
typehint_issubclass(subclass, superclass) typehint_issubclass(subclass, superclass)
for subclass, superclass in zip(possible_subclass, possible_superclass) for subclass, superclass in zip(
possible_subclass, possible_superclass, strict=False
)
) )
if possible_subclass is possible_superclass: if possible_subclass is possible_superclass:
return True return True

View File

@ -13,7 +13,7 @@ import random
import re import re
import string import string
import warnings import warnings
from types import CodeType, FunctionType from types import CodeType, EllipsisType, FunctionType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -54,7 +54,6 @@ from reflex.constants.compiler import Hooks
from reflex.utils import console, exceptions, imports, serializers, types from reflex.utils import console, exceptions, imports, serializers, types
from reflex.utils.exceptions import ( from reflex.utils.exceptions import (
UntypedComputedVarError, UntypedComputedVarError,
VarAttributeError,
VarDependencyError, VarDependencyError,
VarTypeError, VarTypeError,
VarValueError, VarValueError,
@ -108,12 +107,7 @@ class ReflexCallable(Protocol[P, R]):
__call__: Callable[P, R] __call__: Callable[P, R]
if sys.version_info >= (3, 10): ReflexCallableParams = Union[EllipsisType, Tuple[GenericType, ...]]
from types import EllipsisType
ReflexCallableParams = Union[EllipsisType, Tuple[GenericType, ...]]
else:
ReflexCallableParams = Union[Any, Tuple[GenericType, ...]]
def unwrap_reflex_callalbe( def unwrap_reflex_callalbe(
@ -1336,10 +1330,15 @@ class Var(Generic[VAR_TYPE]):
""" """
from .sequence import ArrayVar from .sequence import ArrayVar
if step is None:
return ArrayVar.range(first_endpoint, second_endpoint)
return ArrayVar.range(first_endpoint, second_endpoint, step) return ArrayVar.range(first_endpoint, second_endpoint, step)
def __bool__(self) -> bool: if not TYPE_CHECKING:
"""Raise exception if using Var in a boolean context.
def __bool__(self) -> bool:
"""Raise exception if using Var in a boolean context.
Raises: Raises:
VarTypeError: when attempting to bool-ify the Var. VarTypeError: when attempting to bool-ify the Var.
@ -1924,12 +1923,12 @@ def var_operation(
_raw_js_function=custom_operation_return._raw_js_function, _raw_js_function=custom_operation_return._raw_js_function,
_original_var_operation=simplified_operation, _original_var_operation=simplified_operation,
_var_type=ReflexCallable[ _var_type=ReflexCallable[
tuple( tuple( # pyright: ignore [reportInvalidTypeArguments]
arg_python_type arg_python_type
if isinstance(arg_default_values[i], inspect.Parameter) if isinstance(arg_default_values[i], inspect.Parameter)
else VarWithDefault[arg_python_type] else VarWithDefault[arg_python_type]
for i, (_, arg_python_type) in enumerate(args_with_type_hints) for i, (_, arg_python_type) in enumerate(args_with_type_hints)
), # type: ignore ),
custom_operation_return._var_type, custom_operation_return._var_type,
], ],
) )
@ -2049,11 +2048,6 @@ class CachedVarOperation:
RETURN_TYPE = TypeVar("RETURN_TYPE") RETURN_TYPE = TypeVar("RETURN_TYPE")
DICT_KEY = TypeVar("DICT_KEY")
DICT_VAL = TypeVar("DICT_VAL")
LIST_INSIDE = TypeVar("LIST_INSIDE")
class FakeComputedVarBaseClass(property): class FakeComputedVarBaseClass(property):
"""A fake base class for ComputedVar to avoid inheriting from property.""" """A fake base class for ComputedVar to avoid inheriting from property."""
@ -2273,17 +2267,17 @@ class ComputedVar(Var[RETURN_TYPE]):
@overload @overload
def __get__( def __get__(
self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]], self: ComputedVar[MAPPING_TYPE],
instance: None, instance: None,
owner: Type, owner: Type,
) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ... ) -> ObjectVar[MAPPING_TYPE]: ...
@overload @overload
def __get__( def __get__(
self: ComputedVar[Sequence[LIST_INSIDE]], self: ComputedVar[SEQUENCE_TYPE],
instance: None, instance: None,
owner: Type, owner: Type,
) -> ArrayVar[Sequence[LIST_INSIDE]]: ... ) -> ArrayVar[SEQUENCE_TYPE]: ...
@overload @overload
def __get__(self, instance: None, owner: Type) -> ComputedVar[RETURN_TYPE]: ... def __get__(self, instance: None, owner: Type) -> ComputedVar[RETURN_TYPE]: ...
@ -2588,7 +2582,7 @@ RETURN = TypeVar("RETURN")
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class CustomVarOperationReturn(Var[RETURN]): class CustomVarOperationReturn(Var[RETURN]):
"""Base class for custom var operations.""" """Base class for custom var operations."""
@ -3202,7 +3196,7 @@ class Field(Generic[FIELD_TYPE]):
def __get__(self: Field[int], instance: None, owner: Any) -> NumberVar[int]: ... def __get__(self: Field[int], instance: None, owner: Any) -> NumberVar[int]: ...
@overload @overload
def __get__(self: Field[float], instance: None, owner) -> NumberVar[float]: ... def __get__(self: Field[float], instance: None, owner: Any) -> NumberVar[float]: ...
@overload @overload
def __get__(self: Field[str], instance: None, owner: Any) -> StringVar[str]: ... def __get__(self: Field[str], instance: None, owner: Any) -> StringVar[str]: ...
@ -3251,7 +3245,7 @@ def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]:
Returns: Returns:
The Field. The Field.
""" """
return value # type: ignore return value # pyright: ignore [reportReturnType]
def and_operation(a: Var | Any, b: Var | Any) -> Var: def and_operation(a: Var | Any, b: Var | Any) -> Var:

View File

@ -1193,7 +1193,6 @@ class FunctionVar(
@overload @overload
def call(self: FunctionVar[NoReturn], *args: Var | Any) -> Var: ... def call(self: FunctionVar[NoReturn], *args: Var | Any) -> Var: ...
def call(self, *args: Var | Any) -> Var: # pyright: ignore [reportInconsistentOverload]
def call(self, *args: Var | Any) -> Var: # pyright: ignore [reportInconsistentOverload] def call(self, *args: Var | Any) -> Var: # pyright: ignore [reportInconsistentOverload]
"""Call the function with the given arguments. """Call the function with the given arguments.
@ -1299,7 +1298,7 @@ class FunctionVar(
""" """
args_types, return_type = unwrap_reflex_callalbe(self._var_type) args_types, return_type = unwrap_reflex_callalbe(self._var_type)
if isinstance(args_types, tuple): if isinstance(args_types, tuple):
return ReflexCallable[[*args_types[len(args) :]], return_type], None # type: ignore return ReflexCallable[[*args_types[len(args) :]], return_type], None
return ReflexCallable[..., return_type], None return ReflexCallable[..., return_type], None
def _arg_len(self) -> int | None: def _arg_len(self) -> int | None:
@ -1637,7 +1636,7 @@ def pre_check_args(
Raises: Raises:
VarTypeError: If the arguments are invalid. VarTypeError: If the arguments are invalid.
""" """
for i, (validator, arg) in enumerate(zip(self._validators, args)): for i, (validator, arg) in enumerate(zip(self._validators, args, strict=False)):
if (validation_message := validator(arg)) is not None: if (validation_message := validator(arg)) is not None:
arg_name = self._args.args[i] if i < len(self._args.args) else None arg_name = self._args.args[i] if i < len(self._args.args) else None
if arg_name is not None: if arg_name is not None:
@ -1694,9 +1693,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar[CALLABLE_TYPE]):
_cached_var_name = cached_property_no_lock(format_args_function_operation) _cached_var_name = cached_property_no_lock(format_args_function_operation)
_pre_check = pre_check_args # type: ignore _pre_check = pre_check_args
_partial_type = figure_partial_type # type: ignore _partial_type = figure_partial_type
@classmethod @classmethod
def create( def create(
@ -1776,9 +1775,9 @@ class ArgsFunctionOperationBuilder(
_cached_var_name = cached_property_no_lock(format_args_function_operation) _cached_var_name = cached_property_no_lock(format_args_function_operation)
_pre_check = pre_check_args # type: ignore _pre_check = pre_check_args
_partial_type = figure_partial_type # type: ignore _partial_type = figure_partial_type
@classmethod @classmethod
def create( def create(

View File

@ -1080,7 +1080,7 @@ TUPLE_ENDS_IN_VAR_RELAXED = tuple[
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class MatchOperation(CachedVarOperation, Var[VAR_TYPE]): class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
"""Base class for immutable match operations.""" """Base class for immutable match operations."""

View File

@ -142,7 +142,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
# NoReturn is used here to catch when key value is Any # NoReturn is used here to catch when key value is Any
@overload @overload
def __getitem__( # pyright: ignore [reportOverlappingOverload]
def __getitem__( # pyright: ignore [reportOverlappingOverload] def __getitem__( # pyright: ignore [reportOverlappingOverload]
self: ObjectVar[Mapping[Any, NoReturn]], self: ObjectVar[Mapping[Any, NoReturn]],
key: Var | Any, key: Var | Any,

View File

@ -773,7 +773,7 @@ def map_array_operation(
type_computer=nary_type_computer( type_computer=nary_type_computer(
ReflexCallable[[List[Any], ReflexCallable], List[Any]], ReflexCallable[[List[Any], ReflexCallable], List[Any]],
ReflexCallable[[ReflexCallable], List[Any]], ReflexCallable[[ReflexCallable], List[Any]],
computer=lambda args: List[unwrap_reflex_callalbe(args[1]._var_type)[1]], # type: ignore computer=lambda args: List[unwrap_reflex_callalbe(args[1]._var_type)[1]],
), ),
) )
@ -846,7 +846,7 @@ class SliceVar(Var[slice], python_types=slice):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class LiteralSliceVar(CachedVarOperation, LiteralVar, SliceVar): class LiteralSliceVar(CachedVarOperation, LiteralVar, SliceVar):
"""Base class for immutable literal slice vars.""" """Base class for immutable literal slice vars."""
@ -1245,7 +1245,7 @@ _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class LiteralStringVar(LiteralVar, StringVar[str]): class LiteralStringVar(LiteralVar, StringVar[str]):
"""Base class for immutable literal string vars.""" """Base class for immutable literal string vars."""
@ -1367,7 +1367,7 @@ class LiteralStringVar(LiteralVar, StringVar[str]):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class ConcatVarOperation(CachedVarOperation, StringVar[str]): class ConcatVarOperation(CachedVarOperation, StringVar[str]):
"""Representing a concatenation of literal string vars.""" """Representing a concatenation of literal string vars."""

View File

@ -136,7 +136,7 @@ async def test_lifespan(lifespan_app: AppHarness):
task_global = driver.find_element(By.ID, "task_global") task_global = driver.find_element(By.ID, "task_global")
assert context_global.text == "2" assert context_global.text == "2"
assert lifespan_app.app_module.lifespan_context_global_getter() == 2 # type: ignore assert lifespan_app.app_module.lifespan_context_global_getter() == 2
original_task_global_text = task_global.text original_task_global_text = task_global.text
original_task_global_value = int(original_task_global_text) original_task_global_value = int(original_task_global_text)
@ -145,7 +145,7 @@ async def test_lifespan(lifespan_app: AppHarness):
assert ( assert (
lifespan_app.app_module.lifespan_task_global_getter() lifespan_app.app_module.lifespan_task_global_getter()
> original_task_global_value > original_task_global_value
) # type: ignore )
assert int(task_global.text) > original_task_global_value assert int(task_global.text) > original_task_global_value
# Kill the backend # Kill the backend

View File

@ -1249,11 +1249,11 @@ def test_type_chains():
List[int], List[int],
) )
assert ( assert (
str(object_var.keys()[0].upper()) str(object_var.keys()[0].upper()) # pyright: ignore [reportAttributeAccessIssue]
== '(((...args) => (((_string) => String.prototype.toUpperCase.apply(_string))((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))), ...args)))(0)), ...args)))())' == '(((...args) => (((_string) => String.prototype.toUpperCase.apply(_string))((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))), ...args)))(0)), ...args)))())'
) )
assert ( assert (
str(object_var.entries()[1][1] - 1) str(object_var.entries()[1][1] - 1) # pyright: ignore [reportCallIssue, reportOperatorIssue]
== '((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))), ...args)))(1)), ...args)))(1)) - 1)' == '((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))), ...args)))(1)), ...args)))(1)) - 1)'
) )
assert ( assert (