Compare commits

...

6 Commits

Author SHA1 Message Date
Lendemor
76dce8bba3 Merge branch 'main' into lendemor/builtins_states 2024-10-25 12:49:36 +02:00
Lendemor
7320425285 Merge branch 'main' into lendemor/builtins_states 2024-10-23 17:28:15 +02:00
Lendemor
703a6da3d7 Merge branch 'main' into lendemor/builtins_states 2024-10-22 22:27:10 +02:00
Lendemor
bec66b894e fix tests and hardcoded constants 2024-10-22 21:05:10 +02:00
Lendemor
9f3752ea74 Merge branch 'main' into lendemor/builtins_states 2024-10-22 20:53:52 +02:00
Lendemor
4e8169e1ed move builtins states to their own file 2024-10-22 20:53:10 +02:00
23 changed files with 255 additions and 232 deletions

View File

@ -320,6 +320,7 @@ _MAPPING: dict = {
"upload_files",
"window_alert",
],
"istate.builtins": ["ComponentState", "State"],
"istate.storage": [
"Cookie",
"LocalStorage",
@ -329,8 +330,6 @@ _MAPPING: dict = {
"model": ["session", "Model"],
"state": [
"var",
"ComponentState",
"State",
"dynamic",
],
"style": ["Style", "toggle_color_mode"],

View File

@ -174,6 +174,8 @@ from .event import stop_propagation as stop_propagation
from .event import upload_files as upload_files
from .event import window_alert as window_alert
from .experimental import _x as _x
from .istate.builtins import ComponentState as ComponentState
from .istate.builtins import State as State
from .istate.storage import Cookie as Cookie
from .istate.storage import LocalStorage as LocalStorage
from .istate.storage import SessionStorage as SessionStorage
@ -182,8 +184,6 @@ from .middleware import middleware as middleware
from .model import Model as Model
from .model import session as session
from .page import page as page
from .state import ComponentState as ComponentState
from .state import State as State
from .state import dynamic as dynamic
from .state import var as var
from .style import Style as Style

View File

@ -66,6 +66,7 @@ from reflex.components.core.upload import Upload, get_upload_dir
from reflex.components.radix import themes
from reflex.config import environment, get_config
from reflex.event import Event, EventHandler, EventSpec, window_alert
from reflex.istate.builtins import State
from reflex.model import Model, get_db_status
from reflex.page import (
DECORATED_PAGES,
@ -78,7 +79,6 @@ from reflex.route import (
from reflex.state import (
BaseState,
RouterData,
State,
StateManager,
StateUpdate,
_substate_key,

View File

@ -8,7 +8,7 @@ from reflex.compiler.compiler import _compile_component
from reflex.components.component import Component
from reflex.components.el import div, p
from reflex.event import EventHandler
from reflex.state import FrontendEventExceptionState
from reflex.istate.builtins import FrontendEventExceptionState
from reflex.vars.base import Var

View File

@ -21,7 +21,7 @@ from typing import (
Union,
)
import reflex.state
import reflex.istate.builtins
from reflex.base import Base
from reflex.compiler.templates import STATEFUL_COMPONENT
from reflex.components.core.breakpoints import Breakpoints
@ -225,7 +225,7 @@ class Component(BaseComponent, ABC):
_memoization_mode: MemoizationMode = MemoizationMode()
# State class associated with this component instance
State: Optional[Type[reflex.state.State]] = None
State: Optional[Type[reflex.istate.builtins.State]] = None
def add_imports(self) -> ImportDict | list[ImportDict]:
"""Add imports for the component.

View File

@ -9,7 +9,7 @@ from reflex.components.base.fragment import Fragment
from reflex.components.component import Component
from reflex.components.tags import IterTag
from reflex.constants import MemoizationMode
from reflex.state import ComponentState
from reflex.istate.builtins import ComponentState
from reflex.vars.base import LiteralVar, Var

View File

@ -62,16 +62,20 @@ class CompileVars(SimpleNamespace):
# The name of the function for converting a dict to an event.
TO_EVENT = "Event"
# The name of the internal on_load event.
ON_LOAD_INTERNAL = "reflex___state____on_load_internal_state.on_load_internal"
ON_LOAD_INTERNAL = (
"reflex___istate___builtins____on_load_internal_state.on_load_internal"
)
# The name of the internal event to update generic state vars.
UPDATE_VARS_INTERNAL = (
"reflex___state____update_vars_internal_state.update_vars_internal"
"reflex___istate___builtins____update_vars_internal_state.update_vars_internal"
)
# The name of the frontend event exception state
FRONTEND_EXCEPTION_STATE = "reflex___state____frontend_event_exception_state"
FRONTEND_EXCEPTION_STATE = (
"reflex___istate___builtins____frontend_event_exception_state"
)
# The full name of the frontend exception state
FRONTEND_EXCEPTION_STATE_FULL = (
f"reflex___state____state.{FRONTEND_EXCEPTION_STATE}"
f"reflex___istate___builtins____state.{FRONTEND_EXCEPTION_STATE}"
)

View File

@ -14,7 +14,7 @@ from reflex.components.radix.themes.layout.container import Container
from reflex.components.radix.themes.layout.stack import HStack
from reflex.event import call_script
from reflex.experimental import hooks
from reflex.state import ComponentState
from reflex.istate.builtins import ComponentState
from reflex.style import Style
from reflex.vars.base import Var

View File

@ -11,7 +11,7 @@ from reflex.components.component import Component, ComponentNamespace, Memoizati
from reflex.components.radix.primitives.drawer import DrawerRoot
from reflex.components.radix.themes.layout.box import Box
from reflex.event import EventType
from reflex.state import ComponentState
from reflex.istate.builtins import ComponentState
from reflex.style import Style
from reflex.vars.base import Var

207
reflex/istate/builtins.py Normal file
View File

@ -0,0 +1,207 @@
"""The built-in states used by reflex apps."""
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, Any, ClassVar, Type
import reflex.istate.dynamic
from reflex import constants, event
from reflex.event import Event, EventSpec, fix_events
from reflex.state import BaseState
from reflex.utils import prerequisites
if TYPE_CHECKING:
from reflex.components.component import Component
class State(BaseState):
"""The app Base State."""
# The hydrated bool.
is_hydrated: bool = False
class FrontendEventExceptionState(State):
"""Substate for handling frontend exceptions."""
@event
def handle_frontend_exception(self, stack: str, component_stack: str) -> None:
"""Handle frontend exceptions.
If a frontend exception handler is provided, it will be called.
Otherwise, the default frontend exception handler will be called.
Args:
stack: The stack trace of the exception.
component_stack: The stack trace of the component where the exception occurred.
"""
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance.frontend_exception_handler(Exception(stack))
class UpdateVarsInternalState(State):
"""Substate for handling internal state var updates."""
async def update_vars_internal(self, vars: dict[str, Any]) -> None:
"""Apply updates to fully qualified state vars.
The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
and each value will be set on the appropriate substate instance.
This function is primarily used to apply cookie and local storage
updates from the frontend to the appropriate substate.
Args:
vars: The fully qualified vars and values to update.
"""
for var, value in vars.items():
state_name, _, var_name = var.rpartition(".")
var_state_cls = State.get_class_substate(state_name)
var_state = await self.get_state(var_state_cls)
setattr(var_state, var_name, value)
class OnLoadInternalState(State):
"""Substate for handling on_load event enumeration.
This is a separate substate to avoid deserializing the entire state tree for every page navigation.
"""
def on_load_internal(self) -> list[Event | EventSpec] | None:
"""Queue on_load handlers for the current page.
Returns:
The list of events to queue for on load handling.
"""
# Do not app._compile()! It should be already compiled by now.
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
load_events = app.get_load_events(self.router.page.path)
if not load_events:
self.is_hydrated = True
return # Fast path for navigation with no on_load events defined.
self.is_hydrated = False
return [
*fix_events(
load_events,
self.router.session.client_token,
router_data=self.router_data,
),
State.set_is_hydrated(True), # type: ignore
]
class ComponentState(State, mixin=True):
"""Base class to allow for the creation of a state instance per component.
This allows for the bundling of UI and state logic into a single class,
where each instance has a separate instance of the state.
Subclass this class and define vars and event handlers in the traditional way.
Then define a `get_component` method that returns the UI for the component instance.
See the full [docs](https://reflex.dev/docs/substates/component-state/) for more.
Basic example:
```python
# Subclass ComponentState and define vars and event handlers.
class Counter(rx.ComponentState):
# Define vars that change.
count: int = 0
# Define event handlers.
def increment(self):
self.count += 1
def decrement(self):
self.count -= 1
@classmethod
def get_component(cls, **props):
# Access the state vars and event handlers using `cls`.
return rx.hstack(
rx.button("Decrement", on_click=cls.decrement),
rx.text(cls.count),
rx.button("Increment", on_click=cls.increment),
**props,
)
counter = Counter.create()
```
"""
# The number of components created from this class.
_per_component_state_instance_count: ClassVar[int] = 0
@classmethod
def __init_subclass__(cls, mixin: bool = True, **kwargs):
"""Overwrite mixin default to True.
Args:
mixin: Whether the subclass is a mixin and should not be initialized.
**kwargs: The kwargs to pass to the pydantic init_subclass method.
"""
super().__init_subclass__(mixin=mixin, **kwargs)
@classmethod
def get_component(cls, *children, **props) -> "Component":
"""Get the component instance.
Args:
children: The children of the component.
props: The props of the component.
Raises:
NotImplementedError: if the subclass does not override this method.
"""
raise NotImplementedError(
f"{cls.__name__} must implement get_component to return the component instance."
)
@classmethod
def create(cls, *children, **props) -> "Component":
"""Create a new instance of the Component.
Args:
children: The children of the component.
props: The props of the component.
Returns:
A new instance of the Component with an independent copy of the State.
"""
cls._per_component_state_instance_count += 1
state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
component_state = type(
state_cls_name,
(cls, State),
{"__module__": reflex.istate.dynamic.__name__},
mixin=False,
)
# Save a reference to the dynamic state for pickle/unpickle.
setattr(reflex.istate.dynamic, state_cls_name, component_state)
component = component_state.get_component(*children, **props)
component.State = component_state
return component
def reload_state_module(
module: str,
state: Type[BaseState] = State,
) -> None:
"""Reset rx.State subclasses to avoid conflict when reloading.
Args:
module: The module to reload.
state: Recursive argument for the state class to reload.
"""
for subclass in tuple(state.class_subclasses):
reload_state_module(module=module, state=subclass)
if subclass.__module__ == module and module is not None:
state.class_subclasses.remove(subclass)
state._always_dirty_substates.discard(subclass.get_name())
state._computed_var_dependencies = defaultdict(set)
state._substate_var_dependencies = defaultdict(set)
state._init_var_dependency_dicts()
state.get_class_substate.cache_clear()

View File

@ -40,7 +40,6 @@ from typing import (
from sqlalchemy.orm import DeclarativeBase
from typing_extensions import Self
from reflex import event
from reflex.config import get_config
from reflex.istate.data import RouterData
from reflex.istate.storage import (
@ -94,13 +93,12 @@ from reflex.utils.serializers import serializer
from reflex.utils.types import get_origin, override
from reflex.vars import VarData
if TYPE_CHECKING:
from reflex.components.component import Component
Delta = Dict[str, Any]
var = computed_var
if TYPE_CHECKING:
from reflex.components.component import Component
# If the state is this large, it's considered a performance issue.
TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb
@ -899,11 +897,13 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if len(path) == 0:
return cls
print("get_name: ", cls.get_name())
if path[0] == cls.get_name():
if len(path) == 1:
return cls
path = path[1:]
for substate in cls.get_substates():
print("substate get_name: ", substate.get_name())
if path[0] == substate.get_name():
return substate.get_class_substate(path[1:])
raise ValueError(f"Invalid path: {path}")
@ -2112,13 +2112,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return state
class State(BaseState):
"""The app Base State."""
# The hydrated bool.
is_hydrated: bool = False
T = TypeVar("T", bound=BaseState)
@ -2164,169 +2157,6 @@ def dynamic(func: Callable[[T], Component]):
return wrapper
class FrontendEventExceptionState(State):
"""Substate for handling frontend exceptions."""
@event
def handle_frontend_exception(self, stack: str, component_stack: str) -> None:
"""Handle frontend exceptions.
If a frontend exception handler is provided, it will be called.
Otherwise, the default frontend exception handler will be called.
Args:
stack: The stack trace of the exception.
component_stack: The stack trace of the component where the exception occurred.
"""
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance.frontend_exception_handler(Exception(stack))
class UpdateVarsInternalState(State):
"""Substate for handling internal state var updates."""
async def update_vars_internal(self, vars: dict[str, Any]) -> None:
"""Apply updates to fully qualified state vars.
The keys in `vars` should be in the form of `{state.get_full_name()}.{var_name}`,
and each value will be set on the appropriate substate instance.
This function is primarily used to apply cookie and local storage
updates from the frontend to the appropriate substate.
Args:
vars: The fully qualified vars and values to update.
"""
for var, value in vars.items():
state_name, _, var_name = var.rpartition(".")
var_state_cls = State.get_class_substate(state_name)
var_state = await self.get_state(var_state_cls)
setattr(var_state, var_name, value)
class OnLoadInternalState(State):
"""Substate for handling on_load event enumeration.
This is a separate substate to avoid deserializing the entire state tree for every page navigation.
"""
def on_load_internal(self) -> list[Event | EventSpec] | None:
"""Queue on_load handlers for the current page.
Returns:
The list of events to queue for on load handling.
"""
# Do not app._compile()! It should be already compiled by now.
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
load_events = app.get_load_events(self.router.page.path)
if not load_events:
self.is_hydrated = True
return # Fast path for navigation with no on_load events defined.
self.is_hydrated = False
return [
*fix_events(
load_events,
self.router.session.client_token,
router_data=self.router_data,
),
State.set_is_hydrated(True), # type: ignore
]
class ComponentState(State, mixin=True):
"""Base class to allow for the creation of a state instance per component.
This allows for the bundling of UI and state logic into a single class,
where each instance has a separate instance of the state.
Subclass this class and define vars and event handlers in the traditional way.
Then define a `get_component` method that returns the UI for the component instance.
See the full [docs](https://reflex.dev/docs/substates/component-state/) for more.
Basic example:
```python
# Subclass ComponentState and define vars and event handlers.
class Counter(rx.ComponentState):
# Define vars that change.
count: int = 0
# Define event handlers.
def increment(self):
self.count += 1
def decrement(self):
self.count -= 1
@classmethod
def get_component(cls, **props):
# Access the state vars and event handlers using `cls`.
return rx.hstack(
rx.button("Decrement", on_click=cls.decrement),
rx.text(cls.count),
rx.button("Increment", on_click=cls.increment),
**props,
)
counter = Counter.create()
```
"""
# The number of components created from this class.
_per_component_state_instance_count: ClassVar[int] = 0
@classmethod
def __init_subclass__(cls, mixin: bool = True, **kwargs):
"""Overwrite mixin default to True.
Args:
mixin: Whether the subclass is a mixin and should not be initialized.
**kwargs: The kwargs to pass to the pydantic init_subclass method.
"""
super().__init_subclass__(mixin=mixin, **kwargs)
@classmethod
def get_component(cls, *children, **props) -> "Component":
"""Get the component instance.
Args:
children: The children of the component.
props: The props of the component.
Raises:
NotImplementedError: if the subclass does not override this method.
"""
raise NotImplementedError(
f"{cls.__name__} must implement get_component to return the component instance."
)
@classmethod
def create(cls, *children, **props) -> "Component":
"""Create a new instance of the Component.
Args:
children: The children of the component.
props: The props of the component.
Returns:
A new instance of the Component with an independent copy of the State.
"""
cls._per_component_state_instance_count += 1
state_cls_name = f"{cls.__name__}_n{cls._per_component_state_instance_count}"
component_state = type(
state_cls_name,
(cls, State),
{"__module__": reflex.istate.dynamic.__name__},
mixin=False,
)
# Save a reference to the dynamic state for pickle/unpickle.
setattr(reflex.istate.dynamic, state_cls_name, component_state)
component = component_state.get_component(*children, **props)
component.State = component_state
return component
class StateProxy(wrapt.ObjectProxy):
"""Proxy of a state instance to control mutability of vars for a background task.
@ -2408,7 +2238,7 @@ class StateProxy(wrapt.ObjectProxy):
super().__setattr__(
"__wrapped__",
await parent_state.get_state(
State.get_class_substate(self._self_substate_path)
BaseState.get_class_substate(self._self_substate_path)
),
)
return self
@ -3743,25 +3573,3 @@ def code_uses_state_contexts(javascript_code: str) -> bool:
True if the code attempts to access a member of StateContexts.
"""
return bool("useContext(StateContexts" in javascript_code)
def reload_state_module(
module: str,
state: Type[BaseState] = State,
) -> None:
"""Reset rx.State subclasses to avoid conflict when reloading.
Args:
module: The module to reload.
state: Recursive argument for the state class to reload.
"""
for subclass in tuple(state.class_subclasses):
reload_state_module(module=module, state=subclass)
if subclass.__module__ == module and module is not None:
state.class_subclasses.remove(subclass)
state._always_dirty_substates.discard(subclass.get_name())
state._computed_var_dependencies = defaultdict(set)
state._substate_var_dependencies = defaultdict(set)
state._init_var_dependency_dicts()
state.get_class_substate.cache_clear()

View File

@ -43,13 +43,13 @@ import reflex.utils.exec
import reflex.utils.format
import reflex.utils.prerequisites
import reflex.utils.processes
from reflex.istate.builtins import reload_state_module
from reflex.state import (
BaseState,
StateManager,
StateManagerDisk,
StateManagerMemory,
StateManagerRedis,
reload_state_module,
)
try:

View File

@ -429,7 +429,7 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:
# Get the function name
name = parts[-1]
from reflex.state import State
from reflex.istate.builtins import State
if state_full_name == "state" and name not in State.__dict__:
return ("", to_snake_case(handler.fn.__qualname__))

View File

@ -279,7 +279,7 @@ def get_app(reload: bool = False) -> ModuleType:
app = __import__(module, fromlist=(constants.CompileVars.APP,))
if reload:
from reflex.state import reload_state_module
from reflex.istate.builtins import reload_state_module
# Reset rx.State subclasses to avoid conflict when reloading.
reload_state_module(module=module)

View File

@ -5,7 +5,8 @@ from typing import Generator
import pytest
from selenium.webdriver.common.by import By
from reflex.state import State, _substate_key
from reflex.istate.builtins import State
from reflex.state import _substate_key
from reflex.testing import AppHarness
from . import utils

View File

@ -12,7 +12,8 @@ from reflex.components.core.foreach import (
)
from reflex.components.radix.themes.layout.box import box
from reflex.components.radix.themes.typography.text import text
from reflex.state import BaseState, ComponentState
from reflex.istate.builtins import ComponentState
from reflex.state import BaseState
from reflex.vars.base import Var
from reflex.vars.number import NumberVar
from reflex.vars.sequence import ArrayVar

View File

@ -1,7 +1,7 @@
import pytest
from reflex.components.core.html import Html
from reflex.state import State
from reflex.istate.builtins import State
def test_html_no_children():

View File

@ -10,7 +10,7 @@ from reflex.components.core.upload import (
get_upload_url,
)
from reflex.event import EventSpec
from reflex.state import State
from reflex.istate.builtins import State
from reflex.vars.base import LiteralVar, Var

View File

@ -1,7 +1,7 @@
import pytest
from reflex.event import Event
from reflex.state import State
from reflex.istate.builtins import State
def create_event(name):

View File

@ -3,8 +3,9 @@ from __future__ import annotations
import pytest
from reflex.app import App
from reflex.istate.builtins import State
from reflex.middleware.hydrate_middleware import HydrateMiddleware
from reflex.state import State, StateUpdate
from reflex.state import StateUpdate
class TestState(State):
@ -38,7 +39,7 @@ async def test_preprocess_no_events(hydrate_middleware, event1, mocker):
event1: An Event.
mocker: pytest mock object.
"""
mocker.patch("reflex.state.State.class_subclasses", {TestState})
mocker.patch("reflex.istate.builtins.State.class_subclasses", {TestState})
state = State()
update = await hydrate_middleware.preprocess(
app=App(state=State),

View File

@ -4,7 +4,8 @@ from pathlib import Path
from typing import ClassVar, List
import reflex as rx
from reflex.state import BaseState, State
from reflex.istate.builtins import State
from reflex.state import BaseState
class UploadState(BaseState):

View File

@ -34,13 +34,12 @@ from reflex.components.base.fragment import Fragment
from reflex.components.core.cond import Cond
from reflex.components.radix.themes.typography.text import Text
from reflex.event import Event
from reflex.istate.builtins import OnLoadInternalState, State
from reflex.middleware import HydrateMiddleware
from reflex.model import Model
from reflex.state import (
BaseState,
OnLoadInternalState,
RouterData,
State,
StateManagerDisk,
StateManagerMemory,
StateManagerRedis,
@ -760,7 +759,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
mocker: pytest mocker object.
"""
mocker.patch(
"reflex.state.State.class_subclasses",
"reflex.istate.builtins.State.class_subclasses",
{state if state is FileUploadState else FileStateBase1},
)
state._tmp_path = tmp_path

View File

@ -24,14 +24,13 @@ from reflex.base import Base
from reflex.components.sonner.toast import Toaster
from reflex.constants import CompileVars, RouteVar, SocketEvent
from reflex.event import Event, EventHandler
from reflex.istate.builtins import OnLoadInternalState, State
from reflex.state import (
BaseState,
ImmutableStateError,
LockExpiredError,
MutableProxy,
OnLoadInternalState,
RouterData,
State,
StateManager,
StateManagerDisk,
StateManagerMemory,
@ -2779,7 +2778,8 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
mocker: pytest mock object.
"""
mocker.patch(
"reflex.state.State.class_subclasses", {test_state, OnLoadInternalState}
"reflex.istate.builtins.State.class_subclasses",
{test_state, OnLoadInternalState},
)
app = app_module_mock.app = App(
state=State, load_events={"index": [test_state.test_handler]}
@ -2825,7 +2825,8 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
mocker: pytest mock object.
"""
mocker.patch(
"reflex.state.State.class_subclasses", {OnLoadState, OnLoadInternalState}
"reflex.istate.builtins.State.class_subclasses",
{OnLoadState, OnLoadInternalState},
)
app = app_module_mock.app = App(
state=State,
@ -3231,7 +3232,8 @@ config = rx.Config(
with chdir(proj_root):
# reload config for each parameter to avoid stale values
reflex.config.get_config(reload=True)
from reflex.state import State, StateManager
from reflex.istate.builtins import State
from reflex.state import StateManager
state_manager = StateManager.create(state=State)
assert state_manager.lock_expiration == expected_values[0] # type: ignore