have the pyi here as well

This commit is contained in:
Khaleel Al-Adhami 2024-10-21 18:07:29 -07:00
parent e88aa8257b
commit 3c19ed4f0b

View File

@ -372,6 +372,53 @@ def _extract_class_props_as_ast_nodes(
return kwargs
def type_to_ast(typ) -> ast.AST:
"""Converts any type annotation into its AST representation.
Handles nested generic types, unions, etc.
Args:
typ: The type annotation to convert.
Returns:
The AST representation of the type annotation.
"""
if typ is type(None):
return ast.Name(id="None")
origin = get_origin(typ)
# Handle plain types (int, str, custom classes, etc.)
if origin is None:
if hasattr(typ, "__name__"):
return ast.Name(id=typ.__name__)
elif hasattr(typ, "_name"):
return ast.Name(id=typ._name)
return ast.Name(id=str(typ))
# Get the base type name (List, Dict, Optional, etc.)
base_name = origin._name if hasattr(origin, "_name") else origin.__name__
# Get type arguments
args = get_args(typ)
# Handle empty type arguments
if not args:
return ast.Name(id=base_name)
# Convert all type arguments recursively
arg_nodes = [type_to_ast(arg) for arg in args]
# Special case for single-argument types (like List[T] or Optional[T])
if len(arg_nodes) == 1:
slice_value = arg_nodes[0]
else:
slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load())
return ast.Subscript(
value=ast.Name(id=base_name), slice=ast.Index(value=slice_value), ctx=ast.Load()
)
def _get_parent_imports(func):
_imports = {"reflex.vars": ["Var"]}
for type_hint in inspect.get_annotations(func).values():
@ -439,12 +486,24 @@ def _generate_component_create_functiondef(
for argument in arguments
]
return ast.Name(
id=f"Optional[EventType[{', '.join(
[arg.__name__ for arg in arguments_without_var]
)}]]"
# Convert each argument type to its AST representation
type_args = [type_to_ast(arg) for arg in arguments_without_var]
# Join the type arguments with commas for EventType
args_str = ", ".join(ast.unparse(arg) for arg in type_args)
# Create EventType using the joined string
event_type = ast.Name(id=f"EventType[{args_str}]")
# Wrap in Optional
optional_type = ast.Subscript(
value=ast.Name(id="Optional"),
slice=ast.Index(value=event_type),
ctx=ast.Load(),
)
return ast.Name(id=ast.unparse(optional_type))
if isinstance(annotation, str) and annotation.startswith("Tuple["):
inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")