Merge branch 'main' into lendemor/add_ERA_rules
This commit is contained in:
commit
b2fa14d251
@ -331,7 +331,7 @@ _MAPPING: dict = {
|
|||||||
"SessionStorage",
|
"SessionStorage",
|
||||||
],
|
],
|
||||||
"middleware": ["middleware", "Middleware"],
|
"middleware": ["middleware", "Middleware"],
|
||||||
"model": ["session", "Model"],
|
"model": ["asession", "session", "Model"],
|
||||||
"state": [
|
"state": [
|
||||||
"var",
|
"var",
|
||||||
"ComponentState",
|
"ComponentState",
|
||||||
|
@ -186,6 +186,7 @@ from .istate.wrappers import get_state as get_state
|
|||||||
from .middleware import Middleware as Middleware
|
from .middleware import Middleware as Middleware
|
||||||
from .middleware import middleware as middleware
|
from .middleware import middleware as middleware
|
||||||
from .model import Model as Model
|
from .model import Model as Model
|
||||||
|
from .model import asession as asession
|
||||||
from .model import session as session
|
from .model import session as session
|
||||||
from .page import page as page
|
from .page import page as page
|
||||||
from .state import ComponentState as ComponentState
|
from .state import ComponentState as ComponentState
|
||||||
|
@ -1156,7 +1156,7 @@ class App(MiddlewareMixin, LifespanMixin):
|
|||||||
if hasattr(handler_fn, "__name__"):
|
if hasattr(handler_fn, "__name__"):
|
||||||
_fn_name = handler_fn.__name__
|
_fn_name = handler_fn.__name__
|
||||||
else:
|
else:
|
||||||
_fn_name = handler_fn.__class__.__name__
|
_fn_name = type(handler_fn).__name__
|
||||||
|
|
||||||
if isinstance(handler_fn, functools.partial):
|
if isinstance(handler_fn, functools.partial):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -161,7 +161,7 @@ class ComponentNamespace(SimpleNamespace):
|
|||||||
Returns:
|
Returns:
|
||||||
The hash of the namespace.
|
The hash of the namespace.
|
||||||
"""
|
"""
|
||||||
return hash(self.__class__.__name__)
|
return hash(type(self).__name__)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_style_namespaces(style: ComponentStyle) -> dict:
|
def evaluate_style_namespaces(style: ComponentStyle) -> dict:
|
||||||
@ -2565,7 +2565,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
|
|||||||
Returns:
|
Returns:
|
||||||
The hash of the var.
|
The hash of the var.
|
||||||
"""
|
"""
|
||||||
return hash((self.__class__.__name__, self._js_expr))
|
return hash((type(self).__name__, self._js_expr))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
|
@ -49,9 +49,9 @@ class Cond(MemoizationLeaf):
|
|||||||
The conditional component.
|
The conditional component.
|
||||||
"""
|
"""
|
||||||
# Wrap everything in fragments.
|
# Wrap everything in fragments.
|
||||||
if comp1.__class__.__name__ != "Fragment":
|
if type(comp1).__name__ != "Fragment":
|
||||||
comp1 = Fragment.create(comp1)
|
comp1 = Fragment.create(comp1)
|
||||||
if comp2 is None or comp2.__class__.__name__ != "Fragment":
|
if comp2 is None or type(comp2).__name__ != "Fragment":
|
||||||
comp2 = Fragment.create(comp2) if comp2 else Fragment.create()
|
comp2 = Fragment.create(comp2) if comp2 else Fragment.create()
|
||||||
return Fragment.create(
|
return Fragment.create(
|
||||||
cls(
|
cls(
|
||||||
|
@ -139,9 +139,7 @@ class RadixThemesComponent(Component):
|
|||||||
component = super().create(*children, **props)
|
component = super().create(*children, **props)
|
||||||
if component.library is None:
|
if component.library is None:
|
||||||
component.library = RadixThemesComponent.__fields__["library"].default
|
component.library = RadixThemesComponent.__fields__["library"].default
|
||||||
component.alias = "RadixThemes" + (
|
component.alias = "RadixThemes" + (component.tag or type(component).__name__)
|
||||||
component.tag or component.__class__.__name__
|
|
||||||
)
|
|
||||||
return component
|
return component
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -512,6 +512,9 @@ class EnvironmentVariables:
|
|||||||
# Whether to print the SQL queries if the log level is INFO or lower.
|
# Whether to print the SQL queries if the log level is INFO or lower.
|
||||||
SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False)
|
SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False)
|
||||||
|
|
||||||
|
# Whether to check db connections before using them.
|
||||||
|
SQLALCHEMY_POOL_PRE_PING: EnvVar[bool] = env_var(True)
|
||||||
|
|
||||||
# Whether to ignore the redis config error. Some redis servers only allow out-of-band configuration.
|
# Whether to ignore the redis config error. Some redis servers only allow out-of-band configuration.
|
||||||
REFLEX_IGNORE_REDIS_CONFIG_ERROR: EnvVar[bool] = env_var(False)
|
REFLEX_IGNORE_REDIS_CONFIG_ERROR: EnvVar[bool] = env_var(False)
|
||||||
|
|
||||||
@ -568,6 +571,10 @@ class EnvironmentVariables:
|
|||||||
environment = EnvironmentVariables()
|
environment = EnvironmentVariables()
|
||||||
|
|
||||||
|
|
||||||
|
# These vars are not logged because they may contain sensitive information.
|
||||||
|
_sensitive_env_vars = {"DB_URL", "ASYNC_DB_URL", "REDIS_URL"}
|
||||||
|
|
||||||
|
|
||||||
class Config(Base):
|
class Config(Base):
|
||||||
"""The config defines runtime settings for the app.
|
"""The config defines runtime settings for the app.
|
||||||
|
|
||||||
@ -621,6 +628,9 @@ class Config(Base):
|
|||||||
# The database url used by rx.Model.
|
# The database url used by rx.Model.
|
||||||
db_url: Optional[str] = "sqlite:///reflex.db"
|
db_url: Optional[str] = "sqlite:///reflex.db"
|
||||||
|
|
||||||
|
# The async database url used by rx.Model.
|
||||||
|
async_db_url: Optional[str] = None
|
||||||
|
|
||||||
# The redis url
|
# The redis url
|
||||||
redis_url: Optional[str] = None
|
redis_url: Optional[str] = None
|
||||||
|
|
||||||
@ -748,18 +758,20 @@ class Config(Base):
|
|||||||
|
|
||||||
# If the env var is set, override the config value.
|
# If the env var is set, override the config value.
|
||||||
if env_var is not None:
|
if env_var is not None:
|
||||||
if key.upper() != "DB_URL":
|
|
||||||
console.info(
|
|
||||||
f"Overriding config value {key} with env var {key.upper()}={env_var}",
|
|
||||||
dedupe=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Interpret the value.
|
# Interpret the value.
|
||||||
value = interpret_env_var_value(env_var, field.outer_type_, field.name)
|
value = interpret_env_var_value(env_var, field.outer_type_, field.name)
|
||||||
|
|
||||||
# Set the value.
|
# Set the value.
|
||||||
updated_values[key] = value
|
updated_values[key] = value
|
||||||
|
|
||||||
|
if key.upper() in _sensitive_env_vars:
|
||||||
|
env_var = "***"
|
||||||
|
|
||||||
|
console.info(
|
||||||
|
f"Overriding config value {key} with env var {key.upper()}={env_var}",
|
||||||
|
dedupe=True,
|
||||||
|
)
|
||||||
|
|
||||||
return updated_values
|
return updated_values
|
||||||
|
|
||||||
def get_event_namespace(self) -> str:
|
def get_event_namespace(self) -> str:
|
||||||
|
@ -1556,7 +1556,7 @@ class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
|
|||||||
Returns:
|
Returns:
|
||||||
The hash of the var.
|
The hash of the var.
|
||||||
"""
|
"""
|
||||||
return hash((self.__class__.__name__, self._js_expr))
|
return hash((type(self).__name__, self._js_expr))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@ -1620,7 +1620,7 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
|
|||||||
Returns:
|
Returns:
|
||||||
The hash of the var.
|
The hash of the var.
|
||||||
"""
|
"""
|
||||||
return hash((self.__class__.__name__, self._js_expr))
|
return hash((type(self).__name__, self._js_expr))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
|
126
reflex/model.py
126
reflex/model.py
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, ClassVar, Optional, Type, Union
|
from typing import Any, ClassVar, Optional, Type, Union
|
||||||
|
|
||||||
@ -14,6 +15,7 @@ import alembic.script
|
|||||||
import alembic.util
|
import alembic.util
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
import sqlalchemy.exc
|
import sqlalchemy.exc
|
||||||
|
import sqlalchemy.ext.asyncio
|
||||||
import sqlalchemy.orm
|
import sqlalchemy.orm
|
||||||
|
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
@ -21,6 +23,48 @@ from reflex.config import environment, get_config
|
|||||||
from reflex.utils import console
|
from reflex.utils import console
|
||||||
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
|
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
|
||||||
|
|
||||||
|
_ENGINE: dict[str, sqlalchemy.engine.Engine] = {}
|
||||||
|
_ASYNC_ENGINE: dict[str, sqlalchemy.ext.asyncio.AsyncEngine] = {}
|
||||||
|
_AsyncSessionLocal: dict[str | None, sqlalchemy.ext.asyncio.async_sessionmaker] = {}
|
||||||
|
|
||||||
|
# Import AsyncSession _after_ reflex.utils.compat
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_db_url_for_logging(url: str) -> str:
|
||||||
|
"""Remove username and password from the database URL for logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The database URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The database URL with the username and password removed.
|
||||||
|
"""
|
||||||
|
return re.sub(r"://[^@]+@", "://<username>:<password>@", url)
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine_args(url: str | None = None) -> dict[str, Any]:
|
||||||
|
"""Get the database engine arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The database url.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The database engine arguments as a dict.
|
||||||
|
"""
|
||||||
|
kwargs: dict[str, Any] = dict(
|
||||||
|
# Print the SQL queries if the log level is INFO or lower.
|
||||||
|
echo=environment.SQLALCHEMY_ECHO.get(),
|
||||||
|
# Check connections before returning them.
|
||||||
|
pool_pre_ping=environment.SQLALCHEMY_POOL_PRE_PING.get(),
|
||||||
|
)
|
||||||
|
conf = get_config()
|
||||||
|
url = url or conf.db_url
|
||||||
|
if url is not None and url.startswith("sqlite"):
|
||||||
|
# Needed for the admin dash on sqlite.
|
||||||
|
kwargs["connect_args"] = {"check_same_thread": False}
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
|
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
|
||||||
"""Get the database engine.
|
"""Get the database engine.
|
||||||
@ -38,15 +82,62 @@ def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
|
|||||||
url = url or conf.db_url
|
url = url or conf.db_url
|
||||||
if url is None:
|
if url is None:
|
||||||
raise ValueError("No database url configured")
|
raise ValueError("No database url configured")
|
||||||
|
|
||||||
|
global _ENGINE
|
||||||
|
if url in _ENGINE:
|
||||||
|
return _ENGINE[url]
|
||||||
|
|
||||||
if not environment.ALEMBIC_CONFIG.get().exists():
|
if not environment.ALEMBIC_CONFIG.get().exists():
|
||||||
console.warn(
|
console.warn(
|
||||||
"Database is not initialized, run [bold]reflex db init[/bold] first."
|
"Database is not initialized, run [bold]reflex db init[/bold] first."
|
||||||
)
|
)
|
||||||
# Print the SQL queries if the log level is INFO or lower.
|
_ENGINE[url] = sqlmodel.create_engine(
|
||||||
echo_db_query = environment.SQLALCHEMY_ECHO.get()
|
url,
|
||||||
# Needed for the admin dash on sqlite.
|
**get_engine_args(url),
|
||||||
connect_args = {"check_same_thread": False} if url.startswith("sqlite") else {}
|
)
|
||||||
return sqlmodel.create_engine(url, echo=echo_db_query, connect_args=connect_args)
|
return _ENGINE[url]
|
||||||
|
|
||||||
|
|
||||||
|
def get_async_engine(url: str | None) -> sqlalchemy.ext.asyncio.AsyncEngine:
|
||||||
|
"""Get the async database engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The database url.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The async database engine.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the async database url is None.
|
||||||
|
"""
|
||||||
|
if url is None:
|
||||||
|
conf = get_config()
|
||||||
|
url = conf.async_db_url
|
||||||
|
if url is not None and conf.db_url is not None:
|
||||||
|
async_db_url_tail = url.partition("://")[2]
|
||||||
|
db_url_tail = conf.db_url.partition("://")[2]
|
||||||
|
if async_db_url_tail != db_url_tail:
|
||||||
|
console.warn(
|
||||||
|
f"async_db_url `{_safe_db_url_for_logging(url)}` "
|
||||||
|
"should reference the same database as "
|
||||||
|
f"db_url `{_safe_db_url_for_logging(conf.db_url)}`."
|
||||||
|
)
|
||||||
|
if url is None:
|
||||||
|
raise ValueError("No async database url configured")
|
||||||
|
|
||||||
|
global _ASYNC_ENGINE
|
||||||
|
if url in _ASYNC_ENGINE:
|
||||||
|
return _ASYNC_ENGINE[url]
|
||||||
|
|
||||||
|
if not environment.ALEMBIC_CONFIG.get().exists():
|
||||||
|
console.warn(
|
||||||
|
"Database is not initialized, run [bold]reflex db init[/bold] first."
|
||||||
|
)
|
||||||
|
_ASYNC_ENGINE[url] = sqlalchemy.ext.asyncio.create_async_engine(
|
||||||
|
url,
|
||||||
|
**get_engine_args(url),
|
||||||
|
)
|
||||||
|
return _ASYNC_ENGINE[url]
|
||||||
|
|
||||||
|
|
||||||
async def get_db_status() -> bool:
|
async def get_db_status() -> bool:
|
||||||
@ -425,6 +516,31 @@ def session(url: str | None = None) -> sqlmodel.Session:
|
|||||||
return sqlmodel.Session(get_engine(url))
|
return sqlmodel.Session(get_engine(url))
|
||||||
|
|
||||||
|
|
||||||
|
def asession(url: str | None = None) -> AsyncSession:
|
||||||
|
"""Get an async sqlmodel session to interact with the database.
|
||||||
|
|
||||||
|
async with rx.asession() as asession:
|
||||||
|
...
|
||||||
|
|
||||||
|
Most operations against the `asession` must be awaited.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The database url.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An async database session.
|
||||||
|
"""
|
||||||
|
global _AsyncSessionLocal
|
||||||
|
if url not in _AsyncSessionLocal:
|
||||||
|
_AsyncSessionLocal[url] = sqlalchemy.ext.asyncio.async_sessionmaker(
|
||||||
|
bind=get_async_engine(url),
|
||||||
|
class_=AsyncSession,
|
||||||
|
autocommit=False,
|
||||||
|
autoflush=False,
|
||||||
|
)
|
||||||
|
return _AsyncSessionLocal[url]()
|
||||||
|
|
||||||
|
|
||||||
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
|
def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
|
||||||
"""Get a bare sqlalchemy session to interact with the database.
|
"""Get a bare sqlalchemy session to interact with the database.
|
||||||
|
|
||||||
|
@ -438,7 +438,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
Returns:
|
Returns:
|
||||||
The string representation of the state.
|
The string representation of the state.
|
||||||
"""
|
"""
|
||||||
return f"{self.__class__.__name__}({self.dict()})"
|
return f"{type(self).__name__}({self.dict()})"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_computed_vars(cls) -> list[ComputedVar]:
|
def _get_computed_vars(cls) -> list[ComputedVar]:
|
||||||
@ -3618,7 +3618,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
Returns:
|
Returns:
|
||||||
The representation of the wrapped object.
|
The representation of the wrapped object.
|
||||||
"""
|
"""
|
||||||
return f"{self.__class__.__name__}({self.__wrapped__})"
|
return f"{type(self).__name__}({self.__wrapped__})"
|
||||||
|
|
||||||
def _mark_dirty(
|
def _mark_dirty(
|
||||||
self,
|
self,
|
||||||
|
@ -1569,7 +1569,7 @@ class CachedVarOperation:
|
|||||||
if name == "_js_expr":
|
if name == "_js_expr":
|
||||||
return self._cached_var_name
|
return self._cached_var_name
|
||||||
|
|
||||||
parent_classes = inspect.getmro(self.__class__)
|
parent_classes = inspect.getmro(type(self))
|
||||||
|
|
||||||
next_class = parent_classes[parent_classes.index(CachedVarOperation) + 1]
|
next_class = parent_classes[parent_classes.index(CachedVarOperation) + 1]
|
||||||
|
|
||||||
@ -1611,7 +1611,7 @@ class CachedVarOperation:
|
|||||||
"""
|
"""
|
||||||
return hash(
|
return hash(
|
||||||
(
|
(
|
||||||
self.__class__.__name__,
|
type(self).__name__,
|
||||||
*[
|
*[
|
||||||
getattr(self, field.name)
|
getattr(self, field.name)
|
||||||
for field in dataclasses.fields(self) # type: ignore
|
for field in dataclasses.fields(self) # type: ignore
|
||||||
@ -1733,7 +1733,7 @@ class CallableVar(Var):
|
|||||||
Returns:
|
Returns:
|
||||||
The hash of the object.
|
The hash of the object.
|
||||||
"""
|
"""
|
||||||
return hash((self.__class__.__name__, self.original_var))
|
return hash((type(self).__name__, self.original_var))
|
||||||
|
|
||||||
|
|
||||||
RETURN_TYPE = TypeVar("RETURN_TYPE")
|
RETURN_TYPE = TypeVar("RETURN_TYPE")
|
||||||
|
@ -1012,7 +1012,7 @@ class LiteralNumberVar(LiteralVar, NumberVar):
|
|||||||
Returns:
|
Returns:
|
||||||
int: The hash value of the object.
|
int: The hash value of the object.
|
||||||
"""
|
"""
|
||||||
return hash((self.__class__.__name__, self._var_value))
|
return hash((type(self).__name__, self._var_value))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, value: float | int, _var_data: VarData | None = None):
|
def create(cls, value: float | int, _var_data: VarData | None = None):
|
||||||
@ -1064,7 +1064,7 @@ class LiteralBooleanVar(LiteralVar, BooleanVar):
|
|||||||
Returns:
|
Returns:
|
||||||
int: The hash value of the object.
|
int: The hash value of the object.
|
||||||
"""
|
"""
|
||||||
return hash((self.__class__.__name__, self._var_value))
|
return hash((type(self).__name__, self._var_value))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, value: bool, _var_data: VarData | None = None):
|
def create(cls, value: bool, _var_data: VarData | None = None):
|
||||||
|
@ -362,7 +362,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
|
|||||||
Returns:
|
Returns:
|
||||||
The hash of the var.
|
The hash of the var.
|
||||||
"""
|
"""
|
||||||
return hash((self.__class__.__name__, self._js_expr))
|
return hash((type(self).__name__, self._js_expr))
|
||||||
|
|
||||||
@cached_property_no_lock
|
@cached_property_no_lock
|
||||||
def _cached_get_all_var_data(self) -> VarData | None:
|
def _cached_get_all_var_data(self) -> VarData | None:
|
||||||
|
@ -667,7 +667,7 @@ class LiteralStringVar(LiteralVar, StringVar[str]):
|
|||||||
Returns:
|
Returns:
|
||||||
The hash of the var.
|
The hash of the var.
|
||||||
"""
|
"""
|
||||||
return hash((self.__class__.__name__, self._var_value))
|
return hash((type(self).__name__, self._var_value))
|
||||||
|
|
||||||
def json(self) -> str:
|
def json(self) -> str:
|
||||||
"""Get the JSON representation of the var.
|
"""Get the JSON representation of the var.
|
||||||
|
@ -73,7 +73,7 @@ def StateInheritance():
|
|||||||
def on_click_other_mixin(self):
|
def on_click_other_mixin(self):
|
||||||
self.other_mixin_clicks += 1
|
self.other_mixin_clicks += 1
|
||||||
self.other_mixin = (
|
self.other_mixin = (
|
||||||
f"{self.__class__.__name__}.clicked.{self.other_mixin_clicks}"
|
f"{type(self).__name__}.clicked.{self.other_mixin_clicks}"
|
||||||
)
|
)
|
||||||
|
|
||||||
class Base1(Mixin, rx.State):
|
class Base1(Mixin, rx.State):
|
||||||
|
@ -46,7 +46,7 @@ def test_default_primary_key(model_default_primary: Model):
|
|||||||
Args:
|
Args:
|
||||||
model_default_primary: Fixture.
|
model_default_primary: Fixture.
|
||||||
"""
|
"""
|
||||||
assert "id" in model_default_primary.__class__.__fields__
|
assert "id" in type(model_default_primary).__fields__
|
||||||
|
|
||||||
|
|
||||||
def test_custom_primary_key(model_custom_primary: Model):
|
def test_custom_primary_key(model_custom_primary: Model):
|
||||||
@ -55,7 +55,7 @@ def test_custom_primary_key(model_custom_primary: Model):
|
|||||||
Args:
|
Args:
|
||||||
model_custom_primary: Fixture.
|
model_custom_primary: Fixture.
|
||||||
"""
|
"""
|
||||||
assert "id" not in model_custom_primary.__class__.__fields__
|
assert "id" not in type(model_custom_primary).__fields__
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings(
|
@pytest.mark.filterwarnings(
|
||||||
|
@ -1697,7 +1697,7 @@ async def test_state_manager_modify_state(
|
|||||||
assert not state_manager._states_locks[token].locked()
|
assert not state_manager._states_locks[token].locked()
|
||||||
|
|
||||||
# separate instances should NOT share locks
|
# separate instances should NOT share locks
|
||||||
sm2 = state_manager.__class__(state=TestState)
|
sm2 = type(state_manager)(state=TestState)
|
||||||
assert sm2._state_manager_lock is state_manager._state_manager_lock
|
assert sm2._state_manager_lock is state_manager._state_manager_lock
|
||||||
assert not sm2._states_locks
|
assert not sm2._states_locks
|
||||||
if state_manager._states_locks:
|
if state_manager._states_locks:
|
||||||
|
Loading…
Reference in New Issue
Block a user