From ad3134413b17c7a6bdf39bb0e2d4bbcc0e9e9524 Mon Sep 17 00:00:00 2001 From: Alexander Morgan Date: Fri, 7 Jun 2024 23:28:44 +0200 Subject: [PATCH] Make better/less use of dict.keys() calls (#3455) --- reflex/.templates/apps/demo/code/webui/state.py | 5 +++-- reflex/app.py | 7 ++----- reflex/components/component.py | 11 +++++++---- reflex/utils/pyi_generator.py | 2 +- reflex/utils/types.py | 2 +- tests/components/core/test_banner.py | 10 +++++----- tests/components/core/test_debounce.py | 7 +++---- tests/components/datadisplay/test_datatable.py | 2 +- tests/components/test_component.py | 2 +- tests/test_app.py | 10 +++++----- tests/test_state.py | 8 ++++---- tests/utils/test_imports.py | 2 +- 12 files changed, 34 insertions(+), 34 deletions(-) diff --git a/reflex/.templates/apps/demo/code/webui/state.py b/reflex/.templates/apps/demo/code/webui/state.py index 4956e6f59..51739222d 100644 --- a/reflex/.templates/apps/demo/code/webui/state.py +++ b/reflex/.templates/apps/demo/code/webui/state.py @@ -66,7 +66,8 @@ class State(State): del self.chats[self.current_chat] if len(self.chats) == 0: self.chats = DEFAULT_CHATS - self.current_chat = list(self.chats.keys())[0] + # set self.current_chat to the first chat. + self.current_chat = next(iter(self.chats)) self.toggle_drawer() def set_chat(self, chat_name: str): @@ -85,7 +86,7 @@ class State(State): Returns: The list of chat names. """ - return list(self.chats.keys()) + return [*self.chats] async def process_question(self, form_data: dict[str, str]): """Get the response from the API. diff --git a/reflex/app.py b/reflex/app.py index 89d68c1db..65bc32438 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -707,11 +707,8 @@ class App(LifespanMixin, Base): page_imports = { i for i, tags in imports.items() - if i - not in [ - *constants.PackageJson.DEPENDENCIES.keys(), - *constants.PackageJson.DEV_DEPENDENCIES.keys(), - ] + if i not in constants.PackageJson.DEPENDENCIES + and i not in constants.PackageJson.DEV_DEPENDENCIES and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"]) and i != "" and any(tag.install for tag in tags) diff --git a/reflex/components/component.py b/reflex/components/component.py index e2d8e22c7..516aba929 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -360,7 +360,6 @@ class Component(BaseComponent, ABC): # Get the component fields, triggers, and props. fields = self.get_fields() component_specific_triggers = self.get_event_triggers() - triggers = component_specific_triggers.keys() props = self.get_props() # Add any events triggers. @@ -370,13 +369,17 @@ class Component(BaseComponent, ABC): # Iterate through the kwargs and set the props. for key, value in kwargs.items(): - if key.startswith("on_") and key not in triggers and key not in props: + if ( + key.startswith("on_") + and key not in component_specific_triggers + and key not in props + ): raise ValueError( f"The {(comp_name := type(self).__name__)} does not take in an `{key}` event trigger. If {comp_name}" f" is a third party component make sure to add `{key}` to the component's event triggers. " f"visit https://reflex.dev/docs/wrapping-react/guide/#event-triggers for more info." ) - if key in triggers: + if key in component_specific_triggers: # Event triggers are bound to event chains. field_type = EventChain elif key in props: @@ -436,7 +439,7 @@ class Component(BaseComponent, ABC): ) # Check if the key is an event trigger. - if key in triggers: + if key in component_specific_triggers: # Temporarily disable full control for event triggers. kwargs["event_triggers"][key] = self._create_event_chain( value=value, args_spec=component_specific_triggers[key] diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index cc27aac70..30f885f33 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -424,7 +424,7 @@ def _generate_component_create_functiondef( ), ast.Constant(value=None), ) - for trigger in sorted(clz().get_event_triggers().keys()) + for trigger in sorted(clz().get_event_triggers()) ) logger.debug(f"Generated {clz.__name__}.create method with {len(kwargs)} kwargs") create_args = ast.arguments( diff --git a/reflex/utils/types.py b/reflex/utils/types.py index f35765677..6c759bf32 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -509,7 +509,7 @@ def validate_parameter_literals(func): annotations = {param[0]: param[1].annotation for param in func_params} # validate args - for param, arg in zip(annotations.keys(), args): + for param, arg in zip(annotations, args): if annotations[param] is inspect.Parameter.empty: continue validate_literal(param, arg, annotations[param], func.__name__) diff --git a/tests/components/core/test_banner.py b/tests/components/core/test_banner.py index f929eef37..184a65cc9 100644 --- a/tests/components/core/test_banner.py +++ b/tests/components/core/test_banner.py @@ -10,19 +10,19 @@ from reflex.components.radix.themes.typography.text import Text def test_websocket_target_url(): url = WebsocketTargetURL.create() _imports = url._get_all_imports(collapse=True) - assert list(_imports.keys()) == ["/utils/state", "/env.json"] + assert tuple(_imports) == ("/utils/state", "/env.json") def test_connection_banner(): banner = ConnectionBanner.create() _imports = banner._get_all_imports(collapse=True) - assert list(_imports.keys()) == [ + assert tuple(_imports) == ( "react", "/utils/context", "/utils/state", "@radix-ui/themes@^3.0.0", "/env.json", - ] + ) msg = "Connection error" custom_banner = ConnectionBanner.create(Text.create(msg)) @@ -32,13 +32,13 @@ def test_connection_banner(): def test_connection_modal(): modal = ConnectionModal.create() _imports = modal._get_all_imports(collapse=True) - assert list(_imports.keys()) == [ + assert tuple(_imports) == ( "react", "/utils/context", "/utils/state", "@radix-ui/themes@^3.0.0", "/env.json", - ] + ) msg = "Connection error" custom_modal = ConnectionModal.create(Text.create(msg)) diff --git a/tests/components/core/test_debounce.py b/tests/components/core/test_debounce.py index 8a8ec394c..e4bcb05c4 100644 --- a/tests/components/core/test_debounce.py +++ b/tests/components/core/test_debounce.py @@ -98,11 +98,10 @@ def test_event_triggers(): on_change=S.on_change, ) ) - default_event_triggers = list(rx.Component().get_event_triggers().keys()) - assert list(debounced_input.get_event_triggers().keys()) == [ - *default_event_triggers, + assert tuple(debounced_input.get_event_triggers()) == ( + *rx.Component().get_event_triggers(), # default event triggers "on_change", - ] + ) def test_render_child_props_recursive(): diff --git a/tests/components/datadisplay/test_datatable.py b/tests/components/datadisplay/test_datatable.py index c755064ce..2557be62b 100644 --- a/tests/components/datadisplay/test_datatable.py +++ b/tests/components/datadisplay/test_datatable.py @@ -114,4 +114,4 @@ def test_serialize_dataframe(): value = serialize(df) assert value == serialize_dataframe(df) assert isinstance(value, dict) - assert list(value.keys()) == ["columns", "data"] + assert tuple(value) == ("columns", "data") diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 356b9feae..2e395ce37 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -566,7 +566,7 @@ def test_get_event_triggers(component1, component2): EventTriggers.ON_MOUNT, EventTriggers.ON_UNMOUNT, } - assert set(component1().get_event_triggers().keys()) == default_triggers + assert component1().get_event_triggers().keys() == default_triggers assert ( component2().get_event_triggers().keys() == {"on_open", "on_close"} | default_triggers diff --git a/tests/test_app.py b/tests/test_app.py index 46079e6da..5709d93a9 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -235,9 +235,9 @@ def test_add_page_default_route(app: App, index_page, about_page): """ assert app.pages == {} app.add_page(index_page) - assert set(app.pages.keys()) == {"index"} + assert app.pages.keys() == {"index"} app.add_page(about_page) - assert set(app.pages.keys()) == {"index", "about"} + assert app.pages.keys() == {"index", "about"} def test_add_page_set_route(app: App, index_page, windows_platform: bool): @@ -251,7 +251,7 @@ def test_add_page_set_route(app: App, index_page, windows_platform: bool): route = "test" if windows_platform else "/test" assert app.pages == {} app.add_page(index_page, route=route) - assert set(app.pages.keys()) == {"test"} + assert app.pages.keys() == {"test"} def test_add_page_set_route_dynamic(index_page, windows_platform: bool): @@ -268,7 +268,7 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool): route.lstrip("/").replace("/", "\\") assert app.pages == {} app.add_page(index_page, route=route) - assert set(app.pages.keys()) == {"test/[dynamic]"} + assert app.pages.keys() == {"test/[dynamic]"} assert "dynamic" in app.state.computed_vars assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == { constants.ROUTER @@ -287,7 +287,7 @@ def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool) route = "test\\nested" if windows_platform else "/test/nested" assert app.pages == {} app.add_page(index_page, route=route) - assert set(app.pages.keys()) == {route.strip(os.path.sep)} + assert app.pages.keys() == {route.strip(os.path.sep)} def test_add_page_invalid_api_route(app: App, index_page): diff --git a/tests/test_state.py b/tests/test_state.py index 6fcd1a67f..77f6bc606 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -287,7 +287,7 @@ def test_class_vars(test_state): test_state: A state. """ cls = type(test_state) - assert set(cls.vars.keys()) == { + assert cls.vars.keys() == { "router", "num1", "num2", @@ -310,7 +310,7 @@ def test_event_handlers(test_state): Args: test_state: A state. """ - expected = { + expected_keys = ( "do_something", "set_array", "set_complex", @@ -320,10 +320,10 @@ def test_event_handlers(test_state): "set_num1", "set_num2", "set_obj", - } + ) cls = type(test_state) - assert set(cls.event_handlers.keys()).intersection(expected) == expected + assert all(key in cls.event_handlers for key in expected_keys) def test_default_value(test_state): diff --git a/tests/utils/test_imports.py b/tests/utils/test_imports.py index c7253ff6b..9a5537136 100644 --- a/tests/utils/test_imports.py +++ b/tests/utils/test_imports.py @@ -72,7 +72,7 @@ def test_merge_imports(input_1, input_2, output): """ res = merge_imports(input_1, input_2) - assert set(res.keys()) == set(output.keys()) + assert res.keys() == output.keys() for key in output: assert set(res[key]) == set(output[key])