var_data fixes with hooks values

This commit is contained in:
Khaleel Al-Adhami 2025-01-29 14:36:54 -08:00
parent 2c3257d4ea
commit 6691af0846
4 changed files with 63 additions and 4 deletions

View File

@ -37,6 +37,7 @@ from typing_extensions import (
) )
from reflex import constants from reflex import constants
from reflex.constants.compiler import CompileVars, Hooks, Imports
from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.utils import console, format from reflex.utils import console, format
from reflex.utils.exceptions import ( from reflex.utils.exceptions import (
@ -1729,7 +1730,13 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
arg_def_expr = Var(_js_expr="args") arg_def_expr = Var(_js_expr="args")
if value.invocation is None: if value.invocation is None:
invocation = FunctionStringVar.create("addEvents") invocation = FunctionStringVar.create(
CompileVars.ADD_EVENTS,
_var_data=VarData(
imports=Imports.EVENTS,
hooks={Hooks.EVENTS: None},
),
)
else: else:
invocation = value.invocation invocation = value.invocation

View File

@ -29,6 +29,7 @@ from typing import (
Mapping, Mapping,
NoReturn, NoReturn,
Optional, Optional,
Sequence,
Set, Set,
Tuple, Tuple,
Type, Type,
@ -131,7 +132,7 @@ class VarData:
state: str = "", state: str = "",
field_name: str = "", field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None, imports: ImportDict | ParsedImportDict | None = None,
hooks: Mapping[str, VarData | None] | None = None, hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None,
deps: list[Var] | None = None, deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None, position: Hooks.HookPosition | None = None,
): ):
@ -145,6 +146,10 @@ class VarData:
deps: Dependencies of the var for useCallback. deps: Dependencies of the var for useCallback.
position: Position of the hook in the component. position: Position of the hook in the component.
""" """
if isinstance(hooks, str):
hooks = [hooks]
if not isinstance(hooks, dict):
hooks = {hook: None for hook in (hooks or [])}
immutable_imports: ImmutableParsedImportDict = tuple( immutable_imports: ImmutableParsedImportDict = tuple(
(k, tuple(v)) for k, v in parse_imports(imports or {}).items() (k, tuple(v)) for k, v in parse_imports(imports or {}).items()
) )
@ -155,6 +160,17 @@ class VarData:
object.__setattr__(self, "deps", tuple(deps or [])) object.__setattr__(self, "deps", tuple(deps or []))
object.__setattr__(self, "position", position or None) object.__setattr__(self, "position", position or None)
if hooks and any(hooks.values()):
merged_var_data = VarData.merge(self, *hooks.values())
if merged_var_data is None:
raise ValueError("what")
object.__setattr__(self, "state", merged_var_data.state)
object.__setattr__(self, "field_name", merged_var_data.field_name)
object.__setattr__(self, "imports", merged_var_data.imports)
object.__setattr__(self, "hooks", merged_var_data.hooks)
object.__setattr__(self, "deps", merged_var_data.deps)
object.__setattr__(self, "position", merged_var_data.position)
def old_school_imports(self) -> ImportDict: def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict. """Return the imports as a mutable dict.

View File

@ -3,6 +3,7 @@ from typing import Callable, List
import pytest import pytest
import reflex as rx import reflex as rx
from reflex.constants.compiler import Hooks, Imports
from reflex.event import ( from reflex.event import (
Event, Event,
EventChain, EventChain,
@ -14,7 +15,7 @@ from reflex.event import (
) )
from reflex.state import BaseState from reflex.state import BaseState
from reflex.utils import format from reflex.utils import format
from reflex.vars.base import Field, LiteralVar, Var, field from reflex.vars.base import Field, LiteralVar, Var, VarData, field
def make_var(value) -> Var: def make_var(value) -> Var:
@ -443,9 +444,28 @@ def test_event_var_data():
return (value,) return (value,)
# Ensure chain carries _var_data # Ensure chain carries _var_data
chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec)) chain_var = Var.create(
EventChain(
events=[S.s(S.x)],
args_spec=_args_spec,
invocation=rx.vars.FunctionStringVar.create(""),
)
)
assert chain_var._get_all_var_data() == S.x._get_all_var_data() assert chain_var._get_all_var_data() == S.x._get_all_var_data()
chain_var_data = Var.create(
EventChain(
events=[],
args_spec=_args_spec,
)
)._get_all_var_data()
assert chain_var_data is not None
assert chain_var_data == VarData(
imports=Imports.EVENTS,
hooks={Hooks.EVENTS: None},
)
def test_event_bound_method() -> None: def test_event_bound_method() -> None:
class S(BaseState): class S(BaseState):

View File

@ -1862,3 +1862,19 @@ def test_to_string_operation():
single_var = Var.create(Email()) single_var = Var.create(Email())
assert single_var._var_type == Email assert single_var._var_type == Email
def test_var_data_hooks():
var_data_str = VarData(hooks="what")
var_data_list = VarData(hooks=["what"])
var_data_dict = VarData(hooks={"what": None})
assert var_data_str == var_data_list == var_data_dict
var_data_list_multiple = VarData(hooks=["what", "whot"])
var_data_dict_multiple = VarData(hooks={"what": None, "whot": None})
assert var_data_list_multiple == var_data_dict_multiple
def test_var_data_with_hooks_value():
var_data = VarData(hooks={"what": VarData(hooks={"whot": VarData(hooks="whott")})})
assert var_data == VarData(hooks=["what", "whot", "whott"])