prepare base

This commit is contained in:
Benedikt Bartscher 2024-02-28 23:42:32 +01:00
parent 0b4b7c0d12
commit 25d05856f9
No known key found for this signature in database

View File

@ -2,15 +2,16 @@
from __future__ import annotations
import os
from typing import Any, List, Type
from typing import Any, List, Type, Dict, Optional
import pydantic
from pydantic import BaseModel
from pydantic.fields import ModelField
from pydantic.fields import FieldInfo
from reflex import constants
# TODO: migrate to pydantic v2
def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
"""Ensure that the field's name does not shadow an existing attribute of the model.
@ -35,7 +36,8 @@ def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None
# monkeypatch pydantic validate_field_name method to skip validating
# shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True)
pydantic.main.validate_field_name = validate_field_name # type: ignore
# TODO
# pydantic.main.validate_field_name = validate_field_name # type: ignore
class Base(pydantic.BaseModel):
@ -61,9 +63,10 @@ class Base(pydantic.BaseModel):
Returns:
The object as a json string.
"""
from reflex.utils.serializers import serialize
# from reflex.utils.serializers import serialize
return self.__config__.json_dumps(self.dict(), default=serialize)
return self.model_dump_json()
# return self.__config__.json_dumps(self.dict(), default=serialize)
def set(self, **kwargs):
"""Set multiple fields and return the object.
@ -85,7 +88,8 @@ class Base(pydantic.BaseModel):
Returns:
The fields of the object.
"""
return cls.__fields__
return cls.model_fields
@classmethod
def add_field(cls, var: Any, default_value: Any):
@ -97,14 +101,10 @@ class Base(pydantic.BaseModel):
var: The variable to add a pydantic field for.
default_value: The default value of the field
"""
new_field = ModelField.infer(
name=var._var_name,
value=default_value,
annotation=var._var_type,
class_validators=None,
config=cls.__config__,
)
cls.__fields__.update({var._var_name: new_field})
field_info = FieldInfo(default=default_value, annotation=var._var_type)
cls.model_fields.update({var._var_name: field_info})
cls.model_rebuild(force=True)
def get_value(self, key: str) -> Any:
"""Get the value of a field.
@ -115,7 +115,7 @@ class Base(pydantic.BaseModel):
Returns:
The value of the field.
"""
if isinstance(key, str) and key in self.__fields__:
if isinstance(key, str) and key in self.get_fields():
# Seems like this function signature was wrong all along?
# If the user wants a field that we know of, get it and pass it off to _get_value
key = getattr(self, key)