diff --git a/reflex/vars/dep_tracking.py b/reflex/vars/dep_tracking.py index 3bd8a6bf0..387de5a66 100644 --- a/reflex/vars/dep_tracking.py +++ b/reflex/vars/dep_tracking.py @@ -81,6 +81,9 @@ class DependencyTracker: Args: instruction: The dis instruction to process. + + Raises: + VarValueError: if the attribute is an disallowed name. """ from .base import ComputedVar @@ -133,6 +136,9 @@ class DependencyTracker: Args: instruction: The dis instruction to process. + + Raises: + VarValueError: if the state class cannot be determined from the instruction. """ from reflex.state import BaseState @@ -140,14 +146,25 @@ class DependencyTracker: raise VarValueError( f"Dependency detection cannot identify get_state class from local var {instruction.argval}." ) + if isinstance(self.func, CodeType): + raise VarValueError( + "Dependency detection cannot identify get_state class from a code object." + ) if instruction.opname == "LOAD_GLOBAL": # Special case: referencing state class from global scope. - self._getting_state_class = self.func.__globals__.get(instruction.argval) + try: + self._getting_state_class = inspect.getclosurevars(self.func).globals[ + instruction.argval + ] + except (ValueError, KeyError) as ve: + raise VarValueError( + f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals." + ) from ve elif instruction.opname == "LOAD_DEREF": # Special case: referencing state class from closure. - closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__)) try: - self._getting_state_class = closure[instruction.argval].cell_contents + closure = inspect.getclosurevars(self.func).nonlocals + self._getting_state_class = closure[instruction.argval] except ValueError as ve: raise VarValueError( f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?." @@ -169,33 +186,54 @@ class DependencyTracker: Returns: The Var object. + + Raises: + VarValueError: if the source code for the var cannot be determined. """ # Get the original source code and eval it to get the Var. - start_line = self._getting_var_instructions[0].positions.lineno - start_column = self._getting_var_instructions[0].positions.col_offset - end_line = self._getting_var_instructions[-1].positions.end_lineno - end_column = self._getting_var_instructions[-1].positions.end_col_offset - source = inspect.getsource(inspect.getmodule(self.func)).splitlines(True)[ - start_line - 1 : end_line - ] + module = inspect.getmodule(self.func) + positions = self._getting_var_instructions[0].positions + if module is None or positions is None: + raise VarValueError( + f"Cannot determine the source code for the var in {self.func!r}." + ) + start_line = positions.lineno + start_column = positions.col_offset + end_line = positions.end_lineno + end_column = positions.end_col_offset + if ( + start_line is None + or start_column is None + or end_line is None + or end_column is None + ): + raise VarValueError( + f"Cannot determine the source code for the var in {self.func!r}." + ) + source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line] # Create a python source string snippet. if len(source) > 1: snipped_source = "".join( [ - source[0][start_column:], - source[1:-2] if len(source) > 2 else "", - source[-1][:end_column], + *source[0][start_column:], + *(source[1:-2] if len(source) > 2 else []), + *source[-1][:end_column], ] ) else: snipped_source = source[0][start_column:end_column] + # Fallback if the closure is not available. + globals = {} + closure = {} try: - closure = dict(zip(self.func.__code__.co_freevars, self.func.__closure__)) + if not isinstance(self.func, CodeType): + closurevars = inspect.getclosurevars(self.func) + closure = closurevars.nonlocals + globals = dict(closurevars.globals) except Exception: - # Fallback if the closure is not available. - closure = {} + pass # Evaluate the string in the context of the function's globals and closure. - return eval(f"({snipped_source})", self.func.__globals__, closure) + return eval(f"({snipped_source})", globals, closure) def handle_getting_var(self, instruction: dis.Instruction) -> None: """Handle bytecode analysis when `get_var_value` was called in the function. @@ -207,11 +245,18 @@ class DependencyTracker: Args: instruction: The dis instruction to process. + + Raises: + VarValueError: if the source code for the var cannot be determined. """ if instruction.opname == "CALL" and self._getting_var_instructions: if self._getting_var_instructions: the_var = self._eval_var() the_var_data = the_var._get_all_var_data() + if the_var_data is None: + raise VarValueError( + f"Cannot determine the source code for the var in {self.func!r}." + ) self.dependencies.setdefault(the_var_data.state, set()).add( the_var_data.field_name ) @@ -227,10 +272,6 @@ class DependencyTracker: Recursively called when the function makes a method call on "self" or define comprehensions or nested functions that may reference "self". - - Raises: - VarValueError: if the function references the get_state, parent_state, or substates attributes - (cannot track deps in a related state, only implicitly via parent state). """ for instruction in dis.get_instructions(self.func): if self.scan_status == ScanStatus.GETTING_STATE: diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 0d9d438ea..2a07f1b2e 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3189,8 +3189,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str): RxState = State -@pytest.mark.skip(reason="This test is maybe not relevant anymore.") -def test_potentially_dirty_substates(): +def test_potentially_dirty_states(): """Test that potentially_dirty_substates returns the correct substates. Even if the name "State" is shadowed, it should still work correctly. @@ -3206,9 +3205,9 @@ def test_potentially_dirty_substates(): def bar(self) -> str: return "" - assert RxState._potentially_dirty_substates() == set() - assert State._potentially_dirty_substates() == set() - assert C1._potentially_dirty_substates() == set() + assert RxState._get_potentially_dirty_states() == set() + assert State._get_potentially_dirty_states() == set() + assert C1._get_potentially_dirty_states() == set() @pytest.mark.asyncio @@ -3857,7 +3856,7 @@ async def test_async_computed_var_get_state(mock_app: rx.App, token: str): child3 = await self.get_state(Child3) return child3.child3_var + p.parent_var - mock_app.state_manager.state = mock_app.state = Parent + mock_app.state_manager.state = mock_app._state = Parent # Get the top level state via unconnected sibling. root = await mock_app.state_manager.get_state(_substate_key(token, Child))