diff --git a/benchmarks/benchmark_lighthouse.py b/benchmarks/benchmark_lighthouse.py index 72f486b6f..25d5eaac4 100644 --- a/benchmarks/benchmark_lighthouse.py +++ b/benchmarks/benchmark_lighthouse.py @@ -3,8 +3,8 @@ from __future__ import annotations import json -import os import sys +from pathlib import Path from utils import send_data_to_posthog @@ -28,7 +28,7 @@ def insert_benchmarking_data( send_data_to_posthog("lighthouse_benchmark", properties) -def get_lighthouse_scores(directory_path: str) -> dict: +def get_lighthouse_scores(directory_path: str | Path) -> dict: """Extracts the Lighthouse scores from the JSON files in the specified directory. Args: @@ -38,24 +38,21 @@ def get_lighthouse_scores(directory_path: str) -> dict: dict: The Lighthouse scores. """ scores = {} - + directory_path = Path(directory_path) try: - for filename in os.listdir(directory_path): - if filename.endswith(".json") and filename != "manifest.json": - file_path = os.path.join(directory_path, filename) - with open(file_path, "r") as file: - data = json.load(file) - # Extract scores and add them to the dictionary with the filename as key - scores[data["finalUrl"].replace("http://localhost:3000/", "/")] = { - "performance_score": data["categories"]["performance"]["score"], - "accessibility_score": data["categories"]["accessibility"][ - "score" - ], - "best_practices_score": data["categories"]["best-practices"][ - "score" - ], - "seo_score": data["categories"]["seo"]["score"], - } + for filename in directory_path.iterdir(): + if filename.suffix == ".json" and filename.stem != "manifest": + file_path = directory_path / filename + data = json.loads(file_path.read_text()) + # Extract scores and add them to the dictionary with the filename as key + scores[data["finalUrl"].replace("http://localhost:3000/", "/")] = { + "performance_score": data["categories"]["performance"]["score"], + "accessibility_score": data["categories"]["accessibility"]["score"], + "best_practices_score": data["categories"]["best-practices"][ + "score" + ], + "seo_score": data["categories"]["seo"]["score"], + } except Exception as e: return {"error": e} diff --git a/benchmarks/benchmark_package_size.py b/benchmarks/benchmark_package_size.py index 8e2704355..778b52769 100644 --- a/benchmarks/benchmark_package_size.py +++ b/benchmarks/benchmark_package_size.py @@ -2,11 +2,12 @@ import argparse import os +from pathlib import Path from utils import get_directory_size, get_python_version, send_data_to_posthog -def get_package_size(venv_path, os_name): +def get_package_size(venv_path: Path, os_name): """Get the size of a specified package. Args: @@ -26,14 +27,12 @@ def get_package_size(venv_path, os_name): is_windows = "windows" in os_name - full_path = ( - ["lib", f"python{python_version}", "site-packages"] + package_dir: Path = ( + venv_path / "lib" / f"python{python_version}" / "site-packages" if not is_windows - else ["Lib", "site-packages"] + else venv_path / "Lib" / "site-packages" ) - - package_dir = os.path.join(venv_path, *full_path) - if not os.path.exists(package_dir): + if not package_dir.exists(): raise ValueError( "Error: Virtual environment does not exist or is not activated." ) @@ -63,9 +62,9 @@ def insert_benchmarking_data( path: The path to the dir or file to check size. """ if "./dist" in path: - size = get_directory_size(path) + size = get_directory_size(Path(path)) else: - size = get_package_size(path, os_type_version) + size = get_package_size(Path(path), os_type_version) # Prepare the event data properties = { diff --git a/benchmarks/benchmark_web_size.py b/benchmarks/benchmark_web_size.py index 6c2f40bbc..3ceccecf8 100644 --- a/benchmarks/benchmark_web_size.py +++ b/benchmarks/benchmark_web_size.py @@ -2,6 +2,7 @@ import argparse import os +from pathlib import Path from utils import get_directory_size, send_data_to_posthog @@ -28,7 +29,7 @@ def insert_benchmarking_data( pr_id: The id of the PR. path: The path to the dir or file to check size. """ - size = get_directory_size(path) + size = get_directory_size(Path(path)) # Prepare the event data properties = { diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 7b02c8cc8..bfadf5b4e 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -2,12 +2,13 @@ import os import subprocess +from pathlib import Path import httpx from httpx import HTTPError -def get_python_version(venv_path, os_name): +def get_python_version(venv_path: Path, os_name): """Get the python version of python in a virtual env. Args: @@ -18,13 +19,13 @@ def get_python_version(venv_path, os_name): The python version. """ python_executable = ( - os.path.join(venv_path, "bin", "python") + venv_path / "bin" / "python" if "windows" not in os_name - else os.path.join(venv_path, "Scripts", "python.exe") + else venv_path / "Scripts" / "python.exe" ) try: output = subprocess.check_output( - [python_executable, "--version"], stderr=subprocess.STDOUT + [str(python_executable), "--version"], stderr=subprocess.STDOUT ) python_version = output.decode("utf-8").strip().split()[1] return ".".join(python_version.split(".")[:-1]) @@ -32,7 +33,7 @@ def get_python_version(venv_path, os_name): return None -def get_directory_size(directory): +def get_directory_size(directory: Path): """Get the size of a directory in bytes. Args: @@ -44,8 +45,8 @@ def get_directory_size(directory): total_size = 0 for dirpath, _, filenames in os.walk(directory): for f in filenames: - fp = os.path.join(dirpath, f) - total_size += os.path.getsize(fp) + fp = Path(dirpath) / f + total_size += fp.stat().st_size return total_size diff --git a/poetry.lock b/poetry.lock index 6d7663028..6614c7bd9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -516,21 +516,6 @@ files = [ {file = "darglint-1.8.1.tar.gz", hash = "sha256:080d5106df149b199822e7ee7deb9c012b49891538f14a11be681044f0bb20da"}, ] -[[package]] -name = "dill" -version = "0.3.9" -description = "serialize all of Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"}, - {file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"}, -] - -[package.extras] -graph = ["objgraph (>=1.7.2)"] -profile = ["gprof2dot (>=2022.7.29)"] - [[package]] name = "distlib" version = "0.3.8" @@ -2147,13 +2132,13 @@ md = ["cmarkgfm (>=0.8.0)"] [[package]] name = "redis" -version = "5.1.0" +version = "5.1.1" description = "Python client for Redis database and key-value store" optional = false python-versions = ">=3.8" files = [ - {file = "redis-5.1.0-py3-none-any.whl", hash = "sha256:fd4fccba0d7f6aa48c58a78d76ddb4afc698f5da4a2c1d03d916e4fd7ab88cdd"}, - {file = "redis-5.1.0.tar.gz", hash = "sha256:b756df1e4a3858fcc0ef861f3fc53623a96c41e2b1f5304e09e0fe758d333d40"}, + {file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"}, + {file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"}, ] [package.dependencies] @@ -2251,13 +2236,13 @@ idna2008 = ["idna"] [[package]] name = "rich" -version = "13.9.1" +version = "13.9.2" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.8.0" files = [ - {file = "rich-13.9.1-py3-none-any.whl", hash = "sha256:b340e739f30aa58921dc477b8adaa9ecdb7cecc217be01d93730ee1bc8aa83be"}, - {file = "rich-13.9.1.tar.gz", hash = "sha256:097cffdf85db1babe30cc7deba5ab3a29e1b9885047dab24c57e9a7f8a9c1466"}, + {file = "rich-13.9.2-py3-none-any.whl", hash = "sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1"}, + {file = "rich-13.9.2.tar.gz", hash = "sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c"}, ] [package.dependencies] @@ -3016,4 +3001,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "c4224e45026cde032517f07df3c8553850b8ada5001311485e2abbfa4f4eaf5e" +content-hash = "36059dc143f1eb94f4c87a6cfe94de94eddac6e3d01fe76d28b6ed065c1b7836" diff --git a/pyproject.toml b/pyproject.toml index d058405ea..22f90eba2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ packages = [ [tool.poetry.dependencies] python = "^3.9" -dill = ">=0.3.8,<0.4" fastapi = ">=0.96.0,!=0.111.0,!=0.111.1" gunicorn = ">=20.1.0,<24.0" jinja2 = ">=3.1.2,<4.0" diff --git a/reflex/.templates/jinja/web/pages/_app.js.jinja2 b/reflex/.templates/jinja/web/pages/_app.js.jinja2 index 97c31925d..c893e19e2 100644 --- a/reflex/.templates/jinja/web/pages/_app.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/_app.js.jinja2 @@ -7,10 +7,9 @@ import '/styles/styles.css' {% block declaration %} import { EventLoopProvider, StateProvider, defaultColorMode } from "/utils/context.js"; import { ThemeProvider } from 'next-themes' -import * as React from "react"; -import * as utils_context from "/utils/context.js"; -import * as utils_state from "/utils/state.js"; -import * as radix from "@radix-ui/themes"; +{% for library_alias, library_path in window_libraries %} +import * as {{library_alias}} from "{{library_path}}"; +{% endfor %} {% for custom_code in custom_codes %} {{custom_code}} @@ -33,10 +32,9 @@ export default function MyApp({ Component, pageProps }) { React.useEffect(() => { // Make contexts and state objects available globally for dynamic eval'd components let windowImports = { - "react": React, - "@radix-ui/themes": radix, - "/utils/context": utils_context, - "/utils/state": utils_state, +{% for library_alias, library_path in window_libraries %} + "{{library_path}}": {{library_alias}}, +{% endfor %} }; window["__reflex"] = windowImports; }, []); diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 78e671809..0fe0db8c1 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -544,13 +544,19 @@ export const uploadFiles = async ( /** * Create an event object. - * @param name The name of the event. - * @param payload The payload of the event. - * @param handler The client handler to process event. + * @param {string} name The name of the event. + * @param {Object.} payload The payload of the event. + * @param {Object.} event_actions The actions to take on the event. + * @param {string} handler The client handler to process event. * @returns The event object. */ -export const Event = (name, payload = {}, handler = null) => { - return { name, payload, handler }; +export const Event = ( + name, + payload = {}, + event_actions = {}, + handler = null +) => { + return { name, payload, handler, event_actions }; }; /** @@ -676,6 +682,12 @@ export const useEventLoop = ( if (!(args instanceof Array)) { args = [args]; } + + event_actions = events.reduce( + (acc, e) => ({ ...acc, ...e.event_actions }), + event_actions ?? {} + ); + const _e = args.filter((o) => o?.preventDefault !== undefined)[0]; if (event_actions?.preventDefault && _e?.preventDefault) { diff --git a/reflex/app.py b/reflex/app.py index 111dd9dfd..584b8a321 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -431,25 +431,12 @@ class App(MiddlewareMixin, LifespanMixin, Base): The generated component. Raises: - VarOperationTypeError: When an invalid component var related function is passed. - TypeError: When an invalid component function is passed. exceptions.MatchTypeError: If the return types of match cases in rx.match are different. """ - from reflex.utils.exceptions import VarOperationTypeError - try: return component if isinstance(component, Component) else component() except exceptions.MatchTypeError: raise - except TypeError as e: - message = str(e) - if "Var" in message: - raise VarOperationTypeError( - "You may be trying to use an invalid Python function on a state var. " - "When referencing a var inside your render code, only limited var operations are supported. " - "See the var operation docs here: https://reflex.dev/docs/vars/var-operations/" - ) from e - raise e def add_page( self, @@ -1536,7 +1523,9 @@ class EventNamespace(AsyncNamespace): """ fields = json.loads(data) # Get the event. - event = Event(**{k: v for k, v in fields.items() if k != "handler"}) + event = Event( + **{k: v for k, v in fields.items() if k not in ("handler", "event_actions")} + ) self.token_to_sid[event.token] = sid self.sid_to_token[sid] = event.token diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index edf03039e..0c29f941d 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -40,6 +40,20 @@ def _compile_document_root(root: Component) -> str: ) +def _normalize_library_name(lib: str) -> str: + """Normalize the library name. + + Args: + lib: The library name to normalize. + + Returns: + The normalized library name. + """ + if lib == "react": + return "React" + return lib.replace("@", "").replace("/", "_").replace("-", "_") + + def _compile_app(app_root: Component) -> str: """Compile the app template component. @@ -49,10 +63,20 @@ def _compile_app(app_root: Component) -> str: Returns: The compiled app. """ + from reflex.components.dynamic import bundled_libraries + + window_libraries = [ + (_normalize_library_name(name), name) for name in bundled_libraries + ] + [ + ("utils_context", f"/{constants.Dirs.UTILS}/context"), + ("utils_state", f"/{constants.Dirs.UTILS}/state"), + ] + return templates.APP_ROOT.render( imports=utils.compile_imports(app_root._get_all_imports()), custom_codes=app_root._get_all_custom_code(), hooks={**app_root._get_all_hooks_internal(), **app_root._get_all_hooks()}, + window_libraries=window_libraries, render=app_root.render(), ) @@ -171,7 +195,7 @@ def _compile_root_stylesheet(stylesheets: list[str]) -> str: stylesheet_full_path = ( Path.cwd() / constants.Dirs.APP_ASSETS / stylesheet.strip("/") ) - if not os.path.exists(stylesheet_full_path): + if not stylesheet_full_path.exists(): raise FileNotFoundError( f"The stylesheet file {stylesheet_full_path} does not exist." ) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index b10552554..6f4fa2d1b 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -2,7 +2,6 @@ from __future__ import annotations -import os from pathlib import Path from typing import Any, Callable, Dict, Optional, Type, Union from urllib.parse import urlparse @@ -457,16 +456,16 @@ def add_meta( return page -def write_page(path: str, code: str): +def write_page(path: str | Path, code: str): """Write the given code to the given path. Args: path: The path to write the code to. code: The code to write. """ - path_ops.mkdir(os.path.dirname(path)) - with open(path, "w", encoding="utf-8") as f: - f.write(code) + path = Path(path) + path_ops.mkdir(path.parent) + path.write_text(code, encoding="utf-8") def empty_dir(path: str | Path, keep_files: list[str] | None = None): diff --git a/reflex/components/component.py b/reflex/components/component.py index 9bdd12f0e..26ea2fd3f 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -38,8 +38,10 @@ from reflex.constants import ( ) from reflex.event import ( EventChain, + EventChainVar, EventHandler, EventSpec, + EventVar, call_event_fn, call_event_handler, get_handler_args, @@ -514,7 +516,7 @@ class Component(BaseComponent, ABC): Var, EventHandler, EventSpec, - List[Union[EventHandler, EventSpec]], + List[Union[EventHandler, EventSpec, EventVar]], Callable, ], ) -> Union[EventChain, Var]: @@ -532,11 +534,16 @@ class Component(BaseComponent, ABC): """ # If it's an event chain var, return it. if isinstance(value, Var): - if value._var_type is not EventChain: + if isinstance(value, EventChainVar): + return value + elif isinstance(value, EventVar): + value = [value] + elif issubclass(value._var_type, (EventChain, EventSpec)): + return self._create_event_chain(args_spec, value.guess_type()) + else: raise ValueError( - f"Invalid event chain: {repr(value)} of type {type(value)}" + f"Invalid event chain: {str(value)} of type {value._var_type}" ) - return value elif isinstance(value, EventChain): # Trust that the caller knows what they're doing passing an EventChain directly return value @@ -547,7 +554,7 @@ class Component(BaseComponent, ABC): # If the input is a list of event handlers, create an event chain. if isinstance(value, List): - events: list[EventSpec] = [] + events: List[Union[EventSpec, EventVar]] = [] for v in value: if isinstance(v, (EventHandler, EventSpec)): # Call the event handler to get the event. @@ -561,6 +568,8 @@ class Component(BaseComponent, ABC): "lambda inside an EventChain list." ) events.extend(result) + elif isinstance(v, EventVar): + events.append(v) else: raise ValueError(f"Invalid event: {v}") @@ -570,32 +579,30 @@ class Component(BaseComponent, ABC): if isinstance(result, Var): # Recursively call this function if the lambda returned an EventChain Var. return self._create_event_chain(args_spec, result) - events = result + events = [*result] # Otherwise, raise an error. else: raise ValueError(f"Invalid event chain: {value}") # Add args to the event specs if necessary. - events = [e.with_args(get_handler_args(e)) for e in events] - - # Collect event_actions from each spec - event_actions = {} - for e in events: - event_actions.update(e.event_actions) + events = [ + (e.with_args(get_handler_args(e)) if isinstance(e, EventSpec) else e) + for e in events + ] # Return the event chain. if isinstance(args_spec, Var): return EventChain( events=events, args_spec=None, - event_actions=event_actions, + event_actions={}, ) else: return EventChain( events=events, args_spec=args_spec, - event_actions=event_actions, + event_actions={}, ) def get_event_triggers(self) -> Dict[str, Any]: @@ -1030,8 +1037,11 @@ class Component(BaseComponent, ABC): elif isinstance(event, EventChain): event_args = [] for spec in event.events: - for args in spec.args: - event_args.extend(args) + if isinstance(spec, EventSpec): + for args in spec.args: + event_args.extend(args) + else: + event_args.append(spec) yield event_trigger, event_args def _get_vars(self, include_children: bool = False) -> list[Var]: @@ -1105,8 +1115,12 @@ class Component(BaseComponent, ABC): for trigger in self.event_triggers.values(): if isinstance(trigger, EventChain): for event in trigger.events: - if event.handler.state_full_name: - return True + if isinstance(event, EventSpec): + if event.handler.state_full_name: + return True + else: + if event._var_state: + return True elif isinstance(trigger, Var) and trigger._var_state: return True return False diff --git a/reflex/components/dynamic.py b/reflex/components/dynamic.py index 390b6e688..8d0bab669 100644 --- a/reflex/components/dynamic.py +++ b/reflex/components/dynamic.py @@ -1,12 +1,18 @@ """Components that are dynamically generated on the backend.""" +from typing import TYPE_CHECKING + from reflex import constants from reflex.utils import imports +from reflex.utils.exceptions import DynamicComponentMissingLibrary from reflex.utils.format import format_library_name from reflex.utils.serializers import serializer from reflex.vars import Var, get_unique_variable_name from reflex.vars.base import VarData, transform +if TYPE_CHECKING: + from reflex.components.component import Component + def get_cdn_url(lib: str) -> str: """Get the CDN URL for a library. @@ -20,6 +26,27 @@ def get_cdn_url(lib: str) -> str: return f"https://cdn.jsdelivr.net/npm/{lib}" + "/+esm" +bundled_libraries = { + "react", + "@radix-ui/themes", + "@emotion/react", +} + + +def bundle_library(component: "Component"): + """Bundle a library with the component. + + Args: + component: The component to bundle the library with. + + Raises: + DynamicComponentMissingLibrary: Raised when a dynamic component is missing a library. + """ + if component.library is None: + raise DynamicComponentMissingLibrary("Component must have a library to bundle.") + bundled_libraries.add(format_library_name(component.library)) + + def load_dynamic_serializer(): """Load the serializer for dynamic components.""" # Causes a circular import, so we import here. @@ -58,10 +85,7 @@ def load_dynamic_serializer(): ) ] = None - libs_in_window = [ - "react", - "@radix-ui/themes", - ] + libs_in_window = bundled_libraries imports = {} for lib, names in component._get_all_imports().items(): @@ -69,10 +93,7 @@ def load_dynamic_serializer(): if ( not lib.startswith((".", "/")) and not lib.startswith("http") - and all( - formatted_lib_name != lib_in_window - for lib_in_window in libs_in_window - ) + and formatted_lib_name not in libs_in_window ): imports[get_cdn_url(lib)] = names else: @@ -110,7 +131,14 @@ def load_dynamic_serializer(): module_code_lines.insert(0, "const React = window.__reflex.react;") - return "//__reflex_evaluate\n" + "\n".join(module_code_lines) + return "\n".join( + [ + "//__reflex_evaluate", + "/** @jsx jsx */", + "const { jsx } = window.__reflex['@emotion/react']", + *module_code_lines, + ] + ) @transform def evaluate_component(js_string: Var[str]) -> Var[Component]: diff --git a/reflex/config.py b/reflex/config.py index c237d7421..a5d66cb52 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -6,7 +6,8 @@ import importlib import os import sys import urllib.parse -from typing import Any, Dict, List, Optional, Set +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union try: import pydantic.v1 as pydantic @@ -188,7 +189,7 @@ class Config(Base): telemetry_enabled: bool = True # The bun path - bun_path: str = constants.Bun.DEFAULT_PATH + bun_path: Union[str, Path] = constants.Bun.DEFAULT_PATH # List of origins that are allowed to connect to the backend API. cors_allowed_origins: List[str] = ["*"] diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 225e8000b..b86f083cc 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -6,6 +6,7 @@ import os import platform from enum import Enum from importlib import metadata +from pathlib import Path from types import SimpleNamespace from platformdirs import PlatformDirs @@ -66,18 +67,19 @@ class Reflex(SimpleNamespace): # Get directory value from enviroment variables if it exists. _dir = os.environ.get("REFLEX_DIR", "") - DIR = _dir or ( - # on windows, we use C:/Users//AppData/Local/reflex. - # on macOS, we use ~/Library/Application Support/reflex. - # on linux, we use ~/.local/share/reflex. - # If user sets REFLEX_DIR envroment variable use that instead. - PlatformDirs(MODULE_NAME, False).user_data_dir + DIR = Path( + _dir + or ( + # on windows, we use C:/Users//AppData/Local/reflex. + # on macOS, we use ~/Library/Application Support/reflex. + # on linux, we use ~/.local/share/reflex. + # If user sets REFLEX_DIR envroment variable use that instead. + PlatformDirs(MODULE_NAME, False).user_data_dir + ) ) # The root directory of the reflex library. - ROOT_DIR = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - ) + ROOT_DIR = Path(__file__).parents[2] RELEASES_URL = f"https://api.github.com/repos/reflex-dev/templates/releases" @@ -125,11 +127,11 @@ class Templates(SimpleNamespace): """Folders used by the template system of Reflex.""" # The template directory used during reflex init. - BASE = os.path.join(Reflex.ROOT_DIR, Reflex.MODULE_NAME, ".templates") + BASE = Reflex.ROOT_DIR / Reflex.MODULE_NAME / ".templates" # The web subdirectory of the template directory. - WEB_TEMPLATE = os.path.join(BASE, "web") + WEB_TEMPLATE = BASE / "web" # The jinja template directory. - JINJA_TEMPLATE = os.path.join(BASE, "jinja") + JINJA_TEMPLATE = BASE / "jinja" # Where the code for the templates is stored. CODE = "code" @@ -191,6 +193,14 @@ class LogLevel(str, Enum): levels = list(LogLevel) return levels.index(self) <= levels.index(other) + def subprocess_level(self): + """Return the log level for the subprocess. + + Returns: + The log level for the subprocess + """ + return self if self != LogLevel.DEFAULT else LogLevel.WARNING + # Server socket configuration variables POLLING_MAX_HTTP_BUFFER_SIZE = 1000 * 1000 diff --git a/reflex/constants/config.py b/reflex/constants/config.py index 966727426..3ff7aade5 100644 --- a/reflex/constants/config.py +++ b/reflex/constants/config.py @@ -1,6 +1,7 @@ """Config constants.""" import os +from pathlib import Path from types import SimpleNamespace from reflex.constants.base import Dirs, Reflex @@ -17,9 +18,7 @@ class Config(SimpleNamespace): # The name of the reflex config module. MODULE = "rxconfig" # The python config file. - FILE = f"{MODULE}{Ext.PY}" - # The previous config file. - PREVIOUS_FILE = f"pcconfig{Ext.PY}" + FILE = Path(f"{MODULE}{Ext.PY}") class Expiration(SimpleNamespace): @@ -37,7 +36,7 @@ class GitIgnore(SimpleNamespace): """Gitignore constants.""" # The gitignore file. - FILE = ".gitignore" + FILE = Path(".gitignore") # Files to gitignore. DEFAULTS = {Dirs.WEB, "*.db", "__pycache__/", "*.py[cod]", "assets/external/"} diff --git a/reflex/constants/custom_components.py b/reflex/constants/custom_components.py index 3ea9cf6ed..d879a01f2 100644 --- a/reflex/constants/custom_components.py +++ b/reflex/constants/custom_components.py @@ -2,6 +2,7 @@ from __future__ import annotations +from pathlib import Path from types import SimpleNamespace @@ -11,9 +12,9 @@ class CustomComponents(SimpleNamespace): # The name of the custom components source directory. SRC_DIR = "custom_components" # The name of the custom components pyproject.toml file. - PYPROJECT_TOML = "pyproject.toml" + PYPROJECT_TOML = Path("pyproject.toml") # The name of the custom components package README file. - PACKAGE_README = "README.md" + PACKAGE_README = Path("README.md") # The name of the custom components package .gitignore file. PACKAGE_GITIGNORE = ".gitignore" # The name of the distribution directory as result of a build. @@ -29,6 +30,6 @@ class CustomComponents(SimpleNamespace): "testpypi": "https://test.pypi.org/legacy/", } # The .gitignore file for the custom component project. - FILE = ".gitignore" + FILE = Path(".gitignore") # Files to gitignore. DEFAULTS = {"__pycache__/", "*.py[cod]", "*.egg-info/", "dist/"} diff --git a/reflex/constants/installer.py b/reflex/constants/installer.py index 01a11a37e..a815b1284 100644 --- a/reflex/constants/installer.py +++ b/reflex/constants/installer.py @@ -2,7 +2,6 @@ from __future__ import annotations -import os import platform from types import SimpleNamespace @@ -40,11 +39,10 @@ class Bun(SimpleNamespace): # Min Bun Version MIN_VERSION = "0.7.0" # The directory to store the bun. - ROOT_PATH = os.path.join(Reflex.DIR, "bun") + ROOT_PATH = Reflex.DIR / "bun" # Default bun path. - DEFAULT_PATH = os.path.join( - ROOT_PATH, "bin", "bun" if not IS_WINDOWS else "bun.exe" - ) + DEFAULT_PATH = ROOT_PATH / "bin" / ("bun" if not IS_WINDOWS else "bun.exe") + # URL to bun install script. INSTALL_URL = "https://bun.sh/install" # URL to windows install script. @@ -65,10 +63,10 @@ class Fnm(SimpleNamespace): # The FNM version. VERSION = "1.35.1" # The directory to store fnm. - DIR = os.path.join(Reflex.DIR, "fnm") + DIR = Reflex.DIR / "fnm" FILENAME = get_fnm_name() # The fnm executable binary. - EXE = os.path.join(DIR, "fnm.exe" if IS_WINDOWS else "fnm") + EXE = DIR / ("fnm.exe" if IS_WINDOWS else "fnm") # The URL to the fnm release binary INSTALL_URL = ( @@ -86,18 +84,19 @@ class Node(SimpleNamespace): MIN_VERSION = "18.17.0" # The node bin path. - BIN_PATH = os.path.join( - Fnm.DIR, - "node-versions", - f"v{VERSION}", - "installation", - "bin" if not IS_WINDOWS else "", + BIN_PATH = ( + Fnm.DIR + / "node-versions" + / f"v{VERSION}" + / "installation" + / ("bin" if not IS_WINDOWS else "") ) + # The default path where node is installed. - PATH = os.path.join(BIN_PATH, "node.exe" if IS_WINDOWS else "node") + PATH = BIN_PATH / ("node.exe" if IS_WINDOWS else "node") # The default path where npm is installed. - NPM_PATH = os.path.join(BIN_PATH, "npm") + NPM_PATH = BIN_PATH / "npm" # The environment variable to use the system installed node. USE_SYSTEM_VAR = "REFLEX_USE_SYSTEM_NODE" diff --git a/reflex/custom_components/custom_components.py b/reflex/custom_components/custom_components.py index e6957f8fd..146bb12c2 100644 --- a/reflex/custom_components/custom_components.py +++ b/reflex/custom_components/custom_components.py @@ -36,7 +36,7 @@ POST_CUSTOM_COMPONENTS_GALLERY_TIMEOUT = 15 @contextmanager -def set_directory(working_directory: str): +def set_directory(working_directory: str | Path): """Context manager that sets the working directory. Args: @@ -45,7 +45,8 @@ def set_directory(working_directory: str): Yields: Yield to the caller to perform operations in the working directory. """ - current_directory = os.getcwd() + current_directory = Path.cwd() + working_directory = Path(working_directory) try: os.chdir(working_directory) yield @@ -62,14 +63,14 @@ def _create_package_config(module_name: str, package_name: str): """ from reflex.compiler import templates - with open(CustomComponents.PYPROJECT_TOML, "w") as f: - f.write( - templates.CUSTOM_COMPONENTS_PYPROJECT_TOML.render( - module_name=module_name, - package_name=package_name, - reflex_version=constants.Reflex.VERSION, - ) + pyproject = Path(CustomComponents.PYPROJECT_TOML) + pyproject.write_text( + templates.CUSTOM_COMPONENTS_PYPROJECT_TOML.render( + module_name=module_name, + package_name=package_name, + reflex_version=constants.Reflex.VERSION, ) + ) def _get_package_config(exit_on_fail: bool = True) -> dict: @@ -84,11 +85,11 @@ def _get_package_config(exit_on_fail: bool = True) -> dict: Raises: Exit: If the pyproject.toml file is not found. """ + pyproject = Path(CustomComponents.PYPROJECT_TOML) try: - with open(CustomComponents.PYPROJECT_TOML, "rb") as f: - return dict(tomlkit.load(f)) + return dict(tomlkit.loads(pyproject.read_bytes())) except (OSError, TOMLKitError) as ex: - console.error(f"Unable to read from pyproject.toml due to {ex}") + console.error(f"Unable to read from {pyproject} due to {ex}") if exit_on_fail: raise typer.Exit(code=1) from ex raise @@ -103,17 +104,17 @@ def _create_readme(module_name: str, package_name: str): """ from reflex.compiler import templates - with open(CustomComponents.PACKAGE_README, "w") as f: - f.write( - templates.CUSTOM_COMPONENTS_README.render( - module_name=module_name, - package_name=package_name, - ) + readme = Path(CustomComponents.PACKAGE_README) + readme.write_text( + templates.CUSTOM_COMPONENTS_README.render( + module_name=module_name, + package_name=package_name, ) + ) def _write_source_and_init_py( - custom_component_src_dir: str, + custom_component_src_dir: Path, component_class_name: str, module_name: str, ): @@ -126,27 +127,17 @@ def _write_source_and_init_py( """ from reflex.compiler import templates - with open( - os.path.join( - custom_component_src_dir, - f"{module_name}.py", - ), - "w", - ) as f: - f.write( - templates.CUSTOM_COMPONENTS_SOURCE.render( - component_class_name=component_class_name, module_name=module_name - ) + module_path = custom_component_src_dir / f"{module_name}.py" + module_path.write_text( + templates.CUSTOM_COMPONENTS_SOURCE.render( + component_class_name=component_class_name, module_name=module_name ) + ) - with open( - os.path.join( - custom_component_src_dir, - CustomComponents.INIT_FILE, - ), - "w", - ) as f: - f.write(templates.CUSTOM_COMPONENTS_INIT_FILE.render(module_name=module_name)) + init_path = custom_component_src_dir / CustomComponents.INIT_FILE + init_path.write_text( + templates.CUSTOM_COMPONENTS_INIT_FILE.render(module_name=module_name) + ) def _populate_demo_app(name_variants: NameVariants): @@ -192,7 +183,7 @@ def _get_default_library_name_parts() -> list[str]: Returns: The parts of default library name. """ - current_dir_name = os.getcwd().split(os.path.sep)[-1] + current_dir_name = Path.cwd().name cleaned_dir_name = re.sub("[^0-9a-zA-Z-_]+", "", current_dir_name).lower() parts = [part for part in re.split("-|_", cleaned_dir_name) if part] @@ -345,7 +336,7 @@ def init( console.set_log_level(loglevel) - if os.path.exists(CustomComponents.PYPROJECT_TOML): + if CustomComponents.PYPROJECT_TOML.exists(): console.error(f"A {CustomComponents.PYPROJECT_TOML} already exists. Aborting.") typer.Exit(code=1) diff --git a/reflex/event.py b/reflex/event.py index 95358ace1..7384cf5bf 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -4,16 +4,19 @@ from __future__ import annotations import dataclasses import inspect +import sys import types import urllib.parse from base64 import b64encode from typing import ( Any, Callable, + ClassVar, Dict, List, Optional, Tuple, + Type, Union, get_type_hints, ) @@ -25,8 +28,15 @@ from reflex.utils import format from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch from reflex.utils.types import ArgsSpec, GenericType from reflex.vars import VarData -from reflex.vars.base import LiteralVar, Var -from reflex.vars.function import FunctionStringVar, FunctionVar +from reflex.vars.base import ( + CachedVarOperation, + LiteralNoneVar, + LiteralVar, + ToOperation, + Var, + cached_property_no_lock, +) +from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar from reflex.vars.object import ObjectVar try: @@ -375,7 +385,7 @@ class CallableEventSpec(EventSpec): class EventChain(EventActionsMixin): """Container for a chain of events that will be executed in order.""" - events: List[EventSpec] = dataclasses.field(default_factory=list) + events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list) args_spec: Optional[Callable] = dataclasses.field(default=None) @@ -478,7 +488,7 @@ class FileUpload: if isinstance(events, Var): raise ValueError(f"{on_upload_progress} cannot return a var {events}.") on_upload_progress_chain = EventChain( - events=events, + events=[*events], args_spec=self.on_upload_progress_args_spec, ) formatted_chain = str(format.format_prop(on_upload_progress_chain)) @@ -839,6 +849,16 @@ def call_script( ), ), } + if isinstance(javascript_code, str): + # When there is VarData, include it and eval the JS code inline on the client. + javascript_code, original_code = ( + LiteralVar.create(javascript_code), + javascript_code, + ) + if not javascript_code._get_all_var_data(): + # Without VarData, cast to string and eval the code in the event loop. + javascript_code = str(Var(_js_expr=original_code)) + return server_side( "_call_script", get_fn_signature(call_script), @@ -1126,3 +1146,178 @@ def get_fn_signature(fn: Callable) -> inspect.Signature: "state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any ) return signature.replace(parameters=(new_param, *signature.parameters.values())) + + +class EventVar(ObjectVar): + """Base class for event vars.""" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar): + """A literal event var.""" + + _var_value: EventSpec = dataclasses.field(default=None) # type: ignore + + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._js_expr)) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str( + FunctionStringVar("Event").call( + # event handler name + ".".join( + filter( + None, + format.get_event_handler_parts(self._var_value.handler), + ) + ), + # event handler args + {str(name): value for name, value in self._var_value.args}, + # event actions + self._var_value.event_actions, + # client handler name + *( + [self._var_value.client_handler_name] + if self._var_value.client_handler_name + else [] + ), + ) + ) + + @classmethod + def create( + cls, + value: EventSpec, + _var_data: VarData | None = None, + ) -> LiteralEventVar: + """Create a new LiteralEventVar instance. + + Args: + value: The value of the var. + _var_data: The data of the var. + + Returns: + The created LiteralEventVar instance. + """ + return cls( + _js_expr="", + _var_type=EventSpec, + _var_data=_var_data, + _var_value=value, + ) + + +class EventChainVar(FunctionVar): + """Base class for event chain vars.""" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar): + """A literal event chain var.""" + + _var_value: EventChain = dataclasses.field(default=None) # type: ignore + + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._js_expr)) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + sig = inspect.signature(self._var_value.args_spec) # type: ignore + if sig.parameters: + arg_def = tuple((f"_{p}" for p in sig.parameters)) + arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) + else: + # add a default argument for addEvents if none were specified in value.args_spec + # used to trigger the preventDefault() on the event. + arg_def = ("...args",) + arg_def_expr = Var(_js_expr="args") + + return str( + ArgsFunctionOperation.create( + arg_def, + FunctionStringVar.create("addEvents").call( + LiteralVar.create( + [LiteralVar.create(event) for event in self._var_value.events] + ), + arg_def_expr, + self._var_value.event_actions, + ), + ) + ) + + @classmethod + def create( + cls, + value: EventChain, + _var_data: VarData | None = None, + ) -> LiteralEventChainVar: + """Create a new LiteralEventChainVar instance. + + Args: + value: The value of the var. + _var_data: The data of the var. + + Returns: + The created LiteralEventChainVar instance. + """ + return cls( + _js_expr="", + _var_type=EventChain, + _var_data=_var_data, + _var_value=value, + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToEventVarOperation(ToOperation, EventVar): + """Result of a cast to an event var.""" + + _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) + + _default_var_type: ClassVar[Type] = EventSpec + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToEventChainVarOperation(ToOperation, EventChainVar): + """Result of a cast to an event chain var.""" + + _original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create()) + + _default_var_type: ClassVar[Type] = EventChain diff --git a/reflex/istate/data.py b/reflex/istate/data.py new file mode 100644 index 000000000..9f6e3b3f4 --- /dev/null +++ b/reflex/istate/data.py @@ -0,0 +1,126 @@ +"""This module contains the dataclasses representing the router object.""" + +import dataclasses +from typing import Optional + +from reflex import constants +from reflex.utils import format + + +@dataclasses.dataclass(frozen=True) +class HeaderData: + """An object containing headers data.""" + + host: str = "" + origin: str = "" + upgrade: str = "" + connection: str = "" + cookie: str = "" + pragma: str = "" + cache_control: str = "" + user_agent: str = "" + sec_websocket_version: str = "" + sec_websocket_key: str = "" + sec_websocket_extensions: str = "" + accept_encoding: str = "" + accept_language: str = "" + + def __init__(self, router_data: Optional[dict] = None): + """Initalize the HeaderData object based on router_data. + + Args: + router_data: the router_data dict. + """ + if router_data: + for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items(): + object.__setattr__(self, format.to_snake_case(k), v) + else: + for k in dataclasses.fields(self): + object.__setattr__(self, k.name, "") + + +@dataclasses.dataclass(frozen=True) +class PageData: + """An object containing page data.""" + + host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) + path: str = "" + raw_path: str = "" + full_path: str = "" + full_raw_path: str = "" + params: dict = dataclasses.field(default_factory=dict) + + def __init__(self, router_data: Optional[dict] = None): + """Initalize the PageData object based on router_data. + + Args: + router_data: the router_data dict. + """ + if router_data: + object.__setattr__( + self, + "host", + router_data.get(constants.RouteVar.HEADERS, {}).get("origin", ""), + ) + object.__setattr__( + self, "path", router_data.get(constants.RouteVar.PATH, "") + ) + object.__setattr__( + self, "raw_path", router_data.get(constants.RouteVar.ORIGIN, "") + ) + object.__setattr__(self, "full_path", f"{self.host}{self.path}") + object.__setattr__(self, "full_raw_path", f"{self.host}{self.raw_path}") + object.__setattr__( + self, "params", router_data.get(constants.RouteVar.QUERY, {}) + ) + else: + object.__setattr__(self, "host", "") + object.__setattr__(self, "path", "") + object.__setattr__(self, "raw_path", "") + object.__setattr__(self, "full_path", "") + object.__setattr__(self, "full_raw_path", "") + object.__setattr__(self, "params", {}) + + +@dataclasses.dataclass(frozen=True, init=False) +class SessionData: + """An object containing session data.""" + + client_token: str = "" + client_ip: str = "" + session_id: str = "" + + def __init__(self, router_data: Optional[dict] = None): + """Initalize the SessionData object based on router_data. + + Args: + router_data: the router_data dict. + """ + if router_data: + client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "") + client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "") + session_id = router_data.get(constants.RouteVar.SESSION_ID, "") + else: + client_token = client_ip = session_id = "" + object.__setattr__(self, "client_token", client_token) + object.__setattr__(self, "client_ip", client_ip) + object.__setattr__(self, "session_id", session_id) + + +@dataclasses.dataclass(frozen=True, init=False) +class RouterData: + """An object containing RouterData.""" + + session: SessionData = dataclasses.field(default_factory=SessionData) + headers: HeaderData = dataclasses.field(default_factory=HeaderData) + page: PageData = dataclasses.field(default_factory=PageData) + + def __init__(self, router_data: Optional[dict] = None): + """Initialize the RouterData object. + + Args: + router_data: the router_data dict. + """ + object.__setattr__(self, "session", SessionData(router_data)) + object.__setattr__(self, "headers", HeaderData(router_data)) + object.__setattr__(self, "page", PageData(router_data)) diff --git a/reflex/reflex.py b/reflex/reflex.py index 43ebe2eb4..4608ed171 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -15,7 +15,6 @@ from reflex_cli.utils import dependency from reflex import constants from reflex.config import get_config -from reflex.constants.base import LogLevel from reflex.custom_components.custom_components import custom_components_cli from reflex.state import reset_disk_state_manager from reflex.utils import console, redir, telemetry @@ -115,9 +114,6 @@ def _init( app_name, generation_hash=generation_hash ) - # Migrate Pynecone projects to Reflex. - prerequisites.migrate_to_reflex() - # Initialize the .gitignore. prerequisites.initialize_gitignore() @@ -247,11 +243,6 @@ def _run( setup_frontend(Path.cwd()) commands.append((frontend_cmd, Path.cwd(), frontend_port, backend)) - # If no loglevel is specified, set the subprocesses loglevel to WARNING. - subprocesses_loglevel = ( - loglevel if loglevel != LogLevel.DEFAULT else LogLevel.WARNING - ) - # In prod mode, run the backend on a separate thread. if backend and env == constants.Env.PROD: commands.append( @@ -259,7 +250,7 @@ def _run( backend_cmd, backend_host, backend_port, - subprocesses_loglevel, + loglevel.subprocess_level(), frontend, ) ) @@ -269,7 +260,7 @@ def _run( # In dev mode, run the backend on the main thread. if backend and env == constants.Env.DEV: backend_cmd( - backend_host, int(backend_port), subprocesses_loglevel, frontend + backend_host, int(backend_port), loglevel.subprocess_level(), frontend ) # The windows uvicorn bug workaround # https://github.com/reflex-dev/reflex/issues/2335 @@ -342,7 +333,7 @@ def export( backend=backend, zip_dest_dir=zip_dest_dir, upload_db_file=upload_db_file, - loglevel=loglevel, + loglevel=loglevel.subprocess_level(), ) @@ -577,7 +568,7 @@ def deploy( frontend=frontend, backend=backend, zipping=zipping, - loglevel=loglevel, + loglevel=loglevel.subprocess_level(), upload_db_file=upload_db_file, ), key=key, @@ -591,7 +582,7 @@ def deploy( interactive=interactive, with_metrics=with_metrics, with_tracing=with_tracing, - loglevel=loglevel.value, + loglevel=loglevel.subprocess_level(), ) diff --git a/reflex/state.py b/reflex/state.py index cda36a0a9..5798564fa 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -9,6 +9,7 @@ import dataclasses import functools import inspect import os +import pickle import uuid from abc import ABC, abstractmethod from collections import defaultdict @@ -19,6 +20,7 @@ from typing import ( TYPE_CHECKING, Any, AsyncIterator, + BinaryIO, Callable, ClassVar, Dict, @@ -33,11 +35,11 @@ from typing import ( get_type_hints, ) -import dill from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self from reflex.config import get_config +from reflex.istate.data import RouterData from reflex.vars.base import ( ComputedVar, DynamicRouteVar, @@ -74,6 +76,8 @@ from reflex.utils.exceptions import ( EventHandlerShadowsBuiltInStateMethod, ImmutableStateError, LockExpiredError, + SetUndefinedStateVarError, + StateSchemaMismatchError, ) from reflex.utils.exec import is_testing_env from reflex.utils.serializers import serializer @@ -92,125 +96,6 @@ var = computed_var TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb -@dataclasses.dataclass(frozen=True) -class HeaderData: - """An object containing headers data.""" - - host: str = "" - origin: str = "" - upgrade: str = "" - connection: str = "" - cookie: str = "" - pragma: str = "" - cache_control: str = "" - user_agent: str = "" - sec_websocket_version: str = "" - sec_websocket_key: str = "" - sec_websocket_extensions: str = "" - accept_encoding: str = "" - accept_language: str = "" - - def __init__(self, router_data: Optional[dict] = None): - """Initalize the HeaderData object based on router_data. - - Args: - router_data: the router_data dict. - """ - if router_data: - for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items(): - object.__setattr__(self, format.to_snake_case(k), v) - else: - for k in dataclasses.fields(self): - object.__setattr__(self, k.name, "") - - -@dataclasses.dataclass(frozen=True) -class PageData: - """An object containing page data.""" - - host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) - path: str = "" - raw_path: str = "" - full_path: str = "" - full_raw_path: str = "" - params: dict = dataclasses.field(default_factory=dict) - - def __init__(self, router_data: Optional[dict] = None): - """Initalize the PageData object based on router_data. - - Args: - router_data: the router_data dict. - """ - if router_data: - object.__setattr__( - self, - "host", - router_data.get(constants.RouteVar.HEADERS, {}).get("origin", ""), - ) - object.__setattr__( - self, "path", router_data.get(constants.RouteVar.PATH, "") - ) - object.__setattr__( - self, "raw_path", router_data.get(constants.RouteVar.ORIGIN, "") - ) - object.__setattr__(self, "full_path", f"{self.host}{self.path}") - object.__setattr__(self, "full_raw_path", f"{self.host}{self.raw_path}") - object.__setattr__( - self, "params", router_data.get(constants.RouteVar.QUERY, {}) - ) - else: - object.__setattr__(self, "host", "") - object.__setattr__(self, "path", "") - object.__setattr__(self, "raw_path", "") - object.__setattr__(self, "full_path", "") - object.__setattr__(self, "full_raw_path", "") - object.__setattr__(self, "params", {}) - - -@dataclasses.dataclass(frozen=True, init=False) -class SessionData: - """An object containing session data.""" - - client_token: str = "" - client_ip: str = "" - session_id: str = "" - - def __init__(self, router_data: Optional[dict] = None): - """Initalize the SessionData object based on router_data. - - Args: - router_data: the router_data dict. - """ - if router_data: - client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "") - client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "") - session_id = router_data.get(constants.RouteVar.SESSION_ID, "") - else: - client_token = client_ip = session_id = "" - object.__setattr__(self, "client_token", client_token) - object.__setattr__(self, "client_ip", client_ip) - object.__setattr__(self, "session_id", session_id) - - -@dataclasses.dataclass(frozen=True, init=False) -class RouterData: - """An object containing RouterData.""" - - session: SessionData = dataclasses.field(default_factory=SessionData) - headers: HeaderData = dataclasses.field(default_factory=HeaderData) - page: PageData = dataclasses.field(default_factory=PageData) - - def __init__(self, router_data: Optional[dict] = None): - """Initialize the RouterData object. - - Args: - router_data: the router_data dict. - """ - object.__setattr__(self, "session", SessionData(router_data)) - object.__setattr__(self, "headers", HeaderData(router_data)) - object.__setattr__(self, "page", PageData(router_data)) - - def _no_chain_background_task( state_cls: Type["BaseState"], name: str, fn: Callable ) -> Callable: @@ -698,11 +583,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) @classmethod - def _evaluate(cls, f: Callable[[Self], Any]) -> Var: + def _evaluate( + cls, f: Callable[[Self], Any], of_type: Union[type, None] = None + ) -> Var: """Evaluate a function to a ComputedVar. Experimental. Args: f: The function to evaluate. + of_type: The type of the ComputedVar. Defaults to Component. Returns: The ComputedVar. @@ -710,14 +598,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): console.warn( "The _evaluate method is experimental and may be removed in future versions." ) - from reflex.components.base.fragment import fragment from reflex.components.component import Component + of_type = of_type or Component + unique_var_name = get_unique_variable_name() - @computed_var(_js_expr=unique_var_name, return_type=Component) + @computed_var(_js_expr=unique_var_name, return_type=of_type) def computed_var_func(state: Self): - return fragment(f(state)) + result = f(state) + + if not isinstance(result, of_type): + console.warn( + f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. " + "You can specify expected type with `of_type` argument." + ) + + return result setattr(cls, unique_var_name, computed_var_func) cls.computed_vars[unique_var_name] = computed_var_func @@ -1260,6 +1157,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Args: name: The name of the attribute. value: The value of the attribute. + + Raises: + SetUndefinedStateVarError: If a value of a var is set without first defining it. """ if isinstance(value, MutableProxy): # unwrap proxy objects when assigning back to the state @@ -1277,6 +1177,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): self._mark_dirty() return + if ( + name not in self.vars + and name not in self.get_skip_vars() + and not name.startswith("__") + and not name.startswith(f"_{type(self).__name__}__") + ): + raise SetUndefinedStateVarError( + f"The state variable '{name}' has not been defined in '{type(self).__name__}'. " + f"All state variables must be declared before they can be set." + ) + # Set the attribute. super().__setattr__(name, value) @@ -2005,7 +1916,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): def __getstate__(self): """Get the state for redis serialization. - This method is called by cloudpickle to serialize the object. + This method is called by pickle to serialize the object. It explicitly removes parent_state and substates because those are serialized separately by the StateManagerRedis to allow for better horizontal scaling as state size increases. @@ -2021,6 +1932,43 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): state["__dict__"].pop("_was_touched", None) return state + def _serialize(self) -> bytes: + """Serialize the state for redis. + + Returns: + The serialized state. + """ + return pickle.dumps((state_to_schema(self), self)) + + @classmethod + def _deserialize( + cls, data: bytes | None = None, fp: BinaryIO | None = None + ) -> BaseState: + """Deserialize the state from redis/disk. + + data and fp are mutually exclusive, but one must be provided. + + Args: + data: The serialized state data. + fp: The file pointer to the serialized state data. + + Returns: + The deserialized state. + + Raises: + ValueError: If both data and fp are provided, or neither are provided. + StateSchemaMismatchError: If the state schema does not match the expected schema. + """ + if data is not None and fp is None: + (substate_schema, state) = pickle.loads(data) + elif fp is not None and data is None: + (substate_schema, state) = pickle.load(fp) + else: + raise ValueError("Only one of `data` or `fp` must be provided") + if substate_schema != state_to_schema(state): + raise StateSchemaMismatchError() + return state + class State(BaseState): """The app Base State.""" @@ -2177,7 +2125,11 @@ class ComponentState(State, mixin=True): """ cls._per_component_state_instance_count += 1 state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}" - component_state = type(state_cls_name, (cls, State), {}, mixin=False) + component_state = type( + state_cls_name, (cls, State), {"__module__": __name__}, mixin=False + ) + # Save a reference to the dynamic state for pickle/unpickle. + globals()[state_cls_name] = component_state component = component_state.get_component(*children, **props) component.State = component_state return component @@ -2643,7 +2595,7 @@ def is_serializable(value: Any) -> bool: Whether the value is serializable. """ try: - return bool(dill.dumps(value)) + return bool(pickle.dumps(value)) except Exception: return False @@ -2779,8 +2731,7 @@ class StateManagerDisk(StateManager): if token_path.exists(): try: with token_path.open(mode="rb") as file: - (substate_schema, substate) = dill.load(file) - if substate_schema == state_to_schema(substate): + substate = BaseState._deserialize(fp=file) await self.populate_substates(client_token, substate, root_state) return substate except Exception: @@ -2822,10 +2773,12 @@ class StateManagerDisk(StateManager): client_token, substate_address = _split_substate_key(token) root_state_token = _substate_key(client_token, substate_address.split(".")[0]) + root_state = self.states.get(root_state_token) + if root_state is None: + # Create a new root state which will be persisted in the next set_state call. + root_state = self.state(_reflex_internal_init=True) - return await self.load_state( - root_state_token, self.state(_reflex_internal_init=True) - ) + return await self.load_state(root_state_token, root_state) async def set_state_for_substate(self, client_token: str, substate: BaseState): """Set the state for a substate. @@ -2838,7 +2791,7 @@ class StateManagerDisk(StateManager): self.states[substate_token] = substate - state_dilled = dill.dumps((state_to_schema(substate), substate)) + state_dilled = substate._serialize() if not self.states_directory.exists(): self.states_directory.mkdir(parents=True, exist_ok=True) self.token_path(substate_token).write_bytes(state_dilled) @@ -2881,25 +2834,6 @@ class StateManagerDisk(StateManager): await self.set_state(token, state) -# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes -if not isinstance(State.validate.__func__, FunctionType): - cython_function_or_method = type(State.validate.__func__) - - @dill.register(cython_function_or_method) - def _dill_reduce_cython_function_or_method(pickler, obj): - # Ignore cython function when pickling. - pass - - -@dill.register(type(State)) -def _dill_reduce_state(pickler, obj): - if obj is not State and issubclass(obj, State): - # Avoid serializing subclasses of State, instead get them by reference from the State class. - pickler.save_reduce(State.get_class_substate, (obj.get_full_name(),), obj=obj) - else: - dill.Pickler.dispatch[type](pickler, obj) - - def _default_lock_expiration() -> int: """Get the default lock expiration time. @@ -3039,7 +2973,7 @@ class StateManagerRedis(StateManager): if redis_state is not None: # Deserialize the substate. - state = dill.loads(redis_state) + state = BaseState._deserialize(data=redis_state) # Populate parent state if missing and requested. if parent_state is None: @@ -3151,7 +3085,7 @@ class StateManagerRedis(StateManager): ) # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). if state._get_was_touched(): - pickle_state = dill.dumps(state, byref=True) + pickle_state = state._serialize() self._warn_if_too_large(state, len(pickle_state)) await self.redis.set( _substate_key(client_token, state), diff --git a/reflex/utils/build.py b/reflex/utils/build.py index 7a67ec32e..770809015 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -61,8 +61,8 @@ def generate_sitemap_config(deploy_url: str, export=False): def _zip( component_name: constants.ComponentName, - target: str, - root_dir: str, + target: str | Path, + root_dir: str | Path, exclude_venv_dirs: bool, upload_db_file: bool = False, dirs_to_exclude: set[str] | None = None, @@ -82,22 +82,22 @@ def _zip( top_level_dirs_to_exclude: The top level directory names immediately under root_dir to exclude. Do not exclude folders by these names further in the sub-directories. """ + target = Path(target) + root_dir = Path(root_dir) dirs_to_exclude = dirs_to_exclude or set() files_to_exclude = files_to_exclude or set() files_to_zip: list[str] = [] # Traverse the root directory in a top-down manner. In this traversal order, # we can modify the dirs list in-place to remove directories we don't want to include. for root, dirs, files in os.walk(root_dir, topdown=True): + root = Path(root) # Modify the dirs in-place so excluded and hidden directories are skipped in next traversal. dirs[:] = [ d for d in dirs - if (basename := os.path.basename(os.path.normpath(d))) - not in dirs_to_exclude + if (basename := Path(d).resolve().name) not in dirs_to_exclude and not basename.startswith(".") - and ( - not exclude_venv_dirs or not _looks_like_venv_dir(os.path.join(root, d)) - ) + and (not exclude_venv_dirs or not _looks_like_venv_dir(root / d)) ] # If we are at the top level with root_dir, exclude the top level dirs. if top_level_dirs_to_exclude and root == root_dir: @@ -109,7 +109,7 @@ def _zip( if not f.startswith(".") and (upload_db_file or not f.endswith(".db")) ] files_to_zip += [ - os.path.join(root, file) for file in files if file not in files_to_exclude + str(root / file) for file in files if file not in files_to_exclude ] # Create a progress bar for zipping the component. @@ -126,13 +126,13 @@ def _zip( for file in files_to_zip: console.debug(f"{target}: {file}", progress=progress) progress.advance(task) - zipf.write(file, os.path.relpath(file, root_dir)) + zipf.write(file, Path(file).relative_to(root_dir)) def zip_app( frontend: bool = True, backend: bool = True, - zip_dest_dir: str = os.getcwd(), + zip_dest_dir: str | Path = Path.cwd(), upload_db_file: bool = False, ): """Zip up the app. @@ -143,6 +143,7 @@ def zip_app( zip_dest_dir: The directory to export the zip file to. upload_db_file: Whether to upload the database file. """ + zip_dest_dir = Path(zip_dest_dir) files_to_exclude = { constants.ComponentName.FRONTEND.zip(), constants.ComponentName.BACKEND.zip(), @@ -151,8 +152,8 @@ def zip_app( if frontend: _zip( component_name=constants.ComponentName.FRONTEND, - target=os.path.join(zip_dest_dir, constants.ComponentName.FRONTEND.zip()), - root_dir=str(prerequisites.get_web_dir() / constants.Dirs.STATIC), + target=zip_dest_dir / constants.ComponentName.FRONTEND.zip(), + root_dir=prerequisites.get_web_dir() / constants.Dirs.STATIC, files_to_exclude=files_to_exclude, exclude_venv_dirs=False, ) @@ -160,8 +161,8 @@ def zip_app( if backend: _zip( component_name=constants.ComponentName.BACKEND, - target=os.path.join(zip_dest_dir, constants.ComponentName.BACKEND.zip()), - root_dir=".", + target=zip_dest_dir / constants.ComponentName.BACKEND.zip(), + root_dir=Path("."), dirs_to_exclude={"__pycache__"}, files_to_exclude=files_to_exclude, top_level_dirs_to_exclude={"assets"}, @@ -236,6 +237,9 @@ def setup_frontend( # Set the environment variables in client (env.json). set_env_json() + # update the last reflex run time. + prerequisites.set_last_reflex_run_time() + # Disable the Next telemetry. if disable_telemetry: processes.new_process( @@ -266,5 +270,6 @@ def setup_frontend_prod( build(deploy_url=get_config().deploy_url) -def _looks_like_venv_dir(dir_to_check: str) -> bool: - return os.path.exists(os.path.join(dir_to_check, "pyvenv.cfg")) +def _looks_like_venv_dir(dir_to_check: str | Path) -> bool: + dir_to_check = Path(dir_to_check) + return (dir_to_check / "pyvenv.cfg").exists() diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 7c3532861..8bce605b5 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -115,3 +115,15 @@ class PrimitiveUnserializableToJSON(ReflexError, ValueError): class InvalidLifespanTaskType(ReflexError, TypeError): """Raised when an invalid task type is registered as a lifespan task.""" + + +class DynamicComponentMissingLibrary(ReflexError, ValueError): + """Raised when a dynamic component is missing a library.""" + + +class SetUndefinedStateVarError(ReflexError, AttributeError): + """Raised when setting the value of a var without first declaring it.""" + + +class StateSchemaMismatchError(ReflexError, TypeError): + """Raised when the serialized schema of a state class does not match the current schema.""" diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index b6550fdde..acb69ee19 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -284,7 +284,7 @@ def run_granian_backend(host, port, loglevel: LogLevel): ).serve() except ImportError: console.error( - 'InstallError: REFLEX_USE_GRANIAN is set but `granian` is not installed. (run `pip install "granian>=1.6.0"`)' + 'InstallError: REFLEX_USE_GRANIAN is set but `granian` is not installed. (run `pip install "granian[reload]>=1.6.0"`)' ) os._exit(1) @@ -410,7 +410,7 @@ def run_granian_backend_prod(host, port, loglevel): ) except ImportError: console.error( - 'InstallError: REFLEX_USE_GRANIAN is set but `granian` is not installed. (run `pip install "granian>=1.6.0"`)' + 'InstallError: REFLEX_USE_GRANIAN is set but `granian` is not installed. (run `pip install "granian[reload]>=1.6.0"`)' ) diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 4029bd275..65c0f049b 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -359,19 +359,7 @@ def format_prop( # Handle event props. if isinstance(prop, EventChain): - sig = inspect.signature(prop.args_spec) # type: ignore - if sig.parameters: - arg_def = ",".join(f"_{p}" for p in sig.parameters) - arg_def_expr = f"[{arg_def}]" - else: - # add a default argument for addEvents if none were specified in prop.args_spec - # used to trigger the preventDefault() on the event. - arg_def = "...args" - arg_def_expr = "args" - - chain = ",".join([format_event(event) for event in prop.events]) - event = f"addEvents([{chain}], {arg_def_expr}, {json_dumps(prop.event_actions)})" - prop = f"({arg_def}) => {event}" + return str(Var.create(prop)) # Handle other types. elif isinstance(prop, str): diff --git a/reflex/utils/path_ops.py b/reflex/utils/path_ops.py index 00affd820..1de635b88 100644 --- a/reflex/utils/path_ops.py +++ b/reflex/utils/path_ops.py @@ -164,7 +164,7 @@ def use_system_bun() -> bool: return use_system_install(constants.Bun.USE_SYSTEM_VAR) -def get_node_bin_path() -> str | None: +def get_node_bin_path() -> Path | None: """Get the node binary dir path. Returns: @@ -173,8 +173,8 @@ def get_node_bin_path() -> str | None: bin_path = Path(constants.Node.BIN_PATH) if not bin_path.exists(): str_path = which("node") - return str(Path(str_path).parent.resolve()) if str_path else str_path - return str(bin_path.resolve()) + return Path(str_path).parent.resolve() if str_path else None + return bin_path.resolve() def get_node_path() -> str | None: diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index f9eb9a790..3c2875204 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -2,9 +2,9 @@ from __future__ import annotations +import contextlib import dataclasses import functools -import glob import importlib import importlib.metadata import json @@ -19,7 +19,6 @@ import tempfile import time import zipfile from datetime import datetime -from fileinput import FileInput from pathlib import Path from types import ModuleType from typing import Callable, List, Optional @@ -132,6 +131,14 @@ def get_or_set_last_reflex_version_check_datetime(): return last_version_check_datetime +def set_last_reflex_run_time(): + """Set the last Reflex run time.""" + path_ops.update_json_file( + get_web_dir() / constants.Reflex.JSON, + {"last_reflex_run_datetime": str(datetime.now())}, + ) + + def check_node_version() -> bool: """Check the version of Node.js. @@ -192,7 +199,7 @@ def get_bun_version() -> version.Version | None: """ try: # Run the bun -v command and capture the output - result = processes.new_process([get_config().bun_path, "-v"], run=True) + result = processes.new_process([str(get_config().bun_path), "-v"], run=True) return version.parse(result.stdout) # type: ignore except FileNotFoundError: return None @@ -217,7 +224,7 @@ def get_install_package_manager() -> str | None: or windows_npm_escape_hatch() ): return get_package_manager() - return get_config().bun_path + return str(get_config().bun_path) def get_package_manager() -> str | None: @@ -394,9 +401,7 @@ def validate_app_name(app_name: str | None = None) -> str: Raises: Exit: if the app directory name is reflex or if the name is not standard for a python package name. """ - app_name = ( - app_name if app_name else os.getcwd().split(os.path.sep)[-1].replace("-", "_") - ) + app_name = app_name if app_name else Path.cwd().name.replace("-", "_") # Make sure the app is not named "reflex". if app_name.lower() == constants.Reflex.MODULE_NAME: console.error( @@ -430,7 +435,7 @@ def create_config(app_name: str): def initialize_gitignore( - gitignore_file: str = constants.GitIgnore.FILE, + gitignore_file: Path = constants.GitIgnore.FILE, files_to_ignore: set[str] = constants.GitIgnore.DEFAULTS, ): """Initialize the template .gitignore file. @@ -441,9 +446,10 @@ def initialize_gitignore( """ # Combine with the current ignored files. current_ignore: set[str] = set() - if os.path.exists(gitignore_file): - with open(gitignore_file, "r") as f: - current_ignore |= set([line.strip() for line in f.readlines()]) + if gitignore_file.exists(): + current_ignore |= set( + line.strip() for line in gitignore_file.read_text().splitlines() + ) if files_to_ignore == current_ignore: console.debug(f"{gitignore_file} already up to date.") @@ -451,9 +457,11 @@ def initialize_gitignore( files_to_ignore |= current_ignore # Write files to the .gitignore file. - with open(gitignore_file, "w", newline="\n") as f: - console.debug(f"Creating {gitignore_file}") - f.write(f"{(path_ops.join(sorted(files_to_ignore))).lstrip()}\n") + gitignore_file.touch(exist_ok=True) + console.debug(f"Creating {gitignore_file}") + gitignore_file.write_text( + "\n".join(sorted(files_to_ignore)) + "\n", + ) def initialize_requirements_txt(): @@ -546,8 +554,8 @@ def initialize_app_directory( # Rename the template app to the app name. path_ops.mv(template_code_dir_name, app_name) path_ops.mv( - os.path.join(app_name, template_name + constants.Ext.PY), - os.path.join(app_name, app_name + constants.Ext.PY), + Path(app_name) / (template_name + constants.Ext.PY), + Path(app_name) / (app_name + constants.Ext.PY), ) # Fix up the imports. @@ -691,7 +699,7 @@ def _update_next_config( def remove_existing_bun_installation(): """Remove existing bun installation.""" console.debug("Removing existing bun installation.") - if os.path.exists(get_config().bun_path): + if Path(get_config().bun_path).exists(): path_ops.rm(constants.Bun.ROOT_PATH) @@ -731,7 +739,7 @@ def download_and_extract_fnm_zip(): # Download the zip file url = constants.Fnm.INSTALL_URL console.debug(f"Downloading {url}") - fnm_zip_file = os.path.join(constants.Fnm.DIR, f"{constants.Fnm.FILENAME}.zip") + fnm_zip_file = constants.Fnm.DIR / f"{constants.Fnm.FILENAME}.zip" # Function to download and extract the FNM zip release. try: # Download the FNM zip release. @@ -770,7 +778,7 @@ def install_node(): return path_ops.mkdir(constants.Fnm.DIR) - if not os.path.exists(constants.Fnm.EXE): + if not constants.Fnm.EXE.exists(): download_and_extract_fnm_zip() if constants.IS_WINDOWS: @@ -827,7 +835,7 @@ def install_bun(): ) # Skip if bun is already installed. - if os.path.exists(get_config().bun_path) and get_bun_version() == version.parse( + if Path(get_config().bun_path).exists() and get_bun_version() == version.parse( constants.Bun.VERSION ): console.debug("Skipping bun installation as it is already installed.") @@ -842,7 +850,7 @@ def install_bun(): f"irm {constants.Bun.WINDOWS_INSTALL_URL}|iex", ], env={ - "BUN_INSTALL": constants.Bun.ROOT_PATH, + "BUN_INSTALL": str(constants.Bun.ROOT_PATH), "BUN_VERSION": constants.Bun.VERSION, }, shell=True, @@ -858,25 +866,26 @@ def install_bun(): download_and_run( constants.Bun.INSTALL_URL, f"bun-v{constants.Bun.VERSION}", - BUN_INSTALL=constants.Bun.ROOT_PATH, + BUN_INSTALL=str(constants.Bun.ROOT_PATH), ) -def _write_cached_procedure_file(payload: str, cache_file: str): - with open(cache_file, "w") as f: - f.write(payload) +def _write_cached_procedure_file(payload: str, cache_file: str | Path): + cache_file = Path(cache_file) + cache_file.write_text(payload) -def _read_cached_procedure_file(cache_file: str) -> str | None: - if os.path.exists(cache_file): - with open(cache_file, "r") as f: - return f.read() +def _read_cached_procedure_file(cache_file: str | Path) -> str | None: + cache_file = Path(cache_file) + if cache_file.exists(): + return cache_file.read_text() return None -def _clear_cached_procedure_file(cache_file: str): - if os.path.exists(cache_file): - os.remove(cache_file) +def _clear_cached_procedure_file(cache_file: str | Path): + cache_file = Path(cache_file) + if cache_file.exists(): + cache_file.unlink() def cached_procedure(cache_file: str, payload_fn: Callable[..., str]): @@ -977,7 +986,7 @@ def needs_reinit(frontend: bool = True) -> bool: Raises: Exit: If the app is not initialized. """ - if not os.path.exists(constants.Config.FILE): + if not constants.Config.FILE.exists(): console.error( f"[cyan]{constants.Config.FILE}[/cyan] not found. Move to the root folder of your project, or run [bold]{constants.Reflex.MODULE_NAME} init[/bold] to start a new project." ) @@ -988,7 +997,7 @@ def needs_reinit(frontend: bool = True) -> bool: return False # Make sure the .reflex directory exists. - if not os.path.exists(constants.Reflex.DIR): + if not constants.Reflex.DIR.exists(): return True # Make sure the .web directory exists in frontend mode. @@ -1093,25 +1102,21 @@ def ensure_reflex_installation_id() -> Optional[int]: """ try: initialize_reflex_user_directory() - installation_id_file = os.path.join(constants.Reflex.DIR, "installation_id") + installation_id_file = constants.Reflex.DIR / "installation_id" installation_id = None - if os.path.exists(installation_id_file): - try: - with open(installation_id_file, "r") as f: - installation_id = int(f.read()) - except Exception: + if installation_id_file.exists(): + with contextlib.suppress(Exception): + installation_id = int(installation_id_file.read_text()) # If anything goes wrong at all... just regenerate. # Like what? Examples: # - file not exists # - file not readable # - content not parseable as an int - pass if installation_id is None: installation_id = random.getrandbits(128) - with open(installation_id_file, "w") as f: - f.write(str(installation_id)) + installation_id_file.write_text(str(installation_id)) # If we get here, installation_id is definitely set return installation_id except Exception as e: @@ -1205,50 +1210,6 @@ def prompt_for_template(templates: list[Template]) -> str: return templates[int(template)].name -def migrate_to_reflex(): - """Migration from Pynecone to Reflex.""" - # Check if the old config file exists. - if not os.path.exists(constants.Config.PREVIOUS_FILE): - return - - # Ask the user if they want to migrate. - action = console.ask( - "Pynecone project detected. Automatically upgrade to Reflex?", - choices=["y", "n"], - ) - if action == "n": - return - - # Rename pcconfig to rxconfig. - console.log( - f"[bold]Renaming {constants.Config.PREVIOUS_FILE} to {constants.Config.FILE}" - ) - os.rename(constants.Config.PREVIOUS_FILE, constants.Config.FILE) - - # Find all python files in the app directory. - file_pattern = os.path.join(get_config().app_name, "**/*.py") - file_list = glob.glob(file_pattern, recursive=True) - - # Add the config file to the list of files to be migrated. - file_list.append(constants.Config.FILE) - - # Migrate all files. - updates = { - "Pynecone": "Reflex", - "pynecone as pc": "reflex as rx", - "pynecone.io": "reflex.dev", - "pynecone": "reflex", - "pc.": "rx.", - "pcconfig": "rxconfig", - } - for file_path in file_list: - with FileInput(file_path, inplace=True) as file: - for line in file: - for old, new in updates.items(): - line = line.replace(old, new) - print(line, end="") - - def fetch_app_templates(version: str) -> dict[str, Template]: """Fetch a dict of templates from the templates repo using github API. @@ -1401,7 +1362,7 @@ def initialize_app(app_name: str, template: str | None = None): from reflex.utils import telemetry # Check if the app is already initialized. - if os.path.exists(constants.Config.FILE): + if constants.Config.FILE.exists(): telemetry.send("reinit") return diff --git a/reflex/utils/processes.py b/reflex/utils/processes.py index c435af7d0..a45676c01 100644 --- a/reflex/utils/processes.py +++ b/reflex/utils/processes.py @@ -156,7 +156,7 @@ def new_process(args, run: bool = False, show_logs: bool = False, **kwargs): Raises: Exit: When attempting to run a command with a None value. """ - node_bin_path = path_ops.get_node_bin_path() + node_bin_path = str(path_ops.get_node_bin_path()) if not node_bin_path and not prerequisites.CURRENTLY_INSTALLING_NODE: console.warn( "The path to the Node binary could not be found. Please ensure that Node is properly " @@ -167,7 +167,7 @@ def new_process(args, run: bool = False, show_logs: bool = False, **kwargs): console.error(f"Invalid command: {args}") raise typer.Exit(1) # Add the node bin path to the PATH environment variable. - env = { + env: dict[str, str] = { **os.environ, "PATH": os.pathsep.join( [node_bin_path if node_bin_path else "", os.environ["PATH"]] diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 2d78a14be..0f8a80f8d 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -385,6 +385,15 @@ class Var(Generic[VAR_TYPE]): Returns: The converted var. """ + from reflex.event import ( + EventChain, + EventChainVar, + EventSpec, + EventVar, + ToEventChainVarOperation, + ToEventVarOperation, + ) + from .function import FunctionVar, ToFunctionOperation from .number import ( BooleanVar, @@ -416,6 +425,10 @@ class Var(Generic[VAR_TYPE]): return self.to(BooleanVar, output) if fixed_output_type is None: return ToNoneOperation.create(self) + if fixed_output_type is EventSpec: + return self.to(EventVar, output) + if fixed_output_type is EventChain: + return self.to(EventChainVar, output) if issubclass(fixed_output_type, Base): return self.to(ObjectVar, output) if dataclasses.is_dataclass(fixed_output_type) and not issubclass( @@ -453,10 +466,13 @@ class Var(Generic[VAR_TYPE]): if issubclass(output, StringVar): return ToStringOperation.create(self, var_type or str) - if issubclass(output, (ObjectVar, Base)): - return ToObjectOperation.create(self, var_type or dict) + if issubclass(output, EventVar): + return ToEventVarOperation.create(self, var_type or EventSpec) - if dataclasses.is_dataclass(output): + if issubclass(output, EventChainVar): + return ToEventChainVarOperation.create(self, var_type or EventChain) + + if issubclass(output, (ObjectVar, Base)): return ToObjectOperation.create(self, var_type or dict) if issubclass(output, FunctionVar): @@ -469,6 +485,9 @@ class Var(Generic[VAR_TYPE]): if issubclass(output, NoneVar): return ToNoneOperation.create(self) + if dataclasses.is_dataclass(output): + return ToObjectOperation.create(self, var_type or dict) + # If we can't determine the first argument, we just replace the _var_type. if not issubclass(output, Var) or var_type is None: return dataclasses.replace( @@ -494,6 +513,8 @@ class Var(Generic[VAR_TYPE]): Raises: TypeError: If the type is not supported for guessing. """ + from reflex.event import EventChain, EventChainVar, EventSpec, EventVar + from .number import BooleanVar, NumberVar from .object import ObjectVar from .sequence import ArrayVar, StringVar @@ -539,6 +560,10 @@ class Var(Generic[VAR_TYPE]): return self.to(ArrayVar, self._var_type) if issubclass(fixed_type, str): return self.to(StringVar, self._var_type) + if issubclass(fixed_type, EventSpec): + return self.to(EventVar, self._var_type) + if issubclass(fixed_type, EventChain): + return self.to(EventChainVar, self._var_type) if issubclass(fixed_type, Base): return self.to(ObjectVar, self._var_type) if dataclasses.is_dataclass(fixed_type): @@ -1029,47 +1054,22 @@ class LiteralVar(Var): if value is None: return LiteralNoneVar.create(_var_data=_var_data) - from reflex.event import EventChain, EventHandler, EventSpec + from reflex.event import ( + EventChain, + EventHandler, + EventSpec, + LiteralEventChainVar, + LiteralEventVar, + ) from reflex.utils.format import get_event_handler_parts - from .function import ArgsFunctionOperation, FunctionStringVar from .object import LiteralObjectVar if isinstance(value, EventSpec): - event_name = LiteralVar.create( - ".".join(filter(None, get_event_handler_parts(value.handler))) - ) - event_args = LiteralVar.create( - {str(name): value for name, value in value.args} - ) - event_client_name = LiteralVar.create(value.client_handler_name) - return FunctionStringVar("Event").call( - event_name, - event_args, - *([event_client_name] if value.client_handler_name else []), - ) + return LiteralEventVar.create(value, _var_data=_var_data) if isinstance(value, EventChain): - sig = inspect.signature(value.args_spec) # type: ignore - if sig.parameters: - arg_def = tuple((f"_{p}" for p in sig.parameters)) - arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) - else: - # add a default argument for addEvents if none were specified in value.args_spec - # used to trigger the preventDefault() on the event. - arg_def = ("...args",) - arg_def_expr = Var(_js_expr="args") - - return ArgsFunctionOperation.create( - arg_def, - FunctionStringVar.create("addEvents").call( - LiteralVar.create( - [LiteralVar.create(event) for event in value.events] - ), - arg_def_expr, - LiteralVar.create(value.event_actions), - ), - ) + return LiteralEventChainVar.create(value, _var_data=_var_data) if isinstance(value, EventHandler): return Var(_js_expr=".".join(filter(None, get_event_handler_parts(value)))) @@ -2126,9 +2126,16 @@ class NoneVar(Var[None]): """A var representing None.""" +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) class LiteralNoneVar(LiteralVar, NoneVar): """A var representing None.""" + _var_value: None = None + def json(self) -> str: """Serialize the var to a JSON string. diff --git a/tests/integration/test_call_script.py b/tests/integration/test_call_script.py index 5a3b83abf..744d83d16 100644 --- a/tests/integration/test_call_script.py +++ b/tests/integration/test_call_script.py @@ -46,6 +46,7 @@ def CallScript(): inline_counter: int = 0 external_counter: int = 0 value: str = "Initial" + last_result: str = "" def call_script_callback(self, result): self.results.append(result) @@ -137,6 +138,32 @@ def CallScript(): callback=CallScriptState.set_external_counter, # type: ignore ) + def call_with_var_f_string(self): + return rx.call_script( + f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}", + callback=CallScriptState.set_last_result, # type: ignore + ) + + def call_with_var_str_cast(self): + return rx.call_script( + f"{str(rx.Var('inline_counter'))} + {str(rx.Var('external_counter'))}", + callback=CallScriptState.set_last_result, # type: ignore + ) + + def call_with_var_f_string_wrapped(self): + return rx.call_script( + rx.Var(f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}"), + callback=CallScriptState.set_last_result, # type: ignore + ) + + def call_with_var_str_cast_wrapped(self): + return rx.call_script( + rx.Var( + f"{str(rx.Var('inline_counter'))} + {str(rx.Var('external_counter'))}" + ), + callback=CallScriptState.set_last_result, # type: ignore + ) + def reset_(self): yield rx.call_script("inline_counter = 0; external_counter = 0") self.reset() @@ -234,6 +261,68 @@ def CallScript(): id="update_value", ), rx.button("Reset", id="reset", on_click=CallScriptState.reset_), + rx.input( + value=CallScriptState.last_result, + id="last_result", + read_only=True, + on_click=CallScriptState.set_last_result(""), # type: ignore + ), + rx.button( + "call_with_var_f_string", + on_click=CallScriptState.call_with_var_f_string, + id="call_with_var_f_string", + ), + rx.button( + "call_with_var_str_cast", + on_click=CallScriptState.call_with_var_str_cast, + id="call_with_var_str_cast", + ), + rx.button( + "call_with_var_f_string_wrapped", + on_click=CallScriptState.call_with_var_f_string_wrapped, + id="call_with_var_f_string_wrapped", + ), + rx.button( + "call_with_var_str_cast_wrapped", + on_click=CallScriptState.call_with_var_str_cast_wrapped, + id="call_with_var_str_cast_wrapped", + ), + rx.button( + "call_with_var_f_string_inline", + on_click=rx.call_script( + f"{rx.Var('inline_counter')} + {CallScriptState.last_result}", + callback=CallScriptState.set_last_result, # type: ignore + ), + id="call_with_var_f_string_inline", + ), + rx.button( + "call_with_var_str_cast_inline", + on_click=rx.call_script( + f"{str(rx.Var('inline_counter'))} + {str(rx.Var('external_counter'))}", + callback=CallScriptState.set_last_result, # type: ignore + ), + id="call_with_var_str_cast_inline", + ), + rx.button( + "call_with_var_f_string_wrapped_inline", + on_click=rx.call_script( + rx.Var( + f"{rx.Var('inline_counter')} + {CallScriptState.last_result}" + ), + callback=CallScriptState.set_last_result, # type: ignore + ), + id="call_with_var_f_string_wrapped_inline", + ), + rx.button( + "call_with_var_str_cast_wrapped_inline", + on_click=rx.call_script( + rx.Var( + f"{str(rx.Var('inline_counter'))} + {str(rx.Var('external_counter'))}" + ), + callback=CallScriptState.set_last_result, # type: ignore + ), + id="call_with_var_str_cast_wrapped_inline", + ), ) @@ -363,3 +452,73 @@ def test_call_script( call_script.poll_for_content(update_value_button, exp_not_equal="Initial") == "updated" ) + + +def test_call_script_w_var( + call_script: AppHarness, + driver: WebDriver, +): + """Test evaluating javascript expressions containing Vars. + + Args: + call_script: harness for CallScript app. + driver: WebDriver instance. + """ + assert_token(driver) + last_result = driver.find_element(By.ID, "last_result") + assert last_result.get_attribute("value") == "" + + inline_return_button = driver.find_element(By.ID, "inline_return") + + call_with_var_f_string_button = driver.find_element(By.ID, "call_with_var_f_string") + call_with_var_str_cast_button = driver.find_element(By.ID, "call_with_var_str_cast") + call_with_var_f_string_wrapped_button = driver.find_element( + By.ID, "call_with_var_f_string_wrapped" + ) + call_with_var_str_cast_wrapped_button = driver.find_element( + By.ID, "call_with_var_str_cast_wrapped" + ) + call_with_var_f_string_inline_button = driver.find_element( + By.ID, "call_with_var_f_string_inline" + ) + call_with_var_str_cast_inline_button = driver.find_element( + By.ID, "call_with_var_str_cast_inline" + ) + call_with_var_f_string_wrapped_inline_button = driver.find_element( + By.ID, "call_with_var_f_string_wrapped_inline" + ) + call_with_var_str_cast_wrapped_inline_button = driver.find_element( + By.ID, "call_with_var_str_cast_wrapped_inline" + ) + + inline_return_button.click() + call_with_var_f_string_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="") == "1" + + inline_return_button.click() + call_with_var_str_cast_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="1") == "2" + + inline_return_button.click() + call_with_var_f_string_wrapped_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="2") == "3" + + inline_return_button.click() + call_with_var_str_cast_wrapped_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="3") == "4" + + inline_return_button.click() + call_with_var_f_string_inline_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="4") == "9" + + inline_return_button.click() + call_with_var_str_cast_inline_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="9") == "6" + + inline_return_button.click() + call_with_var_f_string_wrapped_inline_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="6") == "13" + + inline_return_button.click() + call_with_var_str_cast_wrapped_inline_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="13") == "8" diff --git a/tests/integration/test_dynamic_components.py b/tests/integration/test_dynamic_components.py index 5a4d99f9e..aeebd10e9 100644 --- a/tests/integration/test_dynamic_components.py +++ b/tests/integration/test_dynamic_components.py @@ -65,7 +65,9 @@ def DynamicComponents(): DynamicComponentsState.client_token_component, DynamicComponentsState.button, rx.text( - DynamicComponentsState._evaluate(lambda state: factorial(state.value)), + DynamicComponentsState._evaluate( + lambda state: factorial(state.value), of_type=int + ), id="factorial", ), ) diff --git a/tests/integration/test_urls.py b/tests/integration/test_urls.py index bcf17fe41..81689aa18 100755 --- a/tests/integration/test_urls.py +++ b/tests/integration/test_urls.py @@ -8,7 +8,7 @@ import pytest import requests -def check_urls(repo_dir): +def check_urls(repo_dir: Path): """Check that all URLs in the repo are valid and secure. Args: @@ -21,33 +21,33 @@ def check_urls(repo_dir): errors = [] for root, _dirs, files in os.walk(repo_dir): - if "__pycache__" in root: + root = Path(root) + if root.stem == "__pycache__": continue for file_name in files: if not file_name.endswith(".py") and not file_name.endswith(".md"): continue - file_path = os.path.join(root, file_name) + file_path = root / file_name try: - with open(file_path, "r", encoding="utf-8", errors="ignore") as file: - for line in file: - urls = url_pattern.findall(line) - for url in set(urls): - if url.startswith("http://"): - errors.append( - f"Found insecure HTTP URL: {url} in {file_path}" - ) - url = url.strip('"\n') - try: - response = requests.head( - url, allow_redirects=True, timeout=5 - ) - response.raise_for_status() - except requests.RequestException as e: - errors.append( - f"Error accessing URL: {url} in {file_path} | Error: {e}, , Check your path ends with a /" - ) + for line in file_path.read_text().splitlines(): + urls = url_pattern.findall(line) + for url in set(urls): + if url.startswith("http://"): + errors.append( + f"Found insecure HTTP URL: {url} in {file_path}" + ) + url = url.strip('"\n') + try: + response = requests.head( + url, allow_redirects=True, timeout=5 + ) + response.raise_for_status() + except requests.RequestException as e: + errors.append( + f"Error accessing URL: {url} in {file_path} | Error: {e}, , Check your path ends with a /" + ) except Exception as e: errors.append(f"Error reading file: {file_path} | Error: {e}") @@ -58,7 +58,7 @@ def check_urls(repo_dir): "repo_dir", [Path(__file__).resolve().parent.parent / "reflex"], ) -def test_find_and_check_urls(repo_dir): +def test_find_and_check_urls(repo_dir: Path): """Test that all URLs in the repo are valid and secure. Args: diff --git a/tests/units/compiler/test_compiler.py b/tests/units/compiler/test_compiler.py index 63014cf33..afacf43c5 100644 --- a/tests/units/compiler/test_compiler.py +++ b/tests/units/compiler/test_compiler.py @@ -1,4 +1,4 @@ -import os +from pathlib import Path from typing import List import pytest @@ -130,7 +130,7 @@ def test_compile_stylesheets(tmp_path, mocker): ] assert compiler.compile_root_stylesheet(stylesheets) == ( - os.path.join(".web", "styles", "styles.css"), + str(Path(".web") / "styles" / "styles.css"), f"@import url('./tailwind.css'); \n" f"@import url('https://fonts.googleapis.com/css?family=Sofia&effect=neon|outline|emboss|shadow-multiple'); \n" f"@import url('https://cdn.jsdelivr.net/npm/bootstrap@3.3.7/dist/css/bootstrap.min.css'); \n" @@ -164,7 +164,7 @@ def test_compile_stylesheets_exclude_tailwind(tmp_path, mocker): ] assert compiler.compile_root_stylesheet(stylesheets) == ( - os.path.join(".web", "styles", "styles.css"), + str(Path(".web") / "styles" / "styles.css"), "@import url('../public/styles.css'); \n", ) diff --git a/tests/units/components/base/test_script.py b/tests/units/components/base/test_script.py index c6b67da11..be62276f2 100644 --- a/tests/units/components/base/test_script.py +++ b/tests/units/components/base/test_script.py @@ -58,14 +58,14 @@ def test_script_event_handler(): ) render_dict = component.render() assert ( - f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{ }})))], args, ({{ }})))))}}' + f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{ }}), ({{ }})))], args, ({{ }})))))}}' in render_dict["props"] ) assert ( - f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{ }})))], args, ({{ }})))))}}' + f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{ }}), ({{ }})))], args, ({{ }})))))}}' in render_dict["props"] ) assert ( - f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{ }})))], args, ({{ }})))))}}' + f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{ }}), ({{ }})))], args, ({{ }})))))}}' in render_dict["props"] ) diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 73d3f611b..5e94db052 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -832,7 +832,7 @@ def test_component_event_trigger_arbitrary_args(): assert comp.render()["props"][0] == ( "onFoo={((__e, _alpha, _bravo, _charlie) => ((addEvents(" - f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }})))], ' + f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }}), ({{ }})))], ' "[__e, _alpha, _bravo, _charlie], ({ })))))}" ) @@ -1178,7 +1178,7 @@ TEST_VAR = LiteralVar.create("test")._replace( ) FORMATTED_TEST_VAR = LiteralVar.create(f"foo{TEST_VAR}bar") STYLE_VAR = TEST_VAR._replace(_js_expr="style") -EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain) +EVENT_CHAIN_VAR = TEST_VAR.to(EventChain) ARG_VAR = Var(_js_expr="arg") TEST_VAR_DICT_OF_DICT = LiteralVar.create({"a": {"b": "test"}})._replace( @@ -2159,7 +2159,7 @@ class TriggerState(rx.State): rx.text("random text", on_click=TriggerState.do_something), rx.text( "random text", - on_click=Var(_js_expr="toggleColorMode", _var_type=EventChain), + on_click=Var(_js_expr="toggleColorMode").to(EventChain), ), ), True, @@ -2169,7 +2169,7 @@ class TriggerState(rx.State): rx.text("random text", on_click=rx.console_log("log")), rx.text( "random text", - on_click=Var(_js_expr="toggleColorMode", _var_type=EventChain), + on_click=Var(_js_expr="toggleColorMode").to(EventChain), ), ), False, diff --git a/tests/units/test_config.py b/tests/units/test_config.py index 31dd77649..a6c6fe697 100644 --- a/tests/units/test_config.py +++ b/tests/units/test_config.py @@ -192,4 +192,4 @@ def test_reflex_dir_env_var(monkeypatch, tmp_path): mp_ctx = multiprocessing.get_context(method="spawn") with mp_ctx.Pool(processes=1) as pool: - assert pool.apply(reflex_dir_constant) == str(tmp_path) + assert pool.apply(reflex_dir_constant) == tmp_path diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 205162b9f..5bfac7628 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -41,6 +41,7 @@ from reflex.state import ( ) from reflex.testing import chdir from reflex.utils import format, prerequisites, types +from reflex.utils.exceptions import SetUndefinedStateVarError from reflex.utils.format import json_dumps from reflex.vars.base import ComputedVar, Var from tests.units.states.mutation import MutableSQLAModel, MutableTestState @@ -3262,3 +3263,45 @@ def test_child_mixin_state() -> None: assert "computed" in ChildUsesMixinState.inherited_vars assert "computed" not in ChildUsesMixinState.computed_vars + + +def test_assignment_to_undeclared_vars(): + """Test that an attribute error is thrown when undeclared vars are set.""" + + class State(BaseState): + val: str + _val: str + __val: str # type: ignore + + def handle_supported_regular_vars(self): + self.val = "no underscore" + self._val = "single leading underscore" + self.__val = "double leading undercore" + + def handle_regular_var(self): + self.num = 5 + + def handle_backend_var(self): + self._num = 5 + + def handle_non_var(self): + self.__num = 5 + + class Substate(State): + def handle_var(self): + self.value = 20 + + state = State() # type: ignore + sub_state = Substate() # type: ignore + + with pytest.raises(SetUndefinedStateVarError): + state.handle_regular_var() + + with pytest.raises(SetUndefinedStateVarError): + sub_state.handle_var() + + with pytest.raises(SetUndefinedStateVarError): + state.handle_backend_var() + + state.handle_supported_regular_vars() + state.handle_non_var() diff --git a/tests/units/test_telemetry.py b/tests/units/test_telemetry.py index a434779d4..25ad91323 100644 --- a/tests/units/test_telemetry.py +++ b/tests/units/test_telemetry.py @@ -52,4 +52,4 @@ def test_send(mocker, event): telemetry._send(event, telemetry_enabled=True) httpx_post_mock.assert_called_once() - pathlib_path_read_text_mock.assert_called_once() + assert pathlib_path_read_text_mock.call_count == 2 diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py index 042c3f323..d7b0c791e 100644 --- a/tests/units/utils/test_format.py +++ b/tests/units/utils/test_format.py @@ -374,7 +374,7 @@ def test_format_match( events=[EventSpec(handler=EventHandler(fn=mock_event))], args_spec=lambda: [], ), - '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ })))))', + '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ })))))', ), ( EventChain( @@ -395,7 +395,7 @@ def test_format_match( ], args_spec=lambda e: [e.target.value], ), - '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] })))], [_e], ({ })))))', + '((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] }), ({ })))], [_e], ({ })))))', ), ( EventChain( @@ -403,7 +403,19 @@ def test_format_match( args_spec=lambda: [], event_actions={"stopPropagation": True}, ), - '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["stopPropagation"] : true })))))', + '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["stopPropagation"] : true })))))', + ), + ( + EventChain( + events=[ + EventSpec( + handler=EventHandler(fn=mock_event), + event_actions={"stopPropagation": True}, + ) + ], + args_spec=lambda: [], + ), + '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ ["stopPropagation"] : true })))], args, ({ })))))', ), ( EventChain( @@ -411,7 +423,7 @@ def test_format_match( args_spec=lambda: [], event_actions={"preventDefault": True}, ), - '((...args) => ((addEvents([(Event("mock_event", ({ })))], args, ({ ["preventDefault"] : true })))))', + '((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["preventDefault"] : true })))))', ), ({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'), (Var(_js_expr="var", _var_type=int).guess_type(), "var"), @@ -519,7 +531,7 @@ def test_format_event_handler(input, output): [ ( EventSpec(handler=EventHandler(fn=mock_event)), - '(Event("mock_event", ({ })))', + '(Event("mock_event", ({ }), ({ })))', ), ], ) diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py index 5cdd846fe..41bd4e661 100644 --- a/tests/units/utils/test_utils.py +++ b/tests/units/utils/test_utils.py @@ -117,7 +117,7 @@ def test_remove_existing_bun_installation(mocker): Args: mocker: Pytest mocker. """ - mocker.patch("reflex.utils.prerequisites.os.path.exists", return_value=True) + mocker.patch("reflex.utils.prerequisites.Path.exists", return_value=True) rm = mocker.patch("reflex.utils.prerequisites.path_ops.rm", mocker.Mock()) prerequisites.remove_existing_bun_installation() @@ -458,7 +458,7 @@ def test_bun_install_without_unzip(mocker): mocker: Pytest mocker object. """ mocker.patch("reflex.utils.path_ops.which", return_value=None) - mocker.patch("os.path.exists", return_value=False) + mocker.patch("pathlib.Path.exists", return_value=False) mocker.patch("reflex.utils.prerequisites.constants.IS_WINDOWS", False) with pytest.raises(FileNotFoundError): @@ -476,7 +476,7 @@ def test_bun_install_version(mocker, bun_version): """ mocker.patch("reflex.utils.prerequisites.constants.IS_WINDOWS", False) - mocker.patch("os.path.exists", return_value=True) + mocker.patch("pathlib.Path.exists", return_value=True) mocker.patch( "reflex.utils.prerequisites.get_bun_version", return_value=version.parse(bun_version),