solve some but not all pyright issues

This commit is contained in:
Khaleel Al-Adhami 2025-01-16 20:06:02 -08:00
parent 3d73f561b7
commit 112b2ed948
3 changed files with 53 additions and 24 deletions

View File

@ -16,7 +16,7 @@ from itertools import chain
from multiprocessing import Pool, cpu_count
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin
from typing import Any, Callable, Iterable, Sequence, Type, cast, get_args, get_origin
from reflex.components.component import Component
from reflex.utils import types as rx_types
@ -229,7 +229,9 @@ def _generate_imports(
"""
return [
*[
ast.ImportFrom(module=name, names=[ast.alias(name=val) for val in values])
ast.ImportFrom(
module=name, names=[ast.alias(name=val) for val in values], level=0
)
for name, values in DEFAULT_IMPORTS.items()
],
ast.Import([ast.alias("reflex")]),
@ -428,16 +430,15 @@ def type_to_ast(typ, cls: type) -> ast.AST:
return ast.Name(id=base_name)
# Convert all type arguments recursively
arg_nodes = [type_to_ast(arg, cls) for arg in args]
arg_nodes = cast(list[ast.expr], [type_to_ast(arg, cls) for arg in args])
# Special case for single-argument types (like List[T] or Optional[T])
if len(arg_nodes) == 1:
slice_value = arg_nodes[0]
else:
slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load())
return ast.Subscript(
value=ast.Name(id=base_name), slice=ast.Index(value=slice_value), ctx=ast.Load()
value=ast.Name(id=base_name), slice=slice_value, ctx=ast.Load()
)
@ -630,7 +631,7 @@ def _generate_component_create_functiondef(
),
),
ast.Expr(
value=ast.Ellipsis(),
value=ast.Constant(...),
),
],
decorator_list=[
@ -641,8 +642,14 @@ def _generate_component_create_functiondef(
else [ast.Name(id="classmethod")]
),
],
lineno=node.lineno if node is not None else None,
returns=ast.Constant(value=clz.__name__),
**(
{
"lineno": node.lineno,
}
if node is not None
else {}
),
)
return definition
@ -690,13 +697,19 @@ def _generate_staticmethod_call_functiondef(
),
],
decorator_list=[ast.Name(id="staticmethod")],
lineno=node.lineno if node is not None else None,
returns=ast.Constant(
value=_get_type_hint(
typing.get_type_hints(clz.__call__).get("return", None),
type_hint_globals,
)
),
**(
{
"lineno": node.lineno,
}
if node is not None
else {}
),
)
return definition
@ -731,7 +744,12 @@ def _generate_namespace_call_functiondef(
# Determine which class is wrapped by the namespace __call__ method
component_clz = clz.__call__.__self__
if clz.__call__.__func__.__name__ != "create":
func = getattr(clz.__call__, "__func__", None)
if func is None:
raise TypeError(f"__call__ method on {clz_name} does not have a __func__")
if func.__name__ != "create":
return None
definition = _generate_component_create_functiondef(
@ -914,7 +932,7 @@ class StubGenerator(ast.NodeTransformer):
node.body.append(call_definition)
if not node.body:
# We should never return an empty body.
node.body.append(ast.Expr(value=ast.Ellipsis()))
node.body.append(ast.Expr(value=ast.Constant(...)))
self.current_class = None
return node
@ -941,9 +959,9 @@ class StubGenerator(ast.NodeTransformer):
if node.name.startswith("_") and node.name != "__call__":
return None # remove private methods
if node.body[-1] != ast.Expr(value=ast.Ellipsis()):
if node.body[-1] != ast.Expr(value=ast.Constant(...)):
# Blank out the function body for public functions.
node.body = [ast.Expr(value=ast.Ellipsis())]
node.body = [ast.Expr(value=ast.Constant(...))]
return node
def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:

View File

@ -69,21 +69,21 @@ else:
# Potential GenericAlias types for isinstance checks.
GenericAliasTypes = [_GenericAlias]
_GenericAliasTypes: list[type] = [_GenericAlias]
with contextlib.suppress(ImportError):
# For newer versions of Python.
from types import GenericAlias # type: ignore
GenericAliasTypes.append(GenericAlias)
_GenericAliasTypes.append(GenericAlias)
with contextlib.suppress(ImportError):
# For older versions of Python.
from typing import _SpecialGenericAlias # type: ignore
GenericAliasTypes.append(_SpecialGenericAlias)
_GenericAliasTypes.append(_SpecialGenericAlias)
GenericAliasTypes = tuple(GenericAliasTypes)
GenericAliasTypes = tuple(_GenericAliasTypes)
# Potential Union types for isinstance checks (UnionType added in py3.10).
UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,)
@ -181,7 +181,7 @@ def is_generic_alias(cls: GenericType) -> bool:
return isinstance(cls, GenericAliasTypes)
def unionize(*args: GenericType) -> Type:
def unionize(*args: GenericType) -> GenericType:
"""Unionize the types.
Args:
@ -415,7 +415,7 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
@lru_cache()
def get_base_class(cls: GenericType) -> Type:
def get_base_class(cls: GenericType) -> Type | tuple[Type, ...]:
"""Get the base class of a class.
Args:
@ -435,7 +435,14 @@ def get_base_class(cls: GenericType) -> Type:
return type(get_args(cls)[0])
if is_union(cls):
return tuple(get_base_class(arg) for arg in get_args(cls))
base_classes = []
for arg in get_args(cls):
sub_base_classes = get_base_class(arg)
if isinstance(sub_base_classes, tuple):
base_classes.extend(sub_base_classes)
else:
base_classes.append(sub_base_classes)
return tuple(base_classes)
return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls

View File

@ -15,6 +15,7 @@ from typing import (
Sequence,
TypeVar,
Union,
cast,
overload,
)
@ -1102,7 +1103,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
_cases: tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...] = dataclasses.field(
default_factory=tuple
)
_default: Var[VAR_TYPE] = dataclasses.field(
_default: Var[VAR_TYPE] = dataclasses.field( # pyright: ignore[reportAssignmentType]
default_factory=lambda: Var.create(None)
)
@ -1170,11 +1171,14 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
The match operation.
"""
cond = Var.create(cond)
cases = tuple(tuple(Var.create(c) for c in case) for case in cases)
default = Var.create(default)
cases = cast(
tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
tuple(tuple(Var.create(c) for c in case) for case in cases),
)
_default = cast(Var[VAR_TYPE], Var.create(default))
var_type = _var_type or unionize(
*(case[-1]._var_type for case in cases),
default._var_type,
_default._var_type,
)
return cls(
_js_expr="",
@ -1182,7 +1186,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
_var_type=var_type,
_cond=cond,
_cases=cases,
_default=default,
_default=_default,
)