[REF-2787] add_hooks supports Var-wrapped hooks (#3248)

* [REF-2787] add_hooks supports Var-wrapped hooks

* Fix VarData definition in .pyi file to allow removal of type ignore comments
* Var.create and Var.create_safe accept _var_data parameter
* Replace instances where a set of imports was being passed to VarData
* Update code throughout reduce use of `._replace` to add VarData

* Fixup: user hooks _var_data.imports will never be iterable, just a single ImportDict
This commit is contained in:
Masen Furer 2024-05-15 14:59:45 -07:00 committed by GitHub
parent d96baac7d9
commit c5f32db756
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 158 additions and 63 deletions

View File

@ -241,7 +241,7 @@ class Component(BaseComponent, ABC):
""" """
return {} return {}
def add_hooks(self) -> list[str]: def add_hooks(self) -> list[str | Var]:
"""Add hooks inside the component function. """Add hooks inside the component function.
Hooks are pieces of literal Javascript code that is inserted inside the Hooks are pieces of literal Javascript code that is inserted inside the
@ -1265,11 +1265,20 @@ class Component(BaseComponent, ABC):
}, },
) )
other_imports = []
user_hooks = self._get_hooks() user_hooks = self._get_hooks()
if user_hooks is not None and isinstance(user_hooks, Var): if (
_imports = imports.merge_imports(_imports, user_hooks._var_data.imports) # type: ignore user_hooks is not None
and isinstance(user_hooks, Var)
and user_hooks._var_data is not None
and user_hooks._var_data.imports
):
other_imports.append(user_hooks._var_data.imports)
other_imports.extend(
hook_imports for hook_imports in self._get_added_hooks().values()
)
return _imports return imports.merge_imports(_imports, *other_imports)
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component. """Get all the libraries and fields that are used by the component.
@ -1416,6 +1425,36 @@ class Component(BaseComponent, ABC):
**self._get_special_hooks(), **self._get_special_hooks(),
} }
def _get_added_hooks(self) -> dict[str, imports.ImportDict]:
"""Get the hooks added via `add_hooks` method.
Returns:
The deduplicated hooks and imports added by the component and parent components.
"""
code = {}
def extract_var_hooks(hook: Var):
_imports = {}
if hook._var_data is not None:
for sub_hook in hook._var_data.hooks:
code[sub_hook] = {}
if hook._var_data.imports:
_imports = hook._var_data.imports
if str(hook) in code:
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
else:
code[str(hook)] = _imports
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output)
for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
for hook in clz.add_hooks(self):
if isinstance(hook, Var):
extract_var_hooks(hook)
else:
code[hook] = {}
return code
def _get_hooks(self) -> str | None: def _get_hooks(self) -> str | None:
"""Get the React hooks for this component. """Get the React hooks for this component.
@ -1454,11 +1493,7 @@ class Component(BaseComponent, ABC):
if hooks is not None: if hooks is not None:
code[hooks] = None code[hooks] = None
# Add the hook code from add_hooks for each parent class (this is reversed to preserve code.update(self._get_added_hooks())
# the order of the hooks in the final output)
for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
for hook in clz.add_hooks(self):
code[hook] = None
# Add the hook code for the children. # Add the hook code for the children.
for child in self.children: for child in self.children:
@ -2092,8 +2127,8 @@ class StatefulComponent(BaseComponent):
var_deps.extend(cls._get_hook_deps(hook)) var_deps.extend(cls._get_hook_deps(hook))
memo_var_data = VarData.merge( memo_var_data = VarData.merge(
*[var._var_data for var in event_args], *[var._var_data for var in event_args],
VarData( # type: ignore VarData(
imports={"react": {ImportVar(tag="useCallback")}}, imports={"react": [ImportVar(tag="useCallback")]},
), ),
) )

View File

@ -29,23 +29,27 @@ connection_error: Var = Var.create_safe(
value="(connectErrors.length > 0) ? connectErrors[connectErrors.length - 1].message : ''", value="(connectErrors.length > 0) ? connectErrors[connectErrors.length - 1].message : ''",
_var_is_local=False, _var_is_local=False,
_var_is_string=False, _var_is_string=False,
)._replace(merge_var_data=connect_error_var_data) _var_data=connect_error_var_data,
)
connection_errors_count: Var = Var.create_safe( connection_errors_count: Var = Var.create_safe(
value="connectErrors.length", value="connectErrors.length",
_var_is_string=False, _var_is_string=False,
_var_is_local=False, _var_is_local=False,
)._replace(merge_var_data=connect_error_var_data) _var_data=connect_error_var_data,
)
has_connection_errors: Var = Var.create_safe( has_connection_errors: Var = Var.create_safe(
value="connectErrors.length > 0", value="connectErrors.length > 0",
_var_is_string=False, _var_is_string=False,
)._replace(_var_type=bool, merge_var_data=connect_error_var_data) _var_data=connect_error_var_data,
).to(bool)
has_too_many_connection_errors: Var = Var.create_safe( has_too_many_connection_errors: Var = Var.create_safe(
value="connectErrors.length >= 2", value="connectErrors.length >= 2",
_var_is_string=False, _var_is_string=False,
)._replace(_var_type=bool, merge_var_data=connect_error_var_data) _var_data=connect_error_var_data,
).to(bool)
class WebsocketTargetURL(Bare): class WebsocketTargetURL(Bare):

View File

@ -13,7 +13,7 @@ from reflex.utils import format, imports
from reflex.vars import BaseVar, Var, VarData from reflex.vars import BaseVar, Var, VarData
_IS_TRUE_IMPORT = { _IS_TRUE_IMPORT = {
f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")}, f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")],
} }

View File

@ -109,13 +109,11 @@ class DebounceInput(Component):
"{%s}" % (child.alias or child.tag), "{%s}" % (child.alias or child.tag),
_var_is_local=False, _var_is_local=False,
_var_is_string=False, _var_is_string=False,
)._replace( _var_data=VarData(
_var_type=Type[Component],
merge_var_data=VarData( # type: ignore
imports=child._get_imports(), imports=child._get_imports(),
hooks=child._get_hooks_internal(), hooks=child._get_hooks_internal(),
), ),
), ).to(Type[Component]),
) )
component = super().create(**props) component = super().create(**props)

View File

@ -24,12 +24,12 @@ from reflex.vars import BaseVar, CallableVar, Var, VarData
DEFAULT_UPLOAD_ID: str = "default" DEFAULT_UPLOAD_ID: str = "default"
upload_files_context_var_data: VarData = VarData( # type: ignore upload_files_context_var_data: VarData = VarData(
imports={ imports={
"react": {imports.ImportVar(tag="useContext")}, "react": [imports.ImportVar(tag="useContext")],
f"/{Dirs.CONTEXTS_PATH}": { f"/{Dirs.CONTEXTS_PATH}": [
imports.ImportVar(tag="UploadFilesContext"), imports.ImportVar(tag="UploadFilesContext"),
}, ],
}, },
hooks={ hooks={
"const [filesById, setFilesById] = useContext(UploadFilesContext);": None, "const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
@ -118,14 +118,13 @@ def get_upload_dir() -> Path:
uploaded_files_url_prefix: Var = Var.create_safe( uploaded_files_url_prefix: Var = Var.create_safe(
"${getBackendURL(env.UPLOAD)}" "${getBackendURL(env.UPLOAD)}",
)._replace( _var_data=VarData(
merge_var_data=VarData( # type: ignore
imports={ imports={
f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")}, f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")],
"/env.json": {imports.ImportVar(tag="env", is_default=True)}, "/env.json": [imports.ImportVar(tag="env", is_default=True)],
} }
) ),
) )

View File

@ -216,13 +216,17 @@ class Form(BaseHTML):
if ref.startswith("refs_"): if ref.startswith("refs_"):
ref_var = Var.create_safe(ref[:-3]).as_ref() ref_var = Var.create_safe(ref[:-3]).as_ref()
form_refs[ref[5:-3]] = Var.create_safe( form_refs[ref[5:-3]] = Var.create_safe(
f"getRefValues({str(ref_var)})", _var_is_local=False f"getRefValues({str(ref_var)})",
)._replace(merge_var_data=ref_var._var_data) _var_is_local=False,
_var_data=ref_var._var_data,
)
else: else:
ref_var = Var.create_safe(ref).as_ref() ref_var = Var.create_safe(ref).as_ref()
form_refs[ref[4:]] = Var.create_safe( form_refs[ref[4:]] = Var.create_safe(
f"getRefValue({str(ref_var)})", _var_is_local=False f"getRefValue({str(ref_var)})",
)._replace(merge_var_data=ref_var._var_data) _var_is_local=False,
_var_data=ref_var._var_data,
)
return form_refs return form_refs
def _get_vars(self, include_children: bool = True) -> Iterator[Var]: def _get_vars(self, include_children: bool = True) -> Iterator[Var]:
@ -619,14 +623,16 @@ class Textarea(BaseHTML):
on_key_down=Var.create_safe( on_key_down=Var.create_safe(
f"(e) => enterKeySubmitOnKeyDown(e, {self.enter_key_submit._var_name_unwrapped})", f"(e) => enterKeySubmitOnKeyDown(e, {self.enter_key_submit._var_name_unwrapped})",
_var_is_local=False, _var_is_local=False,
)._replace(merge_var_data=self.enter_key_submit._var_data), _var_data=self.enter_key_submit._var_data,
)
) )
if self.auto_height is not None: if self.auto_height is not None:
tag.add_props( tag.add_props(
on_input=Var.create_safe( on_input=Var.create_safe(
f"(e) => autoHeightOnInput(e, {self.auto_height._var_name_unwrapped})", f"(e) => autoHeightOnInput(e, {self.auto_height._var_name_unwrapped})",
_var_is_local=False, _var_is_local=False,
)._replace(merge_var_data=self.auto_height._var_data), _var_data=self.auto_height._var_data,
)
) )
return tag return tag

View File

@ -114,12 +114,14 @@ class DataTable(Gridjs):
_var_name=f"{self.data._var_name}.columns", _var_name=f"{self.data._var_name}.columns",
_var_type=List[Any], _var_type=List[Any],
_var_full_name_needs_state_prefix=True, _var_full_name_needs_state_prefix=True,
)._replace(merge_var_data=self.data._var_data) _var_data=self.data._var_data,
)
self.data = BaseVar( self.data = BaseVar(
_var_name=f"{self.data._var_name}.data", _var_name=f"{self.data._var_name}.data",
_var_type=List[List[Any]], _var_type=List[List[Any]],
_var_full_name_needs_state_prefix=True, _var_full_name_needs_state_prefix=True,
)._replace(merge_var_data=self.data._var_data) _var_data=self.data._var_data,
)
if types.is_dataframe(type(self.data)): if types.is_dataframe(type(self.data)):
# If given a pandas df break up the data and columns # If given a pandas df break up the data and columns
data = serialize(self.data) data = serialize(self.data)

View File

@ -68,7 +68,7 @@ class TabsTrigger(RadixThemesComponent):
_valid_parents: List[str] = ["TabsList"] _valid_parents: List[str] = ["TabsList"]
@classmethod @classmethod
def create(self, *children, **props) -> Component: def create(cls, *children, **props) -> Component:
"""Create a TabsTrigger component. """Create a TabsTrigger component.
Args: Args:

View File

@ -162,7 +162,7 @@ class ToastProps(PropsBase):
class Toaster(Component): class Toaster(Component):
"""A Toaster Component for displaying toast notifications.""" """A Toaster Component for displaying toast notifications."""
library = "sonner@1.4.41" library: str = "sonner@1.4.41"
tag = "Toaster" tag = "Toaster"
@ -209,12 +209,15 @@ class Toaster(Component):
pause_when_page_is_hidden: Var[bool] pause_when_page_is_hidden: Var[bool]
def _get_hooks(self) -> Var[str]: def _get_hooks(self) -> Var[str]:
hook = Var.create_safe(f"{toast_ref} = toast", _var_is_local=True) hook = Var.create_safe(
hook._var_data = VarData( # type: ignore f"{toast_ref} = toast",
imports={ _var_is_local=True,
"/utils/state": [ImportVar(tag="refs")], _var_data=VarData(
self.library: [ImportVar(tag="toast", install=False)], imports={
} "/utils/state": [ImportVar(tag="refs")],
self.library: [ImportVar(tag="toast", install=False)],
}
),
) )
return hook return hook

View File

@ -103,9 +103,9 @@ class Imports(SimpleNamespace):
"""Common sets of import vars.""" """Common sets of import vars."""
EVENTS = { EVENTS = {
"react": {ImportVar(tag="useContext")}, "react": [ImportVar(tag="useContext")],
f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")}, f"/{Dirs.CONTEXTS_PATH}": [ImportVar(tag="EventLoopContext")],
f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)}, f"/{Dirs.STATE_PATH}": [ImportVar(tag=CompileVars.TO_EVENT)],
} }

View File

@ -16,10 +16,10 @@ LIGHT_COLOR_MODE: str = "light"
DARK_COLOR_MODE: str = "dark" DARK_COLOR_MODE: str = "dark"
# Reference the global ColorModeContext # Reference the global ColorModeContext
color_mode_var_data = VarData( # type: ignore color_mode_var_data = VarData(
imports={ imports={
f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")}, f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")],
"react": {ImportVar(tag="useContext")}, "react": [ImportVar(tag="useContext")],
}, },
hooks={ hooks={
f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None, f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None,

View File

@ -341,7 +341,11 @@ class Var:
@classmethod @classmethod
def create( def create(
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False cls,
value: Any,
_var_is_local: bool = True,
_var_is_string: bool = False,
_var_data: Optional[VarData] = None,
) -> Var | None: ) -> Var | None:
"""Create a var from a value. """Create a var from a value.
@ -349,6 +353,7 @@ class Var:
value: The value to create the var from. value: The value to create the var from.
_var_is_local: Whether the var is local. _var_is_local: Whether the var is local.
_var_is_string: Whether the var is a string literal. _var_is_string: Whether the var is a string literal.
_var_data: Additional hooks and imports associated with the Var.
Returns: Returns:
The var. The var.
@ -365,9 +370,8 @@ class Var:
return value return value
# Try to pull the imports and hooks from contained values. # Try to pull the imports and hooks from contained values.
_var_data = None
if not isinstance(value, str): if not isinstance(value, str):
_var_data = VarData.merge(*_extract_var_data(value)) _var_data = VarData.merge(*_extract_var_data(value), _var_data)
# Try to serialize the value. # Try to serialize the value.
type_ = type(value) type_ = type(value)
@ -388,7 +392,11 @@ class Var:
@classmethod @classmethod
def create_safe( def create_safe(
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False cls,
value: Any,
_var_is_local: bool = True,
_var_is_string: bool = False,
_var_data: Optional[VarData] = None,
) -> Var: ) -> Var:
"""Create a var from a value, asserting that it is not None. """Create a var from a value, asserting that it is not None.
@ -396,6 +404,7 @@ class Var:
value: The value to create the var from. value: The value to create the var from.
_var_is_local: Whether the var is local. _var_is_local: Whether the var is local.
_var_is_string: Whether the var is a string literal. _var_is_string: Whether the var is a string literal.
_var_data: Additional hooks and imports associated with the Var.
Returns: Returns:
The var. The var.
@ -404,6 +413,7 @@ class Var:
value, value,
_var_is_local=_var_is_local, _var_is_local=_var_is_local,
_var_is_string=_var_is_string, _var_is_string=_var_is_string,
_var_data=_var_data,
) )
assert var is not None assert var is not None
return var return var

View File

@ -34,10 +34,10 @@ def _decode_var(value: str) -> tuple[VarData, str]: ...
def _extract_var_data(value: Iterable) -> list[VarData | None]: ... def _extract_var_data(value: Iterable) -> list[VarData | None]: ...
class VarData(Base): class VarData(Base):
state: str state: str = ""
imports: dict[str, set[ImportVar]] imports: dict[str, List[ImportVar]] = {}
hooks: Dict[str, None] hooks: Dict[str, None] = {}
interpolations: List[Tuple[int, int]] interpolations: List[Tuple[int, int]] = []
@classmethod @classmethod
def merge(cls, *others: VarData | None) -> VarData | None: ... def merge(cls, *others: VarData | None) -> VarData | None: ...
@ -50,11 +50,11 @@ class Var:
_var_data: VarData | None = None _var_data: VarData | None = None
@classmethod @classmethod
def create( def create(
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
) -> Optional[Var]: ... ) -> Optional[Var]: ...
@classmethod @classmethod
def create_safe( def create_safe(
cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None,
) -> Var: ... ) -> Var: ...
@classmethod @classmethod
def __class_getitem__(cls, type_: Type) -> _GenericAlias: ... def __class_getitem__(cls, type_: Type) -> _GenericAlias: ...

View File

@ -1063,7 +1063,7 @@ def test_stateful_banner():
TEST_VAR = Var.create_safe("test")._replace( TEST_VAR = Var.create_safe("test")._replace(
merge_var_data=VarData( merge_var_data=VarData(
hooks={"useTest": None}, hooks={"useTest": None},
imports={"test": {ImportVar(tag="test")}}, imports={"test": [ImportVar(tag="test")]},
state="Test", state="Test",
interpolations=[], interpolations=[],
) )
@ -1953,6 +1953,44 @@ def test_component_add_custom_code():
} }
def test_component_add_hooks_var():
class HookComponent(Component):
def add_hooks(self):
return [
"const hook3 = useRef(null)",
"const hook1 = 42",
Var.create(
"useEffect(() => () => {}, [])",
_var_data=VarData(
hooks={
"const hook2 = 43": None,
"const hook3 = useRef(null)": None,
},
imports={"react": [ImportVar(tag="useEffect")]},
),
),
Var.create(
"const hook3 = useRef(null)",
_var_data=VarData(
imports={"react": [ImportVar(tag="useRef")]},
),
),
]
assert list(HookComponent()._get_all_hooks()) == [
"const hook3 = useRef(null)",
"const hook1 = 42",
"const hook2 = 43",
"useEffect(() => () => {}, [])",
]
imports = HookComponent()._get_all_imports()
assert len(imports) == 1
assert "react" in imports
assert len(imports["react"]) == 2
assert ImportVar(tag="useRef") in imports["react"]
assert ImportVar(tag="useEffect") in imports["react"]
def test_add_style_embedded_vars(test_state: BaseState): def test_add_style_embedded_vars(test_state: BaseState):
"""Test that add_style works with embedded vars when returning a plain dict. """Test that add_style works with embedded vars when returning a plain dict.