This commit is contained in:
Khaleel Al-Adhami 2025-02-15 14:15:57 -08:00 committed by GitHub
commit a851d5689a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 89 additions and 34 deletions

View File

@ -345,12 +345,16 @@ class Component(BaseComponent, ABC):
if field.name not in props: if field.name not in props:
continue continue
field_type = types.value_inside_optional(
types.get_field_type(cls, field.name)
)
# Set default values for any props. # Set default values for any props.
if types._issubclass(field.type_, Var): if types._issubclass(field_type, Var):
field.required = False field.required = False
if field.default is not None: if field.default is not None:
field.default = LiteralVar.create(field.default) field.default = LiteralVar.create(field.default)
elif types._issubclass(field.type_, EventHandler): elif types._issubclass(field_type, EventHandler):
field.required = False field.required = False
# Ensure renamed props from parent classes are applied to the subclass. # Ensure renamed props from parent classes are applied to the subclass.
@ -414,7 +418,9 @@ class Component(BaseComponent, ABC):
field_type = EventChain field_type = EventChain
elif key in props: elif key in props:
# Set the field type. # Set the field type.
field_type = fields[key].type_ field_type = types.value_inside_optional(
types.get_field_type(type(self), key)
)
else: else:
continue continue
@ -436,7 +442,10 @@ class Component(BaseComponent, ABC):
try: try:
kwargs[key] = determine_key(value) kwargs[key] = determine_key(value)
expected_type = fields[key].outer_type_.__args__[0] expected_type = types.get_args(
types.get_field_type(type(self), key)
)[0]
# validate literal fields. # validate literal fields.
types.validate_literal( types.validate_literal(
key, value, expected_type, type(self).__name__ key, value, expected_type, type(self).__name__
@ -451,7 +460,7 @@ class Component(BaseComponent, ABC):
except TypeError: except TypeError:
# If it is not a valid var, check the base types. # If it is not a valid var, check the base types.
passed_type = type(value) passed_type = type(value)
expected_type = fields[key].outer_type_ expected_type = types.get_field_type(type(self), key)
if types.is_union(passed_type): if types.is_union(passed_type):
# We need to check all possible types in the union. # We need to check all possible types in the union.
passed_types = ( passed_types = (
@ -563,8 +572,11 @@ class Component(BaseComponent, ABC):
# Look for component specific triggers, # Look for component specific triggers,
# e.g. variable declared as EventHandler types. # e.g. variable declared as EventHandler types.
for field in self.get_fields().values(): for name, field in self.get_fields().items():
if types._issubclass(field.outer_type_, EventHandler): if types._issubclass(
types.value_inside_optional(types.get_field_type(type(self), name)),
EventHandler,
):
args_spec = None args_spec = None
annotation = field.annotation annotation = field.annotation
if (metadata := getattr(annotation, "__metadata__", None)) is not None: if (metadata := getattr(annotation, "__metadata__", None)) is not None:
@ -675,9 +687,11 @@ class Component(BaseComponent, ABC):
""" """
return { return {
name name
for name, field in cls.get_fields().items() for name in cls.get_fields()
if name in cls.get_props() if name in cls.get_props()
and types._issubclass(field.outer_type_, Component) and types._issubclass(
types.value_inside_optional(types.get_field_type(cls, name)), Component
)
} }
@classmethod @classmethod

View File

@ -38,7 +38,12 @@ from reflex import constants
from reflex.base import Base from reflex.base import Base
from reflex.utils import console from reflex.utils import console
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
from reflex.utils.types import GenericType, is_union, value_inside_optional from reflex.utils.types import (
GenericType,
is_union,
true_type_for_pydantic_field,
value_inside_optional,
)
try: try:
import pydantic.v1 as pydantic import pydantic.v1 as pydantic
@ -939,7 +944,9 @@ class Config(Base):
# If the env var is set, override the config value. # If the env var is set, override the config value.
if env_var is not None: if env_var is not None:
# Interpret the value. # Interpret the value.
value = interpret_env_var_value(env_var, field.outer_type_, field.name) value = interpret_env_var_value(
env_var, true_type_for_pydantic_field(field), field.name
)
# Set the value. # Set the value.
updated_values[key] = value updated_values[key] = value

View File

@ -115,9 +115,9 @@ from reflex.utils.serializers import serializer
from reflex.utils.types import ( from reflex.utils.types import (
_isinstance, _isinstance,
get_origin, get_origin,
is_optional,
is_union, is_union,
override, override,
true_type_for_pydantic_field,
value_inside_optional, value_inside_optional,
) )
from reflex.vars import VarData from reflex.vars import VarData
@ -293,7 +293,7 @@ if TYPE_CHECKING:
from pydantic.v1.fields import ModelField from pydantic.v1.fields import ModelField
def _unwrap_field_type(type_: Type) -> Type: def _unwrap_field_type(type_: types.GenericType) -> Type:
"""Unwrap rx.Field type annotations. """Unwrap rx.Field type annotations.
Args: Args:
@ -324,7 +324,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
return dispatch( return dispatch(
field_name=field_name, field_name=field_name,
var_data=VarData.from_state(cls, f.name), var_data=VarData.from_state(cls, f.name),
result_var_type=_unwrap_field_type(f.outer_type_), result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)),
) )
@ -1368,9 +1368,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if name in fields: if name in fields:
field = fields[name] field = fields[name]
field_type = _unwrap_field_type(field.outer_type_) field_type = _unwrap_field_type(true_type_for_pydantic_field(field))
if field.allow_none and not is_optional(field_type):
field_type = Union[field_type, None]
if not _isinstance(value, field_type): if not _isinstance(value, field_type):
console.error( console.error(
f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}'," f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}',"

View File

@ -14,6 +14,7 @@ from typing import (
Callable, Callable,
ClassVar, ClassVar,
Dict, Dict,
ForwardRef,
FrozenSet, FrozenSet,
Iterable, Iterable,
List, List,
@ -274,6 +275,33 @@ def is_optional(cls: GenericType) -> bool:
return is_union(cls) and type(None) in get_args(cls) return is_union(cls) and type(None) in get_args(cls)
def true_type_for_pydantic_field(f: ModelField):
"""Get the type for a pydantic field.
Args:
f: The field to get the type for.
Returns:
The type for the field.
"""
if not isinstance(f.annotation, (str, ForwardRef)):
return f.annotation
type_ = f.outer_type_
if (
f.field_info.default is None
or (isinstance(f.annotation, str) and f.annotation.startswith("Optional"))
or (
isinstance(f.annotation, ForwardRef)
and f.annotation.__forward_arg__.startswith("Optional")
)
) and not is_optional(type_):
return Optional[type_]
return type_
def value_inside_optional(cls: GenericType) -> GenericType: def value_inside_optional(cls: GenericType) -> GenericType:
"""Get the value inside an Optional type or the original type. """Get the value inside an Optional type or the original type.
@ -288,6 +316,29 @@ def value_inside_optional(cls: GenericType) -> GenericType:
return cls return cls
def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
"""Get the type of a field in a class.
Args:
cls: The class to check.
field_name: The name of the field to check.
Returns:
The type of the field, if it exists, else None.
"""
if (
hasattr(cls, "__fields__")
and field_name in cls.__fields__
and hasattr(cls.__fields__[field_name], "annotation")
and not isinstance(cls.__fields__[field_name].annotation, (str, ForwardRef))
):
return cls.__fields__[field_name].annotation
type_hints = get_type_hints(cls)
if field_name in type_hints:
return type_hints[field_name]
return None
def get_property_hint(attr: Any | None) -> GenericType | None: def get_property_hint(attr: Any | None) -> GenericType | None:
"""Check if an attribute is a property and return its type hint. """Check if an attribute is a property and return its type hint.
@ -325,24 +376,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
if hint := get_property_hint(attr): if hint := get_property_hint(attr):
return hint return hint
if ( if hasattr(cls, "__fields__") and name in cls.__fields__:
hasattr(cls, "__fields__")
and name in cls.__fields__
and hasattr(cls.__fields__[name], "outer_type_")
):
# pydantic models # pydantic models
field = cls.__fields__[name] return get_field_type(cls, name)
type_ = field.outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
if (
not field.required
and field.default is None
and field.default_factory is None
):
# Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_]
return type_
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase): elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
insp = sqlalchemy.inspect(cls) insp = sqlalchemy.inspect(cls)
if name in insp.columns: if name in insp.columns: