From 112b2ed948f67bcf31002ffa81aec2f6e813e400 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 16 Jan 2025 20:06:02 -0800 Subject: [PATCH] solve some but not all pyright issues --- reflex/utils/pyi_generator.py | 42 +++++++++++++++++++++++++---------- reflex/utils/types.py | 21 ++++++++++++------ reflex/vars/number.py | 14 +++++++----- 3 files changed, 53 insertions(+), 24 deletions(-) diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index 152c06949..52a3a9fff 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -16,7 +16,7 @@ from itertools import chain from multiprocessing import Pool, cpu_count from pathlib import Path from types import ModuleType, SimpleNamespace -from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin +from typing import Any, Callable, Iterable, Sequence, Type, cast, get_args, get_origin from reflex.components.component import Component from reflex.utils import types as rx_types @@ -229,7 +229,9 @@ def _generate_imports( """ return [ *[ - ast.ImportFrom(module=name, names=[ast.alias(name=val) for val in values]) + ast.ImportFrom( + module=name, names=[ast.alias(name=val) for val in values], level=0 + ) for name, values in DEFAULT_IMPORTS.items() ], ast.Import([ast.alias("reflex")]), @@ -428,16 +430,15 @@ def type_to_ast(typ, cls: type) -> ast.AST: return ast.Name(id=base_name) # Convert all type arguments recursively - arg_nodes = [type_to_ast(arg, cls) for arg in args] + arg_nodes = cast(list[ast.expr], [type_to_ast(arg, cls) for arg in args]) # Special case for single-argument types (like List[T] or Optional[T]) if len(arg_nodes) == 1: slice_value = arg_nodes[0] else: slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load()) - return ast.Subscript( - value=ast.Name(id=base_name), slice=ast.Index(value=slice_value), ctx=ast.Load() + value=ast.Name(id=base_name), slice=slice_value, ctx=ast.Load() ) @@ -630,7 +631,7 @@ def _generate_component_create_functiondef( ), ), ast.Expr( - value=ast.Ellipsis(), + value=ast.Constant(...), ), ], decorator_list=[ @@ -641,8 +642,14 @@ def _generate_component_create_functiondef( else [ast.Name(id="classmethod")] ), ], - lineno=node.lineno if node is not None else None, returns=ast.Constant(value=clz.__name__), + **( + { + "lineno": node.lineno, + } + if node is not None + else {} + ), ) return definition @@ -690,13 +697,19 @@ def _generate_staticmethod_call_functiondef( ), ], decorator_list=[ast.Name(id="staticmethod")], - lineno=node.lineno if node is not None else None, returns=ast.Constant( value=_get_type_hint( typing.get_type_hints(clz.__call__).get("return", None), type_hint_globals, ) ), + **( + { + "lineno": node.lineno, + } + if node is not None + else {} + ), ) return definition @@ -731,7 +744,12 @@ def _generate_namespace_call_functiondef( # Determine which class is wrapped by the namespace __call__ method component_clz = clz.__call__.__self__ - if clz.__call__.__func__.__name__ != "create": + func = getattr(clz.__call__, "__func__", None) + + if func is None: + raise TypeError(f"__call__ method on {clz_name} does not have a __func__") + + if func.__name__ != "create": return None definition = _generate_component_create_functiondef( @@ -914,7 +932,7 @@ class StubGenerator(ast.NodeTransformer): node.body.append(call_definition) if not node.body: # We should never return an empty body. - node.body.append(ast.Expr(value=ast.Ellipsis())) + node.body.append(ast.Expr(value=ast.Constant(...))) self.current_class = None return node @@ -941,9 +959,9 @@ class StubGenerator(ast.NodeTransformer): if node.name.startswith("_") and node.name != "__call__": return None # remove private methods - if node.body[-1] != ast.Expr(value=ast.Ellipsis()): + if node.body[-1] != ast.Expr(value=ast.Constant(...)): # Blank out the function body for public functions. - node.body = [ast.Expr(value=ast.Ellipsis())] + node.body = [ast.Expr(value=ast.Constant(...))] return node def visit_Assign(self, node: ast.Assign) -> ast.Assign | None: diff --git a/reflex/utils/types.py b/reflex/utils/types.py index d968b9a09..e323c5441 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -69,21 +69,21 @@ else: # Potential GenericAlias types for isinstance checks. -GenericAliasTypes = [_GenericAlias] +_GenericAliasTypes: list[type] = [_GenericAlias] with contextlib.suppress(ImportError): # For newer versions of Python. from types import GenericAlias # type: ignore - GenericAliasTypes.append(GenericAlias) + _GenericAliasTypes.append(GenericAlias) with contextlib.suppress(ImportError): # For older versions of Python. from typing import _SpecialGenericAlias # type: ignore - GenericAliasTypes.append(_SpecialGenericAlias) + _GenericAliasTypes.append(_SpecialGenericAlias) -GenericAliasTypes = tuple(GenericAliasTypes) +GenericAliasTypes = tuple(_GenericAliasTypes) # Potential Union types for isinstance checks (UnionType added in py3.10). UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,) @@ -181,7 +181,7 @@ def is_generic_alias(cls: GenericType) -> bool: return isinstance(cls, GenericAliasTypes) -def unionize(*args: GenericType) -> Type: +def unionize(*args: GenericType) -> GenericType: """Unionize the types. Args: @@ -415,7 +415,7 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None @lru_cache() -def get_base_class(cls: GenericType) -> Type: +def get_base_class(cls: GenericType) -> Type | tuple[Type, ...]: """Get the base class of a class. Args: @@ -435,7 +435,14 @@ def get_base_class(cls: GenericType) -> Type: return type(get_args(cls)[0]) if is_union(cls): - return tuple(get_base_class(arg) for arg in get_args(cls)) + base_classes = [] + for arg in get_args(cls): + sub_base_classes = get_base_class(arg) + if isinstance(sub_base_classes, tuple): + base_classes.extend(sub_base_classes) + else: + base_classes.append(sub_base_classes) + return tuple(base_classes) return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls diff --git a/reflex/vars/number.py b/reflex/vars/number.py index ee1bb4f7f..2d0fb3326 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -15,6 +15,7 @@ from typing import ( Sequence, TypeVar, Union, + cast, overload, ) @@ -1102,7 +1103,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]): _cases: tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...] = dataclasses.field( default_factory=tuple ) - _default: Var[VAR_TYPE] = dataclasses.field( + _default: Var[VAR_TYPE] = dataclasses.field( # pyright: ignore[reportAssignmentType] default_factory=lambda: Var.create(None) ) @@ -1170,11 +1171,14 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]): The match operation. """ cond = Var.create(cond) - cases = tuple(tuple(Var.create(c) for c in case) for case in cases) - default = Var.create(default) + cases = cast( + tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...], + tuple(tuple(Var.create(c) for c in case) for case in cases), + ) + _default = cast(Var[VAR_TYPE], Var.create(default)) var_type = _var_type or unionize( *(case[-1]._var_type for case in cases), - default._var_type, + _default._var_type, ) return cls( _js_expr="", @@ -1182,7 +1186,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]): _var_type=var_type, _cond=cond, _cases=cases, - _default=default, + _default=_default, )