diff --git a/.github/workflows/performance.yml b/.github/workflows/performance.yml new file mode 100644 index 000000000..c7bd1003a --- /dev/null +++ b/.github/workflows/performance.yml @@ -0,0 +1,34 @@ +name: performance-tests + +on: + push: + branches: + - "main" # or "master" + paths-ignore: + - "**/*.md" + pull_request: + workflow_dispatch: + +env: + TELEMETRY_ENABLED: false + NODE_OPTIONS: "--max_old_space_size=8192" + PR_TITLE: ${{ github.event.pull_request.title }} + APP_HARNESS_HEADLESS: 1 + PYTHONUNBUFFERED: 1 + +jobs: + benchmarks: + name: Run benchmarks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: ./.github/actions/setup_build_env + with: + python-version: 3.12.8 + run-poetry-install: true + create-venv-at-path: .venv + - name: Run benchmarks + uses: CodSpeedHQ/action@v3 + with: + token: ${{ secrets.CODSPEED_TOKEN }} + run: poetry run pytest benchmarks/test_evaluate.py --codspeed diff --git a/.gitignore b/.gitignore index 8bd92964c..29a868796 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ requirements.txt .pyi_generator_last_run .pyi_generator_diff reflex.db +.codspeed \ No newline at end of file diff --git a/benchmarks/test_evaluate.py b/benchmarks/test_evaluate.py new file mode 100644 index 000000000..aa4c8237e --- /dev/null +++ b/benchmarks/test_evaluate.py @@ -0,0 +1,231 @@ +from dataclasses import dataclass +from typing import cast + +import pytest + +import reflex as rx + + +class SideBarState(rx.State): + """State for the side bar.""" + + current_page: rx.Field[str] = rx.field("/") + + +@dataclass(frozen=True) +class SideBarPage: + """A page in the side bar.""" + + title: str + href: str + + +@dataclass(frozen=True) +class SideBarSection: + """A section in the side bar.""" + + name: str + icon: str + pages: tuple[SideBarPage, ...] + + +@dataclass(frozen=True) +class Category: + """A category in the side bar.""" + + name: str + href: str + sections: tuple[SideBarSection, ...] + + +SIDE_BAR = ( + Category( + name="General", + href="/", + sections=( + SideBarSection( + name="Home", + icon="home", + pages=( + SideBarPage(title="Home", href="/"), + SideBarPage(title="Contact", href="/contact"), + ), + ), + SideBarSection( + name="About", + icon="info", + pages=( + SideBarPage(title="About", href="/about"), + SideBarPage(title="FAQ", href="/faq"), + ), + ), + ), + ), + Category( + name="Projects", + href="/projects", + sections=( + SideBarSection( + name="Python", + icon="worm", + pages=( + SideBarPage(title="Python", href="/projects/python"), + SideBarPage(title="Django", href="/projects/django"), + SideBarPage(title="Flask", href="/projects/flask"), + SideBarPage(title="FastAPI", href="/projects/fastapi"), + SideBarPage(title="Pyramid", href="/projects/pyramid"), + SideBarPage(title="Tornado", href="/projects/tornado"), + SideBarPage(title="TurboGears", href="/projects/turbogears"), + SideBarPage(title="Web2py", href="/projects/web2py"), + SideBarPage(title="Zope", href="/projects/zope"), + SideBarPage(title="Plone", href="/projects/plone"), + SideBarPage(title="Quixote", href="/projects/quixote"), + SideBarPage(title="Bottle", href="/projects/bottle"), + SideBarPage(title="CherryPy", href="/projects/cherrypy"), + SideBarPage(title="Falcon", href="/projects/falcon"), + SideBarPage(title="Sanic", href="/projects/sanic"), + SideBarPage(title="Starlette", href="/projects/starlette"), + ), + ), + SideBarSection( + name="JavaScript", + icon="banana", + pages=( + SideBarPage(title="JavaScript", href="/projects/javascript"), + SideBarPage(title="Angular", href="/projects/angular"), + SideBarPage(title="React", href="/projects/react"), + SideBarPage(title="Vue", href="/projects/vue"), + SideBarPage(title="Ember", href="/projects/ember"), + SideBarPage(title="Backbone", href="/projects/backbone"), + SideBarPage(title="Meteor", href="/projects/meteor"), + SideBarPage(title="Svelte", href="/projects/svelte"), + SideBarPage(title="Preact", href="/projects/preact"), + SideBarPage(title="Mithril", href="/projects/mithril"), + SideBarPage(title="Aurelia", href="/projects/aurelia"), + SideBarPage(title="Polymer", href="/projects/polymer"), + SideBarPage(title="Knockout", href="/projects/knockout"), + SideBarPage(title="Dojo", href="/projects/dojo"), + SideBarPage(title="Riot", href="/projects/riot"), + SideBarPage(title="Alpine", href="/projects/alpine"), + SideBarPage(title="Stimulus", href="/projects/stimulus"), + SideBarPage(title="Marko", href="/projects/marko"), + SideBarPage(title="Sapper", href="/projects/sapper"), + SideBarPage(title="Nuxt", href="/projects/nuxt"), + SideBarPage(title="Next", href="/projects/next"), + SideBarPage(title="Gatsby", href="/projects/gatsby"), + SideBarPage(title="Gridsome", href="/projects/gridsome"), + SideBarPage(title="Nest", href="/projects/nest"), + SideBarPage(title="Express", href="/projects/express"), + SideBarPage(title="Koa", href="/projects/koa"), + SideBarPage(title="Hapi", href="/projects/hapi"), + SideBarPage(title="LoopBack", href="/projects/loopback"), + SideBarPage(title="Feathers", href="/projects/feathers"), + SideBarPage(title="Sails", href="/projects/sails"), + SideBarPage(title="Adonis", href="/projects/adonis"), + SideBarPage(title="Meteor", href="/projects/meteor"), + SideBarPage(title="Derby", href="/projects/derby"), + SideBarPage(title="Socket.IO", href="/projects/socketio"), + ), + ), + ), + ), +) + + +def side_bar_page(page: SideBarPage): + return rx.box( + rx.link( + page.title, + href=page.href, + ) + ) + + +def side_bar_section(section: SideBarSection): + return rx.accordion.item( + rx.accordion.header( + rx.accordion.trigger( + rx.hstack( + rx.hstack( + rx.icon(section.icon), + section.name, + align="center", + ), + rx.accordion.icon(), + width="100%", + justify="between", + ) + ) + ), + rx.accordion.content( + rx.vstack( + *map(side_bar_page, section.pages), + ), + border_inline_start="1px solid", + padding_inline_start="1em", + margin_inline_start="1.5em", + ), + value=section.name, + width="100%", + variant="ghost", + ) + + +def side_bar_category(category: Category): + selected_section = cast( + rx.Var, + rx.match( + SideBarState.current_page, + *[ + ( + section.name, + section.name, + ) + for section in category.sections + ], + None, + ), + ) + return rx.vstack( + rx.heading( + rx.link( + category.name, + href=category.href, + ), + size="5", + ), + rx.accordion.root( + *map(side_bar_section, category.sections), + default_value=selected_section.to(str), + variant="ghost", + width="100%", + collapsible=True, + type="multiple", + ), + width="100%", + ) + + +def side_bar(): + return rx.vstack( + *map(side_bar_category, SIDE_BAR), + width="fit-content", + ) + + +LOREM_IPSUM = "Lorem ipsum dolor sit amet, dolor ut dolore pariatur aliqua enim tempor sed. Labore excepteur sed exercitation. Ullamco aliquip lorem sunt enim in incididunt. Magna anim officia sint cillum labore. Ut eu non dolore minim nostrud magna eu, aute ex in incididunt irure eu. Fugiat et magna magna est excepteur eiusmod minim. Quis eiusmod et non pariatur dolor veniam incididunt, eiusmod irure enim sed dolor lorem pariatur do. Occaecat duis irure excepteur dolore. Proident ut laborum pariatur sit sit, nisi nostrud voluptate magna commodo laborum esse velit. Voluptate non minim deserunt adipiscing irure deserunt cupidatat. Laboris veniam commodo incididunt veniam lorem occaecat, fugiat ipsum dolor cupidatat. Ea officia sed eu excepteur culpa adipiscing, tempor consectetur ullamco eu. Anim ex proident nulla sunt culpa, voluptate veniam proident est adipiscing sint elit velit. Laboris adipiscing est culpa cillum magna. Sit veniam nulla nulla, aliqua eiusmod commodo lorem cupidatat commodo occaecat. Fugiat cillum dolor incididunt mollit eiusmod sint. Non lorem dolore labore excepteur minim laborum sed. Irure nisi do lorem nulla sunt commodo, deserunt quis mollit consectetur minim et esse est, proident nostrud officia enim sed reprehenderit. Magna cillum consequat aute reprehenderit duis sunt ullamco. Labore qui mollit voluptate. Duis dolor sint aute amet aliquip officia, est non mollit tempor enim quis fugiat, eu do culpa consectetur magna. Do ullamco aliqua voluptate culpa excepteur reprehenderit reprehenderit. Occaecat nulla sit est magna. Deserunt ea voluptate veniam cillum. Amet cupidatat duis est tempor fugiat ex eu, officia est sunt consectetur labore esse exercitation. Nisi cupidatat irure est nisi. Officia amet eu veniam reprehenderit. In amet incididunt tempor commodo ea labore. Mollit dolor aliquip excepteur, voluptate aute occaecat id officia proident. Ullamco est amet tempor. Proident aliquip proident mollit do aliquip ipsum, culpa quis aute id irure. Velit excepteur cillum cillum ut cupidatat. Occaecat qui elit esse nulla minim. Consequat velit id ad pariatur tempor. Eiusmod deserunt aliqua ex sed quis non. Dolor sint commodo ex in deserunt nostrud excepteur, pariatur ex aliqua anim adipiscing amet proident. Laboris eu laborum magna lorem ipsum fugiat velit." + + +def complicated_page(): + return rx.hstack( + side_bar(), + rx.box( + rx.heading("Complicated Page", size="1"), + rx.text(LOREM_IPSUM), + ), + ) + + +@pytest.mark.benchmark +def test_component_init(): + complicated_page() diff --git a/poetry.lock b/poetry.lock index f8d4cf949..125b71b55 100644 --- a/poetry.lock +++ b/poetry.lock @@ -251,7 +251,7 @@ files = [ {file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"}, {file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"}, ] -markers = {main = "(platform_machine != \"ppc64le\" and platform_machine != \"s390x\") and sys_platform == \"linux\" and platform_python_implementation != \"PyPy\" and (python_version <= \"3.11\" or python_version >= \"3.12\")", dev = "os_name == \"nt\" and implementation_name != \"pypy\" and (python_version <= \"3.11\" or python_version >= \"3.12\")"} +markers = {main = "(platform_machine != \"ppc64le\" and platform_machine != \"s390x\") and sys_platform == \"linux\" and platform_python_implementation != \"PyPy\" and (python_version <= \"3.11\" or python_version >= \"3.12\")", dev = "python_version <= \"3.11\" or python_version >= \"3.12\""} [package.dependencies] pycparser = "*" @@ -495,7 +495,6 @@ files = [ {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:761817a3377ef15ac23cd7834715081791d4ec77f9297ee694ca1ee9c2c7e5eb"}, {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3c672a53c0fb4725a29c303be906d3c1fa99c32f58abe008a82705f9ee96f40b"}, {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4ac4c9f37eba52cb6fbeaf5b59c152ea976726b865bd4cf87883a7e7006cc543"}, - {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:60eb32934076fa07e4316b7b2742fa52cbb190b42c2df2863dbc4230a0a9b385"}, {file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ed3534eb1090483c96178fcb0f8893719d96d5274dfde98aa6add34614e97c8e"}, {file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f3f6fdfa89ee2d9d496e2c087cebef9d4fcbb0ad63c40e821b39f74bf48d9c5e"}, {file = "cryptography-44.0.0-cp37-abi3-win32.whl", hash = "sha256:eb33480f1bad5b78233b0ad3e1b0be21e8ef1da745d8d2aecbb20671658b9053"}, @@ -506,7 +505,6 @@ files = [ {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c5eb858beed7835e5ad1faba59e865109f3e52b3783b9ac21e7e47dc5554e289"}, {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f53c2c87e0fb4b0c00fa9571082a057e37690a8f12233306161c8f4b819960b7"}, {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:9e6fc8a08e116fb7c7dd1f040074c9d7b51d74a8ea40d4df2fc7aa08b76b9e6c"}, - {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:9abcc2e083cbe8dde89124a47e5e53ec38751f0d7dfd36801008f316a127d7ba"}, {file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d2436114e46b36d00f8b72ff57e598978b37399d2786fd39793c36c6d5cb1c64"}, {file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a01956ddfa0a6790d594f5b34fc1bfa6098aca434696a03cfdbe469b8ed79285"}, {file = "cryptography-44.0.0-cp39-abi3-win32.whl", hash = "sha256:eca27345e1214d1b9f9490d200f9db5a874479be914199194e746c893788d417"}, @@ -1101,7 +1099,7 @@ version = "3.0.0" description = "Python port of markdown-it. Markdown parsing, done right!" optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] markers = "python_version <= \"3.11\" or python_version >= \"3.12\"" files = [ {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, @@ -1199,7 +1197,7 @@ version = "0.1.2" description = "Markdown URL utilities" optional = false python-versions = ">=3.7" -groups = ["main"] +groups = ["main", "dev"] markers = "python_version <= \"3.11\" or python_version >= \"3.12\"" files = [ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, @@ -1690,7 +1688,7 @@ files = [ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, ] -markers = {main = "(platform_machine != \"ppc64le\" and platform_machine != \"s390x\") and sys_platform == \"linux\" and platform_python_implementation != \"PyPy\" and (python_version <= \"3.11\" or python_version >= \"3.12\")", dev = "os_name == \"nt\" and implementation_name != \"pypy\" and (python_version <= \"3.11\" or python_version >= \"3.12\")"} +markers = {main = "(platform_machine != \"ppc64le\" and platform_machine != \"s390x\") and sys_platform == \"linux\" and platform_python_implementation != \"PyPy\" and (python_version <= \"3.11\" or python_version >= \"3.12\")", dev = "python_version <= \"3.11\" or python_version >= \"3.12\""} [[package]] name = "pydantic" @@ -1853,7 +1851,7 @@ version = "2.19.1" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] markers = "python_version <= \"3.11\" or python_version >= \"3.12\"" files = [ {file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"}, @@ -1998,6 +1996,39 @@ aspect = ["aspectlib"] elasticsearch = ["elasticsearch"] histogram = ["pygal", "pygaljs", "setuptools"] +[[package]] +name = "pytest-codspeed" +version = "3.1.2" +description = "Pytest plugin to create CodSpeed benchmarks" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +markers = "python_version <= \"3.11\" or python_version >= \"3.12\"" +files = [ + {file = "pytest_codspeed-3.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:aed496f873670ce0ea8f980a7c1a2c6a08f415e0ebdf207bf651b2d922103374"}, + {file = "pytest_codspeed-3.1.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ee45b0b763f6b5fa5d74c7b91d694a9615561c428b320383660672f4471756e3"}, + {file = "pytest_codspeed-3.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c84e591a7a0f67d45e2dc9fd05b276971a3aabcab7478fe43363ebefec1358f4"}, + {file = "pytest_codspeed-3.1.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c6ae6d094247156407770e6b517af70b98862dd59a3c31034aede11d5f71c32c"}, + {file = "pytest_codspeed-3.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d0f264991de5b5cdc118b96fc671386cca3f0f34e411482939bf2459dc599097"}, + {file = "pytest_codspeed-3.1.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c0695a4bcd5ff04e8379124dba5d9795ea5e0cadf38be7a0406432fc1467b555"}, + {file = "pytest_codspeed-3.1.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6dc356c8dcaaa883af83310f397ac06c96fac9b8a1146e303d4b374b2cb46a18"}, + {file = "pytest_codspeed-3.1.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cc8a5d0366322a75cf562f7d8d672d28c1cf6948695c4dddca50331e08f6b3d5"}, + {file = "pytest_codspeed-3.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c5fe7a19b72f54f217480b3b527102579547b1de9fe3acd9e66cb4629ff46c8"}, + {file = "pytest_codspeed-3.1.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b67205755a665593f6521a98317d02a9d07d6fdc593f6634de2c94dea47a3055"}, + {file = "pytest_codspeed-3.1.2-py3-none-any.whl", hash = "sha256:5e7ed0315e33496c5c07dba262b50303b8d0bc4c3d10bf1d422a41e70783f1cb"}, + {file = "pytest_codspeed-3.1.2.tar.gz", hash = "sha256:09c1733af3aab35e94a621aa510f2d2114f65591e6f644c42ca3f67547edad4b"}, +] + +[package.dependencies] +cffi = ">=1.17.1" +pytest = ">=3.8" +rich = ">=13.8.1" + +[package.extras] +compat = ["pytest-benchmark (>=5.0.0,<5.1.0)", "pytest-xdist (>=3.6.1,<3.7.0)"] +lint = ["mypy (>=1.11.2,<1.12.0)", "ruff (>=0.6.5,<0.7.0)"] +test = ["pytest (>=7.0,<8.0)", "pytest-cov (>=4.0.0,<4.1.0)"] + [[package]] name = "pytest-cov" version = "6.0.0" @@ -2362,7 +2393,7 @@ version = "13.9.4" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.8.0" -groups = ["main"] +groups = ["main", "dev"] markers = "python_version <= \"3.11\" or python_version >= \"3.12\"" files = [ {file = "rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90"}, @@ -3152,4 +3183,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.10, <4.0" -content-hash = "35c503a68e87896b4f7d7c209dd3fe6d707ebcc1702377cab0a1339554c6ad77" +content-hash = "822150bcbf41e5cbb61da0a059b41d8971e3c6c974c8af4be7ef55126648aea1" diff --git a/pyproject.toml b/pyproject.toml index 6eeb17489..8d0b37a23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ selenium = ">=4.11.0,<5.0" pytest-benchmark = ">=4.0.0,<6.0" playwright = ">=1.46.0" pytest-playwright = ">=0.5.1" +pytest-codspeed = "^3.1.2" [tool.poetry.scripts] reflex = "reflex.reflex:cli" diff --git a/reflex/.templates/jinja/web/pages/utils.js.jinja2 b/reflex/.templates/jinja/web/pages/utils.js.jinja2 index 567ca6e60..08aeb0d38 100644 --- a/reflex/.templates/jinja/web/pages/utils.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/utils.js.jinja2 @@ -86,7 +86,7 @@ {% for condition in case[:-1] %} case JSON.stringify({{ condition._js_expr }}): {% endfor %} - return {{ case[-1] }}; + return {{ render(case[-1]) }}; break; {% endfor %} default: diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 93c664ef1..009910a32 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -106,6 +106,18 @@ export const getBackendURL = (url_str) => { return endpoint; }; +/** + * Check if the backend is disabled. + * + * @returns True if the backend is disabled, false otherwise. + */ +export const isBackendDisabled = () => { + const cookie = document.cookie + .split("; ") + .find((row) => row.startsWith("backend-enabled=")); + return cookie !== undefined && cookie.split("=")[1] == "false"; +}; + /** * Determine if any event in the event queue is stateful. * @@ -301,10 +313,7 @@ export const applyEvent = async (event, socket) => { // Send the event to the server. if (socket) { - socket.emit( - "event", - event, - ); + socket.emit("event", event); return true; } @@ -497,7 +506,7 @@ export const uploadFiles = async ( return false; } - const upload_ref_name = `__upload_controllers_${upload_id}` + const upload_ref_name = `__upload_controllers_${upload_id}`; if (refs[upload_ref_name]) { console.log("Upload already in progress for ", upload_id); @@ -815,7 +824,7 @@ export const useEventLoop = ( return; } // only use websockets if state is present - if (Object.keys(initialState).length > 1) { + if (Object.keys(initialState).length > 1 && !isBackendDisabled()) { // Initialize the websocket connection. if (!socket.current) { connect( diff --git a/reflex/app.py b/reflex/app.py index d33fa4b31..ce6808816 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -59,7 +59,11 @@ from reflex.components.component import ( ComponentStyle, evaluate_style_namespaces, ) -from reflex.components.core.banner import connection_pulser, connection_toaster +from reflex.components.core.banner import ( + backend_disabled, + connection_pulser, + connection_toaster, +) from reflex.components.core.breakpoints import set_breakpoints from reflex.components.core.client_side_routing import ( Default404Page, @@ -158,9 +162,12 @@ def default_overlay_component() -> Component: Returns: The default overlay_component, which is a connection_modal. """ + config = get_config() + return Fragment.create( connection_pulser(), connection_toaster(), + *([backend_disabled()] if config.is_reflex_cloud else []), *codespaces.codespaces_auto_redirect(), ) diff --git a/reflex/components/base/error_boundary.py b/reflex/components/base/error_boundary.py index f328773c2..74867a757 100644 --- a/reflex/components/base/error_boundary.py +++ b/reflex/components/base/error_boundary.py @@ -11,10 +11,11 @@ from reflex.event import EventHandler, set_clipboard from reflex.state import FrontendEventExceptionState from reflex.vars.base import Var from reflex.vars.function import ArgsFunctionOperation +from reflex.vars.object import ObjectVar def on_error_spec( - error: Var[Dict[str, str]], info: Var[Dict[str, str]] + error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]] ) -> Tuple[Var[str], Var[str]]: """The spec for the on_error event handler. diff --git a/reflex/components/base/error_boundary.pyi b/reflex/components/base/error_boundary.pyi index 2e01c7da0..8d27af0f3 100644 --- a/reflex/components/base/error_boundary.pyi +++ b/reflex/components/base/error_boundary.pyi @@ -9,9 +9,10 @@ from reflex.components.component import Component from reflex.event import BASE_STATE, EventType from reflex.style import Style from reflex.vars.base import Var +from reflex.vars.object import ObjectVar def on_error_spec( - error: Var[Dict[str, str]], info: Var[Dict[str, str]] + error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]] ) -> Tuple[Var[str], Var[str]]: ... class ErrorBoundary(Component): diff --git a/reflex/components/component.py b/reflex/components/component.py index 3f1b88fea..440a408df 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -1944,7 +1944,7 @@ class StatefulComponent(BaseComponent): if not should_memoize: # Determine if any Vars have associated data. - for prop_var in component._get_vars(): + for prop_var in component._get_vars(include_children=True): if prop_var._get_all_var_data(): should_memoize = True break @@ -2327,8 +2327,8 @@ class MemoizationLeaf(Component): """ comp = super().create(*children, **props) if comp._get_all_hooks(): - comp._memoization_mode = cls._memoization_mode.copy( - update={"disposition": MemoizationDisposition.ALWAYS} + comp._memoization_mode = dataclasses.replace( + comp._memoization_mode, disposition=MemoizationDisposition.ALWAYS ) return comp @@ -2389,7 +2389,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) -> if tag["name"] == "match": element = tag["cond"] - conditionals = tag["default"] + conditionals = render_dict_to_var(tag["default"], imported_names) for case in tag["match_cases"][::-1]: condition = case[0].to_string() == element.to_string() @@ -2398,7 +2398,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) -> conditionals = ternary_operation( condition, - case[-1], + render_dict_to_var(case[-1], imported_names), conditionals, ) @@ -2457,6 +2457,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) -> @dataclasses.dataclass( eq=False, frozen=True, + slots=True, ) class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): """A Var that represents a Component.""" diff --git a/reflex/components/core/banner.py b/reflex/components/core/banner.py index 6479bf3b2..882975f2f 100644 --- a/reflex/components/core/banner.py +++ b/reflex/components/core/banner.py @@ -4,8 +4,10 @@ from __future__ import annotations from typing import Optional +from reflex import constants from reflex.components.component import Component from reflex.components.core.cond import cond +from reflex.components.datadisplay.logo import svg_logo from reflex.components.el.elements.typography import Div from reflex.components.lucide.icon import Icon from reflex.components.radix.themes.components.dialog import ( @@ -293,7 +295,84 @@ class ConnectionPulser(Div): ) +class BackendDisabled(Div): + """A component that displays a message when the backend is disabled.""" + + @classmethod + def create(cls, **props) -> Component: + """Create a backend disabled component. + + Args: + **props: The properties of the component. + + Returns: + The backend disabled component. + """ + import reflex as rx + + is_backend_disabled = Var( + "backendDisabled", + _var_type=bool, + _var_data=VarData( + hooks={ + "const [backendDisabled, setBackendDisabled] = useState(false);": None, + "useEffect(() => { setBackendDisabled(isBackendDisabled()); }, []);": None, + }, + imports={ + f"$/{constants.Dirs.STATE_PATH}": [ + ImportVar(tag="isBackendDisabled") + ], + }, + ), + ) + + return super().create( + rx.cond( + is_backend_disabled, + rx.box( + rx.box( + rx.card( + rx.vstack( + svg_logo(), + rx.text( + "You ran out of compute credits.", + ), + rx.callout( + rx.fragment( + "Please upgrade your plan or raise your compute credits at ", + rx.link( + "Reflex Cloud.", + href="https://cloud.reflex.dev/", + ), + ), + width="100%", + icon="info", + variant="surface", + ), + ), + font_size="20px", + font_family='"Inter", "Helvetica", "Arial", sans-serif', + variant="classic", + ), + position="fixed", + top="50%", + left="50%", + transform="translate(-50%, -50%)", + width="40ch", + max_width="90vw", + ), + position="fixed", + z_index=9999, + backdrop_filter="grayscale(1) blur(5px)", + width="100dvw", + height="100dvh", + ), + ) + ) + + connection_banner = ConnectionBanner.create connection_modal = ConnectionModal.create connection_toaster = ConnectionToaster.create connection_pulser = ConnectionPulser.create +backend_disabled = BackendDisabled.create diff --git a/reflex/components/core/banner.pyi b/reflex/components/core/banner.pyi index f44ee7992..2ea514965 100644 --- a/reflex/components/core/banner.pyi +++ b/reflex/components/core/banner.pyi @@ -350,7 +350,93 @@ class ConnectionPulser(Div): """ ... +class BackendDisabled(Div): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + access_key: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + auto_capitalize: Optional[ + Union[Var[Union[bool, int, str]], bool, int, str] + ] = None, + content_editable: Optional[ + Union[Var[Union[bool, int, str]], bool, int, str] + ] = None, + context_menu: Optional[ + Union[Var[Union[bool, int, str]], bool, int, str] + ] = None, + dir: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + draggable: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + enter_key_hint: Optional[ + Union[Var[Union[bool, int, str]], bool, int, str] + ] = None, + hidden: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + input_mode: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + item_prop: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + lang: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + role: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + slot: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + spell_check: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + tab_index: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + title: Optional[Union[Var[Union[bool, int, str]], bool, int, str]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None, + on_blur: Optional[EventType[[], BASE_STATE]] = None, + on_click: Optional[EventType[[], BASE_STATE]] = None, + on_context_menu: Optional[EventType[[], BASE_STATE]] = None, + on_double_click: Optional[EventType[[], BASE_STATE]] = None, + on_focus: Optional[EventType[[], BASE_STATE]] = None, + on_mount: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_down: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_move: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_out: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_over: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_up: Optional[EventType[[], BASE_STATE]] = None, + on_scroll: Optional[EventType[[], BASE_STATE]] = None, + on_unmount: Optional[EventType[[], BASE_STATE]] = None, + **props, + ) -> "BackendDisabled": + """Create a backend disabled component. + + Args: + access_key: Provides a hint for generating a keyboard shortcut for the current element. + auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. + content_editable: Indicates whether the element's content is editable. + context_menu: Defines the ID of a element which will serve as the element's context menu. + dir: Defines the text direction. Allowed values are ltr (Left-To-Right) or rtl (Right-To-Left) + draggable: Defines whether the element can be dragged. + enter_key_hint: Hints what media types the media element is able to play. + hidden: Defines whether the element is hidden. + input_mode: Defines the type of the element. + item_prop: Defines the name of the element for metadata purposes. + lang: Defines the language used in the element. + role: Defines the role of the element. + slot: Assigns a slot in a shadow DOM shadow tree to an element. + spell_check: Defines whether the element may be checked for spelling errors. + tab_index: Defines the position of the current element in the tabbing order. + title: Defines a tooltip for the element. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: The properties of the component. + + Returns: + The backend disabled component. + """ + ... + connection_banner = ConnectionBanner.create connection_modal = ConnectionModal.create connection_toaster = ConnectionToaster.create connection_pulser = ConnectionPulser.create +backend_disabled = BackendDisabled.create diff --git a/reflex/components/core/foreach.py b/reflex/components/core/foreach.py index 30dda9c6a..927b01333 100644 --- a/reflex/components/core/foreach.py +++ b/reflex/components/core/foreach.py @@ -11,6 +11,7 @@ from reflex.components.component import Component from reflex.components.tags import IterTag from reflex.constants import MemoizationMode from reflex.state import ComponentState +from reflex.utils.exceptions import UntypedVarError from reflex.vars.base import LiteralVar, Var @@ -51,6 +52,7 @@ class Foreach(Component): Raises: ForeachVarError: If the iterable is of type Any. TypeError: If the render function is a ComponentState. + UntypedVarError: If the iterable is of type Any without a type annotation. """ iterable = LiteralVar.create(iterable) if iterable._var_type == Any: @@ -72,8 +74,14 @@ class Foreach(Component): iterable=iterable, render_fn=render_fn, ) - # Keep a ref to a rendered component to determine correct imports/hooks/styles. - component.children = [component._render().render_component()] + try: + # Keep a ref to a rendered component to determine correct imports/hooks/styles. + component.children = [component._render().render_component()] + except UntypedVarError as e: + raise UntypedVarError( + f"Could not foreach over var `{iterable!s}` without a type annotation. " + "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" + ) from e return component def _render(self) -> IterTag: diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py index b2d6417bd..338fb2e44 100644 --- a/reflex/components/datadisplay/dataeditor.py +++ b/reflex/components/datadisplay/dataeditor.py @@ -387,7 +387,8 @@ class DataEditor(NoSSRComponent): raise ValueError( "DataEditor data must be an ArrayVar if rows is not provided." ) - props["rows"] = data.length() if isinstance(data, Var) else len(data) + + props["rows"] = data.length() if isinstance(data, ArrayVar) else len(data) if not isinstance(columns, Var) and len(columns): if types.is_dataframe(type(data)) or ( diff --git a/reflex/components/datadisplay/shiki_code_block.py b/reflex/components/datadisplay/shiki_code_block.py index 2d3040966..a4aaec1d4 100644 --- a/reflex/components/datadisplay/shiki_code_block.py +++ b/reflex/components/datadisplay/shiki_code_block.py @@ -621,18 +621,22 @@ class ShikiCodeBlock(Component, MarkdownComponentMap): Returns: Imports for the component. + + Raises: + ValueError: If the transformers are not of type LiteralVar. """ imports = defaultdict(list) + if not isinstance(self.transformers, LiteralVar): + raise ValueError( + f"transformers should be a LiteralVar type. Got {type(self.transformers)} instead." + ) for transformer in self.transformers._var_value: if isinstance(transformer, ShikiBaseTransformers): imports[transformer.library].extend( [ImportVar(tag=str(fn)) for fn in transformer.fns] ) - ( + if transformer.library not in self.lib_dependencies: self.lib_dependencies.append(transformer.library) - if transformer.library not in self.lib_dependencies - else None - ) return imports @classmethod diff --git a/reflex/components/radix/primitives/accordion.py b/reflex/components/radix/primitives/accordion.py index 2d9c7ae96..90a1c41f0 100644 --- a/reflex/components/radix/primitives/accordion.py +++ b/reflex/components/radix/primitives/accordion.py @@ -10,6 +10,7 @@ from reflex.components.core.cond import cond from reflex.components.lucide.icon import Icon from reflex.components.radix.primitives.base import RadixPrimitiveComponent from reflex.components.radix.themes.base import LiteralAccentColor, LiteralRadius +from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler from reflex.style import Style from reflex.vars import get_uuid_string_var @@ -342,6 +343,8 @@ class AccordionTrigger(AccordionComponent): alias = "RadixAccordionTrigger" + _memoization_mode = MemoizationMode(recursive=False) + @classmethod def create(cls, *children, **props) -> Component: """Create the Accordion trigger component. diff --git a/reflex/components/radix/primitives/drawer.py b/reflex/components/radix/primitives/drawer.py index b9056c9d0..30d1a6ae3 100644 --- a/reflex/components/radix/primitives/drawer.py +++ b/reflex/components/radix/primitives/drawer.py @@ -10,6 +10,7 @@ from reflex.components.component import Component, ComponentNamespace from reflex.components.radix.primitives.base import RadixPrimitiveComponent from reflex.components.radix.themes.base import Theme from reflex.components.radix.themes.layout.flex import Flex +from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var @@ -85,6 +86,8 @@ class DrawerTrigger(DrawerComponent): # Defaults to true, if the first child acts as the trigger. as_child: Var[bool] = Var.create(True) + _memoization_mode = MemoizationMode(recursive=False) + @classmethod def create(cls, *children: Any, **props: Any) -> Component: """Create a new DrawerTrigger instance. diff --git a/reflex/components/radix/themes/components/alert_dialog.py b/reflex/components/radix/themes/components/alert_dialog.py index 36d38532c..bc5e2dc7e 100644 --- a/reflex/components/radix/themes/components/alert_dialog.py +++ b/reflex/components/radix/themes/components/alert_dialog.py @@ -5,6 +5,7 @@ from typing import Literal from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var @@ -33,6 +34,8 @@ class AlertDialogTrigger(RadixThemesTriggerComponent): tag = "AlertDialog.Trigger" + _memoization_mode = MemoizationMode(recursive=False) + class AlertDialogContent(elements.Div, RadixThemesComponent): """Contains the content of the dialog. This component is based on the div element.""" diff --git a/reflex/components/radix/themes/components/card.py b/reflex/components/radix/themes/components/card.py index 30823de56..e99ea9cef 100644 --- a/reflex/components/radix/themes/components/card.py +++ b/reflex/components/radix/themes/components/card.py @@ -20,7 +20,7 @@ class Card(elements.Div, RadixThemesComponent): # Card size: "1" - "5" size: Var[Responsive[Literal["1", "2", "3", "4", "5"],]] - # Variant of Card: "solid" | "soft" | "outline" | "ghost" + # Variant of Card: "surface" | "classic" | "ghost" variant: Var[Literal["surface", "classic", "ghost"]] diff --git a/reflex/components/radix/themes/components/card.pyi b/reflex/components/radix/themes/components/card.pyi index d8ab6c06b..e515982e4 100644 --- a/reflex/components/radix/themes/components/card.pyi +++ b/reflex/components/radix/themes/components/card.pyi @@ -94,7 +94,7 @@ class Card(elements.Div, RadixThemesComponent): *children: Child components. as_child: Change the default rendered element for the one passed as a child, merging their props and behavior. size: Card size: "1" - "5" - variant: Variant of Card: "solid" | "soft" | "outline" | "ghost" + variant: Variant of Card: "surface" | "classic" | "ghost" access_key: Provides a hint for generating a keyboard shortcut for the current element. auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. content_editable: Indicates whether the element's content is editable. diff --git a/reflex/components/radix/themes/components/context_menu.py b/reflex/components/radix/themes/components/context_menu.py index f8512a902..60d23db1a 100644 --- a/reflex/components/radix/themes/components/context_menu.py +++ b/reflex/components/radix/themes/components/context_menu.py @@ -4,6 +4,7 @@ from typing import Dict, List, Literal, Union from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive +from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var @@ -55,6 +56,8 @@ class ContextMenuTrigger(RadixThemesComponent): _invalid_children: List[str] = ["ContextMenuContent"] + _memoization_mode = MemoizationMode(recursive=False) + class ContextMenuContent(RadixThemesComponent): """The component that pops out when the context menu is open.""" @@ -153,6 +156,8 @@ class ContextMenuSubTrigger(RadixThemesComponent): _valid_parents: List[str] = ["ContextMenuContent", "ContextMenuSub"] + _memoization_mode = MemoizationMode(recursive=False) + class ContextMenuSubContent(RadixThemesComponent): """The component that pops out when a submenu is open.""" diff --git a/reflex/components/radix/themes/components/dialog.py b/reflex/components/radix/themes/components/dialog.py index 1b7c3b532..ce6e52cb5 100644 --- a/reflex/components/radix/themes/components/dialog.py +++ b/reflex/components/radix/themes/components/dialog.py @@ -5,6 +5,7 @@ from typing import Literal from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var @@ -31,6 +32,8 @@ class DialogTrigger(RadixThemesTriggerComponent): tag = "Dialog.Trigger" + _memoization_mode = MemoizationMode(recursive=False) + class DialogTitle(RadixThemesComponent): """Title component to display inside a Dialog modal.""" diff --git a/reflex/components/radix/themes/components/dropdown_menu.py b/reflex/components/radix/themes/components/dropdown_menu.py index abce3e3bb..6d5709e11 100644 --- a/reflex/components/radix/themes/components/dropdown_menu.py +++ b/reflex/components/radix/themes/components/dropdown_menu.py @@ -4,6 +4,7 @@ from typing import Dict, List, Literal, Union from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive +from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var @@ -60,6 +61,8 @@ class DropdownMenuTrigger(RadixThemesTriggerComponent): _invalid_children: List[str] = ["DropdownMenuContent"] + _memoization_mode = MemoizationMode(recursive=False) + class DropdownMenuContent(RadixThemesComponent): """The Dropdown Menu Content component that pops out when the dropdown menu is open.""" @@ -143,6 +146,8 @@ class DropdownMenuSubTrigger(RadixThemesTriggerComponent): _valid_parents: List[str] = ["DropdownMenuContent", "DropdownMenuSub"] + _memoization_mode = MemoizationMode(recursive=False) + class DropdownMenuSub(RadixThemesComponent): """Contains all the parts of a submenu.""" diff --git a/reflex/components/radix/themes/components/hover_card.py b/reflex/components/radix/themes/components/hover_card.py index bd5489ce6..9e7aa4688 100644 --- a/reflex/components/radix/themes/components/hover_card.py +++ b/reflex/components/radix/themes/components/hover_card.py @@ -5,6 +5,7 @@ from typing import Dict, Literal, Union from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, passthrough_event_spec from reflex.vars.base import Var @@ -37,6 +38,8 @@ class HoverCardTrigger(RadixThemesTriggerComponent): tag = "HoverCard.Trigger" + _memoization_mode = MemoizationMode(recursive=False) + class HoverCardContent(elements.Div, RadixThemesComponent): """Contains the content of the open hover card.""" diff --git a/reflex/components/radix/themes/components/popover.py b/reflex/components/radix/themes/components/popover.py index bdf5f4af3..4c0542cb7 100644 --- a/reflex/components/radix/themes/components/popover.py +++ b/reflex/components/radix/themes/components/popover.py @@ -5,6 +5,7 @@ from typing import Dict, Literal, Union from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var @@ -34,6 +35,8 @@ class PopoverTrigger(RadixThemesTriggerComponent): tag = "Popover.Trigger" + _memoization_mode = MemoizationMode(recursive=False) + class PopoverContent(elements.Div, RadixThemesComponent): """Contains content to be rendered in the open popover.""" diff --git a/reflex/components/radix/themes/components/select.py b/reflex/components/radix/themes/components/select.py index 45e5712bc..6ac992380 100644 --- a/reflex/components/radix/themes/components/select.py +++ b/reflex/components/radix/themes/components/select.py @@ -5,6 +5,7 @@ from typing import List, Literal, Union import reflex as rx from reflex.components.component import Component, ComponentNamespace from reflex.components.core.breakpoints import Responsive +from reflex.constants.compiler import MemoizationMode from reflex.event import no_args_event_spec, passthrough_event_spec from reflex.vars.base import Var @@ -69,6 +70,8 @@ class SelectTrigger(RadixThemesComponent): _valid_parents: List[str] = ["SelectRoot"] + _memoization_mode = MemoizationMode(recursive=False) + class SelectContent(RadixThemesComponent): """The component that pops out when the select is open.""" diff --git a/reflex/components/radix/themes/components/tabs.py b/reflex/components/radix/themes/components/tabs.py index adfb32fab..7b5e5f475 100644 --- a/reflex/components/radix/themes/components/tabs.py +++ b/reflex/components/radix/themes/components/tabs.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Literal from reflex.components.component import Component, ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.core.colors import color +from reflex.constants.compiler import MemoizationMode from reflex.event import EventHandler, passthrough_event_spec from reflex.vars.base import Var @@ -95,6 +96,8 @@ class TabsTrigger(RadixThemesComponent): _valid_parents: List[str] = ["TabsList"] + _memoization_mode = MemoizationMode(recursive=False) + @classmethod def create(cls, *children, **props) -> Component: """Create a TabsTrigger component. diff --git a/reflex/components/tags/tag.py b/reflex/components/tags/tag.py index 8a6326dc0..983726e56 100644 --- a/reflex/components/tags/tag.py +++ b/reflex/components/tags/tag.py @@ -3,13 +3,33 @@ from __future__ import annotations import dataclasses -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union from reflex.event import EventChain from reflex.utils import format, types from reflex.vars.base import LiteralVar, Var +def render_prop(value: Any) -> Any: + """Render the prop. + + Args: + value: The value to render. + + Returns: + The rendered value. + """ + from reflex.components.component import BaseComponent + + if isinstance(value, BaseComponent): + return value.render() + if isinstance(value, Sequence) and not isinstance(value, str): + return [render_prop(v) for v in value] + if callable(value) and not isinstance(value, Var): + return None + return value + + @dataclasses.dataclass() class Tag: """A React tag.""" @@ -65,25 +85,10 @@ class Tag: Yields: Tuple[str, Any]: The field name and value. """ - from reflex.components.component import BaseComponent - for field in dataclasses.fields(self): - value = getattr(self, field.name) - if isinstance(value, list): - children = [] - for child in value: - if isinstance(child, BaseComponent): - children.append(child.render()) - else: - children.append(child) - yield field.name, children - continue - if isinstance(value, BaseComponent): - yield field.name, value.render() - continue - if callable(value) and not isinstance(value, Var): - continue - yield field.name, getattr(self, field.name) + rendered_value = render_prop(getattr(self, field.name)) + if rendered_value is not None: + yield field.name, rendered_value def add_props(self, **kwargs: Optional[Any]) -> Tag: """Add props to the tag. diff --git a/reflex/config.py b/reflex/config.py index f6992f8b5..6609067f9 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -703,6 +703,9 @@ class Config(Base): # Path to file containing key-values pairs to override in the environment; Dotenv format. env_file: Optional[str] = None + # Whether the app is running in the reflex cloud environment. + is_reflex_cloud: bool = False + def __init__(self, *args, **kwargs): """Initialize the config values. diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index dc5d80fe0..9bc9978dc 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -1,10 +1,10 @@ """Compiler variables.""" +import dataclasses import enum from enum import Enum from types import SimpleNamespace -from reflex.base import Base from reflex.constants import Dirs from reflex.utils.imports import ImportVar @@ -151,7 +151,8 @@ class MemoizationDisposition(enum.Enum): NEVER = "never" -class MemoizationMode(Base): +@dataclasses.dataclass(frozen=True) +class MemoizationMode: """The mode for memoizing a Component.""" # The conditions under which the component should be memoized. diff --git a/reflex/event.py b/reflex/event.py index fbbfc70b2..5ce0f3dc1 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -4,7 +4,6 @@ from __future__ import annotations import dataclasses import inspect -import sys import types import urllib.parse from base64 import b64encode @@ -37,6 +36,7 @@ from typing_extensions import ( ) from reflex import constants +from reflex.constants.compiler import CompileVars, Hooks, Imports from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.utils import console, format from reflex.utils.exceptions import ( @@ -540,7 +540,7 @@ class JavasciptKeyboardEvent: shiftKey: bool = False # noqa: N815 -def input_event(e: Var[JavascriptInputEvent]) -> Tuple[Var[str]]: +def input_event(e: ObjectVar[JavascriptInputEvent]) -> Tuple[Var[str]]: """Get the value from an input event. Args: @@ -561,7 +561,9 @@ class KeyInputInfo(TypedDict): shift_key: bool -def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInfo]]: +def key_event( + e: ObjectVar[JavasciptKeyboardEvent], +) -> Tuple[Var[str], Var[KeyInputInfo]]: """Get the key from a keyboard event. Args: @@ -571,7 +573,7 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf The key from the keyboard event. """ return ( - e.key, + e.key.to(str), Var.create( { "alt_key": e.altKey, @@ -579,7 +581,7 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf "meta_key": e.metaKey, "shift_key": e.shiftKey, }, - ), + ).to(KeyInputInfo), ) @@ -1353,7 +1355,7 @@ def unwrap_var_annotation(annotation: GenericType): Returns: The unwrapped annotation. """ - if get_origin(annotation) is Var and (args := get_args(annotation)): + if get_origin(annotation) in (Var, ObjectVar) and (args := get_args(annotation)): return args[0] return annotation @@ -1619,7 +1621,7 @@ class EventVar(ObjectVar, python_types=EventSpec): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralEventVar(VarOperationCall, LiteralVar, EventVar): """A literal event var.""" @@ -1680,7 +1682,7 @@ class EventChainVar(BuilderFunctionVar, python_types=EventChain): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) # Note: LiteralVar is second in the inheritance list allowing it act like a # CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the @@ -1712,6 +1714,9 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV Returns: The created LiteralEventChainVar instance. + + Raises: + ValueError: If the invocation is not a FunctionVar. """ arg_spec = ( value.args_spec[0] @@ -1729,10 +1734,21 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV arg_def_expr = Var(_js_expr="args") if value.invocation is None: - invocation = FunctionStringVar.create("addEvents") + invocation = FunctionStringVar.create( + CompileVars.ADD_EVENTS, + _var_data=VarData( + imports=Imports.EVENTS, + hooks={Hooks.EVENTS: None}, + ), + ) else: invocation = value.invocation + if invocation is not None and not isinstance(invocation, FunctionVar): + raise ValueError( + f"EventChain invocation must be a FunctionVar, got {invocation!s} of type {invocation._var_type!s}." + ) + return cls( _js_expr="", _var_type=EventChain, diff --git a/reflex/experimental/client_state.py b/reflex/experimental/client_state.py index ce3a941bb..8138c2721 100644 --- a/reflex/experimental/client_state.py +++ b/reflex/experimental/client_state.py @@ -4,7 +4,6 @@ from __future__ import annotations import dataclasses import re -import sys from typing import Any, Callable, Union from reflex import constants @@ -49,7 +48,7 @@ def _client_state_ref_dict(var_name: str) -> str: @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ClientStateVar(Var): """A Var that exists on the client via useState.""" diff --git a/reflex/experimental/layout.py b/reflex/experimental/layout.py index 1c74271b2..d54e87f8b 100644 --- a/reflex/experimental/layout.py +++ b/reflex/experimental/layout.py @@ -12,6 +12,7 @@ from reflex.components.radix.themes.components.icon_button import IconButton from reflex.components.radix.themes.layout.box import Box from reflex.components.radix.themes.layout.container import Container from reflex.components.radix.themes.layout.stack import HStack +from reflex.constants.compiler import MemoizationMode from reflex.event import run_script from reflex.experimental import hooks from reflex.state import ComponentState @@ -146,6 +147,8 @@ sidebar_trigger_style = { class SidebarTrigger(Fragment): """A component that renders the sidebar trigger.""" + _memoization_mode = MemoizationMode(recursive=False) + @classmethod def create(cls, sidebar: Component, **props): """Create the sidebar trigger component. diff --git a/reflex/state.py b/reflex/state.py index 7688dcf78..92aaa4710 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1569,9 +1569,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if not isinstance(var, Var): return var + unset = object() + # Fast case: this is a literal var and the value is known. - if hasattr(var, "_var_value"): - return var._var_value + if (var_value := getattr(var, "_var_value", unset)) is not unset: + return var_value # pyright: ignore [reportReturnType] var_data = var._get_all_var_data() if var_data is None or not var_data.state: diff --git a/reflex/style.py b/reflex/style.py index f5e424fe2..192835ca3 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -303,8 +303,16 @@ class Style(dict): Returns: The combined style. """ - _var_data = VarData.merge(self._var_data, getattr(other, "_var_data", None)) - return Style(super().__or__(self, other), _var_data=_var_data) # pyright: ignore [reportGeneralTypeIssues, reportCallIssue] + other_var_data = None + if not isinstance(other, Style): + other_dict, other_var_data = convert(other) + else: + other_dict, other_var_data = other, other._var_data + + new_style = Style(super().__or__(other_dict)) + if self._var_data or other_var_data: + new_style._var_data = VarData.merge(self._var_data, other_var_data) + return new_style def _format_emotion_style_pseudo_selector(key: str) -> str: diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index be3f6ab69..05fbb297c 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -75,6 +75,10 @@ class VarAttributeError(ReflexError, AttributeError): """Custom AttributeError for var related errors.""" +class UntypedVarError(ReflexError, TypeError): + """Custom TypeError for untyped var errors.""" + + class UntypedComputedVarError(ReflexError, TypeError): """Custom TypeError for untyped computed var errors.""" diff --git a/reflex/vars/base.py b/reflex/vars/base.py index f8a26e795..bb0a767f5 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -11,7 +11,6 @@ import json import random import re import string -import sys import warnings from types import CodeType, FunctionType from typing import ( @@ -29,6 +28,7 @@ from typing import ( Mapping, NoReturn, Optional, + Sequence, Set, Tuple, Type, @@ -80,6 +80,7 @@ if TYPE_CHECKING: VAR_TYPE = TypeVar("VAR_TYPE", covariant=True) OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE") STRING_T = TypeVar("STRING_T", bound=str) +SEQUENCE_TYPE = TypeVar("SEQUENCE_TYPE", bound=Sequence) warnings.filterwarnings("ignore", message="fields may not start with an underscore") @@ -130,7 +131,7 @@ class VarData: state: str = "", field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, - hooks: Mapping[str, VarData | None] | None = None, + hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None, deps: list[Var] | None = None, position: Hooks.HookPosition | None = None, ): @@ -144,6 +145,10 @@ class VarData: deps: Dependencies of the var for useCallback. position: Position of the hook in the component. """ + if isinstance(hooks, str): + hooks = [hooks] + if not isinstance(hooks, dict): + hooks = {hook: None for hook in (hooks or [])} immutable_imports: ImmutableParsedImportDict = tuple( (k, tuple(v)) for k, v in parse_imports(imports or {}).items() ) @@ -154,6 +159,16 @@ class VarData: object.__setattr__(self, "deps", tuple(deps or [])) object.__setattr__(self, "position", position or None) + if hooks and any(hooks.values()): + merged_var_data = VarData.merge(self, *hooks.values()) + if merged_var_data is not None: + object.__setattr__(self, "state", merged_var_data.state) + object.__setattr__(self, "field_name", merged_var_data.field_name) + object.__setattr__(self, "imports", merged_var_data.imports) + object.__setattr__(self, "hooks", merged_var_data.hooks) + object.__setattr__(self, "deps", merged_var_data.deps) + object.__setattr__(self, "position", merged_var_data.position) + def old_school_imports(self) -> ImportDict: """Return the imports as a mutable dict. @@ -433,7 +448,7 @@ class Var(Generic[VAR_TYPE]): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ToVarOperation(ToOperation, cls): """Base class of converting a var to another var type.""" @@ -444,7 +459,12 @@ class Var(Generic[VAR_TYPE]): _default_var_type: ClassVar[GenericType] = default_type - ToVarOperation.__name__ = f'To{cls.__name__.removesuffix("Var")}Operation' + new_to_var_operation_name = f"To{cls.__name__.removesuffix('Var')}Operation" + ToVarOperation.__qualname__ = ( + ToVarOperation.__qualname__.removesuffix(ToVarOperation.__name__) + + new_to_var_operation_name + ) + ToVarOperation.__name__ = new_to_var_operation_name _var_subclasses.append(VarSubclassEntry(cls, ToVarOperation, python_types)) @@ -576,7 +596,7 @@ class Var(Generic[VAR_TYPE]): @overload @classmethod - def create( + def create( # pyright: ignore [reportOverlappingOverload] cls, value: STRING_T, _var_data: VarData | None = None, @@ -590,6 +610,22 @@ class Var(Generic[VAR_TYPE]): _var_data: VarData | None = None, ) -> NoneVar: ... + @overload + @classmethod + def create( + cls, + value: MAPPING_TYPE, + _var_data: VarData | None = None, + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + @classmethod + def create( + cls, + value: SEQUENCE_TYPE, + _var_data: VarData | None = None, + ) -> ArrayVar[SEQUENCE_TYPE]: ... + @overload @classmethod def create( @@ -671,8 +707,8 @@ class Var(Generic[VAR_TYPE]): @overload def to( self, - output: type[Mapping], - ) -> ObjectVar[Mapping]: ... + output: type[MAPPING_TYPE], + ) -> ObjectVar[MAPPING_TYPE]: ... @overload def to( @@ -723,7 +759,7 @@ class Var(Generic[VAR_TYPE]): return get_to_operation(NoneVar).create(self) # pyright: ignore [reportReturnType] # Handle fixed_output_type being Base or a dataclass. - if can_use_in_object_var(fixed_output_type): + if can_use_in_object_var(output): return self.to(ObjectVar, output) if inspect.isclass(output): @@ -755,6 +791,9 @@ class Var(Generic[VAR_TYPE]): return self + @overload + def guess_type(self: Var[NoReturn]) -> Var[Any]: ... # pyright: ignore [reportOverlappingOverload] + @overload def guess_type(self: Var[str]) -> StringVar: ... @@ -764,6 +803,9 @@ class Var(Generic[VAR_TYPE]): @overload def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ... + @overload + def guess_type(self: Var[BASE_TYPE]) -> ObjectVar[BASE_TYPE]: ... + @overload def guess_type(self) -> Self: ... @@ -912,7 +954,7 @@ class Var(Generic[VAR_TYPE]): return setter - def _var_set_state(self, state: type[BaseState] | str): + def _var_set_state(self, state: type[BaseState] | str) -> Self: """Set the state of the var. Args: @@ -927,7 +969,7 @@ class Var(Generic[VAR_TYPE]): else format_state_name(state.get_full_name()) ) - return StateOperation.create( + return StateOperation.create( # pyright: ignore [reportReturnType] formatted_state_name, self, _var_data=VarData.merge( @@ -1106,43 +1148,6 @@ class Var(Generic[VAR_TYPE]): """ return self - def __getattr__(self, name: str): - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute. - - Raises: - VarAttributeError: If the attribute does not exist. - TypeError: If the var type is Any. - """ - if name.startswith("_"): - return self.__getattribute__(name) - - if name == "contains": - raise TypeError( - f"Var of type {self._var_type} does not support contains check." - ) - if name == "reverse": - raise TypeError("Cannot reverse non-list var.") - - if self._var_type is Any: - raise TypeError( - f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`." - ) - - if name in REPLACED_NAMES: - raise VarAttributeError( - f"Field {name!r} was renamed to {REPLACED_NAMES[name]!r}" - ) - - raise VarAttributeError( - f"The State var has no attribute '{name}' or may have been annotated wrongly.", - ) - def _decode(self) -> Any: """Decode Var as a python value. @@ -1204,36 +1209,76 @@ class Var(Generic[VAR_TYPE]): return ArrayVar.range(first_endpoint, second_endpoint, step) - def __bool__(self) -> bool: - """Raise exception if using Var in a boolean context. + if not TYPE_CHECKING: - Raises: - VarTypeError: when attempting to bool-ify the Var. - """ - raise VarTypeError( - f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. " - "Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)." - ) + def __getattr__(self, name: str): + """Get an attribute of the var. - def __iter__(self) -> Any: - """Raise exception if using Var in an iterable context. + Args: + name: The name of the attribute. - Raises: - VarTypeError: when attempting to iterate over the Var. - """ - raise VarTypeError( - f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`." - ) + Raises: + VarAttributeError: If the attribute does not exist. + UntypedVarError: If the var type is Any. + TypeError: If the var type is Any. - def __contains__(self, _: Any) -> Var: - """Override the 'in' operator to alert the user that it is not supported. + # noqa: DAR101 self + """ + if name.startswith("_"): + raise VarAttributeError(f"Attribute {name} not found.") - Raises: - VarTypeError: the operation is not supported - """ - raise VarTypeError( - "'in' operator not supported for Var types, use Var.contains() instead." - ) + if name == "contains": + raise TypeError( + f"Var of type {self._var_type} does not support contains check." + ) + if name == "reverse": + raise TypeError("Cannot reverse non-list var.") + + if self._var_type is Any: + raise exceptions.UntypedVarError( + f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`." + ) + + raise VarAttributeError( + f"The State var has no attribute '{name}' or may have been annotated wrongly.", + ) + + def __bool__(self) -> bool: + """Raise exception if using Var in a boolean context. + + Raises: + VarTypeError: when attempting to bool-ify the Var. + + # noqa: DAR101 self + """ + raise VarTypeError( + f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. " + "Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)." + ) + + def __iter__(self) -> Any: + """Raise exception if using Var in an iterable context. + + Raises: + VarTypeError: when attempting to iterate over the Var. + + # noqa: DAR101 self + """ + raise VarTypeError( + f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`." + ) + + def __contains__(self, _: Any) -> Var: + """Override the 'in' operator to alert the user that it is not supported. + + Raises: + VarTypeError: the operation is not supported + + # noqa: DAR101 self + """ + raise VarTypeError( + "'in' operator not supported for Var types, use Var.contains() instead." + ) OUTPUT = TypeVar("OUTPUT", bound=Var) @@ -1384,7 +1429,7 @@ class LiteralVar(Var): TypeError: If the value is not a supported type for LiteralVar. """ from .object import LiteralObjectVar - from .sequence import LiteralStringVar + from .sequence import ArrayVar, LiteralStringVar if isinstance(value, Var): if _var_data is None: @@ -1440,6 +1485,9 @@ class LiteralVar(Var): _var_data=_var_data, ) + if isinstance(value, range): + return ArrayVar.range(value.start, value.stop, value.step) + raise TypeError( f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}." ) @@ -1447,6 +1495,12 @@ class LiteralVar(Var): def __post_init__(self): """Post-initialize the var.""" + @property + def _var_value(self) -> Any: + raise NotImplementedError( + "LiteralVar subclasses must implement the _var_value property." + ) + def json(self) -> str: """Serialize the var to a JSON string. @@ -1519,7 +1573,7 @@ def var_operation( ) -> Callable[P, StringVar]: ... -LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set]) +LIST_T = TypeVar("LIST_T", bound=Sequence) @overload @@ -1756,7 +1810,7 @@ def _or_operation(a: Var, b: Var): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class CallableVar(Var): """Decorate a Var-returning function to act as both a Var and a function. @@ -1837,7 +1891,7 @@ def is_computed_var(obj: Any) -> TypeGuard[ComputedVar]: @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ComputedVar(Var[RETURN_TYPE]): """A field with computed getters.""" @@ -2031,6 +2085,13 @@ class ComputedVar(Var[RETURN_TYPE]): return True return datetime.datetime.now() - last_updated > self._update_interval + @overload + def __get__( + self: ComputedVar[bool], + instance: None, + owner: Type, + ) -> BooleanVar: ... + @overload def __get__( self: ComputedVar[int] | ComputedVar[float], @@ -2059,13 +2120,6 @@ class ComputedVar(Var[RETURN_TYPE]): owner: Type, ) -> ArrayVar[list[LIST_INSIDE]]: ... - @overload - def __get__( - self: ComputedVar[set[LIST_INSIDE]], - instance: None, - owner: Type, - ) -> ArrayVar[set[LIST_INSIDE]]: ... - @overload def __get__( self: ComputedVar[tuple[LIST_INSIDE, ...]], @@ -2268,7 +2322,7 @@ async def _default_async_computed_var(_self: BaseState) -> Any: eq=False, frozen=True, init=False, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class AsyncComputedVar(ComputedVar[RETURN_TYPE]): """A computed var that wraps a coroutinefunction.""" @@ -2279,7 +2333,14 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]): @overload def __get__( - self: AsyncComputedVar[int] | AsyncComputedVar[float], + self: AsyncComputedVar[bool], + instance: None, + owner: Type, + ) -> BooleanVar: ... + + @overload + def __get__( + self: AsyncComputedVar[int] | ComputedVar[float], instance: None, owner: Type, ) -> NumberVar: ... @@ -2305,13 +2366,6 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]): owner: Type, ) -> ArrayVar[list[LIST_INSIDE]]: ... - @overload - def __get__( - self: AsyncComputedVar[set[LIST_INSIDE]], - instance: None, - owner: Type, - ) -> ArrayVar[set[LIST_INSIDE]]: ... - @overload def __get__( self: AsyncComputedVar[tuple[LIST_INSIDE, ...]], @@ -2535,7 +2589,7 @@ def var_operation_return( @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class CustomVarOperation(CachedVarOperation, Var[T]): """Base class for custom var operations.""" @@ -2606,7 +2660,7 @@ class NoneVar(Var[None], python_types=type(None)): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralNoneVar(LiteralVar, NoneVar): """A var representing None.""" @@ -2668,7 +2722,7 @@ def get_to_operation(var_subclass: Type[Var]) -> Type[ToOperation]: @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class StateOperation(CachedVarOperation, Var): """A var operation that accesses a field on an object.""" @@ -2815,19 +2869,6 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: return var_datas -# These names were changed in reflex 0.3.0 -REPLACED_NAMES = { - "full_name": "_var_full_name", - "name": "_js_expr", - "state": "_var_data.state", - "type_": "_var_type", - "is_local": "_var_is_local", - "is_string": "_var_is_string", - "set_state": "_var_set_state", - "deps": "_deps", -} - - dispatchers: Dict[GenericType, Callable[[Var], Var]] = {} diff --git a/reflex/vars/datetime.py b/reflex/vars/datetime.py index b20cfc7a6..c43c24165 100644 --- a/reflex/vars/datetime.py +++ b/reflex/vars/datetime.py @@ -3,7 +3,6 @@ from __future__ import annotations import dataclasses -import sys from datetime import date, datetime from typing import Any, NoReturn, TypeVar, Union, overload @@ -193,7 +192,7 @@ def date_compare_operation( @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralDatetimeVar(LiteralVar, DateTimeVar): """Base class for immutable datetime and date vars.""" diff --git a/reflex/vars/function.py b/reflex/vars/function.py index e8691cfb1..505a69b4c 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -226,7 +226,7 @@ class FunctionStringVar(FunctionVar[CALLABLE_TYPE]): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]): """Base class for immutable vars that are the result of a function call.""" @@ -350,7 +350,7 @@ def format_args_function_operation( @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ArgsFunctionOperation(CachedVarOperation, FunctionVar): """Base class for immutable function defined via arguments and return expression.""" @@ -407,7 +407,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): """Base class for immutable function defined via arguments and return expression with the builder pattern.""" diff --git a/reflex/vars/number.py b/reflex/vars/number.py index 050dc2329..35a55490a 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -5,7 +5,6 @@ from __future__ import annotations import dataclasses import json import math -import sys from typing import ( TYPE_CHECKING, Any, @@ -160,7 +159,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ from .sequence import ArrayVar, LiteralArrayVar - if isinstance(other, (list, tuple, set, ArrayVar)): + if isinstance(other, (list, tuple, ArrayVar)): if isinstance(other, ArrayVar): return other * self return LiteralArrayVar.create(other) * self @@ -187,7 +186,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ from .sequence import ArrayVar, LiteralArrayVar - if isinstance(other, (list, tuple, set, ArrayVar)): + if isinstance(other, (list, tuple, ArrayVar)): if isinstance(other, ArrayVar): return other * self return LiteralArrayVar.create(other) * self @@ -973,7 +972,7 @@ def boolean_not_operation(value: BooleanVar): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralNumberVar(LiteralVar, NumberVar): """Base class for immutable literal number vars.""" @@ -1032,7 +1031,7 @@ class LiteralNumberVar(LiteralVar, NumberVar): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralBooleanVar(LiteralVar, BooleanVar): """Base class for immutable literal boolean vars.""" diff --git a/reflex/vars/object.py b/reflex/vars/object.py index ed4221e4c..cb29cabfb 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -3,7 +3,6 @@ from __future__ import annotations import dataclasses -import sys import typing from inspect import isclass from typing import ( @@ -167,12 +166,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): key: Var | Any, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... - @overload - def __getitem__( - self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]], - key: Var | Any, - ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... - @overload def __getitem__( self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], @@ -229,12 +222,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): name: str, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... - @overload - def __getattr__( - self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]], - name: str, - ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... - @overload def __getattr__( self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], @@ -305,7 +292,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): """Base class for immutable literal object vars.""" @@ -355,17 +342,20 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): Returns: The JSON representation of the object. + + Raises: + TypeError: The keys and values of the object must be literal vars to get the JSON representation """ - return ( - "{" - + ", ".join( - [ - f"{LiteralVar.create(key).json()}:{LiteralVar.create(value).json()}" - for key, value in self._var_value.items() - ] - ) - + "}" - ) + keys_and_values = [] + for key, value in self._var_value.items(): + key = LiteralVar.create(key) + value = LiteralVar.create(value) + if not isinstance(key, LiteralVar) or not isinstance(value, LiteralVar): + raise TypeError( + "The keys and values of the object must be literal vars to get the JSON representation." + ) + keys_and_values.append(f"{key.json()}:{value.json()}") + return "{" + ", ".join(keys_and_values) + "}" def __hash__(self) -> int: """Get the hash of the var. @@ -487,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ObjectItemOperation(CachedVarOperation, Var): """Operation to get an item from an object.""" diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 5e9f6468e..dfd9a6af8 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -6,7 +6,6 @@ import dataclasses import inspect import json import re -import sys import typing from typing import ( TYPE_CHECKING, @@ -15,7 +14,7 @@ from typing import ( List, Literal, NoReturn, - Set, + Sequence, Tuple, Type, Union, @@ -596,7 +595,7 @@ _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL) @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralStringVar(LiteralVar, StringVar[str]): """Base class for immutable literal string vars.""" @@ -718,7 +717,7 @@ class LiteralStringVar(LiteralVar, StringVar[str]): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ConcatVarOperation(CachedVarOperation, StringVar[str]): """Representing a concatenation of literal string vars.""" @@ -794,7 +793,8 @@ class ConcatVarOperation(CachedVarOperation, StringVar[str]): ) -ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[List, Tuple, Set]) +ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Sequence, covariant=True) +OTHER_ARRAY_VAR_TYPE = TypeVar("OTHER_ARRAY_VAR_TYPE", bound=Sequence) OTHER_TUPLE = TypeVar("OTHER_TUPLE") @@ -887,6 +887,11 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): i: Literal[0, -2], ) -> NumberVar: ... + @overload + def __getitem__( + self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1] + ) -> BooleanVar: ... + @overload def __getitem__( self: ( @@ -914,7 +919,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): @overload def __getitem__( - self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1] + self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar ) -> BooleanVar: ... @overload @@ -932,23 +937,12 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar ) -> StringVar: ... - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar - ) -> BooleanVar: ... - @overload def __getitem__( self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]], i: int | NumberVar, ) -> ArrayVar[List[INNER_ARRAY_VAR]]: ... - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[Set[INNER_ARRAY_VAR]], - i: int | NumberVar, - ) -> ArrayVar[Set[INNER_ARRAY_VAR]]: ... - @overload def __getitem__( self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[KEY_TYPE, VALUE_TYPE]], @@ -1239,26 +1233,18 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): LIST_ELEMENT = TypeVar("LIST_ELEMENT") -ARRAY_VAR_OF_LIST_ELEMENT = Union[ - ArrayVar[List[LIST_ELEMENT]], - ArrayVar[Set[LIST_ELEMENT]], - ArrayVar[Tuple[LIST_ELEMENT, ...]], -] +ARRAY_VAR_OF_LIST_ELEMENT = ArrayVar[Sequence[LIST_ELEMENT]] @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): """Base class for immutable literal array vars.""" - _var_value: Union[ - List[Union[Var, Any]], - Set[Union[Var, Any]], - Tuple[Union[Var, Any], ...], - ] = dataclasses.field(default_factory=list) + _var_value: Sequence[Union[Var, Any]] = dataclasses.field(default=()) @cached_property_no_lock def _cached_var_name(self) -> str: @@ -1303,22 +1289,28 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): Returns: The JSON representation of the var. + + Raises: + TypeError: If the array elements are not of type LiteralVar. """ - return ( - "[" - + ", ".join( - [LiteralVar.create(element).json() for element in self._var_value] - ) - + "]" - ) + elements = [] + for element in self._var_value: + element_var = LiteralVar.create(element) + if not isinstance(element_var, LiteralVar): + raise TypeError( + f"Array elements must be of type LiteralVar, not {type(element_var)}" + ) + elements.append(element_var.json()) + + return "[" + ", ".join(elements) + "]" @classmethod def create( cls, - value: ARRAY_VAR_TYPE, - _var_type: Type[ARRAY_VAR_TYPE] | None = None, + value: OTHER_ARRAY_VAR_TYPE, + _var_type: Type[OTHER_ARRAY_VAR_TYPE] | None = None, _var_data: VarData | None = None, - ) -> LiteralArrayVar[ARRAY_VAR_TYPE]: + ) -> LiteralArrayVar[OTHER_ARRAY_VAR_TYPE]: """Create a var from a string value. Args: @@ -1329,7 +1321,7 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): Returns: The var. """ - return cls( + return LiteralArrayVar( _js_expr="", _var_type=figure_out_type(value) if _var_type is None else _var_type, _var_data=_var_data, @@ -1356,7 +1348,7 @@ def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ArraySliceOperation(CachedVarOperation, ArrayVar): """Base class for immutable string vars that are the result of a string slice operation.""" @@ -1593,7 +1585,7 @@ def array_range_operation( The range of numbers. """ return var_operation_return( - js_expression=f"Array.from({{ length: ({stop!s} - {start!s}) / {step!s} }}, (_, i) => {start!s} + i * {step!s})", + js_expression=f"Array.from({{ length: Math.ceil(({stop!s} - {start!s}) / {step!s}) }}, (_, i) => {start!s} + i * {step!s})", var_type=List[int], ) @@ -1705,7 +1697,7 @@ class ColorVar(StringVar[Color], python_types=Color): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar): """Base class for immutable literal color vars.""" diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index c80f30626..4867cf868 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -1,5 +1,6 @@ """Test case for displaying the connection banner when the websocket drops.""" +import functools from typing import Generator import pytest @@ -11,12 +12,19 @@ from reflex.testing import AppHarness, WebDriver from .utils import SessionStorage -def ConnectionBanner(): - """App with a connection banner.""" +def ConnectionBanner(is_reflex_cloud: bool = False): + """App with a connection banner. + + Args: + is_reflex_cloud: The value for config.is_reflex_cloud. + """ import asyncio import reflex as rx + # Simulate reflex cloud deploy + rx.config.get_config().is_reflex_cloud = is_reflex_cloud + class State(rx.State): foo: int = 0 @@ -40,19 +48,43 @@ def ConnectionBanner(): app.add_page(index) +@pytest.fixture( + params=[False, True], ids=["reflex_cloud_disabled", "reflex_cloud_enabled"] +) +def simulate_is_reflex_cloud(request) -> bool: + """Fixture to simulate reflex cloud deployment. + + Args: + request: pytest request fixture. + + Returns: + True if reflex cloud is enabled, False otherwise. + """ + return request.param + + @pytest.fixture() -def connection_banner(tmp_path) -> Generator[AppHarness, None, None]: +def connection_banner( + tmp_path, + simulate_is_reflex_cloud: bool, +) -> Generator[AppHarness, None, None]: """Start ConnectionBanner app at tmp_path via AppHarness. Args: tmp_path: pytest tmp_path fixture + simulate_is_reflex_cloud: Whether is_reflex_cloud is set for the app. Yields: running AppHarness instance """ with AppHarness.create( root=tmp_path, - app_source=ConnectionBanner, + app_source=functools.partial( + ConnectionBanner, is_reflex_cloud=simulate_is_reflex_cloud + ), + app_name="connection_banner_reflex_cloud" + if simulate_is_reflex_cloud + else "connection_banner", ) as harness: yield harness @@ -77,6 +109,38 @@ def has_error_modal(driver: WebDriver) -> bool: return True +def has_cloud_banner(driver: WebDriver) -> bool: + """Check if the cloud banner is displayed. + + Args: + driver: Selenium webdriver instance. + + Returns: + True if the banner is displayed, False otherwise. + """ + try: + driver.find_element( + By.XPATH, "//*[ contains(text(), 'You ran out of compute credits.') ]" + ) + except NoSuchElementException: + return False + else: + return True + + +def _assert_token(connection_banner, driver): + """Poll for backend to be up. + + Args: + connection_banner: AppHarness instance. + driver: Selenium webdriver instance. + """ + ss = SessionStorage(driver) + assert connection_banner._poll_for( + lambda: ss.get("token") is not None + ), "token not found" + + @pytest.mark.asyncio async def test_connection_banner(connection_banner: AppHarness): """Test that the connection banner is displayed when the websocket drops. @@ -88,10 +152,7 @@ async def test_connection_banner(connection_banner: AppHarness): assert connection_banner.backend is not None driver = connection_banner.frontend() - ss = SessionStorage(driver) - assert connection_banner._poll_for( - lambda: ss.get("token") is not None - ), "token not found" + _assert_token(connection_banner, driver) assert connection_banner._poll_for(lambda: not has_error_modal(driver)) @@ -132,3 +193,36 @@ async def test_connection_banner(connection_banner: AppHarness): # Count should have incremented after coming back up assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2" + + +@pytest.mark.asyncio +async def test_cloud_banner( + connection_banner: AppHarness, simulate_is_reflex_cloud: bool +): + """Test that the connection banner is displayed when the websocket drops. + + Args: + connection_banner: AppHarness instance. + simulate_is_reflex_cloud: Whether is_reflex_cloud is set for the app. + """ + assert connection_banner.app_instance is not None + assert connection_banner.backend is not None + driver = connection_banner.frontend() + + driver.add_cookie({"name": "backend-enabled", "value": "truly"}) + driver.refresh() + _assert_token(connection_banner, driver) + assert connection_banner._poll_for(lambda: not has_cloud_banner(driver)) + + driver.add_cookie({"name": "backend-enabled", "value": "false"}) + driver.refresh() + if simulate_is_reflex_cloud: + assert connection_banner._poll_for(lambda: has_cloud_banner(driver)) + else: + _assert_token(connection_banner, driver) + assert connection_banner._poll_for(lambda: not has_cloud_banner(driver)) + + driver.delete_cookie("backend-enabled") + driver.refresh() + _assert_token(connection_banner, driver) + assert connection_banner._poll_for(lambda: not has_cloud_banner(driver)) diff --git a/tests/integration/test_var_operations.py b/tests/integration/test_var_operations.py index a09c8612e..9b952c575 100644 --- a/tests/integration/test_var_operations.py +++ b/tests/integration/test_var_operations.py @@ -600,6 +600,11 @@ def VarOperations(): ), id="foreach_in_match", ), + # Literal range var in a foreach + rx.box(rx.foreach(range(42, 80, 27), rx.text.span), id="range_in_foreach1"), + rx.box(rx.foreach(range(42, 80, 3), rx.text.span), id="range_in_foreach2"), + rx.box(rx.foreach(range(42, 20, -6), rx.text.span), id="range_in_foreach3"), + rx.box(rx.foreach(range(42, 43, 5), rx.text.span), id="range_in_foreach4"), ) @@ -799,6 +804,11 @@ def test_var_operations(driver, var_operations: AppHarness): ("memo_comp_nested", "345"), # foreach in a match ("foreach_in_match", "first\nsecond\nthird"), + # literal range in a foreach + ("range_in_foreach1", "4269"), + ("range_in_foreach2", "42454851545760636669727578"), + ("range_in_foreach3", "42363024"), + ("range_in_foreach4", "42"), ] for tag, expected in tests: diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py index b765750ee..11602b77a 100644 --- a/tests/units/components/core/test_match.py +++ b/tests/units/components/core/test_match.py @@ -3,6 +3,7 @@ from typing import List, Mapping, Tuple import pytest import reflex as rx +from reflex.components.component import Component from reflex.components.core.match import Match from reflex.state import BaseState from reflex.utils.exceptions import MatchTypeError @@ -29,6 +30,8 @@ def test_match_components(): rx.text("default value"), ) match_comp = Match.create(MatchState.value, *match_case_tuples) + + assert isinstance(match_comp, Component) match_dict = match_comp.render() assert match_dict["name"] == "Fragment" @@ -42,7 +45,7 @@ def test_match_components(): assert match_cases[0][0]._js_expr == "1" assert match_cases[0][0]._var_type is int - first_return_value_render = match_cases[0][1].render() + first_return_value_render = match_cases[0][1] assert first_return_value_render["name"] == "RadixThemesText" assert first_return_value_render["children"][0]["contents"] == '{"first value"}' @@ -50,31 +53,31 @@ def test_match_components(): assert match_cases[1][0]._var_type is int assert match_cases[1][1]._js_expr == "3" assert match_cases[1][1]._var_type is int - second_return_value_render = match_cases[1][2].render() + second_return_value_render = match_cases[1][2] assert second_return_value_render["name"] == "RadixThemesText" assert second_return_value_render["children"][0]["contents"] == '{"second value"}' assert match_cases[2][0]._js_expr == "[1, 2]" assert match_cases[2][0]._var_type == List[int] - third_return_value_render = match_cases[2][1].render() + third_return_value_render = match_cases[2][1] assert third_return_value_render["name"] == "RadixThemesText" assert third_return_value_render["children"][0]["contents"] == '{"third value"}' assert match_cases[3][0]._js_expr == '"random"' assert match_cases[3][0]._var_type is str - fourth_return_value_render = match_cases[3][1].render() + fourth_return_value_render = match_cases[3][1] assert fourth_return_value_render["name"] == "RadixThemesText" assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}' assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })' assert match_cases[4][0]._var_type == Mapping[str, str] - fifth_return_value_render = match_cases[4][1].render() + fifth_return_value_render = match_cases[4][1] assert fifth_return_value_render["name"] == "RadixThemesText" assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}' assert match_cases[5][0]._js_expr == f"({MatchState.get_name()}.num + 1)" assert match_cases[5][0]._var_type is int - fifth_return_value_render = match_cases[5][1].render() + fifth_return_value_render = match_cases[5][1] assert fifth_return_value_render["name"] == "RadixThemesText" assert fifth_return_value_render["children"][0]["contents"] == '{"sixth value"}' @@ -151,6 +154,7 @@ def test_match_on_component_without_default(): ) match_comp = Match.create(MatchState.value, *match_case_tuples) + assert isinstance(match_comp, Component) default = match_comp.render()["children"][0]["default"] assert isinstance(default, dict) and default["name"] == Fragment.__name__ diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 26e530f7c..8cffa6e0e 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -36,6 +36,7 @@ from reflex.utils.exceptions import ( from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var +from reflex.vars.object import ObjectVar @pytest.fixture @@ -842,12 +843,12 @@ def test_component_event_trigger_arbitrary_args(): """Test that we can define arbitrary types for the args of an event trigger.""" def on_foo_spec( - _e: Var[JavascriptInputEvent], + _e: ObjectVar[JavascriptInputEvent], alpha: Var[str], bravo: dict[str, Any], - charlie: Var[_Obj], + charlie: ObjectVar[_Obj], ): - return [_e.target.value, bravo["nested"], charlie.custom + 42] + return [_e.target.value, bravo["nested"], charlie.custom.to(int) + 42] class C1(Component): library = "/local" @@ -1328,7 +1329,7 @@ class EventState(rx.State): ), pytest.param( rx.fragment(class_name=[TEST_VAR, "other-class"]), - [LiteralVar.create([TEST_VAR, "other-class"]).join(" ")], + [Var.create([TEST_VAR, "other-class"]).join(" ")], id="fstring-dual-class_name", ), pytest.param( diff --git a/tests/units/test_app.py b/tests/units/test_app.py index a361c1d18..bf1a8a313 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -471,15 +471,15 @@ async def test_dynamic_var_event(test_state: Type[ATestState], token: str): """ state = test_state() # pyright: ignore [reportCallIssue] state.add_var("int_val", int, 0) - result = await state._process( + async for result in state._process( Event( token=token, name=f"{test_state.get_name()}.set_int_val", router_data={"pathname": "/", "query": {}}, payload={"value": 50}, ) - ).__anext__() - assert result.delta == {test_state.get_name(): {"int_val": 50}} + ): + assert result.delta == {test_state.get_name(): {"int_val": 50}} @pytest.mark.asyncio @@ -583,18 +583,17 @@ async def test_list_mutation_detection__plain_list( token: a Token. """ for event_name, expected_delta in event_tuples: - result = await list_mutation_state._process( + async for result in list_mutation_state._process( Event( token=token, name=f"{list_mutation_state.get_name()}.{event_name}", router_data={"pathname": "/", "query": {}}, payload={}, ) - ).__anext__() - - # prefix keys in expected_delta with the state name - expected_delta = {list_mutation_state.get_name(): expected_delta} - assert result.delta == expected_delta + ): + # prefix keys in expected_delta with the state name + expected_delta = {list_mutation_state.get_name(): expected_delta} + assert result.delta == expected_delta @pytest.mark.asyncio @@ -709,19 +708,18 @@ async def test_dict_mutation_detection__plain_list( token: a Token. """ for event_name, expected_delta in event_tuples: - result = await dict_mutation_state._process( + async for result in dict_mutation_state._process( Event( token=token, name=f"{dict_mutation_state.get_name()}.{event_name}", router_data={"pathname": "/", "query": {}}, payload={}, ) - ).__anext__() + ): + # prefix keys in expected_delta with the state name + expected_delta = {dict_mutation_state.get_name(): expected_delta} - # prefix keys in expected_delta with the state name - expected_delta = {dict_mutation_state.get_name(): expected_delta} - - assert result.delta == expected_delta + assert result.delta == expected_delta @pytest.mark.asyncio diff --git a/tests/units/test_event.py b/tests/units/test_event.py index 5e47991da..afcfda504 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -3,6 +3,7 @@ from typing import Callable, List import pytest import reflex as rx +from reflex.constants.compiler import Hooks, Imports from reflex.event import ( Event, EventChain, @@ -14,7 +15,7 @@ from reflex.event import ( ) from reflex.state import BaseState from reflex.utils import format -from reflex.vars.base import Field, LiteralVar, Var, field +from reflex.vars.base import Field, LiteralVar, Var, VarData, field def make_var(value) -> Var: @@ -443,9 +444,28 @@ def test_event_var_data(): return (value,) # Ensure chain carries _var_data - chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec)) + chain_var = Var.create( + EventChain( + events=[S.s(S.x)], + args_spec=_args_spec, + invocation=rx.vars.FunctionStringVar.create(""), + ) + ) assert chain_var._get_all_var_data() == S.x._get_all_var_data() + chain_var_data = Var.create( + EventChain( + events=[], + args_spec=_args_spec, + ) + )._get_all_var_data() + assert chain_var_data is not None + + assert chain_var_data == VarData( + imports=Imports.EVENTS, + hooks={Hooks.EVENTS: None}, + ) + def test_event_bound_method() -> None: class S(BaseState): diff --git a/tests/units/test_state.py b/tests/units/test_state.py index ad844cf01..44c3f60b7 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -790,17 +790,16 @@ async def test_process_event_simple(test_state): assert test_state.num1 == 0 event = Event(token="t", name="set_num1", payload={"value": 69}) - update = await test_state._process(event).__anext__() + async for update in test_state._process(event): + # The event should update the value. + assert test_state.num1 == 69 - # The event should update the value. - assert test_state.num1 == 69 - - # The delta should contain the changes, including computed vars. - assert update.delta == { - TestState.get_full_name(): {"num1": 69, "sum": 72.14}, - GrandchildState3.get_full_name(): {"computed": ""}, - } - assert update.events == [] + # The delta should contain the changes, including computed vars. + assert update.delta == { + TestState.get_full_name(): {"num1": 69, "sum": 72.14}, + GrandchildState3.get_full_name(): {"computed": ""}, + } + assert update.events == [] @pytest.mark.asyncio @@ -820,15 +819,15 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) name=f"{ChildState.get_name()}.change_both", payload={"value": "hi", "count": 12}, ) - update = await test_state._process(event).__anext__() - assert child_state.value == "HI" - assert child_state.count == 24 - assert update.delta == { - # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, - ChildState.get_full_name(): {"value": "HI", "count": 24}, - GrandchildState3.get_full_name(): {"computed": ""}, - } - test_state._clean() + async for update in test_state._process(event): + assert child_state.value == "HI" + assert child_state.count == 24 + assert update.delta == { + # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, + ChildState.get_full_name(): {"value": "HI", "count": 24}, + GrandchildState3.get_full_name(): {"computed": ""}, + } + test_state._clean() # Test with the granchild state. assert grandchild_state.value2 == "" @@ -837,13 +836,13 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) name=f"{GrandchildState.get_full_name()}.set_value2", payload={"value": "new"}, ) - update = await test_state._process(event).__anext__() - assert grandchild_state.value2 == "new" - assert update.delta == { - # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, - GrandchildState.get_full_name(): {"value2": "new"}, - GrandchildState3.get_full_name(): {"computed": ""}, - } + async for update in test_state._process(event): + assert grandchild_state.value2 == "new" + assert update.delta == { + # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, + GrandchildState.get_full_name(): {"value2": "new"}, + GrandchildState3.get_full_name(): {"computed": ""}, + } @pytest.mark.asyncio @@ -2917,10 +2916,10 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker): events = updates[0].events assert len(events) == 2 - assert (await state._process(events[0]).__anext__()).delta == { - test_state.get_full_name(): {"num": 1} - } - assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state) + async for update in state._process(events[0]): + assert update.delta == {test_state.get_full_name(): {"num": 1}} + async for update in state._process(events[1]): + assert update.delta == exp_is_hydrated(state) if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() @@ -2965,13 +2964,12 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker): events = updates[0].events assert len(events) == 3 - assert (await state._process(events[0]).__anext__()).delta == { - OnLoadState.get_full_name(): {"num": 1} - } - assert (await state._process(events[1]).__anext__()).delta == { - OnLoadState.get_full_name(): {"num": 2} - } - assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state) + async for update in state._process(events[0]): + assert update.delta == {OnLoadState.get_full_name(): {"num": 1}} + async for update in state._process(events[1]): + assert update.delta == {OnLoadState.get_full_name(): {"num": 2}} + async for update in state._process(events[2]): + assert update.delta == exp_is_hydrated(state) if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() diff --git a/tests/units/test_style.py b/tests/units/test_style.py index 6ab00d561..e8ff5bd01 100644 --- a/tests/units/test_style.py +++ b/tests/units/test_style.py @@ -541,3 +541,7 @@ def test_style_update_with_var_data(): assert s2._var_data is not None assert "const red = true" in s2._var_data.hooks assert "const blue = true" in s2._var_data.hooks + + s3 = s1 | s2 + assert s3._var_data is not None + assert "_varData" not in s3 diff --git a/tests/units/test_var.py b/tests/units/test_var.py index 4a41adaf7..a72242814 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -1,6 +1,5 @@ import json import math -import sys import typing from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast @@ -422,19 +421,13 @@ class Bar(rx.Base): @pytest.mark.parametrize( ("var", "var_type"), - ( - [ - (Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar), - (Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]), - ] - if sys.version_info >= (3, 10) - else [] - ) - + [ - (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]), - (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str), + [ + (Var(_js_expr="").to(Foo | Bar), Foo | Bar), + (Var(_js_expr="").to(Foo | Bar).bar, Union[int, str]), + (Var(_js_expr="").to(Union[Foo, Bar]), Union[Foo, Bar]), + (Var(_js_expr="").to(Union[Foo, Bar]).baz, str), ( - Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo, + Var(_js_expr="").to(Union[Foo, Bar]).foo, Union[int, None], ), ], @@ -1076,19 +1069,19 @@ def test_array_operations(): assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].slice().reverse()" assert ( str(ArrayVar.range(10)) - == "Array.from({ length: (10 - 0) / 1 }, (_, i) => 0 + i * 1)" + == "Array.from({ length: Math.ceil((10 - 0) / 1) }, (_, i) => 0 + i * 1)" ) assert ( str(ArrayVar.range(1, 10)) - == "Array.from({ length: (10 - 1) / 1 }, (_, i) => 1 + i * 1)" + == "Array.from({ length: Math.ceil((10 - 1) / 1) }, (_, i) => 1 + i * 1)" ) assert ( str(ArrayVar.range(1, 10, 2)) - == "Array.from({ length: (10 - 1) / 2 }, (_, i) => 1 + i * 2)" + == "Array.from({ length: Math.ceil((10 - 1) / 2) }, (_, i) => 1 + i * 2)" ) assert ( str(ArrayVar.range(1, 10, -1)) - == "Array.from({ length: (10 - 1) / -1 }, (_, i) => 1 + i * -1)" + == "Array.from({ length: Math.ceil((10 - 1) / -1) }, (_, i) => 1 + i * -1)" ) @@ -1358,7 +1351,7 @@ def test_unsupported_types_for_contains(var: Var): var: The base var. """ with pytest.raises(TypeError) as err: - assert var.contains(1) + assert var.contains(1) # pyright: ignore [reportAttributeAccessIssue] assert ( err.value.args[0] == f"Var of type {var._var_type} does not support contains check." @@ -1388,7 +1381,7 @@ def test_unsupported_types_for_string_contains(other): def test_unsupported_default_contains(): with pytest.raises(TypeError) as err: - assert 1 in Var(_js_expr="var", _var_type=str).guess_type() + assert 1 in Var(_js_expr="var", _var_type=str).guess_type() # pyright: ignore [reportOperatorIssue] assert ( err.value.args[0] == "'in' operator not supported for Var types, use Var.contains() instead." @@ -1884,3 +1877,19 @@ async def test_async_computed_var(): assert await my_state.async_computed_var == 3 assert await my_state.async_computed_var == 3 assert side_effect_counter == 2 + + +def test_var_data_hooks(): + var_data_str = VarData(hooks="what") + var_data_list = VarData(hooks=["what"]) + var_data_dict = VarData(hooks={"what": None}) + assert var_data_str == var_data_list == var_data_dict + + var_data_list_multiple = VarData(hooks=["what", "whot"]) + var_data_dict_multiple = VarData(hooks={"what": None, "whot": None}) + assert var_data_list_multiple == var_data_dict_multiple + + +def test_var_data_with_hooks_value(): + var_data = VarData(hooks={"what": VarData(hooks={"whot": VarData(hooks="whott")})}) + assert var_data == VarData(hooks=["what", "whot", "whott"])