diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index 00e782399..9280ba682 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -8,14 +8,6 @@ from hashlib import md5 from typing import Any, Callable, Dict, Union from reflex.components.component import Component, CustomComponent -from reflex.components.radix.themes.layout.list import ( - ListItem, - OrderedList, - UnorderedList, -) -from reflex.components.radix.themes.typography.heading import Heading -from reflex.components.radix.themes.typography.link import Link -from reflex.components.radix.themes.typography.text import Text from reflex.components.tags.tag import Tag from reflex.utils import types from reflex.utils.imports import ImportDict, ImportVar @@ -54,7 +46,15 @@ def get_base_component_map() -> dict[str, Callable]: The base component map. """ from reflex.components.datadisplay.code import CodeBlock + from reflex.components.radix.themes.layout.list import ( + ListItem, + OrderedList, + UnorderedList, + ) from reflex.components.radix.themes.typography.code import Code + from reflex.components.radix.themes.typography.heading import Heading + from reflex.components.radix.themes.typography.link import Link + from reflex.components.radix.themes.typography.text import Text return { "h1": lambda value: Heading.create(value, as_="h1", size="6", margin_y="0.5em"), diff --git a/reflex/components/radix/themes/layout/list.py b/reflex/components/radix/themes/layout/list.py index d83fd168b..96fa169a0 100644 --- a/reflex/components/radix/themes/layout/list.py +++ b/reflex/components/radix/themes/layout/list.py @@ -8,6 +8,7 @@ from reflex.components.component import Component, ComponentNamespace from reflex.components.core.foreach import Foreach from reflex.components.el.elements.typography import Li, Ol, Ul from reflex.components.lucide.icon import Icon +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.components.radix.themes.typography.text import Text from reflex.vars.base import Var @@ -36,7 +37,7 @@ LiteralListStyleTypeOrdered = Literal[ ] -class BaseList(Component): +class BaseList(Component, MarkdownComponentMap): """Base class for ordered and unordered lists.""" tag = "ul" @@ -154,7 +155,7 @@ class OrderedList(BaseList, Ol): ) -class ListItem(Li): +class ListItem(Li, MarkdownComponentMap): """Display an item of an ordered or unordered list.""" @classmethod diff --git a/reflex/components/radix/themes/typography/code.py b/reflex/components/radix/themes/typography/code.py index ca19859d3..ab610b505 100644 --- a/reflex/components/radix/themes/typography/code.py +++ b/reflex/components/radix/themes/typography/code.py @@ -7,13 +7,14 @@ from __future__ import annotations from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.vars.base import Var from ..base import LiteralAccentColor, LiteralVariant, RadixThemesComponent from .base import LiteralTextSize, LiteralTextWeight -class Code(elements.Code, RadixThemesComponent): +class Code(elements.Code, RadixThemesComponent, MarkdownComponentMap): """A block level extended quotation.""" tag = "Code" diff --git a/reflex/components/radix/themes/typography/heading.py b/reflex/components/radix/themes/typography/heading.py index 03e109717..ce1eaa68f 100644 --- a/reflex/components/radix/themes/typography/heading.py +++ b/reflex/components/radix/themes/typography/heading.py @@ -7,13 +7,14 @@ from __future__ import annotations from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.vars.base import Var from ..base import LiteralAccentColor, RadixThemesComponent from .base import LiteralTextAlign, LiteralTextSize, LiteralTextTrim, LiteralTextWeight -class Heading(elements.H1, RadixThemesComponent): +class Heading(elements.H1, RadixThemesComponent, MarkdownComponentMap): """A foundational text primitive based on the element.""" tag = "Heading" diff --git a/reflex/components/radix/themes/typography/link.py b/reflex/components/radix/themes/typography/link.py index 6e3d2f983..1cc673536 100644 --- a/reflex/components/radix/themes/typography/link.py +++ b/reflex/components/radix/themes/typography/link.py @@ -12,6 +12,7 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.core.colors import color from reflex.components.core.cond import cond from reflex.components.el.elements.inline import A +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.components.next.link import NextLink from reflex.utils.imports import ImportDict from reflex.vars.base import Var @@ -24,7 +25,7 @@ LiteralLinkUnderline = Literal["auto", "hover", "always", "none"] next_link = NextLink.create() -class Link(RadixThemesComponent, A, MemoizationLeaf): +class Link(RadixThemesComponent, A, MemoizationLeaf, MarkdownComponentMap): """A semantic element for navigation between pages.""" tag = "Link" diff --git a/reflex/components/radix/themes/typography/text.py b/reflex/components/radix/themes/typography/text.py index e3576360a..1663ddedf 100644 --- a/reflex/components/radix/themes/typography/text.py +++ b/reflex/components/radix/themes/typography/text.py @@ -10,6 +10,7 @@ from typing import Literal from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.vars.base import Var from ..base import LiteralAccentColor, RadixThemesComponent @@ -37,7 +38,7 @@ LiteralType = Literal[ ] -class Text(elements.Span, RadixThemesComponent): +class Text(elements.Span, RadixThemesComponent, MarkdownComponentMap): """A foundational text primitive based on the element.""" tag = "Text" diff --git a/tests/units/components/markdown/__init__.py b/tests/units/components/markdown/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/units/components/markdown/test_markdown.py b/tests/units/components/markdown/test_markdown.py new file mode 100644 index 000000000..ec3f89083 --- /dev/null +++ b/tests/units/components/markdown/test_markdown.py @@ -0,0 +1,162 @@ +from typing import Type + +import pytest + +from reflex.components.component import Component, memo +from reflex.components.datadisplay.code import CodeBlock +from reflex.components.datadisplay.shiki_code_block import ShikiHighLevelCodeBlock +from reflex.components.markdown.markdown import Markdown, MarkdownComponentMap +from reflex.components.radix.themes.layout.box import Box +from reflex.components.radix.themes.typography.heading import Heading +from reflex.vars.base import Var + + +class CustomMarkdownComponent(Component, MarkdownComponentMap): + """A custom markdown component.""" + + tag = "CustomMarkdownComponent" + library = "custom" + + @classmethod + def get_fn_args(cls) -> list[str]: + """Return the function arguments.""" + return ["custom_node", "custom_children", "custom_props"] + + @classmethod + def get_fn_body(cls) -> str: + """Return the function body.""" + return "{return custom_node + custom_children + custom_props;}" + + +def syntax_highlighter_memoized_component(codeblock: Type[Component]): + @memo + def code_block(code: str, language: str): + return Box.create( + codeblock.create( + code, + language=language, + class_name="code-block", + can_copy=True, + ), + class_name="relative mb-4", + ) + + def code_block_markdown(*children, **props): + return code_block( + code=children[0], language=props.pop("language", "plain"), **props + ) + + return code_block_markdown + + +@pytest.mark.parametrize( + "fn_body, fn_args, expected", + [ + (None, None, Var(_js_expr="(({node, children, ...props}) => ())")), + ("{return node;}", ["node"], Var(_js_expr="(({node}) => {return node;})")), + ( + "{return node + children;}", + ["node", "children"], + Var(_js_expr="(({node, children}) => {return node + children;})"), + ), + ( + "{return node + props;}", + ["node", "...props"], + Var(_js_expr="(({node, ...props}) => {return node + props;})"), + ), + ( + "{return node + children + props;}", + ["node", "children", "...props"], + Var( + _js_expr="(({node, children, ...props}) => {return node + children + props;})" + ), + ), + ], +) +def test_create_map_fn_var(fn_body, fn_args, expected): + result = MarkdownComponentMap.create_map_fn_var(fn_body, fn_args) + assert result._js_expr == expected._js_expr + + +@pytest.mark.parametrize( + "cls, fn_body, fn_args, expected", + [ + ( + MarkdownComponentMap, + None, + None, + Var(_js_expr="(({node, children, ...props}) => ())"), + ), + ( + MarkdownComponentMap, + "{return node};", + ["node"], + Var(_js_expr="(({node}) => {return node};)"), + ), + ( + CustomMarkdownComponent, + None, + None, + Var( + _js_expr="(({custom_node, custom_children, custom_props}) => {return custom_node + custom_children + custom_props;})" + ), + ), + ( + CustomMarkdownComponent, + "{return custom_node;}", + ["custom_node"], + Var(_js_expr="(({custom_node}) => {return custom_node;})"), + ), + ], +) +def test_create_map_fn_var_subclass(cls, fn_body, fn_args, expected): + result = cls.create_map_fn_var(fn_body, fn_args) + assert result._js_expr == expected._js_expr + + +@pytest.mark.parametrize( + "key,component_map, expected", + [ + ( + "code", + {}, + """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; if (_language) { (async () => { try { const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`); SyntaxHighlighter.registerLanguage(_language, module.default); } catch (error) { console.error(`Error importing language module for ${_language}:`, error); } })(); } ; return inline ? ( {children} ) : ( ); })""", + ), + ( + "code", + { + "codeblock": lambda value, **props: ShikiHighLevelCodeBlock.create( + value, **props + ) + }, + """(({node, inline, className, children, ...props}) => {; return inline ? ( {children} ) : ( ); })""", + ), + ( + "h1", + { + "h1": lambda value: CustomMarkdownComponent.create( + Heading.create(value, as_="h1", size="6", margin_y="0.5em") + ) + }, + """(({custom_node, custom_children, custom_props}) => ({children}))""", + ), + ( + "code", + {"codeblock": syntax_highlighter_memoized_component(CodeBlock)}, + """(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); const _language = match ? match[1] : ''; if (_language) { (async () => { try { const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${_language}`); SyntaxHighlighter.registerLanguage(_language, module.default); } catch (error) { console.error(`Error importing language module for ${_language}:`, error); } })(); } ; return inline ? ( {children} ) : ( ); })""", + ), + ( + "code", + { + "codeblock": syntax_highlighter_memoized_component( + ShikiHighLevelCodeBlock + ) + }, + """(({node, inline, className, children, ...props}) => {; return inline ? ( {children} ) : ( ); })""", + ), + ], +) +def test_markdown_format_component(key, component_map, expected): + markdown = Markdown.create("# header", component_map=component_map) + result = markdown.format_component_map() + assert str(result[key]) == expected