[ENG-4010]Codeblock cleanup in markdown (#4233)

* Codeblock cleanup in markdown

* Initial approach to getting this working with rx.memo and reflex web

* abstract the map var logic

* the tests are not valid + pyright fix

* darglint fix

* Add unit tests plus mix components

* pyi run

* rebase on main

* fix darglint

* testing different OS

* revert

* This should fix it. Right?

* Fix tests

* minor fn signature fix

* use ArgsFunctionOperation

* use destructured args and pass the tests

* fix remaining unit tests

* fix pyi files

* rebase on main

* move language regex on codeblock to markdown

* fix tests

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
Elijah Ahianyo 2024-11-08 03:18:14 +00:00 committed by GitHub
parent 3d85936009
commit cd59ab5406
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 565 additions and 164 deletions

View File

@ -8,20 +8,6 @@
{% endfor %}
export const {{component.name}} = memo(({ {{-component.props|join(", ")-}} }) => {
{% if component.name == "CodeBlock" and "language" in component.props %}
if (language) {
(async () => {
try {
const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${language}`);
SyntaxHighlighter.registerLanguage(language, module.default);
} catch (error) {
console.error(`Error importing language module for ${language}:`, error);
}
})();
}
{% endif %}
{% for hook in component.hooks %}
{{ hook }}
{% endfor %}

View File

@ -8,13 +8,14 @@ from typing import ClassVar, Dict, Literal, Optional, Union
from reflex.components.component import Component, ComponentNamespace
from reflex.components.core.cond import color_mode_cond
from reflex.components.lucide.icon import Icon
from reflex.components.markdown.markdown import _LANGUAGE, MarkdownComponentMap
from reflex.components.radix.themes.components.button import Button
from reflex.components.radix.themes.layout.box import Box
from reflex.constants.colors import Color
from reflex.event import set_clipboard
from reflex.style import Style
from reflex.utils import console, format
from reflex.utils.imports import ImportDict, ImportVar
from reflex.utils.imports import ImportVar
from reflex.vars.base import LiteralVar, Var, VarData
LiteralCodeLanguage = Literal[
@ -378,7 +379,7 @@ for theme_name in dir(Theme):
setattr(Theme, theme_name, getattr(Theme, theme_name)._replace(_var_type=Theme))
class CodeBlock(Component):
class CodeBlock(Component, MarkdownComponentMap):
"""A code block."""
library = "react-syntax-highlighter@15.6.1"
@ -417,39 +418,6 @@ class CodeBlock(Component):
# A custom copy button to override the default one.
copy_button: Optional[Union[bool, Component]] = None
def add_imports(self) -> ImportDict:
"""Add imports for the CodeBlock component.
Returns:
The import dict.
"""
imports_: ImportDict = {}
if (
self.language is not None
and (language_without_quotes := str(self.language).replace('"', ""))
in LiteralCodeLanguage.__args__ # type: ignore
):
imports_[
f"react-syntax-highlighter/dist/cjs/languages/prism/{language_without_quotes}"
] = [
ImportVar(
tag=format.to_camel_case(language_without_quotes),
is_default=True,
install=False,
)
]
return imports_
def _get_custom_code(self) -> Optional[str]:
if (
self.language is not None
and (language_without_quotes := str(self.language).replace('"', ""))
in LiteralCodeLanguage.__args__ # type: ignore
):
return f"{self.alias}.registerLanguage('{language_without_quotes}', {format.to_camel_case(language_without_quotes)})"
@classmethod
def create(
cls,
@ -534,8 +502,8 @@ class CodeBlock(Component):
theme = self.theme
out.add_props(style=theme).remove_props("theme", "code").add_props(
children=self.code
out.add_props(style=theme).remove_props("theme", "code", "language").add_props(
children=self.code, language=_LANGUAGE
)
return out
@ -543,6 +511,46 @@ class CodeBlock(Component):
def _exclude_props(self) -> list[str]:
return ["can_copy", "copy_button"]
@classmethod
def _get_language_registration_hook(cls) -> str:
"""Get the hook to register the language.
Returns:
The hook to register the language.
"""
return f"""
if ({str(_LANGUAGE)}) {{
(async () => {{
try {{
const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${{{str(_LANGUAGE)}}}`);
SyntaxHighlighter.registerLanguage({str(_LANGUAGE)}, module.default);
}} catch (error) {{
console.error(`Error importing language module for ${{{str(_LANGUAGE)}}}:`, error);
}}
}})();
}}
"""
@classmethod
def get_component_map_custom_code(cls) -> str:
"""Get the custom code for the component.
Returns:
The custom code for the component.
"""
return cls._get_language_registration_hook()
def add_hooks(self) -> list[str | Var]:
"""Add hooks for the component.
Returns:
The hooks for the component.
"""
return [
f"const {str(_LANGUAGE)} = {str(self.language)}",
self._get_language_registration_hook(),
]
class CodeblockNamespace(ComponentNamespace):
"""Namespace for the CodeBlock component."""

View File

@ -7,10 +7,10 @@ import dataclasses
from typing import Any, ClassVar, Dict, Literal, Optional, Union, overload
from reflex.components.component import Component, ComponentNamespace
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.constants.colors import Color
from reflex.event import BASE_STATE, EventType
from reflex.style import Style
from reflex.utils.imports import ImportDict
from reflex.vars.base import Var
LiteralCodeLanguage = Literal[
@ -349,8 +349,7 @@ for theme_name in dir(Theme):
continue
setattr(Theme, theme_name, getattr(Theme, theme_name)._replace(_var_type=Theme))
class CodeBlock(Component):
def add_imports(self) -> ImportDict: ...
class CodeBlock(Component, MarkdownComponentMap):
@overload
@classmethod
def create( # type: ignore
@ -984,6 +983,9 @@ class CodeBlock(Component):
...
def add_style(self): ...
@classmethod
def get_component_map_custom_code(cls) -> str: ...
def add_hooks(self) -> list[str | Var]: ...
class CodeblockNamespace(ComponentNamespace):
themes = Theme

View File

@ -12,6 +12,7 @@ from reflex.components.core.colors import color
from reflex.components.core.cond import color_mode_cond
from reflex.components.el.elements.forms import Button
from reflex.components.lucide.icon import Icon
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.components.props import NoExtrasAllowedProps
from reflex.components.radix.themes.layout.box import Box
from reflex.event import run_script, set_clipboard
@ -528,7 +529,7 @@ class ShikiJsTransformer(ShikiBaseTransformers):
super().__init__(**kwargs)
class ShikiCodeBlock(Component):
class ShikiCodeBlock(Component, MarkdownComponentMap):
"""A Code block."""
library = "/components/shiki/code"

View File

@ -7,6 +7,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload
from reflex.base import Base
from reflex.components.component import Component, ComponentNamespace
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.components.props import NoExtrasAllowedProps
from reflex.event import BASE_STATE, EventType
from reflex.style import Style
@ -350,7 +351,7 @@ class ShikiJsTransformer(ShikiBaseTransformers):
fns: list[FunctionStringVar]
style: Optional[Style]
class ShikiCodeBlock(Component):
class ShikiCodeBlock(Component, MarkdownComponentMap):
@overload
@classmethod
def create( # type: ignore

View File

@ -2,25 +2,18 @@
from __future__ import annotations
import dataclasses
import textwrap
from functools import lru_cache
from hashlib import md5
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Dict, Sequence, Union
from reflex.components.component import Component, CustomComponent
from reflex.components.radix.themes.layout.list import (
ListItem,
OrderedList,
UnorderedList,
)
from reflex.components.radix.themes.typography.heading import Heading
from reflex.components.radix.themes.typography.link import Link
from reflex.components.radix.themes.typography.text import Text
from reflex.components.tags.tag import Tag
from reflex.utils import types
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars.base import LiteralVar, Var
from reflex.vars.function import ARRAY_ISARRAY
from reflex.vars.function import ARRAY_ISARRAY, ArgsFunctionOperation, DestructuredArg
from reflex.vars.number import ternary_operation
# Special vars used in the component map.
@ -28,6 +21,7 @@ _CHILDREN = Var(_js_expr="children", _var_type=str)
_PROPS = Var(_js_expr="...props")
_PROPS_IN_TAG = Var(_js_expr="{...props}")
_MOCK_ARG = Var(_js_expr="", _var_type=str)
_LANGUAGE = Var(_js_expr="_language", _var_type=str)
# Special remark plugins.
_REMARK_MATH = Var(_js_expr="remarkMath")
@ -53,7 +47,15 @@ def get_base_component_map() -> dict[str, Callable]:
The base component map.
"""
from reflex.components.datadisplay.code import CodeBlock
from reflex.components.radix.themes.layout.list import (
ListItem,
OrderedList,
UnorderedList,
)
from reflex.components.radix.themes.typography.code import Code
from reflex.components.radix.themes.typography.heading import Heading
from reflex.components.radix.themes.typography.link import Link
from reflex.components.radix.themes.typography.text import Text
return {
"h1": lambda value: Heading.create(value, as_="h1", size="6", margin_y="0.5em"),
@ -74,6 +76,67 @@ def get_base_component_map() -> dict[str, Callable]:
}
@dataclasses.dataclass()
class MarkdownComponentMap:
"""Mixin class for handling custom component maps in Markdown components."""
_explicit_return: bool = dataclasses.field(default=False)
@classmethod
def get_component_map_custom_code(cls) -> str:
"""Get the custom code for the component map.
Returns:
The custom code for the component map.
"""
return ""
@classmethod
def create_map_fn_var(
cls,
fn_body: Var | None = None,
fn_args: Sequence[str] | None = None,
explicit_return: bool | None = None,
) -> Var:
"""Create a function Var for the component map.
Args:
fn_body: The formatted component as a string.
fn_args: The function arguments.
explicit_return: Whether to use explicit return syntax.
Returns:
The function Var for the component map.
"""
fn_args = fn_args or cls.get_fn_args()
fn_body = fn_body if fn_body is not None else cls.get_fn_body()
explicit_return = explicit_return or cls._explicit_return
return ArgsFunctionOperation.create(
args_names=(DestructuredArg(fields=tuple(fn_args)),),
return_expr=fn_body,
explicit_return=explicit_return,
)
@classmethod
def get_fn_args(cls) -> Sequence[str]:
"""Get the function arguments for the component map.
Returns:
The function arguments as a list of strings.
"""
return ["node", _CHILDREN._js_expr, _PROPS._js_expr]
@classmethod
def get_fn_body(cls) -> Var:
"""Get the function body for the component map.
Returns:
The function body as a string.
"""
return Var(_js_expr="undefined", _var_type=None)
class Markdown(Component):
"""A markdown component."""
@ -153,9 +216,6 @@ class Markdown(Component):
Returns:
The imports for the markdown component.
"""
from reflex.components.datadisplay.code import CodeBlock, Theme
from reflex.components.radix.themes.typography.code import Code
return [
{
"": "katex/dist/katex.min.css",
@ -179,10 +239,71 @@ class Markdown(Component):
component(_MOCK_ARG)._get_all_imports() # type: ignore
for component in self.component_map.values()
],
CodeBlock.create(theme=Theme.light)._get_imports(),
Code.create()._get_imports(),
]
def _get_tag_map_fn_var(self, tag: str) -> Var:
return self._get_map_fn_var_from_children(self.get_component(tag), tag)
def format_component_map(self) -> dict[str, Var]:
"""Format the component map for rendering.
Returns:
The formatted component map.
"""
components = {
tag: self._get_tag_map_fn_var(tag)
for tag in self.component_map
if tag not in ("code", "codeblock")
}
# Separate out inline code and code blocks.
components["code"] = self._get_inline_code_fn_var()
return components
def _get_inline_code_fn_var(self) -> Var:
"""Get the function variable for inline code.
This function creates a Var that represents a function to handle
both inline code and code blocks in markdown.
Returns:
The Var for inline code.
"""
# Get any custom code from the codeblock and code components.
custom_code_list = self._get_map_fn_custom_code_from_children(
self.get_component("codeblock")
)
custom_code_list.extend(
self._get_map_fn_custom_code_from_children(self.get_component("code"))
)
codeblock_custom_code = "\n".join(custom_code_list)
# Format the code to handle inline and block code.
formatted_code = f"""
const match = (className || '').match(/language-(?<lang>.*)/);
const {str(_LANGUAGE)} = match ? match[1] : '';
{codeblock_custom_code};
return inline ? (
{self.format_component("code")}
) : (
{self.format_component("codeblock", language=_LANGUAGE)}
);
""".replace("\n", " ")
return MarkdownComponentMap.create_map_fn_var(
fn_args=(
"node",
"inline",
"className",
_CHILDREN._js_expr,
_PROPS._js_expr,
),
fn_body=Var(_js_expr=formatted_code),
explicit_return=True,
)
def get_component(self, tag: str, **props) -> Component:
"""Get the component for a tag and props.
@ -239,43 +360,53 @@ class Markdown(Component):
"""
return str(self.get_component(tag, **props)).replace("\n", "")
def format_component_map(self) -> dict[str, Var]:
"""Format the component map for rendering.
def _get_map_fn_var_from_children(self, component: Component, tag: str) -> Var:
"""Create a function Var for the component map for the specified tag.
Args:
component: The component to check for custom code.
tag: The tag of the component.
Returns:
The formatted component map.
The function Var for the component map.
"""
components = {
tag: Var(
_js_expr=f"(({{node, {_CHILDREN._js_expr}, {_PROPS._js_expr}}}) => ({self.format_component(tag)}))"
)
for tag in self.component_map
}
# Separate out inline code and code blocks.
components["code"] = Var(
_js_expr=f"""(({{node, inline, className, {_CHILDREN._js_expr}, {_PROPS._js_expr}}}) => {{
const match = (className || '').match(/language-(?<lang>.*)/);
const language = match ? match[1] : '';
if (language) {{
(async () => {{
try {{
const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${{language}}`);
SyntaxHighlighter.registerLanguage(language, module.default);
}} catch (error) {{
console.error(`Error importing language module for ${{language}}:`, error);
}}
}})();
}}
return inline ? (
{self.format_component("code")}
) : (
{self.format_component("codeblock", language=Var(_js_expr="language", _var_type=str))}
);
}})""".replace("\n", " ")
formatted_component = Var(
_js_expr=f"({self.format_component(tag)})", _var_type=str
)
if isinstance(component, MarkdownComponentMap):
return component.create_map_fn_var(fn_body=formatted_component)
return components
# fallback to the default fn Var creation if the component is not a MarkdownComponentMap.
return MarkdownComponentMap.create_map_fn_var(fn_body=formatted_component)
def _get_map_fn_custom_code_from_children(self, component) -> list[str]:
"""Recursively get markdown custom code from children components.
Args:
component: The component to check for custom code.
Returns:
A list of markdown custom code strings.
"""
custom_code_list = []
if isinstance(component, MarkdownComponentMap):
custom_code_list.append(component.get_component_map_custom_code())
# If the component is a custom component(rx.memo), obtain the underlining
# component and get the custom code from the children.
if isinstance(component, CustomComponent):
custom_code_list.extend(
self._get_map_fn_custom_code_from_children(
component.component_fn(*component.get_prop_vars())
)
)
elif isinstance(component, Component):
for child in component.children:
custom_code_list.extend(
self._get_map_fn_custom_code_from_children(child)
)
return custom_code_list
@staticmethod
def _component_map_hash(component_map) -> str:
@ -288,12 +419,12 @@ class Markdown(Component):
return f"ComponentMap_{self.component_map_hash}"
def _get_custom_code(self) -> str | None:
hooks = set()
hooks = {}
for _component in self.component_map.values():
comp = _component(_MOCK_ARG)
hooks.update(comp._get_all_hooks_internal())
hooks.update(comp._get_all_hooks())
formatted_hooks = "\n".join(hooks)
formatted_hooks = "\n".join(hooks.keys())
return f"""
function {self._get_component_map_name()} () {{
{formatted_hooks}

View File

@ -3,8 +3,9 @@
# ------------------- DO NOT EDIT ----------------------
# This file was generated by `reflex/utils/pyi_generator.py`!
# ------------------------------------------------------
import dataclasses
from functools import lru_cache
from typing import Any, Callable, Dict, Optional, Union, overload
from typing import Any, Callable, Dict, Optional, Sequence, Union, overload
from reflex.components.component import Component
from reflex.event import BASE_STATE, EventType
@ -16,6 +17,7 @@ _CHILDREN = Var(_js_expr="children", _var_type=str)
_PROPS = Var(_js_expr="...props")
_PROPS_IN_TAG = Var(_js_expr="{...props}")
_MOCK_ARG = Var(_js_expr="", _var_type=str)
_LANGUAGE = Var(_js_expr="_language", _var_type=str)
_REMARK_MATH = Var(_js_expr="remarkMath")
_REMARK_GFM = Var(_js_expr="remarkGfm")
_REMARK_UNWRAP_IMAGES = Var(_js_expr="remarkUnwrapImages")
@ -27,6 +29,21 @@ NO_PROPS_TAGS = ("ul", "ol", "li")
@lru_cache
def get_base_component_map() -> dict[str, Callable]: ...
@dataclasses.dataclass()
class MarkdownComponentMap:
@classmethod
def get_component_map_custom_code(cls) -> str: ...
@classmethod
def create_map_fn_var(
cls,
fn_body: Var | None = None,
fn_args: Sequence[str] | None = None,
explicit_return: bool | None = None,
) -> Var: ...
@classmethod
def get_fn_args(cls) -> Sequence[str]: ...
@classmethod
def get_fn_body(cls) -> Var: ...
class Markdown(Component):
@overload
@ -82,6 +99,6 @@ class Markdown(Component):
...
def add_imports(self) -> ImportDict | list[ImportDict]: ...
def format_component_map(self) -> dict[str, Var]: ...
def get_component(self, tag: str, **props) -> Component: ...
def format_component(self, tag: str, **props) -> str: ...
def format_component_map(self) -> dict[str, Var]: ...

View File

@ -8,6 +8,7 @@ from reflex.components.component import Component, ComponentNamespace
from reflex.components.core.foreach import Foreach
from reflex.components.el.elements.typography import Li, Ol, Ul
from reflex.components.lucide.icon import Icon
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.components.radix.themes.typography.text import Text
from reflex.vars.base import Var
@ -36,7 +37,7 @@ LiteralListStyleTypeOrdered = Literal[
]
class BaseList(Component):
class BaseList(Component, MarkdownComponentMap):
"""Base class for ordered and unordered lists."""
tag = "ul"
@ -154,7 +155,7 @@ class OrderedList(BaseList, Ol):
)
class ListItem(Li):
class ListItem(Li, MarkdownComponentMap):
"""Display an item of an ordered or unordered list."""
@classmethod

View File

@ -7,6 +7,7 @@ from typing import Any, Dict, Iterable, Literal, Optional, Union, overload
from reflex.components.component import Component, ComponentNamespace
from reflex.components.el.elements.typography import Li, Ol, Ul
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.event import BASE_STATE, EventType
from reflex.style import Style
from reflex.vars.base import Var
@ -29,7 +30,7 @@ LiteralListStyleTypeOrdered = Literal[
"katakana",
]
class BaseList(Component):
class BaseList(Component, MarkdownComponentMap):
@overload
@classmethod
def create( # type: ignore
@ -393,7 +394,7 @@ class OrderedList(BaseList, Ol):
"""
...
class ListItem(Li):
class ListItem(Li, MarkdownComponentMap):
@overload
@classmethod
def create( # type: ignore

View File

@ -7,13 +7,14 @@ from __future__ import annotations
from reflex.components.core.breakpoints import Responsive
from reflex.components.el import elements
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.vars.base import Var
from ..base import LiteralAccentColor, LiteralVariant, RadixThemesComponent
from .base import LiteralTextSize, LiteralTextWeight
class Code(elements.Code, RadixThemesComponent):
class Code(elements.Code, RadixThemesComponent, MarkdownComponentMap):
"""A block level extended quotation."""
tag = "Code"

View File

@ -7,13 +7,14 @@ from typing import Any, Dict, Literal, Optional, Union, overload
from reflex.components.core.breakpoints import Breakpoints
from reflex.components.el import elements
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.event import BASE_STATE, EventType
from reflex.style import Style
from reflex.vars.base import Var
from ..base import RadixThemesComponent
class Code(elements.Code, RadixThemesComponent):
class Code(elements.Code, RadixThemesComponent, MarkdownComponentMap):
@overload
@classmethod
def create( # type: ignore

View File

@ -7,13 +7,14 @@ from __future__ import annotations
from reflex.components.core.breakpoints import Responsive
from reflex.components.el import elements
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.vars.base import Var
from ..base import LiteralAccentColor, RadixThemesComponent
from .base import LiteralTextAlign, LiteralTextSize, LiteralTextTrim, LiteralTextWeight
class Heading(elements.H1, RadixThemesComponent):
class Heading(elements.H1, RadixThemesComponent, MarkdownComponentMap):
"""A foundational text primitive based on the <span> element."""
tag = "Heading"

View File

@ -7,13 +7,14 @@ from typing import Any, Dict, Literal, Optional, Union, overload
from reflex.components.core.breakpoints import Breakpoints
from reflex.components.el import elements
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.event import BASE_STATE, EventType
from reflex.style import Style
from reflex.vars.base import Var
from ..base import RadixThemesComponent
class Heading(elements.H1, RadixThemesComponent):
class Heading(elements.H1, RadixThemesComponent, MarkdownComponentMap):
@overload
@classmethod
def create( # type: ignore

View File

@ -12,6 +12,7 @@ from reflex.components.core.breakpoints import Responsive
from reflex.components.core.colors import color
from reflex.components.core.cond import cond
from reflex.components.el.elements.inline import A
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.components.next.link import NextLink
from reflex.utils.imports import ImportDict
from reflex.vars.base import Var
@ -24,7 +25,7 @@ LiteralLinkUnderline = Literal["auto", "hover", "always", "none"]
next_link = NextLink.create()
class Link(RadixThemesComponent, A, MemoizationLeaf):
class Link(RadixThemesComponent, A, MemoizationLeaf, MarkdownComponentMap):
"""A semantic element for navigation between pages."""
tag = "Link"

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload
from reflex.components.component import MemoizationLeaf
from reflex.components.core.breakpoints import Breakpoints
from reflex.components.el.elements.inline import A
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.components.next.link import NextLink
from reflex.event import BASE_STATE, EventType
from reflex.style import Style
@ -19,7 +20,7 @@ from ..base import RadixThemesComponent
LiteralLinkUnderline = Literal["auto", "hover", "always", "none"]
next_link = NextLink.create()
class Link(RadixThemesComponent, A, MemoizationLeaf):
class Link(RadixThemesComponent, A, MemoizationLeaf, MarkdownComponentMap):
def add_imports(self) -> ImportDict: ...
@overload
@classmethod

View File

@ -10,6 +10,7 @@ from typing import Literal
from reflex.components.component import ComponentNamespace
from reflex.components.core.breakpoints import Responsive
from reflex.components.el import elements
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.vars.base import Var
from ..base import LiteralAccentColor, RadixThemesComponent
@ -37,7 +38,7 @@ LiteralType = Literal[
]
class Text(elements.Span, RadixThemesComponent):
class Text(elements.Span, RadixThemesComponent, MarkdownComponentMap):
"""A foundational text primitive based on the <span> element."""
tag = "Text"

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload
from reflex.components.component import ComponentNamespace
from reflex.components.core.breakpoints import Breakpoints
from reflex.components.el import elements
from reflex.components.markdown.markdown import MarkdownComponentMap
from reflex.event import BASE_STATE, EventType
from reflex.style import Style
from reflex.vars.base import Var
@ -35,7 +36,7 @@ LiteralType = Literal[
"sup",
]
class Text(elements.Span, RadixThemesComponent):
class Text(elements.Span, RadixThemesComponent, MarkdownComponentMap):
@overload
@classmethod
def create( # type: ignore

View File

@ -45,6 +45,7 @@ from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
from reflex.vars.function import (
ArgsFunctionOperation,
FunctionArgs,
FunctionStringVar,
FunctionVar,
VarOperationCall,
@ -1643,7 +1644,7 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
_js_expr="",
_var_type=EventChain,
_var_data=_var_data,
_args_names=arg_def,
_args=FunctionArgs(arg_def),
_return_expr=invocation.call(
LiteralVar.create([LiteralVar.create(event) for event in value.events]),
arg_def_expr,

View File

@ -4,8 +4,9 @@ from __future__ import annotations
import dataclasses
import sys
from typing import Any, Callable, Optional, Tuple, Type, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
from reflex.utils import format
from reflex.utils.types import GenericType
from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
@ -126,6 +127,36 @@ class VarOperationCall(CachedVarOperation, Var):
)
@dataclasses.dataclass(frozen=True)
class DestructuredArg:
"""Class for destructured arguments."""
fields: Tuple[str, ...] = tuple()
rest: Optional[str] = None
def to_javascript(self) -> str:
"""Convert the destructured argument to JavaScript.
Returns:
The destructured argument in JavaScript.
"""
return format.wrap(
", ".join(self.fields) + (f", ...{self.rest}" if self.rest else ""),
"{",
"}",
)
@dataclasses.dataclass(
frozen=True,
)
class FunctionArgs:
"""Class for function arguments."""
args: Tuple[Union[str, DestructuredArg], ...] = tuple()
rest: Optional[str] = None
@dataclasses.dataclass(
eq=False,
frozen=True,
@ -134,8 +165,9 @@ class VarOperationCall(CachedVarOperation, Var):
class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
"""Base class for immutable function defined via arguments and return expression."""
_args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
_explicit_return: bool = dataclasses.field(default=False)
@cached_property_no_lock
def _cached_var_name(self) -> str:
@ -144,13 +176,31 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
Returns:
The name of the var.
"""
return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))"
arg_names_str = ", ".join(
[
arg if isinstance(arg, str) else arg.to_javascript()
for arg in self._args.args
]
) + (f", ...{self._args.rest}" if self._args.rest else "")
return_expr_str = str(LiteralVar.create(self._return_expr))
# Wrap return expression in curly braces if explicit return syntax is used.
return_expr_str_wrapped = (
format.wrap(return_expr_str, "{", "}")
if self._explicit_return
else return_expr_str
)
return f"(({arg_names_str}) => {return_expr_str_wrapped})"
@classmethod
def create(
cls,
args_names: Tuple[str, ...],
args_names: Sequence[Union[str, DestructuredArg]],
return_expr: Var | Any,
rest: str | None = None,
explicit_return: bool = False,
_var_type: GenericType = Callable,
_var_data: VarData | None = None,
) -> ArgsFunctionOperation:
@ -159,6 +209,8 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
Args:
args_names: The names of the arguments.
return_expr: The return expression of the function.
rest: The name of the rest argument.
explicit_return: Whether to use explicit return syntax.
_var_data: Additional hooks and imports associated with the Var.
Returns:
@ -168,8 +220,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
_js_expr="",
_var_type=_var_type,
_var_data=_var_data,
_args_names=args_names,
_args=FunctionArgs(args=tuple(args_names), rest=rest),
_return_expr=return_expr,
_explicit_return=explicit_return,
)

View File

@ -62,14 +62,14 @@ def test_script_event_handler():
)
render_dict = component.render()
assert (
f'onReady={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{ }}), ({{ }})))], args, ({{ }})))))}}'
f'onReady={{((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_ready", ({{ }}), ({{ }})))], args, ({{ }}))))}}'
in render_dict["props"]
)
assert (
f'onLoad={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_load", ({{ }}), ({{ }})))], args, ({{ }})))))}}'
f'onLoad={{((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_load", ({{ }}), ({{ }})))], args, ({{ }}))))}}'
in render_dict["props"]
)
assert (
f'onError={{((...args) => ((addEvents([(Event("{EvState.get_full_name()}.on_error", ({{ }}), ({{ }})))], args, ({{ }})))))}}'
f'onError={{((...args) => (addEvents([(Event("{EvState.get_full_name()}.on_error", ({{ }}), ({{ }})))], args, ({{ }}))))}}'
in render_dict["props"]
)

View File

@ -11,22 +11,3 @@ def test_code_light_dark_theme(theme, expected):
code_block = CodeBlock.create(theme=theme)
assert code_block.theme._js_expr == expected # type: ignore
def generate_custom_code(language, expected_case):
return f"SyntaxHighlighter.registerLanguage('{language}', {expected_case})"
@pytest.mark.parametrize(
"language, expected_case",
[
("python", "python"),
("firestore-security-rules", "firestoreSecurityRules"),
("typescript", "typescript"),
],
)
def test_get_custom_code(language, expected_case):
code_block = CodeBlock.create(language=language)
assert code_block._get_custom_code() == generate_custom_code(
language, expected_case
)

View File

@ -0,0 +1,190 @@
from typing import Type
import pytest
from reflex.components.component import Component, memo
from reflex.components.datadisplay.code import CodeBlock
from reflex.components.datadisplay.shiki_code_block import ShikiHighLevelCodeBlock
from reflex.components.markdown.markdown import Markdown, MarkdownComponentMap
from reflex.components.radix.themes.layout.box import Box
from reflex.components.radix.themes.typography.heading import Heading
from reflex.vars.base import Var
class CustomMarkdownComponent(Component, MarkdownComponentMap):
"""A custom markdown component."""
tag = "CustomMarkdownComponent"
library = "custom"
@classmethod
def get_fn_args(cls) -> tuple[str, ...]:
"""Return the function arguments.
Returns:
The function arguments.
"""
return ("custom_node", "custom_children", "custom_props")
@classmethod
def get_fn_body(cls) -> Var:
"""Return the function body.
Returns:
The function body.
"""
return Var(_js_expr="{return custom_node + custom_children + custom_props}")
def syntax_highlighter_memoized_component(codeblock: Type[Component]):
@memo
def code_block(code: str, language: str):
return Box.create(
codeblock.create(
code,
language=language,
class_name="code-block",
can_copy=True,
),
class_name="relative mb-4",
)
def code_block_markdown(*children, **props):
return code_block(
code=children[0], language=props.pop("language", "plain"), **props
)
return code_block_markdown
@pytest.mark.parametrize(
"fn_body, fn_args, explicit_return, expected",
[
(
None,
None,
False,
Var(_js_expr="(({node, children, ...props}) => undefined)"),
),
("return node", ("node",), True, Var(_js_expr="(({node}) => {return node})")),
(
"return node + children",
("node", "children"),
True,
Var(_js_expr="(({node, children}) => {return node + children})"),
),
(
"return node + props",
("node", "...props"),
True,
Var(_js_expr="(({node, ...props}) => {return node + props})"),
),
(
"return node + children + props",
("node", "children", "...props"),
True,
Var(
_js_expr="(({node, children, ...props}) => {return node + children + props})"
),
),
],
)
def test_create_map_fn_var(fn_body, fn_args, explicit_return, expected):
result = MarkdownComponentMap.create_map_fn_var(
fn_body=Var(_js_expr=fn_body, _var_type=str) if fn_body else None,
fn_args=fn_args,
explicit_return=explicit_return,
)
assert result._js_expr == expected._js_expr
@pytest.mark.parametrize(
("cls", "fn_body", "fn_args", "explicit_return", "expected"),
[
(
MarkdownComponentMap,
None,
None,
False,
Var(_js_expr="(({node, children, ...props}) => undefined)"),
),
(
MarkdownComponentMap,
"return node",
("node",),
True,
Var(_js_expr="(({node}) => {return node})"),
),
(
CustomMarkdownComponent,
None,
None,
True,
Var(
_js_expr="(({custom_node, custom_children, custom_props}) => {return custom_node + custom_children + custom_props})"
),
),
(
CustomMarkdownComponent,
"return custom_node",
("custom_node",),
True,
Var(_js_expr="(({custom_node}) => {return custom_node})"),
),
],
)
def test_create_map_fn_var_subclass(cls, fn_body, fn_args, explicit_return, expected):
result = cls.create_map_fn_var(
fn_body=Var(_js_expr=fn_body, _var_type=int) if fn_body else None,
fn_args=fn_args,
explicit_return=explicit_return,
)
assert result._js_expr == expected._js_expr
@pytest.mark.parametrize(
"key,component_map, expected",
[
(
"code",
{},
"""(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?<lang>.*)/); const _language = match ? match[1] : ''; if (_language) { (async () => { try { const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`); SyntaxHighlighter.registerLanguage(_language, module.default); } catch (error) { console.error(`Error importing language module for ${_language}:`, error); } })(); } ; return inline ? ( <RadixThemesCode {...props}>{children}</RadixThemesCode> ) : ( <SyntaxHighlighter children={((Array.isArray(children)) ? children.join("\\n") : children)} css={({ ["marginTop"] : "1em", ["marginBottom"] : "1em" })} customStyle={({ ["marginTop"] : "1em", ["marginBottom"] : "1em" })} language={_language} style={((resolvedColorMode === "light") ? oneLight : oneDark)} wrapLongLines={true} {...props}/> ); })""",
),
(
"code",
{
"codeblock": lambda value, **props: ShikiHighLevelCodeBlock.create(
value, **props
)
},
"""(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?<lang>.*)/); const _language = match ? match[1] : ''; ; return inline ? ( <RadixThemesCode {...props}>{children}</RadixThemesCode> ) : ( <RadixThemesBox css={({ ["pre"] : ({ ["margin"] : "0", ["padding"] : "24px", ["background"] : "transparent", ["overflow-x"] : "auto", ["border-radius"] : "6px" }) })} {...props}><ShikiCode code={((Array.isArray(children)) ? children.join("\\n") : children)} decorations={[]} language={_language} theme={((resolvedColorMode === "light") ? "one-light" : "one-dark-pro")} transformers={[]}/></RadixThemesBox> ); })""",
),
(
"h1",
{
"h1": lambda value: CustomMarkdownComponent.create(
Heading.create(value, as_="h1", size="6", margin_y="0.5em")
)
},
"""(({custom_node, custom_children, custom_props}) => (<CustomMarkdownComponent {...props}><RadixThemesHeading as={"h1"} css={({ ["marginTop"] : "0.5em", ["marginBottom"] : "0.5em" })} size={"6"}>{children}</RadixThemesHeading></CustomMarkdownComponent>))""",
),
(
"code",
{"codeblock": syntax_highlighter_memoized_component(CodeBlock)},
"""(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?<lang>.*)/); const _language = match ? match[1] : ''; if (_language) { (async () => { try { const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`); SyntaxHighlighter.registerLanguage(_language, module.default); } catch (error) { console.error(`Error importing language module for ${_language}:`, error); } })(); } ; return inline ? ( <RadixThemesCode {...props}>{children}</RadixThemesCode> ) : ( <CodeBlock code={((Array.isArray(children)) ? children.join("\\n") : children)} language={_language} {...props}/> ); })""",
),
(
"code",
{
"codeblock": syntax_highlighter_memoized_component(
ShikiHighLevelCodeBlock
)
},
"""(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?<lang>.*)/); const _language = match ? match[1] : ''; ; return inline ? ( <RadixThemesCode {...props}>{children}</RadixThemesCode> ) : ( <CodeBlock code={((Array.isArray(children)) ? children.join("\\n") : children)} language={_language} {...props}/> ); })""",
),
],
)
def test_markdown_format_component(key, component_map, expected):
markdown = Markdown.create("# header", component_map=component_map)
result = markdown.format_component_map()
assert str(result[key]) == expected

View File

@ -844,9 +844,9 @@ def test_component_event_trigger_arbitrary_args():
comp = C1.create(on_foo=C1State.mock_handler)
assert comp.render()["props"][0] == (
"onFoo={((__e, _alpha, _bravo, _charlie) => ((addEvents("
"onFoo={((__e, _alpha, _bravo, _charlie) => (addEvents("
f'[(Event("{C1State.get_full_name()}.mock_handler", ({{ ["_e"] : __e["target"]["value"], ["_bravo"] : _bravo["nested"], ["_charlie"] : (_charlie["custom"] + 42) }}), ({{ }})))], '
"[__e, _alpha, _bravo, _charlie], ({ })))))}"
"[__e, _alpha, _bravo, _charlie], ({ }))))}"
)

View File

@ -222,16 +222,16 @@ def test_event_console_log():
assert spec.handler.fn.__qualname__ == "_call_function"
assert spec.args[0][0].equals(Var(_js_expr="function"))
assert spec.args[0][1].equals(
Var('(() => ((console["log"]("message"))))', _var_type=Callable)
Var('(() => (console["log"]("message")))', _var_type=Callable)
)
assert (
format.format_event(spec)
== 'Event("_call_function", {function:(() => ((console["log"]("message"))))})'
== 'Event("_call_function", {function:(() => (console["log"]("message")))})'
)
spec = event.console_log(Var(_js_expr="message"))
assert (
format.format_event(spec)
== 'Event("_call_function", {function:(() => ((console["log"](message))))})'
== 'Event("_call_function", {function:(() => (console["log"](message)))})'
)
@ -242,16 +242,16 @@ def test_event_window_alert():
assert spec.handler.fn.__qualname__ == "_call_function"
assert spec.args[0][0].equals(Var(_js_expr="function"))
assert spec.args[0][1].equals(
Var('(() => ((window["alert"]("message"))))', _var_type=Callable)
Var('(() => (window["alert"]("message")))', _var_type=Callable)
)
assert (
format.format_event(spec)
== 'Event("_call_function", {function:(() => ((window["alert"]("message"))))})'
== 'Event("_call_function", {function:(() => (window["alert"]("message")))})'
)
spec = event.window_alert(Var(_js_expr="message"))
assert (
format.format_event(spec)
== 'Event("_call_function", {function:(() => ((window["alert"](message))))})'
== 'Event("_call_function", {function:(() => (window["alert"](message)))})'
)

View File

@ -22,7 +22,11 @@ from reflex.vars.base import (
var_operation,
var_operation_return,
)
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar
from reflex.vars.function import (
ArgsFunctionOperation,
DestructuredArg,
FunctionStringVar,
)
from reflex.vars.number import LiteralBooleanVar, LiteralNumberVar, NumberVar
from reflex.vars.object import LiteralObjectVar, ObjectVar
from reflex.vars.sequence import (
@ -921,13 +925,13 @@ def test_function_var():
)
assert (
str(manual_addition_func.call(1, 2))
== '(((a, b) => (({ ["args"] : [a, b], ["result"] : a + b })))(1, 2))'
== '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))'
)
increment_func = addition_func(1)
assert (
str(increment_func.call(2))
== "(((...args) => ((((a, b) => a + b)(1, ...args))))(2))"
== "(((...args) => (((a, b) => a + b)(1, ...args)))(2))"
)
create_hello_statement = ArgsFunctionOperation.create(
@ -937,9 +941,25 @@ def test_function_var():
last_name = LiteralStringVar.create("Universe")
assert (
str(create_hello_statement.call(f"{first_name} {last_name}"))
== '(((name) => (("Hello, "+name+"!")))("Steven Universe"))'
== '(((name) => ("Hello, "+name+"!"))("Steven Universe"))'
)
# Test with destructured arguments
destructured_func = ArgsFunctionOperation.create(
(DestructuredArg(fields=("a", "b")),),
Var(_js_expr="a + b"),
)
assert (
str(destructured_func.call({"a": 1, "b": 2}))
== '((({a, b}) => a + b)(({ ["a"] : 1, ["b"] : 2 })))'
)
# Test with explicit return
explicit_return_func = ArgsFunctionOperation.create(
("a", "b"), Var(_js_expr="return a + b"), explicit_return=True
)
assert str(explicit_return_func.call(1, 2)) == "(((a, b) => {return a + b})(1, 2))"
def test_var_operation():
@var_operation

View File

@ -374,7 +374,7 @@ def test_format_match(
events=[EventSpec(handler=EventHandler(fn=mock_event))],
args_spec=lambda: [],
),
'((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ })))))',
'((...args) => (addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ }))))',
),
(
EventChain(
@ -395,7 +395,7 @@ def test_format_match(
],
args_spec=lambda e: [e.target.value],
),
'((_e) => ((addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] }), ({ })))], [_e], ({ })))))',
'((_e) => (addEvents([(Event("mock_event", ({ ["arg"] : _e["target"]["value"] }), ({ })))], [_e], ({ }))))',
),
(
EventChain(
@ -403,7 +403,7 @@ def test_format_match(
args_spec=lambda: [],
event_actions={"stopPropagation": True},
),
'((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["stopPropagation"] : true })))))',
'((...args) => (addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["stopPropagation"] : true }))))',
),
(
EventChain(
@ -415,7 +415,7 @@ def test_format_match(
],
args_spec=lambda: [],
),
'((...args) => ((addEvents([(Event("mock_event", ({ }), ({ ["stopPropagation"] : true })))], args, ({ })))))',
'((...args) => (addEvents([(Event("mock_event", ({ }), ({ ["stopPropagation"] : true })))], args, ({ }))))',
),
(
EventChain(
@ -423,7 +423,7 @@ def test_format_match(
args_spec=lambda: [],
event_actions={"preventDefault": True},
),
'((...args) => ((addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["preventDefault"] : true })))))',
'((...args) => (addEvents([(Event("mock_event", ({ }), ({ })))], args, ({ ["preventDefault"] : true }))))',
),
({"a": "red", "b": "blue"}, '({ ["a"] : "red", ["b"] : "blue" })'),
(Var(_js_expr="var", _var_type=int).guess_type(), "var"),