[REF-2879] Make client_state work without global refs (local only) (#3379)

* Make client_state work without global refs (local only)

* client_state: if the default is str, mark _var_is_string=True

Ensure that a string default is not rendered literally

* add `to_int` as a Var operation

* Allow an event handler lambda to return a Var in some cases

If an lambda is passed to an event trigger and it returns a single Var, then
treat it like the Var was directly passed for the event trigger.

This allows ClientState.set_var to be used within a lambda.
This commit is contained in:
Masen Furer 2024-06-07 13:29:52 -07:00 committed by GitHub
parent e138d9dfd0
commit bb44d51f2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 129 additions and 40 deletions

View File

@ -519,13 +519,23 @@ class Component(BaseComponent, ABC):
events.append(event)
elif isinstance(v, Callable):
# Call the lambda to get the event chain.
events.extend(call_event_fn(v, args_spec))
result = call_event_fn(v, args_spec)
if isinstance(result, Var):
raise ValueError(
f"Invalid event chain: {v}. Cannot use a Var-returning "
"lambda inside an EventChain list."
)
events.extend(result)
else:
raise ValueError(f"Invalid event: {v}")
# If the input is a callable, create an event chain.
elif isinstance(value, Callable):
events = call_event_fn(value, args_spec)
result = call_event_fn(value, args_spec)
if isinstance(result, Var):
# Recursively call this function if the lambda returned an EventChain Var.
return self._create_event_chain(args_spec, result)
events = result
# Otherwise, raise an error.
else:

View File

@ -415,6 +415,8 @@ class FileUpload(Base):
) # type: ignore
else:
raise ValueError(f"{on_upload_progress} is not a valid event handler.")
if isinstance(events, Var):
raise ValueError(f"{on_upload_progress} cannot return a var {events}.")
on_upload_progress_chain = EventChain(
events=events,
args_spec=self.on_upload_progress_args_spec,
@ -831,19 +833,19 @@ def parse_args_spec(arg_spec: ArgsSpec):
)
def call_event_fn(fn: Callable, arg: Union[Var, ArgsSpec]) -> list[EventSpec]:
def call_event_fn(fn: Callable, arg: Union[Var, ArgsSpec]) -> list[EventSpec] | Var:
"""Call a function to a list of event specs.
The function should return either a single EventSpec or a list of EventSpecs.
If the function takes in an arg, the arg will be passed to the function.
Otherwise, the function will be called with no args.
The function should return a single EventSpec, a list of EventSpecs, or a
single Var. If the function takes in an arg, the arg will be passed to the
function. Otherwise, the function will be called with no args.
Args:
fn: The function to call.
arg: The argument to pass to the function.
Returns:
The event specs from calling the function.
The event specs from calling the function or a Var.
Raises:
EventHandlerValueError: If the lambda has an invalid signature.
@ -866,6 +868,10 @@ def call_event_fn(fn: Callable, arg: Union[Var, ArgsSpec]) -> list[EventSpec]:
else:
raise EventHandlerValueError(f"Lambda {fn} must have 0 or 1 arguments.")
# If the function returns a Var, assume it's an EventChain and render it directly.
if isinstance(out, Var):
return out
# Convert the output to a list.
if not isinstance(out, List):
out = [out]

View File

@ -1,4 +1,5 @@
"""Handle client side state with `useState`."""
from __future__ import annotations
import dataclasses
import sys
@ -9,6 +10,13 @@ from reflex.event import EventChain, EventHandler, EventSpec, call_script
from reflex.utils.imports import ImportVar
from reflex.vars import Var, VarData
NoValue = object()
_refs_import = {
f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
}
def _client_state_ref(var_name: str) -> str:
"""Get the ref path for a ClientStateVar.
@ -36,6 +44,9 @@ class ClientStateVar(Var):
_setter_name: str = dataclasses.field()
_getter_name: str = dataclasses.field()
# Whether to add the var and setter to the global `refs` object for use in any Component.
_global_ref: bool = dataclasses.field(default=True)
# The type of the var.
_var_type: Type = dataclasses.field(default=Any)
@ -62,7 +73,9 @@ class ClientStateVar(Var):
)
@classmethod
def create(cls, var_name, default=None) -> "ClientStateVar":
def create(
cls, var_name: str, default: Any = NoValue, global_ref: bool = True
) -> "ClientStateVar":
"""Create a local_state Var that can be accessed and updated on the client.
The `ClientStateVar` should be included in the highest parent component
@ -72,7 +85,7 @@ class ClientStateVar(Var):
To render the var in a component, use the `value` property.
To update the var in a component, use the `set` property.
To update the var in a component, use the `set` property or `set_value` method.
To access the var in an event handler, use the `retrieve` method with
`callback` set to the event handler which should receive the value.
@ -83,36 +96,45 @@ class ClientStateVar(Var):
Args:
var_name: The name of the variable.
default: The default value of the variable.
global_ref: Whether the state should be accessible in any Component and on the backend.
Returns:
ClientStateVar
"""
if default is None:
assert isinstance(var_name, str), "var_name must be a string."
if default is NoValue:
default_var = Var.create_safe("", _var_is_local=False, _var_is_string=False)
elif not isinstance(default, Var):
default_var = Var.create_safe(default)
default_var = Var.create_safe(
default,
_var_is_string=isinstance(default, str),
)
else:
default_var = default
setter_name = f"set{var_name.capitalize()}"
hooks = {
f"const [{var_name}, {setter_name}] = useState({default_var._var_name_unwrapped})": None,
}
imports = {
"react": [ImportVar(tag="useState")],
}
if global_ref:
hooks[f"{_client_state_ref(var_name)} = {var_name}"] = None
hooks[f"{_client_state_ref(setter_name)} = {setter_name}"] = None
imports.update(_refs_import)
return cls(
_var_name="",
_setter_name=setter_name,
_getter_name=var_name,
_global_ref=global_ref,
_var_is_local=False,
_var_is_string=False,
_var_type=default_var._var_type,
_var_data=VarData.merge(
default_var._var_data,
VarData( # type: ignore
hooks={
f"const [{var_name}, {setter_name}] = useState({default_var._var_name_unwrapped})": None,
f"{_client_state_ref(var_name)} = {var_name}": None,
f"{_client_state_ref(setter_name)} = {setter_name}": None,
},
imports={
"react": [ImportVar(tag="useState", install=False)],
f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
},
hooks=hooks,
imports=imports,
),
),
)
@ -130,16 +152,56 @@ class ClientStateVar(Var):
"""
return (
Var.create_safe(
_client_state_ref(self._getter_name),
_client_state_ref(self._getter_name)
if self._global_ref
else self._getter_name,
_var_is_local=False,
_var_is_string=False,
)
.to(self._var_type)
._replace(
merge_var_data=VarData( # type: ignore
imports={
f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
}
imports=_refs_import if self._global_ref else {}
)
)
)
def set_value(self, value: Any = NoValue) -> Var:
"""Set the value of the client state variable.
This property can only be attached to a frontend event trigger.
To set a value from a backend event handler, see `push`.
Args:
value: The value to set.
Returns:
A special EventChain Var which will set the value when triggered.
"""
setter = (
_client_state_ref(self._setter_name)
if self._global_ref
else self._setter_name
)
if value is not NoValue:
# This is a hack to make it work like an EventSpec taking an arg
value = Var.create_safe(value, _var_is_string=isinstance(value, str))
if not value._var_is_string and value._var_full_name.startswith("_"):
arg = value._var_name_unwrapped
else:
arg = ""
setter = f"({arg}) => {setter}({value._var_name_unwrapped})"
return (
Var.create_safe(
setter,
_var_is_local=False,
_var_is_string=False,
)
.to(EventChain)
._replace(
merge_var_data=VarData( # type: ignore
imports=_refs_import if self._global_ref else {}
)
)
)
@ -155,21 +217,7 @@ class ClientStateVar(Var):
Returns:
A special EventChain Var which will set the value when triggered.
"""
return (
Var.create_safe(
_client_state_ref(self._setter_name),
_var_is_local=False,
_var_is_string=False,
)
.to(EventChain)
._replace(
merge_var_data=VarData( # type: ignore
imports={
f"/{constants.Dirs.STATE_PATH}": [ImportVar(tag="refs")],
}
)
)
)
return self.set_value()
def retrieve(
self, callback: Union[EventHandler, Callable, None] = None
@ -183,7 +231,12 @@ class ClientStateVar(Var):
Returns:
An EventSpec which will retrieve the value when triggered.
Raises:
ValueError: If the ClientStateVar is not global.
"""
if not self._global_ref:
raise ValueError("ClientStateVar must be global to retrieve the value.")
return call_script(_client_state_ref(self._getter_name), callback=callback)
def push(self, value: Any) -> EventSpec:
@ -196,5 +249,10 @@ class ClientStateVar(Var):
Returns:
An EventSpec which will push the value when triggered.
Raises:
ValueError: If the ClientStateVar is not global.
"""
if not self._global_ref:
raise ValueError("ClientStateVar must be global to push the value.")
return call_script(f"{_client_state_ref(self._setter_name)}({value})")

View File

@ -612,6 +612,9 @@ def format_queue_events(
Returns:
The compiled javascript callback to queue the given events on the frontend.
Raises:
ValueError: If a lambda function is given which returns a Var.
"""
from reflex.event import (
EventChain,
@ -648,7 +651,11 @@ def format_queue_events(
if isinstance(spec, (EventHandler, EventSpec)):
specs = [call_event_handler(spec, args_spec or _default_args_spec)]
elif isinstance(spec, type(lambda: None)):
specs = call_event_fn(spec, args_spec or _default_args_spec)
specs = call_event_fn(spec, args_spec or _default_args_spec) # type: ignore
if isinstance(specs, Var):
raise ValueError(
f"Invalid event spec: {specs}. Expected a list of EventSpecs."
)
payloads.extend(format_event(s) for s in specs)
# Return the final code snippet, expecting queueEvents, processEvent, and socket to be in scope.

View File

@ -552,6 +552,14 @@ class Var:
fn = "JSON.stringify" if json else "String"
return self.operation(fn=fn, type_=str)
def to_int(self) -> Var:
"""Convert a var to an int.
Returns:
The parseInt var.
"""
return self.operation(fn="parseInt", type_=int)
def __hash__(self) -> int:
"""Define a hash function for a var.