Merge remote-tracking branch 'origin/main' into lendemor/add_PERF_rules

This commit is contained in:
Masen Furer 2024-12-13 14:10:11 -08:00
commit 291b6f814e
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
46 changed files with 924 additions and 265 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from utils import send_data_to_posthog
@ -18,7 +19,7 @@ def extract_stats_from_json(json_file: str) -> list[dict]:
Returns:
list[dict]: The stats for each test.
"""
with open(json_file, "r") as file:
with Path(json_file).open() as file:
json_data = json.load(file)
# Load the JSON data if it is a string, otherwise assume it's already a dictionary

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from utils import send_data_to_posthog
@ -18,7 +19,7 @@ def extract_stats_from_json(json_file: str) -> dict:
Returns:
dict: The stats for each test.
"""
with open(json_file, "r") as file:
with Path(json_file).open() as file:
json_data = json.load(file)
# Load the JSON data if it is a string, otherwise assume it's already a dictionary

View File

@ -4,26 +4,19 @@ version = "0.6.7dev1"
description = "Web apps in pure Python."
license = "Apache-2.0"
authors = [
"Nikhil Rao <nikhil@reflex.dev>",
"Alek Petuskey <alek@reflex.dev>",
"Masen Furer <masen@reflex.dev>",
"Elijah Ahianyo <elijah@reflex.dev>",
"Thomas Brandého <thomas@reflex.dev>",
"Nikhil Rao <nikhil@reflex.dev>",
"Alek Petuskey <alek@reflex.dev>",
"Masen Furer <masen@reflex.dev>",
"Elijah Ahianyo <elijah@reflex.dev>",
"Thomas Brandého <thomas@reflex.dev>",
]
readme = "README.md"
homepage = "https://reflex.dev"
repository = "https://github.com/reflex-dev/reflex"
documentation = "https://reflex.dev/docs/getting-started/introduction"
keywords = [
"web",
"framework",
]
classifiers = [
"Development Status :: 4 - Beta",
]
packages = [
{include = "reflex"}
]
keywords = ["web", "framework"]
classifiers = ["Development Status :: 4 - Beta"]
packages = [{ include = "reflex" }]
[tool.poetry.dependencies]
python = "^3.9"
@ -42,11 +35,11 @@ uvicorn = ">=0.20.0"
starlette-admin = ">=0.11.0,<1.0"
alembic = ">=1.11.1,<2.0"
platformdirs = ">=3.10.0,<5.0"
distro = {version = ">=1.8.0,<2.0", platform = "linux"}
distro = { version = ">=1.8.0,<2.0", platform = "linux" }
python-engineio = "!=4.6.0"
wrapt = [
{version = ">=1.14.0,<2.0", python = ">=3.11"},
{version = ">=1.11.0,<2.0", python = "<3.11"},
{ version = ">=1.14.0,<2.0", python = ">=3.11" },
{ version = ">=1.11.0,<2.0", python = "<3.11" },
]
packaging = ">=23.1,<25.0"
reflex-hosting-cli = ">=0.1.29,<2.0"
@ -94,7 +87,7 @@ build-backend = "poetry.core.masonry.api"
target-version = "py39"
output-format = "concise"
lint.isort.split-on-trailing-comma = false
lint.select = ["B", "D", "E", "F", "I", "SIM", "W", "RUF", "FURB", "PERF", "ERA"]
lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "W"]
lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012"]
lint.pydocstyle.convention = "google"

View File

@ -5,11 +5,15 @@ export function {{tag_name}} () {
{{ hook }}
{% endfor %}
{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
{{ hook }}
{% endfor %}
{% for hook in memo_trigger_hooks %}
{{ hook }}
{% endfor %}
{% for hook in component._get_all_hooks() %}
{% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
{{ hook }}
{% endfor %}

View File

@ -436,7 +436,7 @@ class App(MiddlewareMixin, LifespanMixin):
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_origins=["*"],
allow_origins=get_config().cors_allowed_origins,
)
@property

View File

@ -45,6 +45,7 @@ class ReflexJinjaEnvironment(Environment):
"on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
"update_vars_internal": constants.CompileVars.UPDATE_VARS_INTERNAL,
"frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL,
"hook_position": constants.Hooks.HookPosition,
}

View File

@ -115,7 +115,7 @@ def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
default, rest = compile_import_statement(fields)
# prevent lib from being rendered on the page if all imports are non rendered kind
if not any({f.render for f in fields}): # type: ignore
if not any(f.render for f in fields): # type: ignore
continue
if not lib:

View File

@ -1208,7 +1208,7 @@ class Component(BaseComponent, ABC):
Yields:
The parent classes that define the method (differently than the base).
"""
seen_methods = set([getattr(Component, method)])
seen_methods = {getattr(Component, method)}
for clz in cls.mro():
if clz is Component:
break
@ -1368,7 +1368,9 @@ class Component(BaseComponent, ABC):
if user_hooks_data is not None:
other_imports.append(user_hooks_data.imports)
other_imports.extend(
hook_imports for hook_imports in self._get_added_hooks().values()
hook_vardata.imports
for hook_vardata in self._get_added_hooks().values()
if hook_vardata is not None
)
return imports.merge_imports(_imports, *other_imports)
@ -1390,15 +1392,9 @@ class Component(BaseComponent, ABC):
# Collect imports from Vars used directly by this component.
var_datas = [var._get_all_var_data() for var in self._get_vars()]
var_imports: List[ImmutableParsedImportDict] = list(
map(
lambda var_data: var_data.imports,
filter(
None,
var_datas,
),
)
)
var_imports: List[ImmutableParsedImportDict] = [
var_data.imports for var_data in var_datas if var_data is not None
]
added_import_dicts: list[ParsedImportDict] = []
for clz in self._iter_parent_classes_with_method("add_imports"):
@ -1523,7 +1519,7 @@ class Component(BaseComponent, ABC):
**self._get_special_hooks(),
}
def _get_added_hooks(self) -> dict[str, ImportDict]:
def _get_added_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks added via `add_hooks` method.
Returns:
@ -1532,17 +1528,15 @@ class Component(BaseComponent, ABC):
code = {}
def extract_var_hooks(hook: Var):
_imports = {}
var_data = VarData.merge(hook._get_all_var_data())
if var_data is not None:
for sub_hook in var_data.hooks:
code[sub_hook] = {}
if var_data.imports:
_imports = var_data.imports
code[sub_hook] = None
if str(hook) in code:
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
code[str(hook)] = VarData.merge(var_data, code[str(hook)])
else:
code[str(hook)] = _imports
code[str(hook)] = var_data
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output)
@ -1551,7 +1545,7 @@ class Component(BaseComponent, ABC):
if isinstance(hook, Var):
extract_var_hooks(hook)
else:
code[hook] = {}
code[hook] = None
return code
@ -1593,8 +1587,7 @@ class Component(BaseComponent, ABC):
if hooks is not None:
code[hooks] = None
for hook in self._get_added_hooks():
code[hook] = None
code.update(self._get_added_hooks())
# Add the hook code for the children.
for child in self.children:
@ -2196,6 +2189,31 @@ class StatefulComponent(BaseComponent):
]
return [var_name]
@staticmethod
def _get_deps_from_event_trigger(event: EventChain | EventSpec | Var) -> set[str]:
"""Get the dependencies accessed by event triggers.
Args:
event: The event trigger to extract deps from.
Returns:
The dependencies accessed by the event triggers.
"""
events: list = [event]
deps = set()
if isinstance(event, EventChain):
events.extend(event.events)
for ev in events:
if isinstance(ev, EventSpec):
for arg in ev.args:
for a in arg:
var_datas = VarData.merge(a._get_all_var_data())
if var_datas and var_datas.deps is not None:
deps |= {str(dep) for dep in var_datas.deps}
return deps
@classmethod
def _get_memoized_event_triggers(
cls,
@ -2232,6 +2250,11 @@ class StatefulComponent(BaseComponent):
# Calculate Var dependencies accessed by the handler for useCallback dep array.
var_deps = ["addEvents", "Event"]
# Get deps from event trigger var data.
var_deps.extend(cls._get_deps_from_event_trigger(event))
# Get deps from hooks.
for arg in event_args:
var_data = arg._get_all_var_data()
if var_data is None:

View File

@ -6,11 +6,12 @@ from typing import Dict, List, Tuple, Union
from reflex.components.base.fragment import Fragment
from reflex.components.tags.tag import Tag
from reflex.constants.compiler import Hooks
from reflex.event import EventChain, EventHandler, passthrough_event_spec
from reflex.utils.format import format_prop, wrap
from reflex.utils.imports import ImportVar
from reflex.vars import get_unique_variable_name
from reflex.vars.base import Var
from reflex.vars.base import Var, VarData
class Clipboard(Fragment):
@ -72,7 +73,7 @@ class Clipboard(Fragment):
),
}
def add_hooks(self) -> list[str]:
def add_hooks(self) -> list[str | Var[str]]:
"""Add hook to register paste event listener.
Returns:
@ -83,13 +84,14 @@ class Clipboard(Fragment):
return []
if isinstance(on_paste, EventChain):
on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})"
return [
"usePasteHandler(%s, %s, %s)"
% (
str(self.targets),
str(self.on_paste_event_actions),
on_paste,
)
Var(
hook_expr,
_var_type="str",
_var_data=VarData(position=Hooks.HookPosition.POST_TRIGGER),
),
]

View File

@ -71,6 +71,6 @@ class Clipboard(Fragment):
...
def add_imports(self) -> dict[str, ImportVar]: ...
def add_hooks(self) -> list[str]: ...
def add_hooks(self) -> list[str | Var[str]]: ...
clipboard = Clipboard.create

View File

@ -339,8 +339,11 @@ class DataEditor(NoSSRComponent):
editor_id = get_unique_variable_name()
# Define the name of the getData callback associated with this component and assign to get_cell_content.
data_callback = f"getData_{editor_id}"
self.get_cell_content = Var(_js_expr=data_callback) # type: ignore
if self.get_cell_content is not None:
data_callback = self.get_cell_content._js_expr
else:
data_callback = f"getData_{editor_id}"
self.get_cell_content = Var(_js_expr=data_callback) # type: ignore
code = [f"function {data_callback}([col, row])" "{"]

View File

@ -8,6 +8,7 @@ from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spe
from reflex.vars.base import Var
from ..base import LiteralAccentColor, RadixThemesComponent
from .checkbox import Checkbox
LiteralDirType = Literal["ltr", "rtl"]
@ -232,6 +233,15 @@ class ContextMenuSeparator(RadixThemesComponent):
tag = "ContextMenu.Separator"
class ContextMenuCheckbox(Checkbox):
"""The component that contains the checkbox."""
tag = "ContextMenu.CheckboxItem"
# Text to render as shortcut.
shortcut: Var[str]
class ContextMenu(ComponentNamespace):
"""Menu representing a set of actions, displayed at the origin of a pointer right-click or long-press."""
@ -243,6 +253,7 @@ class ContextMenu(ComponentNamespace):
sub_content = staticmethod(ContextMenuSubContent.create)
item = staticmethod(ContextMenuItem.create)
separator = staticmethod(ContextMenuSeparator.create)
checkbox = staticmethod(ContextMenuCheckbox.create)
context_menu = ContextMenu()

View File

@ -12,6 +12,7 @@ from reflex.style import Style
from reflex.vars.base import Var
from ..base import RadixThemesComponent
from .checkbox import Checkbox
LiteralDirType = Literal["ltr", "rtl"]
LiteralSizeType = Literal["1", "2"]
@ -672,6 +673,159 @@ class ContextMenuSeparator(RadixThemesComponent):
"""
...
class ContextMenuCheckbox(Checkbox):
@overload
@classmethod
def create( # type: ignore
cls,
*children,
shortcut: Optional[Union[Var[str], str]] = None,
as_child: Optional[Union[Var[bool], bool]] = None,
size: Optional[
Union[
Breakpoints[str, Literal["1", "2", "3"]],
Literal["1", "2", "3"],
Var[
Union[
Breakpoints[str, Literal["1", "2", "3"]], Literal["1", "2", "3"]
]
],
]
] = None,
variant: Optional[
Union[
Literal["classic", "soft", "surface"],
Var[Literal["classic", "soft", "surface"]],
]
] = None,
color_scheme: Optional[
Union[
Literal[
"amber",
"blue",
"bronze",
"brown",
"crimson",
"cyan",
"gold",
"grass",
"gray",
"green",
"indigo",
"iris",
"jade",
"lime",
"mint",
"orange",
"pink",
"plum",
"purple",
"red",
"ruby",
"sky",
"teal",
"tomato",
"violet",
"yellow",
],
Var[
Literal[
"amber",
"blue",
"bronze",
"brown",
"crimson",
"cyan",
"gold",
"grass",
"gray",
"green",
"indigo",
"iris",
"jade",
"lime",
"mint",
"orange",
"pink",
"plum",
"purple",
"red",
"ruby",
"sky",
"teal",
"tomato",
"violet",
"yellow",
]
],
]
] = None,
high_contrast: Optional[Union[Var[bool], bool]] = None,
default_checked: Optional[Union[Var[bool], bool]] = None,
checked: Optional[Union[Var[bool], bool]] = None,
disabled: Optional[Union[Var[bool], bool]] = None,
required: Optional[Union[Var[bool], bool]] = None,
name: Optional[Union[Var[str], str]] = None,
value: Optional[Union[Var[str], str]] = None,
style: Optional[Style] = None,
key: Optional[Any] = None,
id: Optional[Any] = None,
class_name: Optional[Any] = None,
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
on_blur: Optional[EventType[[], BASE_STATE]] = None,
on_change: Optional[
Union[EventType[[], BASE_STATE], EventType[[bool], BASE_STATE]]
] = None,
on_click: Optional[EventType[[], BASE_STATE]] = None,
on_context_menu: Optional[EventType[[], BASE_STATE]] = None,
on_double_click: Optional[EventType[[], BASE_STATE]] = None,
on_focus: Optional[EventType[[], BASE_STATE]] = None,
on_mount: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_down: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_move: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_out: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_over: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_up: Optional[EventType[[], BASE_STATE]] = None,
on_scroll: Optional[EventType[[], BASE_STATE]] = None,
on_unmount: Optional[EventType[[], BASE_STATE]] = None,
**props,
) -> "ContextMenuCheckbox":
"""Create a new component instance.
Will prepend "RadixThemes" to the component tag to avoid conflicts with
other UI libraries for common names, like Text and Button.
Args:
*children: Child components.
shortcut: Text to render as shortcut.
as_child: Change the default rendered element for the one passed as a child, merging their props and behavior.
size: Checkbox size "1" - "3"
variant: Variant of checkbox: "classic" | "surface" | "soft"
color_scheme: Override theme color for checkbox
high_contrast: Whether to render the checkbox with higher contrast color against background
default_checked: Whether the checkbox is checked by default
checked: Whether the checkbox is checked
disabled: Whether the checkbox is disabled
required: Whether the checkbox is required
name: The name of the checkbox control when submitting the form.
value: The value of the checkbox control when submitting the form.
on_change: Fired when the checkbox is checked or unchecked.
style: The style of the component.
key: A unique key for the component.
id: The id for the component.
class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute
**props: Component properties.
Returns:
A new component instance.
"""
...
class ContextMenu(ComponentNamespace):
root = staticmethod(ContextMenuRoot.create)
trigger = staticmethod(ContextMenuTrigger.create)
@ -681,5 +835,6 @@ class ContextMenu(ComponentNamespace):
sub_content = staticmethod(ContextMenuSubContent.create)
item = staticmethod(ContextMenuItem.create)
separator = staticmethod(ContextMenuSeparator.create)
checkbox = staticmethod(ContextMenuCheckbox.create)
context_menu = ContextMenu()

View File

@ -79,7 +79,7 @@ class IconButton(elements.Button, RadixLoadingProp, RadixThemesComponent):
else:
size_map_var = Match.create(
props["size"],
*[(size, px) for size, px in RADIX_TO_LUCIDE_SIZE.items()],
*list(RADIX_TO_LUCIDE_SIZE.items()),
12,
)
if not isinstance(size_map_var, Var):

View File

@ -19,7 +19,7 @@ LiteralTextFieldSize = Literal["1", "2", "3"]
LiteralTextFieldVariant = Literal["classic", "surface", "soft"]
class TextFieldRoot(elements.Div, RadixThemesComponent):
class TextFieldRoot(elements.Input, RadixThemesComponent):
"""Captures user input with an optional slot for buttons and icons."""
tag = "TextField.Root"

View File

@ -17,7 +17,7 @@ from ..base import RadixThemesComponent
LiteralTextFieldSize = Literal["1", "2", "3"]
LiteralTextFieldVariant = Literal["classic", "surface", "soft"]
class TextFieldRoot(elements.Div, RadixThemesComponent):
class TextFieldRoot(elements.Input, RadixThemesComponent):
@overload
@classmethod
def create( # type: ignore
@ -120,6 +120,30 @@ class TextFieldRoot(elements.Div, RadixThemesComponent):
type: Optional[Union[Var[str], str]] = None,
value: Optional[Union[Var[Union[float, int, str]], float, int, str]] = None,
list: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
accept: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
alt: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
auto_focus: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
capture: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
checked: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
default_checked: Optional[Union[Var[bool], bool]] = None,
dirname: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form_action: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form_enc_type: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
] = None,
form_method: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form_no_validate: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
] = None,
form_target: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
max: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
min: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
multiple: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
pattern: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
src: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
step: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
use_map: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
access_key: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
auto_capitalize: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
@ -192,12 +216,12 @@ class TextFieldRoot(elements.Div, RadixThemesComponent):
Args:
*children: The children of the component.
size: Text field size "1" - "3"
size: Specifies the visible width of a text control
variant: Variant of text field: "classic" | "surface" | "soft"
color_scheme: Override theme color for text field
radius: Override theme radius for text field: "none" | "small" | "medium" | "large" | "full"
auto_complete: Whether the input should have autocomplete enabled
default_value: The value of the input when initially rendered.
default_value: The initial value for a text field
disabled: Disables the input
max_length: Specifies the maximum number of characters allowed in the input
min_length: Specifies the minimum number of characters required in the input
@ -208,11 +232,31 @@ class TextFieldRoot(elements.Div, RadixThemesComponent):
type: Specifies the type of input
value: Value of the input
list: References a datalist for suggested options
on_change: Fired when the value of the textarea changes.
on_focus: Fired when the textarea is focused.
on_blur: Fired when the textarea is blurred.
on_key_down: Fired when a key is pressed down.
on_key_up: Fired when a key is released.
on_change: Fired when the input value changes
on_focus: Fired when the input gains focus
on_blur: Fired when the input loses focus
on_key_down: Fired when a key is pressed down
on_key_up: Fired when a key is released
accept: Accepted types of files when the input is file type
alt: Alternate text for input type="image"
auto_focus: Automatically focuses the input when the page loads
capture: Captures media from the user (camera or microphone)
checked: Indicates whether the input is checked (for checkboxes and radio buttons)
default_checked: The initial value (for checkboxes and radio buttons)
dirname: Name part of the input to submit in 'dir' and 'name' pair when form is submitted
form: Associates the input with a form (by id)
form_action: URL to send the form data to (for type="submit" buttons)
form_enc_type: How the form data should be encoded when submitting to the server (for type="submit" buttons)
form_method: HTTP method to use for sending form data (for type="submit" buttons)
form_no_validate: Bypasses form validation when submitting (for type="submit" buttons)
form_target: Specifies where to display the response after submitting the form (for type="submit" buttons)
max: Specifies the maximum value for the input
min: Specifies the minimum value for the input
multiple: Indicates whether multiple values can be entered in an input of the type email or file
pattern: Regex pattern the input's value must match to be valid
src: URL for image inputs
step: Specifies the legal number intervals for an input
use_map: Name of the image map used with the input
access_key: Provides a hint for generating a keyboard shortcut for the current element.
auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user.
content_editable: Indicates whether the element's content is editable.
@ -457,6 +501,30 @@ class TextField(ComponentNamespace):
type: Optional[Union[Var[str], str]] = None,
value: Optional[Union[Var[Union[float, int, str]], float, int, str]] = None,
list: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
accept: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
alt: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
auto_focus: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
capture: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
checked: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
default_checked: Optional[Union[Var[bool], bool]] = None,
dirname: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form_action: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form_enc_type: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
] = None,
form_method: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
form_no_validate: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
] = None,
form_target: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
max: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
min: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
multiple: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
pattern: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
src: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
step: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
use_map: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
access_key: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None,
auto_capitalize: Optional[
Union[Var[Union[bool, int, str]], bool, int, str]
@ -529,12 +597,12 @@ class TextField(ComponentNamespace):
Args:
*children: The children of the component.
size: Text field size "1" - "3"
size: Specifies the visible width of a text control
variant: Variant of text field: "classic" | "surface" | "soft"
color_scheme: Override theme color for text field
radius: Override theme radius for text field: "none" | "small" | "medium" | "large" | "full"
auto_complete: Whether the input should have autocomplete enabled
default_value: The value of the input when initially rendered.
default_value: The initial value for a text field
disabled: Disables the input
max_length: Specifies the maximum number of characters allowed in the input
min_length: Specifies the minimum number of characters required in the input
@ -545,11 +613,31 @@ class TextField(ComponentNamespace):
type: Specifies the type of input
value: Value of the input
list: References a datalist for suggested options
on_change: Fired when the value of the textarea changes.
on_focus: Fired when the textarea is focused.
on_blur: Fired when the textarea is blurred.
on_key_down: Fired when a key is pressed down.
on_key_up: Fired when a key is released.
on_change: Fired when the input value changes
on_focus: Fired when the input gains focus
on_blur: Fired when the input loses focus
on_key_down: Fired when a key is pressed down
on_key_up: Fired when a key is released
accept: Accepted types of files when the input is file type
alt: Alternate text for input type="image"
auto_focus: Automatically focuses the input when the page loads
capture: Captures media from the user (camera or microphone)
checked: Indicates whether the input is checked (for checkboxes and radio buttons)
default_checked: The initial value (for checkboxes and radio buttons)
dirname: Name part of the input to submit in 'dir' and 'name' pair when form is submitted
form: Associates the input with a form (by id)
form_action: URL to send the form data to (for type="submit" buttons)
form_enc_type: How the form data should be encoded when submitting to the server (for type="submit" buttons)
form_method: HTTP method to use for sending form data (for type="submit" buttons)
form_no_validate: Bypasses form validation when submitting (for type="submit" buttons)
form_target: Specifies where to display the response after submitting the form (for type="submit" buttons)
max: Specifies the maximum value for the input
min: Specifies the minimum value for the input
multiple: Indicates whether multiple values can be entered in an input of the type email or file
pattern: Regex pattern the input's value must match to be valid
src: URL for image inputs
step: Specifies the legal number intervals for an input
use_map: Name of the image map used with the input
access_key: Provides a hint for generating a keyboard shortcut for the current element.
auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user.
content_editable: Indicates whether the element's content is editable.

View File

@ -84,10 +84,10 @@ class ChartBase(RechartsCharts):
cls._ensure_valid_dimension("width", width)
cls._ensure_valid_dimension("height", height)
dim_props = dict(
width=width or "100%",
height=height or "100%",
)
dim_props = {
"width": width or "100%",
"height": height or "100%",
}
# Provide min dimensions so the graph always appears, even if the outer container is zero-size.
if width is None:
dim_props["min_width"] = 200

View File

@ -873,7 +873,7 @@ def get_config(reload: bool = False) -> Config:
with _config_lock:
sys_path = sys.path.copy()
sys.path.clear()
sys.path.append(os.getcwd())
sys.path.append(str(Path.cwd()))
try:
# Try to import the module with only the current directory in the path.
return _get_config()

View File

@ -132,6 +132,12 @@ class Hooks(SimpleNamespace):
}
})"""
class HookPosition(enum.Enum):
"""The position of the hook in the component."""
PRE_TRIGGER = "pre_trigger"
POST_TRIGGER = "post_trigger"
class MemoizationDisposition(enum.Enum):
"""The conditions under which a component should be memoized."""

View File

@ -10,7 +10,7 @@ class CustomComponents(SimpleNamespace):
"""Constants for the custom components."""
# The name of the custom components source directory.
SRC_DIR = "custom_components"
SRC_DIR = Path("custom_components")
# The name of the custom components pyproject.toml file.
PYPROJECT_TOML = Path("pyproject.toml")
# The name of the custom components package README file.

View File

@ -31,7 +31,7 @@ class RouteVar(SimpleNamespace):
# This subset of router_data is included in chained on_load events.
ROUTER_DATA_INCLUDE = set((RouteVar.PATH, RouteVar.ORIGIN, RouteVar.QUERY))
ROUTER_DATA_INCLUDE = {RouteVar.PATH, RouteVar.ORIGIN, RouteVar.QUERY}
class RouteRegex(SimpleNamespace):

View File

@ -150,27 +150,27 @@ def _populate_demo_app(name_variants: NameVariants):
from reflex.compiler import templates
from reflex.reflex import _init
demo_app_dir = name_variants.demo_app_dir
demo_app_dir = Path(name_variants.demo_app_dir)
demo_app_name = name_variants.demo_app_name
console.info(f"Creating app for testing: {demo_app_dir}")
console.info(f"Creating app for testing: {demo_app_dir!s}")
os.makedirs(demo_app_dir)
demo_app_dir.mkdir(exist_ok=True)
with set_directory(demo_app_dir):
# We start with the blank template as basis.
_init(name=demo_app_name, template=constants.Templates.DEFAULT)
# Then overwrite the app source file with the one we want for testing custom components.
# This source file is rendered using jinja template file.
with open(f"{demo_app_name}/{demo_app_name}.py", "w") as f:
f.write(
templates.CUSTOM_COMPONENTS_DEMO_APP.render(
custom_component_module_dir=name_variants.custom_component_module_dir,
module_name=name_variants.module_name,
)
demo_file = Path(f"{demo_app_name}/{demo_app_name}.py")
demo_file.write_text(
templates.CUSTOM_COMPONENTS_DEMO_APP.render(
custom_component_module_dir=name_variants.custom_component_module_dir,
module_name=name_variants.module_name,
)
)
# Append the custom component package to the requirements.txt file.
with open(f"{constants.RequirementsTxt.FILE}", "a") as f:
with Path(f"{constants.RequirementsTxt.FILE}").open(mode="a") as f:
f.write(f"{name_variants.package_name}\n")
@ -296,13 +296,14 @@ def _populate_custom_component_project(name_variants: NameVariants):
)
console.info(
f"Initializing the component directory: {CustomComponents.SRC_DIR}/{name_variants.custom_component_module_dir}"
f"Initializing the component directory: {CustomComponents.SRC_DIR / name_variants.custom_component_module_dir}"
)
os.makedirs(CustomComponents.SRC_DIR)
CustomComponents.SRC_DIR.mkdir(exist_ok=True)
with set_directory(CustomComponents.SRC_DIR):
os.makedirs(name_variants.custom_component_module_dir)
module_dir = Path(name_variants.custom_component_module_dir)
module_dir.mkdir(exist_ok=True, parents=True)
_write_source_and_init_py(
custom_component_src_dir=name_variants.custom_component_module_dir,
custom_component_src_dir=module_dir,
component_class_name=name_variants.component_class_name,
module_name=name_variants.module_name,
)
@ -814,7 +815,7 @@ def _validate_project_info():
)
pyproject_toml["project"] = project
try:
with open(CustomComponents.PYPROJECT_TOML, "w") as f:
with CustomComponents.PYPROJECT_TOML.open("w") as f:
tomlkit.dump(pyproject_toml, f)
except (OSError, TOMLKitError) as ex:
console.error(f"Unable to write to pyproject.toml due to {ex}")
@ -922,16 +923,15 @@ def _validate_url_with_protocol_prefix(url: str | None) -> bool:
def _get_file_from_prompt_in_loop() -> Tuple[bytes, str] | None:
image_file = file_extension = None
while image_file is None:
image_filepath = console.ask(
"Upload a preview image of your demo app (enter to skip)"
image_filepath = Path(
console.ask("Upload a preview image of your demo app (enter to skip)")
)
if not image_filepath:
break
file_extension = image_filepath.split(".")[-1]
file_extension = image_filepath.suffix
try:
with open(image_filepath, "rb") as f:
image_file = f.read()
return image_file, file_extension
image_file = image_filepath.read_bytes()
return image_file, file_extension
except OSError as ose:
console.error(f"Unable to read the {file_extension} file due to {ose}")
raise typer.Exit(code=1) from ose

View File

@ -25,6 +25,7 @@ from typing import (
overload,
)
import typing_extensions
from typing_extensions import (
Concatenate,
ParamSpec,
@ -296,7 +297,7 @@ class EventSpec(EventActionsMixin):
handler: EventHandler,
event_actions: Dict[str, Union[bool, int]] | None = None,
client_handler_name: str = "",
args: Tuple[Tuple[Var, Var], ...] = tuple(),
args: Tuple[Tuple[Var, Var], ...] = (),
):
"""Initialize an EventSpec.
@ -311,7 +312,7 @@ class EventSpec(EventActionsMixin):
object.__setattr__(self, "event_actions", event_actions)
object.__setattr__(self, "handler", handler)
object.__setattr__(self, "client_handler_name", client_handler_name)
object.__setattr__(self, "args", args or tuple())
object.__setattr__(self, "args", args or ())
def with_args(self, args: Tuple[Tuple[Var, Var], ...]) -> EventSpec:
"""Copy the event spec, with updated args.
@ -514,7 +515,7 @@ def no_args_event_spec() -> Tuple[()]:
Returns:
An empty tuple.
"""
return tuple() # type: ignore
return () # type: ignore
# These chains can be used for their side effects when no other events are desired.
@ -715,26 +716,61 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
)
@overload
def redirect(
path: str | Var[str],
external: Optional[bool] = False,
replace: Optional[bool] = False,
is_external: Optional[bool] = None,
replace: bool = False,
) -> EventSpec: ...
@overload
@typing_extensions.deprecated("`external` is deprecated use `is_external` instead")
def redirect(
path: str | Var[str],
is_external: Optional[bool] = None,
replace: bool = False,
external: Optional[bool] = None,
) -> EventSpec: ...
def redirect(
path: str | Var[str],
is_external: Optional[bool] = None,
replace: bool = False,
external: Optional[bool] = None,
) -> EventSpec:
"""Redirect to a new path.
Args:
path: The path to redirect to.
external: Whether to open in new tab or not.
is_external: Whether to open in new tab or not.
replace: If True, the current page will not create a new history entry.
external(Deprecated): Whether to open in new tab or not.
Returns:
An event to redirect to the path.
"""
if external is not None:
console.deprecate(
"The `external` prop in `rx.redirect`",
"use `is_external` instead.",
"0.6.6",
"0.7.0",
)
# is_external should take precedence over external.
is_external = (
(False if external is None else external)
if is_external is None
else is_external
)
return server_side(
"_redirect",
get_fn_signature(redirect),
path=path,
external=external,
external=is_external,
replace=replace,
)
@ -1102,9 +1138,7 @@ def run_script(
Var(javascript_code) if isinstance(javascript_code, str) else javascript_code
)
return call_function(
ArgsFunctionOperation.create(tuple(), javascript_code), callback
)
return call_function(ArgsFunctionOperation.create((), javascript_code), callback)
def get_event(state, event):
@ -1456,7 +1490,7 @@ def get_handler_args(
"""
args = inspect.getfullargspec(event_spec.handler.fn).args
return event_spec.args if len(args) > 1 else tuple()
return event_spec.args if len(args) > 1 else ()
def fix_events(

View File

@ -53,12 +53,12 @@ def get_engine_args(url: str | None = None) -> dict[str, Any]:
Returns:
The database engine arguments as a dict.
"""
kwargs: dict[str, Any] = dict(
kwargs: dict[str, Any] = {
# Print the SQL queries if the log level is INFO or lower.
echo=environment.SQLALCHEMY_ECHO.get(),
"echo": environment.SQLALCHEMY_ECHO.get(),
# Check connections before returning them.
pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(),
)
"pool_pre_ping": environment.SQLALCHEMY_POOL_PRE_PING.get(),
}
conf = get_config()
url = url or conf.db_url
if url is not None and url.startswith("sqlite"):

View File

@ -3,7 +3,6 @@
from __future__ import annotations
import atexit
import os
from pathlib import Path
from typing import List, Optional
@ -298,7 +297,7 @@ def export(
True, "--frontend-only", help="Export only frontend.", show_default=False
),
zip_dest_dir: str = typer.Option(
os.getcwd(),
str(Path.cwd()),
help="The directory to export the zip files to.",
show_default=False,
),
@ -443,13 +442,13 @@ def deploy(
hidden=True,
),
regions: List[str] = typer.Option(
list(),
[],
"-r",
"--region",
help="The regions to deploy to. `reflex cloud regions` For multiple envs, repeat this option, e.g. --region sjc --region iad",
),
envs: List[str] = typer.Option(
list(),
[],
"--env",
help="The environment variables to set: <key>=<value>. For multiple envs, repeat this option, e.g. --env k1=v2 --env k2=v2.",
),

View File

@ -437,9 +437,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
)
# Create a fresh copy of the backend variables for this instance
self._backend_vars = copy.deepcopy(
{name: item for name, item in self.backend_vars.items()}
)
self._backend_vars = copy.deepcopy(self.backend_vars)
def __repr__(self) -> str:
"""Get the string representation of the state.
@ -523,9 +521,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cls.inherited_backend_vars = parent_state.backend_vars
# Check if another substate class with the same name has already been defined.
if cls.get_name() in set(
c.get_name() for c in parent_state.class_subclasses
):
if cls.get_name() in {c.get_name() for c in parent_state.class_subclasses}:
# This should not happen, since we have added module prefix to state names in #3214
raise StateValueError(
f"The substate class '{cls.get_name()}' has been defined multiple times. "
@ -788,11 +784,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
)
# ComputedVar with cache=False always need to be recomputed
cls._always_dirty_computed_vars = set(
cls._always_dirty_computed_vars = {
cvar_name
for cvar_name, cvar in cls.computed_vars.items()
if not cvar._cache
)
}
# Any substate containing a ComputedVar with cache=False always needs to be recomputed
if cls._always_dirty_computed_vars:
@ -1862,11 +1858,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Returns:
Set of computed vars to include in the delta.
"""
return set(
return {
cvar
for cvar in self.computed_vars
if self.computed_vars[cvar].needs_update(instance=self)
)
}
def _dirty_computed_vars(
self, from_vars: set[str] | None = None, include_backend: bool = True
@ -1880,12 +1876,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Returns:
Set of computed vars to include in the delta.
"""
return set(
return {
cvar
for dirty_var in from_vars or self.dirty_vars
for cvar in self._computed_var_dependencies[dirty_var]
if include_backend or not self.computed_vars[cvar]._backend
)
}
@classmethod
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
@ -1895,16 +1891,16 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Set of State classes that may need to be fetched to recalc computed vars.
"""
# _always_dirty_substates need to be fetched to recalc computed vars.
fetch_substates = set(
fetch_substates = {
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
for substate_name in cls._always_dirty_substates
)
}
for dependent_substates in cls._substate_var_dependencies.values():
fetch_substates.update(
set(
{
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
for substate_name in dependent_substates
)
}
)
return fetch_substates
@ -2206,7 +2202,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return md5(
pickle.dumps(
list(sorted(_field_tuple(field_name) for field_name in cls.base_vars))
sorted(_field_tuple(field_name) for field_name in cls.base_vars)
)
).hexdigest()
@ -3354,7 +3350,7 @@ class StateManagerRedis(StateManager):
state_cls = self.state.get_class_substate(state_path)
else:
raise RuntimeError(
"StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
)
# The deserialized or newly created (sub)state instance.
@ -3653,33 +3649,30 @@ class MutableProxy(wrapt.ObjectProxy):
"""A proxy for a mutable object that tracks changes."""
# Methods on wrapped objects which should mark the state as dirty.
__mark_dirty_attrs__ = set(
[
"add",
"append",
"clear",
"difference_update",
"discard",
"extend",
"insert",
"intersection_update",
"pop",
"popitem",
"remove",
"reverse",
"setdefault",
"sort",
"symmetric_difference_update",
"update",
]
)
__mark_dirty_attrs__ = {
"add",
"append",
"clear",
"difference_update",
"discard",
"extend",
"insert",
"intersection_update",
"pop",
"popitem",
"remove",
"reverse",
"setdefault",
"sort",
"symmetric_difference_update",
"update",
}
# Methods on wrapped objects might return mutable objects that should be tracked.
__wrap_mutable_attrs__ = set(
[
"get",
"setdefault",
]
)
__wrap_mutable_attrs__ = {
"get",
"setdefault",
}
# These internal attributes on rx.Base should NOT be wrapped in a MutableProxy.
__never_wrap_base_attrs__ = set(Base.__dict__) - {"set"} | set(
@ -3722,7 +3715,7 @@ class MutableProxy(wrapt.ObjectProxy):
self,
wrapped=None,
instance=None,
args=tuple(),
args=(),
kwargs=None,
) -> Any:
"""Mark the state as dirty, then call a wrapped function.
@ -3978,7 +3971,7 @@ class ImmutableMutableProxy(MutableProxy):
self,
wrapped=None,
instance=None,
args=tuple(),
args=(),
kwargs=None,
) -> Any:
"""Raise an exception when an attempt is made to modify the object.

View File

@ -8,7 +8,6 @@ import dataclasses
import functools
import inspect
import os
import pathlib
import platform
import re
import signal
@ -20,6 +19,7 @@ import threading
import time
import types
from http.server import SimpleHTTPRequestHandler
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
@ -100,7 +100,7 @@ class chdir(contextlib.AbstractContextManager):
def __enter__(self):
"""Save current directory and perform chdir."""
self._old_cwd.append(os.getcwd())
self._old_cwd.append(Path.cwd())
os.chdir(self.path)
def __exit__(self, *excinfo):
@ -120,8 +120,8 @@ class AppHarness:
app_source: Optional[
Callable[[], None] | types.ModuleType | str | functools.partial[Any]
]
app_path: pathlib.Path
app_module_path: pathlib.Path
app_path: Path
app_module_path: Path
app_module: Optional[types.ModuleType] = None
app_instance: Optional[reflex.App] = None
frontend_process: Optional[subprocess.Popen] = None
@ -136,7 +136,7 @@ class AppHarness:
@classmethod
def create(
cls,
root: pathlib.Path,
root: Path,
app_source: Optional[
Callable[[], None] | types.ModuleType | str | functools.partial[Any]
] = None,
@ -814,7 +814,7 @@ class AppHarness:
class SimpleHTTPRequestHandlerCustomErrors(SimpleHTTPRequestHandler):
"""SimpleHTTPRequestHandler with custom error page handling."""
def __init__(self, *args, error_page_map: dict[int, pathlib.Path], **kwargs):
def __init__(self, *args, error_page_map: dict[int, Path], **kwargs):
"""Initialize the handler.
Args:
@ -857,8 +857,8 @@ class Subdir404TCPServer(socketserver.TCPServer):
def __init__(
self,
*args,
root: pathlib.Path,
error_page_map: dict[int, pathlib.Path] | None,
root: Path,
error_page_map: dict[int, Path] | None,
**kwargs,
):
"""Initialize the server.

View File

@ -150,7 +150,7 @@ def zip_app(
_zip(
component_name=constants.ComponentName.BACKEND,
target=zip_dest_dir / constants.ComponentName.BACKEND.zip(),
root_dir=Path("."),
root_dir=Path.cwd(),
dirs_to_exclude={"__pycache__"},
files_to_exclude=files_to_exclude,
top_level_dirs_to_exclude={"assets"},

View File

@ -24,7 +24,7 @@ from reflex.utils.prerequisites import get_web_dir
frontend_process = None
def detect_package_change(json_file_path: str) -> str:
def detect_package_change(json_file_path: Path) -> str:
"""Calculates the SHA-256 hash of a JSON file and returns it as a hexadecimal string.
Args:
@ -37,7 +37,7 @@ def detect_package_change(json_file_path: str) -> str:
>>> detect_package_change("package.json")
'a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0u1v2w3x4y5z6a7b8c9d0e1f2'
"""
with open(json_file_path, "r") as file:
with json_file_path.open("r") as file:
json_data = json.load(file)
# Calculate the hash
@ -81,7 +81,7 @@ def run_process_and_launch_url(run_command: list[str], backend_present=True):
from reflex.utils import processes
json_file_path = get_web_dir() / constants.PackageJson.PATH
last_hash = detect_package_change(str(json_file_path))
last_hash = detect_package_change(json_file_path)
process = None
first_run = True
@ -117,14 +117,14 @@ def run_process_and_launch_url(run_command: list[str], backend_present=True):
console.print("New packages detected: Updating app...")
else:
if any(
[x in line for x in ("bin executable does not exist on disk",)]
x in line for x in ("bin executable does not exist on disk",)
):
console.error(
"Try setting `REFLEX_USE_NPM=1` and re-running `reflex init` and `reflex run` to use npm instead of bun:\n"
"`REFLEX_USE_NPM=1 reflex init`\n"
"`REFLEX_USE_NPM=1 reflex run`"
)
new_hash = detect_package_change(str(json_file_path))
new_hash = detect_package_change(json_file_path)
if new_hash != last_hash:
last_hash = new_hash
kill(process.pid)

View File

@ -1,6 +1,5 @@
"""Export utilities."""
import os
from pathlib import Path
from typing import Optional
@ -15,7 +14,7 @@ def export(
zipping: bool = True,
frontend: bool = True,
backend: bool = True,
zip_dest_dir: str = os.getcwd(),
zip_dest_dir: str = str(Path.cwd()),
upload_db_file: bool = False,
api_url: Optional[str] = None,
deploy_url: Optional[str] = None,

View File

@ -205,14 +205,14 @@ def update_json_file(file_path: str | Path, update_dict: dict[str, int | str]):
# Read the existing json object from the file.
json_object = {}
if fp.stat().st_size:
with open(fp) as f:
with fp.open() as f:
json_object = json.load(f)
# Update the json object with the new data.
json_object.update(update_dict)
# Write the updated json object to the file
with open(fp, "w") as f:
with fp.open("w") as f:
json.dump(json_object, f, ensure_ascii=False)

View File

@ -290,7 +290,7 @@ def get_app(reload: bool = False) -> ModuleType:
"If this error occurs in a reflex test case, ensure that `get_app` is mocked."
)
module = config.module
sys.path.insert(0, os.getcwd())
sys.path.insert(0, str(Path.cwd()))
app = __import__(module, fromlist=(constants.CompileVars.APP,))
if reload:
@ -438,9 +438,11 @@ def create_config(app_name: str):
from reflex.compiler import templates
config_name = f"{re.sub(r'[^a-zA-Z]', '', app_name).capitalize()}Config"
with open(constants.Config.FILE, "w") as f:
console.debug(f"Creating {constants.Config.FILE}")
f.write(templates.RXCONFIG.render(app_name=app_name, config_name=config_name))
console.debug(f"Creating {constants.Config.FILE}")
constants.Config.FILE.write_text(
templates.RXCONFIG.render(app_name=app_name, config_name=config_name)
)
def initialize_gitignore(
@ -494,14 +496,14 @@ def initialize_requirements_txt():
console.debug(f"Detected encoding for {fp} as {encoding}.")
try:
other_requirements_exist = False
with open(fp, "r", encoding=encoding) as f:
with fp.open("r", encoding=encoding) as f:
for req in f:
# Check if we have a package name that is reflex
if re.match(r"^reflex[^a-zA-Z0-9]", req):
console.debug(f"{fp} already has reflex as dependency.")
return
other_requirements_exist = True
with open(fp, "a", encoding=encoding) as f:
with fp.open("a", encoding=encoding) as f:
preceding_newline = "\n" if other_requirements_exist else ""
f.write(
f"{preceding_newline}{constants.RequirementsTxt.DEFAULTS_STUB}{constants.Reflex.VERSION}\n"
@ -699,7 +701,7 @@ def _update_next_config(
}
if transpile_packages:
next_config["transpilePackages"] = list(
set((format_library_name(p) for p in transpile_packages))
{format_library_name(p) for p in transpile_packages}
)
if export:
next_config["output"] = "export"
@ -732,13 +734,13 @@ def download_and_run(url: str, *args, show_status: bool = False, **env):
response.raise_for_status()
# Save the script to a temporary file.
script = tempfile.NamedTemporaryFile()
with open(script.name, "w") as f:
f.write(response.text)
script = Path(tempfile.NamedTemporaryFile().name)
script.write_text(response.text)
# Run the script.
env = {**os.environ, **env}
process = processes.new_process(["bash", f.name, *args], env=env)
process = processes.new_process(["bash", str(script), *args], env=env)
show = processes.show_status if show_status else processes.show_logs
show(f"Installing {url}", process)
@ -752,14 +754,14 @@ def download_and_extract_fnm_zip():
# Download the zip file
url = constants.Fnm.INSTALL_URL
console.debug(f"Downloading {url}")
fnm_zip_file = constants.Fnm.DIR / f"{constants.Fnm.FILENAME}.zip"
fnm_zip_file: Path = constants.Fnm.DIR / f"{constants.Fnm.FILENAME}.zip"
# Function to download and extract the FNM zip release.
try:
# Download the FNM zip release.
# TODO: show progress to improve UX
response = net.get(url, follow_redirects=True)
response.raise_for_status()
with open(fnm_zip_file, "wb") as output_file:
with fnm_zip_file.open("wb") as output_file:
for chunk in response.iter_bytes():
output_file.write(chunk)
@ -807,7 +809,7 @@ def install_node():
)
else: # All other platforms (Linux, MacOS).
# Add execute permissions to fnm executable.
os.chmod(constants.Fnm.EXE, stat.S_IXUSR)
constants.Fnm.EXE.chmod(stat.S_IXUSR)
# Install node.
# Specify arm64 arch explicitly for M1s and M2s.
architecture_arg = (
@ -925,7 +927,7 @@ def cached_procedure(cache_file: str, payload_fn: Callable[..., str]):
@cached_procedure(
cache_file=str(get_web_dir() / "reflex.install_frontend_packages.cached"),
payload_fn=lambda p, c: f"{sorted(list(p))!r},{c.json()}",
payload_fn=lambda p, c: f"{sorted(p)!r},{c.json()}",
)
def install_frontend_packages(packages: set[str], config: Config):
"""Installs the base and custom frontend packages.
@ -1300,7 +1302,7 @@ def fetch_app_templates(version: str) -> dict[str, Template]:
for tp in templates_data:
if tp["hidden"] or tp["code_url"] is None:
continue
known_fields = set(f.name for f in dataclasses.fields(Template))
known_fields = {f.name for f in dataclasses.fields(Template)}
filtered_templates[tp["name"]] = Template(
**{k: v for k, v in tp.items() if k in known_fields}
)
@ -1326,7 +1328,7 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str
raise typer.Exit(1) from ose
# Use httpx GET with redirects to download the zip file.
zip_file_path = Path(temp_dir) / "template.zip"
zip_file_path: Path = Path(temp_dir) / "template.zip"
try:
# Note: following redirects can be risky. We only allow this for reflex built templates at the moment.
response = net.get(template_url, follow_redirects=True)
@ -1336,9 +1338,8 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str
console.error(f"Failed to download the template: {he}")
raise typer.Exit(1) from he
try:
with open(zip_file_path, "wb") as f:
f.write(response.content)
console.debug(f"Downloaded the zip to {zip_file_path}")
zip_file_path.write_bytes(response.content)
console.debug(f"Downloaded the zip to {zip_file_path}")
except OSError as ose:
console.error(f"Unable to write the downloaded zip to disk {ose}")
raise typer.Exit(1) from ose

View File

@ -9,6 +9,7 @@ from .base import get_unique_variable_name as get_unique_variable_name
from .base import get_uuid_string_var as get_uuid_string_var
from .base import var_operation as var_operation
from .base import var_operation_return as var_operation_return
from .datetime import DateTimeVar as DateTimeVar
from .function import FunctionStringVar as FunctionStringVar
from .function import FunctionVar as FunctionVar
from .function import VarOperationCall as VarOperationCall

View File

@ -42,7 +42,8 @@ from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints,
from reflex import constants
from reflex.base import Base
from reflex.utils import console, imports, serializers, types
from reflex.constants.compiler import Hooks
from reflex.utils import console, exceptions, imports, serializers, types
from reflex.utils.exceptions import (
VarAttributeError,
VarDependencyError,
@ -115,12 +116,20 @@ class VarData:
# Hooks that need to be present in the component to render this var
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
# Dependencies of the var
deps: Tuple[Var, ...] = dataclasses.field(default_factory=tuple)
# Position of the hook in the component
position: Hooks.HookPosition | None = None
def __init__(
self,
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
):
"""Initialize the var data.
@ -129,6 +138,8 @@ class VarData:
field_name: The name of the field in the state.
imports: Imports needed to render this var.
hooks: Hooks that need to be present in the component to render this var.
deps: Dependencies of the var for useCallback.
position: Position of the hook in the component.
"""
immutable_imports: ImmutableParsedImportDict = tuple(
sorted(
@ -139,6 +150,8 @@ class VarData:
object.__setattr__(self, "field_name", field_name)
object.__setattr__(self, "imports", immutable_imports)
object.__setattr__(self, "hooks", tuple(hooks or {}))
object.__setattr__(self, "deps", tuple(deps or []))
object.__setattr__(self, "position", position or None)
def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict.
@ -146,7 +159,7 @@ class VarData:
Returns:
The imports as a mutable dict.
"""
return dict((k, list(v)) for k, v in self.imports)
return {k: list(v) for k, v in self.imports}
def merge(*all: VarData | None) -> VarData | None:
"""Merge multiple var data objects.
@ -154,6 +167,9 @@ class VarData:
Args:
*all: The var data objects to merge.
Raises:
ReflexError: If trying to merge VarData with different positions.
Returns:
The merged var data object.
@ -184,12 +200,32 @@ class VarData:
*(var_data.imports for var_data in all_var_datas)
)
if state or _imports or hooks or field_name:
deps = [dep for var_data in all_var_datas for dep in var_data.deps]
positions = list(
{
var_data.position
for var_data in all_var_datas
if var_data.position is not None
}
)
if positions:
if len(positions) > 1:
raise exceptions.ReflexError(
f"Cannot merge var data with different positions: {positions}"
)
position = positions[0]
else:
position = None
if state or _imports or hooks or field_name or deps or position:
return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
deps=deps,
position=position,
)
return None
@ -200,7 +236,14 @@ class VarData:
Returns:
True if any field is set to a non-default value.
"""
return bool(self.state or self.imports or self.hooks or self.field_name)
return bool(
self.state
or self.imports
or self.hooks
or self.field_name
or self.deps
or self.position
)
@classmethod
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
@ -480,7 +523,6 @@ class Var(Generic[VAR_TYPE]):
raise TypeError(
"The _var_full_name_needs_state_prefix argument is not supported for Var."
)
value_with_replaced = dataclasses.replace(
self,
_var_type=_var_type or self._var_type,
@ -1591,14 +1633,12 @@ class CachedVarOperation:
The cached VarData.
"""
return VarData.merge(
*map(
lambda value: (
value._get_all_var_data() if isinstance(value, Var) else None
),
map(
lambda field: getattr(self, field.name),
dataclasses.fields(self), # type: ignore
),
*(
value._get_all_var_data() if isinstance(value, Var) else None
for value in (
getattr(self, field.name)
for field in dataclasses.fields(self) # type: ignore
)
),
self._var_data,
)
@ -1889,20 +1929,20 @@ class ComputedVar(Var[RETURN_TYPE]):
Raises:
TypeError: If kwargs contains keys that are not allowed.
"""
field_values = dict(
fget=kwargs.pop("fget", self._fget),
initial_value=kwargs.pop("initial_value", self._initial_value),
cache=kwargs.pop("cache", self._cache),
deps=kwargs.pop("deps", self._static_deps),
auto_deps=kwargs.pop("auto_deps", self._auto_deps),
interval=kwargs.pop("interval", self._update_interval),
backend=kwargs.pop("backend", self._backend),
_js_expr=kwargs.pop("_js_expr", self._js_expr),
_var_type=kwargs.pop("_var_type", self._var_type),
_var_data=kwargs.pop(
field_values = {
"fget": kwargs.pop("fget", self._fget),
"initial_value": kwargs.pop("initial_value", self._initial_value),
"cache": kwargs.pop("cache", self._cache),
"deps": kwargs.pop("deps", self._static_deps),
"auto_deps": kwargs.pop("auto_deps", self._auto_deps),
"interval": kwargs.pop("interval", self._update_interval),
"backend": kwargs.pop("backend", self._backend),
"_js_expr": kwargs.pop("_js_expr", self._js_expr),
"_var_type": kwargs.pop("_var_type", self._var_type),
"_var_data": kwargs.pop(
"_var_data", VarData.merge(self._var_data, merge_var_data)
),
)
}
if kwargs:
unexpected_kwargs = ", ".join(kwargs.keys())
@ -2371,10 +2411,7 @@ class CustomVarOperation(CachedVarOperation, Var[T]):
The cached VarData.
"""
return VarData.merge(
*map(
lambda arg: arg[1]._get_all_var_data(),
self._args,
),
*(arg[1]._get_all_var_data() for arg in self._args),
self._return._get_all_var_data(),
self._var_data,
)

222
reflex/vars/datetime.py Normal file
View File

@ -0,0 +1,222 @@
"""Immutable datetime and date vars."""
from __future__ import annotations
import dataclasses
import sys
from datetime import date, datetime
from typing import Any, NoReturn, TypeVar, Union, overload
from reflex.utils.exceptions import VarTypeError
from reflex.vars.number import BooleanVar
from .base import (
CustomVarOperationReturn,
LiteralVar,
Var,
VarData,
var_operation,
var_operation_return,
)
DATETIME_T = TypeVar("DATETIME_T", datetime, date)
datetime_types = Union[datetime, date]
def raise_var_type_error():
"""Raise a VarTypeError.
Raises:
VarTypeError: Cannot compare a datetime object with a non-datetime object.
"""
raise VarTypeError("Cannot compare a datetime object with a non-datetime object.")
class DateTimeVar(Var[DATETIME_T], python_types=(datetime, date)):
"""A variable that holds a datetime or date object."""
@overload
def __lt__(self, other: datetime_types) -> BooleanVar: ...
@overload
def __lt__(self, other: NoReturn) -> NoReturn: ...
def __lt__(self, other: Any):
"""Less than comparison.
Args:
other: The other datetime to compare.
Returns:
The result of the comparison.
"""
if not isinstance(other, DATETIME_TYPES):
raise_var_type_error()
return date_lt_operation(self, other)
@overload
def __le__(self, other: datetime_types) -> BooleanVar: ...
@overload
def __le__(self, other: NoReturn) -> NoReturn: ...
def __le__(self, other: Any):
"""Less than or equal comparison.
Args:
other: The other datetime to compare.
Returns:
The result of the comparison.
"""
if not isinstance(other, DATETIME_TYPES):
raise_var_type_error()
return date_le_operation(self, other)
@overload
def __gt__(self, other: datetime_types) -> BooleanVar: ...
@overload
def __gt__(self, other: NoReturn) -> NoReturn: ...
def __gt__(self, other: Any):
"""Greater than comparison.
Args:
other: The other datetime to compare.
Returns:
The result of the comparison.
"""
if not isinstance(other, DATETIME_TYPES):
raise_var_type_error()
return date_gt_operation(self, other)
@overload
def __ge__(self, other: datetime_types) -> BooleanVar: ...
@overload
def __ge__(self, other: NoReturn) -> NoReturn: ...
def __ge__(self, other: Any):
"""Greater than or equal comparison.
Args:
other: The other datetime to compare.
Returns:
The result of the comparison.
"""
if not isinstance(other, DATETIME_TYPES):
raise_var_type_error()
return date_ge_operation(self, other)
@var_operation
def date_gt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
"""Greater than comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(rhs, lhs, strict=True)
@var_operation
def date_lt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
"""Less than comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(lhs, rhs, strict=True)
@var_operation
def date_le_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
"""Less than or equal comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(lhs, rhs)
@var_operation
def date_ge_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn:
"""Greater than or equal comparison.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
Returns:
The result of the operation.
"""
return date_compare_operation(rhs, lhs)
def date_compare_operation(
lhs: DateTimeVar[DATETIME_T] | Any,
rhs: DateTimeVar[DATETIME_T] | Any,
strict: bool = False,
) -> CustomVarOperationReturn:
"""Check if the value is less than the other value.
Args:
lhs: The left-hand side of the operation.
rhs: The right-hand side of the operation.
strict: Whether to use strict comparison.
Returns:
The result of the operation.
"""
return var_operation_return(
f"({lhs} { '<' if strict else '<='} {rhs})",
bool,
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralDatetimeVar(LiteralVar, DateTimeVar):
"""Base class for immutable datetime and date vars."""
_var_value: datetime | date = dataclasses.field(default=datetime.now())
@classmethod
def create(cls, value: datetime | date, _var_data: VarData | None = None):
"""Create a new instance of the class.
Args:
value: The value to set.
Returns:
LiteralDatetimeVar: The new instance of the class.
"""
js_expr = f'"{value!s}"'
return cls(
_js_expr=js_expr,
_var_type=type(value),
_var_value=value,
_var_data=_var_data,
)
DATETIME_TYPES = (datetime, date, DateTimeVar)

View File

@ -292,7 +292,7 @@ class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
class DestructuredArg:
"""Class for destructured arguments."""
fields: Tuple[str, ...] = tuple()
fields: Tuple[str, ...] = ()
rest: Optional[str] = None
def to_javascript(self) -> str:
@ -314,7 +314,7 @@ class DestructuredArg:
class FunctionArgs:
"""Class for function arguments."""
args: Tuple[Union[str, DestructuredArg], ...] = tuple()
args: Tuple[Union[str, DestructuredArg], ...] = ()
rest: Optional[str] = None

View File

@ -51,7 +51,7 @@ def raise_unsupported_operand_types(
VarTypeError: The operand types are unsupported.
"""
raise VarTypeError(
f"Unsupported Operand type(s) for {operator}: {', '.join(map(lambda t: t.__name__, operands_types))}"
f"Unsupported Operand type(s) for {operator}: {', '.join(t.__name__ for t in operands_types)}"
)

View File

@ -1177,7 +1177,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
if num_args == 0:
return_value = fn()
function_var = ArgsFunctionOperation.create(tuple(), return_value)
function_var = ArgsFunctionOperation.create((), return_value)
else:
# generic number var
number_var = Var("").to(NumberVar, int)

View File

@ -15,6 +15,7 @@ from .utils import SessionStorage
def CallScript():
"""A test app for browser javascript integration."""
from pathlib import Path
from typing import Dict, List, Optional, Union
import reflex as rx
@ -186,8 +187,7 @@ def CallScript():
self.reset()
app = rx.App(state=rx.State)
with open("assets/external.js", "w") as f:
f.write(external_scripts)
Path("assets/external.js").write_text(external_scripts)
@app.add_page
def index():

View File

@ -0,0 +1,87 @@
from typing import Generator
import pytest
from playwright.sync_api import Page, expect
from reflex.testing import AppHarness
def DatetimeOperationsApp():
from datetime import datetime
import reflex as rx
class DtOperationsState(rx.State):
date1: datetime = datetime(2021, 1, 1)
date2: datetime = datetime(2031, 1, 1)
date3: datetime = datetime(2021, 1, 1)
app = rx.App(state=DtOperationsState)
@app.add_page
def index():
return rx.vstack(
rx.text(DtOperationsState.date1, id="date1"),
rx.text(DtOperationsState.date2, id="date2"),
rx.text(DtOperationsState.date3, id="date3"),
rx.text("Operations between date1 and date2"),
rx.text(DtOperationsState.date1 == DtOperationsState.date2, id="1_eq_2"),
rx.text(DtOperationsState.date1 != DtOperationsState.date2, id="1_neq_2"),
rx.text(DtOperationsState.date1 < DtOperationsState.date2, id="1_lt_2"),
rx.text(DtOperationsState.date1 <= DtOperationsState.date2, id="1_le_2"),
rx.text(DtOperationsState.date1 > DtOperationsState.date2, id="1_gt_2"),
rx.text(DtOperationsState.date1 >= DtOperationsState.date2, id="1_ge_2"),
rx.text("Operations between date1 and date3"),
rx.text(DtOperationsState.date1 == DtOperationsState.date3, id="1_eq_3"),
rx.text(DtOperationsState.date1 != DtOperationsState.date3, id="1_neq_3"),
rx.text(DtOperationsState.date1 < DtOperationsState.date3, id="1_lt_3"),
rx.text(DtOperationsState.date1 <= DtOperationsState.date3, id="1_le_3"),
rx.text(DtOperationsState.date1 > DtOperationsState.date3, id="1_gt_3"),
rx.text(DtOperationsState.date1 >= DtOperationsState.date3, id="1_ge_3"),
)
@pytest.fixture()
def datetime_operations_app(tmp_path_factory) -> Generator[AppHarness, None, None]:
"""Start Table app at tmp_path via AppHarness.
Args:
tmp_path_factory: pytest tmp_path_factory fixture
Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path_factory.mktemp("datetime_operations_app"),
app_source=DatetimeOperationsApp, # type: ignore
) as harness:
assert harness.app_instance is not None, "app is not running"
yield harness
def test_datetime_operations(datetime_operations_app: AppHarness, page: Page):
assert datetime_operations_app.frontend_url is not None
page.goto(datetime_operations_app.frontend_url)
expect(page).to_have_url(datetime_operations_app.frontend_url + "/")
# Check the actual values
expect(page.locator("id=date1")).to_have_text("2021-01-01 00:00:00")
expect(page.locator("id=date2")).to_have_text("2031-01-01 00:00:00")
expect(page.locator("id=date3")).to_have_text("2021-01-01 00:00:00")
# Check the operations between date1 and date2
expect(page.locator("id=1_eq_2")).to_have_text("false")
expect(page.locator("id=1_neq_2")).to_have_text("true")
expect(page.locator("id=1_lt_2")).to_have_text("true")
expect(page.locator("id=1_le_2")).to_have_text("true")
expect(page.locator("id=1_gt_2")).to_have_text("false")
expect(page.locator("id=1_ge_2")).to_have_text("false")
# Check the operations between date1 and date3
expect(page.locator("id=1_eq_3")).to_have_text("true")
expect(page.locator("id=1_neq_3")).to_have_text("false")
expect(page.locator("id=1_lt_3")).to_have_text("false")
expect(page.locator("id=1_le_3")).to_have_text("true")
expect(page.locator("id=1_gt_3")).to_have_text("false")
expect(page.locator("id=1_ge_3")).to_have_text("true")

View File

@ -12,7 +12,7 @@ def test_websocket_target_url():
url = WebsocketTargetURL.create()
var_data = url._get_all_var_data()
assert var_data is not None
assert sorted(tuple((key for key, _ in var_data.imports))) == sorted(
assert sorted(key for key, _ in var_data.imports) == sorted(
("$/utils/state", "$/env.json")
)
@ -20,7 +20,7 @@ def test_websocket_target_url():
def test_connection_banner():
banner = ConnectionBanner.create()
_imports = banner._get_all_imports(collapse=True)
assert sorted(tuple(_imports)) == sorted(
assert sorted(_imports) == sorted(
(
"react",
"$/utils/context",
@ -38,7 +38,7 @@ def test_connection_banner():
def test_connection_modal():
modal = ConnectionModal.create()
_imports = modal._get_all_imports(collapse=True)
assert sorted(tuple(_imports)) == sorted(
assert sorted(_imports) == sorted(
(
"react",
"$/utils/context",

View File

@ -61,14 +61,13 @@ class FileUploadState(State):
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
assert file.filename is not None
outfile = self._tmp_path / file.filename
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
outfile.write_bytes(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
@rx.event(background=True)
@ -109,14 +108,13 @@ class ChildFileUploadState(FileStateBase1):
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
assert file.filename is not None
outfile = self._tmp_path / file.filename
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
outfile.write_bytes(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
@rx.event(background=True)
@ -157,14 +155,13 @@ class GrandChildFileUploadState(FileStateBase2):
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
assert file.filename is not None
outfile = self._tmp_path / file.filename
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
outfile.write_bytes(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
@rx.event(background=True)

View File

@ -105,8 +105,8 @@ def test_initialize_requirements_txt_no_op(mocker):
return_value=Mock(best=lambda: Mock(encoding="utf-8")),
)
mock_fp_touch = mocker.patch("pathlib.Path.touch")
open_mock = mock_open(read_data="reflex==0.2.9")
mocker.patch("builtins.open", open_mock)
open_mock = mock_open(read_data="reflex==0.6.7")
mocker.patch("pathlib.Path.open", open_mock)
initialize_requirements_txt()
assert open_mock.call_count == 1
assert open_mock.call_args.kwargs["encoding"] == "utf-8"
@ -122,7 +122,7 @@ def test_initialize_requirements_txt_missing_reflex(mocker):
return_value=Mock(best=lambda: Mock(encoding="utf-8")),
)
open_mock = mock_open(read_data="random-package=1.2.3")
mocker.patch("builtins.open", open_mock)
mocker.patch("pathlib.Path.open", open_mock)
initialize_requirements_txt()
# Currently open for read, then open for append
assert open_mock.call_count == 2
@ -138,7 +138,7 @@ def test_initialize_requirements_txt_not_exist(mocker):
# File does not exist, create file with reflex
mocker.patch("pathlib.Path.exists", return_value=False)
open_mock = mock_open()
mocker.patch("builtins.open", open_mock)
mocker.patch("pathlib.Path.open", open_mock)
initialize_requirements_txt()
assert open_mock.call_count == 2
# By default, use utf-8 encoding
@ -170,7 +170,7 @@ def test_requirements_txt_other_encoding(mocker):
)
initialize_requirements_txt()
open_mock = mock_open(read_data="random-package=1.2.3")
mocker.patch("builtins.open", open_mock)
mocker.patch("pathlib.Path.open", open_mock)
initialize_requirements_txt()
# Currently open for read, then open for append
assert open_mock.call_count == 2

View File

@ -372,7 +372,7 @@ def test_basic_operations(TestObj):
"var, expected",
[
(v([1, 2, 3]), "[1, 2, 3]"),
(v(set([1, 2, 3])), "[1, 2, 3]"),
(v({1, 2, 3}), "[1, 2, 3]"),
(v(["1", "2", "3"]), '["1", "2", "3"]'),
(
Var(_js_expr="foo")._var_set_state("state").to(list),
@ -903,7 +903,7 @@ def test_literal_var():
True,
False,
None,
set([1, 2, 3]),
{1, 2, 3},
]
)
assert (

View File

@ -222,9 +222,10 @@ def test_serialize(value: Any, expected: str):
'"2021-01-01 01:01:01.000001"',
True,
),
(datetime.date(2021, 1, 1), '"2021-01-01"', True),
(Color(color="slate", shade=1), '"var(--slate-1)"', True),
(BaseSubclass, '"BaseSubclass"', True),
(Path("."), '"."', True),
(Path(), '"."', True),
],
)
def test_serialize_var_to_str(value: Any, expected: str, exp_var_is_string: bool):

View File

@ -270,7 +270,7 @@ def test_unsupported_literals(cls: type):
("appname2.io", "AppnameioConfig"),
],
)
def test_create_config(app_name, expected_config_name, mocker):
def test_create_config(app_name: str, expected_config_name: str, mocker):
"""Test templates.RXCONFIG is formatted with correct app name and config class name.
Args:
@ -278,7 +278,7 @@ def test_create_config(app_name, expected_config_name, mocker):
expected_config_name: Expected config name.
mocker: Mocker object.
"""
mocker.patch("builtins.open")
mocker.patch("pathlib.Path.write_text")
tmpl_mock = mocker.patch("reflex.compiler.templates.RXCONFIG")
prerequisites.create_config(app_name)
tmpl_mock.render.assert_called_with(
@ -464,7 +464,7 @@ def test_node_install_unix(tmp_path, mocker, machine, system):
mocker.patch("httpx.stream", return_value=Resp())
download = mocker.patch("reflex.utils.prerequisites.download_and_extract_fnm_zip")
process = mocker.patch("reflex.utils.processes.new_process")
chmod = mocker.patch("reflex.utils.prerequisites.os.chmod")
chmod = mocker.patch("pathlib.Path.chmod")
mocker.patch("reflex.utils.processes.stream_logs")
prerequisites.install_node()