From 268effe62e203dc0eb495f50b14d0660891d942d Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Mon, 20 Jan 2025 18:12:54 +0000 Subject: [PATCH 1/4] [ENG-4134]Allow specifying custom app module in rxconfig (#4556) * Allow custom app module in rxconfig * what was that pyscopg mess? * fix another mess * get this working with relative imports and hot reload * typing to named tuple * minor refactor * revert redis knobs positions * fix pyright except 1 * fix pyright hopefully * use the resolved module path * testing workflow * move nba-proxy job to counter job * just cast the type * fix tests for python 3.9 * darglint * CR Suggestions for #4556 (#4644) * reload_dirs: search up from app_module for last directory containing __init__ * Change custom app_module to use an import string * preserve sys.path entries added while loading rxconfig.py --------- Co-authored-by: Masen Furer --- .github/workflows/integration_tests.yml | 22 +++++++++++- reflex/app_module_for_backend.py | 7 ++-- reflex/config.py | 30 ++++++++++++++-- reflex/event.py | 2 +- reflex/state.py | 29 ++++++++-------- reflex/utils/exec.py | 26 ++++++++++++-- reflex/utils/prerequisites.py | 46 ++++++++++++++++++++++--- 7 files changed, 132 insertions(+), 30 deletions(-) diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 017336ba5..2ca9aed23 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -33,7 +33,7 @@ env: PR_TITLE: ${{ github.event.pull_request.title }} jobs: - example-counter: + example-counter-and-nba-proxy: env: OUTPUT_FILE: import_benchmark.json timeout-minutes: 30 @@ -119,6 +119,26 @@ jobs: --benchmark-json "./reflex-examples/counter/${{ env.OUTPUT_FILE }}" --branch-name "${{ github.head_ref || github.ref_name }}" --pr-id "${{ github.event.pull_request.id }}" --app-name "counter" + - name: Install requirements for nba proxy example + working-directory: ./reflex-examples/nba-proxy + run: | + poetry run uv pip install -r requirements.txt + - name: Install additional dependencies for DB access + run: poetry run uv pip install psycopg + - name: Check export --backend-only before init for nba-proxy example + working-directory: ./reflex-examples/nba-proxy + run: | + poetry run reflex export --backend-only + - name: Init Website for nba-proxy example + working-directory: ./reflex-examples/nba-proxy + run: | + poetry run reflex init --loglevel debug + - name: Run Website and Check for errors + run: | + # Check that npm is home + npm -v + poetry run bash scripts/integration.sh ./reflex-examples/nba-proxy dev + reflex-web: strategy: diff --git a/reflex/app_module_for_backend.py b/reflex/app_module_for_backend.py index 8109fc3d6..b0ae0a29f 100644 --- a/reflex/app_module_for_backend.py +++ b/reflex/app_module_for_backend.py @@ -7,14 +7,13 @@ from concurrent.futures import ThreadPoolExecutor from reflex import constants from reflex.utils import telemetry from reflex.utils.exec import is_prod_mode -from reflex.utils.prerequisites import get_app +from reflex.utils.prerequisites import get_and_validate_app if constants.CompileVars.APP != "app": raise AssertionError("unexpected variable name for 'app'") telemetry.send("compile") -app_module = get_app(reload=False) -app = getattr(app_module, constants.CompileVars.APP) +app, app_module = get_and_validate_app(reload=False) # For py3.9 compatibility when redis is used, we MUST add any decorator pages # before compiling the app in a thread to avoid event loop error (REF-2172). app._apply_decorated_pages() @@ -30,7 +29,7 @@ if is_prod_mode(): # ensure only "app" is exposed. del app_module del compile_future -del get_app +del get_and_validate_app del is_prod_mode del telemetry del constants diff --git a/reflex/config.py b/reflex/config.py index 7614417d5..8511694fb 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -12,6 +12,7 @@ import threading import urllib.parse from importlib.util import find_spec from pathlib import Path +from types import ModuleType from typing import ( TYPE_CHECKING, Any, @@ -607,6 +608,9 @@ class Config(Base): # The name of the app (should match the name of the app directory). app_name: str + # The path to the app module. + app_module_import: Optional[str] = None + # The log level to use. loglevel: constants.LogLevel = constants.LogLevel.DEFAULT @@ -729,6 +733,19 @@ class Config(Base): "REDIS_URL is required when using the redis state manager." ) + @property + def app_module(self) -> ModuleType | None: + """Return the app module if `app_module_import` is set. + + Returns: + The app module. + """ + return ( + importlib.import_module(self.app_module_import) + if self.app_module_import + else None + ) + @property def module(self) -> str: """Get the module name of the app. @@ -736,6 +753,8 @@ class Config(Base): Returns: The module name. """ + if self.app_module is not None: + return self.app_module.__name__ return ".".join([self.app_name, self.app_name]) def update_from_env(self) -> dict[str, Any]: @@ -874,7 +893,7 @@ def get_config(reload: bool = False) -> Config: return cached_rxconfig.config with _config_lock: - sys_path = sys.path.copy() + orig_sys_path = sys.path.copy() sys.path.clear() sys.path.append(str(Path.cwd())) try: @@ -882,9 +901,14 @@ def get_config(reload: bool = False) -> Config: return _get_config() except Exception: # If the module import fails, try to import with the original sys.path. - sys.path.extend(sys_path) + sys.path.extend(orig_sys_path) return _get_config() finally: + # Find any entries added to sys.path by rxconfig.py itself. + extra_paths = [ + p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd()) + ] # Restore the original sys.path. sys.path.clear() - sys.path.extend(sys_path) + sys.path.extend(extra_paths) + sys.path.extend(orig_sys_path) diff --git a/reflex/event.py b/reflex/event.py index 28852fde5..886a306c1 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -1591,7 +1591,7 @@ def get_handler_args( def fix_events( - events: list[EventHandler | EventSpec] | None, + events: list[EventSpec | EventHandler] | None, token: str, router_data: dict[str, Any] | None = None, ) -> list[Event]: diff --git a/reflex/state.py b/reflex/state.py index e15c73978..66098d232 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1776,9 +1776,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): except Exception as ex: state._clean() - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) - - event_specs = app_instance.backend_exception_handler(ex) + event_specs = ( + prerequisites.get_and_validate_app().app.backend_exception_handler(ex) + ) if event_specs is None: return StateUpdate() @@ -1888,9 +1888,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): except Exception as ex: telemetry.send_error(ex, context="backend") - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) - - event_specs = app_instance.backend_exception_handler(ex) + event_specs = ( + prerequisites.get_and_validate_app().app.backend_exception_handler(ex) + ) yield state._as_state_update( handler, @@ -2403,8 +2403,9 @@ class FrontendEventExceptionState(State): component_stack: The stack trace of the component where the exception occurred. """ - app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) - app_instance.frontend_exception_handler(Exception(stack)) + prerequisites.get_and_validate_app().app.frontend_exception_handler( + Exception(stack) + ) class UpdateVarsInternalState(State): @@ -2442,15 +2443,16 @@ class OnLoadInternalState(State): The list of events to queue for on load handling. """ # Do not app._compile()! It should be already compiled by now. - app = getattr(prerequisites.get_app(), constants.CompileVars.APP) - load_events = app.get_load_events(self.router.page.path) + load_events = prerequisites.get_and_validate_app().app.get_load_events( + self.router.page.path + ) if not load_events: self.is_hydrated = True return # Fast path for navigation with no on_load events defined. self.is_hydrated = False return [ *fix_events( - load_events, + cast(list[Union[EventSpec, EventHandler]], load_events), self.router.session.client_token, router_data=self.router_data, ), @@ -2609,7 +2611,7 @@ class StateProxy(wrapt.ObjectProxy): """ super().__init__(state_instance) # compile is not relevant to backend logic - self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP) + self._self_app = prerequisites.get_and_validate_app().app self._self_substate_path = tuple(state_instance.get_full_name().split(".")) self._self_actx = None self._self_mutable = False @@ -3702,8 +3704,7 @@ def get_state_manager() -> StateManager: Returns: The state manager. """ - app = getattr(prerequisites.get_app(), constants.CompileVars.APP) - return app.state_manager + return prerequisites.get_and_validate_app().app.state_manager class MutableProxy(wrapt.ObjectProxy): diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index 621c4a608..6087818d9 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -240,6 +240,28 @@ def run_backend( run_uvicorn_backend(host, port, loglevel) +def get_reload_dirs() -> list[str]: + """Get the reload directories for the backend. + + Returns: + The reload directories for the backend. + """ + config = get_config() + reload_dirs = [config.app_name] + if config.app_module is not None and config.app_module.__file__: + module_path = Path(config.app_module.__file__).resolve().parent + while module_path.parent.name: + for parent_file in module_path.parent.iterdir(): + if parent_file == "__init__.py": + # go up a level to find dir without `__init__.py` + module_path = module_path.parent + break + else: + break + reload_dirs.append(str(module_path)) + return reload_dirs + + def run_uvicorn_backend(host, port, loglevel: LogLevel): """Run the backend in development mode using Uvicorn. @@ -256,7 +278,7 @@ def run_uvicorn_backend(host, port, loglevel: LogLevel): port=port, log_level=loglevel.value, reload=True, - reload_dirs=[get_config().app_name], + reload_dirs=get_reload_dirs(), ) @@ -281,7 +303,7 @@ def run_granian_backend(host, port, loglevel: LogLevel): interface=Interfaces.ASGI, log_level=LogLevels(loglevel.value), reload=True, - reload_paths=[Path(get_config().app_name)], + reload_paths=get_reload_dirs(), reload_ignore_dirs=[".web"], ).serve() except ImportError: diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index e450393c3..ac1eb58da 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -17,11 +17,12 @@ import stat import sys import tempfile import time +import typing import zipfile from datetime import datetime from pathlib import Path from types import ModuleType -from typing import Callable, List, Optional +from typing import Callable, List, NamedTuple, Optional import httpx import typer @@ -42,9 +43,19 @@ from reflex.utils.exceptions import ( from reflex.utils.format import format_library_name from reflex.utils.registry import _get_npm_registry +if typing.TYPE_CHECKING: + from reflex.app import App + CURRENTLY_INSTALLING_NODE = False +class AppInfo(NamedTuple): + """A tuple containing the app instance and module.""" + + app: App + module: ModuleType + + @dataclasses.dataclass(frozen=True) class Template: """A template for a Reflex app.""" @@ -291,8 +302,11 @@ def get_app(reload: bool = False) -> ModuleType: ) module = config.module sys.path.insert(0, str(Path.cwd())) - app = __import__(module, fromlist=(constants.CompileVars.APP,)) - + app = ( + __import__(module, fromlist=(constants.CompileVars.APP,)) + if not config.app_module + else config.app_module + ) if reload: from reflex.state import reload_state_module @@ -308,6 +322,29 @@ def get_app(reload: bool = False) -> ModuleType: raise +def get_and_validate_app(reload: bool = False) -> AppInfo: + """Get the app instance based on the default config and validate it. + + Args: + reload: Re-import the app module from disk + + Returns: + The app instance and the app module. + + Raises: + RuntimeError: If the app instance is not an instance of rx.App. + """ + from reflex.app import App + + app_module = get_app(reload=reload) + app = getattr(app_module, constants.CompileVars.APP) + if not isinstance(app, App): + raise RuntimeError( + "The app instance in the specified app_module_import in rxconfig must be an instance of rx.App." + ) + return AppInfo(app=app, module=app_module) + + def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: """Get the app module based on the default config after first compiling it. @@ -318,8 +355,7 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType: Returns: The compiled app based on the default config. """ - app_module = get_app(reload=reload) - app = getattr(app_module, constants.CompileVars.APP) + app, app_module = get_and_validate_app(reload=reload) # For py3.9 compatibility when redis is used, we MUST add any decorator pages # before compiling the app in a thread to avoid event loop error (REF-2172). app._apply_decorated_pages() From 9c019a65d588bbd9592e99bb6a0abce8145f741f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Mon, 20 Jan 2025 13:55:53 -0800 Subject: [PATCH 2/4] check for dict passed as children for component (#4656) --- reflex/components/component.py | 15 +++++++-------- reflex/utils/exceptions.py | 18 +++++++++++++++++- tests/units/components/test_component.py | 17 ++++++++++------- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 8649b593d..cfe9a8dc2 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -740,22 +740,21 @@ class Component(BaseComponent, ABC): # Import here to avoid circular imports. from reflex.components.base.bare import Bare from reflex.components.base.fragment import Fragment - from reflex.utils.exceptions import ComponentTypeError + from reflex.utils.exceptions import ChildrenTypeError # Filter out None props props = {key: value for key, value in props.items() if value is not None} def validate_children(children): for child in children: - if isinstance(child, tuple): + if isinstance(child, (tuple, list)): validate_children(child) + # Make sure the child is a valid type. - if not types._isinstance(child, ComponentChild): - raise ComponentTypeError( - "Children of Reflex components must be other components, " - "state vars, or primitive Python types. " - f"Got child {child} of type {type(child)}.", - ) + if isinstance(child, dict) or not types._isinstance( + child, ComponentChild + ): + raise ChildrenTypeError(component=cls.__name__, child=child) # Validate all the children. validate_children(children) diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 339abcda1..838d0a89d 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -1,6 +1,6 @@ """Custom Exceptions.""" -from typing import NoReturn +from typing import Any, NoReturn class ReflexError(Exception): @@ -31,6 +31,22 @@ class ComponentTypeError(ReflexError, TypeError): """Custom TypeError for component related errors.""" +class ChildrenTypeError(ComponentTypeError): + """Raised when the children prop of a component is not a valid type.""" + + def __init__(self, component: str, child: Any): + """Initialize the exception. + + Args: + component: The name of the component. + child: The child that caused the error. + """ + super().__init__( + f"Component {component} received child {child} of type {type(child)}. " + "Accepted types are other components, state vars, or primitive Python types (dict excluded)." + ) + + class EventHandlerTypeError(ReflexError, TypeError): """Custom TypeError for event handler related errors.""" diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 674873b69..6396e4322 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -27,7 +27,7 @@ from reflex.event import ( from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports -from reflex.utils.exceptions import EventFnArgMismatch +from reflex.utils.exceptions import ChildrenTypeError, EventFnArgMismatch from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var @@ -645,14 +645,17 @@ def test_create_filters_none_props(test_component): assert str(component.style["text-align"]) == '"center"' -@pytest.mark.parametrize("children", [((None,),), ("foo", ("bar", (None,)))]) +@pytest.mark.parametrize( + "children", + [ + ((None,),), + ("foo", ("bar", (None,))), + ({"foo": "bar"},), + ], +) def test_component_create_unallowed_types(children, test_component): - with pytest.raises(TypeError) as err: + with pytest.raises(ChildrenTypeError): test_component.create(*children) - assert ( - err.value.args[0] - == "Children of Reflex components must be other components, state vars, or primitive Python types. Got child None of type ." - ) @pytest.mark.parametrize( From 2855ed488719ff51492b8968c473f57345c5abee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Mon, 20 Jan 2025 13:58:17 -0800 Subject: [PATCH 3/4] add some of the TRY rules (#4651) --- pyproject.toml | 4 +- reflex/app.py | 8 +--- reflex/components/component.py | 20 ++++---- reflex/custom_components/custom_components.py | 8 ++-- reflex/utils/prerequisites.py | 48 +++++++++++++------ reflex/utils/telemetry.py | 3 +- tests/integration/test_connection_banner.py | 3 +- 7 files changed, 56 insertions(+), 38 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index eccf21230..d1ae1dcf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,8 +86,8 @@ build-backend = "poetry.core.masonry.api" target-version = "py39" output-format = "concise" lint.isort.split-on-trailing-comma = false -lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "T", "W"] -lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012"] +lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "T", "TRY", "W"] +lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012", "TRY0"] lint.pydocstyle.convention = "google" [tool.ruff.lint.per-file-ignores] diff --git a/reflex/app.py b/reflex/app.py index 08cb4314e..60be0d7dd 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -463,14 +463,8 @@ class App(MiddlewareMixin, LifespanMixin): Returns: The generated component. - - Raises: - exceptions.MatchTypeError: If the return types of match cases in rx.match are different. """ - try: - return component if isinstance(component, Component) else component() - except exceptions.MatchTypeError: - raise + return component if isinstance(component, Component) else component() def add_page( self, diff --git a/reflex/components/component.py b/reflex/components/component.py index cfe9a8dc2..ed90a0f24 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -429,20 +429,22 @@ class Component(BaseComponent, ABC): else: continue + def determine_key(value): + # Try to create a var from the value + key = value if isinstance(value, Var) else LiteralVar.create(value) + + # Check that the var type is not None. + if key is None: + raise TypeError + + return key + # Check whether the key is a component prop. if types._issubclass(field_type, Var): # Used to store the passed types if var type is a union. passed_types = None try: - # Try to create a var from the value. - if isinstance(value, Var): - kwargs[key] = value - else: - kwargs[key] = LiteralVar.create(value) - - # Check that the var type is not None. - if kwargs[key] is None: - raise TypeError + kwargs[key] = determine_key(value) expected_type = fields[key].outer_type_.__args__[0] # validate literal fields. diff --git a/reflex/custom_components/custom_components.py b/reflex/custom_components/custom_components.py index 4a169802f..8000e7f4c 100644 --- a/reflex/custom_components/custom_components.py +++ b/reflex/custom_components/custom_components.py @@ -421,12 +421,13 @@ def _run_commands_in_subprocess(cmds: list[str]) -> bool: console.debug(f"Running command: {' '.join(cmds)}") try: result = subprocess.run(cmds, capture_output=True, text=True, check=True) - console.debug(result.stdout) - return True except subprocess.CalledProcessError as cpe: console.error(cpe.stdout) console.error(cpe.stderr) return False + else: + console.debug(result.stdout) + return True def _make_pyi_files(): @@ -931,10 +932,11 @@ def _get_file_from_prompt_in_loop() -> Tuple[bytes, str] | None: file_extension = image_filepath.suffix try: image_file = image_filepath.read_bytes() - return image_file, file_extension except OSError as ose: console.error(f"Unable to read the {file_extension} file due to {ose}") raise typer.Exit(code=1) from ose + else: + return image_file, file_extension console.debug(f"File extension detected: {file_extension}") return None diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index ac1eb58da..4f9cc0c14 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -278,6 +278,22 @@ def windows_npm_escape_hatch() -> bool: return environment.REFLEX_USE_NPM.get() +def _check_app_name(config: Config): + """Check if the app name is set in the config. + + Args: + config: The config object. + + Raises: + RuntimeError: If the app name is not set in the config. + """ + if not config.app_name: + raise RuntimeError( + "Cannot get the app module because `app_name` is not set in rxconfig! " + "If this error occurs in a reflex test case, ensure that `get_app` is mocked." + ) + + def get_app(reload: bool = False) -> ModuleType: """Get the app module based on the default config. @@ -288,18 +304,16 @@ def get_app(reload: bool = False) -> ModuleType: The app based on the default config. Raises: - RuntimeError: If the app name is not set in the config. + Exception: If an error occurs while getting the app module. """ from reflex.utils import telemetry try: environment.RELOAD_CONFIG.set(reload) config = get_config() - if not config.app_name: - raise RuntimeError( - "Cannot get the app module because `app_name` is not set in rxconfig! " - "If this error occurs in a reflex test case, ensure that `get_app` is mocked." - ) + + _check_app_name(config) + module = config.module sys.path.insert(0, str(Path.cwd())) app = ( @@ -315,11 +329,11 @@ def get_app(reload: bool = False) -> ModuleType: # Reload the app module. importlib.reload(app) - - return app except Exception as ex: telemetry.send_error(ex, context="frontend") raise + else: + return app def get_and_validate_app(reload: bool = False) -> AppInfo: @@ -1189,11 +1203,12 @@ def ensure_reflex_installation_id() -> Optional[int]: if installation_id is None: installation_id = random.getrandbits(128) installation_id_file.write_text(str(installation_id)) - # If we get here, installation_id is definitely set - return installation_id except Exception as e: console.debug(f"Failed to ensure reflex installation id: {e}") return None + else: + # If we get here, installation_id is definitely set + return installation_id def initialize_reflex_user_directory(): @@ -1407,19 +1422,22 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str except OSError as ose: console.error(f"Failed to create temp directory for extracting zip: {ose}") raise typer.Exit(1) from ose + try: zipfile.ZipFile(zip_file_path).extractall(path=unzip_dir) # The zip file downloaded from github looks like: # repo-name-branch/**/*, so we need to remove the top level directory. - if len(subdirs := os.listdir(unzip_dir)) != 1: - console.error(f"Expected one directory in the zip, found {subdirs}") - raise typer.Exit(1) - template_dir = unzip_dir / subdirs[0] - console.debug(f"Template folder is located at {template_dir}") except Exception as uze: console.error(f"Failed to unzip the template: {uze}") raise typer.Exit(1) from uze + if len(subdirs := os.listdir(unzip_dir)) != 1: + console.error(f"Expected one directory in the zip, found {subdirs}") + raise typer.Exit(1) + + template_dir = unzip_dir / subdirs[0] + console.debug(f"Template folder is located at {template_dir}") + # Move the rxconfig file here first. path_ops.mv(str(template_dir / constants.Config.FILE), constants.Config.FILE) new_config = get_config(reload=True) diff --git a/reflex/utils/telemetry.py b/reflex/utils/telemetry.py index fc90932a6..8e9130b09 100644 --- a/reflex/utils/telemetry.py +++ b/reflex/utils/telemetry.py @@ -156,9 +156,10 @@ def _prepare_event(event: str, **kwargs) -> dict: def _send_event(event_data: dict) -> bool: try: httpx.post(POSTHOG_API_URL, json=event_data) - return True except Exception: return False + else: + return True def _send(event, telemetry_enabled, **kwargs): diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index 44187c8ba..18259fe3f 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -71,9 +71,10 @@ def has_error_modal(driver: WebDriver) -> bool: """ try: driver.find_element(By.XPATH, CONNECTION_ERROR_XPATH) - return True except NoSuchElementException: return False + else: + return True @pytest.mark.asyncio From 4dc106545b1f42535bf278ca4092a878515b614d Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Mon, 20 Jan 2025 14:00:08 -0800 Subject: [PATCH 4/4] add defensive checks against data being funny (#4633) --- reflex/app.py | 30 ++++++++++++++++++++++++++++-- reflex/utils/exceptions.py | 4 ++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 60be0d7dd..0d672e4c0 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1557,10 +1557,36 @@ class EventNamespace(AsyncNamespace): Args: sid: The Socket.IO session id. data: The event data. + + Raises: + EventDeserializationError: If the event data is not a dictionary. """ fields = data - # Get the event. - event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS}) + + if isinstance(fields, str): + console.warn( + "Received event data as a string. This generally should not happen and may indicate a bug." + f" Event data: {fields}" + ) + try: + fields = json.loads(fields) + except json.JSONDecodeError as ex: + raise exceptions.EventDeserializationError( + f"Failed to deserialize event data: {fields}." + ) from ex + + if not isinstance(fields, dict): + raise exceptions.EventDeserializationError( + f"Event data must be a dictionary, but received {fields} of type {type(fields)}." + ) + + try: + # Get the event. + event = Event(**{k: v for k, v in fields.items() if k in _EVENT_FIELDS}) + except (TypeError, ValueError) as ex: + raise exceptions.EventDeserializationError( + f"Failed to deserialize event data: {fields}." + ) from ex self.token_to_sid[event.token] = sid self.sid_to_token[sid] = event.token diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 838d0a89d..37a68e420 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -187,6 +187,10 @@ class SystemPackageMissingError(ReflexError): """Raised when a system package is missing.""" +class EventDeserializationError(ReflexError, ValueError): + """Raised when an event cannot be deserialized.""" + + def raise_system_package_missing_error(package: str) -> NoReturn: """Raise a SystemPackageMissingError.