This commit is contained in:
Masen Furer 2025-02-22 16:38:36 +00:00 committed by GitHub
commit 330e4719ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import re import re
from collections import defaultdict from collections import defaultdict
from contextlib import suppress from contextlib import suppress
from typing import Any, ClassVar, Optional, Type, Union from typing import Any, Awaitable, ClassVar, Optional, Type, Union
import alembic.autogenerate import alembic.autogenerate
import alembic.command import alembic.command
@ -19,6 +19,7 @@ import sqlalchemy.exc
import sqlalchemy.ext.asyncio import sqlalchemy.ext.asyncio
import sqlalchemy.orm import sqlalchemy.orm
from alembic.runtime.migration import MigrationContext from alembic.runtime.migration import MigrationContext
from sqlalchemy.util.concurrency import greenlet_spawn
from reflex.base import Base from reflex.base import Base
from reflex.config import environment, get_config from reflex.config import environment, get_config
@ -243,6 +244,18 @@ class ModelRegistry:
return metadata return metadata
class _AsyncAttrGetitem:
__slots__ = "_instance"
def __init__(self, _instance: Model):
self._instance = _instance
def __getattr__(self, name: str) -> Awaitable[Any]:
if name.startswith("_"):
return getattr(self._instance, name)
return greenlet_spawn(getattr, self._instance, name)
class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssues,reportIncompatibleVariableOverride] class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssues,reportIncompatibleVariableOverride]
"""Base class to define a table in the database.""" """Base class to define a table in the database."""
@ -261,6 +274,15 @@ class Model(Base, sqlmodel.SQLModel): # pyright: ignore [reportGeneralTypeIssue
super().__init_subclass__() super().__init_subclass__()
@property
def awaitable_attrs(self) -> _AsyncAttrGetitem:
"""Provide a namespace of all attributes on this object wrapped as awaitables.
Returns:
An awaitable attribute namespace.
"""
return _AsyncAttrGetitem(self)
@classmethod @classmethod
def _dict_recursive(cls, value: Any): def _dict_recursive(cls, value: Any):
"""Recursively serialize the relationship object(s). """Recursively serialize the relationship object(s).