diff --git a/reflex/state.py b/reflex/state.py index 37bc06360..e3e189b22 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -220,6 +220,7 @@ class EventHandlerSetVar(EventHandler): Raises: AttributeError: If the given Var name does not exist on the state. EventHandlerValueError: If the given Var name is not a str + NotImplementedError: If the setter for the given Var is async """ from reflex.utils.exceptions import EventHandlerValueError @@ -228,11 +229,20 @@ class EventHandlerSetVar(EventHandler): raise EventHandlerValueError( f"Var name must be passed as a string, got {args[0]!r}" ) + + handler = getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) + # Check that the requested Var setter exists on the State at compile time. - if getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) is None: + if handler is None: raise AttributeError( f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`" ) + + if asyncio.iscoroutinefunction(handler.fn): + raise NotImplementedError( + f"Setter for {args[0]} is async, which is not supported." + ) + return super().__call__(*args) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index ebfeeb72c..544ddc606 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -106,6 +106,7 @@ class TestState(BaseState): fig: Figure = Figure() dt: datetime.datetime = datetime.datetime.fromisoformat("1989-11-09T18:53:00+01:00") _backend: int = 0 + asynctest: int = 0 @ComputedVar def sum(self) -> float: @@ -129,6 +130,14 @@ class TestState(BaseState): """Do something.""" pass + async def set_asynctest(self, value: int): + """Set the asynctest value. Intentionally overwrite the default setter with an async one. + + Args: + value: The new value. + """ + self.asynctest = value + class ChildState(TestState): """A child state fixture.""" @@ -313,6 +322,7 @@ def test_class_vars(test_state): "upper", "fig", "dt", + "asynctest", } @@ -733,6 +743,7 @@ def test_reset(test_state, child_state): "mapping", "dt", "_backend", + "asynctest", } # The dirty vars should be reset. @@ -3179,6 +3190,13 @@ async def test_setvar(mock_app: rx.App, token: str): TestState.setvar(42, 42) +@pytest.mark.asyncio +async def test_setvar_async_setter(): + """Test that overridden async setters raise Exception when used with setvar.""" + with pytest.raises(NotImplementedError): + TestState.setvar("asynctest", 42) + + @pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis") @pytest.mark.parametrize( "expiration_kwargs, expected_values", diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py index 17485d52e..f8b605541 100644 --- a/tests/units/utils/test_format.py +++ b/tests/units/utils/test_format.py @@ -601,6 +601,7 @@ formatted_router = { "sum": 3.14, "upper": "", "router": formatted_router, + "asynctest": 0, }, ChildState.get_full_name(): { "count": 23,