diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index 4a7fdfca0..026a53bca 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -372,6 +372,53 @@ def _extract_class_props_as_ast_nodes( return kwargs +def type_to_ast(typ) -> ast.AST: + """Converts any type annotation into its AST representation. + Handles nested generic types, unions, etc. + + Args: + typ: The type annotation to convert. + + Returns: + The AST representation of the type annotation. + """ + if typ is type(None): + return ast.Name(id="None") + + origin = get_origin(typ) + + # Handle plain types (int, str, custom classes, etc.) + if origin is None: + if hasattr(typ, "__name__"): + return ast.Name(id=typ.__name__) + elif hasattr(typ, "_name"): + return ast.Name(id=typ._name) + return ast.Name(id=str(typ)) + + # Get the base type name (List, Dict, Optional, etc.) + base_name = origin._name if hasattr(origin, "_name") else origin.__name__ + + # Get type arguments + args = get_args(typ) + + # Handle empty type arguments + if not args: + return ast.Name(id=base_name) + + # Convert all type arguments recursively + arg_nodes = [type_to_ast(arg) for arg in args] + + # Special case for single-argument types (like List[T] or Optional[T]) + if len(arg_nodes) == 1: + slice_value = arg_nodes[0] + else: + slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load()) + + return ast.Subscript( + value=ast.Name(id=base_name), slice=ast.Index(value=slice_value), ctx=ast.Load() + ) + + def _get_parent_imports(func): _imports = {"reflex.vars": ["Var"]} for type_hint in inspect.get_annotations(func).values(): @@ -439,12 +486,24 @@ def _generate_component_create_functiondef( for argument in arguments ] - return ast.Name( - id=f"Optional[EventType[{', '.join( - [arg.__name__ for arg in arguments_without_var] - )}]]" + # Convert each argument type to its AST representation + type_args = [type_to_ast(arg) for arg in arguments_without_var] + + # Join the type arguments with commas for EventType + args_str = ", ".join(ast.unparse(arg) for arg in type_args) + + # Create EventType using the joined string + event_type = ast.Name(id=f"EventType[{args_str}]") + + # Wrap in Optional + optional_type = ast.Subscript( + value=ast.Name(id="Optional"), + slice=ast.Index(value=event_type), + ctx=ast.Load(), ) + return ast.Name(id=ast.unparse(optional_type)) + if isinstance(annotation, str) and annotation.startswith("Tuple["): inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")