diff --git a/reflex/app.py b/reflex/app.py index 03382751a..65cb5bfdf 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -100,6 +100,7 @@ from reflex.state import ( StateManager, StateUpdate, _substate_key, + all_base_state_classes, code_uses_state_contexts, ) from reflex.utils import ( @@ -117,6 +118,7 @@ from reflex.utils.imports import ImportVar if TYPE_CHECKING: from reflex.vars import Var + # Define custom types. ComponentCallable = Callable[[], Component] Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]] @@ -375,6 +377,9 @@ class App(MiddlewareMixin, LifespanMixin): # A map from a page route to the component to render. Users should use `add_page`. _pages: Dict[str, Component] = dataclasses.field(default_factory=dict) + # A mapping of pages which created states as they were being evaluated. + _stateful_pages: Dict[str, None] = dataclasses.field(default_factory=dict) + # The backend API object. _api: FastAPI | None = None @@ -592,8 +597,10 @@ class App(MiddlewareMixin, LifespanMixin): """Add optional api endpoints (_upload).""" if not self.api: return - - if Upload.is_used: + upload_is_used_marker = ( + prerequisites.get_backend_dir() / constants.Dirs.UPLOAD_IS_USED + ) + if Upload.is_used or upload_is_used_marker.exists(): # To upload files. self.api.post(str(constants.Endpoint.UPLOAD))(upload(self)) @@ -603,10 +610,15 @@ class App(MiddlewareMixin, LifespanMixin): StaticFiles(directory=get_upload_dir()), name="uploaded_files", ) + + upload_is_used_marker.parent.mkdir(parents=True, exist_ok=True) + upload_is_used_marker.touch() if codespaces.is_running_in_codespaces(): self.api.get(str(constants.Endpoint.AUTH_CODESPACE))( codespaces.auth_codespace ) + if environment.REFLEX_ADD_ALL_ROUTES_ENDPOINT.get(): + self.add_all_routes_endpoint() def _add_cors(self): """Add CORS middleware to the app.""" @@ -747,13 +759,19 @@ class App(MiddlewareMixin, LifespanMixin): route: The route of the page to compile. save_page: If True, the compiled page is saved to self._pages. """ + n_states_before = len(all_base_state_classes) component, enable_state = compiler.compile_unevaluated_page( route, self._unevaluated_pages[route], self._state, self.style, self.theme ) + # Indicate that the app should use state. if enable_state: self._enable_state() + # Indicate that evaluating this page creates one or more state classes. + if len(all_base_state_classes) > n_states_before: + self._stateful_pages[route] = None + # Add the page. self._check_routes_conflict(route) if save_page: @@ -1042,6 +1060,20 @@ class App(MiddlewareMixin, LifespanMixin): def get_compilation_time() -> str: return str(datetime.now().time()).split(".")[0] + should_compile = self._should_compile() + backend_dir = prerequisites.get_backend_dir() + if not should_compile and backend_dir.exists(): + stateful_pages_marker = backend_dir / constants.Dirs.STATEFUL_PAGES + if stateful_pages_marker.exists(): + with stateful_pages_marker.open("r") as f: + stateful_pages = json.load(f) + for route in stateful_pages: + console.info(f"BE Evaluating stateful page: {route}") + self._compile_page(route, save_page=False) + self._enable_state() + self._add_optional_endpoints() + return + # Render a default 404 page if the user didn't supply one if constants.Page404.SLUG not in self._unevaluated_pages: self.add_page(route=constants.Page404.SLUG) @@ -1343,6 +1375,24 @@ class App(MiddlewareMixin, LifespanMixin): for output_path, code in compile_results: compiler_utils.write_page(output_path, code) + # Write list of routes that create dynamic states for backend to use. + if self._state is not None: + stateful_pages_marker = ( + prerequisites.get_backend_dir() / constants.Dirs.STATEFUL_PAGES + ) + stateful_pages_marker.parent.mkdir(parents=True, exist_ok=True) + with stateful_pages_marker.open("w") as f: + json.dump(list(self._stateful_pages), f) + + def add_all_routes_endpoint(self): + """Add an endpoint to the app that returns all the routes.""" + if not self.api: + return + + @self.api.get(str(constants.Endpoint.ALL_ROUTES)) + async def all_routes(): + return list(self._unevaluated_pages.keys()) + @contextlib.asynccontextmanager async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state out of band. diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index 73b0680d3..5d4020ffd 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -2,9 +2,9 @@ from __future__ import annotations -from typing import Any, Iterator +from typing import Any, Iterator, Sequence -from reflex.components.component import Component, LiteralComponentVar +from reflex.components.component import BaseComponent, Component, ComponentStyle from reflex.components.tags import Tag from reflex.components.tags.tagless import Tagless from reflex.config import PerformanceMode, environment @@ -12,7 +12,7 @@ from reflex.utils import console from reflex.utils.decorator import once from reflex.utils.imports import ParsedImportDict from reflex.vars import BooleanVar, ObjectVar, Var -from reflex.vars.base import VarData +from reflex.vars.base import GLOBAL_CACHE, VarData from reflex.vars.sequence import LiteralStringVar @@ -47,6 +47,11 @@ def validate_str(value: str): ) +def _components_from_var(var: Var) -> Sequence[BaseComponent]: + var_data = var._get_all_var_data() + return var_data.components if var_data else () + + class Bare(Component): """A component with no tag.""" @@ -80,8 +85,9 @@ class Bare(Component): The hooks for the component. """ hooks = super()._get_all_hooks_internal() - if isinstance(self.contents, LiteralComponentVar): - hooks |= self.contents._var_value._get_all_hooks_internal() + if isinstance(self.contents, Var): + for component in _components_from_var(self.contents): + hooks |= component._get_all_hooks_internal() return hooks def _get_all_hooks(self) -> dict[str, VarData | None]: @@ -91,18 +97,22 @@ class Bare(Component): The hooks for the component. """ hooks = super()._get_all_hooks() - if isinstance(self.contents, LiteralComponentVar): - hooks |= self.contents._var_value._get_all_hooks() + if isinstance(self.contents, Var): + for component in _components_from_var(self.contents): + hooks |= component._get_all_hooks() return hooks - def _get_all_imports(self) -> ParsedImportDict: + def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict: """Include the imports for the component. + Args: + collapse: Whether to collapse the imports. + Returns: The imports for the component. """ - imports = super()._get_all_imports() - if isinstance(self.contents, LiteralComponentVar): + imports = super()._get_all_imports(collapse=collapse) + if isinstance(self.contents, Var): var_data = self.contents._get_all_var_data() if var_data: imports |= {k: list(v) for k, v in var_data.imports} @@ -115,8 +125,9 @@ class Bare(Component): The dynamic imports. """ dynamic_imports = super()._get_all_dynamic_imports() - if isinstance(self.contents, LiteralComponentVar): - dynamic_imports |= self.contents._var_value._get_all_dynamic_imports() + if isinstance(self.contents, Var): + for component in _components_from_var(self.contents): + dynamic_imports |= component._get_all_dynamic_imports() return dynamic_imports def _get_all_custom_code(self) -> set[str]: @@ -126,10 +137,24 @@ class Bare(Component): The custom code. """ custom_code = super()._get_all_custom_code() - if isinstance(self.contents, LiteralComponentVar): - custom_code |= self.contents._var_value._get_all_custom_code() + if isinstance(self.contents, Var): + for component in _components_from_var(self.contents): + custom_code |= component._get_all_custom_code() return custom_code + def _get_all_app_wrap_components(self) -> dict[tuple[int, str], Component]: + """Get the components that should be wrapped in the app. + + Returns: + The components that should be wrapped in the app. + """ + app_wrap_components = super()._get_all_app_wrap_components() + if isinstance(self.contents, Var): + for component in _components_from_var(self.contents): + if isinstance(component, Component): + app_wrap_components |= component._get_all_app_wrap_components() + return app_wrap_components + def _get_all_refs(self) -> set[str]: """Get the refs for the children of the component. @@ -137,8 +162,9 @@ class Bare(Component): The refs for the children. """ refs = super()._get_all_refs() - if isinstance(self.contents, LiteralComponentVar): - refs |= self.contents._var_value._get_all_refs() + if isinstance(self.contents, Var): + for component in _components_from_var(self.contents): + refs |= component._get_all_refs() return refs def _render(self) -> Tag: @@ -148,6 +174,33 @@ class Bare(Component): return Tagless(contents=f"{{{self.contents!s}}}") return Tagless(contents=str(self.contents)) + def _add_style_recursive( + self, style: ComponentStyle, theme: Component | None = None + ) -> Component: + """Add style to the component and its children. + + Args: + style: The style to add. + theme: The theme to add. + + Returns: + The component with the style added. + """ + new_self = super()._add_style_recursive(style, theme) + + are_components_touched = False + + if isinstance(self.contents, Var): + for component in _components_from_var(self.contents): + if isinstance(component, Component): + component._add_style_recursive(style, theme) + are_components_touched = True + + if are_components_touched: + GLOBAL_CACHE.clear() + + return new_self + def _get_vars( self, include_children: bool = False, ignore_ids: set[int] | None = None ) -> Iterator[Var]: diff --git a/reflex/components/component.py b/reflex/components/component.py index 005f7791d..af9da1b4e 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -4,6 +4,7 @@ from __future__ import annotations import copy import dataclasses +import inspect import typing from abc import ABC, abstractmethod from functools import lru_cache, wraps @@ -21,6 +22,8 @@ from typing import ( Set, Type, Union, + get_args, + get_origin, ) from typing_extensions import Self @@ -43,6 +46,7 @@ from reflex.constants import ( from reflex.constants.compiler import SpecialAttributes from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.event import ( + EventActionsMixin, EventCallback, EventChain, EventHandler, @@ -191,6 +195,25 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool: return types._isinstance(obj, type_hint, nested=1) +def _components_from( + component_or_var: Union[BaseComponent, Var], +) -> tuple[BaseComponent, ...]: + """Get the components from a component or Var. + + Args: + component_or_var: The component or Var to get the components from. + + Returns: + The components. + """ + if isinstance(component_or_var, Var): + var_data = component_or_var._get_all_var_data() + return var_data.components if var_data else () + if isinstance(component_or_var, BaseComponent): + return (component_or_var,) + return () + + class Component(BaseComponent, ABC): """A component with style, event trigger and other props.""" @@ -489,7 +512,7 @@ class Component(BaseComponent, ABC): # Remove any keys that were added as events. for key in kwargs["event_triggers"]: - del kwargs[key] + kwargs.pop(key, None) # Place data_ and aria_ attributes into custom_attrs special_attributes = tuple( @@ -666,12 +689,21 @@ class Component(BaseComponent, ABC): return set() @classmethod - @lru_cache(maxsize=None) - def get_component_props(cls) -> set[str]: - """Get the props that expected a component as value. + def _are_fields_known(cls) -> bool: + """Check if all fields are known at compile time. True for most components. Returns: - The components props. + Whether all fields are known at compile time. + """ + return True + + @classmethod + @lru_cache(maxsize=None) + def _get_component_prop_names(cls) -> Set[str]: + """Get the names of the component props. NOTE: This assumes all fields are known. + + Returns: + The names of the component props. """ return { name @@ -680,6 +712,26 @@ class Component(BaseComponent, ABC): and types._issubclass(field.outer_type_, Component) } + def _get_components_in_props(self) -> Sequence[BaseComponent]: + """Get the components in the props. + + Returns: + The components in the props + """ + if self._are_fields_known(): + return [ + component + for name in self._get_component_prop_names() + for component in _components_from(getattr(self, name)) + ] + return [ + component + for prop in self.get_props() + if (value := getattr(self, prop)) is not None + and isinstance(value, (BaseComponent, Var)) + for component in _components_from(value) + ] + @classmethod def create(cls, *children, **props) -> Self: """Create the component. @@ -1136,6 +1188,9 @@ class Component(BaseComponent, ABC): if custom_code is not None: code.add(custom_code) + for component in self._get_components_in_props(): + code |= component._get_all_custom_code() + # Add the custom code from add_custom_code method. for clz in self._iter_parent_classes_with_method("add_custom_code"): for item in clz.add_custom_code(self): @@ -1163,7 +1218,7 @@ class Component(BaseComponent, ABC): The dynamic imports. """ # Store the import in a set to avoid duplicates. - dynamic_imports = set() + dynamic_imports: set[str] = set() # Get dynamic import for this component. dynamic_import = self._get_dynamic_imports() @@ -1174,25 +1229,12 @@ class Component(BaseComponent, ABC): for child in self.children: dynamic_imports |= child._get_all_dynamic_imports() - for prop in self.get_component_props(): - if getattr(self, prop) is not None: - dynamic_imports |= getattr(self, prop)._get_all_dynamic_imports() + for component in self._get_components_in_props(): + dynamic_imports |= component._get_all_dynamic_imports() # Return the dynamic imports return dynamic_imports - def _get_props_imports(self) -> List[ParsedImportDict]: - """Get the imports needed for components props. - - Returns: - The imports for the components props of the component. - """ - return [ - getattr(self, prop)._get_all_imports() - for prop in self.get_component_props() - if getattr(self, prop) is not None - ] - def _should_transpile(self, dep: str | None) -> bool: """Check if a dependency should be transpiled. @@ -1303,7 +1345,6 @@ class Component(BaseComponent, ABC): ) return imports.merge_imports( - *self._get_props_imports(), self._get_dependencies_imports(), self._get_hooks_imports(), _imports, @@ -1380,6 +1421,8 @@ class Component(BaseComponent, ABC): for k in var_data.hooks } ) + for component in var_data.components: + vars_hooks.update(component._get_all_hooks()) return vars_hooks def _get_events_hooks(self) -> dict[str, VarData | None]: @@ -1528,6 +1571,9 @@ class Component(BaseComponent, ABC): refs.add(ref) for child in self.children: refs |= child._get_all_refs() + for component in self._get_components_in_props(): + refs |= component._get_all_refs() + return refs def _get_all_custom_components( @@ -1551,6 +1597,9 @@ class Component(BaseComponent, ABC): if not isinstance(child, Component): continue custom_components |= child._get_all_custom_components(seen=seen) + for component in self._get_components_in_props(): + if isinstance(component, Component) and component.tag is not None: + custom_components |= component._get_all_custom_components(seen=seen) return custom_components @property @@ -1614,17 +1663,65 @@ class CustomComponent(Component): # The props of the component. props: Dict[str, Any] = {} - # Props that reference other components. - component_props: Dict[str, Component] = {} - - def __init__(self, *args, **kwargs): + def __init__(self, **kwargs): """Initialize the custom component. Args: - *args: The args to pass to the component. **kwargs: The kwargs to pass to the component. """ - super().__init__(*args, **kwargs) + component_fn = kwargs.get("component_fn") + + # Set the props. + props_types = typing.get_type_hints(component_fn) if component_fn else {} + props = {key: value for key, value in kwargs.items() if key in props_types} + kwargs = {key: value for key, value in kwargs.items() if key not in props_types} + + event_types = { + key + for key in props + if ( + (get_origin((annotation := props_types.get(key))) or annotation) + == EventHandler + ) + } + + def get_args_spec(key: str) -> types.ArgsSpec | Sequence[types.ArgsSpec]: + type_ = props_types[key] + + return ( + args[0] + if (args := get_args(type_)) + else ( + annotation_args[1] + if get_origin( + ( + annotation := inspect.getfullargspec( + component_fn + ).annotations[key] + ) + ) + is typing.Annotated + and (annotation_args := get_args(annotation)) + else no_args_event_spec + ) + ) + + super().__init__( + event_triggers={ + key: EventChain.create( + value=props[key], + args_spec=get_args_spec(key), + key=key, + ) + for key in event_types + }, + **kwargs, + ) + + to_camel_cased_props = { + format.to_camel_case(key) for key in props if key not in event_types + } + self.get_props = lambda: to_camel_cased_props # pyright: ignore [reportIncompatibleVariableOverride] # Unset the style. self.style = Style() @@ -1632,51 +1729,36 @@ class CustomComponent(Component): # Set the tag to the name of the function. self.tag = format.to_title_case(self.component_fn.__name__) - # Get the event triggers defined in the component declaration. - event_triggers_in_component_declaration = self.get_event_triggers() - - # Set the props. - props = typing.get_type_hints(self.component_fn) - for key, value in kwargs.items(): + for key, value in props.items(): # Skip kwargs that are not props. - if key not in props: + if key not in props_types: continue + camel_cased_key = format.to_camel_case(key) + # Get the type based on the annotation. - type_ = props[key] + type_ = props_types[key] # Handle event chains. - if types._issubclass(type_, EventChain): - value = EventChain.create( - value=value, - args_spec=event_triggers_in_component_declaration.get( - key, no_args_event_spec - ), - key=key, + if types._issubclass(type_, EventActionsMixin): + inspect.getfullargspec(component_fn).annotations[key] + self.props[camel_cased_key] = EventChain.create( + value=value, args_spec=get_args_spec(key), key=key ) - self.props[format.to_camel_case(key)] = value continue - # Handle subclasses of Base. - if isinstance(value, Base): - base_value = LiteralVar.create(value) + value = LiteralVar.create(value) + self.props[camel_cased_key] = value + setattr(self, camel_cased_key, value) - # Track hooks and imports associated with Component instances. - if base_value is not None and isinstance(value, Component): - self.component_props[key] = value - value = base_value._replace( - merge_var_data=VarData( - imports=value._get_all_imports(), - hooks=value._get_all_hooks(), - ) - ) - else: - value = base_value - else: - value = LiteralVar.create(value) + @classmethod + def _are_fields_known(cls) -> bool: + """Check if the fields are known. - # Set the prop. - self.props[format.to_camel_case(key)] = value + Returns: + Whether the fields are known. + """ + return False def __eq__(self, other: Any) -> bool: """Check if the component is equal to another. @@ -1698,7 +1780,7 @@ class CustomComponent(Component): return hash(self.tag) @classmethod - def get_props(cls) -> Set[str]: # pyright: ignore [reportIncompatibleVariableOverride] + def get_props(cls) -> Set[str]: """Get the props for the component. Returns: @@ -1735,27 +1817,8 @@ class CustomComponent(Component): seen=seen ) - # Fetch custom components from props as well. - for child_component in self.component_props.values(): - if child_component.tag is None: - continue - if child_component.tag not in seen: - seen.add(child_component.tag) - if isinstance(child_component, CustomComponent): - custom_components |= {child_component} - custom_components |= child_component._get_all_custom_components( - seen=seen - ) return custom_components - def _render(self) -> Tag: - """Define how to render the component in React. - - Returns: - The tag to render. - """ - return super()._render(props=self.props) - def get_prop_vars(self) -> List[Var]: """Get the prop vars. @@ -1765,29 +1828,19 @@ class CustomComponent(Component): return [ Var( _js_expr=name, - _var_type=(prop._var_type if isinstance(prop, Var) else type(prop)), + _var_type=( + prop._var_type + if isinstance(prop, Var) + else ( + type(prop) + if not isinstance(prop, EventActionsMixin) + else EventChain + ) + ), ).guess_type() for name, prop in self.props.items() ] - def _get_vars( - self, include_children: bool = False, ignore_ids: set[int] | None = None - ) -> Iterator[Var]: - """Walk all Vars used in this component. - - Args: - include_children: Whether to include Vars from children. - ignore_ids: The ids to ignore. - - Yields: - Each var referenced by the component (props, styles, event handlers). - """ - ignore_ids = ignore_ids or set() - yield from super()._get_vars( - include_children=include_children, ignore_ids=ignore_ids - ) - yield from filter(lambda prop: isinstance(prop, Var), self.props.values()) - @lru_cache(maxsize=None) # noqa: B019 def get_component(self) -> Component: """Render the component. @@ -2475,6 +2528,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): The VarData for the var. """ return VarData.merge( + self._var_data, VarData( imports={ "@emotion/react": [ @@ -2517,9 +2571,21 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): Returns: The var. """ + var_datas = [ + var_data + for var in value._get_vars(include_children=True) + if (var_data := var._get_all_var_data()) + ] + return LiteralComponentVar( _js_expr="", _var_type=type(value), - _var_data=_var_data, + _var_data=VarData.merge( + _var_data, + *var_datas, + VarData( + components=(value,), + ), + ), _var_value=value, ) diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index 6f9110a16..a76a8b800 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -61,14 +61,6 @@ class Cond(MemoizationLeaf): ) ) - def _get_props_imports(self): - """Get the imports needed for component's props. - - Returns: - The imports for the component's props of the component. - """ - return [] - def _render(self) -> Tag: return CondTag( cond=self.cond, diff --git a/reflex/components/radix/themes/components/context_menu.py b/reflex/components/radix/themes/components/context_menu.py index 60d23db1a..49f22bdfa 100644 --- a/reflex/components/radix/themes/components/context_menu.py +++ b/reflex/components/radix/themes/components/context_menu.py @@ -10,6 +10,7 @@ from reflex.vars.base import Var from ..base import LiteralAccentColor, RadixThemesComponent from .checkbox import Checkbox +from .radio_group import HighLevelRadioGroup LiteralDirType = Literal["ltr", "rtl"] @@ -226,7 +227,11 @@ class ContextMenuItem(RadixThemesComponent): # Optional text used for typeahead purposes. By default the typeahead behavior will use the content of the item. Use this when the content is complex, or you have non-textual content inside. text_value: Var[str] - _valid_parents: List[str] = ["ContextMenuContent", "ContextMenuSubContent"] + _valid_parents: List[str] = [ + "ContextMenuContent", + "ContextMenuSubContent", + "ContextMenuGroup", + ] # Fired when the item is selected. on_select: EventHandler[no_args_event_spec] @@ -247,6 +252,75 @@ class ContextMenuCheckbox(Checkbox): shortcut: Var[str] +class ContextMenuLabel(RadixThemesComponent): + """The component that contains the label.""" + + tag = "ContextMenu.Label" + + # Change the default rendered element for the one passed as a child, merging their props and behavior. Defaults to False. + as_child: Var[bool] + + +class ContextMenuGroup(RadixThemesComponent): + """The component that contains the group.""" + + tag = "ContextMenu.Group" + + # Change the default rendered element for the one passed as a child, merging their props and behavior. Defaults to False. + as_child: Var[bool] + + _valid_parents: List[str] = ["ContextMenuContent", "ContextMenuSubContent"] + + +class ContextMenuRadioGroup(RadixThemesComponent): + """The component that contains context menu radio items.""" + + tag = "ContextMenu.RadioGroup" + + # Change the default rendered element for the one passed as a child, merging their props and behavior. Defaults to False. + as_child: Var[bool] + + # The value of the selected item in the group. + value: Var[str] + + # Props to rename + _rename_props = {"onChange": "onValueChange"} + + # Fired when the value of the radio group changes. + on_change: EventHandler[passthrough_event_spec(str)] + + _valid_parents: List[str] = [ + "ContextMenuRadioItem", + "ContextMenuSubContent", + "ContextMenuContent", + "ContextMenuSub", + ] + + +class ContextMenuRadioItem(HighLevelRadioGroup): + """The component that contains context menu radio items.""" + + tag = "ContextMenu.RadioItem" + + # Override theme color for Dropdown Menu Content + color_scheme: Var[LiteralAccentColor] + + # Change the default rendered element for the one passed as a child, merging their props and behavior. Defaults to False. + as_child: Var[bool] + + # The unique value of the item. + value: Var[str] + + # When true, prevents the user from interacting with the item. + disabled: Var[bool] + + # Event handler called when the user selects an item (via mouse or keyboard). Calling event.preventDefault in this handler will prevent the context menu from closing when selecting that item. + on_select: EventHandler[no_args_event_spec] + + # Optional text used for typeahead purposes. By default the typeahead behavior will use the .textContent of the item. Use this when the content is complex, or you have non-textual content inside. + text_value: Var[str] + + class ContextMenu(ComponentNamespace): """Menu representing a set of actions, displayed at the origin of a pointer right-click or long-press.""" @@ -259,6 +333,10 @@ class ContextMenu(ComponentNamespace): item = staticmethod(ContextMenuItem.create) separator = staticmethod(ContextMenuSeparator.create) checkbox = staticmethod(ContextMenuCheckbox.create) + label = staticmethod(ContextMenuLabel.create) + group = staticmethod(ContextMenuGroup.create) + radio_group = staticmethod(ContextMenuRadioGroup.create) + radio = staticmethod(ContextMenuRadioItem.create) context_menu = ContextMenu() diff --git a/reflex/components/radix/themes/components/context_menu.pyi b/reflex/components/radix/themes/components/context_menu.pyi index 81ccb125b..34aa36f45 100644 --- a/reflex/components/radix/themes/components/context_menu.pyi +++ b/reflex/components/radix/themes/components/context_menu.pyi @@ -3,7 +3,7 @@ # ------------------- DO NOT EDIT ---------------------- # This file was generated by `reflex/utils/pyi_generator.py`! # ------------------------------------------------------ -from typing import Any, Dict, Literal, Optional, Union, overload +from typing import Any, Dict, List, Literal, Optional, Union, overload from reflex.components.component import ComponentNamespace from reflex.components.core.breakpoints import Breakpoints @@ -13,6 +13,7 @@ from reflex.vars.base import Var from ..base import RadixThemesComponent from .checkbox import Checkbox +from .radio_group import HighLevelRadioGroup LiteralDirType = Literal["ltr", "rtl"] LiteralSizeType = Literal["1", "2"] @@ -820,6 +821,320 @@ class ContextMenuCheckbox(Checkbox): """ ... +class ContextMenuLabel(RadixThemesComponent): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + as_child: Optional[Union[Var[bool], bool]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None, + on_blur: Optional[EventType[()]] = None, + on_click: Optional[EventType[()]] = None, + on_context_menu: Optional[EventType[()]] = None, + on_double_click: Optional[EventType[()]] = None, + on_focus: Optional[EventType[()]] = None, + on_mount: Optional[EventType[()]] = None, + on_mouse_down: Optional[EventType[()]] = None, + on_mouse_enter: Optional[EventType[()]] = None, + on_mouse_leave: Optional[EventType[()]] = None, + on_mouse_move: Optional[EventType[()]] = None, + on_mouse_out: Optional[EventType[()]] = None, + on_mouse_over: Optional[EventType[()]] = None, + on_mouse_up: Optional[EventType[()]] = None, + on_scroll: Optional[EventType[()]] = None, + on_unmount: Optional[EventType[()]] = None, + **props, + ) -> "ContextMenuLabel": + """Create a new component instance. + + Will prepend "RadixThemes" to the component tag to avoid conflicts with + other UI libraries for common names, like Text and Button. + + Args: + *children: Child components. + as_child: Change the default rendered element for the one passed as a child, merging their props and behavior. Defaults to False. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: Component properties. + + Returns: + A new component instance. + """ + ... + +class ContextMenuGroup(RadixThemesComponent): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + as_child: Optional[Union[Var[bool], bool]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None, + on_blur: Optional[EventType[()]] = None, + on_click: Optional[EventType[()]] = None, + on_context_menu: Optional[EventType[()]] = None, + on_double_click: Optional[EventType[()]] = None, + on_focus: Optional[EventType[()]] = None, + on_mount: Optional[EventType[()]] = None, + on_mouse_down: Optional[EventType[()]] = None, + on_mouse_enter: Optional[EventType[()]] = None, + on_mouse_leave: Optional[EventType[()]] = None, + on_mouse_move: Optional[EventType[()]] = None, + on_mouse_out: Optional[EventType[()]] = None, + on_mouse_over: Optional[EventType[()]] = None, + on_mouse_up: Optional[EventType[()]] = None, + on_scroll: Optional[EventType[()]] = None, + on_unmount: Optional[EventType[()]] = None, + **props, + ) -> "ContextMenuGroup": + """Create a new component instance. + + Will prepend "RadixThemes" to the component tag to avoid conflicts with + other UI libraries for common names, like Text and Button. + + Args: + *children: Child components. + as_child: Change the default rendered element for the one passed as a child, merging their props and behavior. Defaults to False. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: Component properties. + + Returns: + A new component instance. + """ + ... + +class ContextMenuRadioGroup(RadixThemesComponent): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + as_child: Optional[Union[Var[bool], bool]] = None, + value: Optional[Union[Var[str], str]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None, + on_blur: Optional[EventType[()]] = None, + on_change: Optional[Union[EventType[()], EventType[str]]] = None, + on_click: Optional[EventType[()]] = None, + on_context_menu: Optional[EventType[()]] = None, + on_double_click: Optional[EventType[()]] = None, + on_focus: Optional[EventType[()]] = None, + on_mount: Optional[EventType[()]] = None, + on_mouse_down: Optional[EventType[()]] = None, + on_mouse_enter: Optional[EventType[()]] = None, + on_mouse_leave: Optional[EventType[()]] = None, + on_mouse_move: Optional[EventType[()]] = None, + on_mouse_out: Optional[EventType[()]] = None, + on_mouse_over: Optional[EventType[()]] = None, + on_mouse_up: Optional[EventType[()]] = None, + on_scroll: Optional[EventType[()]] = None, + on_unmount: Optional[EventType[()]] = None, + **props, + ) -> "ContextMenuRadioGroup": + """Create a new component instance. + + Will prepend "RadixThemes" to the component tag to avoid conflicts with + other UI libraries for common names, like Text and Button. + + Args: + *children: Child components. + as_child: Change the default rendered element for the one passed as a child, merging their props and behavior. Defaults to False. + value: The value of the selected item in the group. + on_change: Fired when the value of the radio group changes. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: Component properties. + + Returns: + A new component instance. + """ + ... + +class ContextMenuRadioItem(HighLevelRadioGroup): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + color_scheme: Optional[ + Union[ + Literal[ + "amber", + "blue", + "bronze", + "brown", + "crimson", + "cyan", + "gold", + "grass", + "gray", + "green", + "indigo", + "iris", + "jade", + "lime", + "mint", + "orange", + "pink", + "plum", + "purple", + "red", + "ruby", + "sky", + "teal", + "tomato", + "violet", + "yellow", + ], + Var[ + Literal[ + "amber", + "blue", + "bronze", + "brown", + "crimson", + "cyan", + "gold", + "grass", + "gray", + "green", + "indigo", + "iris", + "jade", + "lime", + "mint", + "orange", + "pink", + "plum", + "purple", + "red", + "ruby", + "sky", + "teal", + "tomato", + "violet", + "yellow", + ] + ], + ] + ] = None, + as_child: Optional[Union[Var[bool], bool]] = None, + value: Optional[Union[Var[str], str]] = None, + disabled: Optional[Union[Var[bool], bool]] = None, + text_value: Optional[Union[Var[str], str]] = None, + items: Optional[Union[List[str], Var[List[str]]]] = None, + direction: Optional[ + Union[ + Literal["column", "column-reverse", "row", "row-reverse"], + Var[Literal["column", "column-reverse", "row", "row-reverse"]], + ] + ] = None, + spacing: Optional[ + Union[ + Literal["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], + Var[Literal["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]], + ] + ] = None, + size: Optional[ + Union[Literal["1", "2", "3"], Var[Literal["1", "2", "3"]]] + ] = None, + variant: Optional[ + Union[ + Literal["classic", "soft", "surface"], + Var[Literal["classic", "soft", "surface"]], + ] + ] = None, + high_contrast: Optional[Union[Var[bool], bool]] = None, + default_value: Optional[Union[Var[str], str]] = None, + name: Optional[Union[Var[str], str]] = None, + required: Optional[Union[Var[bool], bool]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None, + on_blur: Optional[EventType[()]] = None, + on_click: Optional[EventType[()]] = None, + on_context_menu: Optional[EventType[()]] = None, + on_double_click: Optional[EventType[()]] = None, + on_focus: Optional[EventType[()]] = None, + on_mount: Optional[EventType[()]] = None, + on_mouse_down: Optional[EventType[()]] = None, + on_mouse_enter: Optional[EventType[()]] = None, + on_mouse_leave: Optional[EventType[()]] = None, + on_mouse_move: Optional[EventType[()]] = None, + on_mouse_out: Optional[EventType[()]] = None, + on_mouse_over: Optional[EventType[()]] = None, + on_mouse_up: Optional[EventType[()]] = None, + on_scroll: Optional[EventType[()]] = None, + on_select: Optional[EventType[()]] = None, + on_unmount: Optional[EventType[()]] = None, + **props, + ) -> "ContextMenuRadioItem": + """Create a radio group component. + + Args: + items: The items of the radio group. + color_scheme: The color of the radio group + as_child: Change the default rendered element for the one passed as a child, merging their props and behavior. Defaults to False. + value: The controlled value of the radio item to check. Should be used in conjunction with on_change. + disabled: Whether the radio group is disabled + on_select: Event handler called when the user selects an item (via mouse or keyboard). Calling event.preventDefault in this handler will prevent the context menu from closing when selecting that item. + text_value: Optional text used for typeahead purposes. By default the typeahead behavior will use the .textContent of the item. Use this when the content is complex, or you have non-textual content inside. + items: The items of the radio group. + direction: The direction of the radio group. + spacing: The gap between the items of the radio group. + size: The size of the radio group. + variant: The variant of the radio group + high_contrast: Whether to render the radio group with higher contrast color against background + default_value: The initial value of checked radio item. Should be used in conjunction with on_change. + name: The name of the group. Submitted with its owning form as part of a name/value pair. + required: Whether the radio group is required + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: Additional properties to apply to the accordion item. + + Returns: + The created radio group component. + + Raises: + TypeError: If the type of items is invalid. + """ + ... + class ContextMenu(ComponentNamespace): root = staticmethod(ContextMenuRoot.create) trigger = staticmethod(ContextMenuTrigger.create) @@ -830,5 +1145,9 @@ class ContextMenu(ComponentNamespace): item = staticmethod(ContextMenuItem.create) separator = staticmethod(ContextMenuSeparator.create) checkbox = staticmethod(ContextMenuCheckbox.create) + label = staticmethod(ContextMenuLabel.create) + group = staticmethod(ContextMenuGroup.create) + radio_group = staticmethod(ContextMenuRadioGroup.create) + radio = staticmethod(ContextMenuRadioItem.create) context_menu = ContextMenu() diff --git a/reflex/components/tags/tag.py b/reflex/components/tags/tag.py index 515d9e05f..7f7a8c74d 100644 --- a/reflex/components/tags/tag.py +++ b/reflex/components/tags/tag.py @@ -101,7 +101,7 @@ class Tag: """ self.props.update( { - format.to_camel_case(name, allow_hyphens=True): ( + format.to_camel_case(name, treat_hyphens_as_underscores=False): ( prop if types._isinstance(prop, (EventChain, Mapping)) else LiteralVar.create(prop) diff --git a/reflex/config.py b/reflex/config.py index b0cb34691..7900b4680 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -719,6 +719,9 @@ class EnvironmentVariables: # The timeout for the backend to do a cold start in seconds. BACKEND_COLD_START_TIMEOUT: EnvVar[int] = env_var(10) + # Used by flexgen to enumerate the pages. + REFLEX_ADD_ALL_ROUTES_ENDPOINT: EnvVar[bool] = env_var(False) + environment = EnvironmentVariables() diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 7fbcdf18a..0611c7d4c 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -53,6 +53,12 @@ class Dirs(SimpleNamespace): POSTCSS_JS = "postcss.config.js" # The name of the states directory. STATES = ".states" + # Where compilation artifacts for the backend are stored. + BACKEND = "backend" + # JSON-encoded list of page routes that need to be evaluated on the backend. + STATEFUL_PAGES = "stateful_pages.json" + # Marker file indicating that upload component was used in the frontend. + UPLOAD_IS_USED = "upload_is_used" class Reflex(SimpleNamespace): diff --git a/reflex/constants/event.py b/reflex/constants/event.py index d454e6ea8..7b58c99cf 100644 --- a/reflex/constants/event.py +++ b/reflex/constants/event.py @@ -12,6 +12,7 @@ class Endpoint(Enum): UPLOAD = "_upload" AUTH_CODESPACE = "auth-codespace" HEALTH = "_health" + ALL_ROUTES = "_all_routes" def __str__(self) -> str: """Get the string representation of the endpoint. diff --git a/reflex/reflex.py b/reflex/reflex.py index 43f7b6184..878b32d76 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -23,8 +23,6 @@ typer.core.rich = None # pyright: ignore [reportPrivateImportUsage] cli = typer.Typer(add_completion=False, pretty_exceptions_enable=False) -SHOW_BUILT_WITH_REFLEX_INFO = "https://reflex.dev/docs/hosting/reflex-branding/" - # Get the config. config = get_config() @@ -193,15 +191,6 @@ def _run( prerequisites.check_latest_package_version(constants.Reflex.MODULE_NAME) if frontend: - if config.show_built_with_reflex is False: - # The sticky badge may be disabled at runtime for team/enterprise tiers. - prerequisites.check_config_option_in_tier( - option_name="show_built_with_reflex", - allowed_tiers=["team", "enterprise"], - fallback_value=True, - help_link=SHOW_BUILT_WITH_REFLEX_INFO, - ) - # Get the app module. prerequisites.get_compiled_app() @@ -358,15 +347,6 @@ def export( if prerequisites.needs_reinit(frontend=frontend or not backend): _init(name=config.app_name, loglevel=loglevel) - if frontend and config.show_built_with_reflex is False: - # The sticky badge may be disabled on export for team/enterprise tiers. - prerequisites.check_config_option_in_tier( - option_name="show_built_with_reflex", - allowed_tiers=["team", "enterprise"], - fallback_value=True, - help_link=SHOW_BUILT_WITH_REFLEX_INFO, - ) - export_utils.export( zipping=zipping, frontend=frontend, @@ -563,15 +543,6 @@ def deploy( environment.REFLEX_COMPILE_CONTEXT.set(constants.CompileContext.DEPLOY) - if not config.show_built_with_reflex: - # The sticky badge may be disabled on deploy for pro/team/enterprise tiers. - prerequisites.check_config_option_in_tier( - option_name="show_built_with_reflex", - allowed_tiers=["pro", "team", "enterprise"], - fallback_value=True, - help_link=SHOW_BUILT_WITH_REFLEX_INFO, - ) - # Set the log level. console.set_log_level(loglevel) diff --git a/reflex/state.py b/reflex/state.py index 2689ba910..0f0ba97f9 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -327,6 +327,9 @@ async def _resolve_delta(delta: Delta) -> Delta: return delta +all_base_state_classes: dict[str, None] = {} + + class BaseState(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" @@ -624,6 +627,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): cls._var_dependencies = {} cls._init_var_dependency_dicts() + all_base_state_classes[cls.get_full_name()] = None + @staticmethod def _copy_fn(fn: Callable) -> Callable: """Copy a function. Used to copy ComputedVars and EventHandlers from mixins. @@ -4087,6 +4092,7 @@ def reload_state_module( for subclass in tuple(state.class_subclasses): reload_state_module(module=module, state=subclass) if subclass.__module__ == module and module is not None: + all_base_state_classes.pop(subclass.get_full_name(), None) state.class_subclasses.remove(subclass) state._always_dirty_substates.discard(subclass.get_name()) state._var_dependencies = {} diff --git a/reflex/style.py b/reflex/style.py index 1d818ed06..00dc16839 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -190,11 +190,12 @@ def convert( for key, value in style_dict.items(): keys = ( format_style_key(key) - if not isinstance(value, (dict, ObjectVar)) + if not isinstance(value, (dict, ObjectVar, list)) or ( isinstance(value, Breakpoints) and all(not isinstance(v, dict) for v in value.values()) ) + or (isinstance(value, list) and all(not isinstance(v, dict) for v in value)) or ( isinstance(value, ObjectVar) and not issubclass(get_origin(value._var_type) or value._var_type, dict) @@ -236,7 +237,9 @@ def format_style_key(key: str) -> Tuple[str, ...]: Returns: Tuple of css style names corresponding to the key provided. """ - key = format.to_camel_case(key, allow_hyphens=True) + if key.startswith("--"): + return (key,) + key = format.to_camel_case(key) return STYLE_PROP_SHORTHAND_MAPPING.get(key, (key,)) diff --git a/reflex/utils/build.py b/reflex/utils/build.py index 9e35ab984..c02a30c7b 100644 --- a/reflex/utils/build.py +++ b/reflex/utils/build.py @@ -60,6 +60,7 @@ def _zip( dirs_to_exclude: set[str] | None = None, files_to_exclude: set[str] | None = None, top_level_dirs_to_exclude: set[str] | None = None, + globs_to_include: list[str] | None = None, ) -> None: """Zip utility function. @@ -72,6 +73,7 @@ def _zip( dirs_to_exclude: The directories to exclude. files_to_exclude: The files to exclude. top_level_dirs_to_exclude: The top level directory names immediately under root_dir to exclude. Do not exclude folders by these names further in the sub-directories. + globs_to_include: Apply these globs from the root_dir and always include them in the zip. """ target = Path(target) @@ -103,6 +105,13 @@ def _zip( files_to_zip += [ str(root / file) for file in files if file not in files_to_exclude ] + if globs_to_include: + for glob in globs_to_include: + files_to_zip += [ + str(file) + for file in root_dir.glob(glob) + if file.name not in files_to_exclude + ] # Create a progress bar for zipping the component. progress = Progress( @@ -160,6 +169,9 @@ def zip_app( top_level_dirs_to_exclude={"assets"}, exclude_venv_dirs=True, upload_db_file=upload_db_file, + globs_to_include=[ + str(Path(constants.Dirs.WEB) / constants.Dirs.BACKEND / "*") + ], ) diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 214c845f8..14ef8fb46 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -168,7 +168,7 @@ def to_snake_case(text: str) -> str: return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower().replace("-", "_") -def to_camel_case(text: str, allow_hyphens: bool = False) -> str: +def to_camel_case(text: str, treat_hyphens_as_underscores: bool = True) -> str: """Convert a string to camel case. The first word in the text is converted to lowercase and @@ -176,17 +176,16 @@ def to_camel_case(text: str, allow_hyphens: bool = False) -> str: Args: text: The string to convert. - allow_hyphens: Whether to allow hyphens in the string. + treat_hyphens_as_underscores: Whether to allow hyphens in the string. Returns: The camel case string. """ - char = "_" if allow_hyphens else "-_" - words = re.split(f"[{char}]", text.lstrip(char)) - leading_underscores_or_hyphens = "".join(re.findall(rf"^[{char}]+", text)) + char = "_" if not treat_hyphens_as_underscores else "-_" + words = re.split(f"[{char}]", text) # Capitalize the first letter of each word except the first one converted_word = words[0] + "".join(x.capitalize() for x in words[1:]) - return leading_underscores_or_hyphens + converted_word + return converted_word def to_title_case(text: str, sep: str = "") -> str: diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 145b5324c..b5987f4e8 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -99,6 +99,15 @@ def get_states_dir() -> Path: return environment.REFLEX_STATES_WORKDIR.get() +def get_backend_dir() -> Path: + """Get the working directory for the backend. + + Returns: + The working directory. + """ + return get_web_dir() / constants.Dirs.BACKEND + + def check_latest_package_version(package_name: str): """Check if the latest version of the package is installed. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 89bc86fce..6654c7e22 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -76,6 +76,7 @@ from reflex.utils.types import ( ) if TYPE_CHECKING: + from reflex.components.component import BaseComponent from reflex.state import BaseState from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar @@ -132,6 +133,9 @@ class VarData: # Position of the hook in the component position: Hooks.HookPosition | None = None + # Components that are part of this var + components: Tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple) + def __init__( self, state: str = "", @@ -140,6 +144,7 @@ class VarData: hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None, deps: list[Var] | None = None, position: Hooks.HookPosition | None = None, + components: Iterable[BaseComponent] | None = None, ): """Initialize the var data. @@ -150,6 +155,7 @@ class VarData: hooks: Hooks that need to be present in the component to render this var. deps: Dependencies of the var for useCallback. position: Position of the hook in the component. + components: Components that are part of this var. """ if isinstance(hooks, str): hooks = [hooks] @@ -164,6 +170,7 @@ class VarData: object.__setattr__(self, "hooks", tuple(hooks or {})) object.__setattr__(self, "deps", tuple(deps or [])) object.__setattr__(self, "position", position or None) + object.__setattr__(self, "components", tuple(components or [])) if hooks and any(hooks.values()): merged_var_data = VarData.merge(self, *hooks.values()) @@ -174,6 +181,7 @@ class VarData: object.__setattr__(self, "hooks", merged_var_data.hooks) object.__setattr__(self, "deps", merged_var_data.deps) object.__setattr__(self, "position", merged_var_data.position) + object.__setattr__(self, "components", merged_var_data.components) def old_school_imports(self) -> ImportDict: """Return the imports as a mutable dict. @@ -242,17 +250,19 @@ class VarData: else: position = None - if state or _imports or hooks or field_name or deps or position: - return VarData( - state=state, - field_name=field_name, - imports=_imports, - hooks=hooks, - deps=deps, - position=position, - ) + components = tuple( + component for var_data in all_var_datas for component in var_data.components + ) - return None + return VarData( + state=state, + field_name=field_name, + imports=_imports, + hooks=hooks, + deps=deps, + position=position, + components=components, + ) def __bool__(self) -> bool: """Check if the var data is non-empty. @@ -267,6 +277,7 @@ class VarData: or self.field_name or self.deps or self.position + or self.components ) @classmethod diff --git a/tests/units/components/markdown/test_markdown.py b/tests/units/components/markdown/test_markdown.py index c6d395eb1..15d662ef6 100644 --- a/tests/units/components/markdown/test_markdown.py +++ b/tests/units/components/markdown/test_markdown.py @@ -157,7 +157,7 @@ def test_create_map_fn_var_subclass(cls, fn_body, fn_args, explicit_return, expe value, **props ) }, - r"""(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); let _language = match ? match[1] : ''; ; return inline ? ( {children} ) : ( ); })""", + r"""(({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); let _language = match ? match[1] : ''; ; return inline ? ( {children} ) : ( ); })""", ), ( "h1", diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 8cffa6e0e..d333a45b4 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -651,7 +651,7 @@ def test_create_filters_none_props(test_component): # Assert that the style prop is present in the component's props assert str(component.style["color"]) == '"white"' - assert str(component.style["text-align"]) == '"center"' + assert str(component.style["textAlign"]) == '"center"' @pytest.mark.parametrize( @@ -871,7 +871,7 @@ def test_create_custom_component(my_component): """ component = CustomComponent(component_fn=my_component, prop1="test", prop2=1) assert component.tag == "MyComponent" - assert component.get_props() == set() + assert component.get_props() == {"prop1", "prop2"} assert component._get_all_custom_components() == {component} diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py index 89197a03e..053d5a3ae 100644 --- a/tests/units/utils/test_format.py +++ b/tests/units/utils/test_format.py @@ -189,11 +189,11 @@ def test_to_snake_case(input: str, output: str): ("kebab-case", "kebabCase"), ("kebab-case-two", "kebabCaseTwo"), ("snake_kebab-case", "snakeKebabCase"), - ("_hover", "_hover"), - ("-starts-with-hyphen", "-startsWithHyphen"), - ("--starts-with-double-hyphen", "--startsWithDoubleHyphen"), - ("_starts_with_underscore", "_startsWithUnderscore"), - ("__starts_with_double_underscore", "__startsWithDoubleUnderscore"), + ("_hover", "Hover"), + ("-starts-with-hyphen", "StartsWithHyphen"), + ("--starts-with-double-hyphen", "StartsWithDoubleHyphen"), + ("_starts_with_underscore", "StartsWithUnderscore"), + ("__starts_with_double_underscore", "StartsWithDoubleUnderscore"), (":start-with-colon", ":startWithColon"), (":-start-with-colon-dash", ":StartWithColonDash"), ],