From 3ef7106e0e3b1d0980cc516c1450ebd02ae0dd54 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Tue, 10 Dec 2024 21:35:46 +0100 Subject: [PATCH 1/2] style: prefer type() over __class__ (#4512) --- reflex/app.py | 2 +- reflex/components/component.py | 4 ++-- reflex/components/core/cond.py | 4 ++-- reflex/components/radix/themes/base.py | 4 +--- reflex/event.py | 4 ++-- reflex/state.py | 4 ++-- reflex/vars/base.py | 6 +++--- reflex/vars/number.py | 4 ++-- reflex/vars/object.py | 2 +- reflex/vars/sequence.py | 2 +- tests/integration/test_state_inheritance.py | 2 +- tests/units/test_model.py | 4 ++-- tests/units/test_state.py | 2 +- 13 files changed, 21 insertions(+), 23 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index cdf21aa35..eb453bd0b 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1157,7 +1157,7 @@ class App(MiddlewareMixin, LifespanMixin): if hasattr(handler_fn, "__name__"): _fn_name = handler_fn.__name__ else: - _fn_name = handler_fn.__class__.__name__ + _fn_name = type(handler_fn).__name__ if isinstance(handler_fn, functools.partial): raise ValueError( diff --git a/reflex/components/component.py b/reflex/components/component.py index 5c1c4941d..75a821ac8 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -161,7 +161,7 @@ class ComponentNamespace(SimpleNamespace): Returns: The hash of the namespace. """ - return hash(self.__class__.__name__) + return hash(type(self).__name__) def evaluate_style_namespaces(style: ComponentStyle) -> dict: @@ -2565,7 +2565,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self._js_expr)) + return hash((type(self).__name__, self._js_expr)) @classmethod def create( diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index b75667890..488990f54 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -49,9 +49,9 @@ class Cond(MemoizationLeaf): The conditional component. """ # Wrap everything in fragments. - if comp1.__class__.__name__ != "Fragment": + if type(comp1).__name__ != "Fragment": comp1 = Fragment.create(comp1) - if comp2 is None or comp2.__class__.__name__ != "Fragment": + if comp2 is None or type(comp2).__name__ != "Fragment": comp2 = Fragment.create(comp2) if comp2 else Fragment.create() return Fragment.create( cls( diff --git a/reflex/components/radix/themes/base.py b/reflex/components/radix/themes/base.py index 65a9ae835..9a8a2d7c0 100644 --- a/reflex/components/radix/themes/base.py +++ b/reflex/components/radix/themes/base.py @@ -139,9 +139,7 @@ class RadixThemesComponent(Component): component = super().create(*children, **props) if component.library is None: component.library = RadixThemesComponent.__fields__["library"].default - component.alias = "RadixThemes" + ( - component.tag or component.__class__.__name__ - ) + component.alias = "RadixThemes" + (component.tag or type(component).__name__) # value = props.get("value") # if value is not None and component.alias == "RadixThemesSelect.Root": # lv = LiteralVar.create(value) diff --git a/reflex/event.py b/reflex/event.py index 3022de556..9807b696b 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -1556,7 +1556,7 @@ class LiteralEventVar(VarOperationCall, LiteralVar, EventVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self._js_expr)) + return hash((type(self).__name__, self._js_expr)) @classmethod def create( @@ -1620,7 +1620,7 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV Returns: The hash of the var. """ - return hash((self.__class__.__name__, self._js_expr)) + return hash((type(self).__name__, self._js_expr)) @classmethod def create( diff --git a/reflex/state.py b/reflex/state.py index 5e132ef65..09e76159c 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -438,7 +438,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Returns: The string representation of the state. """ - return f"{self.__class__.__name__}({self.dict()})" + return f"{type(self).__name__}({self.dict()})" @classmethod def _get_computed_vars(cls) -> list[ComputedVar]: @@ -3618,7 +3618,7 @@ class MutableProxy(wrapt.ObjectProxy): Returns: The representation of the wrapped object. """ - return f"{self.__class__.__name__}({self.__wrapped__})" + return f"{type(self).__name__}({self.__wrapped__})" def _mark_dirty( self, diff --git a/reflex/vars/base.py b/reflex/vars/base.py index ea78d1a89..941a9d81a 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1569,7 +1569,7 @@ class CachedVarOperation: if name == "_js_expr": return self._cached_var_name - parent_classes = inspect.getmro(self.__class__) + parent_classes = inspect.getmro(type(self)) next_class = parent_classes[parent_classes.index(CachedVarOperation) + 1] @@ -1611,7 +1611,7 @@ class CachedVarOperation: """ return hash( ( - self.__class__.__name__, + type(self).__name__, *[ getattr(self, field.name) for field in dataclasses.fields(self) # type: ignore @@ -1733,7 +1733,7 @@ class CallableVar(Var): Returns: The hash of the object. """ - return hash((self.__class__.__name__, self.original_var)) + return hash((type(self).__name__, self.original_var)) RETURN_TYPE = TypeVar("RETURN_TYPE") diff --git a/reflex/vars/number.py b/reflex/vars/number.py index a762796e2..d14f9695b 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -1012,7 +1012,7 @@ class LiteralNumberVar(LiteralVar, NumberVar): Returns: int: The hash value of the object. """ - return hash((self.__class__.__name__, self._var_value)) + return hash((type(self).__name__, self._var_value)) @classmethod def create(cls, value: float | int, _var_data: VarData | None = None): @@ -1064,7 +1064,7 @@ class LiteralBooleanVar(LiteralVar, BooleanVar): Returns: int: The hash value of the object. """ - return hash((self.__class__.__name__, self._var_value)) + return hash((type(self).__name__, self._var_value)) @classmethod def create(cls, value: bool, _var_data: VarData | None = None): diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 032fc8058..5de431f5a 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -362,7 +362,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self._js_expr)) + return hash((type(self).__name__, self._js_expr)) @cached_property_no_lock def _cached_get_all_var_data(self) -> VarData | None: diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index f5639685f..a02645309 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -667,7 +667,7 @@ class LiteralStringVar(LiteralVar, StringVar[str]): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self._var_value)) + return hash((type(self).__name__, self._var_value)) def json(self) -> str: """Get the JSON representation of the var. diff --git a/tests/integration/test_state_inheritance.py b/tests/integration/test_state_inheritance.py index 81512a67a..6b93a2ed7 100644 --- a/tests/integration/test_state_inheritance.py +++ b/tests/integration/test_state_inheritance.py @@ -73,7 +73,7 @@ def StateInheritance(): def on_click_other_mixin(self): self.other_mixin_clicks += 1 self.other_mixin = ( - f"{self.__class__.__name__}.clicked.{self.other_mixin_clicks}" + f"{type(self).__name__}.clicked.{self.other_mixin_clicks}" ) class Base1(Mixin, rx.State): diff --git a/tests/units/test_model.py b/tests/units/test_model.py index ac8187e03..0a83f39ec 100644 --- a/tests/units/test_model.py +++ b/tests/units/test_model.py @@ -46,7 +46,7 @@ def test_default_primary_key(model_default_primary: Model): Args: model_default_primary: Fixture. """ - assert "id" in model_default_primary.__class__.__fields__ + assert "id" in type(model_default_primary).__fields__ def test_custom_primary_key(model_custom_primary: Model): @@ -55,7 +55,7 @@ def test_custom_primary_key(model_custom_primary: Model): Args: model_custom_primary: Fixture. """ - assert "id" not in model_custom_primary.__class__.__fields__ + assert "id" not in type(model_custom_primary).__fields__ @pytest.mark.filterwarnings( diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 9dc2bc4e2..5c062ccb2 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1698,7 +1698,7 @@ async def test_state_manager_modify_state( assert not state_manager._states_locks[token].locked() # separate instances should NOT share locks - sm2 = state_manager.__class__(state=TestState) + sm2 = type(state_manager)(state=TestState) assert sm2._state_manager_lock is state_manager._state_manager_lock assert not sm2._states_locks if state_manager._states_locks: From 4922f7ba055ca579e9f2d05248e767a3190dee3a Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 10 Dec 2024 12:37:09 -0800 Subject: [PATCH 2/2] Reuse the sqlalchemy engine once it's created (#4493) * Reuse the sqlalchemy engine once it's created * Implement `rx.asession` for async database support Requires setting `async_db_url` to the same database as `db_url`, except using an async driver; this will vary by database. * resolve the url first, so the key into _ENGINE is correct * Ping db connections before returning them from pool Move connect engine kwargs to a separate function * the param is `echo` * sanity check that config db_url and async_db_url are the same throw a warning if the part following the `://` differs between these two * create_async_engine: use sqlalchemy async API update types * redact ASYNC_DB_URL similarly to DB_URL when overridden in config * update rx.asession docstring * use async_sessionmaker * Redact sensitive env vars instead of hiding them --- reflex/__init__.py | 2 +- reflex/__init__.pyi | 1 + reflex/config.py | 24 ++++++--- reflex/model.py | 126 ++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 141 insertions(+), 12 deletions(-) diff --git a/reflex/__init__.py b/reflex/__init__.py index 562524416..0be568b1a 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -331,7 +331,7 @@ _MAPPING: dict = { "SessionStorage", ], "middleware": ["middleware", "Middleware"], - "model": ["session", "Model"], + "model": ["asession", "session", "Model"], "state": [ "var", "ComponentState", diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index 6f61435e6..6bdfa11a0 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -186,6 +186,7 @@ from .istate.wrappers import get_state as get_state from .middleware import Middleware as Middleware from .middleware import middleware as middleware from .model import Model as Model +from .model import asession as asession from .model import session as session from .page import page as page from .state import ComponentState as ComponentState diff --git a/reflex/config.py b/reflex/config.py index 1d952f4f0..c40abcb39 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -512,6 +512,9 @@ class EnvironmentVariables: # Whether to print the SQL queries if the log level is INFO or lower. SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False) + # Whether to check db connections before using them. + SQLALCHEMY_POOL_PRE_PING: EnvVar[bool] = env_var(True) + # Whether to ignore the redis config error. Some redis servers only allow out-of-band configuration. REFLEX_IGNORE_REDIS_CONFIG_ERROR: EnvVar[bool] = env_var(False) @@ -568,6 +571,10 @@ class EnvironmentVariables: environment = EnvironmentVariables() +# These vars are not logged because they may contain sensitive information. +_sensitive_env_vars = {"DB_URL", "ASYNC_DB_URL", "REDIS_URL"} + + class Config(Base): """The config defines runtime settings for the app. @@ -621,6 +628,9 @@ class Config(Base): # The database url used by rx.Model. db_url: Optional[str] = "sqlite:///reflex.db" + # The async database url used by rx.Model. + async_db_url: Optional[str] = None + # The redis url redis_url: Optional[str] = None @@ -748,18 +758,20 @@ class Config(Base): # If the env var is set, override the config value. if env_var is not None: - if key.upper() != "DB_URL": - console.info( - f"Overriding config value {key} with env var {key.upper()}={env_var}", - dedupe=True, - ) - # Interpret the value. value = interpret_env_var_value(env_var, field.outer_type_, field.name) # Set the value. updated_values[key] = value + if key.upper() in _sensitive_env_vars: + env_var = "***" + + console.info( + f"Overriding config value {key} with env var {key.upper()}={env_var}", + dedupe=True, + ) + return updated_values def get_event_namespace(self) -> str: diff --git a/reflex/model.py b/reflex/model.py index 4b070ec67..03b1cc828 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re from collections import defaultdict from typing import Any, ClassVar, Optional, Type, Union @@ -14,6 +15,7 @@ import alembic.script import alembic.util import sqlalchemy import sqlalchemy.exc +import sqlalchemy.ext.asyncio import sqlalchemy.orm from reflex.base import Base @@ -21,6 +23,48 @@ from reflex.config import environment, get_config from reflex.utils import console from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key +_ENGINE: dict[str, sqlalchemy.engine.Engine] = {} +_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {} +_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {} + +# Import AsyncSession _after_ reflex.utils.compat +from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402 + + +def _safe_db_url_for_logging(url: str) -> str: + """Remove username and password from the database URL for logging. + + Args: + url: The database URL. + + Returns: + The database URL with the username and password removed. + """ + return re.sub(r"://[^@]+@", "://:@", url) + + +def get_engine_args(url: str | None = None) -> dict[str, Any]: + """Get the database engine arguments. + + Args: + url: The database url. + + Returns: + The database engine arguments as a dict. + """ + kwargs: dict[str, Any] = dict( + # Print the SQL queries if the log level is INFO or lower. + echo=environment.SQLALCHEMY_ECHO.get(), + # Check connections before returning them. + pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(), + ) + conf = get_config() + url = url or conf.db_url + if url is not None and url.startswith("sqlite"): + # Needed for the admin dash on sqlite. + kwargs["connect_args"] = {"check_same_thread": False} + return kwargs + def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: """Get the database engine. @@ -38,15 +82,62 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine: url = url or conf.db_url if url is None: raise ValueError("No database url configured") + + global _ENGINE + if url in _ENGINE: + return _ENGINE[url] + if not environment.ALEMBIC_CONFIG.get().exists(): console.warn( "Database is not initialized, run [bold]reflex db init[/bold] first." ) - # Print the SQL queries if the log level is INFO or lower. - echo_db_query = environment.SQLALCHEMY_ECHO.get() - # Needed for the admin dash on sqlite. - connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {} - return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args) + _ENGINE[url] = sqlmodel.create_engine( + url, + **get_engine_args(url), + ) + return _ENGINE[url] + + +def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine: + """Get the async database engine. + + Args: + url: The database url. + + Returns: + The async database engine. + + Raises: + ValueError: If the async database url is None. + """ + if url is None: + conf = get_config() + url = conf.async_db_url + if url is not None and conf.db_url is not None: + async_db_url_tail = url.partition("://")[2] + db_url_tail = conf.db_url.partition("://")[2] + if async_db_url_tail != db_url_tail: + console.warn( + f"async_db_url `{_safe_db_url_for_logging(url)}` " + "should reference the same database as " + f"db_url `{_safe_db_url_for_logging(conf.db_url)}`." + ) + if url is None: + raise ValueError("No async database url configured") + + global _ASYNC_ENGINE + if url in _ASYNC_ENGINE: + return _ASYNC_ENGINE[url] + + if not environment.ALEMBIC_CONFIG.get().exists(): + console.warn( + "Database is not initialized, run [bold]reflex db init[/bold] first." + ) + _ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine( + url, + **get_engine_args(url), + ) + return _ASYNC_ENGINE[url] async def get_db_status() -> bool: @@ -425,6 +516,31 @@ def session(url: str | None = None) -> sqlmodel.Session: return sqlmodel.Session(get_engine(url)) +def asession(url: str | None = None) -> AsyncSession: + """Get an async sqlmodel session to interact with the database. + + async with rx.asession() as asession: + ... + + Most operations against the `asession` must be awaited. + + Args: + url: The database url. + + Returns: + An async database session. + """ + global _AsyncSessionLocal + if url not in _AsyncSessionLocal: + _AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker( + bind=get_async_engine(url), + class_=AsyncSession, + autocommit=False, + autoflush=False, + ) + return _AsyncSessionLocal[url]() + + def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session: """Get a bare sqlalchemy session to interact with the database.