From c9fe2e10b36286bdce611d6e6cd39bf0f5b4c4e7 Mon Sep 17 00:00:00 2001
From: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
Date: Thu, 12 Dec 2024 17:55:56 +0300
Subject: [PATCH] don't use _outer_type if we don't have to

---
 reflex/components/component.py | 32 ++++++++++++++------
 reflex/config.py               | 11 +++++--
 reflex/state.py                | 10 +++----
 reflex/utils/types.py          | 54 +++++++++++++++++++++++-----------
 4 files changed, 73 insertions(+), 34 deletions(-)

diff --git a/reflex/components/component.py b/reflex/components/component.py
index fd7c93cbd..9cf96bcd7 100644
--- a/reflex/components/component.py
+++ b/reflex/components/component.py
@@ -357,12 +357,16 @@ class Component(BaseComponent, ABC):
             if field.name not in props:
                 continue
 
+            field_type = types.value_inside_optional(
+                types.get_field_type(cls, field.name)
+            )
+
             # Set default values for any props.
-            if types._issubclass(field.type_, Var):
+            if types._issubclass(field_type, Var):
                 field.required = False
                 if field.default is not None:
                     field.default = LiteralVar.create(field.default)
-            elif types._issubclass(field.type_, EventHandler):
+            elif types._issubclass(field_type, EventHandler):
                 field.required = False
 
         # Ensure renamed props from parent classes are applied to the subclass.
@@ -426,7 +430,9 @@ class Component(BaseComponent, ABC):
                 field_type = EventChain
             elif key in props:
                 # Set the field type.
-                field_type = fields[key].type_
+                field_type = types.value_inside_optional(
+                    types.get_field_type(type(self), key)
+                )
 
             else:
                 continue
@@ -446,7 +452,10 @@ class Component(BaseComponent, ABC):
                     if kwargs[key] is None:
                         raise TypeError
 
-                    expected_type = fields[key].outer_type_.__args__[0]
+                    expected_type = types.get_args(
+                        types.get_field_type(type(self), key)
+                    )[0]
+
                     # validate literal fields.
                     types.validate_literal(
                         key, value, expected_type, type(self).__name__
@@ -461,7 +470,7 @@ class Component(BaseComponent, ABC):
                 except TypeError:
                     # If it is not a valid var, check the base types.
                     passed_type = type(value)
-                    expected_type = fields[key].outer_type_
+                    expected_type = types.get_field_type(type(self), key)
                 if types.is_union(passed_type):
                     # We need to check all possible types in the union.
                     passed_types = (
@@ -674,8 +683,11 @@ class Component(BaseComponent, ABC):
 
         # Look for component specific triggers,
         # e.g. variable declared as EventHandler types.
-        for field in self.get_fields().values():
-            if types._issubclass(field.outer_type_, EventHandler):
+        for name, field in self.get_fields().items():
+            if types._issubclass(
+                types.value_inside_optional(types.get_field_type(type(self), name)),
+                EventHandler,
+            ):
                 args_spec = None
                 annotation = field.annotation
                 if (metadata := getattr(annotation, "__metadata__", None)) is not None:
@@ -787,9 +799,11 @@ class Component(BaseComponent, ABC):
         """
         return {
             name
-            for name, field in cls.get_fields().items()
+            for name in cls.get_fields()
             if name in cls.get_props()
-            and types._issubclass(field.outer_type_, Component)
+            and types._issubclass(
+                types.value_inside_optional(types.get_field_type(cls, name)), Component
+            )
         }
 
     @classmethod
diff --git a/reflex/config.py b/reflex/config.py
index ae2c0ea0e..6606547cc 100644
--- a/reflex/config.py
+++ b/reflex/config.py
@@ -27,7 +27,12 @@ from typing import (
 from typing_extensions import Annotated, get_type_hints
 
 from reflex.utils.exceptions import ConfigError, EnvironmentVarValueError
-from reflex.utils.types import GenericType, is_union, value_inside_optional
+from reflex.utils.types import (
+    GenericType,
+    is_union,
+    true_type_for_pydantic_field,
+    value_inside_optional,
+)
 
 try:
     import pydantic.v1 as pydantic
@@ -759,7 +764,9 @@ class Config(Base):
             # If the env var is set, override the config value.
             if env_var is not None:
                 # Interpret the value.
-                value = interpret_env_var_value(env_var, field.outer_type_, field.name)
+                value = interpret_env_var_value(
+                    env_var, true_type_for_pydantic_field(field), field.name
+                )
 
                 # Set the value.
                 updated_values[key] = value
diff --git a/reflex/state.py b/reflex/state.py
index f4a9d2d57..0ebe2a8ff 100644
--- a/reflex/state.py
+++ b/reflex/state.py
@@ -107,9 +107,9 @@ from reflex.utils.serializers import serializer
 from reflex.utils.types import (
     _isinstance,
     get_origin,
-    is_optional,
     is_union,
     override,
+    true_type_for_pydantic_field,
     value_inside_optional,
 )
 from reflex.vars import VarData
@@ -282,7 +282,7 @@ if TYPE_CHECKING:
     from pydantic.v1.fields import ModelField
 
 
-def _unwrap_field_type(type_: Type) -> Type:
+def _unwrap_field_type(type_: types.GenericType) -> Type:
     """Unwrap rx.Field type annotations.
 
     Args:
@@ -313,7 +313,7 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
     return dispatch(
         field_name=field_name,
         var_data=VarData.from_state(cls, f.name),
-        result_var_type=_unwrap_field_type(f.outer_type_),
+        result_var_type=_unwrap_field_type(true_type_for_pydantic_field(f)),
     )
 
 
@@ -1329,9 +1329,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
 
         if name in fields:
             field = fields[name]
-            field_type = _unwrap_field_type(field.outer_type_)
-            if field.allow_none and not is_optional(field_type):
-                field_type = Union[field_type, None]
+            field_type = _unwrap_field_type(true_type_for_pydantic_field(field))
             if not _isinstance(value, field_type):
                 console.deprecate(
                     "mismatched-type-assignment",
diff --git a/reflex/utils/types.py b/reflex/utils/types.py
index b8bcbf2d6..d8c787a8d 100644
--- a/reflex/utils/types.py
+++ b/reflex/utils/types.py
@@ -269,6 +269,25 @@ def is_optional(cls: GenericType) -> bool:
     return is_union(cls) and type(None) in get_args(cls)
 
 
+def true_type_for_pydantic_field(f: ModelField):
+    """Get the type for a pydantic field.
+
+    Args:
+        f: The field to get the type for.
+
+    Returns:
+        The type for the field.
+    """
+    outer_type = f.outer_type_
+    if (
+        f.allow_none
+        and not is_optional(outer_type)
+        and outer_type not in (None, type(None))
+    ):
+        return Optional[outer_type]
+    return outer_type
+
+
 def value_inside_optional(cls: GenericType) -> GenericType:
     """Get the value inside an Optional type or the original type.
 
@@ -283,6 +302,22 @@ def value_inside_optional(cls: GenericType) -> GenericType:
     return cls
 
 
+def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
+    """Get the type of a field in a class.
+
+    Args:
+        cls: The class to check.
+        field_name: The name of the field to check.
+
+    Returns:
+        The type of the field, if it exists, else None.
+    """
+    type_hints = get_type_hints(cls)
+    if field_name in type_hints:
+        return type_hints[field_name]
+    return None
+
+
 def get_property_hint(attr: Any | None) -> GenericType | None:
     """Check if an attribute is a property and return its type hint.
 
@@ -320,24 +355,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
     if hint := get_property_hint(attr):
         return hint
 
-    if (
-        hasattr(cls, "__fields__")
-        and name in cls.__fields__
-        and hasattr(cls.__fields__[name], "outer_type_")
-    ):
+    if hasattr(cls, "__fields__") and name in cls.__fields__:
         # pydantic models
-        field = cls.__fields__[name]
-        type_ = field.outer_type_
-        if isinstance(type_, ModelField):
-            type_ = type_.type_
-        if (
-            not field.required
-            and field.default is None
-            and field.default_factory is None
-        ):
-            # Ensure frontend uses null coalescing when accessing.
-            type_ = Optional[type_]
-        return type_
+        return get_field_type(cls, name)
     elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
         insp = sqlalchemy.inspect(cls)
         if name in insp.columns: