This commit is contained in:
Khaleel Al-Adhami 2024-10-16 12:27:05 -07:00
commit 0111293147
4 changed files with 142 additions and 89 deletions

View File

@ -22,6 +22,7 @@ from typing import (
TypeVar, TypeVar,
Union, Union,
get_type_hints, get_type_hints,
overload,
) )
from typing_extensions import ParamSpec, get_args, get_origin from typing_extensions import ParamSpec, get_args, get_origin
@ -32,14 +33,17 @@ from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
from reflex.utils.types import ArgsSpec, GenericType from reflex.utils.types import ArgsSpec, GenericType
from reflex.vars import VarData from reflex.vars import VarData
from reflex.vars.base import ( from reflex.vars.base import (
CachedVarOperation,
LiteralNoneVar, LiteralNoneVar,
LiteralVar, LiteralVar,
ToOperation, ToOperation,
Var, Var,
cached_property_no_lock,
) )
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar, FunctionVar from reflex.vars.function import (
ArgsFunctionOperation,
FunctionStringVar,
FunctionVar,
VarOperationCall,
)
from reflex.vars.object import ObjectVar from reflex.vars.object import ObjectVar
try: try:
@ -1041,7 +1045,8 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str):
deprecation_version="0.6.3", deprecation_version="0.6.3",
removal_version="0.7.0", removal_version="0.7.0",
) )
return JavascriptInputEvent # Allow arbitrary attribute access two levels deep until removed.
return Dict[str, dict]
return annotation return annotation
@ -1258,7 +1263,7 @@ class EventVar(ObjectVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar): class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
"""A literal event var.""" """A literal event var."""
_var_value: EventSpec = dataclasses.field(default=None) # type: ignore _var_value: EventSpec = dataclasses.field(default=None) # type: ignore
@ -1271,35 +1276,6 @@ class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
""" """
return hash((self.__class__.__name__, self._js_expr)) return hash((self.__class__.__name__, self._js_expr))
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return str(
FunctionStringVar("Event").call(
# event handler name
".".join(
filter(
None,
format.get_event_handler_parts(self._var_value.handler),
)
),
# event handler args
{str(name): value for name, value in self._var_value.args},
# event actions
self._var_value.event_actions,
# client handler name
*(
[self._var_value.client_handler_name]
if self._var_value.client_handler_name
else []
),
)
)
@classmethod @classmethod
def create( def create(
cls, cls,
@ -1320,6 +1296,22 @@ class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
_var_type=EventSpec, _var_type=EventSpec,
_var_data=_var_data, _var_data=_var_data,
_var_value=value, _var_value=value,
_func=FunctionStringVar("Event"),
_args=(
# event handler name
".".join(
filter(
None,
format.get_event_handler_parts(value.handler),
)
),
# event handler args
{str(name): value for name, value in value.args},
# event actions
value.event_actions,
# client handler name
*([value.client_handler_name] if value.client_handler_name else []),
),
) )
@ -1332,7 +1324,10 @@ class EventChainVar(FunctionVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar): # Note: LiteralVar is second in the inheritance list allowing it act like a
# CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
# _cached_var_name property.
class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
"""A literal event chain var.""" """A literal event chain var."""
_var_value: EventChain = dataclasses.field(default=None) # type: ignore _var_value: EventChain = dataclasses.field(default=None) # type: ignore
@ -1345,41 +1340,6 @@ class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
""" """
return hash((self.__class__.__name__, self._js_expr)) return hash((self.__class__.__name__, self._js_expr))
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
sig = inspect.signature(self._var_value.args_spec) # type: ignore
if sig.parameters:
arg_def = tuple((f"_{p}" for p in sig.parameters))
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
else:
# add a default argument for addEvents if none were specified in value.args_spec
# used to trigger the preventDefault() on the event.
arg_def = ("...args",)
arg_def_expr = Var(_js_expr="args")
if self._var_value.invocation is None:
invocation = FunctionStringVar.create("addEvents")
else:
invocation = self._var_value.invocation
return str(
ArgsFunctionOperation.create(
arg_def,
invocation.call(
LiteralVar.create(
[LiteralVar.create(event) for event in self._var_value.events]
),
arg_def_expr,
self._var_value.event_actions,
),
)
)
@classmethod @classmethod
def create( def create(
cls, cls,
@ -1395,10 +1355,31 @@ class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
Returns: Returns:
The created LiteralEventChainVar instance. The created LiteralEventChainVar instance.
""" """
sig = inspect.signature(value.args_spec) # type: ignore
if sig.parameters:
arg_def = tuple((f"_{p}" for p in sig.parameters))
arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def])
else:
# add a default argument for addEvents if none were specified in value.args_spec
# used to trigger the preventDefault() on the event.
arg_def = ("...args",)
arg_def_expr = Var(_js_expr="args")
if value.invocation is None:
invocation = FunctionStringVar.create("addEvents")
else:
invocation = value.invocation
return cls( return cls(
_js_expr="", _js_expr="",
_var_type=EventChain, _var_type=EventChain,
_var_data=_var_data, _var_data=_var_data,
_args_names=arg_def,
_return_expr=invocation.call(
LiteralVar.create([LiteralVar.create(event) for event in value.events]),
arg_def_expr,
value.event_actions,
),
_var_value=value, _var_value=value,
) )
@ -1437,6 +1418,11 @@ EventType = Union[IndividualEventType[G], List[IndividualEventType[G]]]
P = ParamSpec("P") P = ParamSpec("P")
T = TypeVar("T") T = TypeVar("T")
V = TypeVar("V")
V2 = TypeVar("V2")
V3 = TypeVar("V3")
V4 = TypeVar("V4")
V5 = TypeVar("V5")
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
from typing import Concatenate from typing import Concatenate
@ -1452,7 +1438,55 @@ if sys.version_info >= (3, 10):
""" """
self.func = func self.func = func
def __get__(self, instance, owner) -> Callable[P, T]: @overload
def __get__(
self: EventCallback[[V], T], instance: None, owner
) -> Callable[[Union[Var[V], V]], EventSpec]: ...
@overload
def __get__(
self: EventCallback[[V, V2], T], instance: None, owner
) -> Callable[[Union[Var[V], V], Union[Var[V2], V2]], EventSpec]: ...
@overload
def __get__(
self: EventCallback[[V, V2, V3], T], instance: None, owner
) -> Callable[
[Union[Var[V], V], Union[Var[V2], V2], Union[Var[V3], V3]],
EventSpec,
]: ...
@overload
def __get__(
self: EventCallback[[V, V2, V3, V4], T], instance: None, owner
) -> Callable[
[
Union[Var[V], V],
Union[Var[V2], V2],
Union[Var[V3], V3],
Union[Var[V4], V4],
],
EventSpec,
]: ...
@overload
def __get__(
self: EventCallback[[V, V2, V3, V4, V5], T], instance: None, owner
) -> Callable[
[
Union[Var[V], V],
Union[Var[V2], V2],
Union[Var[V3], V3],
Union[Var[V4], V4],
Union[Var[V5], V5],
],
EventSpec,
]: ...
@overload
def __get__(self, instance, owner) -> Callable[P, T]: ...
def __get__(self, instance, owner) -> Callable:
"""Get the function with the instance bound to it. """Get the function with the instance bound to it.
Args: Args:

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import atexit import atexit
import os import os
import webbrowser
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
@ -586,18 +585,6 @@ def deploy(
) )
@cli.command()
def demo(
frontend_port: str = typer.Option(
"3001", help="Specify a different frontend port."
),
backend_port: str = typer.Option("8001", help="Specify a different backend port."),
):
"""Run the demo app."""
# Open the demo app in a terminal.
webbrowser.open("https://demo.reflex.run")
cli.add_typer(db_cli, name="db", help="Subcommands for managing the database schema.") cli.add_typer(db_cli, name="db", help="Subcommands for managing the database schema.")
cli.add_typer(script_cli, name="script", help="Subcommands running helper scripts.") cli.add_typer(script_cli, name="script", help="Subcommands running helper scripts.")
cli.add_typer( cli.add_typer(

View File

@ -2566,9 +2566,9 @@ class StateManager(Base, ABC):
The state manager (either disk, memory or redis). The state manager (either disk, memory or redis).
""" """
config = get_config() config = get_config()
if config.state_manager_mode == constants.StateManagerMode.DISK:
return StateManagerMemory(state=state)
if config.state_manager_mode == constants.StateManagerMode.MEMORY: if config.state_manager_mode == constants.StateManagerMode.MEMORY:
return StateManagerMemory(state=state)
if config.state_manager_mode == constants.StateManagerMode.DISK:
return StateManagerDisk(state=state) return StateManagerDisk(state=state)
if config.state_manager_mode == constants.StateManagerMode.REDIS: if config.state_manager_mode == constants.StateManagerMode.REDIS:
redis = prerequisites.get_redis() redis = prerequisites.get_redis()

View File

@ -2,11 +2,18 @@ from typing import List
import pytest import pytest
from reflex import event from reflex.event import (
from reflex.event import Event, EventHandler, EventSpec, call_event_handler, fix_events Event,
EventChain,
EventHandler,
EventSpec,
call_event_handler,
event,
fix_events,
)
from reflex.state import BaseState from reflex.state import BaseState
from reflex.utils import format from reflex.utils import format
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import Field, LiteralVar, Var, field
def make_var(value) -> Var: def make_var(value) -> Var:
@ -388,3 +395,28 @@ def test_event_actions_on_state():
assert sp_handler.event_actions == {"stopPropagation": True} assert sp_handler.event_actions == {"stopPropagation": True}
# should NOT affect other references to the handler # should NOT affect other references to the handler
assert not handler.event_actions assert not handler.event_actions
def test_event_var_data():
class S(BaseState):
x: Field[int] = field(0)
@event
def s(self, value: int):
pass
# Handler doesn't have any _var_data because it's just a str
handler_var = Var.create(S.s)
assert handler_var._get_all_var_data() is None
# Ensure spec carries _var_data
spec_var = Var.create(S.s(S.x))
assert spec_var._get_all_var_data() == S.x._get_all_var_data()
# Needed to instantiate the EventChain
def _args_spec(value: Var[int]) -> tuple[Var[int]]:
return (value,)
# Ensure chain carries _var_data
chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec))
assert chain_var._get_all_var_data() == S.x._get_all_var_data()