Compare commits

...

1 Commits

Author SHA1 Message Date
Masen Furer
5411cc07be
Add awaitable_attrs to rx.Model
This works like sqlalchemy's AsyncAttrs mixin, providing an async-context for
accessing fields and relationships when using async sessions.
2025-02-07 16:55:04 -08:00

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import re
from collections import defaultdict
from contextlib import suppress
from typing import Any, ClassVar, Optional, Type, Union
from typing import Any, Awaitable, ClassVar, Optional, Type, Union
import alembic.autogenerate
import alembic.command
@ -19,6 +19,7 @@ import sqlalchemy.exc
import sqlalchemy.ext.asyncio
import sqlalchemy.orm
from alembic.runtime.migration import MigrationContext
from sqlalchemy.util.concurrency import greenlet_spawn
from reflex.base import Base
from reflex.config import environment, get_config
@ -243,6 +244,18 @@ class ModelRegistry:
return metadata
class _AsyncAttrGetitem:
__slots__ = "_instance"
def __init__(self, _instance: Model):
self._instance = _instance
def __getattr__(self, name: str) -> Awaitable[Any]:
if name.startswith("_"):
return getattr(self._instance, name)
return greenlet_spawn(getattr, self._instance, name)
class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssues,reportIncompatibleVariableOverride]
"""Base class to define a table in the database."""
@ -261,6 +274,15 @@ class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssue
super().__init_subclass__()
@property
def awaitable_attrs(self) -> _AsyncAttrGetitem:
"""Provide a namespace of all attributes on this object wrapped as awaitables.
Returns:
An awaitable attribute namespace.
"""
return _AsyncAttrGetitem(self)
@classmethod
def _dict_recursive(cls, value: Any):
"""Recursively serialize the relationship object(s).