This commit is contained in:
Khaleel Al-Adhami 2025-02-15 14:15:57 -08:00 committed by GitHub
commit a851d5689a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 89 additions and 34 deletions

View File

@ -345,12 +345,16 @@ class Component(BaseComponent, ABC):
if field.name not in props:
continue
field_type = types.value_inside_optional(
types.get_field_type(cls, field.name)
)
# Set default values for any props.
if types._issubclass(field.type_, Var):
if types._issubclass(field_type, Var):
field.required = False
if field.default is not None:
field.default = LiteralVar.create(field.default)
elif types._issubclass(field.type_, EventHandler):
elif types._issubclass(field_type, EventHandler):
field.required = False
# Ensure renamed props from parent classes are applied to the subclass.
@ -414,7 +418,9 @@ class Component(BaseComponent, ABC):
field_type = EventChain
elif key in props:
# Set the field type.
field_type = fields[key].type_
field_type = types.value_inside_optional(
types.get_field_type(type(self), key)
)
else:
continue
@ -436,7 +442,10 @@ class Component(BaseComponent, ABC):
try:
kwargs[key] = determine_key(value)
expected_type = fields[key].outer_type_.__args__[0]
expected_type = types.get_args(
types.get_field_type(type(self), key)
)[0]
# validate literal fields.
types.validate_literal(
key, value, expected_type, type(self).__name__
@ -451,7 +460,7 @@ class Component(BaseComponent, ABC):
except TypeError:
# If it is not a valid var, check the base types.
passed_type = type(value)
expected_type = fields[key].outer_type_
expected_type = types.get_field_type(type(self), key)
if types.is_union(passed_type):
# We need to check all possible types in the union.
passed_types = (
@ -563,8 +572,11 @@ class Component(BaseComponent, ABC):
# Look for component specific triggers,
# e.g. variable declared as EventHandler types.
for field in self.get_fields().values():
if types._issubclass(field.outer_type_, EventHandler):
for name, field in self.get_fields().items():
if types._issubclass(
types.value_inside_optional(types.get_field_type(type(self), name)),
EventHandler,
):
args_spec = None
annotation = field.annotation
if (metadata := getattr(annotation, "__metadata__", None)) is not None:
@ -675,9 +687,11 @@ class Component(BaseComponent, ABC):
"""
return {
name
for name, field in cls.get_fields().items()
for name in cls.get_fields()
if name in cls.get_props()
and types._issubclass(field.outer_type_, Component)
and types._issubclass(
types.value_inside_optional(types.get_field_type(cls, name)), Component
)
}
@classmethod

View File

@ -38,7 +38,12 @@ from reflex import constants
from reflex.base import Base
from reflex.utils import console
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
from reflex.utils.types import GenericType, is_union, value_inside_optional
from reflex.utils.types import (
GenericType,
is_union,
true_type_for_pydantic_field,
value_inside_optional,
)
try:
import pydantic.v1 as pydantic
@ -939,7 +944,9 @@ class Config(Base):
# If the env var is set, override the config value.
if env_var is not None:
# Interpret the value.
value = interpret_env_var_value(env_var, field.outer_type_, field.name)
value = interpret_env_var_value(
env_var, true_type_for_pydantic_field(field), field.name
)
# Set the value.
updated_values[key] = value

View File

@ -115,9 +115,9 @@ from reflex.utils.serializers import serializer
from reflex.utils.types import (
_isinstance,
get_origin,
is_optional,
is_union,
override,
true_type_for_pydantic_field,
value_inside_optional,
)
from reflex.vars import VarData
@ -293,7 +293,7 @@ if TYPE_CHECKING:
from pydantic.v1.fields import ModelField
def _unwrap_field_type(type_: Type) -> Type:
def _unwrap_field_type(type_: types.GenericType) -> Type:
"""Unwrap rx.Field type annotations.
Args:
@ -324,7 +324,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
return dispatch(
field_name=field_name,
var_data=VarData.from_state(cls, f.name),
result_var_type=_unwrap_field_type(f.outer_type_),
result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)),
)
@ -1368,9 +1368,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if name in fields:
field = fields[name]
field_type = _unwrap_field_type(field.outer_type_)
if field.allow_none and not is_optional(field_type):
field_type = Union[field_type, None]
field_type = _unwrap_field_type(true_type_for_pydantic_field(field))
if not _isinstance(value, field_type):
console.error(
f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}',"

View File

@ -14,6 +14,7 @@ from typing import (
Callable,
ClassVar,
Dict,
ForwardRef,
FrozenSet,
Iterable,
List,
@ -274,6 +275,33 @@ def is_optional(cls: GenericType) -> bool:
return is_union(cls) and type(None) in get_args(cls)
def true_type_for_pydantic_field(f: ModelField):
"""Get the type for a pydantic field.
Args:
f: The field to get the type for.
Returns:
The type for the field.
"""
if not isinstance(f.annotation, (str, ForwardRef)):
return f.annotation
type_ = f.outer_type_
if (
f.field_info.default is None
or (isinstance(f.annotation, str) and f.annotation.startswith("Optional"))
or (
isinstance(f.annotation, ForwardRef)
and f.annotation.__forward_arg__.startswith("Optional")
)
) and not is_optional(type_):
return Optional[type_]
return type_
def value_inside_optional(cls: GenericType) -> GenericType:
"""Get the value inside an Optional type or the original type.
@ -288,6 +316,29 @@ def value_inside_optional(cls: GenericType) -> GenericType:
return cls
def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
"""Get the type of a field in a class.
Args:
cls: The class to check.
field_name: The name of the field to check.
Returns:
The type of the field, if it exists, else None.
"""
if (
hasattr(cls, "__fields__")
and field_name in cls.__fields__
and hasattr(cls.__fields__[field_name], "annotation")
and not isinstance(cls.__fields__[field_name].annotation, (str, ForwardRef))
):
return cls.__fields__[field_name].annotation
type_hints = get_type_hints(cls)
if field_name in type_hints:
return type_hints[field_name]
return None
def get_property_hint(attr: Any | None) -> GenericType | None:
"""Check if an attribute is a property and return its type hint.
@ -325,24 +376,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
if hint := get_property_hint(attr):
return hint
if (
hasattr(cls, "__fields__")
and name in cls.__fields__
and hasattr(cls.__fields__[name], "outer_type_")
):
if hasattr(cls, "__fields__") and name in cls.__fields__:
# pydantic models
field = cls.__fields__[name]
type_ = field.outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
if (
not field.required
and field.default is None
and field.default_factory is None
):
# Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_]
return type_
return get_field_type(cls, name)
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
insp = sqlalchemy.inspect(cls)
if name in insp.columns: