Compare commits

...

3 Commits

Author SHA1 Message Date
Masen Furer
484dcf28a7
Fix type hints for py39 2024-11-06 16:37:39 -08:00
Masen Furer
3de041629c
add to= type for serializers 2024-11-06 16:34:38 -08:00
Masen Furer
a493b70fad
[ENG-4003] Support DeclarativeBase in core
Add serializers and Var definitions for SQLAlchemy DeclarativeBase models.
2024-11-06 16:28:32 -08:00

View File

@ -2,8 +2,11 @@
from __future__ import annotations
import dataclasses
import sys
from collections import defaultdict
from typing import Any, ClassVar, Optional, Type, Union
from collections.abc import Mapping, MutableSet, Sequence
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
import alembic.autogenerate
import alembic.command
@ -15,11 +18,14 @@ import alembic.util
import sqlalchemy
import sqlalchemy.exc
import sqlalchemy.orm
from sqlalchemy.orm import DeclarativeBase
from reflex.base import Base
from reflex.config import environment, get_config
from reflex.utils import console
from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
from reflex.utils.serializers import serializer
from reflex.vars.object import LiteralObjectVar, ObjectVar
def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
@ -435,3 +441,69 @@ def sqla_session(url: str | None = None) -> sqlalchemy.orm.Session:
A database session.
"""
return sqlalchemy.orm.Session(get_engine(url))
class DeclarativeBaseVar(ObjectVar, python_types=DeclarativeBase):
"""Var for a SQLAlchemy DeclarativeBase object."""
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralDeclarativeBaseVar(LiteralObjectVar, DeclarativeBaseVar):
"""Literal Var for a SQLAlchemy DeclarativeBase object."""
_var_value: DeclarativeBase | None = None
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
@serializer(to=list)
def serialize_Sequence(s: Union[Sequence[T], MutableSet[T]]) -> List[T]:
"""Serialize a sequence or mutable set as a regular list.
Args:
s: The sequence or mutable set to serialize.
Returns:
The serialized list.
"""
return list(s)
@serializer(to=dict)
def serialize_Mapping(m: Mapping[K, V]) -> Dict[K, V]:
"""Serialize a mapping as a regular dictionary.
Args:
m: The mapping to serialize.
Returns:
The serialized dictionary.
"""
return dict(m)
@serializer(to=dict)
def serialize_DeclarativeBase(obj: DeclarativeBase) -> Dict[str, str]:
"""Serialize a SQLAlchemy DeclarativeBase object as a dictionary.
Args:
obj: The SQLAlchemy DeclarativeBase object to serialize.
Returns:
The serialized dictionary.
"""
s = {}
for attr in sqlalchemy.inspect(type(obj)).all_orm_descriptors.keys(): # noqa: SIM118
try:
s[attr] = getattr(obj, attr)
except sqlalchemy.orm.exc.DetachedInstanceError:
# This happens when the relationship was never loaded and the session is closed.
continue
return s