more typing for components

This commit is contained in:
Lendemor 2024-11-20 21:17:21 +01:00
parent 6183da6d8a
commit b3aa558d44
12 changed files with 28 additions and 26 deletions

View File

@ -812,7 +812,7 @@ class Component(BaseComponent, ABC):
# Filter out None props # Filter out None props
props = {key: value for key, value in props.items() if value is not None} props = {key: value for key, value in props.items() if value is not None}
def validate_children(children): def validate_children(children: tuple):
for child in children: for child in children:
if isinstance(child, tuple): if isinstance(child, tuple):
validate_children(child) validate_children(child)
@ -957,7 +957,7 @@ class Component(BaseComponent, ABC):
else {} else {}
) )
def render(self) -> Dict: def render(self) -> dict:
"""Render the component. """Render the component.
Returns: Returns:
@ -975,7 +975,7 @@ class Component(BaseComponent, ABC):
self._replace_prop_names(rendered_dict) self._replace_prop_names(rendered_dict)
return rendered_dict return rendered_dict
def _replace_prop_names(self, rendered_dict) -> None: def _replace_prop_names(self, rendered_dict: dict) -> None:
"""Replace the prop names in the render dictionary. """Replace the prop names in the render dictionary.
Args: Args:
@ -1015,7 +1015,7 @@ class Component(BaseComponent, ABC):
comp.__name__ for comp in (Fragment, Foreach, Cond, Match) comp.__name__ for comp in (Fragment, Foreach, Cond, Match)
] ]
def validate_child(child): def validate_child(child: Any):
child_name = type(child).__name__ child_name = type(child).__name__
# Iterate through the immediate children of fragment # Iterate through the immediate children of fragment

View File

@ -24,7 +24,7 @@ class ClientSideRouting(Component):
library = "$/utils/client_side_routing" library = "$/utils/client_side_routing"
tag = "useClientSideRouting" tag = "useClientSideRouting"
def add_hooks(self) -> list[str]: def add_hooks(self) -> list[str | Var]:
"""Get the hooks to render. """Get the hooks to render.
Returns: Returns:
@ -41,7 +41,7 @@ class ClientSideRouting(Component):
return "" return ""
def wait_for_client_redirect(component) -> Component: def wait_for_client_redirect(component: Component) -> Component:
"""Wait for a redirect to occur before rendering a component. """Wait for a redirect to occur before rendering a component.
This prevents the 404 page from flashing while the redirect is happening. This prevents the 404 page from flashing while the redirect is happening.

View File

@ -13,7 +13,7 @@ from reflex.vars.base import Var
route_not_found: Var route_not_found: Var
class ClientSideRouting(Component): class ClientSideRouting(Component):
def add_hooks(self) -> list[str]: ... def add_hooks(self) -> list[str | Var]: ...
def render(self) -> str: ... def render(self) -> str: ...
@overload @overload
@classmethod @classmethod
@ -60,7 +60,7 @@ class ClientSideRouting(Component):
""" """
... ...
def wait_for_client_redirect(component) -> Component: ... def wait_for_client_redirect(component: Component) -> Component: ...
class Default404Page(Component): class Default404Page(Component):
@overload @overload

View File

@ -154,7 +154,7 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
if c2 is None: if c2 is None:
raise ValueError("For conditional vars, the second argument must be set.") raise ValueError("For conditional vars, the second argument must be set.")
def create_var(cond_part): def create_var(cond_part: Any) -> Var[Any]:
return LiteralVar.create(cond_part) return LiteralVar.create(cond_part)
# convert the truth and false cond parts into vars so the _var_data can be obtained. # convert the truth and false cond parts into vars so the _var_data can be obtained.

View File

@ -109,7 +109,7 @@ class Match(MemoizationLeaf):
return cases, default return cases, default
@classmethod @classmethod
def _create_case_var_with_var_data(cls, case_element): def _create_case_var_with_var_data(cls, case_element: Any) -> Var:
"""Convert a case element into a Var.If the case """Convert a case element into a Var.If the case
is a Style type, we extract the var data and merge it with the is a Style type, we extract the var data and merge it with the
newly created Var. newly created Var.

View File

@ -15,10 +15,8 @@ def svg_logo(color: Union[str, rx.Var[str]] = rx.color_mode_cond("#110F1F", "whi
The Reflex logo SVG. The Reflex logo SVG.
""" """
def logo_path(d): def logo_path(d: str):
return rx.el.svg.path( return rx.el.svg.path(d=d)
d=d,
)
paths = [ paths = [
"M0 11.5999V0.399902H8.96V4.8799H6.72V2.6399H2.24V4.8799H6.72V7.1199H2.24V11.5999H0ZM6.72 11.5999V7.1199H8.96V11.5999H6.72Z", "M0 11.5999V0.399902H8.96V4.8799H6.72V2.6399H2.24V4.8799H6.72V7.1199H2.24V11.5999H0ZM6.72 11.5999V7.1199H8.96V11.5999H6.72Z",

View File

@ -6,7 +6,7 @@ from reflex.components.component import Component
class Element(Component): class Element(Component):
"""The base class for all raw HTML elements.""" """The base class for all raw HTML elements."""
def __eq__(self, other): def __eq__(self, other: object):
"""Two elements are equal if they have the same tag. """Two elements are equal if they have the same tag.
Args: Args:

View File

@ -8,7 +8,7 @@ from functools import lru_cache
from hashlib import md5 from hashlib import md5
from typing import Any, Callable, Dict, Sequence, Union from typing import Any, Callable, Dict, Sequence, Union
from reflex.components.component import Component, CustomComponent from reflex.components.component import BaseComponent, Component, CustomComponent
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.utils import types from reflex.utils import types
from reflex.utils.imports import ImportDict, ImportVar from reflex.utils.imports import ImportDict, ImportVar
@ -379,7 +379,9 @@ const {str(_LANGUAGE)} = match ? match[1] : '';
# fallback to the default fn Var creation if the component is not a MarkdownComponentMap. # fallback to the default fn Var creation if the component is not a MarkdownComponentMap.
return MarkdownComponentMap.create_map_fn_var(fn_body=formatted_component) return MarkdownComponentMap.create_map_fn_var(fn_body=formatted_component)
def _get_map_fn_custom_code_from_children(self, component) -> list[str]: def _get_map_fn_custom_code_from_children(
self, component: BaseComponent
) -> list[str]:
"""Recursively get markdown custom code from children components. """Recursively get markdown custom code from children components.
Args: Args:
@ -409,7 +411,7 @@ const {str(_LANGUAGE)} = match ? match[1] : '';
return custom_code_list return custom_code_list
@staticmethod @staticmethod
def _component_map_hash(component_map) -> str: def _component_map_hash(component_map: dict) -> str:
inp = str( inp = str(
{tag: component(_MOCK_ARG) for tag, component in component_map.items()} {tag: component(_MOCK_ARG) for tag, component in component_map.items()}
).encode() ).encode()

View File

@ -83,7 +83,7 @@ class Image(NextComponent):
style = props.get("style", {}) style = props.get("style", {})
DEFAULT_W_H = "100%" DEFAULT_W_H = "100%"
def check_prop_type(prop_name, prop_value): def check_prop_type(prop_name: str, prop_value: int | str | None):
if types.check_prop_in_allowed_types(prop_value, allowed_types=[int]): if types.check_prop_in_allowed_types(prop_value, allowed_types=[int]):
props[prop_name] = prop_value props[prop_name] = prop_value

View File

@ -48,7 +48,7 @@ class PropsBase(Base):
class NoExtrasAllowedProps(Base): class NoExtrasAllowedProps(Base):
"""A class that holds props to be passed or applied to a component with no extra props allowed.""" """A class that holds props to be passed or applied to a component with no extra props allowed."""
def __init__(self, component_name=None, **kwargs): def __init__(self, component_name: str | None = None, **kwargs):
"""Initialize the props. """Initialize the props.
Args: Args:

View File

@ -17,7 +17,7 @@ rx.text(
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Literal, Optional, Union, get_args from typing import Any, Dict, List, Literal, Optional, Union, get_args
from reflex.components.component import BaseComponent from reflex.components.component import BaseComponent
from reflex.components.core.cond import Cond, color_mode_cond, cond from reflex.components.core.cond import Cond, color_mode_cond, cond
@ -78,17 +78,19 @@ position_map: Dict[str, List[str]] = {
# needed to inverse contains for find # needed to inverse contains for find
def _find(const: List[str], var): def _find(const: List[str], var: Any):
return LiteralArrayVar.create(const).contains(var) return LiteralArrayVar.create(const).contains(var)
def _set_var_default(props, position, prop, default1, default2=""): def _set_var_default(
props: dict, position: Any, prop: str, default1: str, default2: str = ""
):
props.setdefault( props.setdefault(
prop, cond(_find(position_map[prop], position), default1, default2) prop, cond(_find(position_map[prop], position), default1, default2)
) )
def _set_static_default(props, position, prop, default): def _set_static_default(props: dict, position: Any, prop: str, default: str):
if prop in position: if prop in position:
props.setdefault(prop, default) props.setdefault(prop, default)
@ -142,7 +144,7 @@ class ColorModeIconButton(IconButton):
if allow_system: if allow_system:
def color_mode_item(_color_mode): def color_mode_item(_color_mode: str):
return dropdown_menu.item( return dropdown_menu.item(
_color_mode.title(), on_click=set_color_mode(_color_mode) _color_mode.title(), on_click=set_color_mode(_color_mode)
) )

View File

@ -193,7 +193,7 @@ ordered_list = list_ns.ordered
unordered_list = list_ns.unordered unordered_list = list_ns.unordered
def __getattr__(name): def __getattr__(name: Any):
# special case for when accessing list to avoid shadowing # special case for when accessing list to avoid shadowing
# python's built in list object. # python's built in list object.
if name == "list": if name == "list":