Save the var from get_var_name

This commit is contained in:
Masen Furer 2025-01-24 11:46:37 -08:00
parent 3d50c1b623
commit 3da64264ba
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
3 changed files with 32 additions and 32 deletions

View File

@ -15,7 +15,6 @@ import time
import typing
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from hashlib import md5
from pathlib import Path
from types import FunctionType, MethodType
@ -3468,7 +3467,9 @@ class StateManagerRedis(StateManager):
except ValueError:
# The requested state is missing, so fetch and link it (and its parents).
link_tasks.add(
asyncio.create_task(self._link_arbitrary_state(state, substate_cls))
asyncio.create_task(
self._link_arbitrary_state(state, substate_cls)
)
)
for substate_name, substate_task in tasks.items():

View File

@ -1914,16 +1914,10 @@ class ComputedVar(Var[RETURN_TYPE]):
and all_var_data.state
else None
)
var_name = (
dep._js_expr[len(formatted_state_prefix) :]
if state_name
and (
formatted_state_prefix := format_state_name(state_name)
+ "."
)
and dep._js_expr.startswith(formatted_state_prefix)
else dep._js_expr
)
if all_var_data is not None:
var_name = all_var_data.field_name
else:
var_name = dep._js_expr
_static_deps.setdefault(state_name, set()).add(var_name)
elif isinstance(dep, str) and dep != "":
_static_deps.setdefault(None, set()).add(dep)
@ -2214,28 +2208,22 @@ class ComputedVar(Var[RETURN_TYPE]):
start_column = getting_var[0].positions.col_offset
end_line = getting_var[-1].positions.end_lineno
end_column = getting_var[-1].positions.end_col_offset
source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[start_line - 1: end_line]
source = inspect.getsource(inspect.getmodule(obj)).splitlines(True)[
start_line - 1 : end_line
]
if len(source) > 1:
snipped_source = "".join(
[
source[0][start_column:],
source[1:-2] if len(source) > 2 else "",
source[-1][:end_column]
source[-1][:end_column],
]
)
else:
snipped_source = source[0][start_column:end_column]
the_var = eval(f"({snipped_source})", obj.__globals__)
print(the_var)
# code = source[start_line - 1]
# bytecode = bytearray((dis.opmap["RESUME"], 0))
# for ins in getting_var:
# bytecode.append(ins.opcode)
# bytecode.append(ins.arg or 0 & 0xFF)
# bytecode.extend((dis.opmap["RETURN_VALUE"], 0))
# bc = dis.Bytecode(obj)
# code = bc.codeobj.replace(co_code=bytes(bytecode), co_argcount=0, co_nlocals=0, co_varnames=())
# breakpoint()
the_var_data = the_var._get_all_var_data()
d.setdefault(the_var_data.state, set()).add(the_var_data.field_name)
getting_var = False
elif isinstance(getting_var, list):
getting_var.append(instruction)
@ -2266,7 +2254,6 @@ class ComputedVar(Var[RETURN_TYPE]):
# Special case: arbitrary var access requested.
getting_var = True
continue
print(f"{self_on_top_of_stack=}")
target_state = objclass.get_root_state().get_class_substate(
self_on_top_of_stack
)

View File

@ -1170,9 +1170,15 @@ def test_conditional_computed_vars():
ms = MainState()
# Initially there are no dirty computed vars.
assert ms._dirty_computed_vars(from_vars={"flag"}) == {(MainState.get_full_name(), "rendered_var")}
assert ms._dirty_computed_vars(from_vars={"t2"}) == {(MainState.get_full_name(), "rendered_var")}
assert ms._dirty_computed_vars(from_vars={"t1"}) == {(MainState.get_full_name(), "rendered_var")}
assert ms._dirty_computed_vars(from_vars={"flag"}) == {
(MainState.get_full_name(), "rendered_var")
}
assert ms._dirty_computed_vars(from_vars={"t2"}) == {
(MainState.get_full_name(), "rendered_var")
}
assert ms._dirty_computed_vars(from_vars={"t1"}) == {
(MainState.get_full_name(), "rendered_var")
}
assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == {
MainState.get_full_name(): {"flag", "t1", "t2"}
}
@ -1369,7 +1375,10 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
assert isinstance(HandlerState.handler, EventHandler)
s = HandlerState()
assert (HandlerState.get_full_name(), "cached_x_side_effect") in s._var_dependencies["x"]
assert (
HandlerState.get_full_name(),
"cached_x_side_effect",
) in s._var_dependencies["x"]
assert s.cached_x_side_effect == 1
assert s.x == 43
s.handler()
@ -3221,7 +3230,9 @@ async def test_router_var_dep() -> None:
foo = RouterVarDepState.computed_vars["foo"]
State._init_var_dependency_dicts()
assert foo._deps(objclass=RouterVarDepState) == {RouterVarDepState.get_full_name(): {"router"}}
assert foo._deps(objclass=RouterVarDepState) == {
RouterVarDepState.get_full_name(): {"router"}
}
assert State._var_dependencies == {
"router": {(RouterVarDepState.get_full_name(), "foo")}
}
@ -3236,7 +3247,9 @@ async def test_router_var_dep() -> None:
state.parent_state = parent_state
parent_state.substates = {RouterVarDepState.get_name(): state}
populated_substate_classes = await rx_state._recursively_populate_dependent_substates()
populated_substate_classes = (
await rx_state._recursively_populate_dependent_substates()
)
assert populated_substate_classes == {State, RouterVarDepState}
assert state.dirty_vars == set()
@ -3873,4 +3886,3 @@ async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
assert await child.v == 2
root.parent_var = 2
assert await child.v == 3