Fix pre-commit issues

This commit is contained in:
Masen Furer 2025-01-28 23:54:11 -08:00
parent aa5b3037d7
commit e74e913b4c
No known key found for this signature in database
GPG Key ID: 2AE2BD5531FF94F4
2 changed files with 67 additions and 27 deletions

View File

@ -81,6 +81,9 @@ class DependencyTracker:
Args:
instruction: The dis instruction to process.
Raises:
VarValueError: if the attribute is an disallowed name.
"""
from .base import ComputedVar
@ -133,6 +136,9 @@ class DependencyTracker:
Args:
instruction: The dis instruction to process.
Raises:
VarValueError: if the state class cannot be determined from the instruction.
"""
from reflex.state import BaseState
@ -140,14 +146,25 @@ class DependencyTracker:
raise VarValueError(
f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
)
if isinstance(self.func, CodeType):
raise VarValueError(
"Dependency detection cannot identify get_state class from a code object."
)
if instruction.opname == "LOAD_GLOBAL":
# Special case: referencing state class from global scope.
self._getting_state_class = self.func.__globals__.get(instruction.argval)
try:
self._getting_state_class = inspect.getclosurevars(self.func).globals[
instruction.argval
]
except (ValueError, KeyError) as ve:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
) from ve
elif instruction.opname == "LOAD_DEREF":
# Special case: referencing state class from closure.
closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__))
try:
self._getting_state_class = closure[instruction.argval].cell_contents
closure = inspect.getclosurevars(self.func).nonlocals
self._getting_state_class = closure[instruction.argval]
except ValueError as ve:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
@ -169,33 +186,54 @@ class DependencyTracker:
Returns:
The Var object.
Raises:
VarValueError: if the source code for the var cannot be determined.
"""
# Get the original source code and eval it to get the Var.
start_line = self._getting_var_instructions[0].positions.lineno
start_column = self._getting_var_instructions[0].positions.col_offset
end_line = self._getting_var_instructions[-1].positions.end_lineno
end_column = self._getting_var_instructions[-1].positions.end_col_offset
source = inspect.getsource(inspect.getmodule(self.func)).splitlines(True)[
start_line - 1 : end_line
]
module = inspect.getmodule(self.func)
positions = self._getting_var_instructions[0].positions
if module is None or positions is None:
raise VarValueError(
f"Cannot determine the source code for the var in {self.func!r}."
)
start_line = positions.lineno
start_column = positions.col_offset
end_line = positions.end_lineno
end_column = positions.end_col_offset
if (
start_line is None
or start_column is None
or end_line is None
or end_column is None
):
raise VarValueError(
f"Cannot determine the source code for the var in {self.func!r}."
)
source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line]
# Create a python source string snippet.
if len(source) > 1:
snipped_source = "".join(
[
source[0][start_column:],
source[1:-2] if len(source) > 2 else "",
source[-1][:end_column],
*source[0][start_column:],
*(source[1:-2] if len(source) > 2 else []),
*source[-1][:end_column],
]
)
else:
snipped_source = source[0][start_column:end_column]
# Fallback if the closure is not available.
globals = {}
closure = {}
try:
closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__))
if not isinstance(self.func, CodeType):
closurevars = inspect.getclosurevars(self.func)
closure = closurevars.nonlocals
globals = dict(closurevars.globals)
except Exception:
# Fallback if the closure is not available.
closure = {}
pass
# Evaluate the string in the context of the function's globals and closure.
return eval(f"({snipped_source})", self.func.__globals__, closure)
return eval(f"({snipped_source})", globals, closure)
def handle_getting_var(self, instruction: dis.Instruction) -> None:
"""Handle bytecode analysis when `get_var_value` was called in the function.
@ -207,11 +245,18 @@ class DependencyTracker:
Args:
instruction: The dis instruction to process.
Raises:
VarValueError: if the source code for the var cannot be determined.
"""
if instruction.opname == "CALL" and self._getting_var_instructions:
if self._getting_var_instructions:
the_var = self._eval_var()
the_var_data = the_var._get_all_var_data()
if the_var_data is None:
raise VarValueError(
f"Cannot determine the source code for the var in {self.func!r}."
)
self.dependencies.setdefault(the_var_data.state, set()).add(
the_var_data.field_name
)
@ -227,10 +272,6 @@ class DependencyTracker:
Recursively called when the function makes a method call on "self" or
define comprehensions or nested functions that may reference "self".
Raises:
VarValueError: if the function references the get_state, parent_state, or substates attributes
(cannot track deps in a related state, only implicitly via parent state).
"""
for instruction in dis.get_instructions(self.func):
if self.scan_status == ScanStatus.GETTING_STATE:

View File

@ -3189,8 +3189,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
RxState = State
@pytest.mark.skip(reason="This test is maybe not relevant anymore.")
def test_potentially_dirty_substates():
def test_potentially_dirty_states():
"""Test that potentially_dirty_substates returns the correct substates.
Even if the name "State" is shadowed, it should still work correctly.
@ -3206,9 +3205,9 @@ def test_potentially_dirty_substates():
def bar(self) -> str:
return ""
assert RxState._potentially_dirty_substates() == set()
assert State._potentially_dirty_substates() == set()
assert C1._potentially_dirty_substates() == set()
assert RxState._get_potentially_dirty_states() == set()
assert State._get_potentially_dirty_states() == set()
assert C1._get_potentially_dirty_states() == set()
@pytest.mark.asyncio
@ -3857,7 +3856,7 @@ async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
child3 = await self.get_state(Child3)
return child3.child3_var + p.parent_var
mock_app.state_manager.state = mock_app.state = Parent
mock_app.state_manager.state = mock_app._state = Parent
# Get the top level state via unconnected sibling.
root = await mock_app.state_manager.get_state(_substate_key(token, Child))