have the pyi here as well
This commit is contained in:
parent
e88aa8257b
commit
3c19ed4f0b
@ -372,6 +372,53 @@ def _extract_class_props_as_ast_nodes(
|
|||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def type_to_ast(typ) -> ast.AST:
|
||||||
|
"""Converts any type annotation into its AST representation.
|
||||||
|
Handles nested generic types, unions, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
typ: The type annotation to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The AST representation of the type annotation.
|
||||||
|
"""
|
||||||
|
if typ is type(None):
|
||||||
|
return ast.Name(id="None")
|
||||||
|
|
||||||
|
origin = get_origin(typ)
|
||||||
|
|
||||||
|
# Handle plain types (int, str, custom classes, etc.)
|
||||||
|
if origin is None:
|
||||||
|
if hasattr(typ, "__name__"):
|
||||||
|
return ast.Name(id=typ.__name__)
|
||||||
|
elif hasattr(typ, "_name"):
|
||||||
|
return ast.Name(id=typ._name)
|
||||||
|
return ast.Name(id=str(typ))
|
||||||
|
|
||||||
|
# Get the base type name (List, Dict, Optional, etc.)
|
||||||
|
base_name = origin._name if hasattr(origin, "_name") else origin.__name__
|
||||||
|
|
||||||
|
# Get type arguments
|
||||||
|
args = get_args(typ)
|
||||||
|
|
||||||
|
# Handle empty type arguments
|
||||||
|
if not args:
|
||||||
|
return ast.Name(id=base_name)
|
||||||
|
|
||||||
|
# Convert all type arguments recursively
|
||||||
|
arg_nodes = [type_to_ast(arg) for arg in args]
|
||||||
|
|
||||||
|
# Special case for single-argument types (like List[T] or Optional[T])
|
||||||
|
if len(arg_nodes) == 1:
|
||||||
|
slice_value = arg_nodes[0]
|
||||||
|
else:
|
||||||
|
slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load())
|
||||||
|
|
||||||
|
return ast.Subscript(
|
||||||
|
value=ast.Name(id=base_name), slice=ast.Index(value=slice_value), ctx=ast.Load()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_parent_imports(func):
|
def _get_parent_imports(func):
|
||||||
_imports = {"reflex.vars": ["Var"]}
|
_imports = {"reflex.vars": ["Var"]}
|
||||||
for type_hint in inspect.get_annotations(func).values():
|
for type_hint in inspect.get_annotations(func).values():
|
||||||
@ -439,12 +486,24 @@ def _generate_component_create_functiondef(
|
|||||||
for argument in arguments
|
for argument in arguments
|
||||||
]
|
]
|
||||||
|
|
||||||
return ast.Name(
|
# Convert each argument type to its AST representation
|
||||||
id=f"Optional[EventType[{', '.join(
|
type_args = [type_to_ast(arg) for arg in arguments_without_var]
|
||||||
[arg.__name__ for arg in arguments_without_var]
|
|
||||||
)}]]"
|
# Join the type arguments with commas for EventType
|
||||||
|
args_str = ", ".join(ast.unparse(arg) for arg in type_args)
|
||||||
|
|
||||||
|
# Create EventType using the joined string
|
||||||
|
event_type = ast.Name(id=f"EventType[{args_str}]")
|
||||||
|
|
||||||
|
# Wrap in Optional
|
||||||
|
optional_type = ast.Subscript(
|
||||||
|
value=ast.Name(id="Optional"),
|
||||||
|
slice=ast.Index(value=event_type),
|
||||||
|
ctx=ast.Load(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return ast.Name(id=ast.unparse(optional_type))
|
||||||
|
|
||||||
if isinstance(annotation, str) and annotation.startswith("Tuple["):
|
if isinstance(annotation, str) and annotation.startswith("Tuple["):
|
||||||
inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
|
inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user