Merge branch 'main' into improve-client-state

This commit is contained in:
Khaleel Al-Adhami 2025-01-07 10:34:31 -08:00
commit ad8fb19521
17 changed files with 153 additions and 71 deletions

View File

@ -16,7 +16,6 @@ repository = "https://github.com/reflex-dev/reflex"
documentation = "https://reflex.dev/docs/getting-started/introduction"
keywords = ["web", "framework"]
classifiers = ["Development Status :: 4 - Beta"]
packages = [{ include = "reflex" }]
[tool.poetry.dependencies]
python = "^3.9"

View File

@ -1,4 +1,5 @@
{% extends "web/pages/base_page.js.jinja2" %}
{% from "web/pages/macros.js.jinja2" import renderHooks %}
{% block early_imports %}
import '$/styles/styles.css'
@ -18,10 +19,7 @@ import * as {{library_alias}} from "{{library_path}}";
{% block export %}
function AppWrap({children}) {
{% for hook in hooks %}
{{ hook }}
{% endfor %}
{{ renderHooks(hooks) }}
return (
{{utils.render(render, indent_width=0)}}

View File

@ -1,5 +1,5 @@
{% extends "web/pages/base_page.js.jinja2" %}
{% from "web/pages/macros.js.jinja2" import renderHooks %}
{% block export %}
{% for component in components %}
@ -8,9 +8,8 @@
{% endfor %}
export const {{component.name}} = memo(({ {{-component.props|join(", ")-}} }) => {
{% for hook in component.hooks %}
{{ hook }}
{% endfor %}
{{ renderHooks(component.hooks) }}
return(
{{utils.render(component.render)}}
)

View File

@ -1,4 +1,5 @@
{% extends "web/pages/base_page.js.jinja2" %}
{% from "web/pages/macros.js.jinja2" import renderHooks %}
{% block declaration %}
{% for custom_code in custom_codes %}
@ -8,9 +9,7 @@
{% block export %}
export default function Component() {
{% for hook in hooks %}
{{ hook }}
{% endfor %}
{{ renderHooks(hooks)}}
return (
{{utils.render(render, indent_width=0)}}

View File

@ -0,0 +1,38 @@
{% macro renderHooks(hooks) %}
{% set sorted_hooks = sort_hooks(hooks) %}
{# Render the grouped hooks #}
{% for hook, _ in sorted_hooks[const.hook_position.INTERNAL] %}
{{ hook }}
{% endfor %}
{% for hook, _ in sorted_hooks[const.hook_position.PRE_TRIGGER] %}
{{ hook }}
{% endfor %}
{% for hook, _ in sorted_hooks[const.hook_position.POST_TRIGGER] %}
{{ hook }}
{% endfor %}
{% endmacro %}
{% macro renderHooksWithMemo(hooks, memo)%}
{% set sorted_hooks = sort_hooks(hooks) %}
{# Render the grouped hooks #}
{% for hook, _ in sorted_hooks[const.hook_position.INTERNAL] %}
{{ hook }}
{% endfor %}
{% for hook, _ in sorted_hooks[const.hook_position.PRE_TRIGGER] %}
{{ hook }}
{% endfor %}
{% for hook in memo %}
{{ hook }}
{% endfor %}
{% for hook, _ in sorted_hooks[const.hook_position.POST_TRIGGER] %}
{{ hook }}
{% endfor %}
{% endmacro %}

View File

@ -1,22 +1,10 @@
{% import 'web/pages/utils.js.jinja2' as utils %}
{% from 'web/pages/macros.js.jinja2' import renderHooksWithMemo %}
{% set all_hooks = component._get_all_hooks() %}
export function {{tag_name}} () {
{% for hook in component._get_all_hooks_internal() %}
{{ hook }}
{% endfor %}
{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == const.hook_position.PRE_TRIGGER %}
{{ hook }}
{% endfor %}
{% for hook in memo_trigger_hooks %}
{{ hook }}
{% endfor %}
{% for hook, data in component._get_all_hooks().items() if data.position and data.position == const.hook_position.POST_TRIGGER %}
{{ hook }}
{% endfor %}
{{ renderHooksWithMemo(all_hooks, memo_trigger_hooks) }}
return (
{{utils.render(component.render(), indent_width=0)}}
)

View File

@ -75,7 +75,7 @@ def _compile_app(app_root: Component) -> str:
return templates.APP_ROOT.render(
imports=utils.compile_imports(app_root._get_all_imports()),
custom_codes=app_root._get_all_custom_code(),
hooks={**app_root._get_all_hooks_internal(), **app_root._get_all_hooks()},
hooks=app_root._get_all_hooks(),
window_libraries=window_libraries,
render=app_root.render(),
)
@ -149,7 +149,7 @@ def _compile_page(
imports=imports,
dynamic_imports=component._get_all_dynamic_imports(),
custom_codes=component._get_all_custom_code(),
hooks={**component._get_all_hooks_internal(), **component._get_all_hooks()},
hooks=component._get_all_hooks(),
render=component.render(),
**kwargs,
)

View File

@ -1,9 +1,46 @@
"""Templates to use in the reflex compiler."""
from __future__ import annotations
from jinja2 import Environment, FileSystemLoader, Template
from reflex import constants
from reflex.constants import Hooks
from reflex.utils.format import format_state_name, json_dumps
from reflex.vars.base import VarData
def _sort_hooks(hooks: dict[str, VarData | None]):
"""Sort the hooks by their position.
Args:
hooks: The hooks to sort.
Returns:
The sorted hooks.
"""
sorted_hooks = {
Hooks.HookPosition.INTERNAL: [],
Hooks.HookPosition.PRE_TRIGGER: [],
Hooks.HookPosition.POST_TRIGGER: [],
}
for hook, data in hooks.items():
if data and data.position and data.position == Hooks.HookPosition.INTERNAL:
sorted_hooks[Hooks.HookPosition.INTERNAL].append((hook, data))
elif not data or (
not data.position
or data.position == constants.Hooks.HookPosition.PRE_TRIGGER
):
sorted_hooks[Hooks.HookPosition.PRE_TRIGGER].append((hook, data))
elif (
data
and data.position
and data.position == constants.Hooks.HookPosition.POST_TRIGGER
):
sorted_hooks[Hooks.HookPosition.POST_TRIGGER].append((hook, data))
return sorted_hooks
class ReflexJinjaEnvironment(Environment):
@ -47,6 +84,7 @@ class ReflexJinjaEnvironment(Environment):
"frontend_exception_state": constants.CompileVars.FRONTEND_EXCEPTION_STATE_FULL,
"hook_position": constants.Hooks.HookPosition,
}
self.globals["sort_hooks"] = _sort_hooks
def get_template(name: str) -> Template:
@ -103,6 +141,9 @@ STYLE = get_template("web/styles/styles.css.jinja2")
# Code that generate the package json file
PACKAGE_JSON = get_template("web/package.json.jinja2")
# Template containing some macros used in the web pages.
MACROS = get_template("web/pages/macros.js.jinja2")
# Code that generate the pyproject.toml file for custom components.
CUSTOM_COMPONENTS_PYPROJECT_TOML = get_template(
"custom_components/pyproject.toml.jinja2"

View File

@ -290,7 +290,7 @@ def compile_custom_component(
"name": component.tag,
"props": props,
"render": render.render(),
"hooks": {**render._get_all_hooks_internal(), **render._get_all_hooks()},
"hooks": render._get_all_hooks(),
"custom_code": render._get_all_custom_code(),
},
imports,

View File

@ -9,6 +9,7 @@ from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless
from reflex.utils.imports import ParsedImportDict
from reflex.vars import BooleanVar, ObjectVar, Var
from reflex.vars.base import VarData
class Bare(Component):
@ -32,7 +33,7 @@ class Bare(Component):
contents = str(contents) if contents is not None else ""
return cls(contents=contents) # type: ignore
def _get_all_hooks_internal(self) -> dict[str, None]:
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Include the hooks for the component.
Returns:
@ -43,7 +44,7 @@ class Bare(Component):
hooks |= self.contents._var_value._get_all_hooks_internal()
return hooks
def _get_all_hooks(self) -> dict[str, None]:
def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Include the hooks for the component.
Returns:

View File

@ -102,7 +102,7 @@ class BaseComponent(Base, ABC):
"""
@abstractmethod
def _get_all_hooks_internal(self) -> dict[str, None]:
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children.
Returns:
@ -110,7 +110,7 @@ class BaseComponent(Base, ABC):
"""
@abstractmethod
def _get_all_hooks(self) -> dict[str, None]:
def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component.
Returns:
@ -1272,7 +1272,7 @@ class Component(BaseComponent, ABC):
"""
_imports = {}
if self._get_ref_hook():
if self._get_ref_hook() is not None:
# Handle hooks needed for attaching react refs to DOM nodes.
_imports.setdefault("react", set()).add(ImportVar(tag="useRef"))
_imports.setdefault(f"$/{Dirs.STATE_PATH}", set()).add(
@ -1388,7 +1388,7 @@ class Component(BaseComponent, ABC):
}}
}}, []);"""
def _get_ref_hook(self) -> str | None:
def _get_ref_hook(self) -> Var | None:
"""Generate the ref hook for the component.
Returns:
@ -1396,11 +1396,12 @@ class Component(BaseComponent, ABC):
"""
ref = self.get_ref()
if ref is not None:
return (
f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};"
return Var(
f"const {ref} = useRef(null); {Var(_js_expr=ref)._as_ref()!s} = {ref};",
_var_data=VarData(position=Hooks.HookPosition.INTERNAL),
)
def _get_vars_hooks(self) -> dict[str, None]:
def _get_vars_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by vars referenced in this component.
Returns:
@ -1413,27 +1414,38 @@ class Component(BaseComponent, ABC):
vars_hooks.update(
var_data.hooks
if isinstance(var_data.hooks, dict)
else {k: None for k in var_data.hooks}
else {
k: VarData(position=Hooks.HookPosition.INTERNAL)
for k in var_data.hooks
}
)
return vars_hooks
def _get_events_hooks(self) -> dict[str, None]:
def _get_events_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by events referenced in this component.
Returns:
The hooks for the events.
"""
return {Hooks.EVENTS: None} if self.event_triggers else {}
return (
{Hooks.EVENTS: VarData(position=Hooks.HookPosition.INTERNAL)}
if self.event_triggers
else {}
)
def _get_special_hooks(self) -> dict[str, None]:
def _get_special_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks required by special actions referenced in this component.
Returns:
The hooks for special actions.
"""
return {Hooks.AUTOFOCUS: None} if self.autofocus else {}
return (
{Hooks.AUTOFOCUS: VarData(position=Hooks.HookPosition.INTERNAL)}
if self.autofocus
else {}
)
def _get_hooks_internal(self) -> dict[str, None]:
def _get_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component managed by the framework.
Downstream components should NOT override this method to avoid breaking
@ -1444,7 +1456,7 @@ class Component(BaseComponent, ABC):
"""
return {
**{
hook: None
str(hook): VarData(position=Hooks.HookPosition.INTERNAL)
for hook in [self._get_ref_hook(), self._get_mount_lifecycle_hook()]
if hook is not None
},
@ -1493,7 +1505,7 @@ class Component(BaseComponent, ABC):
"""
return
def _get_all_hooks_internal(self) -> dict[str, None]:
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children.
Returns:
@ -1508,7 +1520,7 @@ class Component(BaseComponent, ABC):
return code
def _get_all_hooks(self) -> dict[str, None]:
def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component and its children.
Returns:
@ -1516,6 +1528,9 @@ class Component(BaseComponent, ABC):
"""
code = {}
# Add the internal hooks for this component.
code.update(self._get_hooks_internal())
# Add the hook code for this component.
hooks = self._get_hooks()
if hooks is not None:
@ -2211,7 +2226,7 @@ class StatefulComponent(BaseComponent):
)
return trigger_memo
def _get_all_hooks_internal(self) -> dict[str, None]:
def _get_all_hooks_internal(self) -> dict[str, VarData | None]:
"""Get the reflex internal hooks for the component and its children.
Returns:
@ -2219,7 +2234,7 @@ class StatefulComponent(BaseComponent):
"""
return {}
def _get_all_hooks(self) -> dict[str, None]:
def _get_all_hooks(self) -> dict[str, VarData | None]:
"""Get the React hooks for this component.
Returns:
@ -2337,7 +2352,7 @@ class MemoizationLeaf(Component):
The memoization leaf
"""
comp = super().create(*children, **props)
if comp._get_all_hooks() or comp._get_all_hooks_internal():
if comp._get_all_hooks():
comp._memoization_mode = cls._memoization_mode.copy(
update={"disposition": MemoizationDisposition.ALWAYS}
)

View File

@ -182,9 +182,7 @@ class Form(BaseHTML):
props["handle_submit_unique_name"] = ""
form = super().create(*children, **props)
form.handle_submit_unique_name = md5(
str({**form._get_all_hooks_internal(), **form._get_all_hooks()}).encode(
"utf-8"
)
str(form._get_all_hooks()).encode("utf-8")
).hexdigest()
return form

View File

@ -420,11 +420,12 @@ const {_LANGUAGE!s} = match ? match[1] : '';
def _get_custom_code(self) -> str | None:
hooks = {}
from reflex.compiler.templates import MACROS
for _component in self.component_map.values():
comp = _component(_MOCK_ARG)
hooks.update(comp._get_all_hooks_internal())
hooks.update(comp._get_all_hooks())
formatted_hooks = "\n".join(hooks.keys())
formatted_hooks = MACROS.module.renderHooks(hooks) # type: ignore
return f"""
function {self._get_component_map_name()} () {{
{formatted_hooks}

View File

@ -135,6 +135,7 @@ class Hooks(SimpleNamespace):
class HookPosition(enum.Enum):
"""The position of the hook in the component."""
INTERNAL = "internal"
PRE_TRIGGER = "pre_trigger"
POST_TRIGGER = "post_trigger"

View File

@ -107,7 +107,7 @@ class ClientStateVar(Var):
else:
default_var = default
setter_name = f"set{var_name.capitalize()}"
hooks = {
hooks: dict[str, VarData | None] = {
f"const {id_name} = useId()": None,
f"const [{var_name}, {setter_name}] = useState({default_var!s})": None,
}

View File

@ -28,8 +28,8 @@ import typer
from alembic.util.exc import CommandError
from packaging import version
from redis import Redis as RedisSync
from redis import exceptions
from redis.asyncio import Redis
from redis.exceptions import RedisError
from reflex import constants, model
from reflex.compiler import templates
@ -333,10 +333,11 @@ def get_redis() -> Redis | None:
Returns:
The asynchronous redis client.
"""
if isinstance((redis_url_or_options := parse_redis_url()), str):
return Redis.from_url(redis_url_or_options)
elif isinstance(redis_url_or_options, dict):
return Redis(**redis_url_or_options)
if (redis_url := parse_redis_url()) is not None:
return Redis.from_url(
redis_url,
retry_on_error=[RedisError],
)
return None
@ -346,14 +347,15 @@ def get_redis_sync() -> RedisSync | None:
Returns:
The synchronous redis client.
"""
if isinstance((redis_url_or_options := parse_redis_url()), str):
return RedisSync.from_url(redis_url_or_options)
elif isinstance(redis_url_or_options, dict):
return RedisSync(**redis_url_or_options)
if (redis_url := parse_redis_url()) is not None:
return RedisSync.from_url(
redis_url,
retry_on_error=[RedisError],
)
return None
def parse_redis_url() -> str | dict | None:
def parse_redis_url() -> str | None:
"""Parse the REDIS_URL in config if applicable.
Returns:
@ -387,7 +389,7 @@ async def get_redis_status() -> dict[str, bool | None]:
redis_client.ping()
else:
status = None
except exceptions.RedisError:
except RedisError:
status = False
return {"redis": status}

View File

@ -127,7 +127,7 @@ class VarData:
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None,
hooks: dict[str, VarData | None] | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
):
@ -194,7 +194,9 @@ class VarData:
(var_data.state for var_data in all_var_datas if var_data.state), ""
)
hooks = {hook: None for var_data in all_var_datas for hook in var_data.hooks}
hooks: dict[str, VarData | None] = {
hook: None for var_data in all_var_datas for hook in var_data.hooks
}
_imports = imports.merge_imports(
*(var_data.imports for var_data in all_var_datas)