[REF-2523] Implement new public Component API (#3203)

This commit is contained in:
Masen Furer 2024-05-01 14:46:27 -07:00 committed by GitHub
parent e31b458a69
commit 19d8f6c752
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 348 additions and 0 deletions

View File

@ -167,6 +167,7 @@ _MAPPING = {
"reflex.style": ["style", "toggle_color_mode"],
"reflex.testing": ["testing"],
"reflex.utils": ["utils"],
"reflex.utils.imports": ["ImportVar"],
"reflex.vars": ["vars", "cached_var", "Var"],
}

View File

@ -150,6 +150,7 @@ from reflex import style as style
from reflex.style import toggle_color_mode as toggle_color_mode
from reflex import testing as testing
from reflex import utils as utils
from reflex.utils.imports import ImportVar as ImportVar
from reflex import vars as vars
from reflex.vars import cached_var as cached_var
from reflex.vars import Var as Var

View File

@ -213,6 +213,91 @@ class Component(BaseComponent, ABC):
# State class associated with this component instance
State: Optional[Type[reflex.state.State]] = None
def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]:
"""Add imports for the component.
This method should be implemented by subclasses to add new imports for the component.
Implementations do NOT need to call super(). The result of calling
add_imports in each parent class will be merged internally.
Returns:
The additional imports for this component subclass.
The format of the return value is a dictionary where the keys are the
library names (with optional npm-style version specifications) mapping
to a single name to be imported, or a list names to be imported.
For advanced use cases, the values can be ImportVar instances (for
example, to provide an alias or mark that an import is the default
export from the given library).
```python
return {
"react": "useEffect",
"react-draggable": ["DraggableCore", rx.ImportVar(tag="Draggable", is_default=True)],
}
```
"""
return {}
def add_hooks(self) -> list[str]:
"""Add hooks inside the component function.
Hooks are pieces of literal Javascript code that is inserted inside the
React component function.
Each logical hook should be a separate string in the list.
Common strings will be deduplicated and inserted into the component
function only once, so define const variables and other identical code
in their own strings to avoid defining the same const or hook multiple
times.
If a hook depends on specific data from the component instance, be sure
to use unique values inside the string to _avoid_ deduplication.
Implementations do NOT need to call super(). The result of calling
add_hooks in each parent class will be merged and deduplicated internally.
Returns:
The additional hooks for this component subclass.
```python
return [
"const [count, setCount] = useState(0);",
"useEffect(() => { setCount((prev) => prev + 1); console.log(`mounted ${count} times`); }, []);",
]
```
"""
return []
def add_custom_code(self) -> list[str]:
"""Add custom Javascript code into the page that contains this component.
Custom code is inserted at module level, after any imports.
Each string of custom code is deduplicated per-page, so take care to
avoid defining the same const or function differently from different
component instances.
Custom code is useful for defining global functions or constants which
can then be referenced inside hooks or used by component vars.
Implementations do NOT need to call super(). The result of calling
add_custom_code in each parent class will be merged and deduplicated internally.
Returns:
The additional custom code for this component subclass.
```python
return [
"const translatePoints = (event) => { return { x: event.clientX, y: event.clientY }; };",
]
```
"""
return []
@classmethod
def __init_subclass__(cls, **kwargs):
"""Set default properties.
@ -949,6 +1034,30 @@ class Component(BaseComponent, ABC):
return True
return False
@classmethod
def _iter_parent_classes_with_method(cls, method: str) -> Iterator[Type[Component]]:
"""Iterate through parent classes that define a given method.
Used for handling the `add_*` API functions that internally simulate a super() call chain.
Args:
method: The method to look for.
Yields:
The parent classes that define the method (differently than the base).
"""
seen_methods = set([getattr(Component, method)])
for clz in cls.mro():
if clz is Component:
break
if not issubclass(clz, Component):
continue
method_func = getattr(clz, method, None)
if not callable(method_func) or method_func in seen_methods:
continue
seen_methods.add(method_func)
yield clz
def _get_custom_code(self) -> str | None:
"""Get custom code for the component.
@ -971,6 +1080,11 @@ class Component(BaseComponent, ABC):
if custom_code is not None:
code.add(custom_code)
# Add the custom code from add_custom_code method.
for clz in self._iter_parent_classes_with_method("add_custom_code"):
for item in clz.add_custom_code(self):
code.add(item)
# Add the custom code for the children.
for child in self.children:
code |= child._get_all_custom_code()
@ -1106,6 +1220,26 @@ class Component(BaseComponent, ABC):
var._var_data.imports for var in self._get_vars() if var._var_data
]
# If any subclass implements add_imports, merge the imports.
def _make_list(
value: str | ImportVar | list[str | ImportVar],
) -> list[str | ImportVar]:
if isinstance(value, (str, ImportVar)):
return [value]
return value
_added_import_dicts = []
for clz in self._iter_parent_classes_with_method("add_imports"):
_added_import_dicts.append(
{
package: [
ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
for tag in _make_list(maybe_tags)
]
for package, maybe_tags in clz.add_imports(self).items()
}
)
return imports.merge_imports(
*self._get_props_imports(),
self._get_dependencies_imports(),
@ -1113,6 +1247,7 @@ class Component(BaseComponent, ABC):
_imports,
event_imports,
*var_imports,
*_added_import_dicts,
)
def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict:
@ -1248,6 +1383,12 @@ class Component(BaseComponent, ABC):
if hooks is not None:
code[hooks] = None
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output)
for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
for hook in clz.add_hooks(self):
code[hook] = None
# Add the hook code for the children.
for child in self.children:
code = {**code, **child._get_all_hooks()}

View File

@ -1746,3 +1746,208 @@ def test_invalid_event_trigger():
with pytest.raises(ValueError):
trigger_comp(on_b=rx.console_log("log"))
@pytest.mark.parametrize(
"tags",
(
["Component"],
["Component", "useState"],
[ImportVar(tag="Component")],
[ImportVar(tag="Component"), ImportVar(tag="useState")],
["Component", ImportVar(tag="useState")],
),
)
def test_component_add_imports(tags):
def _list_to_import_vars(tags: List[str]) -> List[ImportVar]:
return [
ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
for tag in tags
]
class BaseComponent(Component):
def _get_imports(self) -> imports.ImportDict:
return {}
class Reference(Component):
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
{"react": _list_to_import_vars(tags)},
{"foo": [ImportVar(tag="bar")]},
)
class TestBase(Component):
def add_imports(
self,
) -> Dict[str, Union[str, ImportVar, List[str], List[ImportVar]]]:
return {"foo": "bar"}
class Test(TestBase):
def add_imports(
self,
) -> Dict[str, Union[str, ImportVar, List[str], List[ImportVar]]]:
return {"react": (tags[0] if len(tags) == 1 else tags)}
baseline = Reference.create()
test = Test.create()
assert baseline._get_all_imports() == {
"react": _list_to_import_vars(tags),
"foo": [ImportVar(tag="bar")],
}
assert test._get_all_imports() == baseline._get_all_imports()
def test_component_add_hooks():
class BaseComponent(Component):
def _get_hooks(self):
return "const hook1 = 42"
class ChildComponent1(BaseComponent):
pass
class GrandchildComponent1(ChildComponent1):
def add_hooks(self):
return [
"const hook2 = 43",
"const hook3 = 44",
]
class GreatGrandchildComponent1(GrandchildComponent1):
def add_hooks(self):
return [
"const hook4 = 45",
]
class GrandchildComponent2(ChildComponent1):
def _get_hooks(self):
return "const hook5 = 46"
class GreatGrandchildComponent2(GrandchildComponent2):
def add_hooks(self):
return [
"const hook2 = 43",
"const hook6 = 47",
]
assert list(BaseComponent()._get_all_hooks()) == ["const hook1 = 42"]
assert list(ChildComponent1()._get_all_hooks()) == ["const hook1 = 42"]
assert list(GrandchildComponent1()._get_all_hooks()) == [
"const hook1 = 42",
"const hook2 = 43",
"const hook3 = 44",
]
assert list(GreatGrandchildComponent1()._get_all_hooks()) == [
"const hook1 = 42",
"const hook2 = 43",
"const hook3 = 44",
"const hook4 = 45",
]
assert list(GrandchildComponent2()._get_all_hooks()) == ["const hook5 = 46"]
assert list(GreatGrandchildComponent2()._get_all_hooks()) == [
"const hook5 = 46",
"const hook2 = 43",
"const hook6 = 47",
]
assert list(
BaseComponent.create(
GrandchildComponent1.create(GreatGrandchildComponent2()),
GreatGrandchildComponent1(),
)._get_all_hooks(),
) == [
"const hook1 = 42",
"const hook2 = 43",
"const hook3 = 44",
"const hook5 = 46",
"const hook6 = 47",
"const hook4 = 45",
]
assert list(
Fragment.create(
GreatGrandchildComponent2(),
GreatGrandchildComponent1(),
)._get_all_hooks()
) == [
"const hook5 = 46",
"const hook2 = 43",
"const hook6 = 47",
"const hook1 = 42",
"const hook3 = 44",
"const hook4 = 45",
]
def test_component_add_custom_code():
class BaseComponent(Component):
def _get_custom_code(self):
return "const custom_code1 = 42"
class ChildComponent1(BaseComponent):
pass
class GrandchildComponent1(ChildComponent1):
def add_custom_code(self):
return [
"const custom_code2 = 43",
"const custom_code3 = 44",
]
class GreatGrandchildComponent1(GrandchildComponent1):
def add_custom_code(self):
return [
"const custom_code4 = 45",
]
class GrandchildComponent2(ChildComponent1):
def _get_custom_code(self):
return "const custom_code5 = 46"
class GreatGrandchildComponent2(GrandchildComponent2):
def add_custom_code(self):
return [
"const custom_code2 = 43",
"const custom_code6 = 47",
]
assert BaseComponent()._get_all_custom_code() == {"const custom_code1 = 42"}
assert ChildComponent1()._get_all_custom_code() == {"const custom_code1 = 42"}
assert GrandchildComponent1()._get_all_custom_code() == {
"const custom_code1 = 42",
"const custom_code2 = 43",
"const custom_code3 = 44",
}
assert GreatGrandchildComponent1()._get_all_custom_code() == {
"const custom_code1 = 42",
"const custom_code2 = 43",
"const custom_code3 = 44",
"const custom_code4 = 45",
}
assert GrandchildComponent2()._get_all_custom_code() == {"const custom_code5 = 46"}
assert GreatGrandchildComponent2()._get_all_custom_code() == {
"const custom_code2 = 43",
"const custom_code5 = 46",
"const custom_code6 = 47",
}
assert BaseComponent.create(
GrandchildComponent1.create(GreatGrandchildComponent2()),
GreatGrandchildComponent1(),
)._get_all_custom_code() == {
"const custom_code1 = 42",
"const custom_code2 = 43",
"const custom_code3 = 44",
"const custom_code4 = 45",
"const custom_code5 = 46",
"const custom_code6 = 47",
}
assert Fragment.create(
GreatGrandchildComponent2(),
GreatGrandchildComponent1(),
)._get_all_custom_code() == {
"const custom_code1 = 42",
"const custom_code2 = 43",
"const custom_code3 = 44",
"const custom_code4 = 45",
"const custom_code5 = 46",
"const custom_code6 = 47",
}