From 2430a9c1e2a02a298f91955f8c551e438bb951a2 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Tue, 18 Feb 2025 14:07:36 -0800 Subject: [PATCH] do the same for set --- reflex/app.py | 3 +-- reflex/app_mixins/lifespan.py | 4 ++-- reflex/components/component.py | 13 ++++++------- reflex/components/el/elements/forms.py | 4 ++-- reflex/config.py | 3 +-- reflex/state.py | 18 +++++++++--------- reflex/vars/base.py | 4 ++-- tests/units/components/core/test_foreach.py | 4 ++-- tests/units/states/mutation.py | 8 ++++---- tests/units/test_var.py | 10 +++++----- 10 files changed, 34 insertions(+), 37 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index ff72a4ddf..f819bc1f4 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -27,7 +27,6 @@ from typing import ( Dict, MutableMapping, Optional, - Set, Type, Union, get_args, @@ -392,7 +391,7 @@ class App(MiddlewareMixin, LifespanMixin): _event_namespace: Optional[EventNamespace] = None # Background tasks that are currently running. - _background_tasks: Set[asyncio.Task] = dataclasses.field(default_factory=set) + _background_tasks: set[asyncio.Task] = dataclasses.field(default_factory=set) # Frontend Error Handler Function frontend_exception_handler: Callable[[Exception], None] = ( diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index 50b90f25c..9f5330e4c 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -7,7 +7,7 @@ import contextlib import dataclasses import functools import inspect -from typing import Callable, Coroutine, Set, Union +from typing import Callable, Coroutine, Union from fastapi import FastAPI @@ -22,7 +22,7 @@ class LifespanMixin(AppMixin): """A Mixin that allow tasks to run during the whole app lifespan.""" # Lifespan tasks that are planned to run. - lifespan_tasks: Set[Union[asyncio.Task, Callable]] = dataclasses.field( + lifespan_tasks: set[Union[asyncio.Task, Callable]] = dataclasses.field( default_factory=set ) diff --git a/reflex/components/component.py b/reflex/components/component.py index 2f6c5f45f..6b95fba95 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -17,7 +17,6 @@ from typing import ( List, Optional, Sequence, - Set, Type, Union, ) @@ -646,7 +645,7 @@ class Component(BaseComponent, ABC): @classmethod @lru_cache(maxsize=None) - def get_props(cls) -> Set[str]: + def get_props(cls) -> set[str]: """Get the unique fields for the component. Returns: @@ -656,7 +655,7 @@ class Component(BaseComponent, ABC): @classmethod @lru_cache(maxsize=None) - def get_initial_props(cls) -> Set[str]: + def get_initial_props(cls) -> set[str]: """Get the initial props to set for the component. Returns: @@ -1155,7 +1154,7 @@ class Component(BaseComponent, ABC): """ return None - def _get_all_dynamic_imports(self) -> Set[str]: + def _get_all_dynamic_imports(self) -> set[str]: """Get dynamic imports for the component and its children. Returns: @@ -1531,7 +1530,7 @@ class Component(BaseComponent, ABC): def _get_all_custom_components( self, seen: set[str] | None = None - ) -> Set[CustomComponent]: + ) -> set[CustomComponent]: """Get all the custom components used by the component. Args: @@ -1697,7 +1696,7 @@ class CustomComponent(Component): return hash(self.tag) @classmethod - def get_props(cls) -> Set[str]: # pyright: ignore [reportIncompatibleVariableOverride] + def get_props(cls) -> set[str]: # pyright: ignore [reportIncompatibleVariableOverride] """Get the props for the component. Returns: @@ -1707,7 +1706,7 @@ class CustomComponent(Component): def _get_all_custom_components( self, seen: set[str] | None = None - ) -> Set[CustomComponent]: + ) -> set[CustomComponent]: """Get all the custom components used by the component. Args: diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index b715c414d..0dfc71276 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -3,7 +3,7 @@ from __future__ import annotations from hashlib import md5 -from typing import Any, Iterator, Literal, Set, Tuple, Union +from typing import Any, Iterator, Literal, Tuple, Union from jinja2 import Environment @@ -720,7 +720,7 @@ class Textarea(BaseHTML): "enter_key_submit", ] - def _get_all_custom_code(self) -> Set[str]: + def _get_all_custom_code(self) -> set[str]: """Include the custom code for auto_height and enter_key_submit functionality. Returns: diff --git a/reflex/config.py b/reflex/config.py index 50aee588f..37b41dbc5 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -24,7 +24,6 @@ from typing import ( Callable, Generic, Optional, - Set, TypeVar, get_args, get_origin, @@ -836,7 +835,7 @@ class Config(Base): redis_token_expiration: int = constants.Expiration.TOKEN # Attributes that were explicitly set by the user. - _non_default_attributes: Set[str] = pydantic.PrivateAttr(set()) + _non_default_attributes: set[str] = pydantic.PrivateAttr(set()) # Path to file containing key-values pairs to override in the environment; Dotenv format. env_file: Optional[str] = None diff --git a/reflex/state.py b/reflex/state.py index 56e190adc..90b130b94 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -117,7 +117,7 @@ if environment.REFLEX_PERF_MODE.get() != PerformanceMode.OFF: # If the state is this large, it's considered a performance issue. TOO_LARGE_SERIALIZED_STATE = environment.REFLEX_STATE_SIZE_LIMIT.get() * 1024 # Only warn about each state class size once. - _WARNED_ABOUT_STATE_SIZE: Set[str] = set() + _WARNED_ABOUT_STATE_SIZE: set[str] = set() # Errors caught during pickling of state HANDLED_PICKLE_ERRORS = ( @@ -351,19 +351,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): event_handlers: ClassVar[Dict[str, EventHandler]] = {} # A set of subclassses of this class. - class_subclasses: ClassVar[Set[Type[BaseState]]] = set() + class_subclasses: ClassVar[set[Type[BaseState]]] = set() # Mapping of var name to set of (state_full_name, var_name) that depend on it. - _var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {} + _var_dependencies: ClassVar[Dict[str, set[Tuple[str, str]]]] = {} # Set of vars which always need to be recomputed - _always_dirty_computed_vars: ClassVar[Set[str]] = set() + _always_dirty_computed_vars: ClassVar[set[str]] = set() # Set of substates which always need to be recomputed - _always_dirty_substates: ClassVar[Set[str]] = set() + _always_dirty_substates: ClassVar[set[str]] = set() # Set of states which might need to be recomputed if vars in this state change. - _potentially_dirty_states: ClassVar[Set[str]] = set() + _potentially_dirty_states: ClassVar[set[str]] = set() # The parent state. parent_state: Optional[BaseState] = None @@ -372,10 +372,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): substates: Dict[str, BaseState] = {} # The set of dirty vars. - dirty_vars: Set[str] = set() + dirty_vars: set[str] = set() # The set of dirty substates. - dirty_substates: Set[str] = set() + dirty_substates: set[str] = set() # The routing path that triggered the state router_data: Dict[str, Any] = {} @@ -3208,7 +3208,7 @@ class StateManagerRedis(StateManager): ) # These events indicate that a lock is no longer held - _redis_keyspace_lock_release_events: Set[bytes] = { + _redis_keyspace_lock_release_events: set[bytes] = { b"del", b"expire", b"expired", diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 587632e54..ec27740be 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1678,7 +1678,7 @@ def figure_out_type(value: Any) -> types.GenericType: if isinstance(value, list): return list[unionize(*(figure_out_type(v) for v in value))] if isinstance(value, set): - return Set[unionize(*(figure_out_type(v) for v in value))] + return set[unionize(*(figure_out_type(v) for v in value))] if isinstance(value, tuple): return Tuple[unionize(*(figure_out_type(v) for v in value)), ...] if isinstance(value, Mapping): @@ -3269,7 +3269,7 @@ class Field(Generic[FIELD_TYPE]): @overload def __get__( - self: Field[list[V]] | Field[Set[V]] | Field[Tuple[V, ...]], + self: Field[list[V]] | Field[set[V]] | Field[Tuple[V, ...]], instance: None, owner: Any, ) -> ArrayVar[list[V]]: ... diff --git a/tests/units/components/core/test_foreach.py b/tests/units/components/core/test_foreach.py index 1e64a426f..c86e636c6 100644 --- a/tests/units/components/core/test_foreach.py +++ b/tests/units/components/core/test_foreach.py @@ -1,4 +1,4 @@ -from typing import Set, Tuple, Union +from typing import Tuple, Union import pydantic.v1 import pytest @@ -50,7 +50,7 @@ class ForEachState(BaseState): "red", "yellow", ) - colors_set: Set[str] = {"red", "green"} + colors_set: set[str] = {"red", "green"} bad_annotation_list: list = [["red", "orange"], ["yellow", "blue"]] color_index_tuple: Tuple[int, str] = (0, "red") diff --git a/tests/units/states/mutation.py b/tests/units/states/mutation.py index 40811a452..fda3c9a0e 100644 --- a/tests/units/states/mutation.py +++ b/tests/units/states/mutation.py @@ -1,6 +1,6 @@ """Test states for mutable vars.""" -from typing import Dict, List, Set, Union +from typing import Dict, List, Union from sqlalchemy import ARRAY, JSON, String from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column @@ -145,7 +145,7 @@ class CustomVar(rx.Base): foo: str = "" array: list[str] = [] hashmap: dict[str, str] = {} - test_set: Set[str] = set() + test_set: set[str] = set() custom: OtherBase = OtherBase() @@ -163,7 +163,7 @@ class MutableSQLAModel(MutableSQLABase): id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) strlist: Mapped[list[str]] = mapped_column(ARRAY(String)) hashmap: Mapped[dict[str, str]] = mapped_column(JSON) - test_set: Mapped[Set[str]] = mapped_column(ARRAY(String)) + test_set: Mapped[set[str]] = mapped_column(ARRAY(String)) @serializer @@ -194,7 +194,7 @@ class MutableTestState(BaseState): "another_key": "another_value", "third_key": {"key": "value"}, } - test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"} + test_set: set[Union[str, int]] = {1, 2, 3, 4, "five"} custom: CustomVar = CustomVar() _be_custom: CustomVar = CustomVar() sqla_model: MutableSQLAModel = MutableSQLAModel( diff --git a/tests/units/test_var.py b/tests/units/test_var.py index a0237d1d1..3558abfcd 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -1,7 +1,7 @@ import json import math import typing -from typing import List, Mapping, Optional, Set, Tuple, Union, cast +from typing import List, Mapping, Optional, Tuple, Union, cast import pytest from pandas import DataFrame @@ -594,7 +594,7 @@ def test_computed_var_replace_with_invalid_kwargs(): ), ( Var(_js_expr="lst", _var_type=list[int]).guess_type(), - Var(_js_expr="set_var", _var_type=Set[str]).guess_type(), + Var(_js_expr="set_var", _var_type=set[str]).guess_type(), ), ( Var(_js_expr="lst", _var_type=list[int]).guess_type(), @@ -714,7 +714,7 @@ def test_dict_indexing(): ), ( Var(_js_expr="lst", _var_type=dict[str, str]).guess_type(), - Var(_js_expr="set_var", _var_type=Set[str]).guess_type(), + Var(_js_expr="set_var", _var_type=set[str]).guess_type(), ), ( Var(_js_expr="lst", _var_type=dict[str, str]).guess_type(), @@ -745,7 +745,7 @@ def test_dict_indexing(): ), ( Var(_js_expr="df", _var_type=DataFrame).guess_type(), - Var(_js_expr="set_var", _var_type=Set[str]).guess_type(), + Var(_js_expr="set_var", _var_type=set[str]).guess_type(), ), ( Var(_js_expr="df", _var_type=DataFrame).guess_type(), @@ -1813,7 +1813,7 @@ def cv_fget(state: BaseState) -> int: ([ComputedVar(fget=cv_fget)], {None: {"cv_fget"}}), ], ) -def test_computed_var_deps(deps: list[Union[str, Var]], expected: Set[str]): +def test_computed_var_deps(deps: list[Union[str, Var]], expected: set[str]): @computed_var(deps=deps) def test_var(state) -> int: return 1