From 397683a9c1c418a8a316a707ef152d1f8b94c71e Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 16 Aug 2024 18:37:11 -0700 Subject: [PATCH] use even better logic for finding state wrt computedvar --- reflex/ivars/base.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/reflex/ivars/base.py b/reflex/ivars/base.py index 72022c427..144b54e7b 100644 --- a/reflex/ivars/base.py +++ b/reflex/ivars/base.py @@ -21,6 +21,7 @@ from typing import ( List, Literal, Optional, + Sequence, Set, Tuple, Type, @@ -1485,22 +1486,38 @@ class ImmutableComputedVar(ImmutableVar[RETURN_TYPE]): Returns: The value of the var for the given instance. """ - from reflex.state import BaseState - if instance is None: - list_of_modules = self.fget.__qualname__.split(".") + from reflex.state import BaseState + + path_to_function = self.fget.__qualname__.split(".") class_name_where_defined = ( - list_of_modules[-2] if len(list_of_modules) > 1 else owner.__name__ + path_to_function[-2] if len(path_to_function) > 1 else owner.__name__ ) - classes_where_defined = [ - c - for c in inspect.getmro(owner) - if c.__name__ == class_name_where_defined - ] + + def contains_class_name(states: Sequence[Type]) -> bool: + return any(c.__name__ == class_name_where_defined for c in states) + + def is_not_mixin(state: Type[BaseState]) -> bool: + return not state._mixin + + def length_of_state(state: Type[BaseState]) -> int: + return len(inspect.getmro(state)) + class_where_defined = cast( Type[BaseState], - classes_where_defined[0] if classes_where_defined else owner, + min( + filter( + is_not_mixin, + filter( + lambda state: contains_class_name(inspect.getmro(state)), + inspect.getmro(owner), + ), + ), + default=owner, + key=length_of_state, + ), ) + return self._replace( _var_name=format_state_name(class_where_defined.get_full_name()) + "."