diff --git a/reflex/components/component.py b/reflex/components/component.py index 005f7791d..41ae80ece 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -345,12 +345,16 @@ class Component(BaseComponent, ABC): if field.name not in props: continue + field_type = types.value_inside_optional( + types.get_field_type(cls, field.name) + ) + # Set default values for any props. - if types._issubclass(field.type_, Var): + if types._issubclass(field_type, Var): field.required = False if field.default is not None: field.default = LiteralVar.create(field.default) - elif types._issubclass(field.type_, EventHandler): + elif types._issubclass(field_type, EventHandler): field.required = False # Ensure renamed props from parent classes are applied to the subclass. @@ -414,7 +418,9 @@ class Component(BaseComponent, ABC): field_type = EventChain elif key in props: # Set the field type. - field_type = fields[key].type_ + field_type = types.value_inside_optional( + types.get_field_type(type(self), key) + ) else: continue @@ -436,7 +442,10 @@ class Component(BaseComponent, ABC): try: kwargs[key] = determine_key(value) - expected_type = fields[key].outer_type_.__args__[0] + expected_type = types.get_args( + types.get_field_type(type(self), key) + )[0] + # validate literal fields. types.validate_literal( key, value, expected_type, type(self).__name__ @@ -451,7 +460,7 @@ class Component(BaseComponent, ABC): except TypeError: # If it is not a valid var, check the base types. passed_type = type(value) - expected_type = fields[key].outer_type_ + expected_type = types.get_field_type(type(self), key) if types.is_union(passed_type): # We need to check all possible types in the union. passed_types = ( @@ -563,8 +572,11 @@ class Component(BaseComponent, ABC): # Look for component specific triggers, # e.g. variable declared as EventHandler types. - for field in self.get_fields().values(): - if types._issubclass(field.outer_type_, EventHandler): + for name, field in self.get_fields().items(): + if types._issubclass( + types.value_inside_optional(types.get_field_type(type(self), name)), + EventHandler, + ): args_spec = None annotation = field.annotation if (metadata := getattr(annotation, "__metadata__", None)) is not None: @@ -675,9 +687,11 @@ class Component(BaseComponent, ABC): """ return { name - for name, field in cls.get_fields().items() + for name in cls.get_fields() if name in cls.get_props() - and types._issubclass(field.outer_type_, Component) + and types._issubclass( + types.value_inside_optional(types.get_field_type(cls, name)), Component + ) } @classmethod diff --git a/reflex/config.py b/reflex/config.py index d0829e627..68feedd1b 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -38,7 +38,12 @@ from reflex import constants from reflex.base import Base from reflex.utils import console from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError -from reflex.utils.types import GenericType, is_union, value_inside_optional +from reflex.utils.types import ( + GenericType, + is_union, + true_type_for_pydantic_field, + value_inside_optional, +) try: import pydantic.v1 as pydantic @@ -939,7 +944,9 @@ class Config(Base): # If the env var is set, override the config value. if env_var is not None: # Interpret the value. - value = interpret_env_var_value(env_var, field.outer_type_, field.name) + value = interpret_env_var_value( + env_var, true_type_for_pydantic_field(field), field.name + ) # Set the value. updated_values[key] = value diff --git a/reflex/state.py b/reflex/state.py index 77c352cfa..a0578af7d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -115,9 +115,9 @@ from reflex.utils.serializers import serializer from reflex.utils.types import ( _isinstance, get_origin, - is_optional, is_union, override, + true_type_for_pydantic_field, value_inside_optional, ) from reflex.vars import VarData @@ -293,7 +293,7 @@ if TYPE_CHECKING: from pydantic.v1.fields import ModelField -def _unwrap_field_type(type_: Type) -> Type: +def _unwrap_field_type(type_: types.GenericType) -> Type: """Unwrap rx.Field type annotations. Args: @@ -324,7 +324,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField): return dispatch( field_name=field_name, var_data=VarData.from_state(cls, f.name), - result_var_type=_unwrap_field_type(f.outer_type_), + result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)), ) @@ -1368,9 +1368,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if name in fields: field = fields[name] - field_type = _unwrap_field_type(field.outer_type_) - if field.allow_none and not is_optional(field_type): - field_type = Union[field_type, None] + field_type = _unwrap_field_type(true_type_for_pydantic_field(field)) if not _isinstance(value, field_type): console.error( f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}'," diff --git a/reflex/utils/types.py b/reflex/utils/types.py index b432319e0..cdca50388 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -14,6 +14,7 @@ from typing import ( Callable, ClassVar, Dict, + ForwardRef, FrozenSet, Iterable, List, @@ -274,6 +275,33 @@ def is_optional(cls: GenericType) -> bool: return is_union(cls) and type(None) in get_args(cls) +def true_type_for_pydantic_field(f: ModelField): + """Get the type for a pydantic field. + + Args: + f: The field to get the type for. + + Returns: + The type for the field. + """ + if not isinstance(f.annotation, (str, ForwardRef)): + return f.annotation + + type_ = f.outer_type_ + + if ( + f.field_info.default is None + or (isinstance(f.annotation, str) and f.annotation.startswith("Optional")) + or ( + isinstance(f.annotation, ForwardRef) + and f.annotation.__forward_arg__.startswith("Optional") + ) + ) and not is_optional(type_): + return Optional[type_] + + return type_ + + def value_inside_optional(cls: GenericType) -> GenericType: """Get the value inside an Optional type or the original type. @@ -288,6 +316,29 @@ def value_inside_optional(cls: GenericType) -> GenericType: return cls +def get_field_type(cls: GenericType, field_name: str) -> GenericType | None: + """Get the type of a field in a class. + + Args: + cls: The class to check. + field_name: The name of the field to check. + + Returns: + The type of the field, if it exists, else None. + """ + if ( + hasattr(cls, "__fields__") + and field_name in cls.__fields__ + and hasattr(cls.__fields__[field_name], "annotation") + and not isinstance(cls.__fields__[field_name].annotation, (str, ForwardRef)) + ): + return cls.__fields__[field_name].annotation + type_hints = get_type_hints(cls) + if field_name in type_hints: + return type_hints[field_name] + return None + + def get_property_hint(attr: Any | None) -> GenericType | None: """Check if an attribute is a property and return its type hint. @@ -325,24 +376,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None if hint := get_property_hint(attr): return hint - if ( - hasattr(cls, "__fields__") - and name in cls.__fields__ - and hasattr(cls.__fields__[name], "outer_type_") - ): + if hasattr(cls, "__fields__") and name in cls.__fields__: # pydantic models - field = cls.__fields__[name] - type_ = field.outer_type_ - if isinstance(type_, ModelField): - type_ = type_.type_ - if ( - not field.required - and field.default is None - and field.default_factory is None - ): - # Ensure frontend uses null coalescing when accessing. - type_ = Optional[type_] - return type_ + return get_field_type(cls, name) elif isinstance(cls, type) and issubclass(cls, DeclarativeBase): insp = sqlalchemy.inspect(cls) if name in insp.columns: