diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index dbe069ae8..0bad7b996 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -3,7 +3,7 @@ fail_fast: true
repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
- rev: v0.8.2
+ rev: v0.9.3
hooks:
- id: ruff-format
args: [reflex, tests]
@@ -24,11 +24,12 @@ repos:
name: update-pyi-files
always_run: true
language: system
+ require_serial: true
description: 'Update pyi files as needed'
entry: python3 scripts/make_pyi.py
- repo: https://github.com/RobertCraigie/pyright-python
- rev: v1.1.392
+ rev: v1.1.393
hooks:
- id: pyright
args: [reflex, tests]
diff --git a/poetry.lock b/poetry.lock
index 125b71b55..f5007ee07 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -164,15 +164,15 @@ virtualenv = ["virtualenv (>=20.0.35)"]
[[package]]
name = "certifi"
-version = "2024.12.14"
+version = "2025.1.31"
description = "Python package for providing Mozilla's CA Bundle."
optional = false
python-versions = ">=3.6"
groups = ["main", "dev"]
markers = "python_version <= \"3.11\" or python_version >= \"3.12\""
files = [
- {file = "certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56"},
- {file = "certifi-2024.12.14.tar.gz", hash = "sha256:b650d30f370c2b724812bee08008be0c4163b163ddaec3f2546c1caf65f191db"},
+ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"},
+ {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"},
]
[[package]]
@@ -618,15 +618,15 @@ test = ["pytest (>=6)"]
[[package]]
name = "fastapi"
-version = "0.115.7"
+version = "0.115.8"
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
optional = false
python-versions = ">=3.8"
groups = ["main"]
markers = "python_version <= \"3.11\" or python_version >= \"3.12\""
files = [
- {file = "fastapi-0.115.7-py3-none-any.whl", hash = "sha256:eb6a8c8bf7f26009e8147111ff15b5177a0e19bb4a45bc3486ab14804539d21e"},
- {file = "fastapi-0.115.7.tar.gz", hash = "sha256:0f106da6c01d88a6786b3248fb4d7a940d071f6f488488898ad5d354b25ed015"},
+ {file = "fastapi-0.115.8-py3-none-any.whl", hash = "sha256:753a96dd7e036b34eeef8babdfcfe3f28ff79648f86551eb36bfc1b0bf4a8cbf"},
+ {file = "fastapi-0.115.8.tar.gz", hash = "sha256:0ce9111231720190473e222cdf0f07f7206ad7e53ea02beb1d2dc36e2f0741e9"},
]
[package.dependencies]
@@ -1876,15 +1876,15 @@ files = [
[[package]]
name = "pyright"
-version = "1.1.392.post0"
+version = "1.1.393"
description = "Command line wrapper for pyright"
optional = false
python-versions = ">=3.7"
groups = ["dev"]
markers = "python_version <= \"3.11\" or python_version >= \"3.12\""
files = [
- {file = "pyright-1.1.392.post0-py3-none-any.whl", hash = "sha256:252f84458a46fa2f0fd4e2f91fc74f50b9ca52c757062e93f6c250c0d8329eb2"},
- {file = "pyright-1.1.392.post0.tar.gz", hash = "sha256:3b7f88de74a28dcfa90c7d90c782b6569a48c2be5f9d4add38472bdaac247ebd"},
+ {file = "pyright-1.1.393-py3-none-any.whl", hash = "sha256:8320629bb7a44ca90944ba599390162bf59307f3d9fb6e27da3b7011b8c17ae5"},
+ {file = "pyright-1.1.393.tar.gz", hash = "sha256:aeeb7ff4e0364775ef416a80111613f91a05c8e01e58ecfefc370ca0db7aed9c"},
]
[package.dependencies]
@@ -1998,25 +1998,25 @@ histogram = ["pygal", "pygaljs", "setuptools"]
[[package]]
name = "pytest-codspeed"
-version = "3.1.2"
+version = "3.2.0"
description = "Pytest plugin to create CodSpeed benchmarks"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
markers = "python_version <= \"3.11\" or python_version >= \"3.12\""
files = [
- {file = "pytest_codspeed-3.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:aed496f873670ce0ea8f980a7c1a2c6a08f415e0ebdf207bf651b2d922103374"},
- {file = "pytest_codspeed-3.1.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ee45b0b763f6b5fa5d74c7b91d694a9615561c428b320383660672f4471756e3"},
- {file = "pytest_codspeed-3.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c84e591a7a0f67d45e2dc9fd05b276971a3aabcab7478fe43363ebefec1358f4"},
- {file = "pytest_codspeed-3.1.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c6ae6d094247156407770e6b517af70b98862dd59a3c31034aede11d5f71c32c"},
- {file = "pytest_codspeed-3.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d0f264991de5b5cdc118b96fc671386cca3f0f34e411482939bf2459dc599097"},
- {file = "pytest_codspeed-3.1.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c0695a4bcd5ff04e8379124dba5d9795ea5e0cadf38be7a0406432fc1467b555"},
- {file = "pytest_codspeed-3.1.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6dc356c8dcaaa883af83310f397ac06c96fac9b8a1146e303d4b374b2cb46a18"},
- {file = "pytest_codspeed-3.1.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cc8a5d0366322a75cf562f7d8d672d28c1cf6948695c4dddca50331e08f6b3d5"},
- {file = "pytest_codspeed-3.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c5fe7a19b72f54f217480b3b527102579547b1de9fe3acd9e66cb4629ff46c8"},
- {file = "pytest_codspeed-3.1.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b67205755a665593f6521a98317d02a9d07d6fdc593f6634de2c94dea47a3055"},
- {file = "pytest_codspeed-3.1.2-py3-none-any.whl", hash = "sha256:5e7ed0315e33496c5c07dba262b50303b8d0bc4c3d10bf1d422a41e70783f1cb"},
- {file = "pytest_codspeed-3.1.2.tar.gz", hash = "sha256:09c1733af3aab35e94a621aa510f2d2114f65591e6f644c42ca3f67547edad4b"},
+ {file = "pytest_codspeed-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c5165774424c7ab8db7e7acdb539763a0e5657996effefdf0664d7fd95158d34"},
+ {file = "pytest_codspeed-3.2.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9bd55f92d772592c04a55209950c50880413ae46876e66bd349ef157075ca26c"},
+ {file = "pytest_codspeed-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4cf6f56067538f4892baa8d7ab5ef4e45bb59033be1ef18759a2c7fc55b32035"},
+ {file = "pytest_codspeed-3.2.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:39a687b05c3d145642061b45ea78e47e12f13ce510104d1a2cda00eee0e36f58"},
+ {file = "pytest_codspeed-3.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:46a1afaaa1ac4c2ca5b0700d31ac46d80a27612961d031067d73c6ccbd8d3c2b"},
+ {file = "pytest_codspeed-3.2.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c48ce3af3dfa78413ed3d69d1924043aa1519048dbff46edccf8f35a25dab3c2"},
+ {file = "pytest_codspeed-3.2.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:66692506d33453df48b36a84703448cb8b22953eea51f03fbb2eb758dc2bdc4f"},
+ {file = "pytest_codspeed-3.2.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:479774f80d0bdfafa16112700df4dbd31bf2a6757fac74795fd79c0a7b3c389b"},
+ {file = "pytest_codspeed-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:109f9f4dd1088019c3b3f887d003b7d65f98a7736ca1d457884f5aa293e8e81c"},
+ {file = "pytest_codspeed-3.2.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e2f69a03b52c9bb041aec1b8ee54b7b6c37a6d0a948786effa4c71157765b6da"},
+ {file = "pytest_codspeed-3.2.0-py3-none-any.whl", hash = "sha256:54b5c2e986d6a28e7b0af11d610ea57bd5531cec8326abe486f1b55b09d91c39"},
+ {file = "pytest_codspeed-3.2.0.tar.gz", hash = "sha256:f9d1b1a3b2c69cdc0490a1e8b1ced44bffbd0e8e21d81a7160cfdd923f6e8155"},
]
[package.dependencies]
@@ -2070,15 +2070,15 @@ dev = ["pre-commit", "pytest-asyncio", "tox"]
[[package]]
name = "pytest-playwright"
-version = "0.6.2"
+version = "0.7.0"
description = "A pytest wrapper with fixtures for Playwright to automate web browsers"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
markers = "python_version <= \"3.11\" or python_version >= \"3.12\""
files = [
- {file = "pytest_playwright-0.6.2-py3-none-any.whl", hash = "sha256:0eff73bebe497b0158befed91e2f5fe94cfa17181f8b3acf575beed84e7e9043"},
- {file = "pytest_playwright-0.6.2.tar.gz", hash = "sha256:ff4054b19aa05df096ac6f74f0572591566aaf0f6d97f6cb9674db8a4d4ed06c"},
+ {file = "pytest_playwright-0.7.0-py3-none-any.whl", hash = "sha256:2516d0871fa606634bfe32afbcc0342d68da2dbff97fe3459849e9c428486da2"},
+ {file = "pytest_playwright-0.7.0.tar.gz", hash = "sha256:b3f2ea514bbead96d26376fac182f68dcd6571e7cb41680a89ff1673c05d60b6"},
]
[package.dependencies]
@@ -2180,15 +2180,15 @@ docs = ["sphinx"]
[[package]]
name = "pytz"
-version = "2024.2"
+version = "2025.1"
description = "World timezone definitions, modern and historical"
optional = false
python-versions = "*"
groups = ["dev"]
markers = "python_version <= \"3.11\" or python_version >= \"3.12\""
files = [
- {file = "pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725"},
- {file = "pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a"},
+ {file = "pytz-2025.1-py2.py3-none-any.whl", hash = "sha256:89dd22dca55b46eac6eda23b2d72721bf1bdfef212645d81513ef5d03038de57"},
+ {file = "pytz-2025.1.tar.gz", hash = "sha256:c2db42be2a2518b28e65f9207c4d05e6ff547d1efa4086469ef855e4ab70178e"},
]
[[package]]
@@ -2311,15 +2311,15 @@ ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"
[[package]]
name = "reflex-hosting-cli"
-version = "0.1.33"
+version = "0.1.34"
description = "Reflex Hosting CLI"
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
markers = "python_version <= \"3.11\" or python_version >= \"3.12\""
files = [
- {file = "reflex_hosting_cli-0.1.33-py3-none-any.whl", hash = "sha256:3fe72fc448a231c61de4ac646f42c936c70e91330f616a23aec658f905d53bc4"},
- {file = "reflex_hosting_cli-0.1.33.tar.gz", hash = "sha256:81c4a896b106eea99f1cab53ea23a6e19802592ce0468cc38d93d440bc95263a"},
+ {file = "reflex_hosting_cli-0.1.34-py3-none-any.whl", hash = "sha256:eabc4dc7bf68e022a9388614c1a35b5ab36b01021df063d0c3356eda0e245264"},
+ {file = "reflex_hosting_cli-0.1.34.tar.gz", hash = "sha256:07be37fda6dcede0a5d4bc1fd1786d9a3df5ad4e49dc1b6ba335418563cfecec"},
]
[package.dependencies]
@@ -2410,31 +2410,31 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]]
name = "ruff"
-version = "0.8.2"
+version = "0.9.3"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
groups = ["dev"]
markers = "python_version <= \"3.11\" or python_version >= \"3.12\""
files = [
- {file = "ruff-0.8.2-py3-none-linux_armv6l.whl", hash = "sha256:c49ab4da37e7c457105aadfd2725e24305ff9bc908487a9bf8d548c6dad8bb3d"},
- {file = "ruff-0.8.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ec016beb69ac16be416c435828be702ee694c0d722505f9c1f35e1b9c0cc1bf5"},
- {file = "ruff-0.8.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f05cdf8d050b30e2ba55c9b09330b51f9f97d36d4673213679b965d25a785f3c"},
- {file = "ruff-0.8.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60f578c11feb1d3d257b2fb043ddb47501ab4816e7e221fbb0077f0d5d4e7b6f"},
- {file = "ruff-0.8.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cbd5cf9b0ae8f30eebc7b360171bd50f59ab29d39f06a670b3e4501a36ba5897"},
- {file = "ruff-0.8.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b402ddee3d777683de60ff76da801fa7e5e8a71038f57ee53e903afbcefdaa58"},
- {file = "ruff-0.8.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:705832cd7d85605cb7858d8a13d75993c8f3ef1397b0831289109e953d833d29"},
- {file = "ruff-0.8.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:32096b41aaf7a5cc095fa45b4167b890e4c8d3fd217603f3634c92a541de7248"},
- {file = "ruff-0.8.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e769083da9439508833cfc7c23e351e1809e67f47c50248250ce1ac52c21fb93"},
- {file = "ruff-0.8.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fe716592ae8a376c2673fdfc1f5c0c193a6d0411f90a496863c99cd9e2ae25d"},
- {file = "ruff-0.8.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:81c148825277e737493242b44c5388a300584d73d5774defa9245aaef55448b0"},
- {file = "ruff-0.8.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d261d7850c8367704874847d95febc698a950bf061c9475d4a8b7689adc4f7fa"},
- {file = "ruff-0.8.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1ca4e3a87496dc07d2427b7dd7ffa88a1e597c28dad65ae6433ecb9f2e4f022f"},
- {file = "ruff-0.8.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:729850feed82ef2440aa27946ab39c18cb4a8889c1128a6d589ffa028ddcfc22"},
- {file = "ruff-0.8.2-py3-none-win32.whl", hash = "sha256:ac42caaa0411d6a7d9594363294416e0e48fc1279e1b0e948391695db2b3d5b1"},
- {file = "ruff-0.8.2-py3-none-win_amd64.whl", hash = "sha256:2aae99ec70abf43372612a838d97bfe77d45146254568d94926e8ed5bbb409ea"},
- {file = "ruff-0.8.2-py3-none-win_arm64.whl", hash = "sha256:fb88e2a506b70cfbc2de6fae6681c4f944f7dd5f2fe87233a7233d888bad73e8"},
- {file = "ruff-0.8.2.tar.gz", hash = "sha256:b84f4f414dda8ac7f75075c1fa0b905ac0ff25361f42e6d5da681a465e0f78e5"},
+ {file = "ruff-0.9.3-py3-none-linux_armv6l.whl", hash = "sha256:7f39b879064c7d9670197d91124a75d118d00b0990586549949aae80cdc16624"},
+ {file = "ruff-0.9.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:a187171e7c09efa4b4cc30ee5d0d55a8d6c5311b3e1b74ac5cb96cc89bafc43c"},
+ {file = "ruff-0.9.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c59ab92f8e92d6725b7ded9d4a31be3ef42688a115c6d3da9457a5bda140e2b4"},
+ {file = "ruff-0.9.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dc153c25e715be41bb228bc651c1e9b1a88d5c6e5ed0194fa0dfea02b026439"},
+ {file = "ruff-0.9.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:646909a1e25e0dc28fbc529eab8eb7bb583079628e8cbe738192853dbbe43af5"},
+ {file = "ruff-0.9.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a5a46e09355695fbdbb30ed9889d6cf1c61b77b700a9fafc21b41f097bfbba4"},
+ {file = "ruff-0.9.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c4bb09d2bbb394e3730d0918c00276e79b2de70ec2a5231cd4ebb51a57df9ba1"},
+ {file = "ruff-0.9.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:96a87ec31dc1044d8c2da2ebbed1c456d9b561e7d087734336518181b26b3aa5"},
+ {file = "ruff-0.9.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bb7554aca6f842645022fe2d301c264e6925baa708b392867b7a62645304df4"},
+ {file = "ruff-0.9.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cabc332b7075a914ecea912cd1f3d4370489c8018f2c945a30bcc934e3bc06a6"},
+ {file = "ruff-0.9.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:33866c3cc2a575cbd546f2cd02bdd466fed65118e4365ee538a3deffd6fcb730"},
+ {file = "ruff-0.9.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:006e5de2621304c8810bcd2ee101587712fa93b4f955ed0985907a36c427e0c2"},
+ {file = "ruff-0.9.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ba6eea4459dbd6b1be4e6bfc766079fb9b8dd2e5a35aff6baee4d9b1514ea519"},
+ {file = "ruff-0.9.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:90230a6b8055ad47d3325e9ee8f8a9ae7e273078a66401ac66df68943ced029b"},
+ {file = "ruff-0.9.3-py3-none-win32.whl", hash = "sha256:eabe5eb2c19a42f4808c03b82bd313fc84d4e395133fb3fc1b1516170a31213c"},
+ {file = "ruff-0.9.3-py3-none-win_amd64.whl", hash = "sha256:040ceb7f20791dfa0e78b4230ee9dce23da3b64dd5848e40e3bf3ab76468dcf4"},
+ {file = "ruff-0.9.3-py3-none-win_arm64.whl", hash = "sha256:800d773f6d4d33b0a3c60e2c6ae8f4c202ea2de056365acfa519aa48acf28e0b"},
+ {file = "ruff-0.9.3.tar.gz", hash = "sha256:8293f89985a090ebc3ed1064df31f3b4b56320cdfcec8b60d3295bddb955c22a"},
]
[[package]]
@@ -3183,4 +3183,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10, <4.0"
-content-hash = "822150bcbf41e5cbb61da0a059b41d8971e3c6c974c8af4be7ef55126648aea1"
+content-hash = "3b7e6e6e872c68f951f191d85a7d76fe1dd86caf32e2143a53a3152a3686fc7f"
diff --git a/pyproject.toml b/pyproject.toml
index 8d0b37a23..2b5507a1d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -41,7 +41,7 @@ wrapt = [
{ version = ">=1.11.0,<2.0", python = "<3.11" },
]
packaging = ">=23.1,<25.0"
-reflex-hosting-cli = ">=0.1.29,<2.0"
+reflex-hosting-cli = ">=0.1.29"
charset-normalizer = ">=3.3.2,<4.0"
wheel = ">=0.42.0,<1.0"
build = ">=1.0.3,<2.0"
@@ -61,7 +61,7 @@ dill = ">=0.3.8"
toml = ">=0.10.2,<1.0"
pytest-asyncio = ">=0.24.0"
pytest-cov = ">=4.0.0,<7.0"
-ruff = "0.8.2"
+ruff = "0.9.3"
pandas = ">=2.1.1,<3.0"
pillow = ">=10.0.0,<12.0"
plotly = ">=5.13.0,<6.0"
@@ -88,7 +88,7 @@ target-version = "py310"
output-format = "concise"
lint.isort.split-on-trailing-comma = false
lint.select = ["ANN001","B", "C4", "D", "E", "ERA", "F", "FURB", "I", "N", "PERF", "PGH", "PTH", "RUF", "SIM", "T", "TRY", "W"]
-lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012", "TRY0"]
+lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF008", "RUF012", "TRY0"]
lint.pydocstyle.convention = "google"
[tool.ruff.lint.per-file-ignores]
diff --git a/reflex/.templates/jinja/web/pages/_app.js.jinja2 b/reflex/.templates/jinja/web/pages/_app.js.jinja2
index 40e31dee6..ee3e24540 100644
--- a/reflex/.templates/jinja/web/pages/_app.js.jinja2
+++ b/reflex/.templates/jinja/web/pages/_app.js.jinja2
@@ -38,13 +38,13 @@ export default function MyApp({ Component, pageProps }) {
}, []);
return (
-
-
-
-
-
-
-
+
+
+
+
+
+
+
);
}
diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js
index 009910a32..2f09ac2de 100644
--- a/reflex/.templates/web/utils/state.js
+++ b/reflex/.templates/web/utils/state.js
@@ -227,8 +227,8 @@ export const applyEvent = async (event, socket) => {
a.href = eval?.(
event.payload.url.replace(
"getBackendURL(env.UPLOAD)",
- `"${getBackendURL(env.UPLOAD)}"`
- )
+ `"${getBackendURL(env.UPLOAD)}"`,
+ ),
);
}
a.download = event.payload.filename;
@@ -341,7 +341,7 @@ export const applyRestEvent = async (event, socket) => {
event.payload.files,
event.payload.upload_id,
event.payload.on_upload_progress,
- socket
+ socket,
);
return false;
}
@@ -408,7 +408,7 @@ export const connect = async (
dispatch,
transports,
setConnectErrors,
- client_storage = {}
+ client_storage = {},
) => {
// Get backend URL object from the endpoint.
const endpoint = getBackendURL(EVENTURL);
@@ -499,7 +499,7 @@ export const uploadFiles = async (
files,
upload_id,
on_upload_progress,
- socket
+ socket,
) => {
// return if there's no file to upload
if (files === undefined || files.length === 0) {
@@ -604,7 +604,7 @@ export const Event = (
name,
payload = {},
event_actions = {},
- handler = null
+ handler = null,
) => {
return { name, payload, handler, event_actions };
};
@@ -631,7 +631,7 @@ export const hydrateClientStorage = (client_storage) => {
for (const state_key in client_storage.local_storage) {
const options = client_storage.local_storage[state_key];
const local_storage_value = localStorage.getItem(
- options.name || state_key
+ options.name || state_key,
);
if (local_storage_value !== null) {
client_storage_values[state_key] = local_storage_value;
@@ -642,7 +642,7 @@ export const hydrateClientStorage = (client_storage) => {
for (const state_key in client_storage.session_storage) {
const session_options = client_storage.session_storage[state_key];
const session_storage_value = sessionStorage.getItem(
- session_options.name || state_key
+ session_options.name || state_key,
);
if (session_storage_value != null) {
client_storage_values[state_key] = session_storage_value;
@@ -667,7 +667,7 @@ export const hydrateClientStorage = (client_storage) => {
const applyClientStorageDelta = (client_storage, delta) => {
// find the main state and check for is_hydrated
const unqualified_states = Object.keys(delta).filter(
- (key) => key.split(".").length === 1
+ (key) => key.split(".").length === 1,
);
if (unqualified_states.length === 1) {
const main_state = delta[unqualified_states[0]];
@@ -701,7 +701,7 @@ const applyClientStorageDelta = (client_storage, delta) => {
const session_options = client_storage.session_storage[state_key];
sessionStorage.setItem(
session_options.name || state_key,
- delta[substate][key]
+ delta[substate][key],
);
}
}
@@ -721,7 +721,7 @@ const applyClientStorageDelta = (client_storage, delta) => {
export const useEventLoop = (
dispatch,
initial_events = () => [],
- client_storage = {}
+ client_storage = {},
) => {
const socket = useRef(null);
const router = useRouter();
@@ -735,7 +735,7 @@ export const useEventLoop = (
event_actions = events.reduce(
(acc, e) => ({ ...acc, ...e.event_actions }),
- event_actions ?? {}
+ event_actions ?? {},
);
const _e = args.filter((o) => o?.preventDefault !== undefined)[0];
@@ -763,7 +763,7 @@ export const useEventLoop = (
debounce(
combined_name,
() => queueEvents(events, socket),
- event_actions.debounce
+ event_actions.debounce,
);
} else {
queueEvents(events, socket);
@@ -782,7 +782,7 @@ export const useEventLoop = (
query,
asPath,
}))(router),
- }))
+ })),
);
sentHydrate.current = true;
}
@@ -817,13 +817,9 @@ export const useEventLoop = (
};
}, []);
- // Main event loop.
+ // Handle socket connect/disconnect.
useEffect(() => {
- // Skip if the router is not ready.
- if (!router.isReady) {
- return;
- }
- // only use websockets if state is present
+ // only use websockets if state is present and backend is not disabled (reflex cloud).
if (Object.keys(initialState).length > 1 && !isBackendDisabled()) {
// Initialize the websocket connection.
if (!socket.current) {
@@ -832,16 +828,31 @@ export const useEventLoop = (
dispatch,
["websocket"],
setConnectErrors,
- client_storage
+ client_storage,
);
}
- (async () => {
- // Process all outstanding events.
- while (event_queue.length > 0 && !event_processing) {
- await processEvent(socket.current);
- }
- })();
}
+
+ // Cleanup function.
+ return () => {
+ if (socket.current) {
+ socket.current.disconnect();
+ }
+ };
+ }, []);
+
+ // Main event loop.
+ useEffect(() => {
+ // Skip if the router is not ready.
+ if (!router.isReady || isBackendDisabled()) {
+ return;
+ }
+ (async () => {
+ // Process all outstanding events.
+ while (event_queue.length > 0 && !event_processing) {
+ await processEvent(socket.current);
+ }
+ })();
});
// localStorage event handling
@@ -865,7 +876,7 @@ export const useEventLoop = (
vars[storage_to_state_map[e.key]] = e.newValue;
const event = Event(
`${state_name}.reflex___state____update_vars_internal_state.update_vars_internal`,
- { vars: vars }
+ { vars: vars },
);
addEvents([event], e);
}
@@ -958,7 +969,7 @@ export const getRefValues = (refs) => {
return refs.map((ref) =>
ref.current
? ref.current.value || ref.current.getAttribute("aria-valuenow")
- : null
+ : null,
);
};
diff --git a/reflex/app.py b/reflex/app.py
index a8d5d156d..d9104ece6 100644
--- a/reflex/app.py
+++ b/reflex/app.py
@@ -54,6 +54,7 @@ from reflex.compiler.compiler import ExecutorSafeFunctions, compile_theme
from reflex.components.base.app_wrap import AppWrap
from reflex.components.base.error_boundary import ErrorBoundary
from reflex.components.base.fragment import Fragment
+from reflex.components.base.strict_mode import StrictMode
from reflex.components.component import (
Component,
ComponentStyle,
@@ -151,7 +152,7 @@ def default_backend_exception_handler(exception: Exception) -> EventSpec:
position="top-center",
id="backend_error",
style={"width": "500px"},
- ) # pyright: ignore [reportReturnType]
+ )
else:
error_message.insert(0, "An error occurred.")
return window_alert("\n".join(error_message))
@@ -918,11 +919,17 @@ class App(MiddlewareMixin, LifespanMixin):
if not var._cache:
continue
deps = var._deps(objclass=state)
- for dep in deps:
- if dep not in state.vars and dep not in state.backend_vars:
- raise exceptions.VarDependencyError(
- f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {dep}"
- )
+ for state_name, dep_set in deps.items():
+ state_cls = (
+ state.get_root_state().get_class_substate(state_name)
+ if state_name != state.get_full_name()
+ else state
+ )
+ for dep in dep_set:
+ if dep not in state_cls.vars and dep not in state_cls.backend_vars:
+ raise exceptions.VarDependencyError(
+ f"ComputedVar {var._js_expr} on state {state.__name__} has an invalid dependency {state_name}.{dep}"
+ )
for substate in state.class_subclasses:
self._validate_var_dependencies(substate)
@@ -960,27 +967,23 @@ class App(MiddlewareMixin, LifespanMixin):
# If a theme component was provided, wrap the app with it
app_wrappers[(20, "Theme")] = self.theme
- should_compile = self._should_compile()
-
- for route in self._unevaluated_pages:
- console.debug(f"Evaluating page: {route}")
- self._compile_page(route, save_page=should_compile)
-
- # Add the optional endpoints (_upload)
- self._add_optional_endpoints()
-
- if not should_compile:
- return
-
# Get the env mode.
config = get_config()
- self._validate_var_dependencies()
- self._setup_overlay_component()
- self._setup_error_boundary()
+ if config.react_strict_mode:
+ app_wrappers[(200, "StrictMode")] = StrictMode.create()
- if config.show_built_with_reflex:
- self._setup_sticky_badge()
+ should_compile = self._should_compile()
+
+ if not should_compile:
+ for route in self._unevaluated_pages:
+ console.debug(f"Evaluating page: {route}")
+ self._compile_page(route, save_page=should_compile)
+
+ # Add the optional endpoints (_upload)
+ self._add_optional_endpoints()
+
+ return
# Create a progress bar.
progress = Progress(
@@ -990,16 +993,33 @@ class App(MiddlewareMixin, LifespanMixin):
)
# try to be somewhat accurate - but still not 100%
- adhoc_steps_without_executor = 6
+ adhoc_steps_without_executor = 7
fixed_pages_within_executor = 5
progress.start()
task = progress.add_task(
f"[{get_compilation_time()}] Compiling:",
total=len(self._pages)
+ + (len(self._unevaluated_pages) * 2)
+ fixed_pages_within_executor
+ adhoc_steps_without_executor,
)
+ for route in self._unevaluated_pages:
+ console.debug(f"Evaluating page: {route}")
+ self._compile_page(route, save_page=should_compile)
+ progress.advance(task)
+
+ # Add the optional endpoints (_upload)
+ self._add_optional_endpoints()
+
+ self._validate_var_dependencies()
+ self._setup_overlay_component()
+ self._setup_error_boundary()
+ if config.show_built_with_reflex:
+ self._setup_sticky_badge()
+
+ progress.advance(task)
+
# Store the compile results.
compile_results = []
@@ -1306,7 +1326,7 @@ class App(MiddlewareMixin, LifespanMixin):
):
raise ValueError(
f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong argument order."
- f"Expected `{required_arg}` as the {required_arg_index+1} argument but got `{list(arg_annotations.keys())[required_arg_index]}`"
+ f"Expected `{required_arg}` as the {required_arg_index + 1} argument but got `{list(arg_annotations.keys())[required_arg_index]}`"
)
if not issubclass(arg_annotations[required_arg], Exception):
diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py
index 218dd0c55..c2a76aad3 100644
--- a/reflex/compiler/compiler.py
+++ b/reflex/compiler/compiler.py
@@ -239,11 +239,19 @@ def _compile_components(
component_renders.append(component_render)
imports = utils.merge_imports(imports, component_imports)
+ dynamic_imports = {
+ comp_import: None
+ for comp_render in component_renders
+ if "dynamic_imports" in comp_render
+ for comp_import in comp_render["dynamic_imports"]
+ }
+
# Compile the components page.
return (
templates.COMPONENTS.render(
imports=utils.compile_imports(imports),
components=component_renders,
+ dynamic_imports=dynamic_imports,
),
imports,
)
diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py
index d145e6c0b..57241fea9 100644
--- a/reflex/compiler/utils.py
+++ b/reflex/compiler/utils.py
@@ -2,12 +2,15 @@
from __future__ import annotations
+import asyncio
+import concurrent.futures
import traceback
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union
from urllib.parse import urlparse
+from reflex.utils.exec import is_in_app_harness
from reflex.utils.prerequisites import get_web_dir
from reflex.vars.base import Var
@@ -33,7 +36,7 @@ from reflex.components.base import (
)
from reflex.components.component import Component, ComponentStyle, CustomComponent
from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
-from reflex.state import BaseState
+from reflex.state import BaseState, _resolve_delta
from reflex.style import Style
from reflex.utils import console, format, imports, path_ops
from reflex.utils.imports import ImportVar, ParsedImportDict
@@ -177,7 +180,24 @@ def compile_state(state: Type[BaseState]) -> dict:
initial_state = state(_reflex_internal_init=True).dict(
initial=True, include_computed=False
)
- return initial_state
+ try:
+ _ = asyncio.get_running_loop()
+ except RuntimeError:
+ pass
+ else:
+ if is_in_app_harness():
+ # Playwright tests already have an event loop running, so we can't use asyncio.run.
+ with concurrent.futures.ThreadPoolExecutor() as pool:
+ resolved_initial_state = pool.submit(
+ asyncio.run, _resolve_delta(initial_state)
+ ).result()
+ console.warn(
+ f"Had to get initial state in a thread 🤮 {resolved_initial_state}",
+ )
+ return resolved_initial_state
+
+ # Normally the compile runs before any event loop starts, we asyncio.run is available for calling.
+ return asyncio.run(_resolve_delta(initial_state))
def _compile_client_storage_field(
@@ -300,6 +320,7 @@ def compile_custom_component(
"render": render.render(),
"hooks": render._get_all_hooks(),
"custom_code": render._get_all_custom_code(),
+ "dynamic_imports": render._get_all_dynamic_imports(),
},
imports,
)
diff --git a/reflex/components/base/strict_mode.py b/reflex/components/base/strict_mode.py
new file mode 100644
index 000000000..46b01ad87
--- /dev/null
+++ b/reflex/components/base/strict_mode.py
@@ -0,0 +1,10 @@
+"""Module for the StrictMode component."""
+
+from reflex.components.component import Component
+
+
+class StrictMode(Component):
+ """A React strict mode component to enable strict mode for its children."""
+
+ library = "react"
+ tag = "StrictMode"
diff --git a/reflex/components/base/strict_mode.pyi b/reflex/components/base/strict_mode.pyi
new file mode 100644
index 000000000..9005c0222
--- /dev/null
+++ b/reflex/components/base/strict_mode.pyi
@@ -0,0 +1,57 @@
+"""Stub file for reflex/components/base/strict_mode.py"""
+
+# ------------------- DO NOT EDIT ----------------------
+# This file was generated by `reflex/utils/pyi_generator.py`!
+# ------------------------------------------------------
+from typing import Any, Dict, Optional, Union, overload
+
+from reflex.components.component import Component
+from reflex.event import BASE_STATE, EventType
+from reflex.style import Style
+from reflex.vars.base import Var
+
+class StrictMode(Component):
+ @overload
+ @classmethod
+ def create( # type: ignore
+ cls,
+ *children,
+ style: Optional[Style] = None,
+ key: Optional[Any] = None,
+ id: Optional[Any] = None,
+ class_name: Optional[Any] = None,
+ autofocus: Optional[bool] = None,
+ custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
+ on_blur: Optional[EventType[[], BASE_STATE]] = None,
+ on_click: Optional[EventType[[], BASE_STATE]] = None,
+ on_context_menu: Optional[EventType[[], BASE_STATE]] = None,
+ on_double_click: Optional[EventType[[], BASE_STATE]] = None,
+ on_focus: Optional[EventType[[], BASE_STATE]] = None,
+ on_mount: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_down: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_move: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_out: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_over: Optional[EventType[[], BASE_STATE]] = None,
+ on_mouse_up: Optional[EventType[[], BASE_STATE]] = None,
+ on_scroll: Optional[EventType[[], BASE_STATE]] = None,
+ on_unmount: Optional[EventType[[], BASE_STATE]] = None,
+ **props,
+ ) -> "StrictMode":
+ """Create the component.
+
+ Args:
+ *children: The children of the component.
+ style: The style of the component.
+ key: A unique key for the component.
+ id: The id for the component.
+ class_name: The class name for the component.
+ autofocus: Whether the component should take the focus once the page is loaded
+ custom_attrs: custom attribute
+ **props: The props of the component.
+
+ Returns:
+ The component.
+ """
+ ...
diff --git a/reflex/components/component.py b/reflex/components/component.py
index 440a408df..6d1264f4d 100644
--- a/reflex/components/component.py
+++ b/reflex/components/component.py
@@ -623,8 +623,7 @@ class Component(BaseComponent, ABC):
if props is None:
# Add component props to the tag.
props = {
- attr[:-1] if attr.endswith("_") else attr: getattr(self, attr)
- for attr in self.get_props()
+ attr.removesuffix("_"): getattr(self, attr) for attr in self.get_props()
}
# Add ref to element if `id` is not None.
diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py
index 338fb2e44..dfac0452a 100644
--- a/reflex/components/datadisplay/dataeditor.py
+++ b/reflex/components/datadisplay/dataeditor.py
@@ -347,7 +347,7 @@ class DataEditor(NoSSRComponent):
data_callback = f"getData_{editor_id}"
self.get_cell_content = Var(_js_expr=data_callback)
- code = [f"function {data_callback}([col, row])" "{"]
+ code = [f"function {data_callback}([col, row]){{"]
columns_path = str(self.columns)
data_path = str(self.data)
diff --git a/reflex/components/sonner/toast.pyi b/reflex/components/sonner/toast.pyi
index 829e959d5..632fb0d87 100644
--- a/reflex/components/sonner/toast.pyi
+++ b/reflex/components/sonner/toast.pyi
@@ -177,7 +177,7 @@ class ToastNamespace(ComponentNamespace):
@staticmethod
def __call__(
message: Union[str, Var] = "", level: Optional[str] = None, **props
- ) -> "Optional[EventSpec]":
+ ) -> "EventSpec":
"""Send a toast message.
Args:
diff --git a/reflex/custom_components/custom_components.py b/reflex/custom_components/custom_components.py
index 024e77eee..a54004803 100644
--- a/reflex/custom_components/custom_components.py
+++ b/reflex/custom_components/custom_components.py
@@ -772,7 +772,7 @@ def _validate_project_info():
pyproject_toml = _get_package_config()
project = pyproject_toml["project"]
console.print(
- f'Double check the information before publishing: {project["name"]} version {project["version"]}'
+ f"Double check the information before publishing: {project['name']} version {project['version']}"
)
console.print("Update or enter to keep the current information.")
@@ -784,7 +784,7 @@ def _validate_project_info():
author["name"] = console.ask("Author Name", default=author.get("name", ""))
author["email"] = console.ask("Author Email", default=author.get("email", ""))
- console.print(f'Current keywords are: {project.get("keywords") or []}')
+ console.print(f"Current keywords are: {project.get('keywords') or []}")
keyword_action = console.ask(
"Keep, replace or append?", choices=["k", "r", "a"], default="k"
)
diff --git a/reflex/event.py b/reflex/event.py
index 5ce0f3dc1..f35e88389 100644
--- a/reflex/event.py
+++ b/reflex/event.py
@@ -332,7 +332,7 @@ class EventSpec(EventActionsMixin):
arg = None
try:
for arg in args:
- values.append(LiteralVar.create(value=arg)) # noqa: PERF401
+ values.append(LiteralVar.create(value=arg)) # noqa: PERF401, RUF100
except TypeError as e:
raise EventHandlerTypeError(
f"Arguments to event handlers must be Vars or JSON-serializable. Got {arg} of type {type(arg)}."
diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py
index 2198b82c2..2dea54e17 100644
--- a/reflex/middleware/hydrate_middleware.py
+++ b/reflex/middleware/hydrate_middleware.py
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
from reflex import constants
from reflex.event import Event, get_hydrate_event
from reflex.middleware.middleware import Middleware
-from reflex.state import BaseState, StateUpdate
+from reflex.state import BaseState, StateUpdate, _resolve_delta
if TYPE_CHECKING:
from reflex.app import App
@@ -42,7 +42,7 @@ class HydrateMiddleware(Middleware):
setattr(state, constants.CompileVars.IS_HYDRATED, False)
# Get the initial state.
- delta = state.dict()
+ delta = await _resolve_delta(state.dict())
# since a full dict was captured, clean any dirtiness
state._clean()
diff --git a/reflex/state.py b/reflex/state.py
index 6c74d5e55..92aaa4710 100644
--- a/reflex/state.py
+++ b/reflex/state.py
@@ -15,7 +15,6 @@ import time
import typing
import uuid
from abc import ABC, abstractmethod
-from collections import defaultdict
from hashlib import md5
from pathlib import Path
from types import FunctionType, MethodType
@@ -329,6 +328,25 @@ def get_var_for_field(cls: Type[BaseState], f: ModelField):
)
+async def _resolve_delta(delta: Delta) -> Delta:
+ """Await all coroutines in the delta.
+
+ Args:
+ delta: The delta to process.
+
+ Returns:
+ The same delta dict with all coroutines resolved to their return value.
+ """
+ tasks = {}
+ for state_name, state_delta in delta.items():
+ for var_name, value in state_delta.items():
+ if asyncio.iscoroutine(value):
+ tasks[state_name, var_name] = asyncio.create_task(value)
+ for (state_name, var_name), task in tasks.items():
+ delta[state_name][var_name] = await task
+ return delta
+
+
class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""The state of the app."""
@@ -356,11 +374,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# A set of subclassses of this class.
class_subclasses: ClassVar[Set[Type[BaseState]]] = set()
- # Mapping of var name to set of computed variables that depend on it
- _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
-
- # Mapping of var name to set of substates that depend on it
- _substate_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
+ # Mapping of var name to set of (state_full_name, var_name) that depend on it.
+ _var_dependencies: ClassVar[Dict[str, Set[Tuple[str, str]]]] = {}
# Set of vars which always need to be recomputed
_always_dirty_computed_vars: ClassVar[Set[str]] = set()
@@ -368,6 +383,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Set of substates which always need to be recomputed
_always_dirty_substates: ClassVar[Set[str]] = set()
+ # Set of states which might need to be recomputed if vars in this state change.
+ _potentially_dirty_states: ClassVar[Set[str]] = set()
+
# The parent state.
parent_state: Optional[BaseState] = None
@@ -519,6 +537,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Reset dirty substate tracking for this class.
cls._always_dirty_substates = set()
+ cls._potentially_dirty_states = set()
# Get the parent vars.
parent_state = cls.get_parent_state()
@@ -622,8 +641,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
setattr(cls, name, handler)
# Initialize per-class var dependency tracking.
- cls._computed_var_dependencies = defaultdict(set)
- cls._substate_var_dependencies = defaultdict(set)
+ cls._var_dependencies = {}
cls._init_var_dependency_dicts()
@staticmethod
@@ -768,26 +786,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Additional updates tracking dicts for vars and substates that always
need to be recomputed.
"""
- inherited_vars = set(cls.inherited_vars).union(
- set(cls.inherited_backend_vars),
- )
for cvar_name, cvar in cls.computed_vars.items():
- # Add the dependencies.
- for var in cvar._deps(objclass=cls):
- cls._computed_var_dependencies[var].add(cvar_name)
- if var in inherited_vars:
- # track that this substate depends on its parent for this var
- state_name = cls.get_name()
- parent_state = cls.get_parent_state()
- while parent_state is not None and var in {
- **parent_state.vars,
- **parent_state.backend_vars,
+ if not cvar._cache:
+ # Do not perform dep calculation when cache=False (these are always dirty).
+ continue
+ for state_name, dvar_set in cvar._deps(objclass=cls).items():
+ state_cls = cls.get_root_state().get_class_substate(state_name)
+ for dvar in dvar_set:
+ defining_state_cls = state_cls
+ while dvar in {
+ *defining_state_cls.inherited_vars,
+ *defining_state_cls.inherited_backend_vars,
}:
- parent_state._substate_var_dependencies[var].add(state_name)
- state_name, parent_state = (
- parent_state.get_name(),
- parent_state.get_parent_state(),
- )
+ parent_state = defining_state_cls.get_parent_state()
+ if parent_state is not None:
+ defining_state_cls = parent_state
+ defining_state_cls._var_dependencies.setdefault(dvar, set()).add(
+ (cls.get_full_name(), cvar_name)
+ )
+ defining_state_cls._potentially_dirty_states.add(
+ cls.get_full_name()
+ )
# ComputedVar with cache=False always need to be recomputed
cls._always_dirty_computed_vars = {
@@ -902,6 +921,17 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
raise ValueError(f"Only one parent state is allowed {parent_states}.")
return parent_states[0] if len(parent_states) == 1 else None
+ @classmethod
+ @functools.lru_cache()
+ def get_root_state(cls) -> Type[BaseState]:
+ """Get the root state.
+
+ Returns:
+ The root state.
+ """
+ parent_state = cls.get_parent_state()
+ return cls if parent_state is None else parent_state.get_root_state()
+
@classmethod
def get_substates(cls) -> set[Type[BaseState]]:
"""Get the substates of the state.
@@ -1351,7 +1381,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
super().__setattr__(name, value)
# Add the var to the dirty list.
- if name in self.vars or name in self._computed_var_dependencies:
+ if name in self.base_vars:
self.dirty_vars.add(name)
self._mark_dirty()
@@ -1422,64 +1452,21 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return self.substates[path[0]].get_substate(path[1:])
@classmethod
- def _get_common_ancestor(cls, other: Type[BaseState]) -> str:
- """Find the name of the nearest common ancestor shared by this and the other state.
-
- Args:
- other: The other state.
+ def _get_potentially_dirty_states(cls) -> set[type[BaseState]]:
+ """Get substates which may have dirty vars due to dependencies.
Returns:
- Full name of the nearest common ancestor.
+ The set of potentially dirty substate classes.
"""
- common_ancestor_parts = []
- for part1, part2 in zip(
- cls.get_full_name().split("."),
- other.get_full_name().split("."),
- strict=True,
- ):
- if part1 != part2:
- break
- common_ancestor_parts.append(part1)
- return ".".join(common_ancestor_parts)
-
- @classmethod
- def _determine_missing_parent_states(
- cls, target_state_cls: Type[BaseState]
- ) -> tuple[str, list[str]]:
- """Determine the missing parent states between the target_state_cls and common ancestor of this state.
-
- Args:
- target_state_cls: The class of the state to find missing parent states for.
-
- Returns:
- The name of the common ancestor and the list of missing parent states.
- """
- common_ancestor_name = cls._get_common_ancestor(target_state_cls)
- common_ancestor_parts = common_ancestor_name.split(".")
- target_state_parts = tuple(target_state_cls.get_full_name().split("."))
- relative_target_state_parts = target_state_parts[len(common_ancestor_parts) :]
-
- # Determine which parent states to fetch from the common ancestor down to the target_state_cls.
- fetch_parent_states = [common_ancestor_name]
- for relative_parent_state_name in relative_target_state_parts:
- fetch_parent_states.append(
- ".".join((fetch_parent_states[-1], relative_parent_state_name))
- )
-
- return common_ancestor_name, fetch_parent_states[1:-1]
-
- def _get_parent_states(self) -> list[tuple[str, BaseState]]:
- """Get all parent state instances up to the root of the state tree.
-
- Returns:
- A list of tuples containing the name and the instance of each parent state.
- """
- parent_states_with_name = []
- parent_state = self
- while parent_state.parent_state is not None:
- parent_state = parent_state.parent_state
- parent_states_with_name.append((parent_state.get_full_name(), parent_state))
- return parent_states_with_name
+ return {
+ cls.get_class_substate(substate_name)
+ for substate_name in cls._always_dirty_substates
+ }.union(
+ {
+ cls.get_root_state().get_class_substate(substate_name)
+ for substate_name in cls._potentially_dirty_states
+ }
+ )
def _get_root_state(self) -> BaseState:
"""Get the root state of the state tree.
@@ -1492,55 +1479,38 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
parent_state = parent_state.parent_state
return parent_state
- async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
- """Populate substates in the tree between the target_state_cls and common ancestor of this state.
+ async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
+ """Get a state instance from redis.
Args:
- target_state_cls: The class of the state to populate parent states for.
+ state_cls: The class of the state.
Returns:
- The parent state instance of target_state_cls.
+ The instance of state_cls associated with this state's client_token.
Raises:
RuntimeError: If redis is not used in this backend process.
+ StateMismatchError: If the state instance is not of the expected type.
"""
+ # Then get the target state and all its substates.
state_manager = get_state_manager()
if not isinstance(state_manager, StateManagerRedis):
raise RuntimeError(
- f"Cannot populate parent states of {target_state_cls.get_full_name()} without redis. "
+ f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
"(All states should already be available -- this is likely a bug).",
)
+ state_in_redis = await state_manager.get_state(
+ token=_substate_key(self.router.session.client_token, state_cls),
+ top_level=False,
+ for_state_instance=self,
+ )
- # Find the missing parent states up to the common ancestor.
- (
- common_ancestor_name,
- missing_parent_states,
- ) = self._determine_missing_parent_states(target_state_cls)
-
- # Fetch all missing parent states and link them up to the common ancestor.
- parent_states_tuple = self._get_parent_states()
- root_state = parent_states_tuple[-1][1]
- parent_states_by_name = dict(parent_states_tuple)
- parent_state = parent_states_by_name[common_ancestor_name]
- for parent_state_name in missing_parent_states:
- try:
- parent_state = root_state.get_substate(parent_state_name.split("."))
- # The requested state is already cached, do NOT fetch it again.
- continue
- except ValueError:
- # The requested state is missing, fetch from redis.
- pass
- parent_state = await state_manager.get_state(
- token=_substate_key(
- self.router.session.client_token, parent_state_name
- ),
- top_level=False,
- get_substates=False,
- parent_state=parent_state,
+ if not isinstance(state_in_redis, state_cls):
+ raise StateMismatchError(
+ f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
)
- # Return the direct parent of target_state_cls for subsequent linking.
- return parent_state
+ return state_in_redis
def _get_state_from_cache(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get a state instance from the cache.
@@ -1562,44 +1532,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
)
return substate
- async def _get_state_from_redis(self, state_cls: Type[T_STATE]) -> T_STATE:
- """Get a state instance from redis.
-
- Args:
- state_cls: The class of the state.
-
- Returns:
- The instance of state_cls associated with this state's client_token.
-
- Raises:
- RuntimeError: If redis is not used in this backend process.
- StateMismatchError: If the state instance is not of the expected type.
- """
- # Fetch all missing parent states from redis.
- parent_state_of_state_cls = await self._populate_parent_states(state_cls)
-
- # Then get the target state and all its substates.
- state_manager = get_state_manager()
- if not isinstance(state_manager, StateManagerRedis):
- raise RuntimeError(
- f"Requested state {state_cls.get_full_name()} is not cached and cannot be accessed without redis. "
- "(All states should already be available -- this is likely a bug).",
- )
-
- state_in_redis = await state_manager.get_state(
- token=_substate_key(self.router.session.client_token, state_cls),
- top_level=False,
- get_substates=True,
- parent_state=parent_state_of_state_cls,
- )
-
- if not isinstance(state_in_redis, state_cls):
- raise StateMismatchError(
- f"Searched for state {state_cls.get_full_name()} but found {state_in_redis}."
- )
-
- return state_in_redis
-
async def get_state(self, state_cls: Type[T_STATE]) -> T_STATE:
"""Get an instance of the state associated with this token.
@@ -1738,7 +1670,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
)
- def _as_state_update(
+ async def _as_state_update(
self,
handler: EventHandler,
events: EventSpec | list[EventSpec] | None,
@@ -1766,7 +1698,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
try:
# Get the delta after processing the event.
- delta = state.get_delta()
+ delta = await _resolve_delta(state.get_delta())
state._clean()
return StateUpdate(
@@ -1866,24 +1798,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Handle async generators.
if inspect.isasyncgen(events):
async for event in events:
- yield state._as_state_update(handler, event, final=False)
- yield state._as_state_update(handler, events=None, final=True)
+ yield await state._as_state_update(handler, event, final=False)
+ yield await state._as_state_update(handler, events=None, final=True)
# Handle regular generators.
elif inspect.isgenerator(events):
try:
while True:
- yield state._as_state_update(handler, next(events), final=False)
+ yield await state._as_state_update(
+ handler, next(events), final=False
+ )
except StopIteration as si:
# the "return" value of the generator is not available
# in the loop, we must catch StopIteration to access it
if si.value is not None:
- yield state._as_state_update(handler, si.value, final=False)
- yield state._as_state_update(handler, events=None, final=True)
+ yield await state._as_state_update(
+ handler, si.value, final=False
+ )
+ yield await state._as_state_update(handler, events=None, final=True)
# Handle regular event chains.
else:
- yield state._as_state_update(handler, events, final=True)
+ yield await state._as_state_update(handler, events, final=True)
# If an error occurs, throw a window alert.
except Exception as ex:
@@ -1893,7 +1829,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
)
- yield state._as_state_update(
+ yield await state._as_state_update(
handler,
event_specs,
final=True,
@@ -1901,15 +1837,28 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
def _mark_dirty_computed_vars(self) -> None:
"""Mark ComputedVars that need to be recalculated based on dirty_vars."""
+ # Append expired computed vars to dirty_vars to trigger recalculation
+ self.dirty_vars.update(self._expired_computed_vars())
+ # Append always dirty computed vars to dirty_vars to trigger recalculation
+ self.dirty_vars.update(self._always_dirty_computed_vars)
+
dirty_vars = self.dirty_vars
while dirty_vars:
calc_vars, dirty_vars = dirty_vars, set()
- for cvar in self._dirty_computed_vars(from_vars=calc_vars):
- self.dirty_vars.add(cvar)
+ for state_name, cvar in self._dirty_computed_vars(from_vars=calc_vars):
+ if state_name == self.get_full_name():
+ defining_state = self
+ else:
+ defining_state = self._get_root_state().get_substate(
+ tuple(state_name.split("."))
+ )
+ defining_state.dirty_vars.add(cvar)
dirty_vars.add(cvar)
- actual_var = self.computed_vars.get(cvar)
+ actual_var = defining_state.computed_vars.get(cvar)
if actual_var is not None:
- actual_var.mark_dirty(instance=self)
+ actual_var.mark_dirty(instance=defining_state)
+ if defining_state is not self:
+ defining_state._mark_dirty()
def _expired_computed_vars(self) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the expiration time.
@@ -1925,7 +1874,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
def _dirty_computed_vars(
self, from_vars: set[str] | None = None, include_backend: bool = True
- ) -> set[str]:
+ ) -> set[tuple[str, str]]:
"""Determine ComputedVars that need to be recalculated based on the given vars.
Args:
@@ -1936,33 +1885,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
Set of computed vars to include in the delta.
"""
return {
- cvar
+ (state_name, cvar)
for dirty_var in from_vars or self.dirty_vars
- for cvar in self._computed_var_dependencies[dirty_var]
+ for state_name, cvar in self._var_dependencies.get(dirty_var, set())
if include_backend or not self.computed_vars[cvar]._backend
}
- @classmethod
- def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
- """Determine substates which could be affected by dirty vars in this state.
-
- Returns:
- Set of State classes that may need to be fetched to recalc computed vars.
- """
- # _always_dirty_substates need to be fetched to recalc computed vars.
- fetch_substates = {
- cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
- for substate_name in cls._always_dirty_substates
- }
- for dependent_substates in cls._substate_var_dependencies.values():
- fetch_substates.update(
- {
- cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
- for substate_name in dependent_substates
- }
- )
- return fetch_substates
-
def get_delta(self) -> Delta:
"""Get the delta for the state.
@@ -1971,21 +1899,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""
delta = {}
- # Apply dirty variables down into substates
- self.dirty_vars.update(self._always_dirty_computed_vars)
- self._mark_dirty()
-
+ self._mark_dirty_computed_vars()
frontend_computed_vars: set[str] = {
name for name, cv in self.computed_vars.items() if not cv._backend
}
# Return the dirty vars for this instance, any cached/dependent computed vars,
# and always dirty computed vars (cache=False)
- delta_vars = (
- self.dirty_vars.intersection(self.base_vars)
- .union(self.dirty_vars.intersection(frontend_computed_vars))
- .union(self._dirty_computed_vars(include_backend=False))
- .union(self._always_dirty_computed_vars)
+ delta_vars = self.dirty_vars.intersection(self.base_vars).union(
+ self.dirty_vars.intersection(frontend_computed_vars)
)
subdelta: Dict[str, Any] = {
@@ -2015,23 +1937,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
self.parent_state.dirty_substates.add(self.get_name())
self.parent_state._mark_dirty()
- # Append expired computed vars to dirty_vars to trigger recalculation
- self.dirty_vars.update(self._expired_computed_vars())
-
# have to mark computed vars dirty to allow access to newly computed
# values within the same ComputedVar function
self._mark_dirty_computed_vars()
- self._mark_dirty_substates()
-
- def _mark_dirty_substates(self):
- """Propagate dirty var / computed var status into substates."""
- substates = self.substates
- for var in self.dirty_vars:
- for substate_name in self._substate_var_dependencies[var]:
- self.dirty_substates.add(substate_name)
- substate = substates[substate_name]
- substate.dirty_vars.add(var)
- substate._mark_dirty()
def _update_was_touched(self):
"""Update the _was_touched flag based on dirty_vars."""
@@ -2103,11 +2011,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
The object as a dictionary.
"""
if include_computed:
- # Apply dirty variables down into substates to allow never-cached ComputedVar to
- # trigger recalculation of dependent vars
- self.dirty_vars.update(self._always_dirty_computed_vars)
- self._mark_dirty()
-
+ self._mark_dirty_computed_vars()
base_vars = {
prop_name: self.get_value(prop_name) for prop_name in self.base_vars
}
@@ -2824,7 +2728,7 @@ class StateProxy(wrapt.ObjectProxy):
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
)
- def _as_state_update(self, *args, **kwargs) -> StateUpdate:
+ async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
"""Temporarily allow mutability to access parent_state.
Args:
@@ -2837,7 +2741,7 @@ class StateProxy(wrapt.ObjectProxy):
original_mutable = self._self_mutable
self._self_mutable = True
try:
- return self.__wrapped__._as_state_update(*args, **kwargs)
+ return await self.__wrapped__._as_state_update(*args, **kwargs)
finally:
self._self_mutable = original_mutable
@@ -3313,103 +3217,106 @@ class StateManagerRedis(StateManager):
b"evicted",
}
- async def _get_parent_state(
- self, token: str, state: BaseState | None = None
- ) -> BaseState | None:
- """Get the parent state for the state requested in the token.
+ def _get_required_state_classes(
+ self,
+ target_state_cls: Type[BaseState],
+ subclasses: bool = False,
+ required_state_classes: set[Type[BaseState]] | None = None,
+ ) -> set[Type[BaseState]]:
+ """Recursively determine which states are required to fetch the target state.
+
+ This will always include potentially dirty substates that depend on vars
+ in the target_state_cls.
Args:
- token: The token to get the state for (_substate_key).
- state: The state instance to get parent state for.
+ target_state_cls: The target state class being fetched.
+ subclasses: Whether to include subclasses of the target state.
+ required_state_classes: Recursive argument tracking state classes that have already been seen.
Returns:
- The parent state for the state requested by the token or None if there is no such parent.
+ The set of state classes required to fetch the target state.
"""
- parent_state = None
- client_token, state_path = _split_substate_key(token)
- parent_state_name = state_path.rpartition(".")[0]
- if parent_state_name:
- cached_substates = None
- if state is not None:
- cached_substates = [state]
- # Retrieve the parent state to populate event handlers onto this substate.
- parent_state = await self.get_state(
- token=_substate_key(client_token, parent_state_name),
- top_level=False,
- get_substates=False,
- cached_substates=cached_substates,
+ if required_state_classes is None:
+ required_state_classes = set()
+ # Get the substates if requested.
+ if subclasses:
+ for substate in target_state_cls.get_substates():
+ self._get_required_state_classes(
+ substate,
+ subclasses=True,
+ required_state_classes=required_state_classes,
+ )
+ if target_state_cls in required_state_classes:
+ return required_state_classes
+ required_state_classes.add(target_state_cls)
+
+ # Get dependent substates.
+ for pd_substates in target_state_cls._get_potentially_dirty_states():
+ self._get_required_state_classes(
+ pd_substates,
+ subclasses=False,
+ required_state_classes=required_state_classes,
)
- return parent_state
- async def _populate_substates(
+ # Get the parent state if it exists.
+ if parent_state := target_state_cls.get_parent_state():
+ self._get_required_state_classes(
+ parent_state,
+ subclasses=False,
+ required_state_classes=required_state_classes,
+ )
+ return required_state_classes
+
+ def _get_populated_states(
self,
- token: str,
- state: BaseState,
- all_substates: bool = False,
- ):
- """Fetch and link substates for the given state instance.
-
- There is no return value; the side-effect is that `state` will have `substates` populated,
- and each substate will have its `parent_state` set to `state`.
+ target_state: BaseState,
+ populated_states: dict[str, BaseState] | None = None,
+ ) -> dict[str, BaseState]:
+ """Recursively determine which states from target_state are already fetched.
Args:
- token: The token to get the state for.
- state: The state instance to populate substates for.
- all_substates: Whether to fetch all substates or just required substates.
+ target_state: The state to check for populated states.
+ populated_states: Recursive argument tracking states seen in previous calls.
+
+ Returns:
+ A dictionary of state full name to state instance.
"""
- client_token, _ = _split_substate_key(token)
-
- if all_substates:
- # All substates are requested.
- fetch_substates = state.get_substates()
- else:
- # Only _potentially_dirty_substates need to be fetched to recalc computed vars.
- fetch_substates = state._potentially_dirty_substates()
-
- tasks = {}
- # Retrieve the necessary substates from redis.
- for substate_cls in fetch_substates:
- if substate_cls.get_name() in state.substates:
- continue
- substate_name = substate_cls.get_name()
- tasks[substate_name] = asyncio.create_task(
- self.get_state(
- token=_substate_key(client_token, substate_cls),
- top_level=False,
- get_substates=all_substates,
- parent_state=state,
- )
+ if populated_states is None:
+ populated_states = {}
+ if target_state.get_full_name() in populated_states:
+ return populated_states
+ populated_states[target_state.get_full_name()] = target_state
+ for substate in target_state.substates.values():
+ self._get_populated_states(substate, populated_states=populated_states)
+ if target_state.parent_state is not None:
+ self._get_populated_states(
+ target_state.parent_state, populated_states=populated_states
)
-
- for substate_name, substate_task in tasks.items():
- state.substates[substate_name] = await substate_task
+ return populated_states
@override
async def get_state(
self,
token: str,
top_level: bool = True,
- get_substates: bool = True,
- parent_state: BaseState | None = None,
- cached_substates: list[BaseState] | None = None,
+ for_state_instance: BaseState | None = None,
) -> BaseState:
"""Get the state for a token.
Args:
token: The token to get the state for.
top_level: If true, return an instance of the top-level state (self.state).
- get_substates: If true, also retrieve substates.
- parent_state: If provided, use this parent_state instead of getting it from redis.
- cached_substates: If provided, attach these substates to the state.
+ for_state_instance: If provided, attach the requested states to this existing state tree.
Returns:
The state for the token.
Raises:
- RuntimeError: when the state_cls is not specified in the token
+ RuntimeError: when the state_cls is not specified in the token, or when the parent state for a
+ requested state was not fetched.
"""
# Split the actual token from the fully qualified substate name.
- _, state_path = _split_substate_key(token)
+ token, state_path = _split_substate_key(token)
if state_path:
# Get the State class associated with the given path.
state_cls = self.state.get_class_substate(state_path)
@@ -3418,43 +3325,59 @@ class StateManagerRedis(StateManager):
f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
)
- # The deserialized or newly created (sub)state instance.
- state = None
+ # Determine which states we already have.
+ flat_state_tree: dict[str, BaseState] = (
+ self._get_populated_states(for_state_instance) if for_state_instance else {}
+ )
- # Fetch the serialized substate from redis.
- redis_state = await self.redis.get(token)
+ # Determine which states from the tree need to be fetched.
+ required_state_classes = sorted(
+ self._get_required_state_classes(state_cls, subclasses=True)
+ - {type(s) for s in flat_state_tree.values()},
+ key=lambda x: x.get_full_name(),
+ )
- if redis_state is not None:
- # Deserialize the substate.
- with contextlib.suppress(StateSchemaMismatchError):
- state = BaseState._deserialize(data=redis_state)
- if state is None:
- # Key didn't exist or schema mismatch so create a new instance for this token.
- state = state_cls(
- init_substates=False,
- _reflex_internal_init=True,
- )
- # Populate parent state if missing and requested.
- if parent_state is None:
- parent_state = await self._get_parent_state(token, state)
- # Set up Bidirectional linkage between this state and its parent.
- if parent_state is not None:
- parent_state.substates[state.get_name()] = state
- state.parent_state = parent_state
- # Avoid fetching substates multiple times.
- if cached_substates:
- for substate in cached_substates:
- state.substates[substate.get_name()] = substate
- if substate.parent_state is None:
- substate.parent_state = state
- # Populate substates if requested.
- await self._populate_substates(token, state, all_substates=get_substates)
+ redis_pipeline = self.redis.pipeline()
+ for state_cls in required_state_classes:
+ redis_pipeline.get(_substate_key(token, state_cls))
+
+ for state_cls, redis_state in zip(
+ required_state_classes,
+ await redis_pipeline.execute(),
+ strict=False,
+ ):
+ state = None
+
+ if redis_state is not None:
+ # Deserialize the substate.
+ with contextlib.suppress(StateSchemaMismatchError):
+ state = BaseState._deserialize(data=redis_state)
+ if state is None:
+ # Key didn't exist or schema mismatch so create a new instance for this token.
+ state = state_cls(
+ init_substates=False,
+ _reflex_internal_init=True,
+ )
+ flat_state_tree[state.get_full_name()] = state
+ if state.get_parent_state() is not None:
+ parent_state_name, _dot, state_name = state.get_full_name().rpartition(
+ "."
+ )
+ parent_state = flat_state_tree.get(parent_state_name)
+ if parent_state is None:
+ raise RuntimeError(
+ f"Parent state for {state.get_full_name()} was not found "
+ "in the state tree, but should have already been fetched. "
+ "This is a bug",
+ )
+ parent_state.substates[state_name] = state
+ state.parent_state = parent_state
# To retain compatibility with previous implementation, by default, we return
- # the top-level state by chasing `parent_state` pointers up the tree.
+ # the top-level state which should always be fetched or already cached.
if top_level:
- return state._get_root_state()
- return state
+ return flat_state_tree[self.state.get_full_name()]
+ return flat_state_tree[state_cls.get_full_name()]
@override
async def set_state(
@@ -4154,12 +4077,19 @@ def reload_state_module(
state: Recursive argument for the state class to reload.
"""
+ # Clean out all potentially dirty states of reloaded modules.
+ for pd_state in tuple(state._potentially_dirty_states):
+ with contextlib.suppress(ValueError):
+ if (
+ state.get_root_state().get_class_substate(pd_state).__module__ == module
+ and module is not None
+ ):
+ state._potentially_dirty_states.remove(pd_state)
for subclass in tuple(state.class_subclasses):
reload_state_module(module=module, state=subclass)
if subclass.__module__ == module and module is not None:
state.class_subclasses.remove(subclass)
state._always_dirty_substates.discard(subclass.get_name())
- state._computed_var_dependencies = defaultdict(set)
- state._substate_var_dependencies = defaultdict(set)
+ state._var_dependencies = {}
state._init_var_dependency_dicts()
state.get_class_substate.cache_clear()
diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py
index 479ff816a..67df7ea91 100644
--- a/reflex/utils/exec.py
+++ b/reflex/utils/exec.py
@@ -488,7 +488,7 @@ def output_system_info():
dependencies.append(fnm_info)
if system == "Linux":
- import distro
+ import distro # pyright: ignore[reportMissingImports]
os_version = distro.name(pretty=True)
else:
diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py
index 2978cc792..629198185 100644
--- a/reflex/utils/prerequisites.py
+++ b/reflex/utils/prerequisites.py
@@ -912,7 +912,6 @@ def _update_next_config(
next_config = {
"basePath": config.frontend_path or "",
"compress": config.next_compression,
- "reactStrictMode": config.react_strict_mode,
"trailingSlash": True,
"staticPageGenerationTimeout": config.static_page_generation_timeout,
}
@@ -1855,7 +1854,7 @@ def initialize_main_module_index_from_generation(app_name: str, generation_hash:
[
resp.text,
"",
- "" "def index() -> rx.Component:",
+ "def index() -> rx.Component:",
f" return {render_func_name}()",
"",
"",
diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py
index bd10e1c1c..beb355d31 100644
--- a/reflex/utils/pyi_generator.py
+++ b/reflex/utils/pyi_generator.py
@@ -622,7 +622,7 @@ def _generate_component_create_functiondef(
defaults=[],
)
- definition = ast.FunctionDef(
+ definition = ast.FunctionDef( # pyright: ignore [reportCallIssue]
name="create",
args=create_args,
body=[ # pyright: ignore [reportArgumentType]
@@ -684,7 +684,7 @@ def _generate_staticmethod_call_functiondef(
else []
),
)
- definition = ast.FunctionDef(
+ definition = ast.FunctionDef( # pyright: ignore [reportCallIssue]
name="__call__",
args=call_args,
body=[
@@ -699,6 +699,7 @@ def _generate_staticmethod_call_functiondef(
value=_get_type_hint(
typing.get_type_hints(clz.__call__).get("return", None),
type_hint_globals,
+ is_optional=False,
)
),
)
diff --git a/reflex/vars/base.py b/reflex/vars/base.py
index 8a76f250d..9f6652122 100644
--- a/reflex/vars/base.py
+++ b/reflex/vars/base.py
@@ -5,7 +5,6 @@ from __future__ import annotations
import contextlib
import dataclasses
import datetime
-import dis
import functools
import inspect
import json
@@ -20,6 +19,7 @@ from typing import (
Any,
Callable,
ClassVar,
+ Coroutine,
Dict,
FrozenSet,
Generic,
@@ -40,6 +40,7 @@ from typing import (
overload,
)
+from sqlalchemy.orm import DeclarativeBase
from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
from reflex import constants
@@ -51,7 +52,6 @@ from reflex.utils.exceptions import (
VarAttributeError,
VarDependencyError,
VarTypeError,
- VarValueError,
)
from reflex.utils.format import format_state_name
from reflex.utils.imports import (
@@ -574,7 +574,7 @@ class Var(Generic[VAR_TYPE]):
@overload
@classmethod
- def create( # type: ignore[override]
+ def create( # pyright: ignore[reportOverlappingOverload]
cls,
value: bool,
_var_data: VarData | None = None,
@@ -582,7 +582,7 @@ class Var(Generic[VAR_TYPE]):
@overload
@classmethod
- def create( # type: ignore[override]
+ def create(
cls,
value: int,
_var_data: VarData | None = None,
@@ -606,7 +606,7 @@ class Var(Generic[VAR_TYPE]):
@overload
@classmethod
- def create(
+ def create( # pyright: ignore[reportOverlappingOverload]
cls,
value: None,
_var_data: VarData | None = None,
@@ -1983,7 +1983,7 @@ class ComputedVar(Var[RETURN_TYPE]):
_initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
# Explicit var dependencies to track
- _static_deps: set[str] = dataclasses.field(default_factory=set)
+ _static_deps: dict[str | None, set[str]] = dataclasses.field(default_factory=dict)
# Whether var dependencies should be auto-determined
_auto_deps: bool = dataclasses.field(default=True)
@@ -2053,26 +2053,72 @@ class ComputedVar(Var[RETURN_TYPE]):
object.__setattr__(self, "_update_interval", interval)
- if deps is None:
- deps = []
- else:
- for dep in deps:
- if isinstance(dep, Var):
- continue
- if isinstance(dep, str) and dep != "":
- continue
- raise TypeError(
- "ComputedVar dependencies must be Var instances or var names (non-empty strings)."
- )
object.__setattr__(
self,
"_static_deps",
- {dep._js_expr if isinstance(dep, Var) else dep for dep in deps},
+ self._calculate_static_deps(deps),
)
object.__setattr__(self, "_auto_deps", auto_deps)
object.__setattr__(self, "_fget", fget)
+ def _calculate_static_deps(
+ self,
+ deps: Union[List[Union[str, Var]], dict[str | None, set[str]]] | None = None,
+ ) -> dict[str | None, set[str]]:
+ """Calculate the static dependencies of the computed var from user input or existing dependencies.
+
+ Args:
+ deps: The user input dependencies or existing dependencies.
+
+ Returns:
+ The static dependencies.
+ """
+ if isinstance(deps, dict):
+ # Assume a dict is coming from _replace, so no special processing.
+ return deps
+ _static_deps = {}
+ if deps is not None:
+ for dep in deps:
+ _static_deps = self._add_static_dep(dep, _static_deps)
+ return _static_deps
+
+ def _add_static_dep(
+ self, dep: Union[str, Var], deps: dict[str | None, set[str]] | None = None
+ ) -> dict[str | None, set[str]]:
+ """Add a static dependency to the computed var or existing dependency set.
+
+ Args:
+ dep: The dependency to add.
+ deps: The existing dependency set.
+
+ Returns:
+ The updated dependency set.
+
+ Raises:
+ TypeError: If the computed var dependencies are not Var instances or var names.
+ """
+ if deps is None:
+ deps = self._static_deps
+ if isinstance(dep, Var):
+ state_name = (
+ all_var_data.state
+ if (all_var_data := dep._get_all_var_data()) and all_var_data.state
+ else None
+ )
+ if all_var_data is not None:
+ var_name = all_var_data.field_name
+ else:
+ var_name = dep._js_expr
+ deps.setdefault(state_name, set()).add(var_name)
+ elif isinstance(dep, str) and dep != "":
+ deps.setdefault(None, set()).add(dep)
+ else:
+ raise TypeError(
+ "ComputedVar dependencies must be Var instances or var names (non-empty strings)."
+ )
+ return deps
+
@override
def _replace(
self,
@@ -2093,6 +2139,8 @@ class ComputedVar(Var[RETURN_TYPE]):
Raises:
TypeError: If kwargs contains keys that are not allowed.
"""
+ if "deps" in kwargs:
+ kwargs["deps"] = self._calculate_static_deps(kwargs["deps"])
field_values = {
"fget": kwargs.pop("fget", self._fget),
"initial_value": kwargs.pop("initial_value", self._initial_value),
@@ -2149,6 +2197,13 @@ class ComputedVar(Var[RETURN_TYPE]):
return True
return datetime.datetime.now() - last_updated > self._update_interval
+ @overload
+ def __get__(
+ self: ComputedVar[bool],
+ instance: None,
+ owner: Type,
+ ) -> BooleanVar: ...
+
@overload
def __get__(
self: ComputedVar[int] | ComputedVar[float],
@@ -2233,125 +2288,67 @@ class ComputedVar(Var[RETURN_TYPE]):
setattr(instance, self._last_updated_attr, datetime.datetime.now())
value = getattr(instance, self._cache_attr)
+ self._check_deprecated_return_type(instance, value)
+
+ return value
+
+ def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None:
if not _isinstance(value, self._var_type):
console.error(
f"Computed var '{type(instance).__name__}.{self._js_expr}' must return"
f" type '{self._var_type}', got '{type(value)}'."
)
- return value
-
def _deps(
self,
- objclass: Type,
+ objclass: Type[BaseState],
obj: FunctionType | CodeType | None = None,
- self_name: Optional[str] = None,
- ) -> set[str]:
+ ) -> dict[str, set[str]]:
"""Determine var dependencies of this ComputedVar.
- Save references to attributes accessed on "self". Recursively called
- when the function makes a method call on "self" or define comprehensions
- or nested functions that may reference "self".
+ Save references to attributes accessed on "self" or other fetched states.
+
+ Recursively called when the function makes a method call on "self" or
+ define comprehensions or nested functions that may reference "self".
Args:
objclass: the class obj this ComputedVar is attached to.
obj: the object to disassemble (defaults to the fget function).
- self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
Returns:
- A set of variable names accessed by the given obj.
-
- Raises:
- VarValueError: if the function references the get_state, parent_state, or substates attributes
- (cannot track deps in a related state, only implicitly via parent state).
+ A dictionary mapping state names to the set of variable names
+ accessed by the given obj.
"""
+ from .dep_tracking import DependencyTracker
+
+ d = {}
+ if self._static_deps:
+ d.update(self._static_deps)
+ # None is a placeholder for the current state class.
+ if None in d:
+ d[objclass.get_full_name()] = d.pop(None)
+
if not self._auto_deps:
- return self._static_deps
- d = self._static_deps.copy()
+ return d
+
if obj is None:
fget = self._fget
if fget is not None:
obj = cast(FunctionType, fget)
else:
- return set()
- with contextlib.suppress(AttributeError):
- # unbox functools.partial
- obj = cast(FunctionType, obj.func) # pyright: ignore [reportAttributeAccessIssue]
- with contextlib.suppress(AttributeError):
- # unbox EventHandler
- obj = cast(FunctionType, obj.fn) # pyright: ignore [reportAttributeAccessIssue]
+ return d
- if self_name is None and isinstance(obj, FunctionType):
- try:
- # the first argument to the function is the name of "self" arg
- self_name = obj.__code__.co_varnames[0]
- except (AttributeError, IndexError):
- self_name = None
- if self_name is None:
- # cannot reference attributes on self if method takes no args
- return set()
-
- invalid_names = ["get_state", "parent_state", "substates", "get_substate"]
- self_is_top_of_stack = False
- for instruction in dis.get_instructions(obj):
- if (
- instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
- and instruction.argval == self_name
- ):
- # bytecode loaded the class instance to the top of stack, next load instruction
- # is referencing an attribute on self
- self_is_top_of_stack = True
- continue
- if self_is_top_of_stack and instruction.opname in (
- "LOAD_ATTR",
- "LOAD_METHOD",
- ):
- try:
- ref_obj = getattr(objclass, instruction.argval)
- except Exception:
- ref_obj = None
- if instruction.argval in invalid_names:
- raise VarValueError(
- f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
- )
- if callable(ref_obj):
- # recurse into callable attributes
- d.update(
- self._deps(
- objclass=objclass,
- obj=ref_obj, # pyright: ignore [reportArgumentType]
- )
- )
- # recurse into property fget functions
- elif isinstance(ref_obj, property) and not isinstance(
- ref_obj, ComputedVar
- ):
- d.update(
- self._deps(
- objclass=objclass,
- obj=ref_obj.fget, # pyright: ignore [reportArgumentType]
- )
- )
- elif (
- instruction.argval in objclass.backend_vars
- or instruction.argval in objclass.vars
- ):
- # var access
- d.add(instruction.argval)
- elif instruction.opname == "LOAD_CONST" and isinstance(
- instruction.argval, CodeType
- ):
- # recurse into nested functions / comprehensions, which can reference
- # instance attributes from the outer scope
- d.update(
- self._deps(
- objclass=objclass,
- obj=instruction.argval,
- self_name=self_name,
- )
- )
- self_is_top_of_stack = False
- return d
+ try:
+ return DependencyTracker(
+ func=obj, state_cls=objclass, dependencies=d
+ ).dependencies
+ except Exception as e:
+ console.warn(
+ "Failed to automatically determine dependencies for computed var "
+ f"{objclass.__name__}.{self._js_expr}: {e}. "
+ "Provide static_deps and set auto_deps=False to suppress this warning."
+ )
+ return d
def mark_dirty(self, instance: BaseState) -> None:
"""Mark this ComputedVar as dirty.
@@ -2362,6 +2359,37 @@ class ComputedVar(Var[RETURN_TYPE]):
with contextlib.suppress(AttributeError):
delattr(instance, self._cache_attr)
+ def add_dependency(self, objclass: Type[BaseState], dep: Var):
+ """Explicitly add a dependency to the ComputedVar.
+
+ After adding the dependency, when the `dep` changes, this computed var
+ will be marked dirty.
+
+ Args:
+ objclass: The class obj this ComputedVar is attached to.
+ dep: The dependency to add.
+
+ Raises:
+ VarDependencyError: If the dependency is not a Var instance with a
+ state and field name
+ """
+ if all_var_data := dep._get_all_var_data():
+ state_name = all_var_data.state
+ if state_name:
+ var_name = all_var_data.field_name
+ if var_name:
+ self._static_deps.setdefault(state_name, set()).add(var_name)
+ objclass.get_root_state().get_class_substate(
+ state_name
+ )._var_dependencies.setdefault(var_name, set()).add(
+ (objclass.get_full_name(), self._js_expr)
+ )
+ return
+ raise VarDependencyError(
+ "ComputedVar dependencies must be Var instances with a state and "
+ f"field name, got {dep!r}."
+ )
+
def _determine_var_type(self) -> Type:
"""Get the type of the var.
@@ -2398,6 +2426,126 @@ class DynamicRouteVar(ComputedVar[Union[str, List[str]]]):
pass
+async def _default_async_computed_var(_self: BaseState) -> Any:
+ return None
+
+
+@dataclasses.dataclass(
+ eq=False,
+ frozen=True,
+ init=False,
+ slots=True,
+)
+class AsyncComputedVar(ComputedVar[RETURN_TYPE]):
+ """A computed var that wraps a coroutinefunction."""
+
+ _fget: Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]] = (
+ dataclasses.field(default=_default_async_computed_var)
+ )
+
+ @overload
+ def __get__(
+ self: AsyncComputedVar[bool],
+ instance: None,
+ owner: Type,
+ ) -> BooleanVar: ...
+
+ @overload
+ def __get__(
+ self: AsyncComputedVar[int] | ComputedVar[float],
+ instance: None,
+ owner: Type,
+ ) -> NumberVar: ...
+
+ @overload
+ def __get__(
+ self: AsyncComputedVar[str],
+ instance: None,
+ owner: Type,
+ ) -> StringVar: ...
+
+ @overload
+ def __get__(
+ self: AsyncComputedVar[Mapping[DICT_KEY, DICT_VAL]],
+ instance: None,
+ owner: Type,
+ ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
+
+ @overload
+ def __get__(
+ self: AsyncComputedVar[list[LIST_INSIDE]],
+ instance: None,
+ owner: Type,
+ ) -> ArrayVar[list[LIST_INSIDE]]: ...
+
+ @overload
+ def __get__(
+ self: AsyncComputedVar[tuple[LIST_INSIDE, ...]],
+ instance: None,
+ owner: Type,
+ ) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ...
+
+ @overload
+ def __get__(self, instance: None, owner: Type) -> AsyncComputedVar[RETURN_TYPE]: ...
+
+ @overload
+ def __get__(
+ self, instance: BaseState, owner: Type
+ ) -> Coroutine[None, None, RETURN_TYPE]: ...
+
+ def __get__(
+ self, instance: BaseState | None, owner
+ ) -> Var | Coroutine[None, None, RETURN_TYPE]:
+ """Get the ComputedVar value.
+
+ If the value is already cached on the instance, return the cached value.
+
+ Args:
+ instance: the instance of the class accessing this computed var.
+ owner: the class that this descriptor is attached to.
+
+ Returns:
+ The value of the var for the given instance.
+ """
+ if instance is None:
+ return super(AsyncComputedVar, self).__get__(instance, owner)
+
+ if not self._cache:
+
+ async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE:
+ value = await self.fget(instance)
+ self._check_deprecated_return_type(instance, value)
+ return value
+
+ return _awaitable_result()
+ else:
+ # handle caching
+ async def _awaitable_result(instance: BaseState = instance) -> RETURN_TYPE:
+ if not hasattr(instance, self._cache_attr) or self.needs_update(
+ instance
+ ):
+ # Set cache attr on state instance.
+ setattr(instance, self._cache_attr, await self.fget(instance))
+ # Ensure the computed var gets serialized to redis.
+ instance._was_touched = True
+ # Set the last updated timestamp on the state instance.
+ setattr(instance, self._last_updated_attr, datetime.datetime.now())
+ value = getattr(instance, self._cache_attr)
+ self._check_deprecated_return_type(instance, value)
+ return value
+
+ return _awaitable_result()
+
+ @property
+ def fget(self) -> Callable[[BaseState], Coroutine[None, None, RETURN_TYPE]]:
+ """Get the getter function.
+
+ Returns:
+ The getter function.
+ """
+ return self._fget
+
+
if TYPE_CHECKING:
BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
@@ -2464,10 +2612,27 @@ def computed_var(
raise VarDependencyError("Cannot track dependencies without caching.")
if fget is not None:
- return ComputedVar(fget, cache=cache)
+ if inspect.iscoroutinefunction(fget):
+ computed_var_cls = AsyncComputedVar
+ else:
+ computed_var_cls = ComputedVar
+ return computed_var_cls(
+ fget,
+ initial_value=initial_value,
+ cache=cache,
+ deps=deps,
+ auto_deps=auto_deps,
+ interval=interval,
+ backend=backend,
+ **kwargs,
+ )
def wrapper(fget: Callable[[BASE_STATE], Any]) -> ComputedVar:
- return ComputedVar(
+ if inspect.iscoroutinefunction(fget):
+ computed_var_cls = AsyncComputedVar
+ else:
+ computed_var_cls = ComputedVar
+ return computed_var_cls(
fget,
initial_value=initial_value,
cache=cache,
@@ -3053,10 +3218,16 @@ def dispatch(
V = TypeVar("V")
-BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
+BASE_TYPE = TypeVar("BASE_TYPE", bound=Base | None)
+SQLA_TYPE = TypeVar("SQLA_TYPE", bound=DeclarativeBase | None)
+
+if TYPE_CHECKING:
+ from _typeshed import DataclassInstance
+
+ DATACLASS_TYPE = TypeVar("DATACLASS_TYPE", bound=DataclassInstance | None)
FIELD_TYPE = TypeVar("FIELD_TYPE")
-MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
+MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping | None)
class Field(Generic[FIELD_TYPE]):
@@ -3101,6 +3272,18 @@ class Field(Generic[FIELD_TYPE]):
self: Field[BASE_TYPE], instance: None, owner: Any
) -> ObjectVar[BASE_TYPE]: ...
+ @overload
+ def __get__(
+ self: Field[SQLA_TYPE], instance: None, owner: Any
+ ) -> ObjectVar[SQLA_TYPE]: ...
+
+ if TYPE_CHECKING:
+
+ @overload
+ def __get__(
+ self: Field[DATACLASS_TYPE], instance: None, owner: Any
+ ) -> ObjectVar[DATACLASS_TYPE]: ...
+
@overload
def __get__(self, instance: None, owner: Any) -> Var[FIELD_TYPE]: ...
diff --git a/reflex/vars/datetime.py b/reflex/vars/datetime.py
index c43c24165..a18df78d0 100644
--- a/reflex/vars/datetime.py
+++ b/reflex/vars/datetime.py
@@ -184,7 +184,7 @@ def date_compare_operation(
The result of the operation.
"""
return var_operation_return(
- f"({lhs} { '<' if strict else '<='} {rhs})",
+ f"({lhs} {'<' if strict else '<='} {rhs})",
bool,
)
diff --git a/reflex/vars/dep_tracking.py b/reflex/vars/dep_tracking.py
new file mode 100644
index 000000000..0b2367799
--- /dev/null
+++ b/reflex/vars/dep_tracking.py
@@ -0,0 +1,344 @@
+"""Collection of base classes."""
+
+from __future__ import annotations
+
+import contextlib
+import dataclasses
+import dis
+import enum
+import inspect
+from types import CellType, CodeType, FunctionType
+from typing import TYPE_CHECKING, Any, ClassVar, Type, cast
+
+from reflex.utils.exceptions import VarValueError
+
+if TYPE_CHECKING:
+ from reflex.state import BaseState
+
+ from .base import Var
+
+
+CellEmpty = object()
+
+
+def get_cell_value(cell: CellType) -> Any:
+ """Get the value of a cell object.
+
+ Args:
+ cell: The cell object to get the value from. (func.__closure__ objects)
+
+ Returns:
+ The value from the cell or CellEmpty if a ValueError is raised.
+ """
+ try:
+ return cell.cell_contents
+ except ValueError:
+ return CellEmpty
+
+
+class ScanStatus(enum.Enum):
+ """State of the dis instruction scanning loop."""
+
+ SCANNING = enum.auto()
+ GETTING_ATTR = enum.auto()
+ GETTING_STATE = enum.auto()
+ GETTING_VAR = enum.auto()
+
+
+@dataclasses.dataclass
+class DependencyTracker:
+ """State machine for identifying state attributes that are accessed by a function."""
+
+ func: FunctionType | CodeType = dataclasses.field()
+ state_cls: Type[BaseState] = dataclasses.field()
+
+ dependencies: dict[str, set[str]] = dataclasses.field(default_factory=dict)
+
+ scan_status: ScanStatus = dataclasses.field(default=ScanStatus.SCANNING)
+ top_of_stack: str | None = dataclasses.field(default=None)
+
+ tracked_locals: dict[str, Type[BaseState]] = dataclasses.field(default_factory=dict)
+
+ _getting_state_class: Type[BaseState] | None = dataclasses.field(default=None)
+ _getting_var_instructions: list[dis.Instruction] = dataclasses.field(
+ default_factory=list
+ )
+
+ INVALID_NAMES: ClassVar[list[str]] = ["parent_state", "substates", "get_substate"]
+
+ def __post_init__(self):
+ """After initializing, populate the dependencies dict."""
+ with contextlib.suppress(AttributeError):
+ # unbox functools.partial
+ self.func = cast(FunctionType, self.func.func) # pyright: ignore[reportAttributeAccessIssue]
+ with contextlib.suppress(AttributeError):
+ # unbox EventHandler
+ self.func = cast(FunctionType, self.func.fn) # pyright: ignore[reportAttributeAccessIssue]
+
+ if isinstance(self.func, FunctionType):
+ with contextlib.suppress(AttributeError, IndexError):
+ # the first argument to the function is the name of "self" arg
+ self.tracked_locals[self.func.__code__.co_varnames[0]] = self.state_cls
+
+ self._populate_dependencies()
+
+ def _merge_deps(self, tracker: DependencyTracker) -> None:
+ """Merge dependencies from another DependencyTracker.
+
+ Args:
+ tracker: The DependencyTracker to merge dependencies from.
+ """
+ for state_name, dep_name in tracker.dependencies.items():
+ self.dependencies.setdefault(state_name, set()).update(dep_name)
+
+ def load_attr_or_method(self, instruction: dis.Instruction) -> None:
+ """Handle loading an attribute or method from the object on top of the stack.
+
+ This method directly tracks attributes and recursively merges
+ dependencies from analyzing the dependencies of any methods called.
+
+ Args:
+ instruction: The dis instruction to process.
+
+ Raises:
+ VarValueError: if the attribute is an disallowed name.
+ """
+ from .base import ComputedVar
+
+ if instruction.argval in self.INVALID_NAMES:
+ raise VarValueError(
+ f"Cached var {self!s} cannot access arbitrary state via `{instruction.argval}`."
+ )
+ if instruction.argval == "get_state":
+ # Special case: arbitrary state access requested.
+ self.scan_status = ScanStatus.GETTING_STATE
+ return
+ if instruction.argval == "get_var_value":
+ # Special case: arbitrary var access requested.
+ self.scan_status = ScanStatus.GETTING_VAR
+ return
+
+ # Reset status back to SCANNING after attribute is accessed.
+ self.scan_status = ScanStatus.SCANNING
+ if not self.top_of_stack:
+ return
+ target_state = self.tracked_locals[self.top_of_stack]
+ try:
+ ref_obj = getattr(target_state, instruction.argval)
+ except AttributeError:
+ # Not found on this state class, maybe it is a dynamic attribute that will be picked up later.
+ ref_obj = None
+
+ if isinstance(ref_obj, property) and not isinstance(ref_obj, ComputedVar):
+ # recurse into property fget functions
+ ref_obj = ref_obj.fget
+ if callable(ref_obj):
+ # recurse into callable attributes
+ self._merge_deps(
+ type(self)(func=cast(FunctionType, ref_obj), state_cls=target_state)
+ )
+ elif (
+ instruction.argval in target_state.backend_vars
+ or instruction.argval in target_state.vars
+ ):
+ # var access
+ self.dependencies.setdefault(target_state.get_full_name(), set()).add(
+ instruction.argval
+ )
+
+ def _get_globals(self) -> dict[str, Any]:
+ """Get the globals of the function.
+
+ Returns:
+ The var names and values in the globals of the function.
+ """
+ if isinstance(self.func, CodeType):
+ return {}
+ return self.func.__globals__ # pyright: ignore[reportAttributeAccessIssue]
+
+ def _get_closure(self) -> dict[str, Any]:
+ """Get the closure of the function, with unbound values omitted.
+
+ Returns:
+ The var names and values in the closure of the function.
+ """
+ if isinstance(self.func, CodeType):
+ return {}
+ return {
+ var_name: get_cell_value(cell)
+ for var_name, cell in zip(
+ self.func.__code__.co_freevars, # pyright: ignore[reportAttributeAccessIssue]
+ self.func.__closure__ or (),
+ strict=False,
+ )
+ if get_cell_value(cell) is not CellEmpty
+ }
+
+ def handle_getting_state(self, instruction: dis.Instruction) -> None:
+ """Handle bytecode analysis when `get_state` was called in the function.
+
+ If the wrapped function is getting an arbitrary state and saving it to a
+ local variable, this method associates the local variable name with the
+ state class in self.tracked_locals.
+
+ When an attribute/method is accessed on a tracked local, it will be
+ associated with this state.
+
+ Args:
+ instruction: The dis instruction to process.
+
+ Raises:
+ VarValueError: if the state class cannot be determined from the instruction.
+ """
+ from reflex.state import BaseState
+
+ if instruction.opname == "LOAD_FAST":
+ raise VarValueError(
+ f"Dependency detection cannot identify get_state class from local var {instruction.argval}."
+ )
+ if isinstance(self.func, CodeType):
+ raise VarValueError(
+ "Dependency detection cannot identify get_state class from a code object."
+ )
+ if instruction.opname == "LOAD_GLOBAL":
+ # Special case: referencing state class from global scope.
+ try:
+ self._getting_state_class = self._get_globals()[instruction.argval]
+ except (ValueError, KeyError) as ve:
+ raise VarValueError(
+ f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, not found in globals."
+ ) from ve
+ elif instruction.opname == "LOAD_DEREF":
+ # Special case: referencing state class from closure.
+ try:
+ self._getting_state_class = self._get_closure()[instruction.argval]
+ except (ValueError, KeyError) as ve:
+ raise VarValueError(
+ f"Cached var {self!s} cannot access arbitrary state `{instruction.argval}`, is it defined yet?"
+ ) from ve
+ elif instruction.opname == "STORE_FAST":
+ # Storing the result of get_state in a local variable.
+ if not isinstance(self._getting_state_class, type) or not issubclass(
+ self._getting_state_class, BaseState
+ ):
+ raise VarValueError(
+ f"Cached var {self!s} cannot determine dependencies in fetched state `{instruction.argval}`."
+ )
+ self.tracked_locals[instruction.argval] = self._getting_state_class
+ self.scan_status = ScanStatus.SCANNING
+ self._getting_state_class = None
+
+ def _eval_var(self) -> Var:
+ """Evaluate instructions from the wrapped function to get the Var object.
+
+ Returns:
+ The Var object.
+
+ Raises:
+ VarValueError: if the source code for the var cannot be determined.
+ """
+ # Get the original source code and eval it to get the Var.
+ module = inspect.getmodule(self.func)
+ positions0 = self._getting_var_instructions[0].positions
+ positions1 = self._getting_var_instructions[-1].positions
+ if module is None or positions0 is None or positions1 is None:
+ raise VarValueError(
+ f"Cannot determine the source code for the var in {self.func!r}."
+ )
+ start_line = positions0.lineno
+ start_column = positions0.col_offset
+ end_line = positions1.end_lineno
+ end_column = positions1.end_col_offset
+ if (
+ start_line is None
+ or start_column is None
+ or end_line is None
+ or end_column is None
+ ):
+ raise VarValueError(
+ f"Cannot determine the source code for the var in {self.func!r}."
+ )
+ source = inspect.getsource(module).splitlines(True)[start_line - 1 : end_line]
+ # Create a python source string snippet.
+ if len(source) > 1:
+ snipped_source = "".join(
+ [
+ *source[0][start_column:],
+ *(source[1:-2] if len(source) > 2 else []),
+ *source[-1][: end_column - 1],
+ ]
+ )
+ else:
+ snipped_source = source[0][start_column : end_column - 1]
+ # Evaluate the string in the context of the function's globals and closure.
+ return eval(f"({snipped_source})", self._get_globals(), self._get_closure())
+
+ def handle_getting_var(self, instruction: dis.Instruction) -> None:
+ """Handle bytecode analysis when `get_var_value` was called in the function.
+
+ This only really works if the expression passed to `get_var_value` is
+ evaluable in the function's global scope or closure, so getting the var
+ value from a var saved in a local variable or in the class instance is
+ not possible.
+
+ Args:
+ instruction: The dis instruction to process.
+
+ Raises:
+ VarValueError: if the source code for the var cannot be determined.
+ """
+ if instruction.opname == "CALL" and self._getting_var_instructions:
+ if self._getting_var_instructions:
+ the_var = self._eval_var()
+ the_var_data = the_var._get_all_var_data()
+ if the_var_data is None:
+ raise VarValueError(
+ f"Cannot determine the source code for the var in {self.func!r}."
+ )
+ self.dependencies.setdefault(the_var_data.state, set()).add(
+ the_var_data.field_name
+ )
+ self._getting_var_instructions.clear()
+ self.scan_status = ScanStatus.SCANNING
+ else:
+ self._getting_var_instructions.append(instruction)
+
+ def _populate_dependencies(self) -> None:
+ """Update self.dependencies based on the disassembly of self.func.
+
+ Save references to attributes accessed on "self" or other fetched states.
+
+ Recursively called when the function makes a method call on "self" or
+ define comprehensions or nested functions that may reference "self".
+ """
+ for instruction in dis.get_instructions(self.func):
+ if self.scan_status == ScanStatus.GETTING_STATE:
+ self.handle_getting_state(instruction)
+ elif self.scan_status == ScanStatus.GETTING_VAR:
+ self.handle_getting_var(instruction)
+ elif (
+ instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
+ and instruction.argval in self.tracked_locals
+ ):
+ # bytecode loaded the class instance to the top of stack, next load instruction
+ # is referencing an attribute on self
+ self.top_of_stack = instruction.argval
+ self.scan_status = ScanStatus.GETTING_ATTR
+ elif self.scan_status == ScanStatus.GETTING_ATTR and instruction.opname in (
+ "LOAD_ATTR",
+ "LOAD_METHOD",
+ ):
+ self.load_attr_or_method(instruction)
+ self.top_of_stack = None
+ elif instruction.opname == "LOAD_CONST" and isinstance(
+ instruction.argval, CodeType
+ ):
+ # recurse into nested functions / comprehensions, which can reference
+ # instance attributes from the outer scope
+ self._merge_deps(
+ type(self)(
+ func=instruction.argval,
+ state_cls=self.state_cls,
+ tracked_locals=self.tracked_locals,
+ )
+ )
diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py
index dfd9a6af8..fb797b4ec 100644
--- a/reflex/vars/sequence.py
+++ b/reflex/vars/sequence.py
@@ -53,8 +53,11 @@ from .number import (
)
if TYPE_CHECKING:
+ from .base import BASE_TYPE, DATACLASS_TYPE, SQLA_TYPE
+ from .function import FunctionVar
from .object import ObjectVar
+
STRING_TYPE = TypeVar("STRING_TYPE", default=str)
@@ -961,6 +964,24 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
i: int | NumberVar,
) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
+ @overload
+ def __getitem__(
+ self: ARRAY_VAR_OF_LIST_ELEMENT[BASE_TYPE],
+ i: int | NumberVar,
+ ) -> ObjectVar[BASE_TYPE]: ...
+
+ @overload
+ def __getitem__(
+ self: ARRAY_VAR_OF_LIST_ELEMENT[SQLA_TYPE],
+ i: int | NumberVar,
+ ) -> ObjectVar[SQLA_TYPE]: ...
+
+ @overload
+ def __getitem__(
+ self: ARRAY_VAR_OF_LIST_ELEMENT[DATACLASS_TYPE],
+ i: int | NumberVar,
+ ) -> ObjectVar[DATACLASS_TYPE]: ...
+
@overload
def __getitem__(self, i: int | NumberVar) -> Var: ...
@@ -1648,10 +1669,6 @@ def repeat_array_operation(
)
-if TYPE_CHECKING:
- from .function import FunctionVar
-
-
@var_operation
def map_array_operation(
array: ArrayVar[ARRAY_VAR_TYPE],
diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py
index 4867cf868..e6a8caef6 100644
--- a/tests/integration/test_connection_banner.py
+++ b/tests/integration/test_connection_banner.py
@@ -136,9 +136,9 @@ def _assert_token(connection_banner, driver):
driver: Selenium webdriver instance.
"""
ss = SessionStorage(driver)
- assert connection_banner._poll_for(
- lambda: ss.get("token") is not None
- ), "token not found"
+ assert connection_banner._poll_for(lambda: ss.get("token") is not None), (
+ "token not found"
+ )
@pytest.mark.asyncio
@@ -153,7 +153,6 @@ async def test_connection_banner(connection_banner: AppHarness):
driver = connection_banner.frontend()
_assert_token(connection_banner, driver)
-
assert connection_banner._poll_for(lambda: not has_error_modal(driver))
delay_button = driver.find_element(By.ID, "delay")
diff --git a/tests/integration/tests_playwright/test_table.py b/tests/integration/tests_playwright/test_table.py
index bd399a840..a88c4a621 100644
--- a/tests/integration/tests_playwright/test_table.py
+++ b/tests/integration/tests_playwright/test_table.py
@@ -3,7 +3,7 @@
from typing import Generator
import pytest
-from playwright.sync_api import Page
+from playwright.sync_api import Page, expect
from reflex.testing import AppHarness
@@ -87,12 +87,14 @@ def test_table(page: Page, table_app: AppHarness):
table = page.get_by_role("table")
# Check column headers
- headers = table.get_by_role("columnheader").all_inner_texts()
- assert headers == expected_col_headers
+ headers = table.get_by_role("columnheader")
+ for header, exp_value in zip(headers.all(), expected_col_headers, strict=True):
+ expect(header).to_have_text(exp_value)
# Check rows headers
- rows = table.get_by_role("rowheader").all_inner_texts()
- assert rows == expected_row_headers
+ rows = table.get_by_role("rowheader")
+ for row, expected_row in zip(rows.all(), expected_row_headers, strict=True):
+ expect(row).to_have_text(expected_row)
# Check cells
rows = table.get_by_role("cell").all_inner_texts()
diff --git a/tests/units/components/core/test_colors.py b/tests/units/components/core/test_colors.py
index 15490e576..31cd75b47 100644
--- a/tests/units/components/core/test_colors.py
+++ b/tests/units/components/core/test_colors.py
@@ -55,13 +55,13 @@ def create_color_var(color):
Color,
),
(
- create_color_var(f'{rx.color(ColorState.color, f"{ColorState.shade}")}'), # pyright: ignore [reportArgumentType]
+ create_color_var(f"{rx.color(ColorState.color, f'{ColorState.shade}')}"), # pyright: ignore [reportArgumentType]
f'("var(--"+{color_state_name!s}.color+"-"+{color_state_name!s}.shade+")")',
str,
),
(
create_color_var(
- f'{rx.color(f"{ColorState.color}", f"{ColorState.shade}")}' # pyright: ignore [reportArgumentType]
+ f"{rx.color(f'{ColorState.color}', f'{ColorState.shade}')}" # pyright: ignore [reportArgumentType]
),
f'("var(--"+{color_state_name!s}.color+"-"+{color_state_name!s}.shade+")")',
str,
diff --git a/tests/units/test_app.py b/tests/units/test_app.py
index 4a6c16d6e..058174a1b 100644
--- a/tests/units/test_app.py
+++ b/tests/units/test_app.py
@@ -277,9 +277,9 @@ def test_add_page_set_route_dynamic(index_page, windows_platform: bool):
assert app._pages.keys() == {"test/[dynamic]"}
assert "dynamic" in app._state.computed_vars
assert app._state.computed_vars["dynamic"]._deps(objclass=EmptyState) == {
- constants.ROUTER
+ EmptyState.get_full_name(): {constants.ROUTER},
}
- assert constants.ROUTER in app._state()._computed_var_dependencies
+ assert constants.ROUTER in app._state()._var_dependencies
def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool):
@@ -995,9 +995,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
assert arg_name in app._state.vars
assert arg_name in app._state.computed_vars
assert app._state.computed_vars[arg_name]._deps(objclass=DynamicState) == {
- constants.ROUTER
+ DynamicState.get_full_name(): {constants.ROUTER},
}
- assert constants.ROUTER in app._state()._computed_var_dependencies
+ assert constants.ROUTER in app._state()._var_dependencies
substate_token = _substate_key(token, DynamicState)
sid = "mock_sid"
@@ -1274,12 +1274,23 @@ def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]:
yield app, web_dir
-def test_app_wrap_compile_theme(compilable_app: tuple[App, Path]):
+@pytest.mark.parametrize(
+ "react_strict_mode",
+ [True, False],
+)
+def test_app_wrap_compile_theme(
+ react_strict_mode: bool, compilable_app: tuple[App, Path], mocker
+):
"""Test that the radix theme component wraps the app.
Args:
+ react_strict_mode: Whether to use React Strict Mode.
compilable_app: compilable_app fixture.
+ mocker: pytest mocker object.
"""
+ conf = rx.Config(app_name="testing", react_strict_mode=react_strict_mode)
+ mocker.patch("reflex.config._get_config", return_value=conf)
+
app, web_dir = compilable_app
app.theme = rx.theme(accent_color="plum")
app._compile()
@@ -1290,24 +1301,37 @@ def test_app_wrap_compile_theme(compilable_app: tuple[App, Path]):
assert (
"function AppWrap({children}) {"
"return ("
- ""
+ + ("" if react_strict_mode else "")
+ + ""
""
""
"{children}"
""
""
""
- ")"
+ + ("" if react_strict_mode else "")
+ + ")"
"}"
) in "".join(app_js_lines)
-def test_app_wrap_priority(compilable_app: tuple[App, Path]):
+@pytest.mark.parametrize(
+ "react_strict_mode",
+ [True, False],
+)
+def test_app_wrap_priority(
+ react_strict_mode: bool, compilable_app: tuple[App, Path], mocker
+):
"""Test that the app wrap components are wrapped in the correct order.
Args:
+ react_strict_mode: Whether to use React Strict Mode.
compilable_app: compilable_app fixture.
+ mocker: pytest mocker object.
"""
+ conf = rx.Config(app_name="testing", react_strict_mode=react_strict_mode)
+ mocker.patch("reflex.config._get_config", return_value=conf)
+
app, web_dir = compilable_app
class Fragment1(Component):
@@ -1339,8 +1363,7 @@ def test_app_wrap_priority(compilable_app: tuple[App, Path]):
]
assert (
"function AppWrap({children}) {"
- "return ("
- ""
+ "return (" + ("" if react_strict_mode else "") + ""
''
""
""
@@ -1350,8 +1373,7 @@ def test_app_wrap_priority(compilable_app: tuple[App, Path]):
""
""
""
- ""
- ")"
+ "" + ("" if react_strict_mode else "") + ")"
"}"
) in "".join(app_js_lines)
@@ -1555,6 +1577,16 @@ def test_app_with_valid_var_dependencies(compilable_app: tuple[App, Path]):
def bar(self) -> str:
return "bar"
+ class Child1(ValidDepState):
+ @computed_var(deps=["base", ValidDepState.bar])
+ def other(self) -> str:
+ return "other"
+
+ class Child2(ValidDepState):
+ @computed_var(deps=["base", Child1.other])
+ def other(self) -> str:
+ return "other"
+
app._state = ValidDepState
app._compile()
diff --git a/tests/units/test_prerequisites.py b/tests/units/test_prerequisites.py
index 3bd029077..4723d8648 100644
--- a/tests/units/test_prerequisites.py
+++ b/tests/units/test_prerequisites.py
@@ -32,7 +32,7 @@ runner = CliRunner()
app_name="test",
),
False,
- 'module.exports = {basePath: "", compress: true, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
+ 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
),
(
Config(
@@ -40,7 +40,7 @@ runner = CliRunner()
static_page_generation_timeout=30,
),
False,
- 'module.exports = {basePath: "", compress: true, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 30};',
+ 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 30};',
),
(
Config(
@@ -48,7 +48,7 @@ runner = CliRunner()
next_compression=False,
),
False,
- 'module.exports = {basePath: "", compress: false, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
+ 'module.exports = {basePath: "", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60};',
),
(
Config(
@@ -56,7 +56,7 @@ runner = CliRunner()
frontend_path="/test",
),
False,
- 'module.exports = {basePath: "/test", compress: true, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
+ 'module.exports = {basePath: "/test", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
),
(
Config(
@@ -65,14 +65,14 @@ runner = CliRunner()
next_compression=False,
),
False,
- 'module.exports = {basePath: "/test", compress: false, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60};',
+ 'module.exports = {basePath: "/test", compress: false, trailingSlash: true, staticPageGenerationTimeout: 60};',
),
(
Config(
app_name="test",
),
True,
- 'module.exports = {basePath: "", compress: true, reactStrictMode: true, trailingSlash: true, staticPageGenerationTimeout: 60, output: "export", distDir: "_static"};',
+ 'module.exports = {basePath: "", compress: true, trailingSlash: true, staticPageGenerationTimeout: 60, output: "export", distDir: "_static"};',
),
],
)
diff --git a/tests/units/test_state.py b/tests/units/test_state.py
index 9e1932305..e0390c5ac 100644
--- a/tests/units/test_state.py
+++ b/tests/units/test_state.py
@@ -14,6 +14,7 @@ from typing import (
Any,
AsyncGenerator,
Callable,
+ ClassVar,
Dict,
List,
Optional,
@@ -1169,13 +1170,17 @@ def test_conditional_computed_vars():
ms = MainState()
# Initially there are no dirty computed vars.
- assert ms._dirty_computed_vars(from_vars={"flag"}) == {"rendered_var"}
- assert ms._dirty_computed_vars(from_vars={"t2"}) == {"rendered_var"}
- assert ms._dirty_computed_vars(from_vars={"t1"}) == {"rendered_var"}
+ assert ms._dirty_computed_vars(from_vars={"flag"}) == {
+ (MainState.get_full_name(), "rendered_var")
+ }
+ assert ms._dirty_computed_vars(from_vars={"t2"}) == {
+ (MainState.get_full_name(), "rendered_var")
+ }
+ assert ms._dirty_computed_vars(from_vars={"t1"}) == {
+ (MainState.get_full_name(), "rendered_var")
+ }
assert ms.computed_vars["rendered_var"]._deps(objclass=MainState) == {
- "flag",
- "t1",
- "t2",
+ MainState.get_full_name(): {"flag", "t1", "t2"}
}
@@ -1370,7 +1375,10 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
assert isinstance(HandlerState.handler, EventHandler)
s = HandlerState()
- assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
+ assert (
+ HandlerState.get_full_name(),
+ "cached_x_side_effect",
+ ) in s._var_dependencies["x"]
assert s.cached_x_side_effect == 1
assert s.x == 43
s.handler()
@@ -1460,15 +1468,15 @@ def test_computed_var_dependencies():
return [z in self._z for z in range(5)]
cs = ComputedState()
- assert cs._computed_var_dependencies["v"] == {
- "comp_v",
- "comp_v_backend",
- "comp_v_via_property",
+ assert cs._var_dependencies["v"] == {
+ (ComputedState.get_full_name(), "comp_v"),
+ (ComputedState.get_full_name(), "comp_v_backend"),
+ (ComputedState.get_full_name(), "comp_v_via_property"),
}
- assert cs._computed_var_dependencies["w"] == {"comp_w"}
- assert cs._computed_var_dependencies["x"] == {"comp_x"}
- assert cs._computed_var_dependencies["y"] == {"comp_y"}
- assert cs._computed_var_dependencies["_z"] == {"comp_z"}
+ assert cs._var_dependencies["w"] == {(ComputedState.get_full_name(), "comp_w")}
+ assert cs._var_dependencies["x"] == {(ComputedState.get_full_name(), "comp_x")}
+ assert cs._var_dependencies["y"] == {(ComputedState.get_full_name(), "comp_y")}
+ assert cs._var_dependencies["_z"] == {(ComputedState.get_full_name(), "comp_z")}
def test_backend_method():
@@ -1615,7 +1623,7 @@ async def test_state_with_invalid_yield(capsys, mock_app):
id="backend_error",
position="top-center",
style={"width": "500px"},
- ) # pyright: ignore [reportCallIssue, reportArgumentType]
+ )
],
token="",
)
@@ -3180,7 +3188,7 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str):
RxState = State
-def test_potentially_dirty_substates():
+def test_potentially_dirty_states():
"""Test that potentially_dirty_substates returns the correct substates.
Even if the name "State" is shadowed, it should still work correctly.
@@ -3196,13 +3204,19 @@ def test_potentially_dirty_substates():
def bar(self) -> str:
return ""
- assert RxState._potentially_dirty_substates() == set()
- assert State._potentially_dirty_substates() == set()
- assert C1._potentially_dirty_substates() == set()
+ assert RxState._get_potentially_dirty_states() == set()
+ assert State._get_potentially_dirty_states() == set()
+ assert C1._get_potentially_dirty_states() == set()
-def test_router_var_dep() -> None:
- """Test that router var dependencies are correctly tracked."""
+@pytest.mark.asyncio
+async def test_router_var_dep(state_manager: StateManager, token: str) -> None:
+ """Test that router var dependencies are correctly tracked.
+
+ Args:
+ state_manager: A state manager.
+ token: A token.
+ """
class RouterVarParentState(State):
"""A parent state for testing router var dependency."""
@@ -3219,30 +3233,27 @@ def test_router_var_dep() -> None:
foo = RouterVarDepState.computed_vars["foo"]
State._init_var_dependency_dicts()
- assert foo._deps(objclass=RouterVarDepState) == {"router"}
- assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
- assert RouterVarParentState._substate_var_dependencies == {
- "router": {RouterVarDepState.get_name()}
- }
- assert RouterVarDepState._computed_var_dependencies == {
- "router": {"foo"},
+ assert foo._deps(objclass=RouterVarDepState) == {
+ RouterVarDepState.get_full_name(): {"router"}
}
+ assert (RouterVarDepState.get_full_name(), "foo") in State._var_dependencies[
+ "router"
+ ]
- rx_state = State()
- parent_state = RouterVarParentState()
- state = RouterVarDepState()
-
- # link states
- rx_state.substates = {RouterVarParentState.get_name(): parent_state}
- parent_state.parent_state = rx_state
- state.parent_state = parent_state
- parent_state.substates = {RouterVarDepState.get_name(): state}
+ # Get state from state manager.
+ state_manager.state = State
+ rx_state = await state_manager.get_state(_substate_key(token, State))
+ assert RouterVarParentState.get_name() in rx_state.substates
+ parent_state = rx_state.substates[RouterVarParentState.get_name()]
+ assert RouterVarDepState.get_name() in parent_state.substates
+ state = parent_state.substates[RouterVarDepState.get_name()]
assert state.dirty_vars == set()
# Reassign router var
state.router = state.router
- assert state.dirty_vars == {"foo", "router"}
+ assert rx_state.dirty_vars == {"router"}
+ assert state.dirty_vars == {"foo"}
assert parent_state.dirty_substates == {RouterVarDepState.get_name()}
@@ -3801,3 +3812,128 @@ async def test_get_var_value(state_manager: StateManager, substate_token: str):
# Generic Var with no state
with pytest.raises(UnretrievableVarValueError):
await state.get_var_value(rx.Var("undefined"))
+
+
+@pytest.mark.asyncio
+async def test_async_computed_var_get_state(mock_app: rx.App, token: str):
+ """A test where an async computed var depends on a var in another state.
+
+ Args:
+ mock_app: An app that will be returned by `get_app()`
+ token: A token.
+ """
+
+ class Parent(BaseState):
+ """A root state like rx.State."""
+
+ parent_var: int = 0
+
+ class Child2(Parent):
+ """An unconnected child state."""
+
+ pass
+
+ class Child3(Parent):
+ """A child state with a computed var causing it to be pre-fetched.
+
+ If child3_var gets set to a value, and `get_state` erroneously
+ re-fetches it from redis, the value will be lost.
+ """
+
+ child3_var: int = 0
+
+ @rx.var(cache=True)
+ def v(self) -> int:
+ return self.child3_var
+
+ class Child(Parent):
+ """A state simulating UpdateVarsInternalState."""
+
+ @rx.var(cache=True)
+ async def v(self) -> int:
+ p = await self.get_state(Parent)
+ child3 = await self.get_state(Child3)
+ return child3.child3_var + p.parent_var
+
+ mock_app.state_manager.state = mock_app._state = Parent
+
+ # Get the top level state via unconnected sibling.
+ root = await mock_app.state_manager.get_state(_substate_key(token, Child))
+ # Set value in parent_var to assert it does not get refetched later.
+ root.parent_var = 1
+
+ if isinstance(mock_app.state_manager, StateManagerRedis):
+ # When redis is used, only states with uncached computed vars are pre-fetched.
+ assert Child2.get_name() not in root.substates
+ assert Child3.get_name() not in root.substates
+
+ # Get the unconnected sibling state, which will be used to `get_state` other instances.
+ child = root.get_substate(Child.get_full_name().split("."))
+
+ # Get an uncached child state.
+ child2 = await child.get_state(Child2)
+ assert child2.parent_var == 1
+
+ # Set value on already-cached Child3 state (prefetched because it has a Computed Var).
+ child3 = await child.get_state(Child3)
+ child3.child3_var = 1
+
+ assert await child.v == 2
+ assert await child.v == 2
+ root.parent_var = 2
+ assert await child.v == 3
+
+
+class Table(rx.ComponentState):
+ """A table state."""
+
+ data: ClassVar[Var]
+
+ @rx.var(cache=True, auto_deps=False)
+ async def rows(self) -> List[Dict[str, Any]]:
+ """Computed var over the given rows.
+
+ Returns:
+ The data rows.
+ """
+ return await self.get_var_value(self.data)
+
+ @classmethod
+ def get_component(cls, data: Var) -> rx.Component:
+ """Get the component for the table.
+
+ Args:
+ data: The data var.
+
+ Returns:
+ The component.
+ """
+ cls.data = data
+ cls.computed_vars["rows"].add_dependency(cls, data)
+ return rx.foreach(data, lambda d: rx.text(d.to_string()))
+
+
+@pytest.mark.asyncio
+async def test_async_computed_var_get_var_value(mock_app: rx.App, token: str):
+ """A test where an async computed var depends on a var in another state.
+
+ Args:
+ mock_app: An app that will be returned by `get_app()`
+ token: A token.
+ """
+
+ class OtherState(rx.State):
+ """A state with a var."""
+
+ data: List[Dict[str, Any]] = [{"foo": "bar"}]
+
+ mock_app.state_manager.state = mock_app._state = rx.State
+ comp = Table.create(data=OtherState.data)
+ state = await mock_app.state_manager.get_state(_substate_key(token, OtherState))
+ other_state = await state.get_state(OtherState)
+ assert comp.State is not None
+ comp_state = await state.get_state(comp.State)
+ assert comp_state.dirty_vars == set()
+
+ other_state.data.append({"foo": "baz"})
+ assert "rows" in comp_state.dirty_vars
diff --git a/tests/units/test_var.py b/tests/units/test_var.py
index ef19e86e8..a72242814 100644
--- a/tests/units/test_var.py
+++ b/tests/units/test_var.py
@@ -1807,9 +1807,9 @@ def cv_fget(state: BaseState) -> int:
@pytest.mark.parametrize(
"deps,expected",
[
- (["a"], {"a"}),
- (["b"], {"b"}),
- ([ComputedVar(fget=cv_fget)], {"cv_fget"}),
+ (["a"], {None: {"a"}}),
+ (["b"], {None: {"b"}}),
+ ([ComputedVar(fget=cv_fget)], {None: {"cv_fget"}}),
],
)
def test_computed_var_deps(deps: List[Union[str, Var]], expected: Set[str]):
@@ -1857,6 +1857,28 @@ def test_to_string_operation():
assert single_var._var_type == Email
+@pytest.mark.asyncio
+async def test_async_computed_var():
+ side_effect_counter = 0
+
+ class AsyncComputedVarState(BaseState):
+ v: int = 1
+
+ @computed_var(cache=True)
+ async def async_computed_var(self) -> int:
+ nonlocal side_effect_counter
+ side_effect_counter += 1
+ return self.v + 1
+
+ my_state = AsyncComputedVarState()
+ assert await my_state.async_computed_var == 2
+ assert await my_state.async_computed_var == 2
+ my_state.v = 2
+ assert await my_state.async_computed_var == 3
+ assert await my_state.async_computed_var == 3
+ assert side_effect_counter == 2
+
+
def test_var_data_hooks():
var_data_str = VarData(hooks="what")
var_data_list = VarData(hooks=["what"])
diff --git a/tests/units/vars/test_object.py b/tests/units/vars/test_object.py
index efcb21166..90e34be96 100644
--- a/tests/units/vars/test_object.py
+++ b/tests/units/vars/test_object.py
@@ -1,10 +1,14 @@
+import dataclasses
+
import pytest
+from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
from typing_extensions import assert_type
import reflex as rx
from reflex.utils.types import GenericType
from reflex.vars.base import Var
from reflex.vars.object import LiteralObjectVar, ObjectVar
+from reflex.vars.sequence import ArrayVar
class Bare:
@@ -32,14 +36,44 @@ class Base(rx.Base):
quantity: int = 0
+class SqlaBase(DeclarativeBase, MappedAsDataclass):
+ """Sqlalchemy declarative mapping base class."""
+
+ pass
+
+
+class SqlaModel(SqlaBase):
+ """A sqlalchemy model with a single attribute."""
+
+ __tablename__: str = "sqla_model"
+
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
+ quantity: Mapped[int] = mapped_column(default=0)
+
+
+@dataclasses.dataclass
+class Dataclass:
+ """A dataclass with a single attribute."""
+
+ quantity: int = 0
+
+
class ObjectState(rx.State):
- """A reflex state with bare and base objects."""
+ """A reflex state with bare, base and sqlalchemy base vars."""
bare: rx.Field[Bare] = rx.field(Bare())
+ bare_optional: rx.Field[Bare | None] = rx.field(None)
base: rx.Field[Base] = rx.field(Base())
+ base_optional: rx.Field[Base | None] = rx.field(None)
+ sqlamodel: rx.Field[SqlaModel] = rx.field(SqlaModel())
+ sqlamodel_optional: rx.Field[SqlaModel | None] = rx.field(None)
+ dataclass: rx.Field[Dataclass] = rx.field(Dataclass())
+ dataclass_optional: rx.Field[Dataclass | None] = rx.field(None)
+
+ base_list: rx.Field[list[Base]] = rx.field([Base()])
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_var_create(type_: GenericType) -> None:
my_object = type_()
var = Var.create(my_object)
@@ -49,7 +83,7 @@ def test_var_create(type_: GenericType) -> None:
assert quantity._var_type is int
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_literal_create(type_: GenericType) -> None:
my_object = type_()
var = LiteralObjectVar.create(my_object)
@@ -59,7 +93,7 @@ def test_literal_create(type_: GenericType) -> None:
assert quantity._var_type is int
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_guess(type_: GenericType) -> None:
my_object = type_()
var = Var.create(my_object)
@@ -70,7 +104,7 @@ def test_guess(type_: GenericType) -> None:
assert quantity._var_type is int
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_state(type_: GenericType) -> None:
attr_name = type_.__name__.lower()
var = getattr(ObjectState, attr_name)
@@ -80,7 +114,7 @@ def test_state(type_: GenericType) -> None:
assert quantity._var_type is int
-@pytest.mark.parametrize("type_", [Base, Bare])
+@pytest.mark.parametrize("type_", [Base, Bare, SqlaModel, Dataclass])
def test_state_to_operation(type_: GenericType) -> None:
attr_name = type_.__name__.lower()
original_var = getattr(ObjectState, attr_name)
@@ -100,3 +134,29 @@ def test_typing() -> None:
# Base
var = ObjectState.base
_ = assert_type(var, ObjectVar[Base])
+ optional_var = ObjectState.base_optional
+ _ = assert_type(optional_var, ObjectVar[Base | None])
+ list_var = ObjectState.base_list
+ _ = assert_type(list_var, ArrayVar[list[Base]])
+ list_var_0 = list_var[0]
+ _ = assert_type(list_var_0, ObjectVar[Base])
+
+ # Sqla
+ var = ObjectState.sqlamodel
+ _ = assert_type(var, ObjectVar[SqlaModel])
+ optional_var = ObjectState.sqlamodel_optional
+ _ = assert_type(optional_var, ObjectVar[SqlaModel | None])
+ list_var = ObjectState.base_list
+ _ = assert_type(list_var, ArrayVar[list[Base]])
+ list_var_0 = list_var[0]
+ _ = assert_type(list_var_0, ObjectVar[Base])
+
+ # Dataclass
+ var = ObjectState.dataclass
+ _ = assert_type(var, ObjectVar[Dataclass])
+ optional_var = ObjectState.dataclass_optional
+ _ = assert_type(optional_var, ObjectVar[Dataclass | None])
+ list_var = ObjectState.base_list
+ _ = assert_type(list_var, ArrayVar[list[Base]])
+ list_var_0 = list_var[0]
+ _ = assert_type(list_var_0, ObjectVar[Base])