RED-1052/rx.State as Base State (#2146)

This commit is contained in:
Elijah Ahianyo 2023-11-29 17:43:33 +00:00 committed by GitHub
parent f8395b1fd6
commit e3ee98098a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 356 additions and 270 deletions

View File

@ -93,7 +93,7 @@ def BackgroundTask():
rx.button("Reset", on_click=State.reset_counter, id="reset"),
)
app = rx.App(state=State)
app = rx.App(state=rx.State)
app.add_page(index)
app.compile()

View File

@ -135,7 +135,7 @@ def CallScript():
yield rx.call_script("inline_counter = 0; external_counter = 0")
self.reset()
app = rx.App(state=CallScriptState)
app = rx.App(state=rx.State)
with open("assets/external.js", "w") as f:
f.write(external_scripts)

View File

@ -97,7 +97,7 @@ def ClientSide():
rx.box(ClientSideSubSubState.l1s, id="l1s"),
)
app = rx.App(state=ClientSideState)
app = rx.App(state=rx.State)
app.add_page(index)
app.add_page(index, route="/foo")
app.compile()
@ -263,7 +263,6 @@ async def test_client_side_state(
state_var_input.send_keys("c7")
input_value_input.send_keys("c7 value")
set_sub_state_button.click()
state_var_input.send_keys("l1")
input_value_input.send_keys("l1 value")
set_sub_state_button.click()
@ -276,7 +275,6 @@ async def test_client_side_state(
state_var_input.send_keys("l4")
input_value_input.send_keys("l4 value")
set_sub_state_button.click()
state_var_input.send_keys("c1s")
input_value_input.send_keys("c1s value")
set_sub_sub_state_button.click()
@ -285,28 +283,28 @@ async def test_client_side_state(
set_sub_sub_state_button.click()
exp_cookies = {
"client_side_state.client_side_sub_state.c1": {
"state.client_side_state.client_side_sub_state.c1": {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c1",
"name": "state.client_side_state.client_side_sub_state.c1",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c1%20value",
},
"client_side_state.client_side_sub_state.c2": {
"state.client_side_state.client_side_sub_state.c2": {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c2",
"name": "state.client_side_state.client_side_sub_state.c2",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c2%20value",
},
"client_side_state.client_side_sub_state.c4": {
"state.client_side_state.client_side_sub_state.c4": {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c4",
"name": "state.client_side_state.client_side_sub_state.c4",
"path": "/",
"sameSite": "Strict",
"secure": False,
@ -321,19 +319,19 @@ async def test_client_side_state(
"secure": False,
"value": "c6%20value",
},
"client_side_state.client_side_sub_state.c7": {
"state.client_side_state.client_side_sub_state.c7": {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c7",
"name": "state.client_side_state.client_side_sub_state.c7",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c7%20value",
},
"client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": {
"state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
"name": "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s",
"path": "/",
"sameSite": "Lax",
"secure": False,
@ -354,40 +352,45 @@ async def test_client_side_state(
input_value_input.send_keys("c3 value")
set_sub_state_button.click()
AppHarness._poll_for(
lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver)
lambda: "state.client_side_state.client_side_sub_state.c3"
in cookie_info_map(driver)
)
c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"]
c3_cookie = cookie_info_map(driver)[
"state.client_side_state.client_side_sub_state.c3"
]
assert c3_cookie.pop("expiry") is not None
assert c3_cookie == {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c3",
"name": "state.client_side_state.client_side_sub_state.c3",
"path": "/",
"sameSite": "Lax",
"secure": False,
"value": "c3%20value",
}
time.sleep(2) # wait for c3 to expire
assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver)
assert "state.client_side_state.client_side_sub_state.c3" not in cookie_info_map(
driver
)
local_storage_items = local_storage.items()
local_storage_items.pop("chakra-ui-color-mode", None)
assert (
local_storage_items.pop("client_side_state.client_side_sub_state.l1")
local_storage_items.pop("state.client_side_state.client_side_sub_state.l1")
== "l1 value"
)
assert (
local_storage_items.pop("client_side_state.client_side_sub_state.l2")
local_storage_items.pop("state.client_side_state.client_side_sub_state.l2")
== "l2 value"
)
assert local_storage_items.pop("l3") == "l3 value"
assert (
local_storage_items.pop("client_side_state.client_side_sub_state.l4")
local_storage_items.pop("state.client_side_state.client_side_sub_state.l4")
== "l4 value"
)
assert (
local_storage_items.pop(
"client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s"
"state.client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s"
)
== "l1s value"
)
@ -482,12 +485,15 @@ async def test_client_side_state(
# make sure c5 cookie shows up on the `/foo` route
AppHarness._poll_for(
lambda: "client_side_state.client_side_sub_state.c5" in cookie_info_map(driver)
lambda: "state.client_side_state.client_side_sub_state.c5"
in cookie_info_map(driver)
)
assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == {
assert cookie_info_map(driver)[
"state.client_side_state.client_side_sub_state.c5"
] == {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c5",
"name": "state.client_side_state.client_side_sub_state.c5",
"path": "/foo/",
"sameSite": "Lax",
"secure": False,

View File

@ -19,7 +19,7 @@ def ConnectionBanner():
def index():
return rx.text("Hello World")
app = rx.App(state=State)
app = rx.App(state=rx.State)
app.add_page(index)
app.compile()

View File

@ -56,7 +56,7 @@ def DynamicRoute():
def redirect_page():
return rx.fragment(rx.text("redirecting..."))
app = rx.App(state=DynamicState)
app = rx.App(state=rx.State)
app.add_page(index)
app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load) # type: ignore
app.add_page(index, route="/static/x", on_load=DynamicState.on_load) # type: ignore
@ -143,10 +143,12 @@ def poll_for_order(
return await dynamic_route.get_state(token)
async def _check():
return (await _backend_state()).order == exp_order
return (await _backend_state()).substates[
"dynamic_state"
].order == exp_order
await AppHarness._poll_for_async(_check)
assert (await _backend_state()).order == exp_order
assert (await _backend_state()).substates["dynamic_state"].order == exp_order
return _poll_for_order

View File

@ -130,7 +130,7 @@ def TestEventAction():
on_click=EventActionState.on_click("outer"), # type: ignore
)
app = rx.App(state=EventActionState)
app = rx.App(state=rx.State)
app.add_page(index)
app.compile()
@ -211,10 +211,14 @@ def poll_for_order(
return await event_action.get_state(token)
async def _check():
return (await _backend_state()).order == exp_order
return (await _backend_state()).substates[
"event_action_state"
].order == exp_order
await AppHarness._poll_for_async(_check)
assert (await _backend_state()).order == exp_order
assert (await _backend_state()).substates[
"event_action_state"
].order == exp_order
return _poll_for_order

View File

@ -122,7 +122,7 @@ def EventChain():
time.sleep(0.5)
self.interim_value = "final"
app = rx.App(state=State)
app = rx.App(state=rx.State)
token_input = rx.input(
value=State.router.session.client_token, is_read_only=True, id="token"
@ -401,12 +401,12 @@ async def test_event_chain_click(
btn.click()
async def _has_all_events():
return len((await event_chain.get_state(token)).event_order) == len(
exp_event_order
)
return len(
(await event_chain.get_state(token)).substates["state"].event_order
) == len(exp_event_order)
await AppHarness._poll_for_async(_has_all_events)
event_order = (await event_chain.get_state(token)).event_order
event_order = (await event_chain.get_state(token)).substates["state"].event_order
assert event_order == exp_event_order
@ -453,12 +453,12 @@ async def test_event_chain_on_load(
token = assert_token(event_chain, driver)
async def _has_all_events():
return len((await event_chain.get_state(token)).event_order) == len(
exp_event_order
)
return len(
(await event_chain.get_state(token)).substates["state"].event_order
) == len(exp_event_order)
await AppHarness._poll_for_async(_has_all_events)
backend_state = await event_chain.get_state(token)
backend_state = (await event_chain.get_state(token)).substates["state"]
assert backend_state.event_order == exp_event_order
assert backend_state.is_hydrated is True
@ -529,12 +529,12 @@ async def test_event_chain_on_mount(
unmount_button.click()
async def _has_all_events():
return len((await event_chain.get_state(token)).event_order) == len(
exp_event_order
)
return len(
(await event_chain.get_state(token)).substates["state"].event_order
) == len(exp_event_order)
await AppHarness._poll_for_async(_has_all_events)
event_order = (await event_chain.get_state(token)).event_order
event_order = (await event_chain.get_state(token)).substates["state"].event_order
assert event_order == exp_event_order

View File

@ -22,7 +22,7 @@ def FormSubmit():
def form_submit(self, form_data: dict):
self.form_data = form_data
app = rx.App(state=FormState)
app = rx.App(state=rx.State)
@app.add_page
def index():
@ -75,7 +75,7 @@ def FormSubmitName():
def form_submit(self, form_data: dict):
self.form_data = form_data
app = rx.App(state=FormState)
app = rx.App(state=rx.State)
@app.add_page
def index():
@ -210,7 +210,7 @@ async def test_submit(driver, form_submit: AppHarness):
submit_input.click()
async def get_form_data():
return (await form_submit.get_state(token)).form_data
return (await form_submit.get_state(token)).substates["form_state"].form_data
# wait for the form data to arrive at the backend
form_data = await AppHarness._poll_for_async(get_form_data)

View File

@ -16,7 +16,7 @@ def FullyControlledInput():
class State(rx.State):
text: str = "initial"
app = rx.App(state=State)
app = rx.App(state=rx.State)
@app.add_page
def index():
@ -85,13 +85,15 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
debounce_input.send_keys("foo")
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "ifoonitial"
assert (await fully_controlled_input.get_state(token)).text == "ifoonitial"
assert (await fully_controlled_input.get_state(token)).substates[
"state"
].text == "ifoonitial"
assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
# clear the input on the backend
async with fully_controlled_input.modify_state(token) as state:
state.text = ""
assert (await fully_controlled_input.get_state(token)).text == ""
state.substates["state"].text = ""
assert (await fully_controlled_input.get_state(token)).substates["state"].text == ""
assert (
fully_controlled_input.poll_for_value(
debounce_input, exp_not_equal="ifoonitial"
@ -103,9 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
debounce_input.send_keys("getting testing done")
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "getting testing done"
assert (
await fully_controlled_input.get_state(token)
).text == "getting testing done"
assert (await fully_controlled_input.get_state(token)).substates[
"state"
].text == "getting testing done"
assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
# type into the on_change input
@ -113,7 +115,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "overwrite the state"
assert on_change_input.get_attribute("value") == "overwrite the state"
assert (await fully_controlled_input.get_state(token)).text == "overwrite the state"
assert (await fully_controlled_input.get_state(token)).substates[
"state"
].text == "overwrite the state"
assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
clear_button.click()

View File

@ -42,7 +42,7 @@ def LoginSample():
rx.button("Do it", on_click=State.login, id="doit"),
)
app = rx.App(state=State)
app = rx.App(state=rx.State)
app.add_page(index)
app.add_page(login)
app.compile()
@ -137,6 +137,6 @@ def test_login_flow(
logout_button = driver.find_element(By.ID, "logout")
logout_button.click()
assert login_sample._poll_for(lambda: local_storage["state.auth_token"] == "")
assert login_sample._poll_for(lambda: local_storage["state.state.auth_token"] == "")
with pytest.raises(NoSuchElementException):
driver.find_element(By.ID, "auth-token")

View File

@ -81,7 +81,7 @@ def RadixThemesApp():
)
app = rx.App(
state=State,
state=rx.State,
theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"),
)
app.add_page(index)

View File

@ -33,7 +33,7 @@ def ServerSideEvent():
def set_value_return_c(self):
return rx.set_value("c", "")
app = rx.App(state=SSState)
app = rx.App(state=rx.State)
@app.add_page
def index():

View File

@ -26,7 +26,7 @@ def Table():
caption: str = "random caption"
app = rx.App(state=TableState)
app = rx.App(state=rx.State)
@app.add_page
def index():

View File

@ -113,7 +113,7 @@ def UploadFile():
),
)
app = rx.App(state=UploadState)
app = rx.App(state=rx.State)
app.add_page(index)
app.compile()
@ -192,7 +192,7 @@ async def test_upload_file(
# look up the backend state and assert on uploaded contents
async def get_file_data():
return (await upload_file.get_state(token))._file_data
return (await upload_file.get_state(token)).substates["upload_state"]._file_data
file_data = await AppHarness._poll_for_async(get_file_data)
assert isinstance(file_data, dict)
@ -205,8 +205,8 @@ async def test_upload_file(
state = await upload_file.get_state(token)
if secondary:
# only the secondary form tracks progress and chain events
assert state.event_order.count("upload_progress") == 1
assert state.event_order.count("chain_event") == 1
assert state.substates["upload_state"].event_order.count("upload_progress") == 1
assert state.substates["upload_state"].event_order.count("chain_event") == 1
@pytest.mark.asyncio
@ -251,7 +251,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
# look up the backend state and assert on uploaded contents
async def get_file_data():
return (await upload_file.get_state(token))._file_data
return (await upload_file.get_state(token)).substates["upload_state"]._file_data
file_data = await AppHarness._poll_for_async(get_file_data)
assert isinstance(file_data, dict)
@ -349,7 +349,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
# look up the backend state and assert on progress
state = await upload_file.get_state(token)
assert state.progress_dicts
assert exp_name not in state._file_data
assert state.substates["upload_state"].progress_dicts
assert exp_name not in state.substates["upload_state"]._file_data
target_file.unlink()

View File

@ -30,7 +30,7 @@ def VarOperations():
dict2: dict = {3: 4}
html_str: str = "<div>hello</div>"
app = rx.App(state=VarOperationState)
app = rx.App(state=rx.State)
@app.add_page
def index():

View File

@ -57,6 +57,7 @@ from reflex.route import (
verify_route_validity,
)
from reflex.state import (
BaseState,
RouterData,
State,
StateManager,
@ -98,7 +99,7 @@ class App(Base):
socket_app: Optional[ASGIApp] = None
# The state class to use for the app.
state: Optional[Type[State]] = None
state: Optional[Type[BaseState]] = None
# Class to manage many client states.
_state_manager: Optional[StateManager] = None
@ -149,25 +150,24 @@ class App(Base):
"`connect_error_component` is deprecated, use `overlay_component` instead"
)
super().__init__(*args, **kwargs)
state_subclasses = State.__subclasses__()
inferred_state = state_subclasses[-1] if state_subclasses else None
state_subclasses = BaseState.__subclasses__()
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
# Special case to allow test cases have multiple subclasses of rx.State.
# Special case to allow test cases have multiple subclasses of rx.BaseState.
if not is_testing_env:
# Only one State class is allowed.
# Only one Base State class is allowed.
if len(state_subclasses) > 1:
raise ValueError(
"rx.State has been subclassed multiple times. Only one subclass is allowed"
"rx.BaseState cannot be subclassed multiple times. use rx.State instead"
)
# verify that provided state is valid
if self.state and inferred_state and self.state is not inferred_state:
if self.state and self.state is not State:
console.warn(
f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported."
f" Defaulting to root state: ({inferred_state.__name__})"
f" Defaulting to root state: ({State.__name__})"
)
self.state = inferred_state
self.state = State
# Get the config
config = get_config()
@ -265,7 +265,7 @@ class App(Base):
raise ValueError("The state manager has not been initialized.")
return self._state_manager
async def preprocess(self, state: State, event: Event) -> StateUpdate | None:
async def preprocess(self, state: BaseState, event: Event) -> StateUpdate | None:
"""Preprocess the event.
This is where middleware can modify the event before it is processed.
@ -290,7 +290,7 @@ class App(Base):
return out # type: ignore
async def postprocess(
self, state: State, event: Event, update: StateUpdate
self, state: BaseState, event: Event, update: StateUpdate
) -> StateUpdate:
"""Postprocess the event.
@ -764,7 +764,7 @@ class App(Base):
future.result()
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]:
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state out of band.
Args:
@ -792,7 +792,9 @@ class App(Base):
sid=state.router.session.session_id,
)
def _process_background(self, state: State, event: Event) -> asyncio.Task | None:
def _process_background(
self, state: BaseState, event: Event
) -> asyncio.Task | None:
"""Process an event in the background and emit updates as they arrive.
Args:

View File

@ -33,6 +33,7 @@ from reflex.route import (
)
from reflex.state import (
State as State,
BaseState as BaseState,
StateManager as StateManager,
StateUpdate as StateUpdate,
)
@ -69,7 +70,7 @@ class App(Base):
api: FastAPI
sio: Optional[AsyncServer]
socket_app: Optional[ASGIApp]
state: Type[State]
state: Type[BaseState]
state_manager: StateManager
style: ComponentStyle
middleware: List[Middleware]

View File

@ -1,11 +1,42 @@
"""Define the base Reflex class."""
from __future__ import annotations
from typing import Any
import os
from typing import Any, List, Type
import pydantic
from pydantic import BaseModel
from pydantic.fields import ModelField
from reflex import constants
def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
"""Ensure that the field's name does not shadow an existing attribute of the model.
Args:
bases: List of base models to check for shadowed attrs.
field_name: name of attribute
Raises:
NameError: If state var field shadows another in its parent state
"""
reload = os.getenv(constants.RELOAD_CONFIG) == "True"
for base in bases:
try:
if not reload and getattr(base, field_name, None):
pass
except TypeError as te:
raise NameError(
f'State var "{field_name}" in {base} has been shadowed by a substate var; '
f'use a different field name instead".'
) from te
# monkeypatch pydantic validate_field_name method to skip validating
# shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True)
pydantic.main.validate_field_name = validate_field_name # type: ignore
class Base(pydantic.BaseModel):
"""The base class subclassed by all Reflex classes.

View File

@ -15,7 +15,7 @@ from reflex.components.component import (
StatefulComponent,
)
from reflex.config import get_config
from reflex.state import State
from reflex.state import BaseState
from reflex.utils.imports import ImportVar
@ -63,7 +63,7 @@ def _compile_theme(theme: dict) -> str:
return templates.THEME.render(theme=theme)
def _compile_contexts(state: Optional[Type[State]]) -> str:
def _compile_contexts(state: Optional[Type[BaseState]]) -> str:
"""Compile the initial state and contexts.
Args:
@ -87,7 +87,7 @@ def _compile_contexts(state: Optional[Type[State]]) -> str:
def _compile_page(
component: Component,
state: Type[State],
state: Type[BaseState],
) -> str:
"""Compile the component given the app state.
@ -337,7 +337,7 @@ def compile_theme(style: ComponentStyle) -> tuple[str, str]:
return output_path, code
def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
def compile_contexts(state: Optional[Type[BaseState]]) -> tuple[str, str]:
"""Compile the initial state / context.
Args:
@ -353,7 +353,7 @@ def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]:
def compile_page(
path: str, component: Component, state: Type[State]
path: str, component: Component, state: Type[BaseState]
) -> tuple[str, str]:
"""Compile a single page.

View File

@ -21,7 +21,7 @@ from reflex.components.base import (
Title,
)
from reflex.components.component import Component, ComponentStyle, CustomComponent
from reflex.state import Cookie, LocalStorage, State
from reflex.state import BaseState, Cookie, LocalStorage
from reflex.style import Style
from reflex.utils import console, format, imports, path_ops
@ -128,7 +128,7 @@ def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None)
}
def compile_state(state: Type[State]) -> dict:
def compile_state(state: Type[BaseState]) -> dict:
"""Compile the state of the app.
Args:
@ -170,7 +170,7 @@ def _compile_client_storage_field(
def _compile_client_storage_recursive(
state: Type[State],
state: Type[BaseState],
) -> tuple[dict[str, dict], dict[str, dict[str, str]]]:
"""Compile the client-side storage for the given state recursively.
@ -208,7 +208,7 @@ def _compile_client_storage_recursive(
return cookies, local_storage
def compile_client_storage(state: Type[State]) -> dict[str, dict]:
def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:
"""Compile the client-side storage for the given state.
Args:

View File

@ -6,6 +6,7 @@ from .base import (
LOCAL_STORAGE,
POLLING_MAX_HTTP_BUFFER_SIZE,
PYTEST_CURRENT_TEST,
RELOAD_CONFIG,
SKIP_COMPILE_ENV_VAR,
ColorMode,
Dirs,
@ -85,6 +86,7 @@ __ALL__ = [
PYTEST_CURRENT_TEST,
PRODUCTION_BACKEND_URL,
Reflex,
RELOAD_CONFIG,
RequirementsTxt,
RouteArgType,
RouteRegex,

View File

@ -173,3 +173,4 @@ SKIP_COMPILE_ENV_VAR = "__REFLEX_SKIP_COMPILE"
# Testing variables.
# Testing os env set by pytest when running a test case.
PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST"
RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG"

View File

@ -22,7 +22,7 @@ from reflex.utils.types import ArgsSpec
from reflex.vars import BaseVar, Var
if TYPE_CHECKING:
from reflex.state import State
from reflex.state import BaseState
class Event(Base):
@ -64,7 +64,7 @@ def background(fn):
def _no_chain_background_task(
state_cls: Type["State"], name: str, fn: Callable
state_cls: Type["BaseState"], name: str, fn: Callable
) -> Callable:
"""Protect against directly chaining a background task from another event handler.

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
from reflex import constants
from reflex.event import Event, fix_events, get_hydrate_event
from reflex.middleware.middleware import Middleware
from reflex.state import State, StateUpdate
from reflex.state import BaseState, StateUpdate
from reflex.utils import format
if TYPE_CHECKING:
@ -17,7 +17,7 @@ class HydrateMiddleware(Middleware):
"""Middleware to handle initial app hydration."""
async def preprocess(
self, app: App, state: State, event: Event
self, app: App, state: BaseState, event: Event
) -> Optional[StateUpdate]:
"""Preprocess the event.

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
from reflex.base import Base
from reflex.event import Event
from reflex.state import State, StateUpdate
from reflex.state import BaseState, StateUpdate
if TYPE_CHECKING:
from reflex.app import App
@ -16,7 +16,7 @@ class Middleware(Base, ABC):
"""Middleware to preprocess and postprocess requests."""
async def preprocess(
self, app: App, state: State, event: Event
self, app: App, state: BaseState, event: Event
) -> Optional[StateUpdate]:
"""Preprocess the event.
@ -31,7 +31,7 @@ class Middleware(Base, ABC):
return None
async def postprocess(
self, app: App, state: State, event: Event, update: StateUpdate
self, app: App, state: BaseState, event: Event, update: StateUpdate
) -> StateUpdate:
"""Postprocess the event.

View File

@ -7,6 +7,7 @@ import copy
import functools
import inspect
import json
import os
import traceback
import urllib.parse
import uuid
@ -81,7 +82,7 @@ class HeaderData(Base):
class PageData(Base):
"""An object containing page data."""
host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?)
host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?)
path: str = ""
raw_path: str = ""
full_path: str = ""
@ -152,7 +153,7 @@ RESERVED_BACKEND_VAR_NAMES = {
}
class State(Base, ABC, extra=pydantic.Extra.allow):
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""
# A map from the var name to the var.
@ -176,6 +177,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# The event handlers.
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
# A set of subclassses of this class.
class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
# Mapping of var name to set of computed variables that depend on it
_computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
@ -189,10 +193,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
_always_dirty_substates: ClassVar[Set[str]] = set()
# The parent state.
parent_state: Optional[State] = None
parent_state: Optional[BaseState] = None
# The substates of the state.
substates: Dict[str, State] = {}
substates: Dict[str, BaseState] = {}
# The set of dirty vars.
dirty_vars: Set[str] = set()
@ -209,10 +213,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# The router data for the current page
router: RouterData = RouterData()
# The hydrated bool.
is_hydrated: bool = False
def __init__(self, *args, parent_state: State | None = None, **kwargs):
def __init__(self, *args, parent_state: BaseState | None = None, **kwargs):
"""Initialize the state.
Args:
@ -220,28 +221,20 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
parent_state: The parent state.
**kwargs: The kwargs to pass to the Pydantic init method.
Raises:
ValueError: If a substate class shadows another.
"""
kwargs["parent_state"] = parent_state
super().__init__(*args, **kwargs)
# Setup the substates.
for substate in self.get_substates():
substate_name = substate.get_name()
if substate_name in self.substates:
raise ValueError(
f"The substate class '{substate_name}' has been defined multiple times. Shadowing "
f"substate classes is not allowed."
)
self.substates[substate_name] = substate(parent_state=self)
self.substates[substate.get_name()] = substate(parent_state=self)
# Convert the event handlers to functions.
self._init_event_handlers()
# Create a fresh copy of the backend variables for this instance
self._backend_vars = copy.deepcopy(self.backend_vars)
def _init_event_handlers(self, state: State | None = None):
def _init_event_handlers(self, state: BaseState | None = None):
"""Initialize event handlers.
Allow event handlers to be called directly on the instance. This is
@ -281,17 +274,44 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Args:
**kwargs: The kwargs to pass to the pydantic init_subclass method.
Raises:
ValueError: If a substate class shadows another.
"""
is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ
super().__init_subclass__(**kwargs)
# Event handlers should not shadow builtin state methods.
cls._check_overridden_methods()
# Reset subclass tracking for this class.
cls.class_subclasses = set()
# Get the parent vars.
parent_state = cls.get_parent_state()
if parent_state is not None:
cls.inherited_vars = parent_state.vars
cls.inherited_backend_vars = parent_state.backend_vars
# Check if another substate class with the same name has already been defined.
if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses):
if is_testing_env:
# Clear existing subclass with same name when app is reloaded via
# utils.prerequisites.get_app(reload=True)
parent_state.class_subclasses = set(
c
for c in parent_state.class_subclasses
if c.__name__ != cls.__name__
)
else:
# During normal operation, subclasses cannot have the same name, even if they are
# defined in different modules.
raise ValueError(
f"The substate class '{cls.__name__}' has been defined multiple times. "
"Shadowing substate classes is not allowed."
)
# Track this new subclass in the parent state's subclasses set.
parent_state.class_subclasses.add(cls)
cls.new_backend_vars = {
name: value
for name, value in cls.__dict__.items()
@ -437,7 +457,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
@classmethod
@functools.lru_cache()
def get_parent_state(cls) -> Type[State] | None:
def get_parent_state(cls) -> Type[BaseState] | None:
"""Get the parent state.
Returns:
@ -446,20 +466,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
parent_states = [
base
for base in cls.__bases__
if types._issubclass(base, State) and base is not State
if types._issubclass(base, BaseState) and base is not BaseState
]
assert len(parent_states) < 2, "Only one parent state is allowed."
return parent_states[0] if len(parent_states) == 1 else None # type: ignore
@classmethod
@functools.lru_cache()
def get_substates(cls) -> set[Type[State]]:
def get_substates(cls) -> set[Type[BaseState]]:
"""Get the substates of the state.
Returns:
The substates of the state.
"""
return set(cls.__subclasses__())
return cls.class_subclasses
@classmethod
@functools.lru_cache()
@ -487,7 +506,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
@classmethod
@functools.lru_cache()
def get_class_substate(cls, path: Sequence[str]) -> Type[State]:
def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]:
"""Get the class substate.
Args:
@ -643,7 +662,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
"""
return {
func[0]: func[1]
for func in inspect.getmembers(State, predicate=inspect.isfunction)
for func in inspect.getmembers(BaseState, predicate=inspect.isfunction)
if not func[0].startswith("__")
}
@ -909,7 +928,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
for substate in self.substates.values():
substate._reset_client_storage()
def get_substate(self, path: Sequence[str]) -> State | None:
def get_substate(self, path: Sequence[str]) -> BaseState | None:
"""Get the substate.
Args:
@ -933,7 +952,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
def _get_event_handler(
self, event: Event
) -> tuple[State | StateProxy, EventHandler]:
) -> tuple[BaseState | StateProxy, EventHandler]:
"""Get the event handler for the given event.
Args:
@ -1050,7 +1069,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
)
async def _process_event(
self, handler: EventHandler, state: State | StateProxy, payload: Dict
self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict
) -> AsyncIterator[StateUpdate]:
"""Process event.
@ -1263,7 +1282,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
d.update(substate_d)
return d
async def __aenter__(self) -> State:
async def __aenter__(self) -> BaseState:
"""Enter the async context manager protocol.
This should not be used for the State class, but exists for
@ -1288,6 +1307,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
pass
class State(BaseState):
"""The app Base State."""
# The hydrated bool.
is_hydrated: bool = False
class StateProxy(wrapt.ObjectProxy):
"""Proxy of a state instance to control mutability of vars for a background task.
@ -1455,10 +1481,10 @@ class StateManager(Base, ABC):
"""A class to manage many client states."""
# The state class to use.
state: Type[State]
state: Type[BaseState]
@classmethod
def create(cls, state: Type[State]):
def create(cls, state: Type[BaseState]):
"""Create a new state manager.
Args:
@ -1473,7 +1499,7 @@ class StateManager(Base, ABC):
return StateManagerMemory(state=state)
@abstractmethod
async def get_state(self, token: str) -> State:
async def get_state(self, token: str) -> BaseState:
"""Get the state for a token.
Args:
@ -1485,7 +1511,7 @@ class StateManager(Base, ABC):
pass
@abstractmethod
async def set_state(self, token: str, state: State):
async def set_state(self, token: str, state: BaseState):
"""Set the state for a token.
Args:
@ -1496,7 +1522,7 @@ class StateManager(Base, ABC):
@abstractmethod
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]:
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock.
Args:
@ -1512,7 +1538,7 @@ class StateManagerMemory(StateManager):
"""A state manager that stores states in memory."""
# The mapping of client ids to states.
states: Dict[str, State] = {}
states: Dict[str, BaseState] = {}
# The mutex ensures the dict of mutexes is updated exclusively
_state_manager_lock = asyncio.Lock()
@ -1527,7 +1553,7 @@ class StateManagerMemory(StateManager):
"_states_locks": {"exclude": True},
}
async def get_state(self, token: str) -> State:
async def get_state(self, token: str) -> BaseState:
"""Get the state for a token.
Args:
@ -1540,7 +1566,7 @@ class StateManagerMemory(StateManager):
self.states[token] = self.state()
return self.states[token]
async def set_state(self, token: str, state: State):
async def set_state(self, token: str, state: BaseState):
"""Set the state for a token.
Args:
@ -1550,7 +1576,7 @@ class StateManagerMemory(StateManager):
pass
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]:
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock.
Args:
@ -1598,7 +1624,7 @@ class StateManagerRedis(StateManager):
b"evicted",
}
async def get_state(self, token: str) -> State:
async def get_state(self, token: str) -> BaseState:
"""Get the state for a token.
Args:
@ -1613,7 +1639,9 @@ class StateManagerRedis(StateManager):
return await self.get_state(token)
return cloudpickle.loads(redis_state)
async def set_state(self, token: str, state: State, lock_id: bytes | None = None):
async def set_state(
self, token: str, state: BaseState, lock_id: bytes | None = None
):
"""Set the state for a token.
Args:
@ -1637,7 +1665,7 @@ class StateManagerRedis(StateManager):
await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]:
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock.
Args:
@ -1879,7 +1907,7 @@ class MutableProxy(wrapt.ObjectProxy):
__mutable_types__ = (list, dict, set, Base)
def __init__(self, wrapped: Any, state: State, field_name: str):
def __init__(self, wrapped: Any, state: BaseState, field_name: str):
"""Create a proxy for a mutable object that tracks changes.
Args:

View File

@ -38,7 +38,7 @@ import reflex.utils.build
import reflex.utils.exec
import reflex.utils.prerequisites
import reflex.utils.processes
from reflex.state import State, StateManagerMemory, StateManagerRedis
from reflex.state import BaseState, State, StateManagerMemory, StateManagerRedis
try:
from selenium import webdriver # pyright: ignore [reportMissingImports]
@ -162,6 +162,9 @@ class AppHarness:
with chdir(self.app_path):
# ensure config and app are reloaded when testing different app
reflex.config.get_config(reload=True)
# reset rx.State subclasses
State.class_subclasses.clear()
# self.app_module.app.
self.app_module = reflex.utils.prerequisites.get_app(reload=True)
self.app_instance = self.app_module.app
if isinstance(self.app_instance.state_manager, StateManagerRedis):
@ -434,7 +437,7 @@ class AppHarness:
self._frontends.append(driver)
return driver
async def get_state(self, token: str) -> State:
async def get_state(self, token: str) -> BaseState:
"""Get the state associated with the given token.
Args:
@ -561,7 +564,7 @@ class AppHarness:
)
return element.get_attribute("value")
def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, reflex.State]:
def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]:
"""Poll app state_manager for any connected clients.
Args:

View File

@ -124,11 +124,13 @@ def get_app(reload: bool = False) -> ModuleType:
Returns:
The app based on the default config.
"""
os.environ[constants.RELOAD_CONFIG] = str(reload)
config = get_config()
module = ".".join([config.app_name, config.app_name])
sys.path.insert(0, os.getcwd())
app = __import__(module, fromlist=(constants.CompileVars.APP,))
if reload:
importlib.reload(app)
return app

View File

@ -41,7 +41,7 @@ from reflex.utils import console, format, imports, serializers, types
from reflex.utils.imports import ImportDict, ImportVar
if TYPE_CHECKING:
from reflex.state import State
from reflex.state import BaseState
# Set of unique variable names.
USED_VARIABLES = set()
@ -1472,7 +1472,7 @@ class Var:
)
)
def _var_set_state(self, state: Type[State] | str) -> Any:
def _var_set_state(self, state: Type[BaseState] | str) -> Any:
"""Set the state of the var.
Args:
@ -1604,14 +1604,14 @@ class BaseVar(Var):
return setter
return ".".join((self._var_data.state, setter))
def get_setter(self) -> Callable[[State, Any], None]:
def get_setter(self) -> Callable[[BaseState, Any], None]:
"""Get the var's setter function.
Returns:
A function that that creates a setter for the var.
"""
def setter(state: State, value: Any):
def setter(state: BaseState, value: Any):
"""Get the setter for the var.
Args:
@ -1643,9 +1643,9 @@ class ComputedVar(Var, property):
def __init__(
self,
fget: Callable[[State], Any],
fset: Callable[[State, Any], None] | None = None,
fdel: Callable[[State], Any] | None = None,
fget: Callable[[BaseState], Any],
fset: Callable[[BaseState, Any], None] | None = None,
fdel: Callable[[BaseState], Any] | None = None,
doc: str | None = None,
**kwargs,
):

View File

@ -5,6 +5,7 @@ from _typeshed import Incomplete
from reflex import constants as constants
from reflex.base import Base as Base
from reflex.state import State as State
from reflex.state import BaseState as BaseState
from reflex.utils import console as console, format as format, types as types
from reflex.utils.imports import ImportVar
from types import FunctionType
@ -110,7 +111,7 @@ class Var:
def as_ref(self) -> Var: ...
@property
def _var_full_name(self) -> str: ...
def _var_set_state(self, state: Type[State] | str) -> Any: ...
def _var_set_state(self, state: Type[BaseState] | str) -> Any: ...
@dataclass(eq=False)
class BaseVar(Var):
@ -123,7 +124,7 @@ class BaseVar(Var):
def __hash__(self) -> int: ...
def get_default_value(self) -> Any: ...
def get_setter_name(self, include_state: bool = ...) -> str: ...
def get_setter(self) -> Callable[[State, Any], None]: ...
def get_setter(self) -> Callable[[BaseState, Any], None]: ...
@dataclass(init=False)
class ComputedVar(Var):

View File

@ -1 +1,4 @@
"""Root directory for tests."""
import os
from reflex import constants

View File

@ -2,7 +2,7 @@
import pytest
from reflex.components.base.script import Script
from reflex.state import State
from reflex.state import BaseState
def test_script_inline():
@ -31,7 +31,7 @@ def test_script_neither():
Script.create()
class EvState(State):
class EvState(BaseState):
"""State for testing event handlers."""
def on_ready(self):

View File

@ -5,6 +5,7 @@ import pandas as pd
import pytest
import reflex as rx
from reflex.state import BaseState
@pytest.fixture
@ -18,7 +19,7 @@ def data_table_state(request):
The data table state class.
"""
class DataTableState(rx.State):
class DataTableState(BaseState):
data = request.param["data"]
columns = ["column1", "column2"]
@ -33,7 +34,7 @@ def data_table_state2():
The data table state class.
"""
class DataTableState(rx.State):
class DataTableState(BaseState):
_data = pd.DataFrame()
@rx.var
@ -51,7 +52,7 @@ def data_table_state3():
The data table state class.
"""
class DataTableState(rx.State):
class DataTableState(BaseState):
_data: List = []
_columns: List = ["col1", "col2"]
@ -74,7 +75,7 @@ def data_table_state4():
The data table state class.
"""
class DataTableState(rx.State):
class DataTableState(BaseState):
_data: List = []
_columns: List = ["col1", "col2"]

View File

@ -4,12 +4,12 @@ from typing import List, Tuple
import pytest
from reflex.components.datadisplay.table import Tbody, Tfoot, Thead
from reflex.state import State
from reflex.state import BaseState
PYTHON_GT_V38 = sys.version_info.major >= 3 and sys.version_info.minor > 8
class TableState(State):
class TableState(BaseState):
"""Test State class."""
rows_List_List_str: List[List[str]] = [["random", "row"]]

View File

@ -3,6 +3,7 @@
import pytest
import reflex as rx
from reflex.state import BaseState
from reflex.vars import BaseVar
@ -24,7 +25,7 @@ def test_render_many_child():
_ = rx.debounce_input("foo", "bar").render()
class S(rx.State):
class S(BaseState):
"""Example state for debounce tests."""
value: str = ""

View File

@ -15,12 +15,13 @@ from reflex.components.layout.responsive import (
tablet_only,
)
from reflex.components.typography.text import Text
from reflex.state import BaseState
from reflex.vars import Var
@pytest.fixture
def cond_state(request):
class CondState(rx.State):
class CondState(BaseState):
value: request.param["value_type"] = request.param["value"] # noqa
return CondState

View File

@ -4,10 +4,10 @@ import pytest
from reflex.components import box, foreach, text
from reflex.components.layout import Foreach
from reflex.state import State
from reflex.state import BaseState
class ForEachState(State):
class ForEachState(BaseState):
"""A state for testing the ForEach component."""
colors_list: List[str] = ["red", "yellow"]

View File

@ -14,7 +14,7 @@ from reflex.components.component import (
from reflex.components.layout.box import Box
from reflex.constants import EventTriggers
from reflex.event import EventChain, EventHandler
from reflex.state import State
from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
from reflex.utils.imports import ImportVar
@ -23,7 +23,7 @@ from reflex.vars import Var, VarData
@pytest.fixture
def test_state():
class TestState(State):
class TestState(BaseState):
num: int
def do_something(self):
@ -400,7 +400,7 @@ def test_get_event_triggers(component1, component2):
)
class C1State(State):
class C1State(BaseState):
"""State for testing C1 component."""
def mock_handler(self, _e, _bravo, _charlie):

View File

@ -8,7 +8,6 @@ from typing import Dict, Generator
import pytest
import reflex as rx
from reflex.app import App
from reflex.event import EventSpec
@ -225,23 +224,3 @@ def token() -> str:
A fresh/unique token string.
"""
return str(uuid.uuid4())
@pytest.fixture
def duplicate_substate():
"""Create a Test state that has duplicate child substates.
Returns:
The test state.
"""
class TestState(rx.State):
pass
class ChildTestState(TestState): # type: ignore # noqa
pass
class ChildTestState(TestState): # type: ignore # noqa
pass
return TestState

View File

@ -2,13 +2,14 @@ from typing import Any, Dict
import pytest
from reflex import constants
from reflex.app import App
from reflex.constants import CompileVars
from reflex.middleware.hydrate_middleware import HydrateMiddleware
from reflex.state import State, StateUpdate
from reflex.state import BaseState, StateUpdate
def exp_is_hydrated(state: State) -> Dict[str, Any]:
def exp_is_hydrated(state: BaseState) -> Dict[str, Any]:
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
Args:
@ -20,7 +21,7 @@ def exp_is_hydrated(state: State) -> Dict[str, Any]:
return {state.get_name(): {CompileVars.IS_HYDRATED: True}}
class TestState(State):
class TestState(BaseState):
"""A test state with no return in handler."""
__test__ = False
@ -32,7 +33,7 @@ class TestState(State):
self.num += 1
class TestState2(State):
class TestState2(BaseState):
"""A test state with return in handler."""
__test__ = False
@ -54,7 +55,7 @@ class TestState2(State):
self.name = "random"
class TestState3(State):
class TestState3(BaseState):
"""A test state with async handler."""
__test__ = False
@ -97,6 +98,9 @@ async def test_preprocess(
event_fixture: The event fixture(an Event).
expected: Expected delta.
"""
test_state.add_var(
constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
)
app = App(state=test_state, load_events={"index": [test_state.test_handler]})
state = test_state()

View File

@ -1,10 +1,12 @@
"""Common rx.State subclasses for use in tests."""
"""Common rx.BaseState subclasses for use in tests."""
import reflex as rx
from reflex.state import BaseState
from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
from .upload import (
ChildFileUploadState,
FileStateBase1,
FileStateBase2,
FileUploadState,
GrandChildFileUploadState,
SubUploadState,
@ -12,7 +14,7 @@ from .upload import (
)
class GenState(rx.State):
class GenState(BaseState):
"""A state with event handlers that generate multiple updates."""
value: int

View File

@ -3,9 +3,10 @@
from typing import Dict, List, Set, Union
import reflex as rx
from reflex.state import BaseState
class DictMutationTestState(rx.State):
class DictMutationTestState(BaseState):
"""A state for testing ReflexDict mutation."""
# plain dict
@ -62,7 +63,7 @@ class DictMutationTestState(rx.State):
self.friend_in_nested_dict["friend"]["age"] = 30
class ListMutationTestState(rx.State):
class ListMutationTestState(BaseState):
"""A state for testing ReflexList mutation."""
# plain list
@ -144,7 +145,7 @@ class CustomVar(rx.Base):
custom: OtherBase = OtherBase()
class MutableTestState(rx.State):
class MutableTestState(BaseState):
"""A test state."""
array: List[Union[str, List, Dict[str, str]]] = [

View File

@ -3,9 +3,10 @@ from pathlib import Path
from typing import ClassVar, List
import reflex as rx
from reflex.state import BaseState, State
class UploadState(rx.State):
class UploadState(BaseState):
"""The base state for uploading a file."""
async def handle_upload1(self, files: List[rx.UploadFile]):
@ -17,7 +18,7 @@ class UploadState(rx.State):
pass
class BaseState(rx.State):
class BaseState(BaseState):
"""The test base state."""
pass
@ -37,7 +38,7 @@ class SubUploadState(BaseState):
pass
class FileUploadState(rx.State):
class FileUploadState(State):
"""The base state for uploading a file."""
img_list: List[str]
@ -79,7 +80,7 @@ class FileUploadState(rx.State):
pass
class FileStateBase1(rx.State):
class FileStateBase1(State):
"""The base state for a child FileUploadState."""
pass

View File

@ -28,7 +28,7 @@ from reflex.components import Box, Component, Cond, Fragment, Text
from reflex.event import Event, get_hydrate_event
from reflex.middleware import HydrateMiddleware
from reflex.model import Model
from reflex.state import RouterData, State, StateManagerRedis, StateUpdate
from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
from reflex.style import Style
from reflex.utils import format
from reflex.vars import ComputedVar
@ -43,7 +43,7 @@ from .states import (
)
class EmptyState(State):
class EmptyState(BaseState):
"""An empty state."""
pass
@ -77,14 +77,14 @@ def about_page():
return about
class ATestState(State):
class ATestState(BaseState):
"""A simple state for testing."""
var: int
@pytest.fixture()
def test_state() -> Type[State]:
def test_state() -> Type[BaseState]:
"""A default state.
Returns:
@ -94,14 +94,14 @@ def test_state() -> Type[State]:
@pytest.fixture()
def redundant_test_state() -> Type[State]:
def redundant_test_state() -> Type[BaseState]:
"""A default state.
Returns:
A default state.
"""
class RedundantTestState(State):
class RedundantTestState(BaseState):
var: int
return RedundantTestState
@ -198,12 +198,12 @@ def test_default_app(app: App):
def test_multiple_states_error(monkeypatch, test_state, redundant_test_state):
"""Test that an error is thrown when multiple classes subclass rx.State.
"""Test that an error is thrown when multiple classes subclass rx.BaseState.
Args:
monkeypatch: Pytest monkeypatch object.
test_state: A test state subclassing rx.State.
redundant_test_state: Another test state subclassing rx.State.
test_state: A test state subclassing rx.BaseState.
redundant_test_state: Another test state subclassing rx.BaseState.
"""
monkeypatch.delenv(constants.PYTEST_CURRENT_TEST)
with pytest.raises(ValueError):
@ -705,12 +705,12 @@ async def test_dict_mutation_detection__plain_list(
[
(
FileUploadState,
{"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
{"state.file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
),
(
ChildFileUploadState,
{
"file_state_base1.child_file_upload_state": {
"state.file_state_base1.child_file_upload_state": {
"img_list": ["image1.jpg", "image2.jpg"]
}
},
@ -718,14 +718,14 @@ async def test_dict_mutation_detection__plain_list(
(
GrandChildFileUploadState,
{
"file_state_base1.file_state_base2.grand_child_file_upload_state": {
"state.file_state_base1.file_state_base2.grand_child_file_upload_state": {
"img_list": ["image1.jpg", "image2.jpg"]
}
},
),
],
)
async def test_upload_file(tmp_path, state, delta, token: str):
async def test_upload_file(tmp_path, state, delta, token: str, mocker):
"""Test that file upload works correctly.
Args:
@ -733,10 +733,15 @@ async def test_upload_file(tmp_path, state, delta, token: str):
state: The state class.
delta: Expected delta
token: a Token.
mocker: pytest mocker object.
"""
mocker.patch(
"reflex.state.State.class_subclasses",
{state if state is FileUploadState else FileStateBase1},
)
state._tmp_path = tmp_path
# The App state must be the "root" of the state tree
app = App(state=state if state is FileUploadState else FileStateBase1)
app = App(state=State)
app.event_namespace.emit = AsyncMock() # type: ignore
current_state = await app.state_manager.get_state(token)
data = b"This is binary data"
@ -749,7 +754,7 @@ async def test_upload_file(tmp_path, state, delta, token: str):
request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
"reflex-event-handler": f"{state_name}.multi_handle_upload",
"reflex-event-handler": f"state.{state_name}.multi_handle_upload",
}
file1 = UploadFile(
@ -851,7 +856,7 @@ async def test_upload_file_background(state, tmp_path, token):
await app.state_manager.redis.close()
class DynamicState(State):
class DynamicState(BaseState):
"""State class for testing dynamic route var.
This is defined at module level because event handlers cannot be addressed
@ -891,9 +896,7 @@ class DynamicState(State):
@pytest.mark.asyncio
async def test_dynamic_route_var_route_change_completed_on_load(
index_page,
windows_platform: bool,
token: str,
index_page, windows_platform: bool, token: str, mocker
):
"""Create app with dynamic route var, and simulate navigation.
@ -904,7 +907,12 @@ async def test_dynamic_route_var_route_change_completed_on_load(
index_page: The index page.
windows_platform: Whether the system is windows.
token: a Token.
mocker: pytest mocker object.
"""
mocker.patch("reflex.state.State.class_subclasses", {DynamicState})
DynamicState.add_var(
constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
)
arg_name = "dynamic"
route = f"/test/[{arg_name}]"
if windows_platform:

View File

@ -4,7 +4,7 @@ import pytest
from reflex import event
from reflex.event import Event, EventHandler, EventSpec, fix_events
from reflex.state import State
from reflex.state import BaseState
from reflex.utils import format
from reflex.vars import Var
@ -303,7 +303,7 @@ def test_event_actions():
def test_event_actions_on_state():
class EventActionState(State):
class EventActionState(BaseState):
def handler(self):
pass

View File

@ -19,11 +19,11 @@ from reflex.base import Base
from reflex.constants import CompileVars, RouteVar, SocketEvent
from reflex.event import Event, EventHandler
from reflex.state import (
BaseState,
ImmutableStateError,
LockExpiredError,
MutableProxy,
RouterData,
State,
StateManager,
StateManagerMemory,
StateManagerRedis,
@ -75,7 +75,7 @@ class Object(Base):
prop2: str = "hello"
class TestState(State):
class TestState(BaseState):
"""A test state."""
# Set this class as not test one
@ -148,7 +148,7 @@ class GrandchildState(ChildState):
pass
class DateTimeState(State):
class DateTimeState(BaseState):
"""A State with some datetime fields."""
d: datetime.date = datetime.date.fromisoformat("1989-11-09")
@ -253,7 +253,6 @@ def test_class_vars(test_state):
"""
cls = type(test_state)
assert set(cls.vars.keys()) == {
CompileVars.IS_HYDRATED, # added by hydrate_middleware to all State
"router",
"num1",
"num2",
@ -641,7 +640,6 @@ def test_reset(test_state, child_state):
"obj",
"upper",
"complex",
"is_hydrated",
"fig",
"key",
"sum",
@ -837,7 +835,7 @@ def test_get_query_params(test_state):
def test_add_var():
class DynamicState(State):
class DynamicState(BaseState):
pass
ds1 = DynamicState()
@ -870,7 +868,7 @@ def test_add_var_default_handlers(test_state):
assert isinstance(test_state.event_handlers["set_rand_int"], EventHandler)
class InterdependentState(State):
class InterdependentState(BaseState):
"""A state with 3 vars and 3 computed vars.
x: a variable that no computed var depends on
@ -915,7 +913,7 @@ class InterdependentState(State):
@pytest.fixture
def interdependent_state() -> State:
def interdependent_state() -> BaseState:
"""A state with varying dependency between vars.
Returns:
@ -988,7 +986,7 @@ def test_per_state_backend_var(interdependent_state):
def test_child_state():
"""Test that the child state computed vars can reference parent state vars."""
class MainState(State):
class MainState(BaseState):
v: int = 2
class ChildState(MainState):
@ -1006,7 +1004,7 @@ def test_child_state():
def test_conditional_computed_vars():
"""Test that computed vars can have conditionals."""
class MainState(State):
class MainState(BaseState):
flag: bool = False
t1: str = "a"
t2: str = "b"
@ -1051,7 +1049,7 @@ def test_event_handlers_convert_to_fns(test_state, child_state):
def test_event_handlers_call_other_handlers():
"""Test that event handlers can call other event handlers."""
class MainState(State):
class MainState(BaseState):
v: int = 0
def set_v(self, v: int):
@ -1077,7 +1075,7 @@ def test_computed_var_cached():
"""Test that a ComputedVar doesn't recalculate when accessed."""
comp_v_calls = 0
class ComputedState(State):
class ComputedState(BaseState):
v: int = 0
@rx.cached_var
@ -1102,7 +1100,7 @@ def test_computed_var_cached():
def test_computed_var_cached_depends_on_non_cached():
"""Test that a cached_var is recalculated if it depends on non-cached ComputedVar."""
class ComputedState(State):
class ComputedState(BaseState):
v: int = 0
@rx.var
@ -1144,7 +1142,7 @@ def test_computed_var_depends_on_parent_non_cached():
"""Child state cached_var that depends on parent state un cached var is always recalculated."""
counter = 0
class ParentState(State):
class ParentState(BaseState):
@rx.var
def no_cache_v(self) -> int:
nonlocal counter
@ -1165,21 +1163,18 @@ def test_computed_var_depends_on_parent_non_cached():
dict1 = ps.dict()
assert dict1[ps.get_full_name()] == {
"no_cache_v": 1,
CompileVars.IS_HYDRATED: False,
"router": formatted_router,
}
assert dict1[cs.get_full_name()] == {"dep_v": 2}
dict2 = ps.dict()
assert dict2[ps.get_full_name()] == {
"no_cache_v": 3,
CompileVars.IS_HYDRATED: False,
"router": formatted_router,
}
assert dict2[cs.get_full_name()] == {"dep_v": 4}
dict3 = ps.dict()
assert dict3[ps.get_full_name()] == {
"no_cache_v": 5,
CompileVars.IS_HYDRATED: False,
"router": formatted_router,
}
assert dict3[cs.get_full_name()] == {"dep_v": 6}
@ -1195,7 +1190,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
"""
counter = 0
class HandlerState(State):
class HandlerState(BaseState):
x: int = 42
def handler(self):
@ -1226,7 +1221,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
def test_computed_var_dependencies():
"""Test that a ComputedVar correctly tracks its dependencies."""
class ComputedState(State):
class ComputedState(BaseState):
v: int = 0
w: int = 0
x: int = 0
@ -1293,7 +1288,7 @@ def test_computed_var_dependencies():
def test_backend_method():
"""A method with leading underscore should be callable from event handler."""
class BackendMethodState(State):
class BackendMethodState(BaseState):
def _be_method(self):
return True
@ -1369,7 +1364,7 @@ def test_error_on_state_method_shadow():
"""Test that an error is thrown when an event handler shadows a state method."""
with pytest.raises(NameError) as err:
class InvalidTest(rx.State):
class InvalidTest(BaseState):
def reset(self):
pass
@ -1382,7 +1377,7 @@ def test_error_on_state_method_shadow():
def test_state_with_invalid_yield():
"""Test that an error is thrown when a state yields an invalid value."""
class StateWithInvalidYield(rx.State):
class StateWithInvalidYield(BaseState):
"""A state that yields an invalid value."""
def invalid_handler(self):
@ -1666,7 +1661,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
assert mcall.kwargs["to"] == grandchild_state.get_sid()
class BackgroundTaskState(State):
class BackgroundTaskState(BaseState):
"""A state with a background task."""
order: List[str] = []
@ -2192,9 +2187,20 @@ def test_mutable_copy_vars(mutable_state, copy_func):
assert not isinstance(var_copy, MutableProxy)
def test_duplicate_substate_class(duplicate_substate):
def test_duplicate_substate_class(mocker):
mocker.patch("reflex.state.os.environ", {})
with pytest.raises(ValueError):
duplicate_substate()
class TestState(BaseState):
pass
class ChildTestState(TestState): # type: ignore # noqa
pass
class ChildTestState(TestState): # type: ignore # noqa
pass
return TestState
class Foo(Base):
@ -2206,7 +2212,7 @@ class Foo(Base):
def test_json_dumps_with_mutables():
"""Test that json.dumps works with Base vars inside mutable types."""
class MutableContainsBase(State):
class MutableContainsBase(BaseState):
items: List[Foo] = [Foo()]
dict_val = MutableContainsBase().dict()
@ -2216,7 +2222,7 @@ def test_json_dumps_with_mutables():
f_formatted_router = str(formatted_router).replace("'", '"')
assert (
val
== f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}'
== f'{{"{MutableContainsBase.get_full_name()}": {{"items": {f_items}, "router": {f_formatted_router}}}}}'
)
@ -2225,7 +2231,7 @@ def test_reset_with_mutables():
default = [[0, 0], [0, 1], [1, 1]]
copied_default = copy.deepcopy(default)
class MutableResetState(State):
class MutableResetState(BaseState):
items: List[List[int]] = default
instance = MutableResetState()
@ -2273,7 +2279,7 @@ class Custom3(Base):
def test_state_union_optional():
"""Test that state can be defined with Union and Optional vars."""
class UnionState(State):
class UnionState(BaseState):
int_float: Union[int, float] = 0
opt_int: Optional[int]
c3: Optional[Custom3]

View File

@ -6,7 +6,7 @@ import pytest
from pandas import DataFrame
from reflex.base import Base
from reflex.state import State
from reflex.state import BaseState
from reflex.vars import (
BaseVar,
ComputedVar,
@ -24,12 +24,6 @@ test_vars = [
]
class BaseState(State):
"""A Test State."""
val: str = "key"
@pytest.fixture
def TestObj():
class TestObj(Base):
@ -41,7 +35,7 @@ def TestObj():
@pytest.fixture
def ParentState(TestObj):
class ParentState(State):
class ParentState(BaseState):
foo: int
bar: int
@ -74,7 +68,7 @@ def GrandChildState(ChildState, TestObj):
@pytest.fixture
def StateWithAnyVar(TestObj):
class StateWithAnyVar(State):
class StateWithAnyVar(BaseState):
@ComputedVar
def var_without_annotation(self) -> typing.Any:
return TestObj
@ -84,7 +78,7 @@ def StateWithAnyVar(TestObj):
@pytest.fixture
def StateWithCorrectVarAnnotation():
class StateWithCorrectVarAnnotation(State):
class StateWithCorrectVarAnnotation(BaseState):
@ComputedVar
def var_with_annotation(self) -> str:
return "Correct annotation"
@ -94,7 +88,7 @@ def StateWithCorrectVarAnnotation():
@pytest.fixture
def StateWithWrongVarAnnotation(TestObj):
class StateWithWrongVarAnnotation(State):
class StateWithWrongVarAnnotation(BaseState):
@ComputedVar
def var_with_annotation(self) -> str:
return TestObj

View File

@ -528,7 +528,6 @@ formatted_router = {
},
"dt": "1989-11-09 18:53:00+01:00",
"fig": [],
"is_hydrated": False,
"key": "",
"map_key": "a",
"mapping": {"a": [1, 2, 3], "b": [4, 5, 6]},
@ -553,7 +552,6 @@ formatted_router = {
DateTimeState.get_full_name(): {
"d": "1989-11-09",
"dt": "1989-11-09 18:53:00+01:00",
"is_hydrated": False,
"t": "18:53:00+01:00",
"td": "11 days, 0:11:00",
"router": formatted_router,

View File

@ -10,7 +10,7 @@ from packaging import version
from reflex import constants
from reflex.base import Base
from reflex.event import EventHandler
from reflex.state import State
from reflex.state import BaseState
from reflex.utils import (
build,
prerequisites,
@ -43,7 +43,7 @@ V056 = version.parse("0.5.6")
VMAXPLUS1 = version.parse(get_above_max_version())
class ExampleTestState(State):
class ExampleTestState(BaseState):
"""Test state class."""
def test_event_handler(self):