diff --git a/reflex/.templates/jinja/web/pages/_app.js.jinja2 b/reflex/.templates/jinja/web/pages/_app.js.jinja2 index 654f7a2a4..97c31925d 100644 --- a/reflex/.templates/jinja/web/pages/_app.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/_app.js.jinja2 @@ -7,6 +7,10 @@ import '/styles/styles.css' {% block declaration %} import { EventLoopProvider, StateProvider, defaultColorMode } from "/utils/context.js"; import { ThemeProvider } from 'next-themes' +import * as React from "react"; +import * as utils_context from "/utils/context.js"; +import * as utils_state from "/utils/state.js"; +import * as radix from "@radix-ui/themes"; {% for custom_code in custom_codes %} {{custom_code}} @@ -26,6 +30,16 @@ function AppWrap({children}) { } export default function MyApp({ Component, pageProps }) { + React.useEffect(() => { + // Make contexts and state objects available globally for dynamic eval'd components + let windowImports = { + "react": React, + "@radix-ui/themes": radix, + "/utils/context": utils_context, + "/utils/state": utils_state, + }; + window["__reflex"] = windowImports; + }, []); return ( diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 26b2d0d0c..66b50b1b4 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -15,6 +15,7 @@ import { } from "utils/context.js"; import debounce from "/utils/helpers/debounce"; import throttle from "/utils/helpers/throttle"; +import * as Babel from "@babel/standalone"; // Endpoint URLs. const EVENTURL = env.EVENT; @@ -117,8 +118,8 @@ export const isStateful = () => { if (event_queue.length === 0) { return false; } - return event_queue.some(event => event.name.startsWith("reflex___state")); -} + return event_queue.some((event) => event.name.startsWith("reflex___state")); +}; /** * Apply a delta to the state. @@ -129,6 +130,22 @@ export const applyDelta = (state, delta) => { return { ...state, ...delta }; }; +/** + * Evaluate a dynamic component. + * @param component The component to evaluate. + * @returns The evaluated component. + */ +export const evalReactComponent = async (component) => { + if (!window.React && window.__reflex) { + window.React = window.__reflex.react; + } + const output = Babel.transform(component, { presets: ["react"] }).code; + const encodedJs = encodeURIComponent(output); + const dataUri = "data:text/javascript;charset=utf-8," + encodedJs; + const module = await eval(`import(dataUri)`); + return module.default; +}; + /** * Only Queue and process events when websocket connection exists. * @param event The event to queue. @@ -141,7 +158,7 @@ export const queueEventIfSocketExists = async (events, socket) => { return; } await queueEvents(events, socket); -} +}; /** * Handle frontend event or send the event to the backend via Websocket. @@ -208,7 +225,10 @@ export const applyEvent = async (event, socket) => { const a = document.createElement("a"); a.hidden = true; // Special case when linking to uploaded files - a.href = event.payload.url.replace("${getBackendURL(env.UPLOAD)}", getBackendURL(env.UPLOAD)) + a.href = event.payload.url.replace( + "${getBackendURL(env.UPLOAD)}", + getBackendURL(env.UPLOAD) + ); a.download = event.payload.filename; a.click(); a.remove(); @@ -249,7 +269,7 @@ export const applyEvent = async (event, socket) => { } catch (e) { console.log("_call_script", e); if (window && window?.onerror) { - window.onerror(e.message, null, null, null, e) + window.onerror(e.message, null, null, null, e); } } return false; @@ -290,10 +310,9 @@ export const applyEvent = async (event, socket) => { export const applyRestEvent = async (event, socket) => { let eventSent = false; if (event.handler === "uploadFiles") { - if (event.payload.files === undefined || event.payload.files.length === 0) { // Submit the event over the websocket to trigger the event handler. - return await applyEvent(Event(event.name), socket) + return await applyEvent(Event(event.name), socket); } // Start upload, but do not wait for it, which would block other events. @@ -397,7 +416,7 @@ export const connect = async ( console.log("Disconnect backend before bfcache on navigation"); socket.current.disconnect(); } - } + }; // Once the socket is open, hydrate the page. socket.current.on("connect", () => { @@ -416,7 +435,7 @@ export const connect = async ( }); // On each received message, queue the updates and events. - socket.current.on("event", (message) => { + socket.current.on("event", async (message) => { const update = JSON5.parse(message); for (const substate in update.delta) { dispatch[substate](update.delta[substate]); @@ -574,7 +593,11 @@ export const hydrateClientStorage = (client_storage) => { } } } - if (client_storage.cookies || client_storage.local_storage || client_storage.session_storage) { + if ( + client_storage.cookies || + client_storage.local_storage || + client_storage.session_storage + ) { return client_storage_values; } return {}; @@ -614,15 +637,17 @@ const applyClientStorageDelta = (client_storage, delta) => { ) { const options = client_storage.local_storage[state_key]; localStorage.setItem(options.name || state_key, delta[substate][key]); - } else if( + } else if ( client_storage.session_storage && state_key in client_storage.session_storage && typeof window !== "undefined" ) { const session_options = client_storage.session_storage[state_key]; - sessionStorage.setItem(session_options.name || state_key, delta[substate][key]); + sessionStorage.setItem( + session_options.name || state_key, + delta[substate][key] + ); } - } } }; @@ -651,7 +676,7 @@ export const useEventLoop = ( if (!(args instanceof Array)) { args = [args]; } - const _e = args.filter((o) => o?.preventDefault !== undefined)[0] + const _e = args.filter((o) => o?.preventDefault !== undefined)[0]; if (event_actions?.preventDefault && _e?.preventDefault) { _e.preventDefault(); @@ -671,7 +696,7 @@ export const useEventLoop = ( debounce( combined_name, () => queueEvents(events, socket), - event_actions.debounce, + event_actions.debounce ); } else { queueEvents(events, socket); @@ -696,30 +721,32 @@ export const useEventLoop = ( } }, [router.isReady]); - // Handle frontend errors and send them to the backend via websocket. - useEffect(() => { - - if (typeof window === 'undefined') { - return; - } - - window.onerror = function (msg, url, lineNo, columnNo, error) { - addEvents([Event(`${exception_state_name}.handle_frontend_exception`, { - stack: error.stack, - })]) - return false; - } + // Handle frontend errors and send them to the backend via websocket. + useEffect(() => { + if (typeof window === "undefined") { + return; + } - //NOTE: Only works in Chrome v49+ - //https://github.com/mknichel/javascript-errors?tab=readme-ov-file#promise-rejection-events - window.onunhandledrejection = function (event) { - addEvents([Event(`${exception_state_name}.handle_frontend_exception`, { - stack: event.reason.stack, - })]) - return false; - } - - },[]) + window.onerror = function (msg, url, lineNo, columnNo, error) { + addEvents([ + Event(`${exception_state_name}.handle_frontend_exception`, { + stack: error.stack, + }), + ]); + return false; + }; + + //NOTE: Only works in Chrome v49+ + //https://github.com/mknichel/javascript-errors?tab=readme-ov-file#promise-rejection-events + window.onunhandledrejection = function (event) { + addEvents([ + Event(`${exception_state_name}.handle_frontend_exception`, { + stack: event.reason.stack, + }), + ]); + return false; + }; + }, []); // Main event loop. useEffect(() => { @@ -782,11 +809,11 @@ export const useEventLoop = ( // Route after the initial page hydration. useEffect(() => { const change_start = () => { - const main_state_dispatch = dispatch["reflex___state____state"] + const main_state_dispatch = dispatch["reflex___state____state"]; if (main_state_dispatch !== undefined) { - main_state_dispatch({ is_hydrated: false }) + main_state_dispatch({ is_hydrated: false }); } - } + }; const change_complete = () => addEvents(onLoadInternalEvent()); router.events.on("routeChangeStart", change_start); router.events.on("routeChangeComplete", change_complete); diff --git a/reflex/base.py b/reflex/base.py index 22b52cfb9..c334ddf56 100644 --- a/reflex/base.py +++ b/reflex/base.py @@ -47,6 +47,9 @@ def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None # shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True) pydantic_main.validate_field_name = validate_field_name # type: ignore +if TYPE_CHECKING: + from reflex.vars import Var + class Base(BaseModel): # pyright: ignore [reportUnboundVariable] """The base class subclassed by all Reflex classes. @@ -92,7 +95,7 @@ class Base(BaseModel): # pyright: ignore [reportUnboundVariable] return self @classmethod - def get_fields(cls) -> dict[str, Any]: + def get_fields(cls) -> dict[str, ModelField]: """Get the fields of the object. Returns: @@ -101,7 +104,7 @@ class Base(BaseModel): # pyright: ignore [reportUnboundVariable] return cls.__fields__ @classmethod - def add_field(cls, var: Any, default_value: Any): + def add_field(cls, var: Var, default_value: Any): """Add a pydantic field after class definition. Used by State.add_var() to correctly handle the new variable. @@ -110,7 +113,7 @@ class Base(BaseModel): # pyright: ignore [reportUnboundVariable] var: The variable to add a pydantic field for. default_value: The default value of the field """ - var_name = var._js_expr.split(".")[-1] + var_name = var._var_field_name new_field = ModelField.infer( name=var_name, value=default_value, @@ -133,13 +136,4 @@ class Base(BaseModel): # pyright: ignore [reportUnboundVariable] # Seems like this function signature was wrong all along? # If the user wants a field that we know of, get it and pass it off to _get_value key = getattr(self, key) - return self._get_value( - key, - to_dict=True, - by_alias=False, - include=None, - exclude=None, - exclude_unset=False, - exclude_defaults=False, - exclude_none=False, - ) + return key diff --git a/reflex/components/component.py b/reflex/components/component.py index d3ae05c89..e6bdfef06 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -25,6 +25,7 @@ import reflex.state from reflex.base import Base from reflex.compiler.templates import STATEFUL_COMPONENT from reflex.components.core.breakpoints import Breakpoints +from reflex.components.dynamic import load_dynamic_serializer from reflex.components.tags import Tag from reflex.constants import ( Dirs, @@ -52,7 +53,6 @@ from reflex.utils.imports import ( ParsedImportDict, parse_imports, ) -from reflex.utils.serializers import serializer from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var @@ -615,8 +615,8 @@ class Component(BaseComponent, ABC): if types._issubclass(field.type_, EventHandler): args_spec = None annotation = field.annotation - if hasattr(annotation, "__metadata__"): - args_spec = annotation.__metadata__[0] + if (metadata := getattr(annotation, "__metadata__", None)) is not None: + args_spec = metadata[0] default_triggers[field.name] = args_spec or (lambda: []) return default_triggers @@ -1882,19 +1882,6 @@ class NoSSRComponent(Component): return "".join((library_import, mod_import, opts_fragment)) -@serializer -def serialize_component(comp: Component): - """Serialize a component. - - Args: - comp: The component to serialize. - - Returns: - The serialized component. - """ - return str(comp) - - class StatefulComponent(BaseComponent): """A component that depends on state and is rendered outside of the page component. @@ -2307,3 +2294,6 @@ class MemoizationLeaf(Component): update={"disposition": MemoizationDisposition.ALWAYS} ) return comp + + +load_dynamic_serializer() diff --git a/reflex/components/datadisplay/code.py b/reflex/components/datadisplay/code.py index d9ab46e53..0b26e0c04 100644 --- a/reflex/components/datadisplay/code.py +++ b/reflex/components/datadisplay/code.py @@ -2,11 +2,12 @@ from __future__ import annotations +import enum from typing import Any, Dict, Literal, Optional, Union from typing_extensions import get_args -from reflex.components.component import Component +from reflex.components.component import Component, ComponentNamespace from reflex.components.core.cond import color_mode_cond from reflex.components.lucide.icon import Icon from reflex.components.radix.themes.components.button import Button @@ -14,9 +15,9 @@ from reflex.components.radix.themes.layout.box import Box from reflex.constants.colors import Color from reflex.event import set_clipboard from reflex.style import Style -from reflex.utils import format +from reflex.utils import console, format from reflex.utils.imports import ImportDict, ImportVar -from reflex.vars.base import LiteralVar, Var +from reflex.vars.base import LiteralVar, Var, VarData LiteralCodeBlockTheme = Literal[ "a11y-dark", @@ -405,31 +406,6 @@ class CodeBlock(Component): """ imports_: ImportDict = {} - themeString = str(self.theme) - - selected_themes = [] - - for possibleTheme in get_args(LiteralCodeBlockTheme): - if format.to_camel_case(possibleTheme) in themeString: - selected_themes.append(possibleTheme) - if possibleTheme in themeString: - selected_themes.append(possibleTheme) - - selected_themes = sorted(set(map(self.convert_theme_name, selected_themes))) - - imports_.update( - { - f"react-syntax-highlighter/dist/cjs/styles/prism/{theme}": [ - ImportVar( - tag=format.to_camel_case(theme), - is_default=True, - install=False, - ) - ] - for theme in selected_themes - } - ) - if ( self.language is not None and (language_without_quotes := str(self.language).replace('"', "")) @@ -480,14 +456,20 @@ class CodeBlock(Component): if "theme" not in props: # Default color scheme responds to global color mode. props["theme"] = color_mode_cond( - light=Var(_js_expr="oneLight"), - dark=Var(_js_expr="oneDark"), + light=Theme.one_light, + dark=Theme.one_dark, ) # react-syntax-highlighter doesnt have an explicit "light" or "dark" theme so we use one-light and one-dark # themes respectively to ensure code compatibility. if "theme" in props and not isinstance(props["theme"], Var): - props["theme"] = cls.convert_theme_name(props["theme"]) + props["theme"] = getattr(Theme, format.to_snake_case(props["theme"])) # type: ignore + console.deprecate( + feature_name="theme prop as string", + reason="Use code_block.themes instead.", + deprecation_version="0.6.0", + removal_version="0.7.0", + ) if can_copy: code = children[0] @@ -533,9 +515,7 @@ class CodeBlock(Component): def _render(self): out = super()._render() - theme = self.theme._replace( - _js_expr=replace_quotes_with_camel_case(str(self.theme)) - ) + theme = self.theme out.add_props(style=theme).remove_props("theme", "code").add_props( children=self.code @@ -558,4 +538,83 @@ class CodeBlock(Component): return theme -code_block = CodeBlock.create +def construct_theme_var(theme: str) -> Var: + """Construct a theme var. + + Args: + theme: The theme to construct. + + Returns: + The constructed theme var. + """ + return Var( + theme, + _var_data=VarData( + imports={ + f"react-syntax-highlighter/dist/cjs/styles/prism/{format.to_kebab_case(theme)}": [ + ImportVar(tag=theme, is_default=True, install=False) + ] + } + ), + ) + + +class Theme(enum.Enum): + """Themes for the CodeBlock component.""" + + a11y_dark = construct_theme_var("a11yDark") + atom_dark = construct_theme_var("atomDark") + cb = construct_theme_var("cb") + coldark_cold = construct_theme_var("coldarkCold") + coldark_dark = construct_theme_var("coldarkDark") + coy = construct_theme_var("coy") + coy_without_shadows = construct_theme_var("coyWithoutShadows") + darcula = construct_theme_var("darcula") + dark = construct_theme_var("oneDark") + dracula = construct_theme_var("dracula") + duotone_dark = construct_theme_var("duotoneDark") + duotone_earth = construct_theme_var("duotoneEarth") + duotone_forest = construct_theme_var("duotoneForest") + duotone_light = construct_theme_var("duotoneLight") + duotone_sea = construct_theme_var("duotoneSea") + duotone_space = construct_theme_var("duotoneSpace") + funky = construct_theme_var("funky") + ghcolors = construct_theme_var("ghcolors") + gruvbox_dark = construct_theme_var("gruvboxDark") + gruvbox_light = construct_theme_var("gruvboxLight") + holi_theme = construct_theme_var("holiTheme") + hopscotch = construct_theme_var("hopscotch") + light = construct_theme_var("oneLight") + lucario = construct_theme_var("lucario") + material_dark = construct_theme_var("materialDark") + material_light = construct_theme_var("materialLight") + material_oceanic = construct_theme_var("materialOceanic") + night_owl = construct_theme_var("nightOwl") + nord = construct_theme_var("nord") + okaidia = construct_theme_var("okaidia") + one_dark = construct_theme_var("oneDark") + one_light = construct_theme_var("oneLight") + pojoaque = construct_theme_var("pojoaque") + prism = construct_theme_var("prism") + shades_of_purple = construct_theme_var("shadesOfPurple") + solarized_dark_atom = construct_theme_var("solarizedDarkAtom") + solarizedlight = construct_theme_var("solarizedlight") + synthwave84 = construct_theme_var("synthwave84") + tomorrow = construct_theme_var("tomorrow") + twilight = construct_theme_var("twilight") + vs = construct_theme_var("vs") + vs_dark = construct_theme_var("vsDark") + vsc_dark_plus = construct_theme_var("vscDarkPlus") + xonokai = construct_theme_var("xonokai") + z_touch = construct_theme_var("zTouch") + + +class CodeblockNamespace(ComponentNamespace): + """Namespace for the CodeBlock component.""" + + themes = Theme + + __call__ = CodeBlock.create + + +code_block = CodeblockNamespace() diff --git a/reflex/components/datadisplay/code.pyi b/reflex/components/datadisplay/code.pyi index 49a7bbb51..7a074d1f9 100644 --- a/reflex/components/datadisplay/code.pyi +++ b/reflex/components/datadisplay/code.pyi @@ -3,9 +3,10 @@ # ------------------- DO NOT EDIT ---------------------- # This file was generated by `reflex/utils/pyi_generator.py`! # ------------------------------------------------------ +import enum from typing import Any, Callable, Dict, Literal, Optional, Union, overload -from reflex.components.component import Component +from reflex.components.component import Component, ComponentNamespace from reflex.constants.colors import Color from reflex.event import EventHandler, EventSpec from reflex.style import Style @@ -1001,4 +1002,706 @@ class CodeBlock(Component): @staticmethod def convert_theme_name(theme) -> str: ... -code_block = CodeBlock.create +def construct_theme_var(theme: str) -> Var: ... + +class Theme(enum.Enum): + a11y_dark = construct_theme_var("a11yDark") + atom_dark = construct_theme_var("atomDark") + cb = construct_theme_var("cb") + coldark_cold = construct_theme_var("coldarkCold") + coldark_dark = construct_theme_var("coldarkDark") + coy = construct_theme_var("coy") + coy_without_shadows = construct_theme_var("coyWithoutShadows") + darcula = construct_theme_var("darcula") + dark = construct_theme_var("oneDark") + dracula = construct_theme_var("dracula") + duotone_dark = construct_theme_var("duotoneDark") + duotone_earth = construct_theme_var("duotoneEarth") + duotone_forest = construct_theme_var("duotoneForest") + duotone_light = construct_theme_var("duotoneLight") + duotone_sea = construct_theme_var("duotoneSea") + duotone_space = construct_theme_var("duotoneSpace") + funky = construct_theme_var("funky") + ghcolors = construct_theme_var("ghcolors") + gruvbox_dark = construct_theme_var("gruvboxDark") + gruvbox_light = construct_theme_var("gruvboxLight") + holi_theme = construct_theme_var("holiTheme") + hopscotch = construct_theme_var("hopscotch") + light = construct_theme_var("oneLight") + lucario = construct_theme_var("lucario") + material_dark = construct_theme_var("materialDark") + material_light = construct_theme_var("materialLight") + material_oceanic = construct_theme_var("materialOceanic") + night_owl = construct_theme_var("nightOwl") + nord = construct_theme_var("nord") + okaidia = construct_theme_var("okaidia") + one_dark = construct_theme_var("oneDark") + one_light = construct_theme_var("oneLight") + pojoaque = construct_theme_var("pojoaque") + prism = construct_theme_var("prism") + shades_of_purple = construct_theme_var("shadesOfPurple") + solarized_dark_atom = construct_theme_var("solarizedDarkAtom") + solarizedlight = construct_theme_var("solarizedlight") + synthwave84 = construct_theme_var("synthwave84") + tomorrow = construct_theme_var("tomorrow") + twilight = construct_theme_var("twilight") + vs = construct_theme_var("vs") + vs_dark = construct_theme_var("vsDark") + vsc_dark_plus = construct_theme_var("vscDarkPlus") + xonokai = construct_theme_var("xonokai") + z_touch = construct_theme_var("zTouch") + +class CodeblockNamespace(ComponentNamespace): + themes = Theme + + @staticmethod + def __call__( + *children, + can_copy: Optional[bool] = False, + copy_button: Optional[Union[Component, bool]] = None, + theme: Optional[Union[Any, Var[Any]]] = None, + language: Optional[ + Union[ + Literal[ + "abap", + "abnf", + "actionscript", + "ada", + "agda", + "al", + "antlr4", + "apacheconf", + "apex", + "apl", + "applescript", + "aql", + "arduino", + "arff", + "asciidoc", + "asm6502", + "asmatmel", + "aspnet", + "autohotkey", + "autoit", + "avisynth", + "avro-idl", + "bash", + "basic", + "batch", + "bbcode", + "bicep", + "birb", + "bison", + "bnf", + "brainfuck", + "brightscript", + "bro", + "bsl", + "c", + "cfscript", + "chaiscript", + "cil", + "clike", + "clojure", + "cmake", + "cobol", + "coffeescript", + "concurnas", + "coq", + "core", + "cpp", + "crystal", + "csharp", + "cshtml", + "csp", + "css", + "css-extras", + "csv", + "cypher", + "d", + "dart", + "dataweave", + "dax", + "dhall", + "diff", + "django", + "dns-zone-file", + "docker", + "dot", + "ebnf", + "editorconfig", + "eiffel", + "ejs", + "elixir", + "elm", + "erb", + "erlang", + "etlua", + "excel-formula", + "factor", + "false", + "firestore-security-rules", + "flow", + "fortran", + "fsharp", + "ftl", + "gap", + "gcode", + "gdscript", + "gedcom", + "gherkin", + "git", + "glsl", + "gml", + "gn", + "go", + "go-module", + "graphql", + "groovy", + "haml", + "handlebars", + "haskell", + "haxe", + "hcl", + "hlsl", + "hoon", + "hpkp", + "hsts", + "http", + "ichigojam", + "icon", + "icu-message-format", + "idris", + "iecst", + "ignore", + "index", + "inform7", + "ini", + "io", + "j", + "java", + "javadoc", + "javadoclike", + "javascript", + "javastacktrace", + "jexl", + "jolie", + "jq", + "js-extras", + "js-templates", + "jsdoc", + "json", + "json5", + "jsonp", + "jsstacktrace", + "jsx", + "julia", + "keepalived", + "keyman", + "kotlin", + "kumir", + "kusto", + "latex", + "latte", + "less", + "lilypond", + "liquid", + "lisp", + "livescript", + "llvm", + "log", + "lolcode", + "lua", + "magma", + "makefile", + "markdown", + "markup", + "markup-templating", + "matlab", + "maxscript", + "mel", + "mermaid", + "mizar", + "mongodb", + "monkey", + "moonscript", + "n1ql", + "n4js", + "nand2tetris-hdl", + "naniscript", + "nasm", + "neon", + "nevod", + "nginx", + "nim", + "nix", + "nsis", + "objectivec", + "ocaml", + "opencl", + "openqasm", + "oz", + "parigp", + "parser", + "pascal", + "pascaligo", + "pcaxis", + "peoplecode", + "perl", + "php", + "php-extras", + "phpdoc", + "plsql", + "powerquery", + "powershell", + "processing", + "prolog", + "promql", + "properties", + "protobuf", + "psl", + "pug", + "puppet", + "pure", + "purebasic", + "purescript", + "python", + "q", + "qml", + "qore", + "qsharp", + "r", + "racket", + "reason", + "regex", + "rego", + "renpy", + "rest", + "rip", + "roboconf", + "robotframework", + "ruby", + "rust", + "sas", + "sass", + "scala", + "scheme", + "scss", + "shell-session", + "smali", + "smalltalk", + "smarty", + "sml", + "solidity", + "solution-file", + "soy", + "sparql", + "splunk-spl", + "sqf", + "sql", + "squirrel", + "stan", + "stylus", + "swift", + "systemd", + "t4-cs", + "t4-templating", + "t4-vb", + "tap", + "tcl", + "textile", + "toml", + "tremor", + "tsx", + "tt2", + "turtle", + "twig", + "typescript", + "typoscript", + "unrealscript", + "uorazor", + "uri", + "v", + "vala", + "vbnet", + "velocity", + "verilog", + "vhdl", + "vim", + "visual-basic", + "warpscript", + "wasm", + "web-idl", + "wiki", + "wolfram", + "wren", + "xeora", + "xml-doc", + "xojo", + "xquery", + "yaml", + "yang", + "zig", + ], + Var[ + Literal[ + "abap", + "abnf", + "actionscript", + "ada", + "agda", + "al", + "antlr4", + "apacheconf", + "apex", + "apl", + "applescript", + "aql", + "arduino", + "arff", + "asciidoc", + "asm6502", + "asmatmel", + "aspnet", + "autohotkey", + "autoit", + "avisynth", + "avro-idl", + "bash", + "basic", + "batch", + "bbcode", + "bicep", + "birb", + "bison", + "bnf", + "brainfuck", + "brightscript", + "bro", + "bsl", + "c", + "cfscript", + "chaiscript", + "cil", + "clike", + "clojure", + "cmake", + "cobol", + "coffeescript", + "concurnas", + "coq", + "core", + "cpp", + "crystal", + "csharp", + "cshtml", + "csp", + "css", + "css-extras", + "csv", + "cypher", + "d", + "dart", + "dataweave", + "dax", + "dhall", + "diff", + "django", + "dns-zone-file", + "docker", + "dot", + "ebnf", + "editorconfig", + "eiffel", + "ejs", + "elixir", + "elm", + "erb", + "erlang", + "etlua", + "excel-formula", + "factor", + "false", + "firestore-security-rules", + "flow", + "fortran", + "fsharp", + "ftl", + "gap", + "gcode", + "gdscript", + "gedcom", + "gherkin", + "git", + "glsl", + "gml", + "gn", + "go", + "go-module", + "graphql", + "groovy", + "haml", + "handlebars", + "haskell", + "haxe", + "hcl", + "hlsl", + "hoon", + "hpkp", + "hsts", + "http", + "ichigojam", + "icon", + "icu-message-format", + "idris", + "iecst", + "ignore", + "index", + "inform7", + "ini", + "io", + "j", + "java", + "javadoc", + "javadoclike", + "javascript", + "javastacktrace", + "jexl", + "jolie", + "jq", + "js-extras", + "js-templates", + "jsdoc", + "json", + "json5", + "jsonp", + "jsstacktrace", + "jsx", + "julia", + "keepalived", + "keyman", + "kotlin", + "kumir", + "kusto", + "latex", + "latte", + "less", + "lilypond", + "liquid", + "lisp", + "livescript", + "llvm", + "log", + "lolcode", + "lua", + "magma", + "makefile", + "markdown", + "markup", + "markup-templating", + "matlab", + "maxscript", + "mel", + "mermaid", + "mizar", + "mongodb", + "monkey", + "moonscript", + "n1ql", + "n4js", + "nand2tetris-hdl", + "naniscript", + "nasm", + "neon", + "nevod", + "nginx", + "nim", + "nix", + "nsis", + "objectivec", + "ocaml", + "opencl", + "openqasm", + "oz", + "parigp", + "parser", + "pascal", + "pascaligo", + "pcaxis", + "peoplecode", + "perl", + "php", + "php-extras", + "phpdoc", + "plsql", + "powerquery", + "powershell", + "processing", + "prolog", + "promql", + "properties", + "protobuf", + "psl", + "pug", + "puppet", + "pure", + "purebasic", + "purescript", + "python", + "q", + "qml", + "qore", + "qsharp", + "r", + "racket", + "reason", + "regex", + "rego", + "renpy", + "rest", + "rip", + "roboconf", + "robotframework", + "ruby", + "rust", + "sas", + "sass", + "scala", + "scheme", + "scss", + "shell-session", + "smali", + "smalltalk", + "smarty", + "sml", + "solidity", + "solution-file", + "soy", + "sparql", + "splunk-spl", + "sqf", + "sql", + "squirrel", + "stan", + "stylus", + "swift", + "systemd", + "t4-cs", + "t4-templating", + "t4-vb", + "tap", + "tcl", + "textile", + "toml", + "tremor", + "tsx", + "tt2", + "turtle", + "twig", + "typescript", + "typoscript", + "unrealscript", + "uorazor", + "uri", + "v", + "vala", + "vbnet", + "velocity", + "verilog", + "vhdl", + "vim", + "visual-basic", + "warpscript", + "wasm", + "web-idl", + "wiki", + "wolfram", + "wren", + "xeora", + "xml-doc", + "xojo", + "xquery", + "yaml", + "yang", + "zig", + ] + ], + ] + ] = None, + code: Optional[Union[Var[str], str]] = None, + show_line_numbers: Optional[Union[Var[bool], bool]] = None, + starting_line_number: Optional[Union[Var[int], int]] = None, + wrap_long_lines: Optional[Union[Var[bool], bool]] = None, + custom_style: Optional[Dict[str, Union[str, Var, Color]]] = None, + code_tag_props: Optional[Union[Dict[str, str], Var[Dict[str, str]]]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, + on_blur: Optional[Union[EventHandler, EventSpec, list, Callable, Var]] = None, + on_click: Optional[Union[EventHandler, EventSpec, list, Callable, Var]] = None, + on_context_menu: Optional[ + Union[EventHandler, EventSpec, list, Callable, Var] + ] = None, + on_double_click: Optional[ + Union[EventHandler, EventSpec, list, Callable, Var] + ] = None, + on_focus: Optional[Union[EventHandler, EventSpec, list, Callable, Var]] = None, + on_mount: Optional[Union[EventHandler, EventSpec, list, Callable, Var]] = None, + on_mouse_down: Optional[ + Union[EventHandler, EventSpec, list, Callable, Var] + ] = None, + on_mouse_enter: Optional[ + Union[EventHandler, EventSpec, list, Callable, Var] + ] = None, + on_mouse_leave: Optional[ + Union[EventHandler, EventSpec, list, Callable, Var] + ] = None, + on_mouse_move: Optional[ + Union[EventHandler, EventSpec, list, Callable, Var] + ] = None, + on_mouse_out: Optional[ + Union[EventHandler, EventSpec, list, Callable, Var] + ] = None, + on_mouse_over: Optional[ + Union[EventHandler, EventSpec, list, Callable, Var] + ] = None, + on_mouse_up: Optional[ + Union[EventHandler, EventSpec, list, Callable, Var] + ] = None, + on_scroll: Optional[Union[EventHandler, EventSpec, list, Callable, Var]] = None, + on_unmount: Optional[ + Union[EventHandler, EventSpec, list, Callable, Var] + ] = None, + **props, + ) -> "CodeBlock": + """Create a text component. + + Args: + *children: The children of the component. + can_copy: Whether a copy button should appears. + copy_button: A custom copy button to override the default one. + theme: The theme to use ("light" or "dark"). + language: The language to use. + code: The code to display. + show_line_numbers: If this is enabled line numbers will be shown next to the code block. + starting_line_number: The starting line number to use. + wrap_long_lines: Whether to wrap long lines. + custom_style: A custom style for the code block. + code_tag_props: Props passed down to the code tag. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: The props to pass to the component. + + Returns: + The text component. + """ + ... + +code_block = CodeblockNamespace() diff --git a/reflex/components/dynamic.py b/reflex/components/dynamic.py new file mode 100644 index 000000000..6ae78161f --- /dev/null +++ b/reflex/components/dynamic.py @@ -0,0 +1,143 @@ +"""Components that are dynamically generated on the backend.""" + +from reflex import constants +from reflex.utils import imports +from reflex.utils.serializers import serializer +from reflex.vars import Var, get_unique_variable_name +from reflex.vars.base import VarData, transform + + +def get_cdn_url(lib: str) -> str: + """Get the CDN URL for a library. + + Args: + lib: The library to get the CDN URL for. + + Returns: + The CDN URL for the library. + """ + return f"https://cdn.jsdelivr.net/npm/{lib}" + "/+esm" + + +def load_dynamic_serializer(): + """Load the serializer for dynamic components.""" + # Causes a circular import, so we import here. + from reflex.components.component import Component + + @serializer + def make_component(component: Component) -> str: + """Generate the code for a dynamic component. + + Args: + component: The component to generate code for. + + Returns: + The generated code + """ + # Causes a circular import, so we import here. + from reflex.compiler import templates, utils + + rendered_components = {} + # Include dynamic imports in the shared component. + if dynamic_imports := component._get_all_dynamic_imports(): + rendered_components.update( + {dynamic_import: None for dynamic_import in dynamic_imports} + ) + + # Include custom code in the shared component. + rendered_components.update( + {code: None for code in component._get_all_custom_code()}, + ) + + rendered_components[ + templates.STATEFUL_COMPONENT.render( + tag_name="MySSRComponent", + memo_trigger_hooks=[], + component=component, + ) + ] = None + + imports = {} + for lib, names in component._get_all_imports().items(): + if ( + not lib.startswith((".", "/")) + and not lib.startswith("http") + and lib != "react" + ): + imports[get_cdn_url(lib)] = names + else: + imports[lib] = names + + module_code_lines = templates.STATEFUL_COMPONENTS.render( + imports=utils.compile_imports(imports), + memoized_code="\n".join(rendered_components), + ).splitlines()[1:] + + # Rewrite imports from `/` to destructure from window + for ix, line in enumerate(module_code_lines[:]): + if line.startswith("import "): + if 'from "/' in line: + module_code_lines[ix] = ( + line.replace("import ", "const ", 1).replace( + " from ", " = window['__reflex'][", 1 + ) + + "]" + ) + elif 'from "react"' in line: + module_code_lines[ix] = line.replace( + "import ", "const ", 1 + ).replace(' from "react"', " = window.__reflex.react", 1) + if line.startswith("export function"): + module_code_lines[ix] = line.replace( + "export function", "export default function", 1 + ) + + module_code_lines.insert(0, "const React = window.__reflex.react;") + + return "//__reflex_evaluate\n" + "\n".join(module_code_lines) + + @transform + def evaluate_component(js_string: Var[str]) -> Var[Component]: + """Evaluate a component. + + Args: + js_string: The JavaScript string to evaluate. + + Returns: + The evaluated JavaScript string. + """ + unique_var_name = get_unique_variable_name() + + return js_string._replace( + _js_expr=unique_var_name, + _var_type=Component, + merge_var_data=VarData.merge( + VarData( + imports={ + f"/{constants.Dirs.STATE_PATH}": [ + imports.ImportVar(tag="evalReactComponent"), + ], + "react": [ + imports.ImportVar(tag="useState"), + imports.ImportVar(tag="useEffect"), + ], + }, + hooks={ + f"const [{unique_var_name}, set_{unique_var_name}] = useState(null);": None, + "useEffect(() => {" + "let isMounted = true;" + f"evalReactComponent({str(js_string)})" + ".then((component) => {" + "if (isMounted) {" + f"set_{unique_var_name}(component);" + "}" + "});" + "return () => {" + "isMounted = false;" + "};" + "}" + f", [{str(js_string)}]);": None, + }, + ), + ), + ) diff --git a/reflex/constants/installer.py b/reflex/constants/installer.py index 4a7027ee8..6f31b08f1 100644 --- a/reflex/constants/installer.py +++ b/reflex/constants/installer.py @@ -111,6 +111,7 @@ class PackageJson(SimpleNamespace): PATH = "package.json" DEPENDENCIES = { + "@babel/standalone": "7.25.3", "@emotion/react": "11.11.1", "axios": "1.6.0", "json5": "2.2.3", diff --git a/reflex/state.py b/reflex/state.py index 6dac48d8f..8b32d1a07 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -40,6 +40,7 @@ from reflex.vars.base import ( DynamicRouteVar, Var, computed_var, + dispatch, is_computed_var, ) @@ -336,6 +337,29 @@ class EventHandlerSetVar(EventHandler): return super().__call__(*args) +if TYPE_CHECKING: + from pydantic.v1.fields import ModelField + + +def get_var_for_field(cls: Type[BaseState], f: ModelField): + """Get a Var instance for a Pydantic field. + + Args: + cls: The state class. + f: The Pydantic field. + + Returns: + The Var instance. + """ + field_name = format.format_state_name(cls.get_full_name()) + "." + f.name + + return dispatch( + field_name=field_name, + var_data=VarData.from_state(cls, f.name), + result_var_type=f.outer_type_, + ) + + class BaseState(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" @@ -556,11 +580,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Set the base and computed vars. cls.base_vars = { - f.name: Var( - _js_expr=format.format_state_name(cls.get_full_name()) + "." + f.name, - _var_type=f.outer_type_, - _var_data=VarData.from_state(cls), - ).guess_type() + f.name: get_var_for_field(cls, f) for f in cls.get_fields().values() if f.name not in cls.get_skip_vars() } @@ -948,7 +968,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): var = Var( _js_expr=format.format_state_name(cls.get_full_name()) + "." + name, _var_type=type_, - _var_data=VarData.from_state(cls), + _var_data=VarData.from_state(cls, name), ).guess_type() # add the pydantic field dynamically (must be done before _init_var) @@ -974,10 +994,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Args: prop: The var instance to set. """ - acutal_var_name = ( - prop._js_expr if "." not in prop._js_expr else prop._js_expr.split(".")[-1] - ) - setattr(cls, acutal_var_name, prop) + setattr(cls, prop._var_field_name, prop) @classmethod def _create_event_handler(cls, fn): @@ -1017,10 +1034,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): prop: The var to set the default value for. """ # Get the pydantic field for the var. - if "." in prop._js_expr: - field = cls.get_fields()[prop._js_expr.split(".")[-1]] - else: - field = cls.get_fields()[prop._js_expr] + field = cls.get_fields()[prop._var_field_name] if field.required: default_value = prop.get_default_value() if default_value is not None: @@ -1761,11 +1775,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): .union(self._always_dirty_computed_vars) ) - subdelta = { - prop: getattr(self, prop) + subdelta: Dict[str, Any] = { + prop: self.get_value(getattr(self, prop)) for prop in delta_vars if not types.is_backend_base_variable(prop, type(self)) } + if len(subdelta) > 0: delta[self.get_full_name()] = subdelta diff --git a/reflex/utils/format.py b/reflex/utils/format.py index e8b040230..c804b2946 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -672,6 +672,8 @@ def format_library_name(library_fullname: str): Returns: The name without the @version if it was part of the name """ + if library_fullname.startswith("https://"): + return library_fullname lib, at, version = library_fullname.rpartition("@") if not lib: lib = at + version diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 37498911f..9fac6fd1f 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -20,6 +20,7 @@ from typing import ( Any, Callable, Dict, + FrozenSet, Generic, Iterable, List, @@ -72,6 +73,7 @@ if TYPE_CHECKING: VAR_TYPE = TypeVar("VAR_TYPE", covariant=True) +OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE") warnings.filterwarnings("ignore", message="fields may not start with an underscore") @@ -119,6 +121,17 @@ class Var(Generic[VAR_TYPE]): """ return self._js_expr + @property + def _var_field_name(self) -> str: + """The name of the field. + + Returns: + The name of the field. + """ + var_data = self._get_all_var_data() + field_name = var_data.field_name if var_data else None + return field_name or self._js_expr + @property @deprecated("Use `_js_expr` instead.") def _var_name_unwrapped(self) -> str: @@ -181,7 +194,19 @@ class Var(Generic[VAR_TYPE]): and self._get_all_var_data() == other._get_all_var_data() ) - def _replace(self, merge_var_data=None, **kwargs: Any): + @overload + def _replace( + self, _var_type: Type[OTHER_VAR_TYPE], merge_var_data=None, **kwargs: Any + ) -> Var[OTHER_VAR_TYPE]: ... + + @overload + def _replace( + self, _var_type: GenericType | None = None, merge_var_data=None, **kwargs: Any + ) -> Self: ... + + def _replace( + self, _var_type: GenericType | None = None, merge_var_data=None, **kwargs: Any + ) -> Self | Var: """Make a copy of this Var with updated fields. Args: @@ -205,14 +230,20 @@ class Var(Generic[VAR_TYPE]): "The _var_full_name_needs_state_prefix argument is not supported for Var." ) - return dataclasses.replace( + value_with_replaced = dataclasses.replace( self, + _var_type=_var_type or self._var_type, _var_data=VarData.merge( kwargs.get("_var_data", self._var_data), merge_var_data ), **kwargs, ) + if (js_expr := kwargs.get("_js_expr", None)) is not None: + object.__setattr__(value_with_replaced, "_js_expr", js_expr) + + return value_with_replaced + @classmethod def create( cls, @@ -566,8 +597,7 @@ class Var(Generic[VAR_TYPE]): Returns: The name of the setter function. """ - var_name_parts = self._js_expr.split(".") - setter = constants.SETTER_PREFIX + var_name_parts[-1] + setter = constants.SETTER_PREFIX + self._var_field_name var_data = self._get_all_var_data() if var_data is None: return setter @@ -581,7 +611,7 @@ class Var(Generic[VAR_TYPE]): Returns: A function that that creates a setter for the var. """ - actual_name = self._js_expr.split(".")[-1] + actual_name = self._var_field_name def setter(state: BaseState, value: Any): """Get the setter for the var. @@ -623,7 +653,9 @@ class Var(Generic[VAR_TYPE]): return StateOperation.create( formatted_state_name, self, - _var_data=VarData.merge(VarData.from_state(state), self._var_data), + _var_data=VarData.merge( + VarData.from_state(state, self._js_expr), self._var_data + ), ).guess_type() def __eq__(self, other: Var | Any) -> BooleanVar: @@ -1706,12 +1738,18 @@ class ComputedVar(Var[RETURN_TYPE]): while self._js_expr in state_where_defined.inherited_vars: state_where_defined = state_where_defined.get_parent_state() - return self._replace( - _js_expr=format_state_name(state_where_defined.get_full_name()) + field_name = ( + format_state_name(state_where_defined.get_full_name()) + "." - + self._js_expr, - merge_var_data=VarData.from_state(state_where_defined), - ).guess_type() + + self._js_expr + ) + + return dispatch( + field_name, + var_data=VarData.from_state(state_where_defined, self._js_expr), + result_var_type=self._var_type, + existing_var=self, + ) if not self._cache: return self.fget(instance) @@ -2339,6 +2377,9 @@ class VarData: # The name of the enclosing state. state: str = dataclasses.field(default="") + # The name of the field in the state. + field_name: str = dataclasses.field(default="") + # Imports needed to render this var imports: ImmutableParsedImportDict = dataclasses.field(default_factory=tuple) @@ -2348,6 +2389,7 @@ class VarData: def __init__( self, state: str = "", + field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, hooks: dict[str, None] | None = None, ): @@ -2355,6 +2397,7 @@ class VarData: Args: state: The name of the enclosing state. + field_name: The name of the field in the state. imports: Imports needed to render this var. hooks: Hooks that need to be present in the component to render this var. """ @@ -2364,6 +2407,7 @@ class VarData: ) ) object.__setattr__(self, "state", state) + object.__setattr__(self, "field_name", field_name) object.__setattr__(self, "imports", immutable_imports) object.__setattr__(self, "hooks", tuple(hooks or {})) @@ -2386,12 +2430,14 @@ class VarData: The merged var data object. """ state = "" + field_name = "" _imports = {} hooks = {} for var_data in others: if var_data is None: continue state = state or var_data.state + field_name = field_name or var_data.field_name _imports = imports.merge_imports(_imports, var_data.imports) hooks.update( var_data.hooks @@ -2399,9 +2445,10 @@ class VarData: else {k: None for k in var_data.hooks} ) - if state or _imports or hooks: + if state or _imports or hooks or field_name: return VarData( state=state, + field_name=field_name, imports=_imports, hooks=hooks, ) @@ -2413,38 +2460,15 @@ class VarData: Returns: True if any field is set to a non-default value. """ - return bool(self.state or self.imports or self.hooks) - - def __eq__(self, other: Any) -> bool: - """Check if two var data objects are equal. - - Args: - other: The other var data object to compare. - - Returns: - True if all fields are equal and collapsed imports are equal. - """ - if not isinstance(other, VarData): - return False - - # Don't compare interpolations - that's added in by the decoder, and - # not part of the vardata itself. - return ( - self.state == other.state - and self.hooks - == ( - other.hooks if isinstance(other, VarData) else tuple(other.hooks.keys()) - ) - and imports.collapse_imports(self.imports) - == imports.collapse_imports(other.imports) - ) + return bool(self.state or self.imports or self.hooks or self.field_name) @classmethod - def from_state(cls, state: Type[BaseState] | str) -> VarData: + def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: """Set the state of the var. Args: state: The state to set or the full name of the state. + field_name: The name of the field in the state. Optional. Returns: The var with the set state. @@ -2452,8 +2476,9 @@ class VarData: from reflex.utils import format state_name = state if isinstance(state, str) else state.get_full_name() - new_var_data = VarData( + return VarData( state=state_name, + field_name=field_name, hooks={ "const {0} = useContext(StateContexts.{0})".format( format.format_state_name(state_name) @@ -2464,7 +2489,6 @@ class VarData: "react": [ImportVar(tag="useContext")], }, ) - return new_var_data def _decode_var_immutable(value: str) -> tuple[VarData | None, str]: @@ -2561,3 +2585,238 @@ REPLACED_NAMES = { "set_state": "_var_set_state", "deps": "_deps", } + + +dispatchers: Dict[GenericType, Callable[[Var], Var]] = {} + + +def transform(fn: Callable[[Var], Var]) -> Callable[[Var], Var]: + """Register a function to transform a Var. + + Args: + fn: The function to register. + + Returns: + The decorator. + + Raises: + TypeError: If the return type of the function is not a Var. + TypeError: If the Var return type does not have a generic type. + ValueError: If a function for the generic type is already registered. + """ + return_type = fn.__annotations__["return"] + + origin = get_origin(return_type) + + if origin is not Var: + raise TypeError( + f"Expected return type of {fn.__name__} to be a Var, got {origin}." + ) + + generic_args = get_args(return_type) + + if not generic_args: + raise TypeError( + f"Expected Var return type of {fn.__name__} to have a generic type." + ) + + generic_type = get_origin(generic_args[0]) or generic_args[0] + + if generic_type in dispatchers: + raise ValueError(f"Function for {generic_type} already registered.") + + dispatchers[generic_type] = fn + + return fn + + +def generic_type_to_actual_type_map( + generic_type: GenericType, actual_type: GenericType +) -> Dict[TypeVar, GenericType]: + """Map the generic type to the actual type. + + Args: + generic_type: The generic type. + actual_type: The actual type. + + Returns: + The mapping of type variables to actual types. + + Raises: + TypeError: If the generic type and actual type do not match. + TypeError: If the number of generic arguments and actual arguments do not match. + """ + generic_origin = get_origin(generic_type) or generic_type + actual_origin = get_origin(actual_type) or actual_type + + if generic_origin is not actual_origin: + if isinstance(generic_origin, TypeVar): + return {generic_origin: actual_origin} + raise TypeError( + f"Type mismatch: expected {generic_origin}, got {actual_origin}." + ) + + generic_args = get_args(generic_type) + actual_args = get_args(actual_type) + + if len(generic_args) != len(actual_args): + raise TypeError( + f"Number of generic arguments mismatch: expected {len(generic_args)}, got {len(actual_args)}." + ) + + # call recursively for nested generic types and merge the results + return { + k: v + for generic_arg, actual_arg in zip(generic_args, actual_args) + for k, v in generic_type_to_actual_type_map(generic_arg, actual_arg).items() + } + + +def resolve_generic_type_with_mapping( + generic_type: GenericType, type_mapping: Dict[TypeVar, GenericType] +): + """Resolve a generic type with a type mapping. + + Args: + generic_type: The generic type. + type_mapping: The type mapping. + + Returns: + The resolved generic type. + """ + if isinstance(generic_type, TypeVar): + return type_mapping.get(generic_type, generic_type) + + generic_origin = get_origin(generic_type) or generic_type + + generic_args = get_args(generic_type) + + if not generic_args: + return generic_type + + mapping_for_older_python = { + list: List, + set: Set, + dict: Dict, + tuple: Tuple, + frozenset: FrozenSet, + } + + return mapping_for_older_python.get(generic_origin, generic_origin)[ + tuple( + resolve_generic_type_with_mapping(arg, type_mapping) for arg in generic_args + ) + ] + + +def resolve_arg_type_from_return_type( + arg_type: GenericType, return_type: GenericType, actual_return_type: GenericType +) -> GenericType: + """Resolve the argument type from the return type. + + Args: + arg_type: The argument type. + return_type: The return type. + actual_return_type: The requested return type. + + Returns: + The argument type without the generics that are resolved. + """ + return resolve_generic_type_with_mapping( + arg_type, generic_type_to_actual_type_map(return_type, actual_return_type) + ) + + +def dispatch( + field_name: str, + var_data: VarData, + result_var_type: GenericType, + existing_var: Var | None = None, +) -> Var: + """Dispatch a Var to the appropriate transformation function. + + Args: + field_name: The name of the field. + var_data: The VarData associated with the Var. + result_var_type: The type of the Var. + existing_var: The existing Var to transform. Optional. + + Returns: + The transformed Var. + + Raises: + TypeError: If the return type of the function is not a Var. + TypeError: If the Var return type does not have a generic type. + TypeError: If the first argument of the function is not a Var. + TypeError: If the first argument of the function does not have a generic type + """ + result_origin_var_type = get_origin(result_var_type) or result_var_type + + if result_origin_var_type in dispatchers: + fn = dispatchers[result_origin_var_type] + fn_first_arg_type = list(inspect.signature(fn).parameters.values())[ + 0 + ].annotation + + fn_return = inspect.signature(fn).return_annotation + + fn_return_origin = get_origin(fn_return) or fn_return + + if fn_return_origin is not Var: + raise TypeError( + f"Expected return type of {fn.__name__} to be a Var, got {fn_return}." + ) + + fn_return_generic_args = get_args(fn_return) + + if not fn_return_generic_args: + raise TypeError(f"Expected generic type of {fn_return} to be a type.") + + arg_origin = get_origin(fn_first_arg_type) or fn_first_arg_type + + if arg_origin is not Var: + raise TypeError( + f"Expected first argument of {fn.__name__} to be a Var, got {fn_first_arg_type}." + ) + + arg_generic_args = get_args(fn_first_arg_type) + + if not arg_generic_args: + raise TypeError( + f"Expected generic type of {fn_first_arg_type} to be a type." + ) + + arg_type = arg_generic_args[0] + fn_return_type = fn_return_generic_args[0] + + var = ( + Var( + field_name, + _var_data=var_data, + _var_type=resolve_arg_type_from_return_type( + arg_type, fn_return_type, result_var_type + ), + ).guess_type() + if existing_var is None + else existing_var._replace( + _var_type=resolve_arg_type_from_return_type( + arg_type, fn_return_type, result_var_type + ), + _var_data=var_data, + _js_expr=field_name, + ).guess_type() + ) + + return fn(var) + + if existing_var is not None: + return existing_var._replace( + _js_expr=field_name, + _var_data=var_data, + _var_type=result_var_type, + ).guess_type() + return Var( + field_name, + _var_data=var_data, + _var_type=result_var_type, + ).guess_type() diff --git a/tests/components/datadisplay/test_code.py b/tests/components/datadisplay/test_code.py index 000ae2d26..809c68fe5 100644 --- a/tests/components/datadisplay/test_code.py +++ b/tests/components/datadisplay/test_code.py @@ -1,10 +1,11 @@ import pytest -from reflex.components.datadisplay.code import CodeBlock +from reflex.components.datadisplay.code import CodeBlock, Theme @pytest.mark.parametrize( - "theme, expected", [("light", '"one-light"'), ("dark", '"one-dark"')] + "theme, expected", + [(Theme.one_light, "oneLight"), (Theme.one_dark, "oneDark")], ) def test_code_light_dark_theme(theme, expected): code_block = CodeBlock.create(theme=theme) diff --git a/tests/test_state.py b/tests/test_state.py index d12727615..21584fff9 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -2520,7 +2520,7 @@ def test_json_dumps_with_mutables(): items: List[Foo] = [Foo()] dict_val = MutableContainsBase().dict() - assert isinstance(dict_val[MutableContainsBase.get_full_name()]["items"][0], dict) + assert isinstance(dict_val[MutableContainsBase.get_full_name()]["items"][0], Foo) val = json_dumps(dict_val) f_items = '[{"tags": ["123", "456"]}]' f_formatted_router = str(formatted_router).replace("'", '"') diff --git a/tests/test_var.py b/tests/test_var.py index 90bf3ee05..5cd816c9b 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union, cast import pytest from pandas import DataFrame +import reflex as rx from reflex.base import Base from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG from reflex.state import BaseState @@ -1052,6 +1053,29 @@ def test_object_operations(): ) +def test_var_component(): + class ComponentVarState(rx.State): + field_var: rx.Component = rx.text("I am a field var") + + @rx.var + def computed_var(self) -> rx.Component: + return rx.text("I am a computed var") + + def has_eval_react_component(var: Var): + var_data = var._get_all_var_data() + assert var_data is not None + assert any( + any( + imported_object.name == "evalReactComponent" + for imported_object in imported_objects + ) + for _, imported_objects in var_data.imports + ) + + has_eval_react_component(ComponentVarState.field_var) # type: ignore + has_eval_react_component(ComponentVarState.computed_var) + + def test_type_chains(): object_var = LiteralObjectVar.create({"a": 1, "b": 2, "c": 3}) assert (object_var._key_type(), object_var._value_type()) == (str, int)