Make better/less use of dict.keys() calls (#3455)
This commit is contained in:
parent
bb44d51f2f
commit
ad3134413b
@ -66,7 +66,8 @@ class State(State):
|
||||
del self.chats[self.current_chat]
|
||||
if len(self.chats) == 0:
|
||||
self.chats = DEFAULT_CHATS
|
||||
self.current_chat = list(self.chats.keys())[0]
|
||||
# set self.current_chat to the first chat.
|
||||
self.current_chat = next(iter(self.chats))
|
||||
self.toggle_drawer()
|
||||
|
||||
def set_chat(self, chat_name: str):
|
||||
@ -85,7 +86,7 @@ class State(State):
|
||||
Returns:
|
||||
The list of chat names.
|
||||
"""
|
||||
return list(self.chats.keys())
|
||||
return [*self.chats]
|
||||
|
||||
async def process_question(self, form_data: dict[str, str]):
|
||||
"""Get the response from the API.
|
||||
|
@ -707,11 +707,8 @@ class App(LifespanMixin, Base):
|
||||
page_imports = {
|
||||
i
|
||||
for i, tags in imports.items()
|
||||
if i
|
||||
not in [
|
||||
*constants.PackageJson.DEPENDENCIES.keys(),
|
||||
*constants.PackageJson.DEV_DEPENDENCIES.keys(),
|
||||
]
|
||||
if i not in constants.PackageJson.DEPENDENCIES
|
||||
and i not in constants.PackageJson.DEV_DEPENDENCIES
|
||||
and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"])
|
||||
and i != ""
|
||||
and any(tag.install for tag in tags)
|
||||
|
@ -360,7 +360,6 @@ class Component(BaseComponent, ABC):
|
||||
# Get the component fields, triggers, and props.
|
||||
fields = self.get_fields()
|
||||
component_specific_triggers = self.get_event_triggers()
|
||||
triggers = component_specific_triggers.keys()
|
||||
props = self.get_props()
|
||||
|
||||
# Add any events triggers.
|
||||
@ -370,13 +369,17 @@ class Component(BaseComponent, ABC):
|
||||
|
||||
# Iterate through the kwargs and set the props.
|
||||
for key, value in kwargs.items():
|
||||
if key.startswith("on_") and key not in triggers and key not in props:
|
||||
if (
|
||||
key.startswith("on_")
|
||||
and key not in component_specific_triggers
|
||||
and key not in props
|
||||
):
|
||||
raise ValueError(
|
||||
f"The {(comp_name := type(self).__name__)} does not take in an `{key}` event trigger. If {comp_name}"
|
||||
f" is a third party component make sure to add `{key}` to the component's event triggers. "
|
||||
f"visit https://reflex.dev/docs/wrapping-react/guide/#event-triggers for more info."
|
||||
)
|
||||
if key in triggers:
|
||||
if key in component_specific_triggers:
|
||||
# Event triggers are bound to event chains.
|
||||
field_type = EventChain
|
||||
elif key in props:
|
||||
@ -436,7 +439,7 @@ class Component(BaseComponent, ABC):
|
||||
)
|
||||
|
||||
# Check if the key is an event trigger.
|
||||
if key in triggers:
|
||||
if key in component_specific_triggers:
|
||||
# Temporarily disable full control for event triggers.
|
||||
kwargs["event_triggers"][key] = self._create_event_chain(
|
||||
value=value, args_spec=component_specific_triggers[key]
|
||||
|
@ -424,7 +424,7 @@ def _generate_component_create_functiondef(
|
||||
),
|
||||
ast.Constant(value=None),
|
||||
)
|
||||
for trigger in sorted(clz().get_event_triggers().keys())
|
||||
for trigger in sorted(clz().get_event_triggers())
|
||||
)
|
||||
logger.debug(f"Generated {clz.__name__}.create method with {len(kwargs)} kwargs")
|
||||
create_args = ast.arguments(
|
||||
|
@ -509,7 +509,7 @@ def validate_parameter_literals(func):
|
||||
annotations = {param[0]: param[1].annotation for param in func_params}
|
||||
|
||||
# validate args
|
||||
for param, arg in zip(annotations.keys(), args):
|
||||
for param, arg in zip(annotations, args):
|
||||
if annotations[param] is inspect.Parameter.empty:
|
||||
continue
|
||||
validate_literal(param, arg, annotations[param], func.__name__)
|
||||
|
@ -10,19 +10,19 @@ from reflex.components.radix.themes.typography.text import Text
|
||||
def test_websocket_target_url():
|
||||
url = WebsocketTargetURL.create()
|
||||
_imports = url._get_all_imports(collapse=True)
|
||||
assert list(_imports.keys()) == ["/utils/state", "/env.json"]
|
||||
assert tuple(_imports) == ("/utils/state", "/env.json")
|
||||
|
||||
|
||||
def test_connection_banner():
|
||||
banner = ConnectionBanner.create()
|
||||
_imports = banner._get_all_imports(collapse=True)
|
||||
assert list(_imports.keys()) == [
|
||||
assert tuple(_imports) == (
|
||||
"react",
|
||||
"/utils/context",
|
||||
"/utils/state",
|
||||
"@radix-ui/themes@^3.0.0",
|
||||
"/env.json",
|
||||
]
|
||||
)
|
||||
|
||||
msg = "Connection error"
|
||||
custom_banner = ConnectionBanner.create(Text.create(msg))
|
||||
@ -32,13 +32,13 @@ def test_connection_banner():
|
||||
def test_connection_modal():
|
||||
modal = ConnectionModal.create()
|
||||
_imports = modal._get_all_imports(collapse=True)
|
||||
assert list(_imports.keys()) == [
|
||||
assert tuple(_imports) == (
|
||||
"react",
|
||||
"/utils/context",
|
||||
"/utils/state",
|
||||
"@radix-ui/themes@^3.0.0",
|
||||
"/env.json",
|
||||
]
|
||||
)
|
||||
|
||||
msg = "Connection error"
|
||||
custom_modal = ConnectionModal.create(Text.create(msg))
|
||||
|
@ -98,11 +98,10 @@ def test_event_triggers():
|
||||
on_change=S.on_change,
|
||||
)
|
||||
)
|
||||
default_event_triggers = list(rx.Component().get_event_triggers().keys())
|
||||
assert list(debounced_input.get_event_triggers().keys()) == [
|
||||
*default_event_triggers,
|
||||
assert tuple(debounced_input.get_event_triggers()) == (
|
||||
*rx.Component().get_event_triggers(), # default event triggers
|
||||
"on_change",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_render_child_props_recursive():
|
||||
|
@ -114,4 +114,4 @@ def test_serialize_dataframe():
|
||||
value = serialize(df)
|
||||
assert value == serialize_dataframe(df)
|
||||
assert isinstance(value, dict)
|
||||
assert list(value.keys()) == ["columns", "data"]
|
||||
assert tuple(value) == ("columns", "data")
|
||||
|
@ -566,7 +566,7 @@ def test_get_event_triggers(component1, component2):
|
||||
EventTriggers.ON_MOUNT,
|
||||
EventTriggers.ON_UNMOUNT,
|
||||
}
|
||||
assert set(component1().get_event_triggers().keys()) == default_triggers
|
||||
assert component1().get_event_triggers().keys() == default_triggers
|
||||
assert (
|
||||
component2().get_event_triggers().keys()
|
||||
== {"on_open", "on_close"} | default_triggers
|
||||
|
@ -235,9 +235,9 @@ def test_add_page_default_route(app: App, index_page, about_page):
|
||||
"""
|
||||
assert app.pages == {}
|
||||
app.add_page(index_page)
|
||||
assert set(app.pages.keys()) == {"index"}
|
||||
assert app.pages.keys() == {"index"}
|
||||
app.add_page(about_page)
|
||||
assert set(app.pages.keys()) == {"index", "about"}
|
||||
assert app.pages.keys() == {"index", "about"}
|
||||
|
||||
|
||||
def test_add_page_set_route(app: App, index_page, windows_platform: bool):
|
||||
@ -251,7 +251,7 @@ def test_add_page_set_route(app: App, index_page, windows_platform: bool):
|
||||
route = "test" if windows_platform else "/test"
|
||||
assert app.pages == {}
|
||||
app.add_page(index_page, route=route)
|
||||
assert set(app.pages.keys()) == {"test"}
|
||||
assert app.pages.keys() == {"test"}
|
||||
|
||||
|
||||
def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
|
||||
@ -268,7 +268,7 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
|
||||
route.lstrip("/").replace("/", "\\")
|
||||
assert app.pages == {}
|
||||
app.add_page(index_page, route=route)
|
||||
assert set(app.pages.keys()) == {"test/[dynamic]"}
|
||||
assert app.pages.keys() == {"test/[dynamic]"}
|
||||
assert "dynamic" in app.state.computed_vars
|
||||
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
|
||||
constants.ROUTER
|
||||
@ -287,7 +287,7 @@ def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool)
|
||||
route = "test\\nested" if windows_platform else "/test/nested"
|
||||
assert app.pages == {}
|
||||
app.add_page(index_page, route=route)
|
||||
assert set(app.pages.keys()) == {route.strip(os.path.sep)}
|
||||
assert app.pages.keys() == {route.strip(os.path.sep)}
|
||||
|
||||
|
||||
def test_add_page_invalid_api_route(app: App, index_page):
|
||||
|
@ -287,7 +287,7 @@ def test_class_vars(test_state):
|
||||
test_state: A state.
|
||||
"""
|
||||
cls = type(test_state)
|
||||
assert set(cls.vars.keys()) == {
|
||||
assert cls.vars.keys() == {
|
||||
"router",
|
||||
"num1",
|
||||
"num2",
|
||||
@ -310,7 +310,7 @@ def test_event_handlers(test_state):
|
||||
Args:
|
||||
test_state: A state.
|
||||
"""
|
||||
expected = {
|
||||
expected_keys = (
|
||||
"do_something",
|
||||
"set_array",
|
||||
"set_complex",
|
||||
@ -320,10 +320,10 @@ def test_event_handlers(test_state):
|
||||
"set_num1",
|
||||
"set_num2",
|
||||
"set_obj",
|
||||
}
|
||||
)
|
||||
|
||||
cls = type(test_state)
|
||||
assert set(cls.event_handlers.keys()).intersection(expected) == expected
|
||||
assert all(key in cls.event_handlers for key in expected_keys)
|
||||
|
||||
|
||||
def test_default_value(test_state):
|
||||
|
@ -72,7 +72,7 @@ def test_merge_imports(input_1, input_2, output):
|
||||
|
||||
"""
|
||||
res = merge_imports(input_1, input_2)
|
||||
assert set(res.keys()) == set(output.keys())
|
||||
assert res.keys() == output.keys()
|
||||
|
||||
for key in output:
|
||||
assert set(res[key]) == set(output[key])
|
||||
|
Loading…
Reference in New Issue
Block a user