diff --git a/reflex/state.py b/reflex/state.py index c86bd9a1b..5e132ef65 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -104,6 +104,7 @@ from reflex.utils.serializers import serializer from reflex.utils.types import ( _isinstance, get_origin, + is_optional, is_union, override, value_inside_optional, @@ -278,6 +279,22 @@ if TYPE_CHECKING: from pydantic.v1.fields import ModelField +def _unwrap_field_type(type_: Type) -> Type: + """Unwrap rx.Field type annotations. + + Args: + type_: The type to unwrap. + + Returns: + The unwrapped type. + """ + from reflex.vars import Field + + if get_origin(type_) is Field: + return get_args(type_)[0] + return type_ + + def get_var_for_field(cls: Type[BaseState], f: ModelField): """Get a Var instance for a Pydantic field. @@ -288,16 +305,12 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField): Returns: The Var instance. """ - from reflex.vars import Field - field_name = format.format_state_name(cls.get_full_name()) + "." + f.name return dispatch( field_name=field_name, var_data=VarData.from_state(cls, f.name), - result_var_type=f.outer_type_ - if get_origin(f.outer_type_) is not Field - else get_args(f.outer_type_)[0], + result_var_type=_unwrap_field_type(f.outer_type_), ) @@ -1313,8 +1326,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if name in fields: field = fields[name] - field_type = field.outer_type_ - if field.allow_none: + field_type = _unwrap_field_type(field.outer_type_) + if field.allow_none and not is_optional(field_type): field_type = Union[field_type, None] if not _isinstance(value, field_type): console.deprecate(