diff --git a/.gitignore b/.gitignore index 0f7d9e5ff..8bd92964c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ assets/external/* dist/* examples/ .web +.states .idea .vscode .coverage diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 5b8046347..93c664ef1 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -408,7 +408,7 @@ export const connect = async ( socket.current = io(endpoint.href, { path: endpoint["pathname"], transports: transports, - protocols: env.TEST_MODE ? undefined : [reflexEnvironment.version], + protocols: [reflexEnvironment.version], autoUnref: false, }); // Ensure undefined fields in events are sent as null instead of removed diff --git a/reflex/app.py b/reflex/app.py index 8bc249612..6523f598a 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -405,7 +405,31 @@ class App(MiddlewareMixin, LifespanMixin): self.sio.register_namespace(self.event_namespace) # Mount the socket app with the API. if self.api: - self.api.mount(str(constants.Endpoint.EVENT), socket_app) + + class HeaderMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + original_send = send + + async def modified_send(message): + headers = dict(scope["headers"]) + protocol_key = b"sec-websocket-protocol" + if ( + message["type"] == "websocket.accept" + and protocol_key in headers + ): + message["headers"] = [ + *message.get("headers", []), + (b"sec-websocket-protocol", headers[protocol_key]), + ] + return await original_send(message) + + return await self.app(scope, receive, modified_send) + + socket_app_with_headers = HeaderMiddleware(socket_app) + self.api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers) # Check the exception handlers self._validate_exception_handlers() diff --git a/reflex/config.py b/reflex/config.py index 3b88f78cd..3878d021e 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -490,6 +490,9 @@ class EnvironmentVariables: # The working directory for the next.js commands. REFLEX_WEB_WORKDIR: EnvVar[Path] = env_var(Path(constants.Dirs.WEB)) + # The working directory for the states directory. + REFLEX_STATES_WORKDIR: EnvVar[Path] = env_var(Path(constants.Dirs.STATES)) + # Path to the alembic config file ALEMBIC_CONFIG: EnvVar[ExistingPath] = env_var(Path(constants.ALEMBIC_CONFIG)) diff --git a/reflex/constants/base.py b/reflex/constants/base.py index f737858c0..11f3e3c05 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -52,7 +52,7 @@ class Dirs(SimpleNamespace): # The name of the postcss config file. POSTCSS_JS = "postcss.config.js" # The name of the states directory. - STATES = "states" + STATES = ".states" class Reflex(SimpleNamespace): diff --git a/reflex/constants/config.py b/reflex/constants/config.py index 7425fd864..a49216c00 100644 --- a/reflex/constants/config.py +++ b/reflex/constants/config.py @@ -39,7 +39,14 @@ class GitIgnore(SimpleNamespace): # The gitignore file. FILE = Path(".gitignore") # Files to gitignore. - DEFAULTS = {Dirs.WEB, "*.db", "__pycache__/", "*.py[cod]", "assets/external/"} + DEFAULTS = { + Dirs.WEB, + Dirs.STATES, + "*.db", + "__pycache__/", + "*.py[cod]", + "assets/external/", + } class RequirementsTxt(SimpleNamespace): diff --git a/reflex/state.py b/reflex/state.py index 5f478c176..05d920033 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -3046,7 +3046,7 @@ def is_serializable(value: Any) -> bool: def reset_disk_state_manager(): """Reset the disk state manager.""" - states_directory = prerequisites.get_web_dir() / constants.Dirs.STATES + states_directory = prerequisites.get_states_dir() if states_directory.exists(): for path in states_directory.iterdir(): path.unlink() @@ -3094,7 +3094,7 @@ class StateManagerDisk(StateManager): Returns: The states directory. """ - return prerequisites.get_web_dir() / constants.Dirs.STATES + return prerequisites.get_states_dir() def _purge_expired_states(self): """Purge expired states from the disk.""" diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 556620e1c..583e82f83 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -307,7 +307,7 @@ def run_granian_backend(host, port, loglevel: LogLevel): log_level=LogLevels(loglevel.value), reload=True, reload_paths=get_reload_dirs(), - reload_ignore_dirs=[".web"], + reload_ignore_dirs=[".web", ".states"], ).serve() except ImportError: console.error( @@ -467,19 +467,19 @@ def output_system_info(): system = platform.system() + fnm_info = f"[FNM {prerequisites.get_fnm_version()} (Expected: {constants.Fnm.VERSION}) (PATH: {constants.Fnm.EXE})]" + if system != "Windows" or ( system == "Windows" and prerequisites.is_windows_bun_supported() ): dependencies.extend( [ - f"[FNM {prerequisites.get_fnm_version()} (Expected: {constants.Fnm.VERSION}) (PATH: {constants.Fnm.EXE})]", - f"[Bun {prerequisites.get_bun_version()} (Expected: {constants.Bun.VERSION}) (PATH: {config.bun_path})]", + fnm_info, + f"[Bun {prerequisites.get_bun_version()} (Expected: {constants.Bun.VERSION}) (PATH: {path_ops.get_bun_path()})]", ], ) else: - dependencies.append( - f"[FNM {prerequisites.get_fnm_version()} (Expected: {constants.Fnm.VERSION}) (PATH: {constants.Fnm.EXE})]", - ) + dependencies.append(fnm_info) if system == "Linux": import distro diff --git a/reflex/utils/path_ops.py b/reflex/utils/path_ops.py index b447718d2..edab085ff 100644 --- a/reflex/utils/path_ops.py +++ b/reflex/utils/path_ops.py @@ -9,7 +9,7 @@ import shutil from pathlib import Path from reflex import constants -from reflex.config import environment +from reflex.config import environment, get_config # Shorthand for join. join = os.linesep.join @@ -118,7 +118,7 @@ def ln(src: str | Path, dest: str | Path, overwrite: bool = False) -> bool: return True -def which(program: str | Path) -> str | Path | None: +def which(program: str | Path) -> Path | None: """Find the path to an executable. Args: @@ -127,7 +127,8 @@ def which(program: str | Path) -> str | Path | None: Returns: The path to the executable. """ - return shutil.which(str(program)) + which_result = shutil.which(program) + return Path(which_result) if which_result else None def use_system_node() -> bool: @@ -156,12 +157,12 @@ def get_node_bin_path() -> Path | None: """ bin_path = Path(constants.Node.BIN_PATH) if not bin_path.exists(): - str_path = which("node") - return Path(str_path).parent.resolve() if str_path else None - return bin_path.resolve() + path = which("node") + return path.parent.absolute() if path else None + return bin_path.absolute() -def get_node_path() -> str | None: +def get_node_path() -> Path | None: """Get the node binary path. Returns: @@ -169,9 +170,8 @@ def get_node_path() -> str | None: """ node_path = Path(constants.Node.PATH) if use_system_node() or not node_path.exists(): - system_node_path = which("node") - return str(system_node_path) if system_node_path else None - return str(node_path) + node_path = which("node") + return node_path def get_npm_path() -> Path | None: @@ -182,11 +182,22 @@ def get_npm_path() -> Path | None: """ npm_path = Path(constants.Node.NPM_PATH) if use_system_node() or not npm_path.exists(): - system_npm_path = which("npm") - npm_path = Path(system_npm_path) if system_npm_path else None + npm_path = which("npm") return npm_path.absolute() if npm_path else None +def get_bun_path() -> Path | None: + """Get bun binary path. + + Returns: + The path to the bun binary file. + """ + bun_path = get_config().bun_path + if use_system_bun() or not bun_path.exists(): + bun_path = which("bun") + return bun_path.absolute() if bun_path else None + + def update_json_file(file_path: str | Path, update_dict: dict[str, int | str]): """Update the contents of a json file. @@ -196,6 +207,9 @@ def update_json_file(file_path: str | Path, update_dict: dict[str, int | str]): """ fp = Path(file_path) + # Create the parent directory if it doesn't exist. + fp.parent.mkdir(parents=True, exist_ok=True) + # Create the file if it doesn't exist. fp.touch(exist_ok=True) diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 6623a178d..ea510c842 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -87,6 +87,17 @@ def get_web_dir() -> Path: return environment.REFLEX_WEB_WORKDIR.get() +def get_states_dir() -> Path: + """Get the working directory for the states. + + Can be overridden with REFLEX_STATES_WORKDIR. + + Returns: + The working directory. + """ + return environment.REFLEX_STATES_WORKDIR.get() + + def check_latest_package_version(package_name: str): """Check if the latest version of the package is installed. @@ -194,10 +205,14 @@ def get_bun_version() -> version.Version | None: Returns: The version of bun. """ + bun_path = path_ops.get_bun_path() + if bun_path is None: + return None try: # Run the bun -v command and capture the output result = processes.new_process([str(get_config().bun_path), "-v"], run=True) - return version.parse(result.stdout) # pyright: ignore [reportArgumentType] + result = processes.new_process([str(get_config().bun_path), "-v"], run=True) + return version.parse(str(result.stdout)) # pyright: ignore [reportArgumentType] except FileNotFoundError: return None except version.InvalidVersion as e: @@ -1051,9 +1066,7 @@ def install_bun(): ) # Skip if bun is already installed. - if Path(get_config().bun_path).exists() and get_bun_version() == version.parse( - constants.Bun.VERSION - ): + if get_bun_version() == version.parse(constants.Bun.VERSION): console.debug("Skipping bun installation as it is already installed.") return @@ -1074,8 +1087,7 @@ def install_bun(): show_logs=console.is_debug(), ) else: - unzip_path = path_ops.which("unzip") - if unzip_path is None: + if path_ops.which("unzip") is None: raise SystemPackageMissingError("unzip") # Run the bun install script. @@ -1279,12 +1291,9 @@ def validate_bun(): Raises: Exit: If custom specified bun does not exist or does not meet requirements. """ - # if a custom bun path is provided, make sure its valid - # This is specific to non-FHS OS - bun_path = get_config().bun_path - if path_ops.use_system_bun(): - bun_path = path_ops.which("bun") - if bun_path != constants.Bun.DEFAULT_PATH: + bun_path = path_ops.get_bun_path() + + if bun_path and bun_path.samefile(constants.Bun.DEFAULT_PATH): console.info(f"Using custom Bun path: {bun_path}") bun_version = get_bun_version() if not bun_version: diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py index 39a1ec35c..44356dac5 100644 --- a/tests/units/utils/test_utils.py +++ b/tests/units/utils/test_utils.py @@ -122,13 +122,13 @@ def test_validate_invalid_bun_path(mocker): Args: mocker: Pytest mocker object. """ - mock = mocker.Mock() - mocker.patch.object(mock, "bun_path", return_value="/mock/path") - mocker.patch("reflex.utils.prerequisites.get_config", mock) + mock_path = mocker.Mock() + mocker.patch("reflex.utils.path_ops.get_bun_path", return_value=mock_path) mocker.patch("reflex.utils.prerequisites.get_bun_version", return_value=None) with pytest.raises(typer.Exit): prerequisites.validate_bun() + mock_path.samefile.assert_called_once() def test_validate_bun_path_incompatible_version(mocker): @@ -137,9 +137,8 @@ def test_validate_bun_path_incompatible_version(mocker): Args: mocker: Pytest mocker object. """ - mock = mocker.Mock() - mocker.patch.object(mock, "bun_path", return_value="/mock/path") - mocker.patch("reflex.utils.prerequisites.get_config", mock) + mock_path = mocker.Mock() + mocker.patch("reflex.utils.path_ops.get_bun_path", return_value=mock_path) mocker.patch( "reflex.utils.prerequisites.get_bun_version", return_value=version.parse("0.6.5"),