more typing for components

This commit is contained in:
Lendemor 2024-11-20 21:17:21 +01:00
parent 6183da6d8a
commit b3aa558d44
12 changed files with 28 additions and 26 deletions

View File

@ -812,7 +812,7 @@ class Component(BaseComponent, ABC):
# Filter out None props
props = {key: value for key, value in props.items() if value is not None}
def validate_children(children):
def validate_children(children: tuple):
for child in children:
if isinstance(child, tuple):
validate_children(child)
@ -957,7 +957,7 @@ class Component(BaseComponent, ABC):
else {}
)
def render(self) -> Dict:
def render(self) -> dict:
"""Render the component.
Returns:
@ -975,7 +975,7 @@ class Component(BaseComponent, ABC):
self._replace_prop_names(rendered_dict)
return rendered_dict
def _replace_prop_names(self, rendered_dict) -> None:
def _replace_prop_names(self, rendered_dict: dict) -> None:
"""Replace the prop names in the render dictionary.
Args:
@ -1015,7 +1015,7 @@ class Component(BaseComponent, ABC):
comp.__name__ for comp in (Fragment, Foreach, Cond, Match)
]
def validate_child(child):
def validate_child(child: Any):
child_name = type(child).__name__
# Iterate through the immediate children of fragment

View File

@ -24,7 +24,7 @@ class ClientSideRouting(Component):
library = "$/utils/client_side_routing"
tag = "useClientSideRouting"
def add_hooks(self) -> list[str]:
def add_hooks(self) -> list[str | Var]:
"""Get the hooks to render.
Returns:
@ -41,7 +41,7 @@ class ClientSideRouting(Component):
return ""
def wait_for_client_redirect(component) -> Component:
def wait_for_client_redirect(component: Component) -> Component:
"""Wait for a redirect to occur before rendering a component.
This prevents the 404 page from flashing while the redirect is happening.

View File

@ -13,7 +13,7 @@ from reflex.vars.base import Var
route_not_found: Var
class ClientSideRouting(Component):
def add_hooks(self) -> list[str]: ...
def add_hooks(self) -> list[str | Var]: ...
def render(self) -> str: ...
@overload
@classmethod
@ -60,7 +60,7 @@ class ClientSideRouting(Component):
"""
...
def wait_for_client_redirect(component) -> Component: ...
def wait_for_client_redirect(component: Component) -> Component: ...
class Default404Page(Component):
@overload

View File

@ -154,7 +154,7 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
if c2 is None:
raise ValueError("For conditional vars, the second argument must be set.")
def create_var(cond_part):
def create_var(cond_part: Any) -> Var[Any]:
return LiteralVar.create(cond_part)
# convert the truth and false cond parts into vars so the _var_data can be obtained.

View File

@ -109,7 +109,7 @@ class Match(MemoizationLeaf):
return cases, default
@classmethod
def _create_case_var_with_var_data(cls, case_element):
def _create_case_var_with_var_data(cls, case_element: Any) -> Var:
"""Convert a case element into a Var.If the case
is a Style type, we extract the var data and merge it with the
newly created Var.

View File

@ -15,10 +15,8 @@ def svg_logo(color: Union[str, rx.Var[str]] = rx.color_mode_cond("#110F1F", "whi
The Reflex logo SVG.
"""
def logo_path(d):
return rx.el.svg.path(
d=d,
)
def logo_path(d: str):
return rx.el.svg.path(d=d)
paths = [
"M0 11.5999V0.399902H8.96V4.8799H6.72V2.6399H2.24V4.8799H6.72V7.1199H2.24V11.5999H0ZM6.72 11.5999V7.1199H8.96V11.5999H6.72Z",

View File

@ -6,7 +6,7 @@ from reflex.components.component import Component
class Element(Component):
"""The base class for all raw HTML elements."""
def __eq__(self, other):
def __eq__(self, other: object):
"""Two elements are equal if they have the same tag.
Args:

View File

@ -8,7 +8,7 @@ from functools import lru_cache
from hashlib import md5
from typing import Any, Callable, Dict, Sequence, Union
from reflex.components.component import Component, CustomComponent
from reflex.components.component import BaseComponent, Component, CustomComponent
from reflex.components.tags.tag import Tag
from reflex.utils import types
from reflex.utils.imports import ImportDict, ImportVar
@ -379,7 +379,9 @@ const {str(_LANGUAGE)} = match ? match[1] : '';
# fallback to the default fn Var creation if the component is not a MarkdownComponentMap.
return MarkdownComponentMap.create_map_fn_var(fn_body=formatted_component)
def _get_map_fn_custom_code_from_children(self, component) -> list[str]:
def _get_map_fn_custom_code_from_children(
self, component: BaseComponent
) -> list[str]:
"""Recursively get markdown custom code from children components.
Args:
@ -409,7 +411,7 @@ const {str(_LANGUAGE)} = match ? match[1] : '';
return custom_code_list
@staticmethod
def _component_map_hash(component_map) -> str:
def _component_map_hash(component_map: dict) -> str:
inp = str(
{tag: component(_MOCK_ARG) for tag, component in component_map.items()}
).encode()

View File

@ -83,7 +83,7 @@ class Image(NextComponent):
style = props.get("style", {})
DEFAULT_W_H = "100%"
def check_prop_type(prop_name, prop_value):
def check_prop_type(prop_name: str, prop_value: int | str | None):
if types.check_prop_in_allowed_types(prop_value, allowed_types=[int]):
props[prop_name] = prop_value

View File

@ -48,7 +48,7 @@ class PropsBase(Base):
class NoExtrasAllowedProps(Base):
"""A class that holds props to be passed or applied to a component with no extra props allowed."""
def __init__(self, component_name=None, **kwargs):
def __init__(self, component_name: str | None = None, **kwargs):
"""Initialize the props.
Args:

View File

@ -17,7 +17,7 @@ rx.text(
from __future__ import annotations
from typing import Dict, List, Literal, Optional, Union, get_args
from typing import Any, Dict, List, Literal, Optional, Union, get_args
from reflex.components.component import BaseComponent
from reflex.components.core.cond import Cond, color_mode_cond, cond
@ -78,17 +78,19 @@ position_map: Dict[str, List[str]] = {
# needed to inverse contains for find
def _find(const: List[str], var):
def _find(const: List[str], var: Any):
return LiteralArrayVar.create(const).contains(var)
def _set_var_default(props, position, prop, default1, default2=""):
def _set_var_default(
props: dict, position: Any, prop: str, default1: str, default2: str = ""
):
props.setdefault(
prop, cond(_find(position_map[prop], position), default1, default2)
)
def _set_static_default(props, position, prop, default):
def _set_static_default(props: dict, position: Any, prop: str, default: str):
if prop in position:
props.setdefault(prop, default)
@ -142,7 +144,7 @@ class ColorModeIconButton(IconButton):
if allow_system:
def color_mode_item(_color_mode):
def color_mode_item(_color_mode: str):
return dropdown_menu.item(
_color_mode.title(), on_click=set_color_mode(_color_mode)
)

View File

@ -193,7 +193,7 @@ ordered_list = list_ns.ordered
unordered_list = list_ns.unordered
def __getattr__(name):
def __getattr__(name: Any):
# special case for when accessing list to avoid shadowing
# python's built in list object.
if name == "list":