Compare commits

...

5 Commits

Author SHA1 Message Date
Khaleel Al-Adhami
9e963338ef add check for default value being null 2024-12-17 19:26:58 +03:00
Khaleel Al-Adhami
3fb8450fb5 check against optional in annotation str 2024-12-13 23:22:47 +03:00
Khaleel Al-Adhami
d3fbf123b5 have a shorter path for get_field_type if it's nice 2024-12-13 07:14:33 +03:00
Khaleel Al-Adhami
c2a39b46d7 apparently we should use .annotation, and .allow_none is useless 2024-12-13 07:12:16 +03:00
Khaleel Al-Adhami
c9fe2e10b3 don't use _outer_type if we don't have to 2024-12-12 17:55:56 +03:00
4 changed files with 89 additions and 34 deletions

View File

@ -357,12 +357,16 @@ class Component(BaseComponent, ABC):
if field.name not in props:
continue
field_type = types.value_inside_optional(
types.get_field_type(cls, field.name)
)
# Set default values for any props.
if types._issubclass(field.type_, Var):
if types._issubclass(field_type, Var):
field.required = False
if field.default is not None:
field.default = LiteralVar.create(field.default)
elif types._issubclass(field.type_, EventHandler):
elif types._issubclass(field_type, EventHandler):
field.required = False
# Ensure renamed props from parent classes are applied to the subclass.
@ -426,7 +430,9 @@ class Component(BaseComponent, ABC):
field_type = EventChain
elif key in props:
# Set the field type.
field_type = fields[key].type_
field_type = types.value_inside_optional(
types.get_field_type(type(self), key)
)
else:
continue
@ -446,7 +452,10 @@ class Component(BaseComponent, ABC):
if kwargs[key] is None:
raise TypeError
expected_type = fields[key].outer_type_.__args__[0]
expected_type = types.get_args(
types.get_field_type(type(self), key)
)[0]
# validate literal fields.
types.validate_literal(
key, value, expected_type, type(self).__name__
@ -461,7 +470,7 @@ class Component(BaseComponent, ABC):
except TypeError:
# If it is not a valid var, check the base types.
passed_type = type(value)
expected_type = fields[key].outer_type_
expected_type = types.get_field_type(type(self), key)
if types.is_union(passed_type):
# We need to check all possible types in the union.
passed_types = (
@ -674,8 +683,11 @@ class Component(BaseComponent, ABC):
# Look for component specific triggers,
# e.g. variable declared as EventHandler types.
for field in self.get_fields().values():
if types._issubclass(field.outer_type_, EventHandler):
for name, field in self.get_fields().items():
if types._issubclass(
types.value_inside_optional(types.get_field_type(type(self), name)),
EventHandler,
):
args_spec = None
annotation = field.annotation
if (metadata := getattr(annotation, "__metadata__", None)) is not None:
@ -787,9 +799,11 @@ class Component(BaseComponent, ABC):
"""
return {
name
for name, field in cls.get_fields().items()
for name in cls.get_fields()
if name in cls.get_props()
and types._issubclass(field.outer_type_, Component)
and types._issubclass(
types.value_inside_optional(types.get_field_type(cls, name)), Component
)
}
@classmethod

View File

@ -27,7 +27,12 @@ from typing import (
from typing_extensions import Annotated, get_type_hints
from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
from reflex.utils.types import GenericType, is_union, value_inside_optional
from reflex.utils.types import (
GenericType,
is_union,
true_type_for_pydantic_field,
value_inside_optional,
)
try:
import pydantic.v1 as pydantic
@ -759,7 +764,9 @@ class Config(Base):
# If the env var is set, override the config value.
if env_var is not None:
# Interpret the value.
value = interpret_env_var_value(env_var, field.outer_type_, field.name)
value = interpret_env_var_value(
env_var, true_type_for_pydantic_field(field), field.name
)
# Set the value.
updated_values[key] = value

View File

@ -107,9 +107,9 @@ from reflex.utils.serializers import serializer
from reflex.utils.types import (
_isinstance,
get_origin,
is_optional,
is_union,
override,
true_type_for_pydantic_field,
value_inside_optional,
)
from reflex.vars import VarData
@ -282,7 +282,7 @@ if TYPE_CHECKING:
from pydantic.v1.fields import ModelField
def _unwrap_field_type(type_: Type) -> Type:
def _unwrap_field_type(type_: types.GenericType) -> Type:
"""Unwrap rx.Field type annotations.
Args:
@ -313,7 +313,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
return dispatch(
field_name=field_name,
var_data=VarData.from_state(cls, f.name),
result_var_type=_unwrap_field_type(f.outer_type_),
result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)),
)
@ -1329,9 +1329,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if name in fields:
field = fields[name]
field_type = _unwrap_field_type(field.outer_type_)
if field.allow_none and not is_optional(field_type):
field_type = Union[field_type, None]
field_type = _unwrap_field_type(true_type_for_pydantic_field(field))
if not _isinstance(value, field_type):
console.deprecate(
"mismatched-type-assignment",

View File

@ -14,6 +14,7 @@ from typing import (
Callable,
ClassVar,
Dict,
ForwardRef,
FrozenSet,
Iterable,
List,
@ -269,6 +270,33 @@ def is_optional(cls: GenericType) -> bool:
return is_union(cls) and type(None) in get_args(cls)
def true_type_for_pydantic_field(f: ModelField):
"""Get the type for a pydantic field.
Args:
f: The field to get the type for.
Returns:
The type for the field.
"""
if not isinstance(f.annotation, (str, ForwardRef)):
return f.annotation
type_ = f.outer_type_
if (
f.field_info.default is None
or (isinstance(f.annotation, str) and f.annotation.startswith("Optional"))
or (
isinstance(f.annotation, ForwardRef)
and f.annotation.__forward_arg__.startswith("Optional")
)
) and not is_optional(type_):
return Optional[type_]
return type_
def value_inside_optional(cls: GenericType) -> GenericType:
"""Get the value inside an Optional type or the original type.
@ -283,6 +311,29 @@ def value_inside_optional(cls: GenericType) -> GenericType:
return cls
def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
"""Get the type of a field in a class.
Args:
cls: The class to check.
field_name: The name of the field to check.
Returns:
The type of the field, if it exists, else None.
"""
if (
hasattr(cls, "__fields__")
and field_name in cls.__fields__
and hasattr(cls.__fields__[field_name], "annotation")
and not isinstance(cls.__fields__[field_name].annotation, (str, ForwardRef))
):
return cls.__fields__[field_name].annotation
type_hints = get_type_hints(cls)
if field_name in type_hints:
return type_hints[field_name]
return None
def get_property_hint(attr: Any | None) -> GenericType | None:
"""Check if an attribute is a property and return its type hint.
@ -320,24 +371,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
if hint := get_property_hint(attr):
return hint
if (
hasattr(cls, "__fields__")
and name in cls.__fields__
and hasattr(cls.__fields__[name], "outer_type_")
):
if hasattr(cls, "__fields__") and name in cls.__fields__:
# pydantic models
field = cls.__fields__[name]
type_ = field.outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
if (
not field.required
and field.default is None
and field.default_factory is None
):
# Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_]
return type_
return get_field_type(cls, name)
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
insp = sqlalchemy.inspect(cls)
if name in insp.columns: