diff --git a/reflex/model.py b/reflex/model.py index 06bb87b02..a7e7d3180 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -5,7 +5,7 @@ from __future__ import annotations import re from collections import defaultdict from contextlib import suppress -from typing import Any, ClassVar, Optional, Type, Union +from typing import Any, Awaitable, ClassVar, Optional, Type, Union import alembic.autogenerate import alembic.command @@ -19,6 +19,7 @@ import sqlalchemy.exc import sqlalchemy.ext.asyncio import sqlalchemy.orm from alembic.runtime.migration import MigrationContext +from sqlalchemy.util.concurrency import greenlet_spawn from reflex.base import Base from reflex.config import environment, get_config @@ -243,6 +244,18 @@ class ModelRegistry: return metadata +class _AsyncAttrGetitem: + __slots__ = "_instance" + + def __init__(self, _instance: Model): + self._instance = _instance + + def __getattr__(self, name: str) -> Awaitable[Any]: + if name.startswith("_"): + return getattr(self._instance, name) + return greenlet_spawn(getattr, self._instance, name) + + class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssues,reportIncompatibleVariableOverride] """Base class to define a table in the database.""" @@ -261,6 +274,15 @@ class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssue super().__init_subclass__() + @property + def awaitable_attrs(self) -> _AsyncAttrGetitem: + """Provide a namespace of all attributes on this object wrapped as awaitables. + + Returns: + An awaitable attribute namespace. + """ + return _AsyncAttrGetitem(self) + @classmethod def _dict_recursive(cls, value: Any): """Recursively serialize the relationship object(s).