diff --git a/reflex/state.py b/reflex/state.py index e073fbfee..eef1f153b 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1899,36 +1899,33 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if include_backend or not self.computed_vars[cvar]._backend ) - # TODO: just return full name? cache? @classmethod - def _potentially_dirty_substates(cls) -> set[Type[BaseState]]: + def _potentially_dirty_substates(cls) -> set[str]: """Determine substates which could be affected by dirty vars in this state. Returns: - Set of State classes that may need to be fetched to recalc computed vars. + Set of State full names that may need to be fetched to recalc computed vars. """ # _always_dirty_substates need to be fetched to recalc computed vars. fetch_substates = set( - cls.get_class_substate((cls.get_name(), *substate_name.split("."))) + f"{cls.get_full_name()}.{substate_name}" for substate_name in cls._always_dirty_substates ) for dependent_substates in cls._substate_var_dependencies.values(): fetch_substates.update( set( - cls.get_class_substate((cls.get_name(), *substate_name.split("."))) + f"{cls.get_full_name()}.{substate_name}" for substate_name in dependent_substates ) ) return fetch_substates - # TODO: just return full name? cache? - # this only needs to be computed once, and only for the root state? @classmethod - def _recursive_potentially_dirty_substates(cls) -> set[Type[BaseState]]: + def _recursive_potentially_dirty_substates(cls) -> set[str]: """Recursively determine substates which could be affected by dirty vars in this state. Returns: - Set of State classes that may need to be fetched to recalc computed vars. + Set of full state names that may need to be fetched to recalc computed vars. """ fetch_substates = cls._potentially_dirty_substates() for substate_cls in cls.get_substates(): @@ -3285,12 +3282,7 @@ class StateManagerRedis(StateManager): walk_state_path = walk_state_path.rpartition(".")[0] state_tokens.add(walk_state_path) - state_tokens.update( - { - substate.get_full_name() - for substate in self.state._recursive_potentially_dirty_substates() - } - ) + state_tokens.update(self.state._recursive_potentially_dirty_substates()) if get_substates: state_tokens.update( { diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 8e61b8dae..04c037715 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3135,10 +3135,17 @@ def test_potentially_dirty_substates(): def bar(self) -> str: return "" - assert RxState._potentially_dirty_substates() == {State} - assert State._potentially_dirty_substates() == {C1} + assert RxState._potentially_dirty_substates() == {State.get_full_name()} + assert State._potentially_dirty_substates() == {C1.get_full_name()} assert C1._potentially_dirty_substates() == set() + assert RxState._recursive_potentially_dirty_substates() == { + State.get_full_name(), + C1.get_full_name(), + } + assert State._recursive_potentially_dirty_substates() == {C1.get_full_name()} + assert C1._recursive_potentially_dirty_substates() == set() + def test_router_var_dep() -> None: """Test that router var dependencies are correctly tracked.""" @@ -3159,7 +3166,9 @@ def test_router_var_dep() -> None: State._init_var_dependency_dicts() assert foo._deps(objclass=RouterVarDepState) == {"router"} - assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState} + assert RouterVarParentState._potentially_dirty_substates() == { + RouterVarDepState.get_full_name() + } assert RouterVarParentState._substate_var_dependencies == { "router": {RouterVarDepState.get_name()} }