diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml
index c8ab75adb..b849dd328 100644
--- a/.github/workflows/benchmarks.yml
+++ b/.github/workflows/benchmarks.yml
@@ -17,7 +17,7 @@ defaults:
env:
PYTHONIOENCODING: 'utf8'
TELEMETRY_ENABLED: false
- NODE_OPTIONS: '--max_old_space_size=4096'
+ NODE_OPTIONS: '--max_old_space_size=8192'
PR_TITLE: ${{ github.event.pull_request.title }}
jobs:
diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml
index 8fa787aba..58f7a668a 100644
--- a/.github/workflows/integration_tests.yml
+++ b/.github/workflows/integration_tests.yml
@@ -29,7 +29,7 @@ env:
# - Best effort print lines that contain illegal chars (map to some default char, etc.)
PYTHONIOENCODING: 'utf8'
TELEMETRY_ENABLED: false
- NODE_OPTIONS: '--max_old_space_size=4096'
+ NODE_OPTIONS: '--max_old_space_size=8192'
PR_TITLE: ${{ github.event.pull_request.title }}
jobs:
diff --git a/integration/test_background_task.py b/integration/test_background_task.py
index 96a47e951..98b6e48ff 100644
--- a/integration/test_background_task.py
+++ b/integration/test_background_task.py
@@ -12,7 +12,10 @@ def BackgroundTask():
"""Test that background tasks work as expected."""
import asyncio
+ import pytest
+
import reflex as rx
+ from reflex.state import ImmutableStateError
class State(rx.State):
counter: int = 0
@@ -71,6 +74,38 @@ def BackgroundTask():
self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
)
+ @rx.background
+ async def nested_async_with_self(self):
+ async with self:
+ self.counter += 1
+ with pytest.raises(ImmutableStateError):
+ async with self:
+ self.counter += 1
+
+ async def triple_count(self):
+ third_state = await self.get_state(ThirdState)
+ await third_state._triple_count()
+
+ class OtherState(rx.State):
+ @rx.background
+ async def get_other_state(self):
+ async with self:
+ state = await self.get_state(State)
+ state.counter += 1
+ await state.triple_count()
+ with pytest.raises(ImmutableStateError):
+ await state.triple_count()
+ with pytest.raises(ImmutableStateError):
+ state.counter += 1
+ async with state:
+ state.counter += 1
+ await state.triple_count()
+
+ class ThirdState(rx.State):
+ async def _triple_count(self):
+ state = await self.get_state(State)
+ state.counter *= 3
+
def index() -> rx.Component:
return rx.vstack(
rx.chakra.input(
@@ -109,6 +144,16 @@ def BackgroundTask():
on_click=State.handle_racy_event,
id="racy-increment",
),
+ rx.button(
+ "Nested Async with Self",
+ on_click=State.nested_async_with_self,
+ id="nested-async-with-self",
+ ),
+ rx.button(
+ "Increment from OtherState",
+ on_click=OtherState.get_other_state,
+ id="increment-from-other-state",
+ ),
rx.button("Reset", on_click=State.reset_counter, id="reset"),
)
@@ -230,3 +275,61 @@ def test_background_task(
assert background_task._poll_for(
lambda: not background_task.app_instance.background_tasks # type: ignore
)
+
+
+def test_nested_async_with_self(
+ background_task: AppHarness,
+ driver: WebDriver,
+ token: str,
+):
+ """Test that nested async with self in the same coroutine raises Exception.
+
+ Args:
+ background_task: harness for BackgroundTask app.
+ driver: WebDriver instance.
+ token: The token for the connected client.
+ """
+ assert background_task.app_instance is not None
+
+ # get a reference to all buttons
+ nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
+ increment_button = driver.find_element(By.ID, "increment")
+
+ # get a reference to the counter
+ counter = driver.find_element(By.ID, "counter")
+ assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
+
+ nested_async_with_self_button.click()
+ assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
+
+ increment_button.click()
+ assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
+
+
+def test_get_state(
+ background_task: AppHarness,
+ driver: WebDriver,
+ token: str,
+):
+ """Test that get_state returns a state bound to the correct StateProxy.
+
+ Args:
+ background_task: harness for BackgroundTask app.
+ driver: WebDriver instance.
+ token: The token for the connected client.
+ """
+ assert background_task.app_instance is not None
+
+ # get a reference to all buttons
+ other_state_button = driver.find_element(By.ID, "increment-from-other-state")
+ increment_button = driver.find_element(By.ID, "increment")
+
+ # get a reference to the counter
+ counter = driver.find_element(By.ID, "counter")
+ assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
+
+ other_state_button.click()
+ assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
+
+ increment_button.click()
+ assert background_task._poll_for(lambda: counter.text == "13", timeout=5)
diff --git a/pyproject.toml b/pyproject.toml
index 4900510cb..656ff091a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "reflex"
-version = "0.5.7"
+version = "0.5.9"
description = "Web apps in pure Python."
license = "Apache-2.0"
authors = [
diff --git a/reflex/.templates/web/utils/client_side_routing.js b/reflex/.templates/web/utils/client_side_routing.js
index 75fb581c8..1718c8e61 100644
--- a/reflex/.templates/web/utils/client_side_routing.js
+++ b/reflex/.templates/web/utils/client_side_routing.js
@@ -23,7 +23,12 @@ export const useClientSideRouting = () => {
router.replace({
pathname: window.location.pathname,
query: window.location.search.slice(1),
- })
+ }).then(()=>{
+ // Check if the current route is /404
+ if (router.pathname === '/404') {
+ setRouteNotFound(true); // Mark as an actual 404
+ }
+ })
.catch((e) => {
setRouteNotFound(true) // navigation failed, so this is a real 404
})
diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js
index f67ce6858..26b2d0d0c 100644
--- a/reflex/.templates/web/utils/state.js
+++ b/reflex/.templates/web/utils/state.js
@@ -647,7 +647,12 @@ export const useEventLoop = (
const [connectErrors, setConnectErrors] = useState([]);
// Function to add new events to the event queue.
- const addEvents = (events, _e, event_actions) => {
+ const addEvents = (events, args, event_actions) => {
+ if (!(args instanceof Array)) {
+ args = [args];
+ }
+ const _e = args.filter((o) => o?.preventDefault !== undefined)[0]
+
if (event_actions?.preventDefault && _e?.preventDefault) {
_e.preventDefault();
}
@@ -777,7 +782,7 @@ export const useEventLoop = (
// Route after the initial page hydration.
useEffect(() => {
const change_start = () => {
- const main_state_dispatch = dispatch["state"]
+ const main_state_dispatch = dispatch["reflex___state____state"]
if (main_state_dispatch !== undefined) {
main_state_dispatch({ is_hydrated: false })
}
diff --git a/reflex/app.py b/reflex/app.py
index 001c95885..92ad3d149 100644
--- a/reflex/app.py
+++ b/reflex/app.py
@@ -112,11 +112,29 @@ def default_backend_exception_handler(exception: Exception) -> EventSpec:
EventSpec: The window alert event.
"""
+ from reflex.components.sonner.toast import Toaster, toast
+
error = traceback.format_exc()
console.error(f"[Reflex Backend Exception]\n {error}\n")
- return window_alert("An error occurred. See logs for details.")
+ error_message = (
+ ["Contact the website administrator."]
+ if is_prod_mode()
+ else [f"{type(exception).__name__}: {exception}.", "See logs for details."]
+ )
+ if Toaster.is_used:
+ return toast(
+ "An error occurred.",
+ level="error",
+ description="
".join(error_message),
+ position="top-center",
+ id="backend_error",
+ style={"width": "500px"},
+ ) # type: ignore
+ else:
+ error_message.insert(0, "An error occurred.")
+ return window_alert("\n".join(error_message))
def default_overlay_component() -> Component:
@@ -183,7 +201,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
# A component that is present on every page (defaults to the Connection Error banner).
overlay_component: Optional[Union[Component, ComponentCallable]] = (
- default_overlay_component
+ default_overlay_component()
)
# Error boundary component to wrap the app with.
diff --git a/reflex/components/component.py b/reflex/components/component.py
index 1786dd060..bb3e9053f 100644
--- a/reflex/components/component.py
+++ b/reflex/components/component.py
@@ -24,6 +24,7 @@ from typing import (
import reflex.state
from reflex.base import Base
from reflex.compiler.templates import STATEFUL_COMPONENT
+from reflex.components.core.breakpoints import Breakpoints
from reflex.components.tags import Tag
from reflex.constants import (
Dirs,
@@ -466,6 +467,12 @@ class Component(BaseComponent, ABC):
# Merge styles, the later ones overriding keys in the earlier ones.
style = {k: v for style_dict in style for k, v in style_dict.items()}
+ if isinstance(style, Breakpoints):
+ style = {
+ # Assign the Breakpoints to the self-referential selector to avoid squashing down to a regular dict.
+ "&": style,
+ }
+
kwargs["style"] = Style(
{
**self.get_fields()["style"].default,
diff --git a/reflex/components/core/banner.py b/reflex/components/core/banner.py
index b634ab75a..c6b46696c 100644
--- a/reflex/components/core/banner.py
+++ b/reflex/components/core/banner.py
@@ -153,6 +153,20 @@ useEffect(() => {{
hook,
]
+ @classmethod
+ def create(cls, *children, **props) -> Component:
+ """Create a connection toaster component.
+
+ Args:
+ *children: The children of the component.
+ **props: The properties of the component.
+
+ Returns:
+ The connection toaster component.
+ """
+ Toaster.is_used = True
+ return super().create(*children, **props)
+
class ConnectionBanner(Component):
"""A connection banner component."""
diff --git a/reflex/components/core/banner.pyi b/reflex/components/core/banner.pyi
index ddaafb153..b9b6d506f 100644
--- a/reflex/components/core/banner.pyi
+++ b/reflex/components/core/banner.pyi
@@ -187,7 +187,7 @@ class ConnectionToaster(Toaster):
] = None,
**props,
) -> "ConnectionToaster":
- """Create the component.
+ """Create a connection toaster component.
Args:
*children: The children of the component.
@@ -211,10 +211,10 @@ class ConnectionToaster(Toaster):
class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute
- **props: The props of the component.
+ **props: The properties of the component.
Returns:
- The component.
+ The connection toaster component.
"""
...
diff --git a/reflex/components/el/__init__.py b/reflex/components/el/__init__.py
index 750e65dba..9fe1d89cd 100644
--- a/reflex/components/el/__init__.py
+++ b/reflex/components/el/__init__.py
@@ -10,6 +10,7 @@ _SUBMODULES: set[str] = {"elements"}
_SUBMOD_ATTRS: dict[str, list[str]] = {
f"elements.{k}": v for k, v in elements._MAPPING.items()
}
+_PYRIGHT_IGNORE_IMPORTS = elements._PYRIGHT_IGNORE_IMPORTS
__getattr__, __dir__, __all__ = lazy_loader.attach(
__name__,
diff --git a/reflex/components/el/__init__.pyi b/reflex/components/el/__init__.pyi
index 080c9c37c..d312db84e 100644
--- a/reflex/components/el/__init__.pyi
+++ b/reflex/components/el/__init__.pyi
@@ -3,6 +3,7 @@
# This file was generated by `reflex/utils/pyi_generator.py`!
# ------------------------------------------------------
+from . import elements
from .elements.forms import Button as Button
from .elements.forms import Fieldset as Fieldset
from .elements.forms import Form as Form
@@ -91,7 +92,7 @@ from .elements.media import Defs as Defs
from .elements.media import Embed as Embed
from .elements.media import Iframe as Iframe
from .elements.media import Img as Img
-from .elements.media import Lineargradient as Lineargradient
+from .elements.media import LinearGradient as LinearGradient
from .elements.media import Map as Map
from .elements.media import Object as Object
from .elements.media import Path as Path
@@ -104,19 +105,19 @@ from .elements.media import Track as Track
from .elements.media import Video as Video
from .elements.media import area as area
from .elements.media import audio as audio
-from .elements.media import defs as defs
+from .elements.media import defs as defs # type: ignore
from .elements.media import embed as embed
from .elements.media import iframe as iframe
from .elements.media import image as image
from .elements.media import img as img
-from .elements.media import lineargradient as lineargradient
+from .elements.media import lineargradient as lineargradient # type: ignore
from .elements.media import map as map
from .elements.media import object as object
-from .elements.media import path as path
+from .elements.media import path as path # type: ignore
from .elements.media import picture as picture
from .elements.media import portal as portal
from .elements.media import source as source
-from .elements.media import stop as stop
+from .elements.media import stop as stop # type: ignore
from .elements.media import svg as svg
from .elements.media import track as track
from .elements.media import video as video
@@ -230,3 +231,5 @@ from .elements.typography import ol as ol
from .elements.typography import p as p
from .elements.typography import pre as pre
from .elements.typography import ul as ul
+
+_PYRIGHT_IGNORE_IMPORTS = elements._PYRIGHT_IGNORE_IMPORTS
diff --git a/reflex/components/el/elements/__init__.py b/reflex/components/el/elements/__init__.py
index 1c2684f1d..024ae8c3d 100644
--- a/reflex/components/el/elements/__init__.py
+++ b/reflex/components/el/elements/__init__.py
@@ -67,6 +67,7 @@ _MAPPING = {
"svg",
"defs",
"lineargradient",
+ "LinearGradient",
"stop",
"path",
],
@@ -129,12 +130,13 @@ _MAPPING = {
}
-EXCLUDE = ["del_", "Del", "image"]
+EXCLUDE = ["del_", "Del", "image", "lineargradient", "LinearGradient"]
for _, v in _MAPPING.items():
v.extend([mod.capitalize() for mod in v if mod not in EXCLUDE])
_SUBMOD_ATTRS: dict[str, list[str]] = _MAPPING
+_PYRIGHT_IGNORE_IMPORTS = ["stop", "lineargradient", "path", "defs"]
__getattr__, __dir__, __all__ = lazy_loader.attach(
__name__,
submod_attrs=_SUBMOD_ATTRS,
diff --git a/reflex/components/el/elements/__init__.pyi b/reflex/components/el/elements/__init__.pyi
index 8c35756e5..4f218f361 100644
--- a/reflex/components/el/elements/__init__.pyi
+++ b/reflex/components/el/elements/__init__.pyi
@@ -91,7 +91,7 @@ from .media import Defs as Defs
from .media import Embed as Embed
from .media import Iframe as Iframe
from .media import Img as Img
-from .media import Lineargradient as Lineargradient
+from .media import LinearGradient as LinearGradient
from .media import Map as Map
from .media import Object as Object
from .media import Path as Path
@@ -104,19 +104,19 @@ from .media import Track as Track
from .media import Video as Video
from .media import area as area
from .media import audio as audio
-from .media import defs as defs
+from .media import defs as defs # type: ignore
from .media import embed as embed
from .media import iframe as iframe
from .media import image as image
from .media import img as img
-from .media import lineargradient as lineargradient
+from .media import lineargradient as lineargradient # type: ignore
from .media import map as map
from .media import object as object
-from .media import path as path
+from .media import path as path # type: ignore
from .media import picture as picture
from .media import portal as portal
from .media import source as source
-from .media import stop as stop
+from .media import stop as stop # type: ignore
from .media import svg as svg
from .media import track as track
from .media import video as video
@@ -294,6 +294,7 @@ _MAPPING = {
"svg",
"defs",
"lineargradient",
+ "LinearGradient",
"stop",
"path",
],
@@ -347,6 +348,7 @@ _MAPPING = {
"Del",
],
}
-EXCLUDE = ["del_", "Del", "image"]
+EXCLUDE = ["del_", "Del", "image", "lineargradient", "LinearGradient"]
for _, v in _MAPPING.items():
v.extend([mod.capitalize() for mod in v if mod not in EXCLUDE])
+_PYRIGHT_IGNORE_IMPORTS = ["stop", "lineargradient", "path", "defs"]
diff --git a/reflex/components/el/elements/media.py b/reflex/components/el/elements/media.py
index 8d56c78b4..b2bdc9e6f 100644
--- a/reflex/components/el/elements/media.py
+++ b/reflex/components/el/elements/media.py
@@ -2,8 +2,9 @@
from typing import Any, Union
-from reflex import Component
+from reflex import Component, ComponentNamespace
from reflex.constants.colors import Color
+from reflex.utils import console
from reflex.vars import Var as Var
from .base import BaseHTML
@@ -309,6 +310,56 @@ class Svg(BaseHTML):
"""Display the svg element."""
tag = "svg"
+ # The width of the svg.
+ width: Var[Union[str, int]]
+ # The height of the svg.
+ height: Var[Union[str, int]]
+ # The XML namespace declaration.
+ xmlns: Var[str]
+
+
+class Circle(BaseHTML):
+ """The SVG circle component."""
+
+ tag = "circle"
+ # The x-axis coordinate of the center of the circle.
+ cx: Var[Union[str, int]]
+ # The y-axis coordinate of the center of the circle.
+ cy: Var[Union[str, int]]
+ # The radius of the circle.
+ r: Var[Union[str, int]]
+ # The total length for the circle's circumference, in user units.
+ path_length: Var[int]
+
+
+class Rect(BaseHTML):
+ """The SVG rect component."""
+
+ tag = "rect"
+ # The x coordinate of the rect.
+ x: Var[Union[str, int]]
+ # The y coordinate of the rect.
+ y: Var[Union[str, int]]
+ # The width of the rect
+ width: Var[Union[str, int]]
+ # The height of the rect.
+ height: Var[Union[str, int]]
+ # The horizontal corner radius of the rect. Defaults to ry if it is specified.
+ rx: Var[Union[str, int]]
+ # The vertical corner radius of the rect. Defaults to rx if it is specified.
+ ry: Var[Union[str, int]]
+ # The total length of the rectangle's perimeter, in user units.
+ path_length: Var[int]
+
+
+class Polygon(BaseHTML):
+ """The SVG polygon component."""
+
+ tag = "polygon"
+ # defines the list of points (pairs of x,y absolute coordinates) required to draw the polygon.
+ points: Var[str]
+ # This prop lets specify the total length for the path, in user units.
+ path_length: Var[int]
class Defs(BaseHTML):
@@ -317,30 +368,30 @@ class Defs(BaseHTML):
tag = "defs"
-class Lineargradient(BaseHTML):
+class LinearGradient(BaseHTML):
"""Display the linearGradient element."""
tag = "linearGradient"
- # Units for the gradient
+ # Units for the gradient.
gradient_units: Var[Union[str, bool]]
- # Transform applied to the gradient
+ # Transform applied to the gradient.
gradient_transform: Var[Union[str, bool]]
- # Method used to spread the gradient
+ # Method used to spread the gradient.
spread_method: Var[Union[str, bool]]
- # X coordinate of the starting point of the gradient
+ # X coordinate of the starting point of the gradient.
x1: Var[Union[str, int, bool]]
- # X coordinate of the ending point of the gradient
+ # X coordinate of the ending point of the gradient.
x2: Var[Union[str, int, bool]]
- # Y coordinate of the starting point of the gradient
+ # Y coordinate of the starting point of the gradient.
y1: Var[Union[str, int, bool]]
- # Y coordinate of the ending point of the gradient
+ # Y coordinate of the ending point of the gradient.
y2: Var[Union[str, int, bool]]
@@ -349,13 +400,13 @@ class Stop(BaseHTML):
tag = "stop"
- # Offset of the gradient stop
+ # Offset of the gradient stop.
offset: Var[Union[str, float, int]]
- # Color of the gradient stop
+ # Color of the gradient stop.
stop_color: Var[Union[str, Color, bool]]
- # Opacity of the gradient stop
+ # Opacity of the gradient stop.
stop_opacity: Var[Union[str, float, int, bool]]
@@ -364,10 +415,23 @@ class Path(BaseHTML):
tag = "path"
- # Defines the shape of the path
+ # Defines the shape of the path.
d: Var[Union[str, int, bool]]
+class SVG(ComponentNamespace):
+ """SVG component namespace."""
+
+ circle = staticmethod(Circle.create)
+ rect = staticmethod(Rect.create)
+ polygon = staticmethod(Polygon.create)
+ path = staticmethod(Path.create)
+ stop = staticmethod(Stop.create)
+ linear_gradient = staticmethod(LinearGradient.create)
+ defs = staticmethod(Defs.create)
+ __call__ = staticmethod(Svg.create)
+
+
area = Area.create
audio = Audio.create
image = img = Img.create
@@ -380,8 +444,24 @@ object = Object.create
picture = Picture.create
portal = Portal.create
source = Source.create
-svg = Svg.create
-defs = Defs.create
-lineargradient = Lineargradient.create
-stop = Stop.create
-path = Path.create
+svg = SVG()
+
+
+def __getattr__(name: str):
+ if name in ("defs", "lineargradient", "stop", "path"):
+ console.deprecate(
+ f"`rx.el.{name}`",
+ reason=f"use `rx.el.svg.{'linear_gradient' if name =='lineargradient' else name}`",
+ deprecation_version="0.5.8",
+ removal_version="0.6.0",
+ )
+ return (
+ LinearGradient.create
+ if name == "lineargradient"
+ else globals()[name.capitalize()].create
+ )
+
+ try:
+ return globals()[name]
+ except KeyError:
+ raise AttributeError(f"module '{__name__} has no attribute '{name}'") from None
diff --git a/reflex/components/el/elements/media.pyi b/reflex/components/el/elements/media.pyi
index 7a9a064fb..ba5f14137 100644
--- a/reflex/components/el/elements/media.pyi
+++ b/reflex/components/el/elements/media.pyi
@@ -5,6 +5,7 @@
# ------------------------------------------------------
from typing import Any, Callable, Dict, Optional, Union, overload
+from reflex import ComponentNamespace
from reflex.constants.colors import Color
from reflex.event import EventHandler, EventSpec
from reflex.style import Style
@@ -1563,6 +1564,9 @@ class Svg(BaseHTML):
def create( # type: ignore
cls,
*children,
+ width: Optional[Union[Var[Union[int, str]], str, int]] = None,
+ height: Optional[Union[Var[Union[int, str]], str, int]] = None,
+ xmlns: Optional[Union[Var[str], str]] = None,
access_key: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None,
auto_capitalize: Optional[
Union[Var[Union[bool, int, str]], str, int, bool]
@@ -1644,6 +1648,383 @@ class Svg(BaseHTML):
Args:
*children: The children of the component.
+ width: The width of the svg.
+ height: The height of the svg.
+ xmlns: The XML namespace declaration.
+ access_key: Provides a hint for generating a keyboard shortcut for the current element.
+ auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user.
+ content_editable: Indicates whether the element's content is editable.
+ context_menu: Defines the ID of a