This commit is contained in:
Khaleel Al-Adhami 2024-10-16 12:27:05 -07:00
commit 0111293147
4 changed files with 142 additions and 89 deletions

View File

@ -22,6 +22,7 @@ from typing import (
TypeVar,
Union,
get_type_hints,
overload,
)
from typing_extensions import ParamSpec, get_args, get_origin
@ -32,14 +33,17 @@ from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
from reflex.utils.types import ArgsSpec, GenericType
from reflex.vars import VarData
from reflex.vars.base import (
CachedVarOperation,
LiteralNoneVar,
LiteralVar,
ToOperation,
Var,
cached_property_no_lock,
)
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar
from reflex.vars.function import (
ArgsFunctionOperation,
FunctionStringVar,
FunctionVar,
VarOperationCall,
)
from reflex.vars.object import ObjectVar
try:
@ -1041,7 +1045,8 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str):
deprecation_version="0.6.3",
removal_version="0.7.0",
)
return JavascriptInputEvent
# Allow arbitrary attribute access two levels deep until removed.
return Dict[str, dict]
return annotation
@ -1258,7 +1263,7 @@ class EventVar(ObjectVar):
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
"""A literal event var."""
_var_value: EventSpec = dataclasses.field(default=None) # type: ignore
@ -1271,35 +1276,6 @@ class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
"""
return hash((self.__class__.__name__, self._js_expr))
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return str(
FunctionStringVar("Event").call(
# event handler name
".".join(
filter(
None,
format.get_event_handler_parts(self._var_value.handler),
)
),
# event handler args
{str(name): value for name, value in self._var_value.args},
# event actions
self._var_value.event_actions,
# client handler name
*(
[self._var_value.client_handler_name]
if self._var_value.client_handler_name
else []
),
)
)
@classmethod
def create(
cls,
@ -1320,6 +1296,22 @@ class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
_var_type=EventSpec,
_var_data=_var_data,
_var_value=value,
_func=FunctionStringVar("Event"),
_args=(
# event handler name
".".join(
filter(
None,
format.get_event_handler_parts(value.handler),
)
),
# event handler args
{str(name): value for name, value in value.args},
# event actions
value.event_actions,
# client handler name
*([value.client_handler_name] if value.client_handler_name else []),
),
)
@ -1332,7 +1324,10 @@ class EventChainVar(FunctionVar):
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
# Note: LiteralVar is second in the inheritance list allowing it act like a
# CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
# _cached_var_name property.
class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
"""A literal event chain var."""
_var_value: EventChain = dataclasses.field(default=None) # type: ignore
@ -1345,41 +1340,6 @@ class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
"""
return hash((self.__class__.__name__, self._js_expr))
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
sig = inspect.signature(self._var_value.args_spec) # type: ignore
if sig.parameters:
arg_def = tuple((f"_{p}" for p in sig.parameters))
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
else:
# add a default argument for addEvents if none were specified in value.args_spec
# used to trigger the preventDefault() on the event.
arg_def = ("...args",)
arg_def_expr = Var(_js_expr="args")
if self._var_value.invocation is None:
invocation = FunctionStringVar.create("addEvents")
else:
invocation = self._var_value.invocation
return str(
ArgsFunctionOperation.create(
arg_def,
invocation.call(
LiteralVar.create(
[LiteralVar.create(event) for event in self._var_value.events]
),
arg_def_expr,
self._var_value.event_actions,
),
)
)
@classmethod
def create(
cls,
@ -1395,10 +1355,31 @@ class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
Returns:
The created LiteralEventChainVar instance.
"""
sig = inspect.signature(value.args_spec) # type: ignore
if sig.parameters:
arg_def = tuple((f"_{p}" for p in sig.parameters))
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
else:
# add a default argument for addEvents if none were specified in value.args_spec
# used to trigger the preventDefault() on the event.
arg_def = ("...args",)
arg_def_expr = Var(_js_expr="args")
if value.invocation is None:
invocation = FunctionStringVar.create("addEvents")
else:
invocation = value.invocation
return cls(
_js_expr="",
_var_type=EventChain,
_var_data=_var_data,
_args_names=arg_def,
_return_expr=invocation.call(
LiteralVar.create([LiteralVar.create(event) for event in value.events]),
arg_def_expr,
value.event_actions,
),
_var_value=value,
)
@ -1437,6 +1418,11 @@ EventType = Union[IndividualEventType[G], List[IndividualEventType[G]]]
P = ParamSpec("P")
T = TypeVar("T")
V = TypeVar("V")
V2 = TypeVar("V2")
V3 = TypeVar("V3")
V4 = TypeVar("V4")
V5 = TypeVar("V5")
if sys.version_info >= (3, 10):
from typing import Concatenate
@ -1452,7 +1438,55 @@ if sys.version_info >= (3, 10):
"""
self.func = func
def __get__(self, instance, owner) -> Callable[P, T]:
@overload
def __get__(
self: EventCallback[[V], T], instance: None, owner
) -> Callable[[Union[Var[V], V]], EventSpec]: ...
@overload
def __get__(
self: EventCallback[[V, V2], T], instance: None, owner
) -> Callable[[Union[Var[V], V], Union[Var[V2], V2]], EventSpec]: ...
@overload
def __get__(
self: EventCallback[[V, V2, V3], T], instance: None, owner
) -> Callable[
[Union[Var[V], V], Union[Var[V2], V2], Union[Var[V3], V3]],
EventSpec,
]: ...
@overload
def __get__(
self: EventCallback[[V, V2, V3, V4], T], instance: None, owner
) -> Callable[
[
Union[Var[V], V],
Union[Var[V2], V2],
Union[Var[V3], V3],
Union[Var[V4], V4],
],
EventSpec,
]: ...
@overload
def __get__(
self: EventCallback[[V, V2, V3, V4, V5], T], instance: None, owner
) -> Callable[
[
Union[Var[V], V],
Union[Var[V2], V2],
Union[Var[V3], V3],
Union[Var[V4], V4],
Union[Var[V5], V5],
],
EventSpec,
]: ...
@overload
def __get__(self, instance, owner) -> Callable[P, T]: ...
def __get__(self, instance, owner) -> Callable:
"""Get the function with the instance bound to it.
Args:

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import atexit
import os
import webbrowser
from pathlib import Path
from typing import List, Optional
@ -586,18 +585,6 @@ def deploy(
)
@cli.command()
def demo(
frontend_port: str = typer.Option(
"3001", help="Specify a different frontend port."
),
backend_port: str = typer.Option("8001", help="Specify a different backend port."),
):
"""Run the demo app."""
# Open the demo app in a terminal.
webbrowser.open("https://demo.reflex.run")
cli.add_typer(db_cli, name="db", help="Subcommands for managing the database schema.")
cli.add_typer(script_cli, name="script", help="Subcommands running helper scripts.")
cli.add_typer(

View File

@ -2566,9 +2566,9 @@ class StateManager(Base, ABC):
The state manager (either disk, memory or redis).
"""
config = get_config()
if config.state_manager_mode == constants.StateManagerMode.DISK:
return StateManagerMemory(state=state)
if config.state_manager_mode == constants.StateManagerMode.MEMORY:
return StateManagerMemory(state=state)
if config.state_manager_mode == constants.StateManagerMode.DISK:
return StateManagerDisk(state=state)
if config.state_manager_mode == constants.StateManagerMode.REDIS:
redis = prerequisites.get_redis()

View File

@ -2,11 +2,18 @@ from typing import List
import pytest
from reflex import event
from reflex.event import Event, EventHandler, EventSpec, call_event_handler, fix_events
from reflex.event import (
Event,
EventChain,
EventHandler,
EventSpec,
call_event_handler,
event,
fix_events,
)
from reflex.state import BaseState
from reflex.utils import format
from reflex.vars.base import LiteralVar, Var
from reflex.vars.base import Field, LiteralVar, Var, field
def make_var(value) -> Var:
@ -388,3 +395,28 @@ def test_event_actions_on_state():
assert sp_handler.event_actions == {"stopPropagation": True}
# should NOT affect other references to the handler
assert not handler.event_actions
def test_event_var_data():
class S(BaseState):
x: Field[int] = field(0)
@event
def s(self, value: int):
pass
# Handler doesn't have any _var_data because it's just a str
handler_var = Var.create(S.s)
assert handler_var._get_all_var_data() is None
# Ensure spec carries _var_data
spec_var = Var.create(S.s(S.x))
assert spec_var._get_all_var_data() == S.x._get_all_var_data()
# Needed to instantiate the EventChain
def _args_spec(value: Var[int]) -> tuple[Var[int]]:
return (value,)
# Ensure chain carries _var_data
chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec))
assert chain_var._get_all_var_data() == S.x._get_all_var_data()