[REF-2787] add_hooks supports Var-wrapped hooks (#3248)
* [REF-2787] add_hooks supports Var-wrapped hooks * Fix VarData definition in .pyi file to allow removal of type ignore comments * Var.create and Var.create_safe accept _var_data parameter * Replace instances where a set of imports was being passed to VarData * Update code throughout reduce use of `._replace` to add VarData * Fixup: user hooks _var_data.imports will never be iterable, just a single ImportDict
This commit is contained in:
parent
d96baac7d9
commit
c5f32db756
@ -241,7 +241,7 @@ class Component(BaseComponent, ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def add_hooks(self) -> list[str]:
|
||||
def add_hooks(self) -> list[str | Var]:
|
||||
"""Add hooks inside the component function.
|
||||
|
||||
Hooks are pieces of literal Javascript code that is inserted inside the
|
||||
@ -1265,11 +1265,20 @@ class Component(BaseComponent, ABC):
|
||||
},
|
||||
)
|
||||
|
||||
other_imports = []
|
||||
user_hooks = self._get_hooks()
|
||||
if user_hooks is not None and isinstance(user_hooks, Var):
|
||||
_imports = imports.merge_imports(_imports, user_hooks._var_data.imports) # type: ignore
|
||||
if (
|
||||
user_hooks is not None
|
||||
and isinstance(user_hooks, Var)
|
||||
and user_hooks._var_data is not None
|
||||
and user_hooks._var_data.imports
|
||||
):
|
||||
other_imports.append(user_hooks._var_data.imports)
|
||||
other_imports.extend(
|
||||
hook_imports for hook_imports in self._get_added_hooks().values()
|
||||
)
|
||||
|
||||
return _imports
|
||||
return imports.merge_imports(_imports, *other_imports)
|
||||
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
"""Get all the libraries and fields that are used by the component.
|
||||
@ -1416,6 +1425,36 @@ class Component(BaseComponent, ABC):
|
||||
**self._get_special_hooks(),
|
||||
}
|
||||
|
||||
def _get_added_hooks(self) -> dict[str, imports.ImportDict]:
|
||||
"""Get the hooks added via `add_hooks` method.
|
||||
|
||||
Returns:
|
||||
The deduplicated hooks and imports added by the component and parent components.
|
||||
"""
|
||||
code = {}
|
||||
|
||||
def extract_var_hooks(hook: Var):
|
||||
_imports = {}
|
||||
if hook._var_data is not None:
|
||||
for sub_hook in hook._var_data.hooks:
|
||||
code[sub_hook] = {}
|
||||
if hook._var_data.imports:
|
||||
_imports = hook._var_data.imports
|
||||
if str(hook) in code:
|
||||
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
|
||||
else:
|
||||
code[str(hook)] = _imports
|
||||
|
||||
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
|
||||
# the order of the hooks in the final output)
|
||||
for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
|
||||
for hook in clz.add_hooks(self):
|
||||
if isinstance(hook, Var):
|
||||
extract_var_hooks(hook)
|
||||
else:
|
||||
code[hook] = {}
|
||||
return code
|
||||
|
||||
def _get_hooks(self) -> str | None:
|
||||
"""Get the React hooks for this component.
|
||||
|
||||
@ -1454,11 +1493,7 @@ class Component(BaseComponent, ABC):
|
||||
if hooks is not None:
|
||||
code[hooks] = None
|
||||
|
||||
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
|
||||
# the order of the hooks in the final output)
|
||||
for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
|
||||
for hook in clz.add_hooks(self):
|
||||
code[hook] = None
|
||||
code.update(self._get_added_hooks())
|
||||
|
||||
# Add the hook code for the children.
|
||||
for child in self.children:
|
||||
@ -2092,8 +2127,8 @@ class StatefulComponent(BaseComponent):
|
||||
var_deps.extend(cls._get_hook_deps(hook))
|
||||
memo_var_data = VarData.merge(
|
||||
*[var._var_data for var in event_args],
|
||||
VarData( # type: ignore
|
||||
imports={"react": {ImportVar(tag="useCallback")}},
|
||||
VarData(
|
||||
imports={"react": [ImportVar(tag="useCallback")]},
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -29,23 +29,27 @@ connection_error: Var = Var.create_safe(
|
||||
value="(connectErrors.length > 0) ? connectErrors[connectErrors.length - 1].message : ''",
|
||||
_var_is_local=False,
|
||||
_var_is_string=False,
|
||||
)._replace(merge_var_data=connect_error_var_data)
|
||||
_var_data=connect_error_var_data,
|
||||
)
|
||||
|
||||
connection_errors_count: Var = Var.create_safe(
|
||||
value="connectErrors.length",
|
||||
_var_is_string=False,
|
||||
_var_is_local=False,
|
||||
)._replace(merge_var_data=connect_error_var_data)
|
||||
_var_data=connect_error_var_data,
|
||||
)
|
||||
|
||||
has_connection_errors: Var = Var.create_safe(
|
||||
value="connectErrors.length > 0",
|
||||
_var_is_string=False,
|
||||
)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
|
||||
_var_data=connect_error_var_data,
|
||||
).to(bool)
|
||||
|
||||
has_too_many_connection_errors: Var = Var.create_safe(
|
||||
value="connectErrors.length >= 2",
|
||||
_var_is_string=False,
|
||||
)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
|
||||
_var_data=connect_error_var_data,
|
||||
).to(bool)
|
||||
|
||||
|
||||
class WebsocketTargetURL(Bare):
|
||||
|
@ -13,7 +13,7 @@ from reflex.utils import format, imports
|
||||
from reflex.vars import BaseVar, Var, VarData
|
||||
|
||||
_IS_TRUE_IMPORT = {
|
||||
f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
|
||||
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")],
|
||||
}
|
||||
|
||||
|
||||
|
@ -109,13 +109,11 @@ class DebounceInput(Component):
|
||||
"{%s}" % (child.alias or child.tag),
|
||||
_var_is_local=False,
|
||||
_var_is_string=False,
|
||||
)._replace(
|
||||
_var_type=Type[Component],
|
||||
merge_var_data=VarData( # type: ignore
|
||||
_var_data=VarData(
|
||||
imports=child._get_imports(),
|
||||
hooks=child._get_hooks_internal(),
|
||||
),
|
||||
),
|
||||
).to(Type[Component]),
|
||||
)
|
||||
|
||||
component = super().create(**props)
|
||||
|
@ -24,12 +24,12 @@ from reflex.vars import BaseVar, CallableVar, Var, VarData
|
||||
|
||||
DEFAULT_UPLOAD_ID: str = "default"
|
||||
|
||||
upload_files_context_var_data: VarData = VarData( # type: ignore
|
||||
upload_files_context_var_data: VarData = VarData(
|
||||
imports={
|
||||
"react": {imports.ImportVar(tag="useContext")},
|
||||
f"/{Dirs.CONTEXTS_PATH}": {
|
||||
"react": [imports.ImportVar(tag="useContext")],
|
||||
f"/{Dirs.CONTEXTS_PATH}": [
|
||||
imports.ImportVar(tag="UploadFilesContext"),
|
||||
},
|
||||
],
|
||||
},
|
||||
hooks={
|
||||
"const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
|
||||
@ -118,14 +118,13 @@ def get_upload_dir() -> Path:
|
||||
|
||||
|
||||
uploaded_files_url_prefix: Var = Var.create_safe(
|
||||
"${getBackendURL(env.UPLOAD)}"
|
||||
)._replace(
|
||||
merge_var_data=VarData( # type: ignore
|
||||
"${getBackendURL(env.UPLOAD)}",
|
||||
_var_data=VarData(
|
||||
imports={
|
||||
f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")},
|
||||
"/env.json": {imports.ImportVar(tag="env", is_default=True)},
|
||||
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")],
|
||||
"/env.json": [imports.ImportVar(tag="env", is_default=True)],
|
||||
}
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@ -216,13 +216,17 @@ class Form(BaseHTML):
|
||||
if ref.startswith("refs_"):
|
||||
ref_var = Var.create_safe(ref[:-3]).as_ref()
|
||||
form_refs[ref[5:-3]] = Var.create_safe(
|
||||
f"getRefValues({str(ref_var)})", _var_is_local=False
|
||||
)._replace(merge_var_data=ref_var._var_data)
|
||||
f"getRefValues({str(ref_var)})",
|
||||
_var_is_local=False,
|
||||
_var_data=ref_var._var_data,
|
||||
)
|
||||
else:
|
||||
ref_var = Var.create_safe(ref).as_ref()
|
||||
form_refs[ref[4:]] = Var.create_safe(
|
||||
f"getRefValue({str(ref_var)})", _var_is_local=False
|
||||
)._replace(merge_var_data=ref_var._var_data)
|
||||
f"getRefValue({str(ref_var)})",
|
||||
_var_is_local=False,
|
||||
_var_data=ref_var._var_data,
|
||||
)
|
||||
return form_refs
|
||||
|
||||
def _get_vars(self, include_children: bool = True) -> Iterator[Var]:
|
||||
@ -619,14 +623,16 @@ class Textarea(BaseHTML):
|
||||
on_key_down=Var.create_safe(
|
||||
f"(e) => enterKeySubmitOnKeyDown(e, {self.enter_key_submit._var_name_unwrapped})",
|
||||
_var_is_local=False,
|
||||
)._replace(merge_var_data=self.enter_key_submit._var_data),
|
||||
_var_data=self.enter_key_submit._var_data,
|
||||
)
|
||||
)
|
||||
if self.auto_height is not None:
|
||||
tag.add_props(
|
||||
on_input=Var.create_safe(
|
||||
f"(e) => autoHeightOnInput(e, {self.auto_height._var_name_unwrapped})",
|
||||
_var_is_local=False,
|
||||
)._replace(merge_var_data=self.auto_height._var_data),
|
||||
_var_data=self.auto_height._var_data,
|
||||
)
|
||||
)
|
||||
return tag
|
||||
|
||||
|
@ -114,12 +114,14 @@ class DataTable(Gridjs):
|
||||
_var_name=f"{self.data._var_name}.columns",
|
||||
_var_type=List[Any],
|
||||
_var_full_name_needs_state_prefix=True,
|
||||
)._replace(merge_var_data=self.data._var_data)
|
||||
_var_data=self.data._var_data,
|
||||
)
|
||||
self.data = BaseVar(
|
||||
_var_name=f"{self.data._var_name}.data",
|
||||
_var_type=List[List[Any]],
|
||||
_var_full_name_needs_state_prefix=True,
|
||||
)._replace(merge_var_data=self.data._var_data)
|
||||
_var_data=self.data._var_data,
|
||||
)
|
||||
if types.is_dataframe(type(self.data)):
|
||||
# If given a pandas df break up the data and columns
|
||||
data = serialize(self.data)
|
||||
|
@ -68,7 +68,7 @@ class TabsTrigger(RadixThemesComponent):
|
||||
_valid_parents: List[str] = ["TabsList"]
|
||||
|
||||
@classmethod
|
||||
def create(self, *children, **props) -> Component:
|
||||
def create(cls, *children, **props) -> Component:
|
||||
"""Create a TabsTrigger component.
|
||||
|
||||
Args:
|
||||
|
@ -162,7 +162,7 @@ class ToastProps(PropsBase):
|
||||
class Toaster(Component):
|
||||
"""A Toaster Component for displaying toast notifications."""
|
||||
|
||||
library = "sonner@1.4.41"
|
||||
library: str = "sonner@1.4.41"
|
||||
|
||||
tag = "Toaster"
|
||||
|
||||
@ -209,12 +209,15 @@ class Toaster(Component):
|
||||
pause_when_page_is_hidden: Var[bool]
|
||||
|
||||
def _get_hooks(self) -> Var[str]:
|
||||
hook = Var.create_safe(f"{toast_ref} = toast", _var_is_local=True)
|
||||
hook._var_data = VarData( # type: ignore
|
||||
hook = Var.create_safe(
|
||||
f"{toast_ref} = toast",
|
||||
_var_is_local=True,
|
||||
_var_data=VarData(
|
||||
imports={
|
||||
"/utils/state": [ImportVar(tag="refs")],
|
||||
self.library: [ImportVar(tag="toast", install=False)],
|
||||
}
|
||||
),
|
||||
)
|
||||
return hook
|
||||
|
||||
|
@ -103,9 +103,9 @@ class Imports(SimpleNamespace):
|
||||
"""Common sets of import vars."""
|
||||
|
||||
EVENTS = {
|
||||
"react": {ImportVar(tag="useContext")},
|
||||
f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
|
||||
f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
|
||||
"react": [ImportVar(tag="useContext")],
|
||||
f"/{Dirs.CONTEXTS_PATH}": [ImportVar(tag="EventLoopContext")],
|
||||
f"/{Dirs.STATE_PATH}": [ImportVar(tag=CompileVars.TO_EVENT)],
|
||||
}
|
||||
|
||||
|
||||
|
@ -16,10 +16,10 @@ LIGHT_COLOR_MODE: str = "light"
|
||||
DARK_COLOR_MODE: str = "dark"
|
||||
|
||||
# Reference the global ColorModeContext
|
||||
color_mode_var_data = VarData( # type: ignore
|
||||
color_mode_var_data = VarData(
|
||||
imports={
|
||||
f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")},
|
||||
"react": {ImportVar(tag="useContext")},
|
||||
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")],
|
||||
"react": [ImportVar(tag="useContext")],
|
||||
},
|
||||
hooks={
|
||||
f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None,
|
||||
|
@ -341,7 +341,11 @@ class Var:
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
|
||||
cls,
|
||||
value: Any,
|
||||
_var_is_local: bool = True,
|
||||
_var_is_string: bool = False,
|
||||
_var_data: Optional[VarData] = None,
|
||||
) -> Var | None:
|
||||
"""Create a var from a value.
|
||||
|
||||
@ -349,6 +353,7 @@ class Var:
|
||||
value: The value to create the var from.
|
||||
_var_is_local: Whether the var is local.
|
||||
_var_is_string: Whether the var is a string literal.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
|
||||
Returns:
|
||||
The var.
|
||||
@ -365,9 +370,8 @@ class Var:
|
||||
return value
|
||||
|
||||
# Try to pull the imports and hooks from contained values.
|
||||
_var_data = None
|
||||
if not isinstance(value, str):
|
||||
_var_data = VarData.merge(*_extract_var_data(value))
|
||||
_var_data = VarData.merge(*_extract_var_data(value), _var_data)
|
||||
|
||||
# Try to serialize the value.
|
||||
type_ = type(value)
|
||||
@ -388,7 +392,11 @@ class Var:
|
||||
|
||||
@classmethod
|
||||
def create_safe(
|
||||
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
|
||||
cls,
|
||||
value: Any,
|
||||
_var_is_local: bool = True,
|
||||
_var_is_string: bool = False,
|
||||
_var_data: Optional[VarData] = None,
|
||||
) -> Var:
|
||||
"""Create a var from a value, asserting that it is not None.
|
||||
|
||||
@ -396,6 +404,7 @@ class Var:
|
||||
value: The value to create the var from.
|
||||
_var_is_local: Whether the var is local.
|
||||
_var_is_string: Whether the var is a string literal.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
|
||||
Returns:
|
||||
The var.
|
||||
@ -404,6 +413,7 @@ class Var:
|
||||
value,
|
||||
_var_is_local=_var_is_local,
|
||||
_var_is_string=_var_is_string,
|
||||
_var_data=_var_data,
|
||||
)
|
||||
assert var is not None
|
||||
return var
|
||||
|
@ -34,10 +34,10 @@ def _decode_var(value: str) -> tuple[VarData, str]: ...
|
||||
def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
|
||||
|
||||
class VarData(Base):
|
||||
state: str
|
||||
imports: dict[str, set[ImportVar]]
|
||||
hooks: Dict[str, None]
|
||||
interpolations: List[Tuple[int, int]]
|
||||
state: str = ""
|
||||
imports: dict[str, List[ImportVar]] = {}
|
||||
hooks: Dict[str, None] = {}
|
||||
interpolations: List[Tuple[int, int]] = []
|
||||
@classmethod
|
||||
def merge(cls, *others: VarData | None) -> VarData | None: ...
|
||||
|
||||
@ -50,11 +50,11 @@ class Var:
|
||||
_var_data: VarData | None = None
|
||||
@classmethod
|
||||
def create(
|
||||
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
|
||||
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
|
||||
) -> Optional[Var]: ...
|
||||
@classmethod
|
||||
def create_safe(
|
||||
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False
|
||||
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
|
||||
) -> Var: ...
|
||||
@classmethod
|
||||
def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...
|
||||
|
@ -1063,7 +1063,7 @@ def test_stateful_banner():
|
||||
TEST_VAR = Var.create_safe("test")._replace(
|
||||
merge_var_data=VarData(
|
||||
hooks={"useTest": None},
|
||||
imports={"test": {ImportVar(tag="test")}},
|
||||
imports={"test": [ImportVar(tag="test")]},
|
||||
state="Test",
|
||||
interpolations=[],
|
||||
)
|
||||
@ -1953,6 +1953,44 @@ def test_component_add_custom_code():
|
||||
}
|
||||
|
||||
|
||||
def test_component_add_hooks_var():
|
||||
class HookComponent(Component):
|
||||
def add_hooks(self):
|
||||
return [
|
||||
"const hook3 = useRef(null)",
|
||||
"const hook1 = 42",
|
||||
Var.create(
|
||||
"useEffect(() => () => {}, [])",
|
||||
_var_data=VarData(
|
||||
hooks={
|
||||
"const hook2 = 43": None,
|
||||
"const hook3 = useRef(null)": None,
|
||||
},
|
||||
imports={"react": [ImportVar(tag="useEffect")]},
|
||||
),
|
||||
),
|
||||
Var.create(
|
||||
"const hook3 = useRef(null)",
|
||||
_var_data=VarData(
|
||||
imports={"react": [ImportVar(tag="useRef")]},
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
assert list(HookComponent()._get_all_hooks()) == [
|
||||
"const hook3 = useRef(null)",
|
||||
"const hook1 = 42",
|
||||
"const hook2 = 43",
|
||||
"useEffect(() => () => {}, [])",
|
||||
]
|
||||
imports = HookComponent()._get_all_imports()
|
||||
assert len(imports) == 1
|
||||
assert "react" in imports
|
||||
assert len(imports["react"]) == 2
|
||||
assert ImportVar(tag="useRef") in imports["react"]
|
||||
assert ImportVar(tag="useEffect") in imports["react"]
|
||||
|
||||
|
||||
def test_add_style_embedded_vars(test_state: BaseState):
|
||||
"""Test that add_style works with embedded vars when returning a plain dict.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user