Make better/less use of dict.keys() calls (#3455)
This commit is contained in:
parent
bb44d51f2f
commit
ad3134413b
@ -66,7 +66,8 @@ class State(State):
|
|||||||
del self.chats[self.current_chat]
|
del self.chats[self.current_chat]
|
||||||
if len(self.chats) == 0:
|
if len(self.chats) == 0:
|
||||||
self.chats = DEFAULT_CHATS
|
self.chats = DEFAULT_CHATS
|
||||||
self.current_chat = list(self.chats.keys())[0]
|
# set self.current_chat to the first chat.
|
||||||
|
self.current_chat = next(iter(self.chats))
|
||||||
self.toggle_drawer()
|
self.toggle_drawer()
|
||||||
|
|
||||||
def set_chat(self, chat_name: str):
|
def set_chat(self, chat_name: str):
|
||||||
@ -85,7 +86,7 @@ class State(State):
|
|||||||
Returns:
|
Returns:
|
||||||
The list of chat names.
|
The list of chat names.
|
||||||
"""
|
"""
|
||||||
return list(self.chats.keys())
|
return [*self.chats]
|
||||||
|
|
||||||
async def process_question(self, form_data: dict[str, str]):
|
async def process_question(self, form_data: dict[str, str]):
|
||||||
"""Get the response from the API.
|
"""Get the response from the API.
|
||||||
|
@ -707,11 +707,8 @@ class App(LifespanMixin, Base):
|
|||||||
page_imports = {
|
page_imports = {
|
||||||
i
|
i
|
||||||
for i, tags in imports.items()
|
for i, tags in imports.items()
|
||||||
if i
|
if i not in constants.PackageJson.DEPENDENCIES
|
||||||
not in [
|
and i not in constants.PackageJson.DEV_DEPENDENCIES
|
||||||
*constants.PackageJson.DEPENDENCIES.keys(),
|
|
||||||
*constants.PackageJson.DEV_DEPENDENCIES.keys(),
|
|
||||||
]
|
|
||||||
and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"])
|
and not any(i.startswith(prefix) for prefix in ["/", ".", "next/"])
|
||||||
and i != ""
|
and i != ""
|
||||||
and any(tag.install for tag in tags)
|
and any(tag.install for tag in tags)
|
||||||
|
@ -360,7 +360,6 @@ class Component(BaseComponent, ABC):
|
|||||||
# Get the component fields, triggers, and props.
|
# Get the component fields, triggers, and props.
|
||||||
fields = self.get_fields()
|
fields = self.get_fields()
|
||||||
component_specific_triggers = self.get_event_triggers()
|
component_specific_triggers = self.get_event_triggers()
|
||||||
triggers = component_specific_triggers.keys()
|
|
||||||
props = self.get_props()
|
props = self.get_props()
|
||||||
|
|
||||||
# Add any events triggers.
|
# Add any events triggers.
|
||||||
@ -370,13 +369,17 @@ class Component(BaseComponent, ABC):
|
|||||||
|
|
||||||
# Iterate through the kwargs and set the props.
|
# Iterate through the kwargs and set the props.
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key.startswith("on_") and key not in triggers and key not in props:
|
if (
|
||||||
|
key.startswith("on_")
|
||||||
|
and key not in component_specific_triggers
|
||||||
|
and key not in props
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The {(comp_name := type(self).__name__)} does not take in an `{key}` event trigger. If {comp_name}"
|
f"The {(comp_name := type(self).__name__)} does not take in an `{key}` event trigger. If {comp_name}"
|
||||||
f" is a third party component make sure to add `{key}` to the component's event triggers. "
|
f" is a third party component make sure to add `{key}` to the component's event triggers. "
|
||||||
f"visit https://reflex.dev/docs/wrapping-react/guide/#event-triggers for more info."
|
f"visit https://reflex.dev/docs/wrapping-react/guide/#event-triggers for more info."
|
||||||
)
|
)
|
||||||
if key in triggers:
|
if key in component_specific_triggers:
|
||||||
# Event triggers are bound to event chains.
|
# Event triggers are bound to event chains.
|
||||||
field_type = EventChain
|
field_type = EventChain
|
||||||
elif key in props:
|
elif key in props:
|
||||||
@ -436,7 +439,7 @@ class Component(BaseComponent, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if the key is an event trigger.
|
# Check if the key is an event trigger.
|
||||||
if key in triggers:
|
if key in component_specific_triggers:
|
||||||
# Temporarily disable full control for event triggers.
|
# Temporarily disable full control for event triggers.
|
||||||
kwargs["event_triggers"][key] = self._create_event_chain(
|
kwargs["event_triggers"][key] = self._create_event_chain(
|
||||||
value=value, args_spec=component_specific_triggers[key]
|
value=value, args_spec=component_specific_triggers[key]
|
||||||
|
@ -424,7 +424,7 @@ def _generate_component_create_functiondef(
|
|||||||
),
|
),
|
||||||
ast.Constant(value=None),
|
ast.Constant(value=None),
|
||||||
)
|
)
|
||||||
for trigger in sorted(clz().get_event_triggers().keys())
|
for trigger in sorted(clz().get_event_triggers())
|
||||||
)
|
)
|
||||||
logger.debug(f"Generated {clz.__name__}.create method with {len(kwargs)} kwargs")
|
logger.debug(f"Generated {clz.__name__}.create method with {len(kwargs)} kwargs")
|
||||||
create_args = ast.arguments(
|
create_args = ast.arguments(
|
||||||
|
@ -509,7 +509,7 @@ def validate_parameter_literals(func):
|
|||||||
annotations = {param[0]: param[1].annotation for param in func_params}
|
annotations = {param[0]: param[1].annotation for param in func_params}
|
||||||
|
|
||||||
# validate args
|
# validate args
|
||||||
for param, arg in zip(annotations.keys(), args):
|
for param, arg in zip(annotations, args):
|
||||||
if annotations[param] is inspect.Parameter.empty:
|
if annotations[param] is inspect.Parameter.empty:
|
||||||
continue
|
continue
|
||||||
validate_literal(param, arg, annotations[param], func.__name__)
|
validate_literal(param, arg, annotations[param], func.__name__)
|
||||||
|
@ -10,19 +10,19 @@ from reflex.components.radix.themes.typography.text import Text
|
|||||||
def test_websocket_target_url():
|
def test_websocket_target_url():
|
||||||
url = WebsocketTargetURL.create()
|
url = WebsocketTargetURL.create()
|
||||||
_imports = url._get_all_imports(collapse=True)
|
_imports = url._get_all_imports(collapse=True)
|
||||||
assert list(_imports.keys()) == ["/utils/state", "/env.json"]
|
assert tuple(_imports) == ("/utils/state", "/env.json")
|
||||||
|
|
||||||
|
|
||||||
def test_connection_banner():
|
def test_connection_banner():
|
||||||
banner = ConnectionBanner.create()
|
banner = ConnectionBanner.create()
|
||||||
_imports = banner._get_all_imports(collapse=True)
|
_imports = banner._get_all_imports(collapse=True)
|
||||||
assert list(_imports.keys()) == [
|
assert tuple(_imports) == (
|
||||||
"react",
|
"react",
|
||||||
"/utils/context",
|
"/utils/context",
|
||||||
"/utils/state",
|
"/utils/state",
|
||||||
"@radix-ui/themes@^3.0.0",
|
"@radix-ui/themes@^3.0.0",
|
||||||
"/env.json",
|
"/env.json",
|
||||||
]
|
)
|
||||||
|
|
||||||
msg = "Connection error"
|
msg = "Connection error"
|
||||||
custom_banner = ConnectionBanner.create(Text.create(msg))
|
custom_banner = ConnectionBanner.create(Text.create(msg))
|
||||||
@ -32,13 +32,13 @@ def test_connection_banner():
|
|||||||
def test_connection_modal():
|
def test_connection_modal():
|
||||||
modal = ConnectionModal.create()
|
modal = ConnectionModal.create()
|
||||||
_imports = modal._get_all_imports(collapse=True)
|
_imports = modal._get_all_imports(collapse=True)
|
||||||
assert list(_imports.keys()) == [
|
assert tuple(_imports) == (
|
||||||
"react",
|
"react",
|
||||||
"/utils/context",
|
"/utils/context",
|
||||||
"/utils/state",
|
"/utils/state",
|
||||||
"@radix-ui/themes@^3.0.0",
|
"@radix-ui/themes@^3.0.0",
|
||||||
"/env.json",
|
"/env.json",
|
||||||
]
|
)
|
||||||
|
|
||||||
msg = "Connection error"
|
msg = "Connection error"
|
||||||
custom_modal = ConnectionModal.create(Text.create(msg))
|
custom_modal = ConnectionModal.create(Text.create(msg))
|
||||||
|
@ -98,11 +98,10 @@ def test_event_triggers():
|
|||||||
on_change=S.on_change,
|
on_change=S.on_change,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
default_event_triggers = list(rx.Component().get_event_triggers().keys())
|
assert tuple(debounced_input.get_event_triggers()) == (
|
||||||
assert list(debounced_input.get_event_triggers().keys()) == [
|
*rx.Component().get_event_triggers(), # default event triggers
|
||||||
*default_event_triggers,
|
|
||||||
"on_change",
|
"on_change",
|
||||||
]
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_render_child_props_recursive():
|
def test_render_child_props_recursive():
|
||||||
|
@ -114,4 +114,4 @@ def test_serialize_dataframe():
|
|||||||
value = serialize(df)
|
value = serialize(df)
|
||||||
assert value == serialize_dataframe(df)
|
assert value == serialize_dataframe(df)
|
||||||
assert isinstance(value, dict)
|
assert isinstance(value, dict)
|
||||||
assert list(value.keys()) == ["columns", "data"]
|
assert tuple(value) == ("columns", "data")
|
||||||
|
@ -566,7 +566,7 @@ def test_get_event_triggers(component1, component2):
|
|||||||
EventTriggers.ON_MOUNT,
|
EventTriggers.ON_MOUNT,
|
||||||
EventTriggers.ON_UNMOUNT,
|
EventTriggers.ON_UNMOUNT,
|
||||||
}
|
}
|
||||||
assert set(component1().get_event_triggers().keys()) == default_triggers
|
assert component1().get_event_triggers().keys() == default_triggers
|
||||||
assert (
|
assert (
|
||||||
component2().get_event_triggers().keys()
|
component2().get_event_triggers().keys()
|
||||||
== {"on_open", "on_close"} | default_triggers
|
== {"on_open", "on_close"} | default_triggers
|
||||||
|
@ -235,9 +235,9 @@ def test_add_page_default_route(app: App, index_page, about_page):
|
|||||||
"""
|
"""
|
||||||
assert app.pages == {}
|
assert app.pages == {}
|
||||||
app.add_page(index_page)
|
app.add_page(index_page)
|
||||||
assert set(app.pages.keys()) == {"index"}
|
assert app.pages.keys() == {"index"}
|
||||||
app.add_page(about_page)
|
app.add_page(about_page)
|
||||||
assert set(app.pages.keys()) == {"index", "about"}
|
assert app.pages.keys() == {"index", "about"}
|
||||||
|
|
||||||
|
|
||||||
def test_add_page_set_route(app: App, index_page, windows_platform: bool):
|
def test_add_page_set_route(app: App, index_page, windows_platform: bool):
|
||||||
@ -251,7 +251,7 @@ def test_add_page_set_route(app: App, index_page, windows_platform: bool):
|
|||||||
route = "test" if windows_platform else "/test"
|
route = "test" if windows_platform else "/test"
|
||||||
assert app.pages == {}
|
assert app.pages == {}
|
||||||
app.add_page(index_page, route=route)
|
app.add_page(index_page, route=route)
|
||||||
assert set(app.pages.keys()) == {"test"}
|
assert app.pages.keys() == {"test"}
|
||||||
|
|
||||||
|
|
||||||
def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
|
def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
|
||||||
@ -268,7 +268,7 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
|
|||||||
route.lstrip("/").replace("/", "\\")
|
route.lstrip("/").replace("/", "\\")
|
||||||
assert app.pages == {}
|
assert app.pages == {}
|
||||||
app.add_page(index_page, route=route)
|
app.add_page(index_page, route=route)
|
||||||
assert set(app.pages.keys()) == {"test/[dynamic]"}
|
assert app.pages.keys() == {"test/[dynamic]"}
|
||||||
assert "dynamic" in app.state.computed_vars
|
assert "dynamic" in app.state.computed_vars
|
||||||
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
|
assert app.state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
|
||||||
constants.ROUTER
|
constants.ROUTER
|
||||||
@ -287,7 +287,7 @@ def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool)
|
|||||||
route = "test\\nested" if windows_platform else "/test/nested"
|
route = "test\\nested" if windows_platform else "/test/nested"
|
||||||
assert app.pages == {}
|
assert app.pages == {}
|
||||||
app.add_page(index_page, route=route)
|
app.add_page(index_page, route=route)
|
||||||
assert set(app.pages.keys()) == {route.strip(os.path.sep)}
|
assert app.pages.keys() == {route.strip(os.path.sep)}
|
||||||
|
|
||||||
|
|
||||||
def test_add_page_invalid_api_route(app: App, index_page):
|
def test_add_page_invalid_api_route(app: App, index_page):
|
||||||
|
@ -287,7 +287,7 @@ def test_class_vars(test_state):
|
|||||||
test_state: A state.
|
test_state: A state.
|
||||||
"""
|
"""
|
||||||
cls = type(test_state)
|
cls = type(test_state)
|
||||||
assert set(cls.vars.keys()) == {
|
assert cls.vars.keys() == {
|
||||||
"router",
|
"router",
|
||||||
"num1",
|
"num1",
|
||||||
"num2",
|
"num2",
|
||||||
@ -310,7 +310,7 @@ def test_event_handlers(test_state):
|
|||||||
Args:
|
Args:
|
||||||
test_state: A state.
|
test_state: A state.
|
||||||
"""
|
"""
|
||||||
expected = {
|
expected_keys = (
|
||||||
"do_something",
|
"do_something",
|
||||||
"set_array",
|
"set_array",
|
||||||
"set_complex",
|
"set_complex",
|
||||||
@ -320,10 +320,10 @@ def test_event_handlers(test_state):
|
|||||||
"set_num1",
|
"set_num1",
|
||||||
"set_num2",
|
"set_num2",
|
||||||
"set_obj",
|
"set_obj",
|
||||||
}
|
)
|
||||||
|
|
||||||
cls = type(test_state)
|
cls = type(test_state)
|
||||||
assert set(cls.event_handlers.keys()).intersection(expected) == expected
|
assert all(key in cls.event_handlers for key in expected_keys)
|
||||||
|
|
||||||
|
|
||||||
def test_default_value(test_state):
|
def test_default_value(test_state):
|
||||||
|
@ -72,7 +72,7 @@ def test_merge_imports(input_1, input_2, output):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
res = merge_imports(input_1, input_2)
|
res = merge_imports(input_1, input_2)
|
||||||
assert set(res.keys()) == set(output.keys())
|
assert res.keys() == output.keys()
|
||||||
|
|
||||||
for key in output:
|
for key in output:
|
||||||
assert set(res[key]) == set(output[key])
|
assert set(res[key]) == set(output[key])
|
||||||
|
Loading…
Reference in New Issue
Block a user