diff --git a/reflex/experimental/client_state.py b/reflex/experimental/client_state.py index 45dfef237..ca91905ae 100644 --- a/reflex/experimental/client_state.py +++ b/reflex/experimental/client_state.py @@ -34,6 +34,18 @@ def _client_state_ref(var_name: str) -> str: return f"refs['_client_state_{var_name}']" +def _client_state_ref_dict(var_name: str) -> str: + """Get the ref path for a ClientStateVar. + + Args: + var_name: The name of the variable. + + Returns: + An accessor for ClientStateVar ref as a string. + """ + return f"refs['_client_state_dict_{var_name}']" + + @dataclasses.dataclass( eq=False, frozen=True, @@ -115,10 +127,41 @@ class ClientStateVar(Var): "react": [ImportVar(tag="useState"), ImportVar(tag="useId")], } if global_ref: - hooks[f"{_client_state_ref(var_name)} ??= {{}}"] = None - hooks[f"{_client_state_ref(setter_name)} ??= {{}}"] = None - hooks[f"{_client_state_ref(var_name)}[{id_name}] = {var_name}"] = None - hooks[f"{_client_state_ref(setter_name)}[{id_name}] = {setter_name}"] = None + arg_name = get_unique_variable_name() + func = ArgsFunctionOperationBuilder.create( + args_names=(arg_name,), + return_expr=Var("Array.prototype.forEach.call") + .to(FunctionVar) + .call( + ( + Var("Object.values") + .to(FunctionVar) + .call(Var(_client_state_ref_dict(setter_name))) + .to(list) + .to(list) + ) + + Var.create( + [ + Var( + f"(value) => {{ {_client_state_ref(var_name)} = value; }}" + ) + ] + ).to(list), + ArgsFunctionOperationBuilder.create( + args_names=("setter",), + return_expr=Var("setter").to(FunctionVar).call(Var(arg_name)), + ), + ), + ) + + hooks[f"{_client_state_ref(setter_name)} = {func!s}"] = None + hooks[f"{_client_state_ref(var_name)} ??= {var_name!s}"] = None + hooks[f"{_client_state_ref_dict(var_name)} ??= {{}}"] = None + hooks[f"{_client_state_ref_dict(setter_name)} ??= {{}}"] = None + hooks[f"{_client_state_ref_dict(var_name)}[{id_name}] = {var_name}"] = None + hooks[ + f"{_client_state_ref_dict(setter_name)}[{id_name}] = {setter_name}" + ] = None imports.update(_refs_import) return cls( _js_expr="", @@ -150,7 +193,7 @@ class ClientStateVar(Var): return ( Var( _js_expr=( - _client_state_ref(self._getter_name) + f"[{self._id_name}]" + _client_state_ref_dict(self._getter_name) + f"[{self._id_name}]" if self._global_ref else self._getter_name ), @@ -179,26 +222,11 @@ class ClientStateVar(Var): """ _var_data = VarData(imports=_refs_import if self._global_ref else {}) - arg_name = get_unique_variable_name() setter = ( - ArgsFunctionOperationBuilder.create( - args_names=(arg_name,), - return_expr=Var("Array.prototype.forEach.call") - .to(FunctionVar) - .call( - Var("Object.values") - .to(FunctionVar) - .call(Var(_client_state_ref(self._setter_name))), - ArgsFunctionOperationBuilder.create( - args_names=("setter",), - return_expr=Var("setter").to(FunctionVar).call(Var(arg_name)), - ), - ), - _var_data=_var_data, - ) + Var(_client_state_ref(self._setter_name)) if self._global_ref - else Var(self._setter_name, _var_data=_var_data).to(FunctionVar) - ) + else Var(self._setter_name, _var_data=_var_data) + ).to(FunctionVar) if value is not NoValue: # This is a hack to make it work like an EventSpec taking an arg