From 5411cc07becd3e6aeecf37ce49e399b0b5217ecd Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 7 Feb 2025 16:55:04 -0800 Subject: [PATCH] Add `awaitable_attrs` to rx.Model This works like sqlalchemy's AsyncAttrs mixin, providing an async-context for accessing fields and relationships when using async sessions. --- reflex/model.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/reflex/model.py b/reflex/model.py index 06bb87b02..a7e7d3180 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -5,7 +5,7 @@ from __future__ import annotations import re from collections import defaultdict from contextlib import suppress -from typing import Any, ClassVar, Optional, Type, Union +from typing import Any, Awaitable, ClassVar, Optional, Type, Union import alembic.autogenerate import alembic.command @@ -19,6 +19,7 @@ import sqlalchemy.exc import sqlalchemy.ext.asyncio import sqlalchemy.orm from alembic.runtime.migration import MigrationContext +from sqlalchemy.util.concurrency import greenlet_spawn from reflex.base import Base from reflex.config import environment, get_config @@ -243,6 +244,18 @@ class ModelRegistry: return metadata +class _AsyncAttrGetitem: + __slots__ = "_instance" + + def __init__(self, _instance: Model): + self._instance = _instance + + def __getattr__(self, name: str) -> Awaitable[Any]: + if name.startswith("_"): + return getattr(self._instance, name) + return greenlet_spawn(getattr, self._instance, name) + + class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssues,reportIncompatibleVariableOverride] """Base class to define a table in the database.""" @@ -261,6 +274,15 @@ class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssue super().__init_subclass__() + @property + def awaitable_attrs(self) -> _AsyncAttrGetitem: + """Provide a namespace of all attributes on this object wrapped as awaitables. + + Returns: + An awaitable attribute namespace. + """ + return _AsyncAttrGetitem(self) + @classmethod def _dict_recursive(cls, value: Any): """Recursively serialize the relationship object(s).