diff --git a/reflex/experimental/client_state.py b/reflex/experimental/client_state.py index 1982b3dfe..a023b557a 100644 --- a/reflex/experimental/client_state.py +++ b/reflex/experimental/client_state.py @@ -12,7 +12,7 @@ from reflex.event import EventChain, EventHandler, EventSpec, run_script from reflex.utils.imports import ImportVar from reflex.vars import VarData, get_unique_variable_name from reflex.vars.base import LiteralVar, Var -from reflex.vars.function import FunctionVar +from reflex.vars.function import ArgsFunctionOperationBuilder, FunctionVar NoValue = object() @@ -45,6 +45,7 @@ class ClientStateVar(Var): # Track the names of the getters and setters _setter_name: str = dataclasses.field(default="") _getter_name: str = dataclasses.field(default="") + _id_name: str = dataclasses.field(default="") # Whether to add the var and setter to the global `refs` object for use in any Component. _global_ref: bool = dataclasses.field(default=True) @@ -96,6 +97,7 @@ class ClientStateVar(Var): """ if var_name is None: var_name = get_unique_variable_name() + id_name = "id_" + get_unique_variable_name() if not isinstance(var_name, str): raise ValueError("var_name must be a string.") if default is NoValue: @@ -106,19 +108,23 @@ class ClientStateVar(Var): default_var = default setter_name = f"set{var_name.capitalize()}" hooks = { + f"const {id_name} = useId()": None, f"const [{var_name}, {setter_name}] = useState({default_var!s})": None, } imports = { - "react": [ImportVar(tag="useState")], + "react": [ImportVar(tag="useState"), ImportVar(tag="useId")], } if global_ref: - hooks[f"{_client_state_ref(var_name)} = {var_name}"] = None - hooks[f"{_client_state_ref(setter_name)} = {setter_name}"] = None + hooks[f"{_client_state_ref(var_name)} ??= {{}}"] = None + hooks[f"{_client_state_ref(setter_name)} ??= {{}}"] = None + hooks[f"{_client_state_ref(var_name)}[{id_name}] = {var_name}"] = None + hooks[f"{_client_state_ref(setter_name)}[{id_name}] = {setter_name}"] = None imports.update(_refs_import) return cls( _js_expr="", _setter_name=setter_name, _getter_name=var_name, + _id_name=id_name, _global_ref=global_ref, _var_type=default_var._var_type, _var_data=VarData.merge( @@ -144,10 +150,11 @@ class ClientStateVar(Var): return ( Var( _js_expr=( - _client_state_ref(self._getter_name) + _client_state_ref(self._getter_name) + f"[{self._id_name}]" if self._global_ref else self._getter_name - ) + ), + _var_data=self._var_data, ) .to(self._var_type) ._replace( @@ -170,28 +177,43 @@ class ClientStateVar(Var): Returns: A special EventChain Var which will set the value when triggered. """ - setter = ( - _client_state_ref(self._setter_name) - if self._global_ref - else self._setter_name - ) _var_data = VarData(imports=_refs_import if self._global_ref else {}) + + arg_name = get_unique_variable_name() + setter = ( + ArgsFunctionOperationBuilder.create( + args_names=(arg_name,), + return_expr=Var("Array.prototype.forEach.call") + .to(FunctionVar) + .call( + Var("Object.values") + .to(FunctionVar) + .call(Var(_client_state_ref(self._setter_name))), + ArgsFunctionOperationBuilder.create( + args_names=("setter",), + return_expr=Var("setter").to(FunctionVar).call(Var(arg_name)), + ), + ), + _var_data=_var_data, + ) + if self._global_ref + else Var(self._setter_name, _var_data=_var_data).to(FunctionVar) + ) + if value is not NoValue: # This is a hack to make it work like an EventSpec taking an arg value_var = LiteralVar.create(value) - _var_data = VarData.merge(_var_data, value_var._get_all_var_data()) value_str = str(value_var) - if value_str.startswith("_"): + setter = ArgsFunctionOperationBuilder.create( # remove patterns of ["*"] from the value_str using regex - arg = re.sub(r"\[\".*\"\]", "", value_str) - setter = f"(({arg}) => {setter}({value_str}))" - else: - setter = f"(() => {setter}({value_str}))" - return Var( - _js_expr=setter, - _var_data=_var_data, - ).to(FunctionVar, EventChain) + args_names=(re.sub(r"\[\".*\"\]", "", value_str)) + if value_str.startswith("_") + else (), + return_expr=setter.call(value_var), + ) + + return setter.to(FunctionVar, EventChain) @property def set(self) -> Var: