From c9fe2e10b36286bdce611d6e6cd39bf0f5b4c4e7 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 12 Dec 2024 17:55:56 +0300 Subject: [PATCH 1/5] don't use _outer_type if we don't have to --- reflex/components/component.py | 32 ++++++++++++++------ reflex/config.py | 11 +++++-- reflex/state.py | 10 +++---- reflex/utils/types.py | 54 +++++++++++++++++++++++----------- 4 files changed, 73 insertions(+), 34 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index fd7c93cbd..9cf96bcd7 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -357,12 +357,16 @@ class Component(BaseComponent, ABC): if field.name not in props: continue + field_type = types.value_inside_optional( + types.get_field_type(cls, field.name) + ) + # Set default values for any props. - if types._issubclass(field.type_, Var): + if types._issubclass(field_type, Var): field.required = False if field.default is not None: field.default = LiteralVar.create(field.default) - elif types._issubclass(field.type_, EventHandler): + elif types._issubclass(field_type, EventHandler): field.required = False # Ensure renamed props from parent classes are applied to the subclass. @@ -426,7 +430,9 @@ class Component(BaseComponent, ABC): field_type = EventChain elif key in props: # Set the field type. - field_type = fields[key].type_ + field_type = types.value_inside_optional( + types.get_field_type(type(self), key) + ) else: continue @@ -446,7 +452,10 @@ class Component(BaseComponent, ABC): if kwargs[key] is None: raise TypeError - expected_type = fields[key].outer_type_.__args__[0] + expected_type = types.get_args( + types.get_field_type(type(self), key) + )[0] + # validate literal fields. types.validate_literal( key, value, expected_type, type(self).__name__ @@ -461,7 +470,7 @@ class Component(BaseComponent, ABC): except TypeError: # If it is not a valid var, check the base types. passed_type = type(value) - expected_type = fields[key].outer_type_ + expected_type = types.get_field_type(type(self), key) if types.is_union(passed_type): # We need to check all possible types in the union. passed_types = ( @@ -674,8 +683,11 @@ class Component(BaseComponent, ABC): # Look for component specific triggers, # e.g. variable declared as EventHandler types. - for field in self.get_fields().values(): - if types._issubclass(field.outer_type_, EventHandler): + for name, field in self.get_fields().items(): + if types._issubclass( + types.value_inside_optional(types.get_field_type(type(self), name)), + EventHandler, + ): args_spec = None annotation = field.annotation if (metadata := getattr(annotation, "__metadata__", None)) is not None: @@ -787,9 +799,11 @@ class Component(BaseComponent, ABC): """ return { name - for name, field in cls.get_fields().items() + for name in cls.get_fields() if name in cls.get_props() - and types._issubclass(field.outer_type_, Component) + and types._issubclass( + types.value_inside_optional(types.get_field_type(cls, name)), Component + ) } @classmethod diff --git a/reflex/config.py b/reflex/config.py index ae2c0ea0e..6606547cc 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -27,7 +27,12 @@ from typing import ( from typing_extensions import Annotated, get_type_hints from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError -from reflex.utils.types import GenericType, is_union, value_inside_optional +from reflex.utils.types import ( + GenericType, + is_union, + true_type_for_pydantic_field, + value_inside_optional, +) try: import pydantic.v1 as pydantic @@ -759,7 +764,9 @@ class Config(Base): # If the env var is set, override the config value. if env_var is not None: # Interpret the value. - value = interpret_env_var_value(env_var, field.outer_type_, field.name) + value = interpret_env_var_value( + env_var, true_type_for_pydantic_field(field), field.name + ) # Set the value. updated_values[key] = value diff --git a/reflex/state.py b/reflex/state.py index f4a9d2d57..0ebe2a8ff 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -107,9 +107,9 @@ from reflex.utils.serializers import serializer from reflex.utils.types import ( _isinstance, get_origin, - is_optional, is_union, override, + true_type_for_pydantic_field, value_inside_optional, ) from reflex.vars import VarData @@ -282,7 +282,7 @@ if TYPE_CHECKING: from pydantic.v1.fields import ModelField -def _unwrap_field_type(type_: Type) -> Type: +def _unwrap_field_type(type_: types.GenericType) -> Type: """Unwrap rx.Field type annotations. Args: @@ -313,7 +313,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField): return dispatch( field_name=field_name, var_data=VarData.from_state(cls, f.name), - result_var_type=_unwrap_field_type(f.outer_type_), + result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)), ) @@ -1329,9 +1329,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if name in fields: field = fields[name] - field_type = _unwrap_field_type(field.outer_type_) - if field.allow_none and not is_optional(field_type): - field_type = Union[field_type, None] + field_type = _unwrap_field_type(true_type_for_pydantic_field(field)) if not _isinstance(value, field_type): console.deprecate( "mismatched-type-assignment", diff --git a/reflex/utils/types.py b/reflex/utils/types.py index b8bcbf2d6..d8c787a8d 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -269,6 +269,25 @@ def is_optional(cls: GenericType) -> bool: return is_union(cls) and type(None) in get_args(cls) +def true_type_for_pydantic_field(f: ModelField): + """Get the type for a pydantic field. + + Args: + f: The field to get the type for. + + Returns: + The type for the field. + """ + outer_type = f.outer_type_ + if ( + f.allow_none + and not is_optional(outer_type) + and outer_type not in (None, type(None)) + ): + return Optional[outer_type] + return outer_type + + def value_inside_optional(cls: GenericType) -> GenericType: """Get the value inside an Optional type or the original type. @@ -283,6 +302,22 @@ def value_inside_optional(cls: GenericType) -> GenericType: return cls +def get_field_type(cls: GenericType, field_name: str) -> GenericType | None: + """Get the type of a field in a class. + + Args: + cls: The class to check. + field_name: The name of the field to check. + + Returns: + The type of the field, if it exists, else None. + """ + type_hints = get_type_hints(cls) + if field_name in type_hints: + return type_hints[field_name] + return None + + def get_property_hint(attr: Any | None) -> GenericType | None: """Check if an attribute is a property and return its type hint. @@ -320,24 +355,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None if hint := get_property_hint(attr): return hint - if ( - hasattr(cls, "__fields__") - and name in cls.__fields__ - and hasattr(cls.__fields__[name], "outer_type_") - ): + if hasattr(cls, "__fields__") and name in cls.__fields__: # pydantic models - field = cls.__fields__[name] - type_ = field.outer_type_ - if isinstance(type_, ModelField): - type_ = type_.type_ - if ( - not field.required - and field.default is None - and field.default_factory is None - ): - # Ensure frontend uses null coalescing when accessing. - type_ = Optional[type_] - return type_ + return get_field_type(cls, name) elif isinstance(cls, type) and issubclass(cls, DeclarativeBase): insp = sqlalchemy.inspect(cls) if name in insp.columns: From c2a39b46d73e2a405cb05aa25773dabdda6e43f3 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 13 Dec 2024 07:12:16 +0300 Subject: [PATCH 2/5] apparently we should use .annotation, and .allow_none is useless --- reflex/utils/types.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index d8c787a8d..cba8ec02a 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -14,6 +14,7 @@ from typing import ( Callable, ClassVar, Dict, + ForwardRef, FrozenSet, Iterable, List, @@ -278,14 +279,9 @@ def true_type_for_pydantic_field(f: ModelField): Returns: The type for the field. """ - outer_type = f.outer_type_ - if ( - f.allow_none - and not is_optional(outer_type) - and outer_type not in (None, type(None)) - ): - return Optional[outer_type] - return outer_type + if not isinstance(f.annotation, (str, ForwardRef)): + return f.annotation + return f.outer_type_ def value_inside_optional(cls: GenericType) -> GenericType: From d3fbf123b52714f3a4c97cc01231fd5d112d8148 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 13 Dec 2024 07:14:33 +0300 Subject: [PATCH 3/5] have a shorter path for get_field_type if it's nice --- reflex/utils/types.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index cba8ec02a..184eebf7c 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -308,6 +308,13 @@ def get_field_type(cls: GenericType, field_name: str) -> GenericType | None: Returns: The type of the field, if it exists, else None. """ + if ( + hasattr(cls, "__fields__") + and field_name in cls.__fields__ + and hasattr(cls.__fields__[field_name], "annotation") + and not isinstance(cls.__fields__[field_name].annotation, (str, ForwardRef)) + ): + return cls.__fields__[field_name].annotation type_hints = get_type_hints(cls) if field_name in type_hints: return type_hints[field_name] From 3fb8450fb5d928f07c6d7440193bc53863993218 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 13 Dec 2024 23:22:47 +0300 Subject: [PATCH 4/5] check against optional in annotation str --- reflex/utils/types.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 184eebf7c..4e54aa70f 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -281,7 +281,19 @@ def true_type_for_pydantic_field(f: ModelField): """ if not isinstance(f.annotation, (str, ForwardRef)): return f.annotation - return f.outer_type_ + + type_ = f.outer_type_ + + if ( + (isinstance(f.annotation, str) and f.annotation.startswith("Optional")) + or ( + isinstance(f.annotation, ForwardRef) + and f.annotation.__forward_arg__.startswith("Optional") + ) + ) and not is_optional(type_): + return Optional[type_] + + return type_ def value_inside_optional(cls: GenericType) -> GenericType: From 9e963338efc7b737134735db8a85d9a3e7ad174b Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Tue, 17 Dec 2024 19:26:58 +0300 Subject: [PATCH 5/5] add check for default value being null --- reflex/utils/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 4e54aa70f..d39f1b483 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -285,7 +285,8 @@ def true_type_for_pydantic_field(f: ModelField): type_ = f.outer_type_ if ( - (isinstance(f.annotation, str) and f.annotation.startswith("Optional")) + f.field_info.default is None + or (isinstance(f.annotation, str) and f.annotation.startswith("Optional")) or ( isinstance(f.annotation, ForwardRef) and f.annotation.__forward_arg__.startswith("Optional")