From 67dcbeae262fa72eeb65167825bc7b1f70e85a08 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 15 Oct 2024 16:30:32 -0700 Subject: [PATCH] test case for event related vars carrying _var_data --- tests/units/test_event.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/tests/units/test_event.py b/tests/units/test_event.py index 3996a6101..d7b7cf7a2 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -2,11 +2,18 @@ from typing import List import pytest -from reflex import event -from reflex.event import Event, EventHandler, EventSpec, call_event_handler, fix_events +from reflex.event import ( + Event, + EventChain, + EventHandler, + EventSpec, + call_event_handler, + event, + fix_events, +) from reflex.state import BaseState from reflex.utils import format -from reflex.vars.base import LiteralVar, Var +from reflex.vars.base import Field, LiteralVar, Var, field def make_var(value) -> Var: @@ -388,3 +395,28 @@ def test_event_actions_on_state(): assert sp_handler.event_actions == {"stopPropagation": True} # should NOT affect other references to the handler assert not handler.event_actions + + +def test_event_var_data(): + class S(BaseState): + x: Field[int] = field(0) + + @event + def s(self, value: int): + pass + + # Handler doesn't have any _var_data because it's just a str + handler_var = Var.create(S.s) + assert handler_var._get_all_var_data() is None + + # Ensure spec carries _var_data + spec_var = Var.create(S.s(S.x)) + assert spec_var._get_all_var_data() == S.x._get_all_var_data() + + # Needed to instantiate the EventChain + def _args_spec(value: Var[int]) -> tuple[Var[int]]: + return (value,) + + # Ensure chain carries _var_data + chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec)) + assert chain_var._get_all_var_data() == S.x._get_all_var_data()