fix: handle default_factory in get_attribute_access_type, add tests for sqla dataclasses
This commit is contained in:
parent
fd0fd2c6d4
commit
930c1e2129
@ -331,7 +331,11 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
||||
type_ = field.outer_type_
|
||||
if isinstance(type_, ModelField):
|
||||
type_ = type_.type_
|
||||
if not field.required and field.default is None:
|
||||
if (
|
||||
not field.required
|
||||
and field.default is None
|
||||
and field.default_factory is None
|
||||
):
|
||||
# Ensure frontend uses null coalescing when accessing.
|
||||
type_ = Optional[type_]
|
||||
return type_
|
||||
|
@ -3,11 +3,19 @@ from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
import attrs
|
||||
import pydantic.v1
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
import sqlmodel
|
||||
from sqlalchemy import JSON, TypeDecorator
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
MappedAsDataclass,
|
||||
mapped_column,
|
||||
relationship,
|
||||
)
|
||||
|
||||
import reflex as rx
|
||||
from reflex.utils.types import GenericType, get_attribute_access_type
|
||||
@ -53,6 +61,10 @@ class SQLALabel(SQLABase):
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
|
||||
test: Mapped[SQLAClass] = relationship(back_populates="labels")
|
||||
test_dataclass_id: Mapped[int] = mapped_column(
|
||||
sqlalchemy.ForeignKey("test_dataclass.id")
|
||||
)
|
||||
test_dataclass: Mapped[SQLAClassDataclass] = relationship(back_populates="labels")
|
||||
|
||||
|
||||
class SQLAClass(SQLABase):
|
||||
@ -75,6 +87,62 @@ class SQLAClass(SQLABase):
|
||||
# do not use lower case dict here!
|
||||
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
|
||||
dict_str_str: Mapped[Dict[str, str]] = mapped_column()
|
||||
default_factory: Mapped[List[int]] = mapped_column(
|
||||
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER),
|
||||
)
|
||||
|
||||
@property
|
||||
def str_property(self) -> str:
|
||||
"""String property.
|
||||
|
||||
Returns:
|
||||
Name attribute
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@hybrid_property
|
||||
def str_or_int_property(self) -> Union[str, int]:
|
||||
"""String or int property.
|
||||
|
||||
Returns:
|
||||
Name attribute
|
||||
"""
|
||||
return self.name
|
||||
|
||||
@hybrid_property
|
||||
def first_label(self) -> Optional[SQLALabel]:
|
||||
"""First label property.
|
||||
|
||||
Returns:
|
||||
First label
|
||||
"""
|
||||
return self.labels[0] if self.labels else None
|
||||
|
||||
|
||||
class SQLAClassDataclass(MappedAsDataclass, SQLABase):
|
||||
"""Test sqlalchemy model."""
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
count: Mapped[int] = mapped_column()
|
||||
name: Mapped[str] = mapped_column()
|
||||
int_list: Mapped[List[int]] = mapped_column(
|
||||
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER)
|
||||
)
|
||||
str_list: Mapped[List[str]] = mapped_column(
|
||||
sqlalchemy.types.ARRAY(item_type=sqlalchemy.String)
|
||||
)
|
||||
optional_int: Mapped[Optional[int]] = mapped_column(nullable=True)
|
||||
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
|
||||
sqla_tag: Mapped[Optional[SQLATag]] = relationship()
|
||||
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test_dataclass")
|
||||
# do not use lower case dict here!
|
||||
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
|
||||
dict_str_str: Mapped[Dict[str, str]] = mapped_column()
|
||||
default_factory: Mapped[List[int]] = mapped_column(
|
||||
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER),
|
||||
default_factory=list,
|
||||
)
|
||||
__tablename__: str = "test_dataclass"
|
||||
|
||||
@property
|
||||
def str_property(self) -> str:
|
||||
@ -115,6 +183,7 @@ class ModelClass(rx.Model):
|
||||
sqla_tag: Optional[SQLATag] = None
|
||||
labels: List[SQLALabel] = []
|
||||
dict_str_str: Dict[str, str] = {}
|
||||
default_factory: List[int] = sqlmodel.Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def str_property(self) -> str:
|
||||
@ -155,6 +224,7 @@ class BaseClass(rx.Base):
|
||||
sqla_tag: Optional[SQLATag] = None
|
||||
labels: List[SQLALabel] = []
|
||||
dict_str_str: Dict[str, str] = {}
|
||||
default_factory: List[int] = pydantic.v1.Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def str_property(self) -> str:
|
||||
@ -195,6 +265,7 @@ class BareClass:
|
||||
sqla_tag: Optional[SQLATag] = None
|
||||
labels: List[SQLALabel] = []
|
||||
dict_str_str: Dict[str, str] = {}
|
||||
default_factory: List[int] = []
|
||||
|
||||
@property
|
||||
def str_property(self) -> str:
|
||||
@ -236,6 +307,7 @@ class AttrClass:
|
||||
sqla_tag: Optional[SQLATag] = None
|
||||
labels: List[SQLALabel] = []
|
||||
dict_str_str: Dict[str, str] = {}
|
||||
default_factory: List[int] = attrs.field(factory=list)
|
||||
|
||||
@property
|
||||
def str_property(self) -> str:
|
||||
@ -268,6 +340,7 @@ class AttrClass:
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
SQLAClass,
|
||||
SQLAClassDataclass,
|
||||
BaseClass,
|
||||
BareClass,
|
||||
ModelClass,
|
||||
@ -300,6 +373,7 @@ def cls(request: pytest.FixtureRequest) -> type:
|
||||
pytest.param("str_property", str, id="str_property"),
|
||||
pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"),
|
||||
pytest.param("first_label", Optional[SQLALabel], id="first_label"),
|
||||
pytest.param("default_factory", List[int], id="default_factory"),
|
||||
],
|
||||
)
|
||||
def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user