Make better/less use of dict.keys() calls (#3455)

This commit is contained in:
Alexander Morgan 2024-06-07 23:28:44 +02:00 committed by GitHub
parent bb44d51f2f
commit ad3134413b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 34 additions and 34 deletions

View File

@ -66,7 +66,8 @@ class State(State):
del self.chats[self.current_chat] del self.chats[self.current_chat]
if len(self.chats) == 0: if len(self.chats) == 0:
self.chats = DEFAULT_CHATS self.chats = DEFAULT_CHATS
self.current_chat = list(self.chats.keys())[0] # set self.current_chat to the first chat.
self.current_chat = next(iter(self.chats))
self.toggle_drawer() self.toggle_drawer()
def set_chat(self, chat_name: str): def set_chat(self, chat_name: str):
@ -85,7 +86,7 @@ class State(State):
Returns: Returns:
The list of chat names. The list of chat names.
""" """
return list(self.chats.keys()) return [*self.chats]
async def process_question(self, form_data: dict[str, str]): async def process_question(self, form_data: dict[str, str]):
"""Get the response from the API. """Get the response from the API.

View File

@ -707,11 +707,8 @@ class App(LifespanMixin, Base):
page_imports = { page_imports = {
i i
for i, tags in imports.items() for i, tags in imports.items()
if i if i not in constants.PackageJson.DEPENDENCIES
not in [ and i not in constants.PackageJson.DEV_DEPENDENCIES
*constants.PackageJson.DEPENDENCIES.keys(),
*constants.PackageJson.DEV_DEPENDENCIES.keys(),
]
and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"]) and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"])
and i != "" and i != ""
and any(tag.install for tag in tags) and any(tag.install for tag in tags)

View File

@ -360,7 +360,6 @@ class Component(BaseComponent, ABC):
# Get the component fields, triggers, and props. # Get the component fields, triggers, and props.
fields = self.get_fields() fields = self.get_fields()
component_specific_triggers = self.get_event_triggers() component_specific_triggers = self.get_event_triggers()
triggers = component_specific_triggers.keys()
props = self.get_props() props = self.get_props()
# Add any events triggers. # Add any events triggers.
@ -370,13 +369,17 @@ class Component(BaseComponent, ABC):
# Iterate through the kwargs and set the props. # Iterate through the kwargs and set the props.
for key, value in kwargs.items(): for key, value in kwargs.items():
if key.startswith("on_") and key not in triggers and key not in props: if (
key.startswith("on_")
and key not in component_specific_triggers
and key not in props
):
raise ValueError( raise ValueError(
f"The {(comp_name := type(self).__name__)} does not take in an `{key}` event trigger. If {comp_name}" f"The {(comp_name := type(self).__name__)} does not take in an `{key}` event trigger. If {comp_name}"
f" is a third party component make sure to add `{key}` to the component's event triggers. " f" is a third party component make sure to add `{key}` to the component's event triggers. "
f"visit https://reflex.dev/docs/wrapping-react/guide/#event-triggers for more info." f"visit https://reflex.dev/docs/wrapping-react/guide/#event-triggers for more info."
) )
if key in triggers: if key in component_specific_triggers:
# Event triggers are bound to event chains. # Event triggers are bound to event chains.
field_type = EventChain field_type = EventChain
elif key in props: elif key in props:
@ -436,7 +439,7 @@ class Component(BaseComponent, ABC):
) )
# Check if the key is an event trigger. # Check if the key is an event trigger.
if key in triggers: if key in component_specific_triggers:
# Temporarily disable full control for event triggers. # Temporarily disable full control for event triggers.
kwargs["event_triggers"][key] = self._create_event_chain( kwargs["event_triggers"][key] = self._create_event_chain(
value=value, args_spec=component_specific_triggers[key] value=value, args_spec=component_specific_triggers[key]

View File

@ -424,7 +424,7 @@ def _generate_component_create_functiondef(
), ),
ast.Constant(value=None), ast.Constant(value=None),
) )
for trigger in sorted(clz().get_event_triggers().keys()) for trigger in sorted(clz().get_event_triggers())
) )
logger.debug(f"Generated {clz.__name__}.create method with {len(kwargs)} kwargs") logger.debug(f"Generated {clz.__name__}.create method with {len(kwargs)} kwargs")
create_args = ast.arguments( create_args = ast.arguments(

View File

@ -509,7 +509,7 @@ def validate_parameter_literals(func):
annotations = {param[0]: param[1].annotation for param in func_params} annotations = {param[0]: param[1].annotation for param in func_params}
# validate args # validate args
for param, arg in zip(annotations.keys(), args): for param, arg in zip(annotations, args):
if annotations[param] is inspect.Parameter.empty: if annotations[param] is inspect.Parameter.empty:
continue continue
validate_literal(param, arg, annotations[param], func.__name__) validate_literal(param, arg, annotations[param], func.__name__)

View File

@ -10,19 +10,19 @@ from reflex.components.radix.themes.typography.text import Text
def test_websocket_target_url(): def test_websocket_target_url():
url = WebsocketTargetURL.create() url = WebsocketTargetURL.create()
_imports = url._get_all_imports(collapse=True) _imports = url._get_all_imports(collapse=True)
assert list(_imports.keys()) == ["/utils/state", "/env.json"] assert tuple(_imports) == ("/utils/state", "/env.json")
def test_connection_banner(): def test_connection_banner():
banner = ConnectionBanner.create() banner = ConnectionBanner.create()
_imports = banner._get_all_imports(collapse=True) _imports = banner._get_all_imports(collapse=True)
assert list(_imports.keys()) == [ assert tuple(_imports) == (
"react", "react",
"/utils/context", "/utils/context",
"/utils/state", "/utils/state",
"@radix-ui/themes@^3.0.0", "@radix-ui/themes@^3.0.0",
"/env.json", "/env.json",
] )
msg = "Connection error" msg = "Connection error"
custom_banner = ConnectionBanner.create(Text.create(msg)) custom_banner = ConnectionBanner.create(Text.create(msg))
@ -32,13 +32,13 @@ def test_connection_banner():
def test_connection_modal(): def test_connection_modal():
modal = ConnectionModal.create() modal = ConnectionModal.create()
_imports = modal._get_all_imports(collapse=True) _imports = modal._get_all_imports(collapse=True)
assert list(_imports.keys()) == [ assert tuple(_imports) == (
"react", "react",
"/utils/context", "/utils/context",
"/utils/state", "/utils/state",
"@radix-ui/themes@^3.0.0", "@radix-ui/themes@^3.0.0",
"/env.json", "/env.json",
] )
msg = "Connection error" msg = "Connection error"
custom_modal = ConnectionModal.create(Text.create(msg)) custom_modal = ConnectionModal.create(Text.create(msg))

View File

@ -98,11 +98,10 @@ def test_event_triggers():
on_change=S.on_change, on_change=S.on_change,
) )
) )
default_event_triggers = list(rx.Component().get_event_triggers().keys()) assert tuple(debounced_input.get_event_triggers()) == (
assert list(debounced_input.get_event_triggers().keys()) == [ *rx.Component().get_event_triggers(), # default event triggers
*default_event_triggers,
"on_change", "on_change",
] )
def test_render_child_props_recursive(): def test_render_child_props_recursive():

View File

@ -114,4 +114,4 @@ def test_serialize_dataframe():
value = serialize(df) value = serialize(df)
assert value == serialize_dataframe(df) assert value == serialize_dataframe(df)
assert isinstance(value, dict) assert isinstance(value, dict)
assert list(value.keys()) == ["columns", "data"] assert tuple(value) == ("columns", "data")

View File

@ -566,7 +566,7 @@ def test_get_event_triggers(component1, component2):
EventTriggers.ON_MOUNT, EventTriggers.ON_MOUNT,
EventTriggers.ON_UNMOUNT, EventTriggers.ON_UNMOUNT,
} }
assert set(component1().get_event_triggers().keys()) == default_triggers assert component1().get_event_triggers().keys() == default_triggers
assert ( assert (
component2().get_event_triggers().keys() component2().get_event_triggers().keys()
== {"on_open", "on_close"} | default_triggers == {"on_open", "on_close"} | default_triggers

View File

@ -235,9 +235,9 @@ def test_add_page_default_route(app: App, index_page, about_page):
""" """
assert app.pages == {} assert app.pages == {}
app.add_page(index_page) app.add_page(index_page)
assert set(app.pages.keys()) == {"index"} assert app.pages.keys() == {"index"}
app.add_page(about_page) app.add_page(about_page)
assert set(app.pages.keys()) == {"index", "about"} assert app.pages.keys() == {"index", "about"}
def test_add_page_set_route(app: App, index_page, windows_platform: bool): def test_add_page_set_route(app: App, index_page, windows_platform: bool):
@ -251,7 +251,7 @@ def test_add_page_set_route(app: App, index_page, windows_platform: bool):
route = "test" if windows_platform else "/test" route = "test" if windows_platform else "/test"
assert app.pages == {} assert app.pages == {}
app.add_page(index_page, route=route) app.add_page(index_page, route=route)
assert set(app.pages.keys()) == {"test"} assert app.pages.keys() == {"test"}
def test_add_page_set_route_dynamic(index_page, windows_platform: bool): def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
@ -268,7 +268,7 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
route.lstrip("/").replace("/", "\\") route.lstrip("/").replace("/", "\\")
assert app.pages == {} assert app.pages == {}
app.add_page(index_page, route=route) app.add_page(index_page, route=route)
assert set(app.pages.keys()) == {"test/[dynamic]"} assert app.pages.keys() == {"test/[dynamic]"}
assert "dynamic" in app.state.computed_vars assert "dynamic" in app.state.computed_vars
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == { assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
constants.ROUTER constants.ROUTER
@ -287,7 +287,7 @@ def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool)
route = "test\\nested" if windows_platform else "/test/nested" route = "test\\nested" if windows_platform else "/test/nested"
assert app.pages == {} assert app.pages == {}
app.add_page(index_page, route=route) app.add_page(index_page, route=route)
assert set(app.pages.keys()) == {route.strip(os.path.sep)} assert app.pages.keys() == {route.strip(os.path.sep)}
def test_add_page_invalid_api_route(app: App, index_page): def test_add_page_invalid_api_route(app: App, index_page):

View File

@ -287,7 +287,7 @@ def test_class_vars(test_state):
test_state: A state. test_state: A state.
""" """
cls = type(test_state) cls = type(test_state)
assert set(cls.vars.keys()) == { assert cls.vars.keys() == {
"router", "router",
"num1", "num1",
"num2", "num2",
@ -310,7 +310,7 @@ def test_event_handlers(test_state):
Args: Args:
test_state: A state. test_state: A state.
""" """
expected = { expected_keys = (
"do_something", "do_something",
"set_array", "set_array",
"set_complex", "set_complex",
@ -320,10 +320,10 @@ def test_event_handlers(test_state):
"set_num1", "set_num1",
"set_num2", "set_num2",
"set_obj", "set_obj",
} )
cls = type(test_state) cls = type(test_state)
assert set(cls.event_handlers.keys()).intersection(expected) == expected assert all(key in cls.event_handlers for key in expected_keys)
def test_default_value(test_state): def test_default_value(test_state):

View File

@ -72,7 +72,7 @@ def test_merge_imports(input_1, input_2, output):
""" """
res = merge_imports(input_1, input_2) res = merge_imports(input_1, input_2)
assert set(res.keys()) == set(output.keys()) assert res.keys() == output.keys()
for key in output: for key in output:
assert set(res[key]) == set(output[key]) assert set(res[key]) == set(output[key])