don't use _outer_type if we don't have to

This commit is contained in:
Khaleel Al-Adhami 2024-12-12 17:55:56 +03:00
parent 1b6f539657
commit c9fe2e10b3
4 changed files with 73 additions and 34 deletions

View File

@ -357,12 +357,16 @@ class Component(BaseComponent, ABC):
if field.name not in props: if field.name not in props:
continue continue
field_type = types.value_inside_optional(
types.get_field_type(cls, field.name)
)
# Set default values for any props. # Set default values for any props.
if types._issubclass(field.type_, Var): if types._issubclass(field_type, Var):
field.required = False field.required = False
if field.default is not None: if field.default is not None:
field.default = LiteralVar.create(field.default) field.default = LiteralVar.create(field.default)
elif types._issubclass(field.type_, EventHandler): elif types._issubclass(field_type, EventHandler):
field.required = False field.required = False
# Ensure renamed props from parent classes are applied to the subclass. # Ensure renamed props from parent classes are applied to the subclass.
@ -426,7 +430,9 @@ class Component(BaseComponent, ABC):
field_type = EventChain field_type = EventChain
elif key in props: elif key in props:
# Set the field type. # Set the field type.
field_type = fields[key].type_ field_type = types.value_inside_optional(
types.get_field_type(type(self), key)
)
else: else:
continue continue
@ -446,7 +452,10 @@ class Component(BaseComponent, ABC):
if kwargs[key] is None: if kwargs[key] is None:
raise TypeError raise TypeError
expected_type = fields[key].outer_type_.__args__[0] expected_type = types.get_args(
types.get_field_type(type(self), key)
)[0]
# validate literal fields. # validate literal fields.
types.validate_literal( types.validate_literal(
key, value, expected_type, type(self).__name__ key, value, expected_type, type(self).__name__
@ -461,7 +470,7 @@ class Component(BaseComponent, ABC):
except TypeError: except TypeError:
# If it is not a valid var, check the base types. # If it is not a valid var, check the base types.
passed_type = type(value) passed_type = type(value)
expected_type = fields[key].outer_type_ expected_type = types.get_field_type(type(self), key)
if types.is_union(passed_type): if types.is_union(passed_type):
# We need to check all possible types in the union. # We need to check all possible types in the union.
passed_types = ( passed_types = (
@ -674,8 +683,11 @@ class Component(BaseComponent, ABC):
# Look for component specific triggers, # Look for component specific triggers,
# e.g. variable declared as EventHandler types. # e.g. variable declared as EventHandler types.
for field in self.get_fields().values(): for name, field in self.get_fields().items():
if types._issubclass(field.outer_type_, EventHandler): if types._issubclass(
types.value_inside_optional(types.get_field_type(type(self), name)),
EventHandler,
):
args_spec = None args_spec = None
annotation = field.annotation annotation = field.annotation
if (metadata := getattr(annotation, "__metadata__", None)) is not None: if (metadata := getattr(annotation, "__metadata__", None)) is not None:
@ -787,9 +799,11 @@ class Component(BaseComponent, ABC):
""" """
return { return {
name name
for name, field in cls.get_fields().items() for name in cls.get_fields()
if name in cls.get_props() if name in cls.get_props()
and types._issubclass(field.outer_type_, Component) and types._issubclass(
types.value_inside_optional(types.get_field_type(cls, name)), Component
)
} }
@classmethod @classmethod

View File

@ -27,7 +27,12 @@ from typing import (
from typing_extensions import Annotated, get_type_hints from typing_extensions import Annotated, get_type_hints
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
from reflex.utils.types import GenericType, is_union, value_inside_optional from reflex.utils.types import (
GenericType,
is_union,
true_type_for_pydantic_field,
value_inside_optional,
)
try: try:
import pydantic.v1 as pydantic import pydantic.v1 as pydantic
@ -759,7 +764,9 @@ class Config(Base):
# If the env var is set, override the config value. # If the env var is set, override the config value.
if env_var is not None: if env_var is not None:
# Interpret the value. # Interpret the value.
value = interpret_env_var_value(env_var, field.outer_type_, field.name) value = interpret_env_var_value(
env_var, true_type_for_pydantic_field(field), field.name
)
# Set the value. # Set the value.
updated_values[key] = value updated_values[key] = value

View File

@ -107,9 +107,9 @@ from reflex.utils.serializers import serializer
from reflex.utils.types import ( from reflex.utils.types import (
_isinstance, _isinstance,
get_origin, get_origin,
is_optional,
is_union, is_union,
override, override,
true_type_for_pydantic_field,
value_inside_optional, value_inside_optional,
) )
from reflex.vars import VarData from reflex.vars import VarData
@ -282,7 +282,7 @@ if TYPE_CHECKING:
from pydantic.v1.fields import ModelField from pydantic.v1.fields import ModelField
def _unwrap_field_type(type_: Type) -> Type: def _unwrap_field_type(type_: types.GenericType) -> Type:
"""Unwrap rx.Field type annotations. """Unwrap rx.Field type annotations.
Args: Args:
@ -313,7 +313,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
return dispatch( return dispatch(
field_name=field_name, field_name=field_name,
var_data=VarData.from_state(cls, f.name), var_data=VarData.from_state(cls, f.name),
result_var_type=_unwrap_field_type(f.outer_type_), result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)),
) )
@ -1329,9 +1329,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if name in fields: if name in fields:
field = fields[name] field = fields[name]
field_type = _unwrap_field_type(field.outer_type_) field_type = _unwrap_field_type(true_type_for_pydantic_field(field))
if field.allow_none and not is_optional(field_type):
field_type = Union[field_type, None]
if not _isinstance(value, field_type): if not _isinstance(value, field_type):
console.deprecate( console.deprecate(
"mismatched-type-assignment", "mismatched-type-assignment",

View File

@ -269,6 +269,25 @@ def is_optional(cls: GenericType) -> bool:
return is_union(cls) and type(None) in get_args(cls) return is_union(cls) and type(None) in get_args(cls)
def true_type_for_pydantic_field(f: ModelField):
"""Get the type for a pydantic field.
Args:
f: The field to get the type for.
Returns:
The type for the field.
"""
outer_type = f.outer_type_
if (
f.allow_none
and not is_optional(outer_type)
and outer_type not in (None, type(None))
):
return Optional[outer_type]
return outer_type
def value_inside_optional(cls: GenericType) -> GenericType: def value_inside_optional(cls: GenericType) -> GenericType:
"""Get the value inside an Optional type or the original type. """Get the value inside an Optional type or the original type.
@ -283,6 +302,22 @@ def value_inside_optional(cls: GenericType) -> GenericType:
return cls return cls
def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
"""Get the type of a field in a class.
Args:
cls: The class to check.
field_name: The name of the field to check.
Returns:
The type of the field, if it exists, else None.
"""
type_hints = get_type_hints(cls)
if field_name in type_hints:
return type_hints[field_name]
return None
def get_property_hint(attr: Any | None) -> GenericType | None: def get_property_hint(attr: Any | None) -> GenericType | None:
"""Check if an attribute is a property and return its type hint. """Check if an attribute is a property and return its type hint.
@ -320,24 +355,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
if hint := get_property_hint(attr): if hint := get_property_hint(attr):
return hint return hint
if ( if hasattr(cls, "__fields__") and name in cls.__fields__:
hasattr(cls, "__fields__")
and name in cls.__fields__
and hasattr(cls.__fields__[name], "outer_type_")
):
# pydantic models # pydantic models
field = cls.__fields__[name] return get_field_type(cls, name)
type_ = field.outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
if (
not field.required
and field.default is None
and field.default_factory is None
):
# Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_]
return type_
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase): elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
insp = sqlalchemy.inspect(cls) insp = sqlalchemy.inspect(cls)
if name in insp.columns: if name in insp.columns: