Merge 9e963338ef
into 6848915883
This commit is contained in:
commit
a851d5689a
@ -345,12 +345,16 @@ class Component(BaseComponent, ABC):
|
|||||||
if field.name not in props:
|
if field.name not in props:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
field_type = types.value_inside_optional(
|
||||||
|
types.get_field_type(cls, field.name)
|
||||||
|
)
|
||||||
|
|
||||||
# Set default values for any props.
|
# Set default values for any props.
|
||||||
if types._issubclass(field.type_, Var):
|
if types._issubclass(field_type, Var):
|
||||||
field.required = False
|
field.required = False
|
||||||
if field.default is not None:
|
if field.default is not None:
|
||||||
field.default = LiteralVar.create(field.default)
|
field.default = LiteralVar.create(field.default)
|
||||||
elif types._issubclass(field.type_, EventHandler):
|
elif types._issubclass(field_type, EventHandler):
|
||||||
field.required = False
|
field.required = False
|
||||||
|
|
||||||
# Ensure renamed props from parent classes are applied to the subclass.
|
# Ensure renamed props from parent classes are applied to the subclass.
|
||||||
@ -414,7 +418,9 @@ class Component(BaseComponent, ABC):
|
|||||||
field_type = EventChain
|
field_type = EventChain
|
||||||
elif key in props:
|
elif key in props:
|
||||||
# Set the field type.
|
# Set the field type.
|
||||||
field_type = fields[key].type_
|
field_type = types.value_inside_optional(
|
||||||
|
types.get_field_type(type(self), key)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
@ -436,7 +442,10 @@ class Component(BaseComponent, ABC):
|
|||||||
try:
|
try:
|
||||||
kwargs[key] = determine_key(value)
|
kwargs[key] = determine_key(value)
|
||||||
|
|
||||||
expected_type = fields[key].outer_type_.__args__[0]
|
expected_type = types.get_args(
|
||||||
|
types.get_field_type(type(self), key)
|
||||||
|
)[0]
|
||||||
|
|
||||||
# validate literal fields.
|
# validate literal fields.
|
||||||
types.validate_literal(
|
types.validate_literal(
|
||||||
key, value, expected_type, type(self).__name__
|
key, value, expected_type, type(self).__name__
|
||||||
@ -451,7 +460,7 @@ class Component(BaseComponent, ABC):
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
# If it is not a valid var, check the base types.
|
# If it is not a valid var, check the base types.
|
||||||
passed_type = type(value)
|
passed_type = type(value)
|
||||||
expected_type = fields[key].outer_type_
|
expected_type = types.get_field_type(type(self), key)
|
||||||
if types.is_union(passed_type):
|
if types.is_union(passed_type):
|
||||||
# We need to check all possible types in the union.
|
# We need to check all possible types in the union.
|
||||||
passed_types = (
|
passed_types = (
|
||||||
@ -563,8 +572,11 @@ class Component(BaseComponent, ABC):
|
|||||||
|
|
||||||
# Look for component specific triggers,
|
# Look for component specific triggers,
|
||||||
# e.g. variable declared as EventHandler types.
|
# e.g. variable declared as EventHandler types.
|
||||||
for field in self.get_fields().values():
|
for name, field in self.get_fields().items():
|
||||||
if types._issubclass(field.outer_type_, EventHandler):
|
if types._issubclass(
|
||||||
|
types.value_inside_optional(types.get_field_type(type(self), name)),
|
||||||
|
EventHandler,
|
||||||
|
):
|
||||||
args_spec = None
|
args_spec = None
|
||||||
annotation = field.annotation
|
annotation = field.annotation
|
||||||
if (metadata := getattr(annotation, "__metadata__", None)) is not None:
|
if (metadata := getattr(annotation, "__metadata__", None)) is not None:
|
||||||
@ -675,9 +687,11 @@ class Component(BaseComponent, ABC):
|
|||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
name
|
name
|
||||||
for name, field in cls.get_fields().items()
|
for name in cls.get_fields()
|
||||||
if name in cls.get_props()
|
if name in cls.get_props()
|
||||||
and types._issubclass(field.outer_type_, Component)
|
and types._issubclass(
|
||||||
|
types.value_inside_optional(types.get_field_type(cls, name)), Component
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -38,7 +38,12 @@ from reflex import constants
|
|||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.utils import console
|
from reflex.utils import console
|
||||||
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
|
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
|
||||||
from reflex.utils.types import GenericType, is_union, value_inside_optional
|
from reflex.utils.types import (
|
||||||
|
GenericType,
|
||||||
|
is_union,
|
||||||
|
true_type_for_pydantic_field,
|
||||||
|
value_inside_optional,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pydantic.v1 as pydantic
|
import pydantic.v1 as pydantic
|
||||||
@ -939,7 +944,9 @@ class Config(Base):
|
|||||||
# If the env var is set, override the config value.
|
# If the env var is set, override the config value.
|
||||||
if env_var is not None:
|
if env_var is not None:
|
||||||
# Interpret the value.
|
# Interpret the value.
|
||||||
value = interpret_env_var_value(env_var, field.outer_type_, field.name)
|
value = interpret_env_var_value(
|
||||||
|
env_var, true_type_for_pydantic_field(field), field.name
|
||||||
|
)
|
||||||
|
|
||||||
# Set the value.
|
# Set the value.
|
||||||
updated_values[key] = value
|
updated_values[key] = value
|
||||||
|
@ -115,9 +115,9 @@ from reflex.utils.serializers import serializer
|
|||||||
from reflex.utils.types import (
|
from reflex.utils.types import (
|
||||||
_isinstance,
|
_isinstance,
|
||||||
get_origin,
|
get_origin,
|
||||||
is_optional,
|
|
||||||
is_union,
|
is_union,
|
||||||
override,
|
override,
|
||||||
|
true_type_for_pydantic_field,
|
||||||
value_inside_optional,
|
value_inside_optional,
|
||||||
)
|
)
|
||||||
from reflex.vars import VarData
|
from reflex.vars import VarData
|
||||||
@ -293,7 +293,7 @@ if TYPE_CHECKING:
|
|||||||
from pydantic.v1.fields import ModelField
|
from pydantic.v1.fields import ModelField
|
||||||
|
|
||||||
|
|
||||||
def _unwrap_field_type(type_: Type) -> Type:
|
def _unwrap_field_type(type_: types.GenericType) -> Type:
|
||||||
"""Unwrap rx.Field type annotations.
|
"""Unwrap rx.Field type annotations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -324,7 +324,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
|
|||||||
return dispatch(
|
return dispatch(
|
||||||
field_name=field_name,
|
field_name=field_name,
|
||||||
var_data=VarData.from_state(cls, f.name),
|
var_data=VarData.from_state(cls, f.name),
|
||||||
result_var_type=_unwrap_field_type(f.outer_type_),
|
result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1368,9 +1368,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
if name in fields:
|
if name in fields:
|
||||||
field = fields[name]
|
field = fields[name]
|
||||||
field_type = _unwrap_field_type(field.outer_type_)
|
field_type = _unwrap_field_type(true_type_for_pydantic_field(field))
|
||||||
if field.allow_none and not is_optional(field_type):
|
|
||||||
field_type = Union[field_type, None]
|
|
||||||
if not _isinstance(value, field_type):
|
if not _isinstance(value, field_type):
|
||||||
console.error(
|
console.error(
|
||||||
f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}',"
|
f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}',"
|
||||||
|
@ -14,6 +14,7 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
|
ForwardRef,
|
||||||
FrozenSet,
|
FrozenSet,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
@ -274,6 +275,33 @@ def is_optional(cls: GenericType) -> bool:
|
|||||||
return is_union(cls) and type(None) in get_args(cls)
|
return is_union(cls) and type(None) in get_args(cls)
|
||||||
|
|
||||||
|
|
||||||
|
def true_type_for_pydantic_field(f: ModelField):
|
||||||
|
"""Get the type for a pydantic field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f: The field to get the type for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The type for the field.
|
||||||
|
"""
|
||||||
|
if not isinstance(f.annotation, (str, ForwardRef)):
|
||||||
|
return f.annotation
|
||||||
|
|
||||||
|
type_ = f.outer_type_
|
||||||
|
|
||||||
|
if (
|
||||||
|
f.field_info.default is None
|
||||||
|
or (isinstance(f.annotation, str) and f.annotation.startswith("Optional"))
|
||||||
|
or (
|
||||||
|
isinstance(f.annotation, ForwardRef)
|
||||||
|
and f.annotation.__forward_arg__.startswith("Optional")
|
||||||
|
)
|
||||||
|
) and not is_optional(type_):
|
||||||
|
return Optional[type_]
|
||||||
|
|
||||||
|
return type_
|
||||||
|
|
||||||
|
|
||||||
def value_inside_optional(cls: GenericType) -> GenericType:
|
def value_inside_optional(cls: GenericType) -> GenericType:
|
||||||
"""Get the value inside an Optional type or the original type.
|
"""Get the value inside an Optional type or the original type.
|
||||||
|
|
||||||
@ -288,6 +316,29 @@ def value_inside_optional(cls: GenericType) -> GenericType:
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
|
||||||
|
"""Get the type of a field in a class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: The class to check.
|
||||||
|
field_name: The name of the field to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The type of the field, if it exists, else None.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
hasattr(cls, "__fields__")
|
||||||
|
and field_name in cls.__fields__
|
||||||
|
and hasattr(cls.__fields__[field_name], "annotation")
|
||||||
|
and not isinstance(cls.__fields__[field_name].annotation, (str, ForwardRef))
|
||||||
|
):
|
||||||
|
return cls.__fields__[field_name].annotation
|
||||||
|
type_hints = get_type_hints(cls)
|
||||||
|
if field_name in type_hints:
|
||||||
|
return type_hints[field_name]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_property_hint(attr: Any | None) -> GenericType | None:
|
def get_property_hint(attr: Any | None) -> GenericType | None:
|
||||||
"""Check if an attribute is a property and return its type hint.
|
"""Check if an attribute is a property and return its type hint.
|
||||||
|
|
||||||
@ -325,24 +376,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
|||||||
if hint := get_property_hint(attr):
|
if hint := get_property_hint(attr):
|
||||||
return hint
|
return hint
|
||||||
|
|
||||||
if (
|
if hasattr(cls, "__fields__") and name in cls.__fields__:
|
||||||
hasattr(cls, "__fields__")
|
|
||||||
and name in cls.__fields__
|
|
||||||
and hasattr(cls.__fields__[name], "outer_type_")
|
|
||||||
):
|
|
||||||
# pydantic models
|
# pydantic models
|
||||||
field = cls.__fields__[name]
|
return get_field_type(cls, name)
|
||||||
type_ = field.outer_type_
|
|
||||||
if isinstance(type_, ModelField):
|
|
||||||
type_ = type_.type_
|
|
||||||
if (
|
|
||||||
not field.required
|
|
||||||
and field.default is None
|
|
||||||
and field.default_factory is None
|
|
||||||
):
|
|
||||||
# Ensure frontend uses null coalescing when accessing.
|
|
||||||
type_ = Optional[type_]
|
|
||||||
return type_
|
|
||||||
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
|
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
|
||||||
insp = sqlalchemy.inspect(cls)
|
insp = sqlalchemy.inspect(cls)
|
||||||
if name in insp.columns:
|
if name in insp.columns:
|
||||||
|
Loading…
Reference in New Issue
Block a user