From bb44d51f2fe7383a851990185aeaaa8e6a276c9d Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 7 Jun 2024 13:29:52 -0700 Subject: [PATCH] [REF-2879] Make client_state work without global refs (local only) (#3379) * Make client_state work without global refs (local only) * client_state: if the default is str, mark _var_is_string=True Ensure that a string default is not rendered literally * add `to_int` as a Var operation * Allow an event handler lambda to return a Var in some cases If an lambda is passed to an event trigger and it returns a single Var, then treat it like the Var was directly passed for the event trigger. This allows ClientState.set_var to be used within a lambda. --- reflex/components/component.py | 14 +++- reflex/event.py | 16 ++-- reflex/experimental/client_state.py | 122 ++++++++++++++++++++-------- reflex/utils/format.py | 9 +- reflex/vars.py | 8 ++ 5 files changed, 129 insertions(+), 40 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index b039ebe0d..e2d8e22c7 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -519,13 +519,23 @@ class Component(BaseComponent, ABC): events.append(event) elif isinstance(v, Callable): # Call the lambda to get the event chain. - events.extend(call_event_fn(v, args_spec)) + result = call_event_fn(v, args_spec) + if isinstance(result, Var): + raise ValueError( + f"Invalid event chain: {v}. Cannot use a Var-returning " + "lambda inside an EventChain list." + ) + events.extend(result) else: raise ValueError(f"Invalid event: {v}") # If the input is a callable, create an event chain. elif isinstance(value, Callable): - events = call_event_fn(value, args_spec) + result = call_event_fn(value, args_spec) + if isinstance(result, Var): + # Recursively call this function if the lambda returned an EventChain Var. + return self._create_event_chain(args_spec, result) + events = result # Otherwise, raise an error. else: diff --git a/reflex/event.py b/reflex/event.py index 01abb64b9..cc2d6b42b 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -415,6 +415,8 @@ class FileUpload(Base): ) # type: ignore else: raise ValueError(f"{on_upload_progress} is not a valid event handler.") + if isinstance(events, Var): + raise ValueError(f"{on_upload_progress} cannot return a var {events}.") on_upload_progress_chain = EventChain( events=events, args_spec=self.on_upload_progress_args_spec, @@ -831,19 +833,19 @@ def parse_args_spec(arg_spec: ArgsSpec): ) -def call_event_fn(fn: Callable, arg: Union[Var, ArgsSpec]) -> list[EventSpec]: +def call_event_fn(fn: Callable, arg: Union[Var, ArgsSpec]) -> list[EventSpec] | Var: """Call a function to a list of event specs. - The function should return either a single EventSpec or a list of EventSpecs. - If the function takes in an arg, the arg will be passed to the function. - Otherwise, the function will be called with no args. + The function should return a single EventSpec, a list of EventSpecs, or a + single Var. If the function takes in an arg, the arg will be passed to the + function. Otherwise, the function will be called with no args. Args: fn: The function to call. arg: The argument to pass to the function. Returns: - The event specs from calling the function. + The event specs from calling the function or a Var. Raises: EventHandlerValueError: If the lambda has an invalid signature. @@ -866,6 +868,10 @@ def call_event_fn(fn: Callable, arg: Union[Var, ArgsSpec]) -> list[EventSpec]: else: raise EventHandlerValueError(f"Lambda {fn} must have 0 or 1 arguments.") + # If the function returns a Var, assume it's an EventChain and render it directly. + if isinstance(out, Var): + return out + # Convert the output to a list. if not isinstance(out, List): out = [out] diff --git a/reflex/experimental/client_state.py b/reflex/experimental/client_state.py index 9282c4721..39f61aa7c 100644 --- a/reflex/experimental/client_state.py +++ b/reflex/experimental/client_state.py @@ -1,4 +1,5 @@ """Handle client side state with `useState`.""" +from __future__ import annotations import dataclasses import sys @@ -9,6 +10,13 @@ from reflex.event import EventChain, EventHandler, EventSpec, call_script from reflex.utils.imports import ImportVar from reflex.vars import Var, VarData +NoValue = object() + + +_refs_import = { + f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")], +} + def _client_state_ref(var_name: str) -> str: """Get the ref path for a ClientStateVar. @@ -36,6 +44,9 @@ class ClientStateVar(Var): _setter_name: str = dataclasses.field() _getter_name: str = dataclasses.field() + # Whether to add the var and setter to the global `refs` object for use in any Component. + _global_ref: bool = dataclasses.field(default=True) + # The type of the var. _var_type: Type = dataclasses.field(default=Any) @@ -62,7 +73,9 @@ class ClientStateVar(Var): ) @classmethod - def create(cls, var_name, default=None) -> "ClientStateVar": + def create( + cls, var_name: str, default: Any = NoValue, global_ref: bool = True + ) -> "ClientStateVar": """Create a local_state Var that can be accessed and updated on the client. The `ClientStateVar` should be included in the highest parent component @@ -72,7 +85,7 @@ class ClientStateVar(Var): To render the var in a component, use the `value` property. - To update the var in a component, use the `set` property. + To update the var in a component, use the `set` property or `set_value` method. To access the var in an event handler, use the `retrieve` method with `callback` set to the event handler which should receive the value. @@ -83,36 +96,45 @@ class ClientStateVar(Var): Args: var_name: The name of the variable. default: The default value of the variable. + global_ref: Whether the state should be accessible in any Component and on the backend. Returns: ClientStateVar """ - if default is None: + assert isinstance(var_name, str), "var_name must be a string." + if default is NoValue: default_var = Var.create_safe("", _var_is_local=False, _var_is_string=False) elif not isinstance(default, Var): - default_var = Var.create_safe(default) + default_var = Var.create_safe( + default, + _var_is_string=isinstance(default, str), + ) else: default_var = default setter_name = f"set{var_name.capitalize()}" + hooks = { + f"const [{var_name}, {setter_name}] = useState({default_var._var_name_unwrapped})": None, + } + imports = { + "react": [ImportVar(tag="useState")], + } + if global_ref: + hooks[f"{_client_state_ref(var_name)} = {var_name}"] = None + hooks[f"{_client_state_ref(setter_name)} = {setter_name}"] = None + imports.update(_refs_import) return cls( _var_name="", _setter_name=setter_name, _getter_name=var_name, + _global_ref=global_ref, _var_is_local=False, _var_is_string=False, _var_type=default_var._var_type, _var_data=VarData.merge( default_var._var_data, VarData( # type: ignore - hooks={ - f"const [{var_name}, {setter_name}] = useState({default_var._var_name_unwrapped})": None, - f"{_client_state_ref(var_name)} = {var_name}": None, - f"{_client_state_ref(setter_name)} = {setter_name}": None, - }, - imports={ - "react": [ImportVar(tag="useState", install=False)], - f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")], - }, + hooks=hooks, + imports=imports, ), ), ) @@ -130,16 +152,56 @@ class ClientStateVar(Var): """ return ( Var.create_safe( - _client_state_ref(self._getter_name), + _client_state_ref(self._getter_name) + if self._global_ref + else self._getter_name, _var_is_local=False, _var_is_string=False, ) .to(self._var_type) ._replace( merge_var_data=VarData( # type: ignore - imports={ - f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")], - } + imports=_refs_import if self._global_ref else {} + ) + ) + ) + + def set_value(self, value: Any = NoValue) -> Var: + """Set the value of the client state variable. + + This property can only be attached to a frontend event trigger. + + To set a value from a backend event handler, see `push`. + + Args: + value: The value to set. + + Returns: + A special EventChain Var which will set the value when triggered. + """ + setter = ( + _client_state_ref(self._setter_name) + if self._global_ref + else self._setter_name + ) + if value is not NoValue: + # This is a hack to make it work like an EventSpec taking an arg + value = Var.create_safe(value, _var_is_string=isinstance(value, str)) + if not value._var_is_string and value._var_full_name.startswith("_"): + arg = value._var_name_unwrapped + else: + arg = "" + setter = f"({arg}) => {setter}({value._var_name_unwrapped})" + return ( + Var.create_safe( + setter, + _var_is_local=False, + _var_is_string=False, + ) + .to(EventChain) + ._replace( + merge_var_data=VarData( # type: ignore + imports=_refs_import if self._global_ref else {} ) ) ) @@ -155,21 +217,7 @@ class ClientStateVar(Var): Returns: A special EventChain Var which will set the value when triggered. """ - return ( - Var.create_safe( - _client_state_ref(self._setter_name), - _var_is_local=False, - _var_is_string=False, - ) - .to(EventChain) - ._replace( - merge_var_data=VarData( # type: ignore - imports={ - f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")], - } - ) - ) - ) + return self.set_value() def retrieve( self, callback: Union[EventHandler, Callable, None] = None @@ -183,7 +231,12 @@ class ClientStateVar(Var): Returns: An EventSpec which will retrieve the value when triggered. + + Raises: + ValueError: If the ClientStateVar is not global. """ + if not self._global_ref: + raise ValueError("ClientStateVar must be global to retrieve the value.") return call_script(_client_state_ref(self._getter_name), callback=callback) def push(self, value: Any) -> EventSpec: @@ -196,5 +249,10 @@ class ClientStateVar(Var): Returns: An EventSpec which will push the value when triggered. + + Raises: + ValueError: If the ClientStateVar is not global. """ + if not self._global_ref: + raise ValueError("ClientStateVar must be global to push the value.") return call_script(f"{_client_state_ref(self._setter_name)}({value})") diff --git a/reflex/utils/format.py b/reflex/utils/format.py index f3c2e63de..9df51b433 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -612,6 +612,9 @@ def format_queue_events( Returns: The compiled javascript callback to queue the given events on the frontend. + + Raises: + ValueError: If a lambda function is given which returns a Var. """ from reflex.event import ( EventChain, @@ -648,7 +651,11 @@ def format_queue_events( if isinstance(spec, (EventHandler, EventSpec)): specs = [call_event_handler(spec, args_spec or _default_args_spec)] elif isinstance(spec, type(lambda: None)): - specs = call_event_fn(spec, args_spec or _default_args_spec) + specs = call_event_fn(spec, args_spec or _default_args_spec) # type: ignore + if isinstance(specs, Var): + raise ValueError( + f"Invalid event spec: {specs}. Expected a list of EventSpecs." + ) payloads.extend(format_event(s) for s in specs) # Return the final code snippet, expecting queueEvents, processEvent, and socket to be in scope. diff --git a/reflex/vars.py b/reflex/vars.py index ce51c0324..83ebfea68 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -552,6 +552,14 @@ class Var: fn = "JSON.stringify" if json else "String" return self.operation(fn=fn, type_=str) + def to_int(self) -> Var: + """Convert a var to an int. + + Returns: + The parseInt var. + """ + return self.operation(fn="parseInt", type_=int) + def __hash__(self) -> int: """Define a hash function for a var.