Fix pre-commit issues

This commit is contained in:
Masen Furer 2025-01-28 23:54:11 -08:00
parent aa5b3037d7
commit e74e913b4c
No known key found for this signature in database
GPG Key ID: 2AE2BD5531FF94F4
2 changed files with 67 additions and 27 deletions

View File

@ -81,6 +81,9 @@ class DependencyTracker:
Args: Args:
instruction: The dis instruction to process. instruction: The dis instruction to process.
Raises:
VarValueError: if the attribute is an disallowed name.
""" """
from .base import ComputedVar from .base import ComputedVar
@ -133,6 +136,9 @@ class DependencyTracker:
Args: Args:
instruction: The dis instruction to process. instruction: The dis instruction to process.
Raises:
VarValueError: if the state class cannot be determined from the instruction.
""" """
from reflex.state import BaseState from reflex.state import BaseState
@ -140,14 +146,25 @@ class DependencyTracker:
raise VarValueError( raise VarValueError(
f"Dependency detection cannot identify get_state class from local var {instruction.argval}." f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
) )
if isinstance(self.func, CodeType):
raise VarValueError(
"Dependency detection cannot identify get_state class from a code object."
)
if instruction.opname == "LOAD_GLOBAL": if instruction.opname == "LOAD_GLOBAL":
# Special case: referencing state class from global scope. # Special case: referencing state class from global scope.
self._getting_state_class = self.func.__globals__.get(instruction.argval) try:
self._getting_state_class = inspect.getclosurevars(self.func).globals[
instruction.argval
]
except (ValueError, KeyError) as ve:
raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
) from ve
elif instruction.opname == "LOAD_DEREF": elif instruction.opname == "LOAD_DEREF":
# Special case: referencing state class from closure. # Special case: referencing state class from closure.
closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__))
try: try:
self._getting_state_class = closure[instruction.argval].cell_contents closure = inspect.getclosurevars(self.func).nonlocals
self._getting_state_class = closure[instruction.argval]
except ValueError as ve: except ValueError as ve:
raise VarValueError( raise VarValueError(
f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?." f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?."
@ -169,33 +186,54 @@ class DependencyTracker:
Returns: Returns:
The Var object. The Var object.
Raises:
VarValueError: if the source code for the var cannot be determined.
""" """
# Get the original source code and eval it to get the Var. # Get the original source code and eval it to get the Var.
start_line = self._getting_var_instructions[0].positions.lineno module = inspect.getmodule(self.func)
start_column = self._getting_var_instructions[0].positions.col_offset positions = self._getting_var_instructions[0].positions
end_line = self._getting_var_instructions[-1].positions.end_lineno if module is None or positions is None:
end_column = self._getting_var_instructions[-1].positions.end_col_offset raise VarValueError(
source = inspect.getsource(inspect.getmodule(self.func)).splitlines(True)[ f"Cannot determine the source code for the var in {self.func!r}."
start_line - 1 : end_line )
] start_line = positions.lineno
start_column = positions.col_offset
end_line = positions.end_lineno
end_column = positions.end_col_offset
if (
start_line is None
or start_column is None
or end_line is None
or end_column is None
):
raise VarValueError(
f"Cannot determine the source code for the var in {self.func!r}."
)
source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line]
# Create a python source string snippet. # Create a python source string snippet.
if len(source) > 1: if len(source) > 1:
snipped_source = "".join( snipped_source = "".join(
[ [
source[0][start_column:], *source[0][start_column:],
source[1:-2] if len(source) > 2 else "", *(source[1:-2] if len(source) > 2 else []),
source[-1][:end_column], *source[-1][:end_column],
] ]
) )
else: else:
snipped_source = source[0][start_column:end_column] snipped_source = source[0][start_column:end_column]
# Fallback if the closure is not available.
globals = {}
closure = {}
try: try:
closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__)) if not isinstance(self.func, CodeType):
closurevars = inspect.getclosurevars(self.func)
closure = closurevars.nonlocals
globals = dict(closurevars.globals)
except Exception: except Exception:
# Fallback if the closure is not available. pass
closure = {}
# Evaluate the string in the context of the function's globals and closure. # Evaluate the string in the context of the function's globals and closure.
return eval(f"({snipped_source})", self.func.__globals__, closure) return eval(f"({snipped_source})", globals, closure)
def handle_getting_var(self, instruction: dis.Instruction) -> None: def handle_getting_var(self, instruction: dis.Instruction) -> None:
"""Handle bytecode analysis when `get_var_value` was called in the function. """Handle bytecode analysis when `get_var_value` was called in the function.
@ -207,11 +245,18 @@ class DependencyTracker:
Args: Args:
instruction: The dis instruction to process. instruction: The dis instruction to process.
Raises:
VarValueError: if the source code for the var cannot be determined.
""" """
if instruction.opname == "CALL" and self._getting_var_instructions: if instruction.opname == "CALL" and self._getting_var_instructions:
if self._getting_var_instructions: if self._getting_var_instructions:
the_var = self._eval_var() the_var = self._eval_var()
the_var_data = the_var._get_all_var_data() the_var_data = the_var._get_all_var_data()
if the_var_data is None:
raise VarValueError(
f"Cannot determine the source code for the var in {self.func!r}."
)
self.dependencies.setdefault(the_var_data.state, set()).add( self.dependencies.setdefault(the_var_data.state, set()).add(
the_var_data.field_name the_var_data.field_name
) )
@ -227,10 +272,6 @@ class DependencyTracker:
Recursively called when the function makes a method call on "self" or Recursively called when the function makes a method call on "self" or
define comprehensions or nested functions that may reference "self". define comprehensions or nested functions that may reference "self".
Raises:
VarValueError: if the function references the get_state, parent_state, or substates attributes
(cannot track deps in a related state, only implicitly via parent state).
""" """
for instruction in dis.get_instructions(self.func): for instruction in dis.get_instructions(self.func):
if self.scan_status == ScanStatus.GETTING_STATE: if self.scan_status == ScanStatus.GETTING_STATE:

View File

@ -3189,8 +3189,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
RxState = State RxState = State
@pytest.mark.skip(reason="This test is maybe not relevant anymore.") def test_potentially_dirty_states():
def test_potentially_dirty_substates():
"""Test that potentially_dirty_substates returns the correct substates. """Test that potentially_dirty_substates returns the correct substates.
Even if the name "State" is shadowed, it should still work correctly. Even if the name "State" is shadowed, it should still work correctly.
@ -3206,9 +3205,9 @@ def test_potentially_dirty_substates():
def bar(self) -> str: def bar(self) -> str:
return "" return ""
assert RxState._potentially_dirty_substates() == set() assert RxState._get_potentially_dirty_states() == set()
assert State._potentially_dirty_substates() == set() assert State._get_potentially_dirty_states() == set()
assert C1._potentially_dirty_substates() == set() assert C1._get_potentially_dirty_states() == set()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -3857,7 +3856,7 @@ async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
child3 = await self.get_state(Child3) child3 = await self.get_state(Child3)
return child3.child3_var + p.parent_var return child3.child3_var + p.parent_var
mock_app.state_manager.state = mock_app.state = Parent mock_app.state_manager.state = mock_app._state = Parent
# Get the top level state via unconnected sibling. # Get the top level state via unconnected sibling.
root = await mock_app.state_manager.get_state(_substate_key(token, Child)) root = await mock_app.state_manager.get_state(_substate_key(token, Child))