Merge 9e963338ef
into 6848915883
This commit is contained in:
commit
a851d5689a
@ -345,12 +345,16 @@ class Component(BaseComponent, ABC):
|
||||
if field.name not in props:
|
||||
continue
|
||||
|
||||
field_type = types.value_inside_optional(
|
||||
types.get_field_type(cls, field.name)
|
||||
)
|
||||
|
||||
# Set default values for any props.
|
||||
if types._issubclass(field.type_, Var):
|
||||
if types._issubclass(field_type, Var):
|
||||
field.required = False
|
||||
if field.default is not None:
|
||||
field.default = LiteralVar.create(field.default)
|
||||
elif types._issubclass(field.type_, EventHandler):
|
||||
elif types._issubclass(field_type, EventHandler):
|
||||
field.required = False
|
||||
|
||||
# Ensure renamed props from parent classes are applied to the subclass.
|
||||
@ -414,7 +418,9 @@ class Component(BaseComponent, ABC):
|
||||
field_type = EventChain
|
||||
elif key in props:
|
||||
# Set the field type.
|
||||
field_type = fields[key].type_
|
||||
field_type = types.value_inside_optional(
|
||||
types.get_field_type(type(self), key)
|
||||
)
|
||||
|
||||
else:
|
||||
continue
|
||||
@ -436,7 +442,10 @@ class Component(BaseComponent, ABC):
|
||||
try:
|
||||
kwargs[key] = determine_key(value)
|
||||
|
||||
expected_type = fields[key].outer_type_.__args__[0]
|
||||
expected_type = types.get_args(
|
||||
types.get_field_type(type(self), key)
|
||||
)[0]
|
||||
|
||||
# validate literal fields.
|
||||
types.validate_literal(
|
||||
key, value, expected_type, type(self).__name__
|
||||
@ -451,7 +460,7 @@ class Component(BaseComponent, ABC):
|
||||
except TypeError:
|
||||
# If it is not a valid var, check the base types.
|
||||
passed_type = type(value)
|
||||
expected_type = fields[key].outer_type_
|
||||
expected_type = types.get_field_type(type(self), key)
|
||||
if types.is_union(passed_type):
|
||||
# We need to check all possible types in the union.
|
||||
passed_types = (
|
||||
@ -563,8 +572,11 @@ class Component(BaseComponent, ABC):
|
||||
|
||||
# Look for component specific triggers,
|
||||
# e.g. variable declared as EventHandler types.
|
||||
for field in self.get_fields().values():
|
||||
if types._issubclass(field.outer_type_, EventHandler):
|
||||
for name, field in self.get_fields().items():
|
||||
if types._issubclass(
|
||||
types.value_inside_optional(types.get_field_type(type(self), name)),
|
||||
EventHandler,
|
||||
):
|
||||
args_spec = None
|
||||
annotation = field.annotation
|
||||
if (metadata := getattr(annotation, "__metadata__", None)) is not None:
|
||||
@ -675,9 +687,11 @@ class Component(BaseComponent, ABC):
|
||||
"""
|
||||
return {
|
||||
name
|
||||
for name, field in cls.get_fields().items()
|
||||
for name in cls.get_fields()
|
||||
if name in cls.get_props()
|
||||
and types._issubclass(field.outer_type_, Component)
|
||||
and types._issubclass(
|
||||
types.value_inside_optional(types.get_field_type(cls, name)), Component
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
@ -38,7 +38,12 @@ from reflex import constants
|
||||
from reflex.base import Base
|
||||
from reflex.utils import console
|
||||
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
|
||||
from reflex.utils.types import GenericType, is_union, value_inside_optional
|
||||
from reflex.utils.types import (
|
||||
GenericType,
|
||||
is_union,
|
||||
true_type_for_pydantic_field,
|
||||
value_inside_optional,
|
||||
)
|
||||
|
||||
try:
|
||||
import pydantic.v1 as pydantic
|
||||
@ -939,7 +944,9 @@ class Config(Base):
|
||||
# If the env var is set, override the config value.
|
||||
if env_var is not None:
|
||||
# Interpret the value.
|
||||
value = interpret_env_var_value(env_var, field.outer_type_, field.name)
|
||||
value = interpret_env_var_value(
|
||||
env_var, true_type_for_pydantic_field(field), field.name
|
||||
)
|
||||
|
||||
# Set the value.
|
||||
updated_values[key] = value
|
||||
|
@ -115,9 +115,9 @@ from reflex.utils.serializers import serializer
|
||||
from reflex.utils.types import (
|
||||
_isinstance,
|
||||
get_origin,
|
||||
is_optional,
|
||||
is_union,
|
||||
override,
|
||||
true_type_for_pydantic_field,
|
||||
value_inside_optional,
|
||||
)
|
||||
from reflex.vars import VarData
|
||||
@ -293,7 +293,7 @@ if TYPE_CHECKING:
|
||||
from pydantic.v1.fields import ModelField
|
||||
|
||||
|
||||
def _unwrap_field_type(type_: Type) -> Type:
|
||||
def _unwrap_field_type(type_: types.GenericType) -> Type:
|
||||
"""Unwrap rx.Field type annotations.
|
||||
|
||||
Args:
|
||||
@ -324,7 +324,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
|
||||
return dispatch(
|
||||
field_name=field_name,
|
||||
var_data=VarData.from_state(cls, f.name),
|
||||
result_var_type=_unwrap_field_type(f.outer_type_),
|
||||
result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)),
|
||||
)
|
||||
|
||||
|
||||
@ -1368,9 +1368,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
if name in fields:
|
||||
field = fields[name]
|
||||
field_type = _unwrap_field_type(field.outer_type_)
|
||||
if field.allow_none and not is_optional(field_type):
|
||||
field_type = Union[field_type, None]
|
||||
field_type = _unwrap_field_type(true_type_for_pydantic_field(field))
|
||||
if not _isinstance(value, field_type):
|
||||
console.error(
|
||||
f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}',"
|
||||
|
@ -14,6 +14,7 @@ from typing import (
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
FrozenSet,
|
||||
Iterable,
|
||||
List,
|
||||
@ -274,6 +275,33 @@ def is_optional(cls: GenericType) -> bool:
|
||||
return is_union(cls) and type(None) in get_args(cls)
|
||||
|
||||
|
||||
def true_type_for_pydantic_field(f: ModelField):
|
||||
"""Get the type for a pydantic field.
|
||||
|
||||
Args:
|
||||
f: The field to get the type for.
|
||||
|
||||
Returns:
|
||||
The type for the field.
|
||||
"""
|
||||
if not isinstance(f.annotation, (str, ForwardRef)):
|
||||
return f.annotation
|
||||
|
||||
type_ = f.outer_type_
|
||||
|
||||
if (
|
||||
f.field_info.default is None
|
||||
or (isinstance(f.annotation, str) and f.annotation.startswith("Optional"))
|
||||
or (
|
||||
isinstance(f.annotation, ForwardRef)
|
||||
and f.annotation.__forward_arg__.startswith("Optional")
|
||||
)
|
||||
) and not is_optional(type_):
|
||||
return Optional[type_]
|
||||
|
||||
return type_
|
||||
|
||||
|
||||
def value_inside_optional(cls: GenericType) -> GenericType:
|
||||
"""Get the value inside an Optional type or the original type.
|
||||
|
||||
@ -288,6 +316,29 @@ def value_inside_optional(cls: GenericType) -> GenericType:
|
||||
return cls
|
||||
|
||||
|
||||
def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
|
||||
"""Get the type of a field in a class.
|
||||
|
||||
Args:
|
||||
cls: The class to check.
|
||||
field_name: The name of the field to check.
|
||||
|
||||
Returns:
|
||||
The type of the field, if it exists, else None.
|
||||
"""
|
||||
if (
|
||||
hasattr(cls, "__fields__")
|
||||
and field_name in cls.__fields__
|
||||
and hasattr(cls.__fields__[field_name], "annotation")
|
||||
and not isinstance(cls.__fields__[field_name].annotation, (str, ForwardRef))
|
||||
):
|
||||
return cls.__fields__[field_name].annotation
|
||||
type_hints = get_type_hints(cls)
|
||||
if field_name in type_hints:
|
||||
return type_hints[field_name]
|
||||
return None
|
||||
|
||||
|
||||
def get_property_hint(attr: Any | None) -> GenericType | None:
|
||||
"""Check if an attribute is a property and return its type hint.
|
||||
|
||||
@ -325,24 +376,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
||||
if hint := get_property_hint(attr):
|
||||
return hint
|
||||
|
||||
if (
|
||||
hasattr(cls, "__fields__")
|
||||
and name in cls.__fields__
|
||||
and hasattr(cls.__fields__[name], "outer_type_")
|
||||
):
|
||||
if hasattr(cls, "__fields__") and name in cls.__fields__:
|
||||
# pydantic models
|
||||
field = cls.__fields__[name]
|
||||
type_ = field.outer_type_
|
||||
if isinstance(type_, ModelField):
|
||||
type_ = type_.type_
|
||||
if (
|
||||
not field.required
|
||||
and field.default is None
|
||||
and field.default_factory is None
|
||||
):
|
||||
# Ensure frontend uses null coalescing when accessing.
|
||||
type_ = Optional[type_]
|
||||
return type_
|
||||
return get_field_type(cls, name)
|
||||
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
|
||||
insp = sqlalchemy.inspect(cls)
|
||||
if name in insp.columns:
|
||||
|
Loading…
Reference in New Issue
Block a user