var_data fixes with hooks values (#4717)

* var_data fixes with hooks values

* remove the raise error
This commit is contained in:
Khaleel Al-Adhami 2025-01-31 13:07:51 -08:00 committed by GitHub
parent 335816cbf7
commit 12a42b6c47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 62 additions and 4 deletions

View File

@ -37,6 +37,7 @@ from typing_extensions import (
) )
from reflex import constants from reflex import constants
from reflex.constants.compiler import CompileVars, Hooks, Imports
from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.utils import console, format from reflex.utils import console, format
from reflex.utils.exceptions import ( from reflex.utils.exceptions import (
@ -1729,7 +1730,13 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
arg_def_expr = Var(_js_expr="args") arg_def_expr = Var(_js_expr="args")
if value.invocation is None: if value.invocation is None:
invocation = FunctionStringVar.create("addEvents") invocation = FunctionStringVar.create(
CompileVars.ADD_EVENTS,
_var_data=VarData(
imports=Imports.EVENTS,
hooks={Hooks.EVENTS: None},
),
)
else: else:
invocation = value.invocation invocation = value.invocation

View File

@ -29,6 +29,7 @@ from typing import (
Mapping, Mapping,
NoReturn, NoReturn,
Optional, Optional,
Sequence,
Set, Set,
Tuple, Tuple,
Type, Type,
@ -131,7 +132,7 @@ class VarData:
state: str = "", state: str = "",
field_name: str = "", field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None, imports: ImportDict | ParsedImportDict | None = None,
hooks: Mapping[str, VarData | None] | None = None, hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None,
deps: list[Var] | None = None, deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None, position: Hooks.HookPosition | None = None,
): ):
@ -145,6 +146,10 @@ class VarData:
deps: Dependencies of the var for useCallback. deps: Dependencies of the var for useCallback.
position: Position of the hook in the component. position: Position of the hook in the component.
""" """
if isinstance(hooks, str):
hooks = [hooks]
if not isinstance(hooks, dict):
hooks = {hook: None for hook in (hooks or [])}
immutable_imports: ImmutableParsedImportDict = tuple( immutable_imports: ImmutableParsedImportDict = tuple(
(k, tuple(v)) for k, v in parse_imports(imports or {}).items() (k, tuple(v)) for k, v in parse_imports(imports or {}).items()
) )
@ -155,6 +160,16 @@ class VarData:
object.__setattr__(self, "deps", tuple(deps or [])) object.__setattr__(self, "deps", tuple(deps or []))
object.__setattr__(self, "position", position or None) object.__setattr__(self, "position", position or None)
if hooks and any(hooks.values()):
merged_var_data = VarData.merge(self, *hooks.values())
if merged_var_data is not None:
object.__setattr__(self, "state", merged_var_data.state)
object.__setattr__(self, "field_name", merged_var_data.field_name)
object.__setattr__(self, "imports", merged_var_data.imports)
object.__setattr__(self, "hooks", merged_var_data.hooks)
object.__setattr__(self, "deps", merged_var_data.deps)
object.__setattr__(self, "position", merged_var_data.position)
def old_school_imports(self) -> ImportDict: def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict. """Return the imports as a mutable dict.

View File

@ -3,6 +3,7 @@ from typing import Callable, List
import pytest import pytest
import reflex as rx import reflex as rx
from reflex.constants.compiler import Hooks, Imports
from reflex.event import ( from reflex.event import (
Event, Event,
EventChain, EventChain,
@ -14,7 +15,7 @@ from reflex.event import (
) )
from reflex.state import BaseState from reflex.state import BaseState
from reflex.utils import format from reflex.utils import format
from reflex.vars.base import Field, LiteralVar, Var, field from reflex.vars.base import Field, LiteralVar, Var, VarData, field
def make_var(value) -> Var: def make_var(value) -> Var:
@ -443,9 +444,28 @@ def test_event_var_data():
return (value,) return (value,)
# Ensure chain carries _var_data # Ensure chain carries _var_data
chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec)) chain_var = Var.create(
EventChain(
events=[S.s(S.x)],
args_spec=_args_spec,
invocation=rx.vars.FunctionStringVar.create(""),
)
)
assert chain_var._get_all_var_data() == S.x._get_all_var_data() assert chain_var._get_all_var_data() == S.x._get_all_var_data()
chain_var_data = Var.create(
EventChain(
events=[],
args_spec=_args_spec,
)
)._get_all_var_data()
assert chain_var_data is not None
assert chain_var_data == VarData(
imports=Imports.EVENTS,
hooks={Hooks.EVENTS: None},
)
def test_event_bound_method() -> None: def test_event_bound_method() -> None:
class S(BaseState): class S(BaseState):

View File

@ -1862,3 +1862,19 @@ def test_to_string_operation():
single_var = Var.create(Email()) single_var = Var.create(Email())
assert single_var._var_type == Email assert single_var._var_type == Email
def test_var_data_hooks():
var_data_str = VarData(hooks="what")
var_data_list = VarData(hooks=["what"])
var_data_dict = VarData(hooks={"what": None})
assert var_data_str == var_data_list == var_data_dict
var_data_list_multiple = VarData(hooks=["what", "whot"])
var_data_dict_multiple = VarData(hooks={"what": None, "whot": None})
assert var_data_list_multiple == var_data_dict_multiple
def test_var_data_with_hooks_value():
var_data = VarData(hooks={"what": VarData(hooks={"whot": VarData(hooks="whott")})})
assert var_data == VarData(hooks=["what", "whot", "whott"])