diff --git a/reflex/state.py b/reflex/state.py index 6ef3ff3e8..6705f8a50 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -15,7 +15,6 @@ import time import typing import uuid from abc import ABC, abstractmethod -from collections import defaultdict from hashlib import md5 from pathlib import Path from types import FunctionType, MethodType @@ -3468,7 +3467,9 @@ class StateManagerRedis(StateManager): except ValueError: # The requested state is missing, so fetch and link it (and its parents). link_tasks.add( - asyncio.create_task(self._link_arbitrary_state(state, substate_cls)) + asyncio.create_task( + self._link_arbitrary_state(state, substate_cls) + ) ) for substate_name, substate_task in tasks.items(): diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 6bc5b25c4..7d2e2af85 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1914,16 +1914,10 @@ class ComputedVar(Var[RETURN_TYPE]): and all_var_data.state else None ) - var_name = ( - dep._js_expr[len(formatted_state_prefix) :] - if state_name - and ( - formatted_state_prefix := format_state_name(state_name) - + "." - ) - and dep._js_expr.startswith(formatted_state_prefix) - else dep._js_expr - ) + if all_var_data is not None: + var_name = all_var_data.field_name + else: + var_name = dep._js_expr _static_deps.setdefault(state_name, set()).add(var_name) elif isinstance(dep, str) and dep != "": _static_deps.setdefault(None, set()).add(dep) @@ -2214,28 +2208,22 @@ class ComputedVar(Var[RETURN_TYPE]): start_column = getting_var[0].positions.col_offset end_line = getting_var[-1].positions.end_lineno end_column = getting_var[-1].positions.end_col_offset - source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[start_line - 1: end_line] + source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[ + start_line - 1 : end_line + ] if len(source) > 1: snipped_source = "".join( [ source[0][start_column:], source[1:-2] if len(source) > 2 else "", - source[-1][:end_column] + source[-1][:end_column], ] ) else: snipped_source = source[0][start_column:end_column] the_var = eval(f"({snipped_source})", obj.__globals__) - print(the_var) - # code = source[start_line - 1] - # bytecode = bytearray((dis.opmap["RESUME"], 0)) - # for ins in getting_var: - # bytecode.append(ins.opcode) - # bytecode.append(ins.arg or 0 & 0xFF) - # bytecode.extend((dis.opmap["RETURN_VALUE"], 0)) - # bc = dis.Bytecode(obj) - # code = bc.codeobj.replace(co_code=bytes(bytecode), co_argcount=0, co_nlocals=0, co_varnames=()) - # breakpoint() + the_var_data = the_var._get_all_var_data() + d.setdefault(the_var_data.state, set()).add(the_var_data.field_name) getting_var = False elif isinstance(getting_var, list): getting_var.append(instruction) @@ -2266,7 +2254,6 @@ class ComputedVar(Var[RETURN_TYPE]): # Special case: arbitrary var access requested. getting_var = True continue - print(f"{self_on_top_of_stack=}") target_state = objclass.get_root_state().get_class_substate( self_on_top_of_stack ) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index c5e2b1287..10713b2d2 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1170,9 +1170,15 @@ def test_conditional_computed_vars(): ms = MainState() # Initially there are no dirty computed vars. - assert ms._dirty_computed_vars(from_vars={"flag"}) == {(MainState.get_full_name(), "rendered_var")} - assert ms._dirty_computed_vars(from_vars={"t2"}) == {(MainState.get_full_name(), "rendered_var")} - assert ms._dirty_computed_vars(from_vars={"t1"}) == {(MainState.get_full_name(), "rendered_var")} + assert ms._dirty_computed_vars(from_vars={"flag"}) == { + (MainState.get_full_name(), "rendered_var") + } + assert ms._dirty_computed_vars(from_vars={"t2"}) == { + (MainState.get_full_name(), "rendered_var") + } + assert ms._dirty_computed_vars(from_vars={"t1"}) == { + (MainState.get_full_name(), "rendered_var") + } assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == { MainState.get_full_name(): {"flag", "t1", "t2"} } @@ -1369,7 +1375,10 @@ def test_cached_var_depends_on_event_handler(use_partial: bool): assert isinstance(HandlerState.handler, EventHandler) s = HandlerState() - assert (HandlerState.get_full_name(), "cached_x_side_effect") in s._var_dependencies["x"] + assert ( + HandlerState.get_full_name(), + "cached_x_side_effect", + ) in s._var_dependencies["x"] assert s.cached_x_side_effect == 1 assert s.x == 43 s.handler() @@ -3221,7 +3230,9 @@ async def test_router_var_dep() -> None: foo = RouterVarDepState.computed_vars["foo"] State._init_var_dependency_dicts() - assert foo._deps(objclass=RouterVarDepState) == {RouterVarDepState.get_full_name(): {"router"}} + assert foo._deps(objclass=RouterVarDepState) == { + RouterVarDepState.get_full_name(): {"router"} + } assert State._var_dependencies == { "router": {(RouterVarDepState.get_full_name(), "foo")} } @@ -3236,7 +3247,9 @@ async def test_router_var_dep() -> None: state.parent_state = parent_state parent_state.substates = {RouterVarDepState.get_name(): state} - populated_substate_classes = await rx_state._recursively_populate_dependent_substates() + populated_substate_classes = ( + await rx_state._recursively_populate_dependent_substates() + ) assert populated_substate_classes == {State, RouterVarDepState} assert state.dirty_vars == set() @@ -3873,4 +3886,3 @@ async def test_async_computed_var_get_state(mock_app: rx.App, token: str): assert await child.v == 2 root.parent_var = 2 assert await child.v == 3 -