From 6a9f83cc2d3e56128d6245870ae0c4c9e71ec2b9 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Wed, 2 Oct 2024 18:04:04 -0700 Subject: [PATCH 01/16] set loglevel to info with hosting cli (#4043) * set loglevel to info with hosting cli * reduce reused logic --- reflex/constants/base.py | 8 ++++++++ reflex/reflex.py | 16 +++++----------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 225e8000b..0914c087f 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -191,6 +191,14 @@ class LogLevel(str, Enum): levels = list(LogLevel) return levels.index(self) <= levels.index(other) + def subprocess_level(self): + """Return the log level for the subprocess. + + Returns: + The log level for the subprocess + """ + return self if self != LogLevel.DEFAULT else LogLevel.INFO + # Server socket configuration variables POLLING_MAX_HTTP_BUFFER_SIZE = 1000 * 1000 diff --git a/reflex/reflex.py b/reflex/reflex.py index 43ebe2eb4..99dccf831 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -15,7 +15,6 @@ from reflex_cli.utils import dependency from reflex import constants from reflex.config import get_config -from reflex.constants.base import LogLevel from reflex.custom_components.custom_components import custom_components_cli from reflex.state import reset_disk_state_manager from reflex.utils import console, redir, telemetry @@ -247,11 +246,6 @@ def _run( setup_frontend(Path.cwd()) commands.append((frontend_cmd, Path.cwd(), frontend_port, backend)) - # If no loglevel is specified, set the subprocesses loglevel to WARNING. - subprocesses_loglevel = ( - loglevel if loglevel != LogLevel.DEFAULT else LogLevel.WARNING - ) - # In prod mode, run the backend on a separate thread. if backend and env == constants.Env.PROD: commands.append( @@ -259,7 +253,7 @@ def _run( backend_cmd, backend_host, backend_port, - subprocesses_loglevel, + loglevel.subprocess_level(), frontend, ) ) @@ -269,7 +263,7 @@ def _run( # In dev mode, run the backend on the main thread. if backend and env == constants.Env.DEV: backend_cmd( - backend_host, int(backend_port), subprocesses_loglevel, frontend + backend_host, int(backend_port), loglevel.subprocess_level(), frontend ) # The windows uvicorn bug workaround # https://github.com/reflex-dev/reflex/issues/2335 @@ -342,7 +336,7 @@ def export( backend=backend, zip_dest_dir=zip_dest_dir, upload_db_file=upload_db_file, - loglevel=loglevel, + loglevel=loglevel.subprocess_level(), ) @@ -577,7 +571,7 @@ def deploy( frontend=frontend, backend=backend, zipping=zipping, - loglevel=loglevel, + loglevel=loglevel.subprocess_level(), upload_db_file=upload_db_file, ), key=key, @@ -591,7 +585,7 @@ def deploy( interactive=interactive, with_metrics=with_metrics, with_tracing=with_tracing, - loglevel=loglevel.value, + loglevel=loglevel.subprocess_level(), ) From f3be9a33058ec7b4f6c30717a91be111b8c8cdb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Thu, 3 Oct 2024 06:04:44 -0700 Subject: [PATCH 02/16] fix granian message (#4037) --- reflex/utils/exec.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index b6550fdde..acb69ee19 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -284,7 +284,7 @@ def run_granian_backend(host, port, loglevel: LogLevel): ).serve() except ImportError: console.error( - 'InstallError: REFLEX_USE_GRANIAN is set but `granian` is not installed. (run `pip install "granian>=1.6.0"`)' + 'InstallError: REFLEX_USE_GRANIAN is set but `granian` is not installed. (run `pip install "granian[reload]>=1.6.0"`)' ) os._exit(1) @@ -410,7 +410,7 @@ def run_granian_backend_prod(host, port, loglevel): ) except ImportError: console.error( - 'InstallError: REFLEX_USE_GRANIAN is set but `granian` is not installed. (run `pip install "granian>=1.6.0"`)' + 'InstallError: REFLEX_USE_GRANIAN is set but `granian` is not installed. (run `pip install "granian[reload]>=1.6.0"`)' ) From 3f5194316298eb936f2f3c498f197e41bec56eec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Thu, 3 Oct 2024 08:50:39 -0700 Subject: [PATCH 03/16] use pathlib as much as possible (#3967) * use pathlib as much as possible * fixstuff * break locally to unbreak in CI :shrug: * add type on env * debug attempt 1 * debugged * oops, there is the actual fix * fix 3.9 compat --- benchmarks/benchmark_lighthouse.py | 35 +++-- benchmarks/benchmark_package_size.py | 17 ++- benchmarks/benchmark_web_size.py | 3 +- benchmarks/utils.py | 15 +- reflex/compiler/compiler.py | 2 +- reflex/compiler/utils.py | 9 +- reflex/config.py | 5 +- reflex/constants/base.py | 26 ++-- reflex/constants/config.py | 7 +- reflex/constants/custom_components.py | 7 +- reflex/constants/installer.py | 29 ++-- reflex/custom_components/custom_components.py | 71 +++++----- reflex/reflex.py | 3 - reflex/utils/build.py | 34 ++--- reflex/utils/path_ops.py | 6 +- reflex/utils/prerequisites.py | 129 ++++++------------ reflex/utils/processes.py | 4 +- tests/integration/test_urls.py | 44 +++--- tests/units/compiler/test_compiler.py | 6 +- tests/units/test_config.py | 2 +- tests/units/test_telemetry.py | 2 +- tests/units/utils/test_utils.py | 6 +- 22 files changed, 202 insertions(+), 260 deletions(-) diff --git a/benchmarks/benchmark_lighthouse.py b/benchmarks/benchmark_lighthouse.py index 72f486b6f..25d5eaac4 100644 --- a/benchmarks/benchmark_lighthouse.py +++ b/benchmarks/benchmark_lighthouse.py @@ -3,8 +3,8 @@ from __future__ import annotations import json -import os import sys +from pathlib import Path from utils import send_data_to_posthog @@ -28,7 +28,7 @@ def insert_benchmarking_data( send_data_to_posthog("lighthouse_benchmark", properties) -def get_lighthouse_scores(directory_path: str) -> dict: +def get_lighthouse_scores(directory_path: str | Path) -> dict: """Extracts the Lighthouse scores from the JSON files in the specified directory. Args: @@ -38,24 +38,21 @@ def get_lighthouse_scores(directory_path: str) -> dict: dict: The Lighthouse scores. """ scores = {} - + directory_path = Path(directory_path) try: - for filename in os.listdir(directory_path): - if filename.endswith(".json") and filename != "manifest.json": - file_path = os.path.join(directory_path, filename) - with open(file_path, "r") as file: - data = json.load(file) - # Extract scores and add them to the dictionary with the filename as key - scores[data["finalUrl"].replace("http://localhost:3000/", "/")] = { - "performance_score": data["categories"]["performance"]["score"], - "accessibility_score": data["categories"]["accessibility"][ - "score" - ], - "best_practices_score": data["categories"]["best-practices"][ - "score" - ], - "seo_score": data["categories"]["seo"]["score"], - } + for filename in directory_path.iterdir(): + if filename.suffix == ".json" and filename.stem != "manifest": + file_path = directory_path / filename + data = json.loads(file_path.read_text()) + # Extract scores and add them to the dictionary with the filename as key + scores[data["finalUrl"].replace("http://localhost:3000/", "/")] = { + "performance_score": data["categories"]["performance"]["score"], + "accessibility_score": data["categories"]["accessibility"]["score"], + "best_practices_score": data["categories"]["best-practices"][ + "score" + ], + "seo_score": data["categories"]["seo"]["score"], + } except Exception as e: return {"error": e} diff --git a/benchmarks/benchmark_package_size.py b/benchmarks/benchmark_package_size.py index 8e2704355..778b52769 100644 --- a/benchmarks/benchmark_package_size.py +++ b/benchmarks/benchmark_package_size.py @@ -2,11 +2,12 @@ import argparse import os +from pathlib import Path from utils import get_directory_size, get_python_version, send_data_to_posthog -def get_package_size(venv_path, os_name): +def get_package_size(venv_path: Path, os_name): """Get the size of a specified package. Args: @@ -26,14 +27,12 @@ def get_package_size(venv_path, os_name): is_windows = "windows" in os_name - full_path = ( - ["lib", f"python{python_version}", "site-packages"] + package_dir: Path = ( + venv_path / "lib" / f"python{python_version}" / "site-packages" if not is_windows - else ["Lib", "site-packages"] + else venv_path / "Lib" / "site-packages" ) - - package_dir = os.path.join(venv_path, *full_path) - if not os.path.exists(package_dir): + if not package_dir.exists(): raise ValueError( "Error: Virtual environment does not exist or is not activated." ) @@ -63,9 +62,9 @@ def insert_benchmarking_data( path: The path to the dir or file to check size. """ if "./dist" in path: - size = get_directory_size(path) + size = get_directory_size(Path(path)) else: - size = get_package_size(path, os_type_version) + size = get_package_size(Path(path), os_type_version) # Prepare the event data properties = { diff --git a/benchmarks/benchmark_web_size.py b/benchmarks/benchmark_web_size.py index 6c2f40bbc..3ceccecf8 100644 --- a/benchmarks/benchmark_web_size.py +++ b/benchmarks/benchmark_web_size.py @@ -2,6 +2,7 @@ import argparse import os +from pathlib import Path from utils import get_directory_size, send_data_to_posthog @@ -28,7 +29,7 @@ def insert_benchmarking_data( pr_id: The id of the PR. path: The path to the dir or file to check size. """ - size = get_directory_size(path) + size = get_directory_size(Path(path)) # Prepare the event data properties = { diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 7b02c8cc8..bfadf5b4e 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -2,12 +2,13 @@ import os import subprocess +from pathlib import Path import httpx from httpx import HTTPError -def get_python_version(venv_path, os_name): +def get_python_version(venv_path: Path, os_name): """Get the python version of python in a virtual env. Args: @@ -18,13 +19,13 @@ def get_python_version(venv_path, os_name): The python version. """ python_executable = ( - os.path.join(venv_path, "bin", "python") + venv_path / "bin" / "python" if "windows" not in os_name - else os.path.join(venv_path, "Scripts", "python.exe") + else venv_path / "Scripts" / "python.exe" ) try: output = subprocess.check_output( - [python_executable, "--version"], stderr=subprocess.STDOUT + [str(python_executable), "--version"], stderr=subprocess.STDOUT ) python_version = output.decode("utf-8").strip().split()[1] return ".".join(python_version.split(".")[:-1]) @@ -32,7 +33,7 @@ def get_python_version(venv_path, os_name): return None -def get_directory_size(directory): +def get_directory_size(directory: Path): """Get the size of a directory in bytes. Args: @@ -44,8 +45,8 @@ def get_directory_size(directory): total_size = 0 for dirpath, _, filenames in os.walk(directory): for f in filenames: - fp = os.path.join(dirpath, f) - total_size += os.path.getsize(fp) + fp = Path(dirpath) / f + total_size += fp.stat().st_size return total_size diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index edf03039e..343bda3e1 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -171,7 +171,7 @@ def _compile_root_stylesheet(stylesheets: list[str]) -> str: stylesheet_full_path = ( Path.cwd() / constants.Dirs.APP_ASSETS / stylesheet.strip("/") ) - if not os.path.exists(stylesheet_full_path): + if not stylesheet_full_path.exists(): raise FileNotFoundError( f"The stylesheet file {stylesheet_full_path} does not exist." ) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index b10552554..6f4fa2d1b 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -2,7 +2,6 @@ from __future__ import annotations -import os from pathlib import Path from typing import Any, Callable, Dict, Optional, Type, Union from urllib.parse import urlparse @@ -457,16 +456,16 @@ def add_meta( return page -def write_page(path: str, code: str): +def write_page(path: str | Path, code: str): """Write the given code to the given path. Args: path: The path to write the code to. code: The code to write. """ - path_ops.mkdir(os.path.dirname(path)) - with open(path, "w", encoding="utf-8") as f: - f.write(code) + path = Path(path) + path_ops.mkdir(path.parent) + path.write_text(code, encoding="utf-8") def empty_dir(path: str | Path, keep_files: list[str] | None = None): diff --git a/reflex/config.py b/reflex/config.py index c237d7421..a5d66cb52 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -6,7 +6,8 @@ import importlib import os import sys import urllib.parse -from typing import Any, Dict, List, Optional, Set +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union try: import pydantic.v1 as pydantic @@ -188,7 +189,7 @@ class Config(Base): telemetry_enabled: bool = True # The bun path - bun_path: str = constants.Bun.DEFAULT_PATH + bun_path: Union[str, Path] = constants.Bun.DEFAULT_PATH # List of origins that are allowed to connect to the backend API. cors_allowed_origins: List[str] = ["*"] diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 0914c087f..05675643f 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -6,6 +6,7 @@ import os import platform from enum import Enum from importlib import metadata +from pathlib import Path from types import SimpleNamespace from platformdirs import PlatformDirs @@ -66,18 +67,19 @@ class Reflex(SimpleNamespace): # Get directory value from enviroment variables if it exists. _dir = os.environ.get("REFLEX_DIR", "") - DIR = _dir or ( - # on windows, we use C:/Users//AppData/Local/reflex. - # on macOS, we use ~/Library/Application Support/reflex. - # on linux, we use ~/.local/share/reflex. - # If user sets REFLEX_DIR envroment variable use that instead. - PlatformDirs(MODULE_NAME, False).user_data_dir + DIR = Path( + _dir + or ( + # on windows, we use C:/Users//AppData/Local/reflex. + # on macOS, we use ~/Library/Application Support/reflex. + # on linux, we use ~/.local/share/reflex. + # If user sets REFLEX_DIR envroment variable use that instead. + PlatformDirs(MODULE_NAME, False).user_data_dir + ) ) # The root directory of the reflex library. - ROOT_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - ) + ROOT_DIR = Path(__file__).parents[2] RELEASES_URL = f"https://api.github.com/repos/reflex-dev/templates/releases" @@ -125,11 +127,11 @@ class Templates(SimpleNamespace): """Folders used by the template system of Reflex.""" # The template directory used during reflex init. - BASE = os.path.join(Reflex.ROOT_DIR, Reflex.MODULE_NAME, ".templates") + BASE = Reflex.ROOT_DIR / Reflex.MODULE_NAME / ".templates" # The web subdirectory of the template directory. - WEB_TEMPLATE = os.path.join(BASE, "web") + WEB_TEMPLATE = BASE / "web" # The jinja template directory. - JINJA_TEMPLATE = os.path.join(BASE, "jinja") + JINJA_TEMPLATE = BASE / "jinja" # Where the code for the templates is stored. CODE = "code" diff --git a/reflex/constants/config.py b/reflex/constants/config.py index 966727426..3ff7aade5 100644 --- a/reflex/constants/config.py +++ b/reflex/constants/config.py @@ -1,6 +1,7 @@ """Config constants.""" import os +from pathlib import Path from types import SimpleNamespace from reflex.constants.base import Dirs, Reflex @@ -17,9 +18,7 @@ class Config(SimpleNamespace): # The name of the reflex config module. MODULE = "rxconfig" # The python config file. - FILE = f"{MODULE}{Ext.PY}" - # The previous config file. - PREVIOUS_FILE = f"pcconfig{Ext.PY}" + FILE = Path(f"{MODULE}{Ext.PY}") class Expiration(SimpleNamespace): @@ -37,7 +36,7 @@ class GitIgnore(SimpleNamespace): """Gitignore constants.""" # The gitignore file. - FILE = ".gitignore" + FILE = Path(".gitignore") # Files to gitignore. DEFAULTS = {Dirs.WEB, "*.db", "__pycache__/", "*.py[cod]", "assets/external/"} diff --git a/reflex/constants/custom_components.py b/reflex/constants/custom_components.py index 3ea9cf6ed..d879a01f2 100644 --- a/reflex/constants/custom_components.py +++ b/reflex/constants/custom_components.py @@ -2,6 +2,7 @@ from __future__ import annotations +from pathlib import Path from types import SimpleNamespace @@ -11,9 +12,9 @@ class CustomComponents(SimpleNamespace): # The name of the custom components source directory. SRC_DIR = "custom_components" # The name of the custom components pyproject.toml file. - PYPROJECT_TOML = "pyproject.toml" + PYPROJECT_TOML = Path("pyproject.toml") # The name of the custom components package README file. - PACKAGE_README = "README.md" + PACKAGE_README = Path("README.md") # The name of the custom components package .gitignore file. PACKAGE_GITIGNORE = ".gitignore" # The name of the distribution directory as result of a build. @@ -29,6 +30,6 @@ class CustomComponents(SimpleNamespace): "testpypi": "https://test.pypi.org/legacy/", } # The .gitignore file for the custom component project. - FILE = ".gitignore" + FILE = Path(".gitignore") # Files to gitignore. DEFAULTS = {"__pycache__/", "*.py[cod]", "*.egg-info/", "dist/"} diff --git a/reflex/constants/installer.py b/reflex/constants/installer.py index 01a11a37e..a815b1284 100644 --- a/reflex/constants/installer.py +++ b/reflex/constants/installer.py @@ -2,7 +2,6 @@ from __future__ import annotations -import os import platform from types import SimpleNamespace @@ -40,11 +39,10 @@ class Bun(SimpleNamespace): # Min Bun Version MIN_VERSION = "0.7.0" # The directory to store the bun. - ROOT_PATH = os.path.join(Reflex.DIR, "bun") + ROOT_PATH = Reflex.DIR / "bun" # Default bun path. - DEFAULT_PATH = os.path.join( - ROOT_PATH, "bin", "bun" if not IS_WINDOWS else "bun.exe" - ) + DEFAULT_PATH = ROOT_PATH / "bin" / ("bun" if not IS_WINDOWS else "bun.exe") + # URL to bun install script. INSTALL_URL = "https://bun.sh/install" # URL to windows install script. @@ -65,10 +63,10 @@ class Fnm(SimpleNamespace): # The FNM version. VERSION = "1.35.1" # The directory to store fnm. - DIR = os.path.join(Reflex.DIR, "fnm") + DIR = Reflex.DIR / "fnm" FILENAME = get_fnm_name() # The fnm executable binary. - EXE = os.path.join(DIR, "fnm.exe" if IS_WINDOWS else "fnm") + EXE = DIR / ("fnm.exe" if IS_WINDOWS else "fnm") # The URL to the fnm release binary INSTALL_URL = ( @@ -86,18 +84,19 @@ class Node(SimpleNamespace): MIN_VERSION = "18.17.0" # The node bin path. - BIN_PATH = os.path.join( - Fnm.DIR, - "node-versions", - f"v{VERSION}", - "installation", - "bin" if not IS_WINDOWS else "", + BIN_PATH = ( + Fnm.DIR + / "node-versions" + / f"v{VERSION}" + / "installation" + / ("bin" if not IS_WINDOWS else "") ) + # The default path where node is installed. - PATH = os.path.join(BIN_PATH, "node.exe" if IS_WINDOWS else "node") + PATH = BIN_PATH / ("node.exe" if IS_WINDOWS else "node") # The default path where npm is installed. - NPM_PATH = os.path.join(BIN_PATH, "npm") + NPM_PATH = BIN_PATH / "npm" # The environment variable to use the system installed node. USE_SYSTEM_VAR = "REFLEX_USE_SYSTEM_NODE" diff --git a/reflex/custom_components/custom_components.py b/reflex/custom_components/custom_components.py index e6957f8fd..146bb12c2 100644 --- a/reflex/custom_components/custom_components.py +++ b/reflex/custom_components/custom_components.py @@ -36,7 +36,7 @@ POST_CUSTOM_COMPONENTS_GALLERY_TIMEOUT = 15 @contextmanager -def set_directory(working_directory: str): +def set_directory(working_directory: str | Path): """Context manager that sets the working directory. Args: @@ -45,7 +45,8 @@ def set_directory(working_directory: str): Yields: Yield to the caller to perform operations in the working directory. """ - current_directory = os.getcwd() + current_directory = Path.cwd() + working_directory = Path(working_directory) try: os.chdir(working_directory) yield @@ -62,14 +63,14 @@ def _create_package_config(module_name: str, package_name: str): """ from reflex.compiler import templates - with open(CustomComponents.PYPROJECT_TOML, "w") as f: - f.write( - templates.CUSTOM_COMPONENTS_PYPROJECT_TOML.render( - module_name=module_name, - package_name=package_name, - reflex_version=constants.Reflex.VERSION, - ) + pyproject = Path(CustomComponents.PYPROJECT_TOML) + pyproject.write_text( + templates.CUSTOM_COMPONENTS_PYPROJECT_TOML.render( + module_name=module_name, + package_name=package_name, + reflex_version=constants.Reflex.VERSION, ) + ) def _get_package_config(exit_on_fail: bool = True) -> dict: @@ -84,11 +85,11 @@ def _get_package_config(exit_on_fail: bool = True) -> dict: Raises: Exit: If the pyproject.toml file is not found. """ + pyproject = Path(CustomComponents.PYPROJECT_TOML) try: - with open(CustomComponents.PYPROJECT_TOML, "rb") as f: - return dict(tomlkit.load(f)) + return dict(tomlkit.loads(pyproject.read_bytes())) except (OSError, TOMLKitError) as ex: - console.error(f"Unable to read from pyproject.toml due to {ex}") + console.error(f"Unable to read from {pyproject} due to {ex}") if exit_on_fail: raise typer.Exit(code=1) from ex raise @@ -103,17 +104,17 @@ def _create_readme(module_name: str, package_name: str): """ from reflex.compiler import templates - with open(CustomComponents.PACKAGE_README, "w") as f: - f.write( - templates.CUSTOM_COMPONENTS_README.render( - module_name=module_name, - package_name=package_name, - ) + readme = Path(CustomComponents.PACKAGE_README) + readme.write_text( + templates.CUSTOM_COMPONENTS_README.render( + module_name=module_name, + package_name=package_name, ) + ) def _write_source_and_init_py( - custom_component_src_dir: str, + custom_component_src_dir: Path, component_class_name: str, module_name: str, ): @@ -126,27 +127,17 @@ def _write_source_and_init_py( """ from reflex.compiler import templates - with open( - os.path.join( - custom_component_src_dir, - f"{module_name}.py", - ), - "w", - ) as f: - f.write( - templates.CUSTOM_COMPONENTS_SOURCE.render( - component_class_name=component_class_name, module_name=module_name - ) + module_path = custom_component_src_dir / f"{module_name}.py" + module_path.write_text( + templates.CUSTOM_COMPONENTS_SOURCE.render( + component_class_name=component_class_name, module_name=module_name ) + ) - with open( - os.path.join( - custom_component_src_dir, - CustomComponents.INIT_FILE, - ), - "w", - ) as f: - f.write(templates.CUSTOM_COMPONENTS_INIT_FILE.render(module_name=module_name)) + init_path = custom_component_src_dir / CustomComponents.INIT_FILE + init_path.write_text( + templates.CUSTOM_COMPONENTS_INIT_FILE.render(module_name=module_name) + ) def _populate_demo_app(name_variants: NameVariants): @@ -192,7 +183,7 @@ def _get_default_library_name_parts() -> list[str]: Returns: The parts of default library name. """ - current_dir_name = os.getcwd().split(os.path.sep)[-1] + current_dir_name = Path.cwd().name cleaned_dir_name = re.sub("[^0-9a-zA-Z-_]+", "", current_dir_name).lower() parts = [part for part in re.split("-|_", cleaned_dir_name) if part] @@ -345,7 +336,7 @@ def init( console.set_log_level(loglevel) - if os.path.exists(CustomComponents.PYPROJECT_TOML): + if CustomComponents.PYPROJECT_TOML.exists(): console.error(f"A {CustomComponents.PYPROJECT_TOML} already exists. Aborting.") typer.Exit(code=1) diff --git a/reflex/reflex.py b/reflex/reflex.py index 99dccf831..4608ed171 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -114,9 +114,6 @@ def _init( app_name, generation_hash=generation_hash ) - # Migrate Pynecone projects to Reflex. - prerequisites.migrate_to_reflex() - # Initialize the .gitignore. prerequisites.initialize_gitignore() diff --git a/reflex/utils/build.py b/reflex/utils/build.py index 7a67ec32e..f16224d05 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -61,8 +61,8 @@ def generate_sitemap_config(deploy_url: str, export=False): def _zip( component_name: constants.ComponentName, - target: str, - root_dir: str, + target: str | Path, + root_dir: str | Path, exclude_venv_dirs: bool, upload_db_file: bool = False, dirs_to_exclude: set[str] | None = None, @@ -82,22 +82,22 @@ def _zip( top_level_dirs_to_exclude: The top level directory names immediately under root_dir to exclude. Do not exclude folders by these names further in the sub-directories. """ + target = Path(target) + root_dir = Path(root_dir) dirs_to_exclude = dirs_to_exclude or set() files_to_exclude = files_to_exclude or set() files_to_zip: list[str] = [] # Traverse the root directory in a top-down manner. In this traversal order, # we can modify the dirs list in-place to remove directories we don't want to include. for root, dirs, files in os.walk(root_dir, topdown=True): + root = Path(root) # Modify the dirs in-place so excluded and hidden directories are skipped in next traversal. dirs[:] = [ d for d in dirs - if (basename := os.path.basename(os.path.normpath(d))) - not in dirs_to_exclude + if (basename := Path(d).resolve().name) not in dirs_to_exclude and not basename.startswith(".") - and ( - not exclude_venv_dirs or not _looks_like_venv_dir(os.path.join(root, d)) - ) + and (not exclude_venv_dirs or not _looks_like_venv_dir(root / d)) ] # If we are at the top level with root_dir, exclude the top level dirs. if top_level_dirs_to_exclude and root == root_dir: @@ -109,7 +109,7 @@ def _zip( if not f.startswith(".") and (upload_db_file or not f.endswith(".db")) ] files_to_zip += [ - os.path.join(root, file) for file in files if file not in files_to_exclude + str(root / file) for file in files if file not in files_to_exclude ] # Create a progress bar for zipping the component. @@ -126,13 +126,13 @@ def _zip( for file in files_to_zip: console.debug(f"{target}: {file}", progress=progress) progress.advance(task) - zipf.write(file, os.path.relpath(file, root_dir)) + zipf.write(file, Path(file).relative_to(root_dir)) def zip_app( frontend: bool = True, backend: bool = True, - zip_dest_dir: str = os.getcwd(), + zip_dest_dir: str | Path = Path.cwd(), upload_db_file: bool = False, ): """Zip up the app. @@ -143,6 +143,7 @@ def zip_app( zip_dest_dir: The directory to export the zip file to. upload_db_file: Whether to upload the database file. """ + zip_dest_dir = Path(zip_dest_dir) files_to_exclude = { constants.ComponentName.FRONTEND.zip(), constants.ComponentName.BACKEND.zip(), @@ -151,8 +152,8 @@ def zip_app( if frontend: _zip( component_name=constants.ComponentName.FRONTEND, - target=os.path.join(zip_dest_dir, constants.ComponentName.FRONTEND.zip()), - root_dir=str(prerequisites.get_web_dir() / constants.Dirs.STATIC), + target=zip_dest_dir / constants.ComponentName.FRONTEND.zip(), + root_dir=prerequisites.get_web_dir() / constants.Dirs.STATIC, files_to_exclude=files_to_exclude, exclude_venv_dirs=False, ) @@ -160,8 +161,8 @@ def zip_app( if backend: _zip( component_name=constants.ComponentName.BACKEND, - target=os.path.join(zip_dest_dir, constants.ComponentName.BACKEND.zip()), - root_dir=".", + target=zip_dest_dir / constants.ComponentName.BACKEND.zip(), + root_dir=Path("."), dirs_to_exclude={"__pycache__"}, files_to_exclude=files_to_exclude, top_level_dirs_to_exclude={"assets"}, @@ -266,5 +267,6 @@ def setup_frontend_prod( build(deploy_url=get_config().deploy_url) -def _looks_like_venv_dir(dir_to_check: str) -> bool: - return os.path.exists(os.path.join(dir_to_check, "pyvenv.cfg")) +def _looks_like_venv_dir(dir_to_check: str | Path) -> bool: + dir_to_check = Path(dir_to_check) + return (dir_to_check / "pyvenv.cfg").exists() diff --git a/reflex/utils/path_ops.py b/reflex/utils/path_ops.py index 00affd820..1de635b88 100644 --- a/reflex/utils/path_ops.py +++ b/reflex/utils/path_ops.py @@ -164,7 +164,7 @@ def use_system_bun() -> bool: return use_system_install(constants.Bun.USE_SYSTEM_VAR) -def get_node_bin_path() -> str | None: +def get_node_bin_path() -> Path | None: """Get the node binary dir path. Returns: @@ -173,8 +173,8 @@ def get_node_bin_path() -> str | None: bin_path = Path(constants.Node.BIN_PATH) if not bin_path.exists(): str_path = which("node") - return str(Path(str_path).parent.resolve()) if str_path else str_path - return str(bin_path.resolve()) + return Path(str_path).parent.resolve() if str_path else None + return bin_path.resolve() def get_node_path() -> str | None: diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index f9eb9a790..0e49e3b55 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -2,9 +2,9 @@ from __future__ import annotations +import contextlib import dataclasses import functools -import glob import importlib import importlib.metadata import json @@ -19,7 +19,6 @@ import tempfile import time import zipfile from datetime import datetime -from fileinput import FileInput from pathlib import Path from types import ModuleType from typing import Callable, List, Optional @@ -192,7 +191,7 @@ def get_bun_version() -> version.Version | None: """ try: # Run the bun -v command and capture the output - result = processes.new_process([get_config().bun_path, "-v"], run=True) + result = processes.new_process([str(get_config().bun_path), "-v"], run=True) return version.parse(result.stdout) # type: ignore except FileNotFoundError: return None @@ -217,7 +216,7 @@ def get_install_package_manager() -> str | None: or windows_npm_escape_hatch() ): return get_package_manager() - return get_config().bun_path + return str(get_config().bun_path) def get_package_manager() -> str | None: @@ -394,9 +393,7 @@ def validate_app_name(app_name: str | None = None) -> str: Raises: Exit: if the app directory name is reflex or if the name is not standard for a python package name. """ - app_name = ( - app_name if app_name else os.getcwd().split(os.path.sep)[-1].replace("-", "_") - ) + app_name = app_name if app_name else Path.cwd().name.replace("-", "_") # Make sure the app is not named "reflex". if app_name.lower() == constants.Reflex.MODULE_NAME: console.error( @@ -430,7 +427,7 @@ def create_config(app_name: str): def initialize_gitignore( - gitignore_file: str = constants.GitIgnore.FILE, + gitignore_file: Path = constants.GitIgnore.FILE, files_to_ignore: set[str] = constants.GitIgnore.DEFAULTS, ): """Initialize the template .gitignore file. @@ -441,9 +438,10 @@ def initialize_gitignore( """ # Combine with the current ignored files. current_ignore: set[str] = set() - if os.path.exists(gitignore_file): - with open(gitignore_file, "r") as f: - current_ignore |= set([line.strip() for line in f.readlines()]) + if gitignore_file.exists(): + current_ignore |= set( + line.strip() for line in gitignore_file.read_text().splitlines() + ) if files_to_ignore == current_ignore: console.debug(f"{gitignore_file} already up to date.") @@ -451,9 +449,11 @@ def initialize_gitignore( files_to_ignore |= current_ignore # Write files to the .gitignore file. - with open(gitignore_file, "w", newline="\n") as f: - console.debug(f"Creating {gitignore_file}") - f.write(f"{(path_ops.join(sorted(files_to_ignore))).lstrip()}\n") + gitignore_file.touch(exist_ok=True) + console.debug(f"Creating {gitignore_file}") + gitignore_file.write_text( + "\n".join(sorted(files_to_ignore)) + "\n", + ) def initialize_requirements_txt(): @@ -546,8 +546,8 @@ def initialize_app_directory( # Rename the template app to the app name. path_ops.mv(template_code_dir_name, app_name) path_ops.mv( - os.path.join(app_name, template_name + constants.Ext.PY), - os.path.join(app_name, app_name + constants.Ext.PY), + Path(app_name) / (template_name + constants.Ext.PY), + Path(app_name) / (app_name + constants.Ext.PY), ) # Fix up the imports. @@ -691,7 +691,7 @@ def _update_next_config( def remove_existing_bun_installation(): """Remove existing bun installation.""" console.debug("Removing existing bun installation.") - if os.path.exists(get_config().bun_path): + if Path(get_config().bun_path).exists(): path_ops.rm(constants.Bun.ROOT_PATH) @@ -731,7 +731,7 @@ def download_and_extract_fnm_zip(): # Download the zip file url = constants.Fnm.INSTALL_URL console.debug(f"Downloading {url}") - fnm_zip_file = os.path.join(constants.Fnm.DIR, f"{constants.Fnm.FILENAME}.zip") + fnm_zip_file = constants.Fnm.DIR / f"{constants.Fnm.FILENAME}.zip" # Function to download and extract the FNM zip release. try: # Download the FNM zip release. @@ -770,7 +770,7 @@ def install_node(): return path_ops.mkdir(constants.Fnm.DIR) - if not os.path.exists(constants.Fnm.EXE): + if not constants.Fnm.EXE.exists(): download_and_extract_fnm_zip() if constants.IS_WINDOWS: @@ -827,7 +827,7 @@ def install_bun(): ) # Skip if bun is already installed. - if os.path.exists(get_config().bun_path) and get_bun_version() == version.parse( + if Path(get_config().bun_path).exists() and get_bun_version() == version.parse( constants.Bun.VERSION ): console.debug("Skipping bun installation as it is already installed.") @@ -842,7 +842,7 @@ def install_bun(): f"irm {constants.Bun.WINDOWS_INSTALL_URL}|iex", ], env={ - "BUN_INSTALL": constants.Bun.ROOT_PATH, + "BUN_INSTALL": str(constants.Bun.ROOT_PATH), "BUN_VERSION": constants.Bun.VERSION, }, shell=True, @@ -858,25 +858,26 @@ def install_bun(): download_and_run( constants.Bun.INSTALL_URL, f"bun-v{constants.Bun.VERSION}", - BUN_INSTALL=constants.Bun.ROOT_PATH, + BUN_INSTALL=str(constants.Bun.ROOT_PATH), ) -def _write_cached_procedure_file(payload: str, cache_file: str): - with open(cache_file, "w") as f: - f.write(payload) +def _write_cached_procedure_file(payload: str, cache_file: str | Path): + cache_file = Path(cache_file) + cache_file.write_text(payload) -def _read_cached_procedure_file(cache_file: str) -> str | None: - if os.path.exists(cache_file): - with open(cache_file, "r") as f: - return f.read() +def _read_cached_procedure_file(cache_file: str | Path) -> str | None: + cache_file = Path(cache_file) + if cache_file.exists(): + return cache_file.read_text() return None -def _clear_cached_procedure_file(cache_file: str): - if os.path.exists(cache_file): - os.remove(cache_file) +def _clear_cached_procedure_file(cache_file: str | Path): + cache_file = Path(cache_file) + if cache_file.exists(): + cache_file.unlink() def cached_procedure(cache_file: str, payload_fn: Callable[..., str]): @@ -977,7 +978,7 @@ def needs_reinit(frontend: bool = True) -> bool: Raises: Exit: If the app is not initialized. """ - if not os.path.exists(constants.Config.FILE): + if not constants.Config.FILE.exists(): console.error( f"[cyan]{constants.Config.FILE}[/cyan] not found. Move to the root folder of your project, or run [bold]{constants.Reflex.MODULE_NAME} init[/bold] to start a new project." ) @@ -988,7 +989,7 @@ def needs_reinit(frontend: bool = True) -> bool: return False # Make sure the .reflex directory exists. - if not os.path.exists(constants.Reflex.DIR): + if not constants.Reflex.DIR.exists(): return True # Make sure the .web directory exists in frontend mode. @@ -1093,25 +1094,21 @@ def ensure_reflex_installation_id() -> Optional[int]: """ try: initialize_reflex_user_directory() - installation_id_file = os.path.join(constants.Reflex.DIR, "installation_id") + installation_id_file = constants.Reflex.DIR / "installation_id" installation_id = None - if os.path.exists(installation_id_file): - try: - with open(installation_id_file, "r") as f: - installation_id = int(f.read()) - except Exception: + if installation_id_file.exists(): + with contextlib.suppress(Exception): + installation_id = int(installation_id_file.read_text()) # If anything goes wrong at all... just regenerate. # Like what? Examples: # - file not exists # - file not readable # - content not parseable as an int - pass if installation_id is None: installation_id = random.getrandbits(128) - with open(installation_id_file, "w") as f: - f.write(str(installation_id)) + installation_id_file.write_text(str(installation_id)) # If we get here, installation_id is definitely set return installation_id except Exception as e: @@ -1205,50 +1202,6 @@ def prompt_for_template(templates: list[Template]) -> str: return templates[int(template)].name -def migrate_to_reflex(): - """Migration from Pynecone to Reflex.""" - # Check if the old config file exists. - if not os.path.exists(constants.Config.PREVIOUS_FILE): - return - - # Ask the user if they want to migrate. - action = console.ask( - "Pynecone project detected. Automatically upgrade to Reflex?", - choices=["y", "n"], - ) - if action == "n": - return - - # Rename pcconfig to rxconfig. - console.log( - f"[bold]Renaming {constants.Config.PREVIOUS_FILE} to {constants.Config.FILE}" - ) - os.rename(constants.Config.PREVIOUS_FILE, constants.Config.FILE) - - # Find all python files in the app directory. - file_pattern = os.path.join(get_config().app_name, "**/*.py") - file_list = glob.glob(file_pattern, recursive=True) - - # Add the config file to the list of files to be migrated. - file_list.append(constants.Config.FILE) - - # Migrate all files. - updates = { - "Pynecone": "Reflex", - "pynecone as pc": "reflex as rx", - "pynecone.io": "reflex.dev", - "pynecone": "reflex", - "pc.": "rx.", - "pcconfig": "rxconfig", - } - for file_path in file_list: - with FileInput(file_path, inplace=True) as file: - for line in file: - for old, new in updates.items(): - line = line.replace(old, new) - print(line, end="") - - def fetch_app_templates(version: str) -> dict[str, Template]: """Fetch a dict of templates from the templates repo using github API. @@ -1401,7 +1354,7 @@ def initialize_app(app_name: str, template: str | None = None): from reflex.utils import telemetry # Check if the app is already initialized. - if os.path.exists(constants.Config.FILE): + if constants.Config.FILE.exists(): telemetry.send("reinit") return diff --git a/reflex/utils/processes.py b/reflex/utils/processes.py index c435af7d0..a45676c01 100644 --- a/reflex/utils/processes.py +++ b/reflex/utils/processes.py @@ -156,7 +156,7 @@ def new_process(args, run: bool = False, show_logs: bool = False, **kwargs): Raises: Exit: When attempting to run a command with a None value. """ - node_bin_path = path_ops.get_node_bin_path() + node_bin_path = str(path_ops.get_node_bin_path()) if not node_bin_path and not prerequisites.CURRENTLY_INSTALLING_NODE: console.warn( "The path to the Node binary could not be found. Please ensure that Node is properly " @@ -167,7 +167,7 @@ def new_process(args, run: bool = False, show_logs: bool = False, **kwargs): console.error(f"Invalid command: {args}") raise typer.Exit(1) # Add the node bin path to the PATH environment variable. - env = { + env: dict[str, str] = { **os.environ, "PATH": os.pathsep.join( [node_bin_path if node_bin_path else "", os.environ["PATH"]] diff --git a/tests/integration/test_urls.py b/tests/integration/test_urls.py index bcf17fe41..81689aa18 100755 --- a/tests/integration/test_urls.py +++ b/tests/integration/test_urls.py @@ -8,7 +8,7 @@ import pytest import requests -def check_urls(repo_dir): +def check_urls(repo_dir: Path): """Check that all URLs in the repo are valid and secure. Args: @@ -21,33 +21,33 @@ def check_urls(repo_dir): errors = [] for root, _dirs, files in os.walk(repo_dir): - if "__pycache__" in root: + root = Path(root) + if root.stem == "__pycache__": continue for file_name in files: if not file_name.endswith(".py") and not file_name.endswith(".md"): continue - file_path = os.path.join(root, file_name) + file_path = root / file_name try: - with open(file_path, "r", encoding="utf-8", errors="ignore") as file: - for line in file: - urls = url_pattern.findall(line) - for url in set(urls): - if url.startswith("http://"): - errors.append( - f"Found insecure HTTP URL: {url} in {file_path}" - ) - url = url.strip('"\n') - try: - response = requests.head( - url, allow_redirects=True, timeout=5 - ) - response.raise_for_status() - except requests.RequestException as e: - errors.append( - f"Error accessing URL: {url} in {file_path} | Error: {e}, , Check your path ends with a /" - ) + for line in file_path.read_text().splitlines(): + urls = url_pattern.findall(line) + for url in set(urls): + if url.startswith("http://"): + errors.append( + f"Found insecure HTTP URL: {url} in {file_path}" + ) + url = url.strip('"\n') + try: + response = requests.head( + url, allow_redirects=True, timeout=5 + ) + response.raise_for_status() + except requests.RequestException as e: + errors.append( + f"Error accessing URL: {url} in {file_path} | Error: {e}, , Check your path ends with a /" + ) except Exception as e: errors.append(f"Error reading file: {file_path} | Error: {e}") @@ -58,7 +58,7 @@ def check_urls(repo_dir): "repo_dir", [Path(__file__).resolve().parent.parent / "reflex"], ) -def test_find_and_check_urls(repo_dir): +def test_find_and_check_urls(repo_dir: Path): """Test that all URLs in the repo are valid and secure. Args: diff --git a/tests/units/compiler/test_compiler.py b/tests/units/compiler/test_compiler.py index 63014cf33..afacf43c5 100644 --- a/tests/units/compiler/test_compiler.py +++ b/tests/units/compiler/test_compiler.py @@ -1,4 +1,4 @@ -import os +from pathlib import Path from typing import List import pytest @@ -130,7 +130,7 @@ def test_compile_stylesheets(tmp_path, mocker): ] assert compiler.compile_root_stylesheet(stylesheets) == ( - os.path.join(".web", "styles", "styles.css"), + str(Path(".web") / "styles" / "styles.css"), f"@import url('./tailwind.css'); \n" f"@import url('https://fonts.googleapis.com/css?family=Sofia&effect=neon|outline|emboss|shadow-multiple'); \n" f"@import url('https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css'); \n" @@ -164,7 +164,7 @@ def test_compile_stylesheets_exclude_tailwind(tmp_path, mocker): ] assert compiler.compile_root_stylesheet(stylesheets) == ( - os.path.join(".web", "styles", "styles.css"), + str(Path(".web") / "styles" / "styles.css"), "@import url('../public/styles.css'); \n", ) diff --git a/tests/units/test_config.py b/tests/units/test_config.py index 31dd77649..a6c6fe697 100644 --- a/tests/units/test_config.py +++ b/tests/units/test_config.py @@ -192,4 +192,4 @@ def test_reflex_dir_env_var(monkeypatch, tmp_path): mp_ctx = multiprocessing.get_context(method="spawn") with mp_ctx.Pool(processes=1) as pool: - assert pool.apply(reflex_dir_constant) == str(tmp_path) + assert pool.apply(reflex_dir_constant) == tmp_path diff --git a/tests/units/test_telemetry.py b/tests/units/test_telemetry.py index a434779d4..25ad91323 100644 --- a/tests/units/test_telemetry.py +++ b/tests/units/test_telemetry.py @@ -52,4 +52,4 @@ def test_send(mocker, event): telemetry._send(event, telemetry_enabled=True) httpx_post_mock.assert_called_once() - pathlib_path_read_text_mock.assert_called_once() + assert pathlib_path_read_text_mock.call_count == 2 diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py index 5cdd846fe..41bd4e661 100644 --- a/tests/units/utils/test_utils.py +++ b/tests/units/utils/test_utils.py @@ -117,7 +117,7 @@ def test_remove_existing_bun_installation(mocker): Args: mocker: Pytest mocker. """ - mocker.patch("reflex.utils.prerequisites.os.path.exists", return_value=True) + mocker.patch("reflex.utils.prerequisites.Path.exists", return_value=True) rm = mocker.patch("reflex.utils.prerequisites.path_ops.rm", mocker.Mock()) prerequisites.remove_existing_bun_installation() @@ -458,7 +458,7 @@ def test_bun_install_without_unzip(mocker): mocker: Pytest mocker object. """ mocker.patch("reflex.utils.path_ops.which", return_value=None) - mocker.patch("os.path.exists", return_value=False) + mocker.patch("pathlib.Path.exists", return_value=False) mocker.patch("reflex.utils.prerequisites.constants.IS_WINDOWS", False) with pytest.raises(FileNotFoundError): @@ -476,7 +476,7 @@ def test_bun_install_version(mocker, bun_version): """ mocker.patch("reflex.utils.prerequisites.constants.IS_WINDOWS", False) - mocker.patch("os.path.exists", return_value=True) + mocker.patch("pathlib.Path.exists", return_value=True) mocker.patch( "reflex.utils.prerequisites.get_bun_version", return_value=version.parse(bun_version), From 12d73e4167c4d204591a1eb4a6390327e362d1e8 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Thu, 3 Oct 2024 16:48:50 +0000 Subject: [PATCH 04/16] Track the last reflex run time (#4045) --- reflex/utils/build.py | 3 +++ reflex/utils/prerequisites.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/reflex/utils/build.py b/reflex/utils/build.py index f16224d05..770809015 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -237,6 +237,9 @@ def setup_frontend( # Set the environment variables in client (env.json). set_env_json() + # update the last reflex run time. + prerequisites.set_last_reflex_run_time() + # Disable the Next telemetry. if disable_telemetry: processes.new_process( diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 0e49e3b55..3c2875204 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -131,6 +131,14 @@ def get_or_set_last_reflex_version_check_datetime(): return last_version_check_datetime +def set_last_reflex_run_time(): + """Set the last Reflex run time.""" + path_ops.update_json_file( + get_web_dir() / constants.Reflex.JSON, + {"last_reflex_run_datetime": str(datetime.now())}, + ) + + def check_node_version() -> bool: """Check the version of Node.js. From 4b3d05621282b449aaf11c3f6d4982536d558797 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Thu, 3 Oct 2024 17:35:34 +0000 Subject: [PATCH 05/16] [ENG-3476] Setting State Vars that are not defined should raise an error (#4007) --- reflex/state.py | 15 +++++++++++++ reflex/utils/exceptions.py | 4 ++++ tests/units/test_state.py | 43 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/reflex/state.py b/reflex/state.py index cda36a0a9..64ea960e1 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -74,6 +74,7 @@ from reflex.utils.exceptions import ( EventHandlerShadowsBuiltInStateMethod, ImmutableStateError, LockExpiredError, + SetUndefinedStateVarError, ) from reflex.utils.exec import is_testing_env from reflex.utils.serializers import serializer @@ -1260,6 +1261,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Args: name: The name of the attribute. value: The value of the attribute. + + Raises: + SetUndefinedStateVarError: If a value of a var is set without first defining it. """ if isinstance(value, MutableProxy): # unwrap proxy objects when assigning back to the state @@ -1277,6 +1281,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): self._mark_dirty() return + if ( + name not in self.vars + and name not in self.get_skip_vars() + and not name.startswith("__") + and not name.startswith(f"_{type(self).__name__}__") + ): + raise SetUndefinedStateVarError( + f"The state variable '{name}' has not been defined in '{type(self).__name__}'. " + f"All state variables must be declared before they can be set." + ) + # Set the attribute. super().__setattr__(name, value) diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 7c3532861..9c79a387a 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -115,3 +115,7 @@ class PrimitiveUnserializableToJSON(ReflexError, ValueError): class InvalidLifespanTaskType(ReflexError, TypeError): """Raised when an invalid task type is registered as a lifespan task.""" + + +class SetUndefinedStateVarError(ReflexError, AttributeError): + """Raised when setting the value of a var without first declaring it.""" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 205162b9f..5bfac7628 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -41,6 +41,7 @@ from reflex.state import ( ) from reflex.testing import chdir from reflex.utils import format, prerequisites, types +from reflex.utils.exceptions import SetUndefinedStateVarError from reflex.utils.format import json_dumps from reflex.vars.base import ComputedVar, Var from tests.units.states.mutation import MutableSQLAModel, MutableTestState @@ -3262,3 +3263,45 @@ def test_child_mixin_state() -> None: assert "computed" in ChildUsesMixinState.inherited_vars assert "computed" not in ChildUsesMixinState.computed_vars + + +def test_assignment_to_undeclared_vars(): + """Test that an attribute error is thrown when undeclared vars are set.""" + + class State(BaseState): + val: str + _val: str + __val: str # type: ignore + + def handle_supported_regular_vars(self): + self.val = "no underscore" + self._val = "single leading underscore" + self.__val = "double leading undercore" + + def handle_regular_var(self): + self.num = 5 + + def handle_backend_var(self): + self._num = 5 + + def handle_non_var(self): + self.__num = 5 + + class Substate(State): + def handle_var(self): + self.value = 20 + + state = State() # type: ignore + sub_state = Substate() # type: ignore + + with pytest.raises(SetUndefinedStateVarError): + state.handle_regular_var() + + with pytest.raises(SetUndefinedStateVarError): + sub_state.handle_var() + + with pytest.raises(SetUndefinedStateVarError): + state.handle_backend_var() + + state.handle_supported_regular_vars() + state.handle_non_var() From 27bb7179d62c83547fd8fae2d2e3c69d8e5215a4 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 3 Oct 2024 12:24:56 -0700 Subject: [PATCH 06/16] default should be warning for subprocesses not info (#4049) --- reflex/constants/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 05675643f..b86f083cc 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -199,7 +199,7 @@ class LogLevel(str, Enum): Returns: The log level for the subprocess """ - return self if self != LogLevel.DEFAULT else LogLevel.INFO + return self if self != LogLevel.DEFAULT else LogLevel.WARNING # Server socket configuration variables From 56709a210b0b722c05b72dc515634d19afb1586b Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 3 Oct 2024 13:01:19 -0700 Subject: [PATCH 07/16] add of_type to _evaluate (#4051) * add of_type to _evaluate * get it right pyright --- reflex/state.py | 20 ++++++++++++++++---- tests/integration/test_dynamic_components.py | 4 +++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 64ea960e1..1746835be 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -699,11 +699,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) @classmethod - def _evaluate(cls, f: Callable[[Self], Any]) -> Var: + def _evaluate( + cls, f: Callable[[Self], Any], of_type: Union[type, None] = None + ) -> Var: """Evaluate a function to a ComputedVar. Experimental. Args: f: The function to evaluate. + of_type: The type of the ComputedVar. Defaults to Component. Returns: The ComputedVar. @@ -711,14 +714,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): console.warn( "The _evaluate method is experimental and may be removed in future versions." ) - from reflex.components.base.fragment import fragment from reflex.components.component import Component + of_type = of_type or Component + unique_var_name = get_unique_variable_name() - @computed_var(_js_expr=unique_var_name, return_type=Component) + @computed_var(_js_expr=unique_var_name, return_type=of_type) def computed_var_func(state: Self): - return fragment(f(state)) + result = f(state) + + if not isinstance(result, of_type): + console.warn( + f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. " + "You can specify expected type with `of_type` argument." + ) + + return result setattr(cls, unique_var_name, computed_var_func) cls.computed_vars[unique_var_name] = computed_var_func diff --git a/tests/integration/test_dynamic_components.py b/tests/integration/test_dynamic_components.py index 5a4d99f9e..aeebd10e9 100644 --- a/tests/integration/test_dynamic_components.py +++ b/tests/integration/test_dynamic_components.py @@ -65,7 +65,9 @@ def DynamicComponents(): DynamicComponentsState.client_token_component, DynamicComponentsState.button, rx.text( - DynamicComponentsState._evaluate(lambda state: factorial(state.value)), + DynamicComponentsState._evaluate( + lambda state: factorial(state.value), of_type=int + ), id="factorial", ), ) From 40f1880932cec15231262e705566033835920594 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Thu, 3 Oct 2024 14:18:28 -0700 Subject: [PATCH 08/16] move router dataclasses in their own file (#4044) --- reflex/istate/data.py | 126 ++++++++++++++++++++++++++++++++++++++++++ reflex/state.py | 120 +--------------------------------------- 2 files changed, 127 insertions(+), 119 deletions(-) create mode 100644 reflex/istate/data.py diff --git a/reflex/istate/data.py b/reflex/istate/data.py new file mode 100644 index 000000000..9f6e3b3f4 --- /dev/null +++ b/reflex/istate/data.py @@ -0,0 +1,126 @@ +"""This module contains the dataclasses representing the router object.""" + +import dataclasses +from typing import Optional + +from reflex import constants +from reflex.utils import format + + +@dataclasses.dataclass(frozen=True) +class HeaderData: + """An object containing headers data.""" + + host: str = "" + origin: str = "" + upgrade: str = "" + connection: str = "" + cookie: str = "" + pragma: str = "" + cache_control: str = "" + user_agent: str = "" + sec_websocket_version: str = "" + sec_websocket_key: str = "" + sec_websocket_extensions: str = "" + accept_encoding: str = "" + accept_language: str = "" + + def __init__(self, router_data: Optional[dict] = None): + """Initalize the HeaderData object based on router_data. + + Args: + router_data: the router_data dict. + """ + if router_data: + for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items(): + object.__setattr__(self, format.to_snake_case(k), v) + else: + for k in dataclasses.fields(self): + object.__setattr__(self, k.name, "") + + +@dataclasses.dataclass(frozen=True) +class PageData: + """An object containing page data.""" + + host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) + path: str = "" + raw_path: str = "" + full_path: str = "" + full_raw_path: str = "" + params: dict = dataclasses.field(default_factory=dict) + + def __init__(self, router_data: Optional[dict] = None): + """Initalize the PageData object based on router_data. + + Args: + router_data: the router_data dict. + """ + if router_data: + object.__setattr__( + self, + "host", + router_data.get(constants.RouteVar.HEADERS, {}).get("origin", ""), + ) + object.__setattr__( + self, "path", router_data.get(constants.RouteVar.PATH, "") + ) + object.__setattr__( + self, "raw_path", router_data.get(constants.RouteVar.ORIGIN, "") + ) + object.__setattr__(self, "full_path", f"{self.host}{self.path}") + object.__setattr__(self, "full_raw_path", f"{self.host}{self.raw_path}") + object.__setattr__( + self, "params", router_data.get(constants.RouteVar.QUERY, {}) + ) + else: + object.__setattr__(self, "host", "") + object.__setattr__(self, "path", "") + object.__setattr__(self, "raw_path", "") + object.__setattr__(self, "full_path", "") + object.__setattr__(self, "full_raw_path", "") + object.__setattr__(self, "params", {}) + + +@dataclasses.dataclass(frozen=True, init=False) +class SessionData: + """An object containing session data.""" + + client_token: str = "" + client_ip: str = "" + session_id: str = "" + + def __init__(self, router_data: Optional[dict] = None): + """Initalize the SessionData object based on router_data. + + Args: + router_data: the router_data dict. + """ + if router_data: + client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "") + client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "") + session_id = router_data.get(constants.RouteVar.SESSION_ID, "") + else: + client_token = client_ip = session_id = "" + object.__setattr__(self, "client_token", client_token) + object.__setattr__(self, "client_ip", client_ip) + object.__setattr__(self, "session_id", session_id) + + +@dataclasses.dataclass(frozen=True, init=False) +class RouterData: + """An object containing RouterData.""" + + session: SessionData = dataclasses.field(default_factory=SessionData) + headers: HeaderData = dataclasses.field(default_factory=HeaderData) + page: PageData = dataclasses.field(default_factory=PageData) + + def __init__(self, router_data: Optional[dict] = None): + """Initialize the RouterData object. + + Args: + router_data: the router_data dict. + """ + object.__setattr__(self, "session", SessionData(router_data)) + object.__setattr__(self, "headers", HeaderData(router_data)) + object.__setattr__(self, "page", PageData(router_data)) diff --git a/reflex/state.py b/reflex/state.py index 1746835be..b1988e38a 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -38,6 +38,7 @@ from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self from reflex.config import get_config +from reflex.istate.data import RouterData from reflex.vars.base import ( ComputedVar, DynamicRouteVar, @@ -93,125 +94,6 @@ var = computed_var TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb -@dataclasses.dataclass(frozen=True) -class HeaderData: - """An object containing headers data.""" - - host: str = "" - origin: str = "" - upgrade: str = "" - connection: str = "" - cookie: str = "" - pragma: str = "" - cache_control: str = "" - user_agent: str = "" - sec_websocket_version: str = "" - sec_websocket_key: str = "" - sec_websocket_extensions: str = "" - accept_encoding: str = "" - accept_language: str = "" - - def __init__(self, router_data: Optional[dict] = None): - """Initalize the HeaderData object based on router_data. - - Args: - router_data: the router_data dict. - """ - if router_data: - for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items(): - object.__setattr__(self, format.to_snake_case(k), v) - else: - for k in dataclasses.fields(self): - object.__setattr__(self, k.name, "") - - -@dataclasses.dataclass(frozen=True) -class PageData: - """An object containing page data.""" - - host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) - path: str = "" - raw_path: str = "" - full_path: str = "" - full_raw_path: str = "" - params: dict = dataclasses.field(default_factory=dict) - - def __init__(self, router_data: Optional[dict] = None): - """Initalize the PageData object based on router_data. - - Args: - router_data: the router_data dict. - """ - if router_data: - object.__setattr__( - self, - "host", - router_data.get(constants.RouteVar.HEADERS, {}).get("origin", ""), - ) - object.__setattr__( - self, "path", router_data.get(constants.RouteVar.PATH, "") - ) - object.__setattr__( - self, "raw_path", router_data.get(constants.RouteVar.ORIGIN, "") - ) - object.__setattr__(self, "full_path", f"{self.host}{self.path}") - object.__setattr__(self, "full_raw_path", f"{self.host}{self.raw_path}") - object.__setattr__( - self, "params", router_data.get(constants.RouteVar.QUERY, {}) - ) - else: - object.__setattr__(self, "host", "") - object.__setattr__(self, "path", "") - object.__setattr__(self, "raw_path", "") - object.__setattr__(self, "full_path", "") - object.__setattr__(self, "full_raw_path", "") - object.__setattr__(self, "params", {}) - - -@dataclasses.dataclass(frozen=True, init=False) -class SessionData: - """An object containing session data.""" - - client_token: str = "" - client_ip: str = "" - session_id: str = "" - - def __init__(self, router_data: Optional[dict] = None): - """Initalize the SessionData object based on router_data. - - Args: - router_data: the router_data dict. - """ - if router_data: - client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "") - client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "") - session_id = router_data.get(constants.RouteVar.SESSION_ID, "") - else: - client_token = client_ip = session_id = "" - object.__setattr__(self, "client_token", client_token) - object.__setattr__(self, "client_ip", client_ip) - object.__setattr__(self, "session_id", session_id) - - -@dataclasses.dataclass(frozen=True, init=False) -class RouterData: - """An object containing RouterData.""" - - session: SessionData = dataclasses.field(default_factory=SessionData) - headers: HeaderData = dataclasses.field(default_factory=HeaderData) - page: PageData = dataclasses.field(default_factory=PageData) - - def __init__(self, router_data: Optional[dict] = None): - """Initialize the RouterData object. - - Args: - router_data: the router_data dict. - """ - object.__setattr__(self, "session", SessionData(router_data)) - object.__setattr__(self, "headers", HeaderData(router_data)) - object.__setattr__(self, "page", PageData(router_data)) - - def _no_chain_background_task( state_cls: Type["BaseState"], name: str, fn: Callable ) -> Callable: From a66e0f2e11bea56c1664fe1ae740b7d81d8e88f1 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 3 Oct 2024 14:18:53 -0700 Subject: [PATCH 09/16] [ENG-3870] rx.call_script with f-string var produces incorrect code (#4039) * Add additional test cases for rx.call_script Include internal vars inside an f-string to be properly rendered on the backend and frontend. * [ENG-3870] rx.call_script with f-string var produces incorrect code Avoid casting javascript code with embedded Var as LiteralStringVar There are two cases that need to be handled: 1. The javascript code contains Vars with VarData, these can only be evaluated in the component context, since they may use hooks. Vars with VarData cannot be used from the backend. In this case, we cast the given code as a raw js expression and include the extracted VarData. 2. The javascript code has no VarData. In this case, we pass the code as the raw js expression and cast to a python str to get a js literal string to eval. * use VarData.__bool__ instead of `is None` --- reflex/event.py | 10 ++ tests/integration/test_call_script.py | 159 ++++++++++++++++++++++++++ 2 files changed, 169 insertions(+) diff --git a/reflex/event.py b/reflex/event.py index 95358ace1..9b54eddeb 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -839,6 +839,16 @@ def call_script( ), ), } + if isinstance(javascript_code, str): + # When there is VarData, include it and eval the JS code inline on the client. + javascript_code, original_code = ( + LiteralVar.create(javascript_code), + javascript_code, + ) + if not javascript_code._get_all_var_data(): + # Without VarData, cast to string and eval the code in the event loop. + javascript_code = str(Var(_js_expr=original_code)) + return server_side( "_call_script", get_fn_signature(call_script), diff --git a/tests/integration/test_call_script.py b/tests/integration/test_call_script.py index 5a3b83abf..744d83d16 100644 --- a/tests/integration/test_call_script.py +++ b/tests/integration/test_call_script.py @@ -46,6 +46,7 @@ def CallScript(): inline_counter: int = 0 external_counter: int = 0 value: str = "Initial" + last_result: str = "" def call_script_callback(self, result): self.results.append(result) @@ -137,6 +138,32 @@ def CallScript(): callback=CallScriptState.set_external_counter, # type: ignore ) + def call_with_var_f_string(self): + return rx.call_script( + f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}", + callback=CallScriptState.set_last_result, # type: ignore + ) + + def call_with_var_str_cast(self): + return rx.call_script( + f"{str(rx.Var('inline_counter'))} + {str(rx.Var('external_counter'))}", + callback=CallScriptState.set_last_result, # type: ignore + ) + + def call_with_var_f_string_wrapped(self): + return rx.call_script( + rx.Var(f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}"), + callback=CallScriptState.set_last_result, # type: ignore + ) + + def call_with_var_str_cast_wrapped(self): + return rx.call_script( + rx.Var( + f"{str(rx.Var('inline_counter'))} + {str(rx.Var('external_counter'))}" + ), + callback=CallScriptState.set_last_result, # type: ignore + ) + def reset_(self): yield rx.call_script("inline_counter = 0; external_counter = 0") self.reset() @@ -234,6 +261,68 @@ def CallScript(): id="update_value", ), rx.button("Reset", id="reset", on_click=CallScriptState.reset_), + rx.input( + value=CallScriptState.last_result, + id="last_result", + read_only=True, + on_click=CallScriptState.set_last_result(""), # type: ignore + ), + rx.button( + "call_with_var_f_string", + on_click=CallScriptState.call_with_var_f_string, + id="call_with_var_f_string", + ), + rx.button( + "call_with_var_str_cast", + on_click=CallScriptState.call_with_var_str_cast, + id="call_with_var_str_cast", + ), + rx.button( + "call_with_var_f_string_wrapped", + on_click=CallScriptState.call_with_var_f_string_wrapped, + id="call_with_var_f_string_wrapped", + ), + rx.button( + "call_with_var_str_cast_wrapped", + on_click=CallScriptState.call_with_var_str_cast_wrapped, + id="call_with_var_str_cast_wrapped", + ), + rx.button( + "call_with_var_f_string_inline", + on_click=rx.call_script( + f"{rx.Var('inline_counter')} + {CallScriptState.last_result}", + callback=CallScriptState.set_last_result, # type: ignore + ), + id="call_with_var_f_string_inline", + ), + rx.button( + "call_with_var_str_cast_inline", + on_click=rx.call_script( + f"{str(rx.Var('inline_counter'))} + {str(rx.Var('external_counter'))}", + callback=CallScriptState.set_last_result, # type: ignore + ), + id="call_with_var_str_cast_inline", + ), + rx.button( + "call_with_var_f_string_wrapped_inline", + on_click=rx.call_script( + rx.Var( + f"{rx.Var('inline_counter')} + {CallScriptState.last_result}" + ), + callback=CallScriptState.set_last_result, # type: ignore + ), + id="call_with_var_f_string_wrapped_inline", + ), + rx.button( + "call_with_var_str_cast_wrapped_inline", + on_click=rx.call_script( + rx.Var( + f"{str(rx.Var('inline_counter'))} + {str(rx.Var('external_counter'))}" + ), + callback=CallScriptState.set_last_result, # type: ignore + ), + id="call_with_var_str_cast_wrapped_inline", + ), ) @@ -363,3 +452,73 @@ def test_call_script( call_script.poll_for_content(update_value_button, exp_not_equal="Initial") == "updated" ) + + +def test_call_script_w_var( + call_script: AppHarness, + driver: WebDriver, +): + """Test evaluating javascript expressions containing Vars. + + Args: + call_script: harness for CallScript app. + driver: WebDriver instance. + """ + assert_token(driver) + last_result = driver.find_element(By.ID, "last_result") + assert last_result.get_attribute("value") == "" + + inline_return_button = driver.find_element(By.ID, "inline_return") + + call_with_var_f_string_button = driver.find_element(By.ID, "call_with_var_f_string") + call_with_var_str_cast_button = driver.find_element(By.ID, "call_with_var_str_cast") + call_with_var_f_string_wrapped_button = driver.find_element( + By.ID, "call_with_var_f_string_wrapped" + ) + call_with_var_str_cast_wrapped_button = driver.find_element( + By.ID, "call_with_var_str_cast_wrapped" + ) + call_with_var_f_string_inline_button = driver.find_element( + By.ID, "call_with_var_f_string_inline" + ) + call_with_var_str_cast_inline_button = driver.find_element( + By.ID, "call_with_var_str_cast_inline" + ) + call_with_var_f_string_wrapped_inline_button = driver.find_element( + By.ID, "call_with_var_f_string_wrapped_inline" + ) + call_with_var_str_cast_wrapped_inline_button = driver.find_element( + By.ID, "call_with_var_str_cast_wrapped_inline" + ) + + inline_return_button.click() + call_with_var_f_string_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="") == "1" + + inline_return_button.click() + call_with_var_str_cast_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="1") == "2" + + inline_return_button.click() + call_with_var_f_string_wrapped_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="2") == "3" + + inline_return_button.click() + call_with_var_str_cast_wrapped_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="3") == "4" + + inline_return_button.click() + call_with_var_f_string_inline_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="4") == "9" + + inline_return_button.click() + call_with_var_str_cast_inline_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="9") == "6" + + inline_return_button.click() + call_with_var_f_string_wrapped_inline_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="6") == "13" + + inline_return_button.click() + call_with_var_str_cast_wrapped_inline_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="13") == "8" From ad0827c59ce249e94330c0b0c7dd414d7c2e8b25 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 3 Oct 2024 14:25:21 -0700 Subject: [PATCH 10/16] bundle chakra in window for CSR (#4042) * bundle chakra in window for CSR * remove repeated chakra ui reference * use dynamically generated libraries * remove js from it --- .../.templates/jinja/web/pages/_app.js.jinja2 | 14 ++++---- reflex/compiler/compiler.py | 24 ++++++++++++++ reflex/components/dynamic.py | 33 ++++++++++++++----- reflex/utils/exceptions.py | 4 +++ 4 files changed, 59 insertions(+), 16 deletions(-) diff --git a/reflex/.templates/jinja/web/pages/_app.js.jinja2 b/reflex/.templates/jinja/web/pages/_app.js.jinja2 index 97c31925d..c893e19e2 100644 --- a/reflex/.templates/jinja/web/pages/_app.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/_app.js.jinja2 @@ -7,10 +7,9 @@ import '/styles/styles.css' {% block declaration %} import { EventLoopProvider, StateProvider, defaultColorMode } from "/utils/context.js"; import { ThemeProvider } from 'next-themes' -import * as React from "react"; -import * as utils_context from "/utils/context.js"; -import * as utils_state from "/utils/state.js"; -import * as radix from "@radix-ui/themes"; +{% for library_alias, library_path in window_libraries %} +import * as {{library_alias}} from "{{library_path}}"; +{% endfor %} {% for custom_code in custom_codes %} {{custom_code}} @@ -33,10 +32,9 @@ export default function MyApp({ Component, pageProps }) { React.useEffect(() => { // Make contexts and state objects available globally for dynamic eval'd components let windowImports = { - "react": React, - "@radix-ui/themes": radix, - "/utils/context": utils_context, - "/utils/state": utils_state, +{% for library_alias, library_path in window_libraries %} + "{{library_path}}": {{library_alias}}, +{% endfor %} }; window["__reflex"] = windowImports; }, []); diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 343bda3e1..0c29f941d 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -40,6 +40,20 @@ def _compile_document_root(root: Component) -> str: ) +def _normalize_library_name(lib: str) -> str: + """Normalize the library name. + + Args: + lib: The library name to normalize. + + Returns: + The normalized library name. + """ + if lib == "react": + return "React" + return lib.replace("@", "").replace("/", "_").replace("-", "_") + + def _compile_app(app_root: Component) -> str: """Compile the app template component. @@ -49,10 +63,20 @@ def _compile_app(app_root: Component) -> str: Returns: The compiled app. """ + from reflex.components.dynamic import bundled_libraries + + window_libraries = [ + (_normalize_library_name(name), name) for name in bundled_libraries + ] + [ + ("utils_context", f"/{constants.Dirs.UTILS}/context"), + ("utils_state", f"/{constants.Dirs.UTILS}/state"), + ] + return templates.APP_ROOT.render( imports=utils.compile_imports(app_root._get_all_imports()), custom_codes=app_root._get_all_custom_code(), hooks={**app_root._get_all_hooks_internal(), **app_root._get_all_hooks()}, + window_libraries=window_libraries, render=app_root.render(), ) diff --git a/reflex/components/dynamic.py b/reflex/components/dynamic.py index 390b6e688..2e336027b 100644 --- a/reflex/components/dynamic.py +++ b/reflex/components/dynamic.py @@ -1,12 +1,18 @@ """Components that are dynamically generated on the backend.""" +from typing import TYPE_CHECKING + from reflex import constants from reflex.utils import imports +from reflex.utils.exceptions import DynamicComponentMissingLibrary from reflex.utils.format import format_library_name from reflex.utils.serializers import serializer from reflex.vars import Var, get_unique_variable_name from reflex.vars.base import VarData, transform +if TYPE_CHECKING: + from reflex.components.component import Component + def get_cdn_url(lib: str) -> str: """Get the CDN URL for a library. @@ -20,6 +26,23 @@ def get_cdn_url(lib: str) -> str: return f"https://cdn.jsdelivr.net/npm/{lib}" + "/+esm" +bundled_libraries = {"react", "@radix-ui/themes"} + + +def bundle_library(component: "Component"): + """Bundle a library with the component. + + Args: + component: The component to bundle the library with. + + Raises: + DynamicComponentMissingLibrary: Raised when a dynamic component is missing a library. + """ + if component.library is None: + raise DynamicComponentMissingLibrary("Component must have a library to bundle.") + bundled_libraries.add(format_library_name(component.library)) + + def load_dynamic_serializer(): """Load the serializer for dynamic components.""" # Causes a circular import, so we import here. @@ -58,10 +81,7 @@ def load_dynamic_serializer(): ) ] = None - libs_in_window = [ - "react", - "@radix-ui/themes", - ] + libs_in_window = bundled_libraries imports = {} for lib, names in component._get_all_imports().items(): @@ -69,10 +89,7 @@ def load_dynamic_serializer(): if ( not lib.startswith((".", "/")) and not lib.startswith("http") - and all( - formatted_lib_name != lib_in_window - for lib_in_window in libs_in_window - ) + and formatted_lib_name not in libs_in_window ): imports[get_cdn_url(lib)] = names else: diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 9c79a387a..0383f7ba6 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -117,5 +117,9 @@ class InvalidLifespanTaskType(ReflexError, TypeError): """Raised when an invalid task type is registered as a lifespan task.""" +class DynamicComponentMissingLibrary(ReflexError, ValueError): + """Raised when a dynamic component is missing a library.""" + + class SetUndefinedStateVarError(ReflexError, AttributeError): """Raised when setting the value of a var without first declaring it.""" From 73e8a4e0abfd2fb63fae10b69a9af830bf7a5c93 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 3 Oct 2024 15:33:51 -0700 Subject: [PATCH 11/16] support eventspec/eventchain in var operations (#4038) --- reflex/.templates/web/utils/state.js | 22 ++- reflex/app.py | 4 +- reflex/components/component.py | 50 ++++-- reflex/event.py | 193 ++++++++++++++++++++- reflex/utils/format.py | 14 +- reflex/vars/base.py | 81 +++++---- tests/units/components/base/test_script.py | 6 +- tests/units/components/test_component.py | 8 +- tests/units/utils/test_format.py | 22 ++- 9 files changed, 310 insertions(+), 90 deletions(-) diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 78e671809..0fe0db8c1 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -544,13 +544,19 @@ export const uploadFiles = async ( /** * Create an event object. - * @param name The name of the event. - * @param payload The payload of the event. - * @param handler The client handler to process event. + * @param {string} name The name of the event. + * @param {Object.} payload The payload of the event. + * @param {Object.} event_actions The actions to take on the event. + * @param {string} handler The client handler to process event. * @returns The event object. */ -export const Event = (name, payload = {}, handler = null) => { - return { name, payload, handler }; +export const Event = ( + name, + payload = {}, + event_actions = {}, + handler = null +) => { + return { name, payload, handler, event_actions }; }; /** @@ -676,6 +682,12 @@ export const useEventLoop = ( if (!(args instanceof Array)) { args = [args]; } + + event_actions = events.reduce( + (acc, e) => ({ ...acc, ...e.event_actions }), + event_actions ?? {} + ); + const _e = args.filter((o) => o?.preventDefault !== undefined)[0]; if (event_actions?.preventDefault && _e?.preventDefault) { diff --git a/reflex/app.py b/reflex/app.py index 111dd9dfd..d8a6f2590 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1536,7 +1536,9 @@ class EventNamespace(AsyncNamespace): """ fields = json.loads(data) # Get the event. - event = Event(**{k: v for k, v in fields.items() if k != "handler"}) + event = Event( + **{k: v for k, v in fields.items() if k not in ("handler", "event_actions")} + ) self.token_to_sid[event.token] = sid self.sid_to_token[sid] = event.token diff --git a/reflex/components/component.py b/reflex/components/component.py index 9bdd12f0e..26ea2fd3f 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -38,8 +38,10 @@ from reflex.constants import ( ) from reflex.event import ( EventChain, + EventChainVar, EventHandler, EventSpec, + EventVar, call_event_fn, call_event_handler, get_handler_args, @@ -514,7 +516,7 @@ class Component(BaseComponent, ABC): Var, EventHandler, EventSpec, - List[Union[EventHandler, EventSpec]], + List[Union[EventHandler, EventSpec, EventVar]], Callable, ], ) -> Union[EventChain, Var]: @@ -532,11 +534,16 @@ class Component(BaseComponent, ABC): """ # If it's an event chain var, return it. if isinstance(value, Var): - if value._var_type is not EventChain: + if isinstance(value, EventChainVar): + return value + elif isinstance(value, EventVar): + value = [value] + elif issubclass(value._var_type, (EventChain, EventSpec)): + return self._create_event_chain(args_spec, value.guess_type()) + else: raise ValueError( - f"Invalid event chain: {repr(value)} of type {type(value)}" + f"Invalid event chain: {str(value)} of type {value._var_type}" ) - return value elif isinstance(value, EventChain): # Trust that the caller knows what they're doing passing an EventChain directly return value @@ -547,7 +554,7 @@ class Component(BaseComponent, ABC): # If the input is a list of event handlers, create an event chain. if isinstance(value, List): - events: list[EventSpec] = [] + events: List[Union[EventSpec, EventVar]] = [] for v in value: if isinstance(v, (EventHandler, EventSpec)): # Call the event handler to get the event. @@ -561,6 +568,8 @@ class Component(BaseComponent, ABC): "lambda inside an EventChain list." ) events.extend(result) + elif isinstance(v, EventVar): + events.append(v) else: raise ValueError(f"Invalid event: {v}") @@ -570,32 +579,30 @@ class Component(BaseComponent, ABC): if isinstance(result, Var): # Recursively call this function if the lambda returned an EventChain Var. return self._create_event_chain(args_spec, result) - events = result + events = [*result] # Otherwise, raise an error. else: raise ValueError(f"Invalid event chain: {value}") # Add args to the event specs if necessary. - events = [e.with_args(get_handler_args(e)) for e in events] - - # Collect event_actions from each spec - event_actions = {} - for e in events: - event_actions.update(e.event_actions) + events = [ + (e.with_args(get_handler_args(e)) if isinstance(e, EventSpec) else e) + for e in events + ] # Return the event chain. if isinstance(args_spec, Var): return EventChain( events=events, args_spec=None, - event_actions=event_actions, + event_actions={}, ) else: return EventChain( events=events, args_spec=args_spec, - event_actions=event_actions, + event_actions={}, ) def get_event_triggers(self) -> Dict[str, Any]: @@ -1030,8 +1037,11 @@ class Component(BaseComponent, ABC): elif isinstance(event, EventChain): event_args = [] for spec in event.events: - for args in spec.args: - event_args.extend(args) + if isinstance(spec, EventSpec): + for args in spec.args: + event_args.extend(args) + else: + event_args.append(spec) yield event_trigger, event_args def _get_vars(self, include_children: bool = False) -> list[Var]: @@ -1105,8 +1115,12 @@ class Component(BaseComponent, ABC): for trigger in self.event_triggers.values(): if isinstance(trigger, EventChain): for event in trigger.events: - if event.handler.state_full_name: - return True + if isinstance(event, EventSpec): + if event.handler.state_full_name: + return True + else: + if event._var_state: + return True elif isinstance(trigger, Var) and trigger._var_state: return True return False diff --git a/reflex/event.py b/reflex/event.py index 9b54eddeb..7384cf5bf 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -4,16 +4,19 @@ from __future__ import annotations import dataclasses import inspect +import sys import types import urllib.parse from base64 import b64encode from typing import ( Any, Callable, + ClassVar, Dict, List, Optional, Tuple, + Type, Union, get_type_hints, ) @@ -25,8 +28,15 @@ from reflex.utils import format from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch from reflex.utils.types import ArgsSpec, GenericType from reflex.vars import VarData -from reflex.vars.base import LiteralVar, Var -from reflex.vars.function import FunctionStringVar, FunctionVar +from reflex.vars.base import ( + CachedVarOperation, + LiteralNoneVar, + LiteralVar, + ToOperation, + Var, + cached_property_no_lock, +) +from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar from reflex.vars.object import ObjectVar try: @@ -375,7 +385,7 @@ class CallableEventSpec(EventSpec): class EventChain(EventActionsMixin): """Container for a chain of events that will be executed in order.""" - events: List[EventSpec] = dataclasses.field(default_factory=list) + events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list) args_spec: Optional[Callable] = dataclasses.field(default=None) @@ -478,7 +488,7 @@ class FileUpload: if isinstance(events, Var): raise ValueError(f"{on_upload_progress} cannot return a var {events}.") on_upload_progress_chain = EventChain( - events=events, + events=[*events], args_spec=self.on_upload_progress_args_spec, ) formatted_chain = str(format.format_prop(on_upload_progress_chain)) @@ -1136,3 +1146,178 @@ def get_fn_signature(fn: Callable) -> inspect.Signature: "state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any ) return signature.replace(parameters=(new_param, *signature.parameters.values())) + + +class EventVar(ObjectVar): + """Base class for event vars.""" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar): + """A literal event var.""" + + _var_value: EventSpec = dataclasses.field(default=None) # type: ignore + + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._js_expr)) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str( + FunctionStringVar("Event").call( + # event handler name + ".".join( + filter( + None, + format.get_event_handler_parts(self._var_value.handler), + ) + ), + # event handler args + {str(name): value for name, value in self._var_value.args}, + # event actions + self._var_value.event_actions, + # client handler name + *( + [self._var_value.client_handler_name] + if self._var_value.client_handler_name + else [] + ), + ) + ) + + @classmethod + def create( + cls, + value: EventSpec, + _var_data: VarData | None = None, + ) -> LiteralEventVar: + """Create a new LiteralEventVar instance. + + Args: + value: The value of the var. + _var_data: The data of the var. + + Returns: + The created LiteralEventVar instance. + """ + return cls( + _js_expr="", + _var_type=EventSpec, + _var_data=_var_data, + _var_value=value, + ) + + +class EventChainVar(FunctionVar): + """Base class for event chain vars.""" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar): + """A literal event chain var.""" + + _var_value: EventChain = dataclasses.field(default=None) # type: ignore + + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._js_expr)) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + sig = inspect.signature(self._var_value.args_spec) # type: ignore + if sig.parameters: + arg_def = tuple((f"_{p}" for p in sig.parameters)) + arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) + else: + # add a default argument for addEvents if none were specified in value.args_spec + # used to trigger the preventDefault() on the event. + arg_def = ("...args",) + arg_def_expr = Var(_js_expr="args") + + return str( + ArgsFunctionOperation.create( + arg_def, + FunctionStringVar.create("addEvents").call( + LiteralVar.create( + [LiteralVar.create(event) for event in self._var_value.events] + ), + arg_def_expr, + self._var_value.event_actions, + ), + ) + ) + + @classmethod + def create( + cls, + value: EventChain, + _var_data: VarData | None = None, + ) -> LiteralEventChainVar: + """Create a new LiteralEventChainVar instance. + + Args: + value: The value of the var. + _var_data: The data of the var. + + Returns: + The created LiteralEventChainVar instance. + """ + return cls( + _js_expr="", + _var_type=EventChain, + _var_data=_var_data, + _var_value=value, + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToEventVarOperation(ToOperation, EventVar): + """Result of a cast to an event var.""" + + _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) + + _default_var_type: ClassVar[Type] = EventSpec + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToEventChainVarOperation(ToOperation, EventChainVar): + """Result of a cast to an event chain var.""" + + _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) + + _default_var_type: ClassVar[Type] = EventChain diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 4029bd275..65c0f049b 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -359,19 +359,7 @@ def format_prop( # Handle event props. if isinstance(prop, EventChain): - sig = inspect.signature(prop.args_spec) # type: ignore - if sig.parameters: - arg_def = ",".join(f"_{p}" for p in sig.parameters) - arg_def_expr = f"[{arg_def}]" - else: - # add a default argument for addEvents if none were specified in prop.args_spec - # used to trigger the preventDefault() on the event. - arg_def = "...args" - arg_def_expr = "args" - - chain = ",".join([format_event(event) for event in prop.events]) - event = f"addEvents([{chain}], {arg_def_expr}, {json_dumps(prop.event_actions)})" - prop = f"({arg_def}) => {event}" + return str(Var.create(prop)) # Handle other types. elif isinstance(prop, str): diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 2d78a14be..0f8a80f8d 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -385,6 +385,15 @@ class Var(Generic[VAR_TYPE]): Returns: The converted var. """ + from reflex.event import ( + EventChain, + EventChainVar, + EventSpec, + EventVar, + ToEventChainVarOperation, + ToEventVarOperation, + ) + from .function import FunctionVar, ToFunctionOperation from .number import ( BooleanVar, @@ -416,6 +425,10 @@ class Var(Generic[VAR_TYPE]): return self.to(BooleanVar, output) if fixed_output_type is None: return ToNoneOperation.create(self) + if fixed_output_type is EventSpec: + return self.to(EventVar, output) + if fixed_output_type is EventChain: + return self.to(EventChainVar, output) if issubclass(fixed_output_type, Base): return self.to(ObjectVar, output) if dataclasses.is_dataclass(fixed_output_type) and not issubclass( @@ -453,10 +466,13 @@ class Var(Generic[VAR_TYPE]): if issubclass(output, StringVar): return ToStringOperation.create(self, var_type or str) - if issubclass(output, (ObjectVar, Base)): - return ToObjectOperation.create(self, var_type or dict) + if issubclass(output, EventVar): + return ToEventVarOperation.create(self, var_type or EventSpec) - if dataclasses.is_dataclass(output): + if issubclass(output, EventChainVar): + return ToEventChainVarOperation.create(self, var_type or EventChain) + + if issubclass(output, (ObjectVar, Base)): return ToObjectOperation.create(self, var_type or dict) if issubclass(output, FunctionVar): @@ -469,6 +485,9 @@ class Var(Generic[VAR_TYPE]): if issubclass(output, NoneVar): return ToNoneOperation.create(self) + if dataclasses.is_dataclass(output): + return ToObjectOperation.create(self, var_type or dict) + # If we can't determine the first argument, we just replace the _var_type. if not issubclass(output, Var) or var_type is None: return dataclasses.replace( @@ -494,6 +513,8 @@ class Var(Generic[VAR_TYPE]): Raises: TypeError: If the type is not supported for guessing. """ + from reflex.event import EventChain, EventChainVar, EventSpec, EventVar + from .number import BooleanVar, NumberVar from .object import ObjectVar from .sequence import ArrayVar, StringVar @@ -539,6 +560,10 @@ class Var(Generic[VAR_TYPE]): return self.to(ArrayVar, self._var_type) if issubclass(fixed_type, str): return self.to(StringVar, self._var_type) + if issubclass(fixed_type, EventSpec): + return self.to(EventVar, self._var_type) + if issubclass(fixed_type, EventChain): + return self.to(EventChainVar, self._var_type) if issubclass(fixed_type, Base): return self.to(ObjectVar, self._var_type) if dataclasses.is_dataclass(fixed_type): @@ -1029,47 +1054,22 @@ class LiteralVar(Var): if value is None: return LiteralNoneVar.create(_var_data=_var_data) - from reflex.event import EventChain, EventHandler, EventSpec + from reflex.event import ( + EventChain, + EventHandler, + EventSpec, + LiteralEventChainVar, + LiteralEventVar, + ) from reflex.utils.format import get_event_handler_parts - from .function import ArgsFunctionOperation, FunctionStringVar from .object import LiteralObjectVar if isinstance(value, EventSpec): - event_name = LiteralVar.create( - ".".join(filter(None, get_event_handler_parts(value.handler))) - ) - event_args = LiteralVar.create( - {str(name): value for name, value in value.args} - ) - event_client_name = LiteralVar.create(value.client_handler_name) - return FunctionStringVar("Event").call( - event_name, - event_args, - *([event_client_name] if value.client_handler_name else []), - ) + return LiteralEventVar.create(value, _var_data=_var_data) if isinstance(value, EventChain): - sig = inspect.signature(value.args_spec) # type: ignore - if sig.parameters: - arg_def = tuple((f"_{p}" for p in sig.parameters)) - arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) - else: - # add a default argument for addEvents if none were specified in value.args_spec - # used to trigger the preventDefault() on the event. - arg_def = ("...args",) - arg_def_expr = Var(_js_expr="args") - - return ArgsFunctionOperation.create( - arg_def, - FunctionStringVar.create("addEvents").call( - LiteralVar.create( - [LiteralVar.create(event) for event in value.events] - ), - arg_def_expr, - LiteralVar.create(value.event_actions), - ), - ) + return LiteralEventChainVar.create(value, _var_data=_var_data) if isinstance(value, EventHandler): return Var(_js_expr=".".join(filter(None, get_event_handler_parts(value)))) @@ -2126,9 +2126,16 @@ class NoneVar(Var[None]): """A var representing None.""" +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) class LiteralNoneVar(LiteralVar, NoneVar): """A var representing None.""" + _var_value: None = None + def json(self) -> str: """Serialize the var to a JSON string. diff --git a/tests/units/components/base/test_script.py b/tests/units/components/base/test_script.py index c6b67da11..be62276f2 100644 --- a/tests/units/components/base/test_script.py +++ b/tests/units/components/base/test_script.py @@ -58,14 +58,14 @@ def test_script_event_handler(): ) render_dict = component.render() assert ( - f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{ }})))], args, ({{ }})))))}}' + f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{ }}), ({{ }})))], args, ({{ }})))))}}' in render_dict["props"] ) assert ( - f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{ }})))], args, ({{ }})))))}}' + f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{ }}), ({{ }})))], args, ({{ }})))))}}' in render_dict["props"] ) assert ( - f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{ }})))], args, ({{ }})))))}}' + f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{ }}), ({{ }})))], args, ({{ }})))))}}' in render_dict["props"] ) diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 73d3f611b..5e94db052 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -832,7 +832,7 @@ def test_component_event_trigger_arbitrary_args(): assert comp.render()["props"][0] == ( "onFoo={((__e, _alpha, _bravo, _charlie) => ((addEvents(" - f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }})))], ' + f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }}), ({{ }})))], ' "[__e, _alpha, _bravo, _charlie], ({ })))))}" ) @@ -1178,7 +1178,7 @@ TEST_VAR = LiteralVar.create("test")._replace( ) FORMATTED_TEST_VAR = LiteralVar.create(f"foo{TEST_VAR}bar") STYLE_VAR = TEST_VAR._replace(_js_expr="style") -EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain) +EVENT_CHAIN_VAR = TEST_VAR.to(EventChain) ARG_VAR = Var(_js_expr="arg") TEST_VAR_DICT_OF_DICT = LiteralVar.create({"a": {"b": "test"}})._replace( @@ -2159,7 +2159,7 @@ class TriggerState(rx.State): rx.text("random text", on_click=TriggerState.do_something), rx.text( "random text", - on_click=Var(_js_expr="toggleColorMode", _var_type=EventChain), + on_click=Var(_js_expr="toggleColorMode").to(EventChain), ), ), True, @@ -2169,7 +2169,7 @@ class TriggerState(rx.State): rx.text("random text", on_click=rx.console_log("log")), rx.text( "random text", - on_click=Var(_js_expr="toggleColorMode", _var_type=EventChain), + on_click=Var(_js_expr="toggleColorMode").to(EventChain), ), ), False, diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py index 042c3f323..d7b0c791e 100644 --- a/tests/units/utils/test_format.py +++ b/tests/units/utils/test_format.py @@ -374,7 +374,7 @@ def test_format_match( events=[EventSpec(handler=EventHandler(fn=mock_event))], args_spec=lambda: [], ), - '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ })))))', + '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ })))))', ), ( EventChain( @@ -395,7 +395,7 @@ def test_format_match( ], args_spec=lambda e: [e.target.value], ), - '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] })))], [_e], ({ })))))', + '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] }), ({ })))], [_e], ({ })))))', ), ( EventChain( @@ -403,7 +403,19 @@ def test_format_match( args_spec=lambda: [], event_actions={"stopPropagation": True}, ), - '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["stopPropagation"] : true })))))', + '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["stopPropagation"] : true })))))', + ), + ( + EventChain( + events=[ + EventSpec( + handler=EventHandler(fn=mock_event), + event_actions={"stopPropagation": True}, + ) + ], + args_spec=lambda: [], + ), + '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ ["stopPropagation"] : true })))], args, ({ })))))', ), ( EventChain( @@ -411,7 +423,7 @@ def test_format_match( args_spec=lambda: [], event_actions={"preventDefault": True}, ), - '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["preventDefault"] : true })))))', + '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["preventDefault"] : true })))))', ), ({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'), (Var(_js_expr="var", _var_type=int).guess_type(), "var"), @@ -519,7 +531,7 @@ def test_format_event_handler(input, output): [ ( EventSpec(handler=EventHandler(fn=mock_event)), - '(Event("mock_event", ({ })))', + '(Event("mock_event", ({ }), ({ })))', ), ], ) From 0f8630fb2df80873a248508d4bf112e354364a77 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 3 Oct 2024 15:58:04 -0700 Subject: [PATCH 12/16] remove var operation error (#4053) * remove var operation error * dang it darglint --- reflex/app.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index d8a6f2590..584b8a321 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -431,25 +431,12 @@ class App(MiddlewareMixin, LifespanMixin, Base): The generated component. Raises: - VarOperationTypeError: When an invalid component var related function is passed. - TypeError: When an invalid component function is passed. exceptions.MatchTypeError: If the return types of match cases in rx.match are different. """ - from reflex.utils.exceptions import VarOperationTypeError - try: return component if isinstance(component, Component) else component() except exceptions.MatchTypeError: raise - except TypeError as e: - message = str(e) - if "Var" in message: - raise VarOperationTypeError( - "You may be trying to use an invalid Python function on a state var. " - "When referencing a var inside your render code, only limited var operations are supported. " - "See the var operation docs here: https://reflex.dev/docs/vars/var-operations/" - ) from e - raise e def add_page( self, From fafdeb892e575ac89315c4e632c15d9abedb685f Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 3 Oct 2024 15:58:42 -0700 Subject: [PATCH 13/16] Include emotion inside of dynamic components (#4052) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bundle chakra in window for CSR * remove repeated chakra ui reference * use dynamically generated libraries * remove js from it * include emotion react for dynamic components * make code more readable Co-authored-by: Thomas Brandého * jsx yea * what --------- Co-authored-by: Thomas Brandého --- reflex/components/dynamic.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/reflex/components/dynamic.py b/reflex/components/dynamic.py index 2e336027b..8d0bab669 100644 --- a/reflex/components/dynamic.py +++ b/reflex/components/dynamic.py @@ -26,7 +26,11 @@ def get_cdn_url(lib: str) -> str: return f"https://cdn.jsdelivr.net/npm/{lib}" + "/+esm" -bundled_libraries = {"react", "@radix-ui/themes"} +bundled_libraries = { + "react", + "@radix-ui/themes", + "@emotion/react", +} def bundle_library(component: "Component"): @@ -127,7 +131,14 @@ def load_dynamic_serializer(): module_code_lines.insert(0, "const React = window.__reflex.react;") - return "//__reflex_evaluate\n" + "\n".join(module_code_lines) + return "\n".join( + [ + "//__reflex_evaluate", + "/** @jsx jsx */", + "const { jsx } = window.__reflex['@emotion/react']", + *module_code_lines, + ] + ) @transform def evaluate_component(js_string: Var[str]) -> Var[Component]: From d77b900bd7a9b05f2d60d6f70120c1e61c08b162 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 3 Oct 2024 19:19:06 -0700 Subject: [PATCH 14/16] [ENG-3867] Garden Variety Pickle (#4054) * Use regular `pickle` module from stdlib * Avoid recreating the rx.State tree for every `get_state` * Remove dill dependency * relock deps --- poetry.lock | 96 +++++++++++++++++--------------------- pyproject.toml | 1 - reflex/state.py | 87 ++++++++++++++++++++++------------ reflex/utils/exceptions.py | 4 ++ 4 files changed, 103 insertions(+), 85 deletions(-) diff --git a/poetry.lock b/poetry.lock index f94a3832a..928731c26 100644 --- a/poetry.lock +++ b/poetry.lock @@ -516,21 +516,6 @@ files = [ {file = "darglint-1.8.1.tar.gz", hash = "sha256:080d5106df149b199822e7ee7deb9c012b49891538f14a11be681044f0bb20da"}, ] -[[package]] -name = "dill" -version = "0.3.8" -description = "serialize all of Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, - {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, -] - -[package.extras] -graph = ["objgraph (>=1.7.2)"] -profile = ["gprof2dot (>=2022.7.29)"] - [[package]] name = "distlib" version = "0.3.8" @@ -719,13 +704,13 @@ files = [ [[package]] name = "httpcore" -version = "1.0.5" +version = "1.0.6" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, - {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, + {file = "httpcore-1.0.6-py3-none-any.whl", hash = "sha256:27b59625743b85577a8c0e10e55b50b5368a4f2cfe8cc7bcfa9cf00829c2682f"}, + {file = "httpcore-1.0.6.tar.gz", hash = "sha256:73f6dbd6eb8c21bbf7ef8efad555481853f5f6acdeaff1edb0694289269ee17f"}, ] [package.dependencies] @@ -736,7 +721,7 @@ h11 = ">=0.13,<0.15" asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] -trio = ["trio (>=0.22.0,<0.26.0)"] +trio = ["trio (>=0.22.0,<1.0)"] [[package]] name = "httpx" @@ -863,21 +848,25 @@ test = ["portend", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-c [[package]] name = "jaraco-functools" -version = "4.0.2" +version = "4.1.0" description = "Functools like those found in stdlib" optional = false python-versions = ">=3.8" files = [ - {file = "jaraco.functools-4.0.2-py3-none-any.whl", hash = "sha256:c9d16a3ed4ccb5a889ad8e0b7a343401ee5b2a71cee6ed192d3f68bc351e94e3"}, - {file = "jaraco_functools-4.0.2.tar.gz", hash = "sha256:3460c74cd0d32bf82b9576bbb3527c4364d5b27a21f5158a62aed6c4b42e23f5"}, + {file = "jaraco.functools-4.1.0-py3-none-any.whl", hash = "sha256:ad159f13428bc4acbf5541ad6dec511f91573b90fba04df61dafa2a1231cf649"}, + {file = "jaraco_functools-4.1.0.tar.gz", hash = "sha256:70f7e0e2ae076498e212562325e805204fc092d7b4c17e0e86c959e249701a9d"}, ] [package.dependencies] more-itertools = "*" [package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -test = ["jaraco.classes", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["jaraco.classes", "pytest (>=6,!=8.1.*)"] +type = ["pytest-mypy"] [[package]] name = "jeepney" @@ -1788,13 +1777,13 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pyproject-hooks" -version = "1.1.0" +version = "1.2.0" description = "Wrappers to call pyproject.toml-based build backend hooks." optional = false python-versions = ">=3.7" files = [ - {file = "pyproject_hooks-1.1.0-py3-none-any.whl", hash = "sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2"}, - {file = "pyproject_hooks-1.1.0.tar.gz", hash = "sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965"}, + {file = "pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913"}, + {file = "pyproject_hooks-1.2.0.tar.gz", hash = "sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8"}, ] [[package]] @@ -1992,13 +1981,13 @@ docs = ["sphinx"] [[package]] name = "python-multipart" -version = "0.0.10" +version = "0.0.12" description = "A streaming multipart parser for Python" optional = false python-versions = ">=3.8" files = [ - {file = "python_multipart-0.0.10-py3-none-any.whl", hash = "sha256:2b06ad9e8d50c7a8db80e3b56dab590137b323410605af2be20d62a5f1ba1dc8"}, - {file = "python_multipart-0.0.10.tar.gz", hash = "sha256:46eb3c6ce6fdda5fb1a03c7e11d490e407c6930a2703fe7aef4da71c374688fa"}, + {file = "python_multipart-0.0.12-py3-none-any.whl", hash = "sha256:43dcf96cf65888a9cd3423544dd0d75ac10f7aa0c3c28a175bbcd00c9ce1aebf"}, + {file = "python_multipart-0.0.12.tar.gz", hash = "sha256:045e1f98d719c1ce085ed7f7e1ef9d8ccc8c02ba02b5566d5f7521410ced58cb"}, ] [[package]] @@ -2143,31 +2132,31 @@ md = ["cmarkgfm (>=0.8.0)"] [[package]] name = "redis" -version = "5.0.8" +version = "5.1.0" description = "Python client for Redis database and key-value store" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "redis-5.0.8-py3-none-any.whl", hash = "sha256:56134ee08ea909106090934adc36f65c9bcbbaecea5b21ba704ba6fb561f8eb4"}, - {file = "redis-5.0.8.tar.gz", hash = "sha256:0c5b10d387568dfe0698c6fad6615750c24170e548ca2deac10c649d463e9870"}, + {file = "redis-5.1.0-py3-none-any.whl", hash = "sha256:fd4fccba0d7f6aa48c58a78d76ddb4afc698f5da4a2c1d03d916e4fd7ab88cdd"}, + {file = "redis-5.1.0.tar.gz", hash = "sha256:b756df1e4a3858fcc0ef861f3fc53623a96c41e2b1f5304e09e0fe758d333d40"}, ] [package.dependencies] async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} [package.extras] -hiredis = ["hiredis (>1.0.0)"] -ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] +hiredis = ["hiredis (>=3.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] [[package]] name = "reflex-chakra" -version = "0.6.0" +version = "0.6.1" description = "reflex using chakra components" optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "reflex_chakra-0.6.0-py3-none-any.whl", hash = "sha256:eca1593fca67289e05591dd21fbcc8632c119d64a08bdc41fd995055a114cc91"}, - {file = "reflex_chakra-0.6.0.tar.gz", hash = "sha256:db1c7b48f1ba547bf91e5af103fce6fc7191d7225b414ebfbada7d983e33dd87"}, + {file = "reflex_chakra-0.6.1-py3-none-any.whl", hash = "sha256:824d461264b6d2c836ba4a2a430e677a890b82e83da149672accfc58786442fa"}, + {file = "reflex_chakra-0.6.1.tar.gz", hash = "sha256:4b9b3c8bada19cbb4d1b8d8bc4ab0460ec008a91f380010c34d416d5b613dc07"}, ] [package.dependencies] @@ -2247,18 +2236,19 @@ idna2008 = ["idna"] [[package]] name = "rich" -version = "13.8.1" +version = "13.9.1" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false -python-versions = ">=3.7.0" +python-versions = ">=3.8.0" files = [ - {file = "rich-13.8.1-py3-none-any.whl", hash = "sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06"}, - {file = "rich-13.8.1.tar.gz", hash = "sha256:8260cda28e3db6bf04d2d1ef4dbc03ba80a824c88b0e7668a0f23126a424844a"}, + {file = "rich-13.9.1-py3-none-any.whl", hash = "sha256:b340e739f30aa58921dc477b8adaa9ecdb7cecc217be01d93730ee1bc8aa83be"}, + {file = "rich-13.9.1.tar.gz", hash = "sha256:097cffdf85db1babe30cc7deba5ab3a29e1b9885047dab24c57e9a7f8a9c1466"}, ] [package.dependencies] markdown-it-py = ">=2.2.0" pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] @@ -2595,13 +2585,13 @@ files = [ [[package]] name = "tomli" -version = "2.0.1" +version = "2.0.2" description = "A lil' TOML parser" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, - {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, + {file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"}, + {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"}, ] [[package]] @@ -2734,13 +2724,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" -version = "0.30.6" +version = "0.31.0" description = "The lightning-fast ASGI server." optional = false python-versions = ">=3.8" files = [ - {file = "uvicorn-0.30.6-py3-none-any.whl", hash = "sha256:65fd46fe3fda5bdc1b03b94eb634923ff18cd35b2f084813ea79d1f103f711b5"}, - {file = "uvicorn-0.30.6.tar.gz", hash = "sha256:4b15decdda1e72be08209e860a1e10e92439ad5b97cf44cc945fcbee66fc5788"}, + {file = "uvicorn-0.31.0-py3-none-any.whl", hash = "sha256:cac7be4dd4d891c363cd942160a7b02e69150dcbc7a36be04d5f4af4b17c8ced"}, + {file = "uvicorn-0.31.0.tar.gz", hash = "sha256:13bc21373d103859f68fe739608e2eb054a816dea79189bc3ca08ea89a275906"}, ] [package.dependencies] @@ -2753,13 +2743,13 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", [[package]] name = "virtualenv" -version = "20.26.5" +version = "20.26.6" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.26.5-py3-none-any.whl", hash = "sha256:4f3ac17b81fba3ce3bd6f4ead2749a72da5929c01774948e243db9ba41df4ff6"}, - {file = "virtualenv-20.26.5.tar.gz", hash = "sha256:ce489cac131aa58f4b25e321d6d186171f78e6cb13fafbf32a840cee67733ff4"}, + {file = "virtualenv-20.26.6-py3-none-any.whl", hash = "sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2"}, + {file = "virtualenv-20.26.6.tar.gz", hash = "sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48"}, ] [package.dependencies] @@ -3011,4 +3001,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "adccd071775567aeefe219261aeb9e222906c865745f03edb1e770edc79c44ac" +content-hash = "e4b462ebfae90550ba7fa49b360d7110c0d344ee616c23989c22d866ef8f6f31" diff --git a/pyproject.toml b/pyproject.toml index 08c4fbdbc..281741368 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ packages = [ [tool.poetry.dependencies] python = "^3.9" -dill = ">=0.3.8,<0.4" fastapi = ">=0.96.0,!=0.111.0,!=0.111.1" gunicorn = ">=20.1.0,<24.0" jinja2 = ">=3.1.2,<4.0" diff --git a/reflex/state.py b/reflex/state.py index b1988e38a..5798564fa 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -9,6 +9,7 @@ import dataclasses import functools import inspect import os +import pickle import uuid from abc import ABC, abstractmethod from collections import defaultdict @@ -19,6 +20,7 @@ from typing import ( TYPE_CHECKING, Any, AsyncIterator, + BinaryIO, Callable, ClassVar, Dict, @@ -33,7 +35,6 @@ from typing import ( get_type_hints, ) -import dill from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self @@ -76,6 +77,7 @@ from reflex.utils.exceptions import ( ImmutableStateError, LockExpiredError, SetUndefinedStateVarError, + StateSchemaMismatchError, ) from reflex.utils.exec import is_testing_env from reflex.utils.serializers import serializer @@ -1914,7 +1916,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): def __getstate__(self): """Get the state for redis serialization. - This method is called by cloudpickle to serialize the object. + This method is called by pickle to serialize the object. It explicitly removes parent_state and substates because those are serialized separately by the StateManagerRedis to allow for better horizontal scaling as state size increases. @@ -1930,6 +1932,43 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): state["__dict__"].pop("_was_touched", None) return state + def _serialize(self) -> bytes: + """Serialize the state for redis. + + Returns: + The serialized state. + """ + return pickle.dumps((state_to_schema(self), self)) + + @classmethod + def _deserialize( + cls, data: bytes | None = None, fp: BinaryIO | None = None + ) -> BaseState: + """Deserialize the state from redis/disk. + + data and fp are mutually exclusive, but one must be provided. + + Args: + data: The serialized state data. + fp: The file pointer to the serialized state data. + + Returns: + The deserialized state. + + Raises: + ValueError: If both data and fp are provided, or neither are provided. + StateSchemaMismatchError: If the state schema does not match the expected schema. + """ + if data is not None and fp is None: + (substate_schema, state) = pickle.loads(data) + elif fp is not None and data is None: + (substate_schema, state) = pickle.load(fp) + else: + raise ValueError("Only one of `data` or `fp` must be provided") + if substate_schema != state_to_schema(state): + raise StateSchemaMismatchError() + return state + class State(BaseState): """The app Base State.""" @@ -2086,7 +2125,11 @@ class ComponentState(State, mixin=True): """ cls._per_component_state_instance_count += 1 state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}" - component_state = type(state_cls_name, (cls, State), {}, mixin=False) + component_state = type( + state_cls_name, (cls, State), {"__module__": __name__}, mixin=False + ) + # Save a reference to the dynamic state for pickle/unpickle. + globals()[state_cls_name] = component_state component = component_state.get_component(*children, **props) component.State = component_state return component @@ -2552,7 +2595,7 @@ def is_serializable(value: Any) -> bool: Whether the value is serializable. """ try: - return bool(dill.dumps(value)) + return bool(pickle.dumps(value)) except Exception: return False @@ -2688,8 +2731,7 @@ class StateManagerDisk(StateManager): if token_path.exists(): try: with token_path.open(mode="rb") as file: - (substate_schema, substate) = dill.load(file) - if substate_schema == state_to_schema(substate): + substate = BaseState._deserialize(fp=file) await self.populate_substates(client_token, substate, root_state) return substate except Exception: @@ -2731,10 +2773,12 @@ class StateManagerDisk(StateManager): client_token, substate_address = _split_substate_key(token) root_state_token = _substate_key(client_token, substate_address.split(".")[0]) + root_state = self.states.get(root_state_token) + if root_state is None: + # Create a new root state which will be persisted in the next set_state call. + root_state = self.state(_reflex_internal_init=True) - return await self.load_state( - root_state_token, self.state(_reflex_internal_init=True) - ) + return await self.load_state(root_state_token, root_state) async def set_state_for_substate(self, client_token: str, substate: BaseState): """Set the state for a substate. @@ -2747,7 +2791,7 @@ class StateManagerDisk(StateManager): self.states[substate_token] = substate - state_dilled = dill.dumps((state_to_schema(substate), substate)) + state_dilled = substate._serialize() if not self.states_directory.exists(): self.states_directory.mkdir(parents=True, exist_ok=True) self.token_path(substate_token).write_bytes(state_dilled) @@ -2790,25 +2834,6 @@ class StateManagerDisk(StateManager): await self.set_state(token, state) -# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes -if not isinstance(State.validate.__func__, FunctionType): - cython_function_or_method = type(State.validate.__func__) - - @dill.register(cython_function_or_method) - def _dill_reduce_cython_function_or_method(pickler, obj): - # Ignore cython function when pickling. - pass - - -@dill.register(type(State)) -def _dill_reduce_state(pickler, obj): - if obj is not State and issubclass(obj, State): - # Avoid serializing subclasses of State, instead get them by reference from the State class. - pickler.save_reduce(State.get_class_substate, (obj.get_full_name(),), obj=obj) - else: - dill.Pickler.dispatch[type](pickler, obj) - - def _default_lock_expiration() -> int: """Get the default lock expiration time. @@ -2948,7 +2973,7 @@ class StateManagerRedis(StateManager): if redis_state is not None: # Deserialize the substate. - state = dill.loads(redis_state) + state = BaseState._deserialize(data=redis_state) # Populate parent state if missing and requested. if parent_state is None: @@ -3060,7 +3085,7 @@ class StateManagerRedis(StateManager): ) # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). if state._get_was_touched(): - pickle_state = dill.dumps(state, byref=True) + pickle_state = state._serialize() self._warn_if_too_large(state, len(pickle_state)) await self.redis.set( _substate_key(client_token, state), diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 0383f7ba6..8bce605b5 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -123,3 +123,7 @@ class DynamicComponentMissingLibrary(ReflexError, ValueError): class SetUndefinedStateVarError(ReflexError, AttributeError): """Raised when setting the value of a var without first declaring it.""" + + +class StateSchemaMismatchError(ReflexError, TypeError): + """Raised when the serialized schema of a state class does not match the current schema.""" From 9b5a36814ab46ced35638854dfdd1098db0065a0 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 4 Oct 2024 12:27:52 -0700 Subject: [PATCH 15/16] bump version to 0.6.3dev1 (#4061) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 281741368..7d79ddf2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "reflex" -version = "0.6.2dev1" +version = "0.6.3dev1" description = "Web apps in pure Python." license = "Apache-2.0" authors = [ From 12b81ad7543b08fc84ad1b0854d645eaa5cde7e7 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 4 Oct 2024 13:56:25 -0700 Subject: [PATCH 16/16] convert literal type to its variants (#4062) --- reflex/vars/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 0f8a80f8d..7eab62c68 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -547,6 +547,10 @@ class Var(Generic[VAR_TYPE]): return self + if fixed_type is Literal: + args = get_args(var_type) + fixed_type = unionize(*(type(arg) for arg in args)) + if not inspect.isclass(fixed_type): raise TypeError(f"Unsupported type {var_type} for guess_type.")