From a4f8e18ff115b3d3a1d212305bdbba9c3d79570a Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Wed, 4 Dec 2024 23:49:39 +0100 Subject: [PATCH] fix: only make type optional if it's not already, add helper to unwrap rx.Field --- reflex/state.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 55f29cf45..5317c109f 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -104,6 +104,7 @@ from reflex.utils.serializers import serializer from reflex.utils.types import ( _isinstance, get_origin, + is_optional, is_union, override, value_inside_optional, @@ -278,6 +279,22 @@ if TYPE_CHECKING: from pydantic.v1.fields import ModelField +def unwrap_field_type(type_: Type) -> Type: + """Unwrap rx.Field type annotations. + + Args: + type_: The type to unwrap. + + Returns: + The unwrapped type. + """ + from reflex.vars import Field + + if get_origin(type_) is Field: + return get_args(type_)[0] + return type_ + + def get_var_for_field(cls: Type[BaseState], f: ModelField): """Get a Var instance for a Pydantic field. @@ -288,16 +305,12 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField): Returns: The Var instance. """ - from reflex.vars import Field - field_name = format.format_state_name(cls.get_full_name()) + "." + f.name return dispatch( field_name=field_name, var_data=VarData.from_state(cls, f.name), - result_var_type=f.outer_type_ - if get_origin(f.outer_type_) is not Field - else get_args(f.outer_type_)[0], + result_var_type=unwrap_field_type(f.outer_type_), ) @@ -1310,8 +1323,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if name in fields: field = fields[name] - field_type = field.outer_type_ - if field.allow_none: + field_type = unwrap_field_type(field.outer_type_) + if field.allow_none and not is_optional(field_type): field_type = Union[field_type, None] if not _isinstance(value, field_type): console.deprecate(