diff --git a/reflex/components/datadisplay/code.py b/reflex/components/datadisplay/code.py index 124e93ce2..a8658127f 100644 --- a/reflex/components/datadisplay/code.py +++ b/reflex/components/datadisplay/code.py @@ -8,14 +8,17 @@ from typing import ClassVar, Dict, Literal, Optional, Union from reflex.components.component import Component, ComponentNamespace from reflex.components.core.cond import color_mode_cond from reflex.components.lucide.icon import Icon +from reflex.components.markdown.markdown import ( + _LANGUAGE, + MarkdownComponentMap, +) from reflex.components.radix.themes.components.button import Button from reflex.components.radix.themes.layout.box import Box -from reflex.components.markdown.markdown import MarkdownComponentMapMixin, _PROPS, _CHILDREN from reflex.constants.colors import Color from reflex.event import set_clipboard from reflex.style import Style from reflex.utils import console, format -from reflex.utils.imports import ImportDict, ImportVar +from reflex.utils.imports import ImportVar from reflex.vars.base import LiteralVar, Var, VarData LiteralCodeLanguage = Literal[ @@ -379,9 +382,7 @@ for theme_name in dir(Theme): setattr(Theme, theme_name, getattr(Theme, theme_name)._replace(_var_type=Theme)) -LANGUAGE_VAR = Var(_js_expr="__language") - -class CodeBlock(Component, MarkdownComponentMapMixin): +class CodeBlock(Component, MarkdownComponentMap): """A code block.""" library = "react-syntax-highlighter@15.6.1" @@ -520,9 +521,8 @@ class CodeBlock(Component, MarkdownComponentMapMixin): theme = self.theme - out.add_props(style=theme).remove_props("theme", "code", "language").add_props( - children=self.code, language=LANGUAGE_VAR + children=self.code, language=_LANGUAGE ) return out @@ -530,6 +530,22 @@ class CodeBlock(Component, MarkdownComponentMapMixin): def _exclude_props(self) -> list[str]: return ["can_copy", "copy_button"] + @classmethod + def _get_language_registration_hook(cls) -> str: + """Get the hook to register the language.""" + return f""" + if ({str(_LANGUAGE)}) {{ + (async () => {{ + try {{ + const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${{{str(_LANGUAGE)}}}`); + SyntaxHighlighter.registerLanguage({str(_LANGUAGE)}, module.default); + }} catch (error) {{ + console.error(`Error importing language module for ${{{str(_LANGUAGE)}}}:`, error); + }} + }})(); + }} +""" + @classmethod def get_component_map_custom_code(cls) -> str: """Get the custom code for the component. @@ -537,21 +553,11 @@ class CodeBlock(Component, MarkdownComponentMapMixin): Returns: The custom code for the component. """ - return """ + return f""" const match = (className || '').match(/language-(?.*)/); - const language = match ? match[1] : ''; - if (language) { - (async () => { - try { - const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${{language}}`); - SyntaxHighlighter.registerLanguage(language, module.default); - } catch (error) { - console.error(`Error importing language module for ${language}:`, error); - } - })(); - } - """ - +const {str(_LANGUAGE)} = match ? match[1] : ''; +{cls._get_language_registration_hook()} +""" def add_hooks(self) -> list[str | Var]: """Add hooks for the component. @@ -560,19 +566,8 @@ const match = (className || '').match(/language-(?.*)/); The hooks for the component. """ return [ - f"const {str(LANGUAGE_VAR)} = {str(self.language)}", - f""" - if ({str(LANGUAGE_VAR)}) {{ - (async () => {{ - try {{ - const module = await import(`react-syntax-highlighter/dist/cjs/languages/prism/${{{str(LANGUAGE_VAR)}}}`); - SyntaxHighlighter.registerLanguage({str(LANGUAGE_VAR)}, module.default); - }} catch (error) {{ - console.error(`Error importing language module for ${{{str(LANGUAGE_VAR)}}}:`, error); - }} - }})(); - }} - """ + f"const {str(_LANGUAGE)} = {str(self.language)}", + self._get_language_registration_hook(), ] diff --git a/reflex/components/datadisplay/code.pyi b/reflex/components/datadisplay/code.pyi index eadcb524f..8cd37b3e3 100644 --- a/reflex/components/datadisplay/code.pyi +++ b/reflex/components/datadisplay/code.pyi @@ -7,10 +7,10 @@ import dataclasses from typing import Any, ClassVar, Dict, Literal, Optional, Union, overload from reflex.components.component import Component, ComponentNamespace +from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.constants.colors import Color from reflex.event import BASE_STATE, EventType from reflex.style import Style -from reflex.utils.imports import ImportDict from reflex.vars.base import Var LiteralCodeLanguage = Literal[ @@ -349,8 +349,7 @@ for theme_name in dir(Theme): continue setattr(Theme, theme_name, getattr(Theme, theme_name)._replace(_var_type=Theme)) -class CodeBlock(Component): - def add_imports(self) -> ImportDict: ... +class CodeBlock(Component, MarkdownComponentMap): @overload @classmethod def create( # type: ignore @@ -984,6 +983,9 @@ class CodeBlock(Component): ... def add_style(self): ... + @classmethod + def get_component_map_custom_code(cls) -> str: ... + def add_hooks(self) -> list[str | Var]: ... class CodeblockNamespace(ComponentNamespace): themes = Theme diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index b50bbe9bc..73bee1de7 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -28,6 +28,7 @@ _CHILDREN = Var(_js_expr="children", _var_type=str) _PROPS = Var(_js_expr="...props") _PROPS_IN_TAG = Var(_js_expr="{...props}") _MOCK_ARG = Var(_js_expr="", _var_type=str) +_LANGUAGE = Var(_js_expr="_language", _var_type=str) # Special remark plugins. _REMARK_MATH = Var(_js_expr="remarkMath") @@ -74,10 +75,55 @@ def get_base_component_map() -> dict[str, Callable]: } -class MarkdownComponentMapMixin: - def get_component_map_custom_code(self) -> str: +class MarkdownComponentMap: + """Mixin class for handling custom component maps in Markdown components.""" + + @classmethod + def get_component_map_custom_code(cls) -> str: + """Get the custom code for the component map. + + Returns: + The custom code for the component map. + """ return "" + @classmethod + def create_map_fn_var( + cls, fn_body: str | None = None, fn_args: list | None = None + ) -> Var: + """Create a function Var for the component map. + + Args: + fn_body: The formatted component as a string. + fn_args: The function arguments. + + Returns: + The function Var for the component map. + """ + fn_args = fn_args or cls.get_fn_args() + fn_body = fn_body or cls.get_fn_body() + fn_args_str = ", ".join(fn_args) + + return Var(_js_expr=f"(({{{fn_args_str}}}) => {fn_body})") + + @classmethod + def get_fn_args(cls) -> list[str]: + """Get the function arguments for the component map. + + Returns: + The function arguments as a list of strings. + """ + return ["node", _CHILDREN._js_expr, _PROPS._js_expr] + + @classmethod + def get_fn_body(cls) -> str: + """Get the function body for the component map. + + Returns: + The function body as a string. + """ + return "()" + class Markdown(Component): """A markdown component.""" @@ -132,7 +178,7 @@ class Markdown(Component): ) def _get_all_custom_components( - self, seen: set[str] | None = None + self, seen: set[str] | None = None ) -> set[CustomComponent]: """Get all the custom components used by the component. @@ -158,9 +204,6 @@ class Markdown(Component): Returns: The imports for the markdown component. """ - from reflex.components.datadisplay.code import CodeBlock, Theme - from reflex.components.radix.themes.typography.code import Code - return [ { "": "katex/dist/katex.min.css", @@ -184,10 +227,67 @@ class Markdown(Component): component(_MOCK_ARG)._get_all_imports() # type: ignore for component in self.component_map.values() ], - # CodeBlock.create(theme=Theme.light)._get_imports(), - # Code.create()._get_imports(), ] + def _get_tag_map_fn_var(self, tag: str) -> Var: + return self._get_map_fn_var_from_children(self.get_component(tag), tag) + + def format_component_map(self) -> dict[str, Var]: + """Format the component map for rendering. + + Returns: + The formatted component map. + """ + components = { + tag: self._get_tag_map_fn_var(tag) + for tag in self.component_map + if tag not in ("code", "codeblock") + } + + # Separate out inline code and code blocks. + components["code"] = self._get_inline_code_fn_var() + + return components + + def _get_inline_code_fn_var(self) -> Var: + """Get the function variable for inline code. + + This function creates a Var that represents a function to handle + both inline code and code blocks in markdown. + + Returns: + The Var for inline code. + """ + # Get any custom code from the codeblock and code components. + custom_code_list = self._get_map_fn_custom_code_from_children( + self.get_component("codeblock") + ) + custom_code_list.extend( + self._get_map_fn_custom_code_from_children(self.get_component("code")) + ) + + codeblock_custom_code = "\n".join(custom_code_list) + + # Format the code to handle inline and block code. + formatted_code = f"""{{{codeblock_custom_code}; + return inline ? ( + {self.format_component("code")} + ) : ( + {self.format_component("codeblock", language=_LANGUAGE)} + ); + }}""".replace("\n", " ") + + return MarkdownComponentMap.create_map_fn_var( + fn_args=[ + "node", + "inline", + "className", + _CHILDREN._js_expr, + _PROPS._js_expr, + ], + fn_body=formatted_code, + ) + def get_component(self, tag: str, **props) -> Component: """Get the component for a tag and props. @@ -244,36 +344,23 @@ class Markdown(Component): """ return str(self.get_component(tag, **props)).replace("\n", "") - def format_component_map(self) -> dict[str, Var]: - """Format the component map for rendering. + def _get_map_fn_var_from_children(self, component: Component, tag: str) -> Var: + """Create a function Var for the component map for the specified tag. + + Args: + component: The component to check for custom code. + tag: The tag of the component. Returns: - The formatted component map. + The function Var for the component map. """ - components = { - tag: Var( - _js_expr=f"(({{node, {_CHILDREN._js_expr}, {_PROPS._js_expr}}}) => ({self.format_component(tag)}))" - ) - for tag in self.component_map - } - codeblock_component = self.get_component("codeblock") - custom_code_list = self._get_custom_code_from_children(codeblock_component) - codeblock_custom_code = "\n".join(custom_code_list) - # Separate out inline code and code blocks. - components["code"] = Var( - _js_expr=f"""(({{node, inline, className, {_CHILDREN._js_expr}, {_PROPS._js_expr}}}) => {{ - {codeblock_custom_code}; - return inline ? ( - {self.format_component("code")} - ) : ( - {self.format_component("codeblock", language=Var(_js_expr="language", _var_type=str))} - ); - }})""".replace("\n", " ") - ) + if isinstance(component, MarkdownComponentMap): + return component.create_map_fn_var(f"({self.format_component(tag)})") - return components + # fallback to the default fn Var creation if the component is not a MarkdownComponentMap. + return MarkdownComponentMap.create_map_fn_var(f"({self.format_component(tag)})") - def _get_custom_code_from_children(self, component) -> list[str]: + def _get_map_fn_custom_code_from_children(self, component) -> list[str]: """Recursively get markdown custom code from children components. Args: @@ -283,16 +370,25 @@ class Markdown(Component): A list of markdown custom code strings. """ custom_code_list = [] - if hasattr(component, "get_component_map_custom_code"): + if isinstance(component, MarkdownComponentMap): custom_code_list.append(component.get_component_map_custom_code()) + # If the component is a custom component(rx.memo), obtain the underlining + # component and get the custom code from the children. if isinstance(component, CustomComponent): - custom_code_list.extend(self._get_custom_code_from_children(component.component_fn(*component.get_prop_vars()))) + custom_code_list.extend( + self._get_map_fn_custom_code_from_children( + component.component_fn(*component.get_prop_vars()) + ) + ) else: for child in component.children: - custom_code_list.extend(self._get_custom_code_from_children(child)) + custom_code_list.extend( + self._get_map_fn_custom_code_from_children(child) + ) return custom_code_list + @staticmethod def _component_map_hash(component_map) -> str: inp = str( diff --git a/reflex/components/markdown/markdown.pyi b/reflex/components/markdown/markdown.pyi index 25d6d4c00..e61d04cd4 100644 --- a/reflex/components/markdown/markdown.pyi +++ b/reflex/components/markdown/markdown.pyi @@ -16,6 +16,7 @@ _CHILDREN = Var(_js_expr="children", _var_type=str) _PROPS = Var(_js_expr="...props") _PROPS_IN_TAG = Var(_js_expr="{...props}") _MOCK_ARG = Var(_js_expr="", _var_type=str) +_LANGUAGE = Var(_js_expr="_language", _var_type=str) _REMARK_MATH = Var(_js_expr="remarkMath") _REMARK_GFM = Var(_js_expr="remarkGfm") _REMARK_UNWRAP_IMAGES = Var(_js_expr="remarkUnwrapImages") @@ -28,6 +29,18 @@ NO_PROPS_TAGS = ("ul", "ol", "li") @lru_cache def get_base_component_map() -> dict[str, Callable]: ... +class MarkdownComponentMap: + @classmethod + def get_component_map_custom_code(cls) -> str: ... + @classmethod + def create_map_fn_var( + cls, fn_body: str | None = None, fn_args: list | None = None + ) -> Var: ... + @classmethod + def get_fn_args(cls) -> list[str]: ... + @classmethod + def get_fn_body(cls) -> str: ... + class Markdown(Component): @overload @classmethod @@ -82,6 +95,6 @@ class Markdown(Component): ... def add_imports(self) -> ImportDict | list[ImportDict]: ... + def format_component_map(self) -> dict[str, Var]: ... def get_component(self, tag: str, **props) -> Component: ... def format_component(self, tag: str, **props) -> str: ... - def format_component_map(self) -> dict[str, Var]: ...