Merge branch 'main' into lendemor/add_ERA_rules

This commit is contained in:
Lendemor 2024-12-10 21:43:55 +01:00
commit b2fa14d251
17 changed files with 162 additions and 35 deletions

View File

@ -331,7 +331,7 @@ _MAPPING: dict = {
"SessionStorage",
],
"middleware": ["middleware", "Middleware"],
"model": ["session", "Model"],
"model": ["asession", "session", "Model"],
"state": [
"var",
"ComponentState",

View File

@ -186,6 +186,7 @@ from .istate.wrappers import get_state as get_state
from .middleware import Middleware as Middleware
from .middleware import middleware as middleware
from .model import Model as Model
from .model import asession as asession
from .model import session as session
from .page import page as page
from .state import ComponentState as ComponentState

View File

@ -1156,7 +1156,7 @@ class App(MiddlewareMixin, LifespanMixin):
if hasattr(handler_fn, "__name__"):
_fn_name = handler_fn.__name__
else:
_fn_name = handler_fn.__class__.__name__
_fn_name = type(handler_fn).__name__
if isinstance(handler_fn, functools.partial):
raise ValueError(

View File

@ -161,7 +161,7 @@ class ComponentNamespace(SimpleNamespace):
Returns:
The hash of the namespace.
"""
return hash(self.__class__.__name__)
return hash(type(self).__name__)
def evaluate_style_namespaces(style: ComponentStyle) -> dict:
@ -2565,7 +2565,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._js_expr))
return hash((type(self).__name__, self._js_expr))
@classmethod
def create(

View File

@ -49,9 +49,9 @@ class Cond(MemoizationLeaf):
The conditional component.
"""
# Wrap everything in fragments.
if comp1.__class__.__name__ != "Fragment":
if type(comp1).__name__ != "Fragment":
comp1 = Fragment.create(comp1)
if comp2 is None or comp2.__class__.__name__ != "Fragment":
if comp2 is None or type(comp2).__name__ != "Fragment":
comp2 = Fragment.create(comp2) if comp2 else Fragment.create()
return Fragment.create(
cls(

View File

@ -139,9 +139,7 @@ class RadixThemesComponent(Component):
component = super().create(*children, **props)
if component.library is None:
component.library = RadixThemesComponent.__fields__["library"].default
component.alias = "RadixThemes" + (
component.tag or component.__class__.__name__
)
component.alias = "RadixThemes" + (component.tag or type(component).__name__)
return component
@staticmethod

View File

@ -512,6 +512,9 @@ class EnvironmentVariables:
# Whether to print the SQL queries if the log level is INFO or lower.
SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False)
# Whether to check db connections before using them.
SQLALCHEMY_POOL_PRE_PING: EnvVar[bool] = env_var(True)
# Whether to ignore the redis config error. Some redis servers only allow out-of-band configuration.
REFLEX_IGNORE_REDIS_CONFIG_ERROR: EnvVar[bool] = env_var(False)
@ -568,6 +571,10 @@ class EnvironmentVariables:
environment = EnvironmentVariables()
# These vars are not logged because they may contain sensitive information.
_sensitive_env_vars = {"DB_URL", "ASYNC_DB_URL", "REDIS_URL"}
class Config(Base):
"""The config defines runtime settings for the app.
@ -621,6 +628,9 @@ class Config(Base):
# The database url used by rx.Model.
db_url: Optional[str] = "sqlite:///reflex.db"
# The async database url used by rx.Model.
async_db_url: Optional[str] = None
# The redis url
redis_url: Optional[str] = None
@ -748,18 +758,20 @@ class Config(Base):
# If the env var is set, override the config value.
if env_var is not None:
if key.upper() != "DB_URL":
console.info(
f"Overriding config value {key} with env var {key.upper()}={env_var}",
dedupe=True,
)
# Interpret the value.
value = interpret_env_var_value(env_var, field.outer_type_, field.name)
# Set the value.
updated_values[key] = value
if key.upper() in _sensitive_env_vars:
env_var = "***"
console.info(
f"Overriding config value {key} with env var {key.upper()}={env_var}",
dedupe=True,
)
return updated_values
def get_event_namespace(self) -> str:

View File

@ -1556,7 +1556,7 @@ class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._js_expr))
return hash((type(self).__name__, self._js_expr))
@classmethod
def create(
@ -1620,7 +1620,7 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._js_expr))
return hash((type(self).__name__, self._js_expr))
@classmethod
def create(

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import re
from collections import defaultdict
from typing import Any, ClassVar, Optional, Type, Union
@ -14,6 +15,7 @@ import alembic.script
import alembic.util
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.ext.asyncio
import sqlalchemy.orm
from reflex.base import Base
@ -21,6 +23,48 @@ from reflex.config import environment, get_config
from reflex.utils import console
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
# Import AsyncSession _after_ reflex.utils.compat
from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402
def _safe_db_url_for_logging(url: str) -> str:
"""Remove username and password from the database URL for logging.
Args:
url: The database URL.
Returns:
The database URL with the username and password removed.
"""
return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
def get_engine_args(url: str | None = None) -> dict[str, Any]:
"""Get the database engine arguments.
Args:
url: The database url.
Returns:
The database engine arguments as a dict.
"""
kwargs: dict[str, Any] = dict(
# Print the SQL queries if the log level is INFO or lower.
echo=environment.SQLALCHEMY_ECHO.get(),
# Check connections before returning them.
pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(),
)
conf = get_config()
url = url or conf.db_url
if url is not None and url.startswith("sqlite"):
# Needed for the admin dash on sqlite.
kwargs["connect_args"] = {"check_same_thread": False}
return kwargs
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
"""Get the database engine.
@ -38,15 +82,62 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
url = url or conf.db_url
if url is None:
raise ValueError("No database url configured")
global _ENGINE
if url in _ENGINE:
return _ENGINE[url]
if not environment.ALEMBIC_CONFIG.get().exists():
console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first."
)
# Print the SQL queries if the log level is INFO or lower.
echo_db_query = environment.SQLALCHEMY_ECHO.get()
# Needed for the admin dash on sqlite.
connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {}
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
_ENGINE[url] = sqlmodel.create_engine(
url,
**get_engine_args(url),
)
return _ENGINE[url]
def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
"""Get the async database engine.
Args:
url: The database url.
Returns:
The async database engine.
Raises:
ValueError: If the async database url is None.
"""
if url is None:
conf = get_config()
url = conf.async_db_url
if url is not None and conf.db_url is not None:
async_db_url_tail = url.partition("://")[2]
db_url_tail = conf.db_url.partition("://")[2]
if async_db_url_tail != db_url_tail:
console.warn(
f"async_db_url `{_safe_db_url_for_logging(url)}` "
"should reference the same database as "
f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
)
if url is None:
raise ValueError("No async database url configured")
global _ASYNC_ENGINE
if url in _ASYNC_ENGINE:
return _ASYNC_ENGINE[url]
if not environment.ALEMBIC_CONFIG.get().exists():
console.warn(
"Database is not initialized, run [bold]reflex db init[/bold] first."
)
_ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
url,
**get_engine_args(url),
)
return _ASYNC_ENGINE[url]
async def get_db_status() -> bool:
@ -425,6 +516,31 @@ def session(url: str | None = None) -> sqlmodel.Session:
return sqlmodel.Session(get_engine(url))
def asession(url: str | None = None) -> AsyncSession:
"""Get an async sqlmodel session to interact with the database.
async with rx.asession() as asession:
...
Most operations against the `asession` must be awaited.
Args:
url: The database url.
Returns:
An async database session.
"""
global _AsyncSessionLocal
if url not in _AsyncSessionLocal:
_AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
bind=get_async_engine(url),
class_=AsyncSession,
autocommit=False,
autoflush=False,
)
return _AsyncSessionLocal[url]()
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
"""Get a bare sqlalchemy session to interact with the database.

View File

@ -438,7 +438,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Returns:
The string representation of the state.
"""
return f"{self.__class__.__name__}({self.dict()})"
return f"{type(self).__name__}({self.dict()})"
@classmethod
def _get_computed_vars(cls) -> list[ComputedVar]:
@ -3618,7 +3618,7 @@ class MutableProxy(wrapt.ObjectProxy):
Returns:
The representation of the wrapped object.
"""
return f"{self.__class__.__name__}({self.__wrapped__})"
return f"{type(self).__name__}({self.__wrapped__})"
def _mark_dirty(
self,

View File

@ -1569,7 +1569,7 @@ class CachedVarOperation:
if name == "_js_expr":
return self._cached_var_name
parent_classes = inspect.getmro(self.__class__)
parent_classes = inspect.getmro(type(self))
next_class = parent_classes[parent_classes.index(CachedVarOperation) + 1]
@ -1611,7 +1611,7 @@ class CachedVarOperation:
"""
return hash(
(
self.__class__.__name__,
type(self).__name__,
*[
getattr(self, field.name)
for field in dataclasses.fields(self) # type: ignore
@ -1733,7 +1733,7 @@ class CallableVar(Var):
Returns:
The hash of the object.
"""
return hash((self.__class__.__name__, self.original_var))
return hash((type(self).__name__, self.original_var))
RETURN_TYPE = TypeVar("RETURN_TYPE")

View File

@ -1012,7 +1012,7 @@ class LiteralNumberVar(LiteralVar, NumberVar):
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._var_value))
return hash((type(self).__name__, self._var_value))
@classmethod
def create(cls, value: float | int, _var_data: VarData | None = None):
@ -1064,7 +1064,7 @@ class LiteralBooleanVar(LiteralVar, BooleanVar):
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._var_value))
return hash((type(self).__name__, self._var_value))
@classmethod
def create(cls, value: bool, _var_data: VarData | None = None):

View File

@ -362,7 +362,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._js_expr))
return hash((type(self).__name__, self._js_expr))
@cached_property_no_lock
def _cached_get_all_var_data(self) -> VarData | None:

View File

@ -667,7 +667,7 @@ class LiteralStringVar(LiteralVar, StringVar[str]):
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._var_value))
return hash((type(self).__name__, self._var_value))
def json(self) -> str:
"""Get the JSON representation of the var.

View File

@ -73,7 +73,7 @@ def StateInheritance():
def on_click_other_mixin(self):
self.other_mixin_clicks += 1
self.other_mixin = (
f"{self.__class__.__name__}.clicked.{self.other_mixin_clicks}"
f"{type(self).__name__}.clicked.{self.other_mixin_clicks}"
)
class Base1(Mixin, rx.State):

View File

@ -46,7 +46,7 @@ def test_default_primary_key(model_default_primary: Model):
Args:
model_default_primary: Fixture.
"""
assert "id" in model_default_primary.__class__.__fields__
assert "id" in type(model_default_primary).__fields__
def test_custom_primary_key(model_custom_primary: Model):
@ -55,7 +55,7 @@ def test_custom_primary_key(model_custom_primary: Model):
Args:
model_custom_primary: Fixture.
"""
assert "id" not in model_custom_primary.__class__.__fields__
assert "id" not in type(model_custom_primary).__fields__
@pytest.mark.filterwarnings(

View File

@ -1697,7 +1697,7 @@ async def test_state_manager_modify_state(
assert not state_manager._states_locks[token].locked()
# separate instances should NOT share locks
sm2 = state_manager.__class__(state=TestState)
sm2 = type(state_manager)(state=TestState)
assert sm2._state_manager_lock is state_manager._state_manager_lock
assert not sm2._states_locks
if state_manager._states_locks: