From c397cb2b04a426ef6af3ff08ca7e35206dc9459c Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Mon, 5 Aug 2024 16:56:29 -0700 Subject: [PATCH] create immutable callable var and get rid of more base vars --- reflex/components/chakra/navigation/link.py | 7 +-- reflex/components/core/match.py | 64 +++++++++------------ reflex/components/core/upload.py | 40 +++++++------ reflex/components/el/elements/forms.py | 4 +- reflex/ivars/base.py | 47 +++++++++++++++ reflex/style.py | 4 +- 6 files changed, 105 insertions(+), 61 deletions(-) diff --git a/reflex/components/chakra/navigation/link.py b/reflex/components/chakra/navigation/link.py index e30644acd..e9ce9de1a 100644 --- a/reflex/components/chakra/navigation/link.py +++ b/reflex/components/chakra/navigation/link.py @@ -3,8 +3,9 @@ from reflex.components.chakra import ChakraComponent from reflex.components.component import Component from reflex.components.next.link import NextLink +from reflex.ivars.base import ImmutableVar from reflex.utils.imports import ImportDict -from reflex.vars import BaseVar, Var +from reflex.vars import Var next_link = NextLink.create() @@ -24,9 +25,7 @@ class Link(ChakraComponent): text: Var[str] # What the link renders to. - as_: Var[str] = BaseVar.create( - value="{NextLink}", _var_is_local=False, _var_is_string=False - ) # type: ignore + as_: Var[str] = ImmutableVar(_var_name="{NextLink}", _var_type=str) # If true, the link will open in new tab. is_external: Var[bool] diff --git a/reflex/components/core/match.py b/reflex/components/core/match.py index a88757263..557844c81 100644 --- a/reflex/components/core/match.py +++ b/reflex/components/core/match.py @@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union from reflex.components.base import Fragment from reflex.components.component import BaseComponent, Component, MemoizationLeaf from reflex.components.tags import MatchTag, Tag -from reflex.ivars.base import LiteralVar +from reflex.ivars.base import ImmutableVar, LiteralVar from reflex.style import Style from reflex.utils import format, types from reflex.utils.exceptions import MatchTypeError from reflex.utils.imports import ImportDict -from reflex.vars import BaseVar, Var, VarData +from reflex.vars import ImmutableVarData, Var class Match(MemoizationLeaf): @@ -27,7 +27,7 @@ class Match(MemoizationLeaf): default: Any @classmethod - def create(cls, cond: Any, *cases) -> Union[Component, BaseVar]: + def create(cls, cond: Any, *cases) -> Union[Component, Var]: """Create a Match Component. Args: @@ -46,7 +46,7 @@ class Match(MemoizationLeaf): cls._validate_return_types(match_cases) - if default is None and types._issubclass(type(match_cases[0][-1]), BaseVar): + if default is None and types._issubclass(type(match_cases[0][-1]), Var): raise ValueError( "For cases with return types as Vars, a default case must be provided" ) @@ -56,7 +56,7 @@ class Match(MemoizationLeaf): ) @classmethod - def _create_condition_var(cls, cond: Any) -> BaseVar: + def _create_condition_var(cls, cond: Any) -> Var: """Convert the condition to a Var. Args: @@ -72,12 +72,12 @@ class Match(MemoizationLeaf): if match_cond_var is None: raise ValueError("The condition must be set") - return match_cond_var # type: ignore + return match_cond_var @classmethod def _process_cases( cls, cases: List - ) -> Tuple[List, Optional[Union[BaseVar, BaseComponent]]]: + ) -> Tuple[List, Optional[Union[Var, BaseComponent]]]: """Process the list of match cases and the catchall default case. Args: @@ -103,7 +103,7 @@ class Match(MemoizationLeaf): else default ) - return cases, default # type: ignore + return cases, default @classmethod def _create_case_var_with_var_data(cls, case_element): @@ -117,16 +117,12 @@ class Match(MemoizationLeaf): Returns: The case element Var. """ - _var_data = case_element._var_data if isinstance(case_element, Style) else None # type: ignore - case_element = LiteralVar.create(case_element) - if _var_data is not None: - case_element._var_data = VarData.merge( - case_element._get_all_var_data(), _var_data - ) # type: ignore + _var_data = case_element._var_data if isinstance(case_element, Style) else None + case_element = LiteralVar.create(case_element, _var_data=_var_data) return case_element @classmethod - def _process_match_cases(cls, cases: List) -> List[List[BaseVar]]: + def _process_match_cases(cls, cases: List) -> List[List[Var]]: """Process the individual match cases. Args: @@ -158,7 +154,7 @@ class Match(MemoizationLeaf): if not isinstance(element, BaseComponent) else element ) - if not isinstance(el, (BaseVar, BaseComponent)): + if not isinstance(el, (Var, BaseComponent)): raise ValueError("Case element must be a var or component") case_list.append(el) @@ -167,7 +163,7 @@ class Match(MemoizationLeaf): return match_cases @classmethod - def _validate_return_types(cls, match_cases: List[List[BaseVar]]) -> None: + def _validate_return_types(cls, match_cases: List[List[Var]]) -> None: """Validate that match cases have the same return types. Args: @@ -181,14 +177,14 @@ class Match(MemoizationLeaf): if types._isinstance(first_case_return, BaseComponent): return_type = BaseComponent - elif types._isinstance(first_case_return, BaseVar): - return_type = BaseVar + elif types._isinstance(first_case_return, Var): + return_type = Var for index, case in enumerate(match_cases): if not types._issubclass(type(case[-1]), return_type): raise MatchTypeError( f"Match cases should have the same return types. Case {index} with return " - f"value `{case[-1]._var_name if isinstance(case[-1], BaseVar) else textwrap.shorten(str(case[-1]), width=250)}`" + f"value `{case[-1]._var_name if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`" f" of type {type(case[-1])!r} is not {return_type}" ) @@ -196,9 +192,9 @@ class Match(MemoizationLeaf): def _create_match_cond_var_or_component( cls, match_cond_var: Var, - match_cases: List[List[BaseVar]], - default: Optional[Union[BaseVar, BaseComponent]], - ) -> Union[Component, BaseVar]: + match_cases: List[List[Var]], + default: Optional[Union[Var, BaseComponent]], + ) -> Union[Component, Var]: """Create and return the match condition var or component. Args: @@ -229,28 +225,22 @@ class Match(MemoizationLeaf): # Validate the match cases (as well as the default case) to have Var return types. if any( - case for case in match_cases if not types._isinstance(case[-1], BaseVar) - ) or not types._isinstance(default, BaseVar): + case for case in match_cases if not types._isinstance(case[-1], Var) + ) or not types._isinstance(default, Var): raise ValueError("Return types of match cases should be Vars.") - # match cases and default should all be Vars at this point. - # Retrieve var data of every var in the match cases and default. - var_data = [ - *[el._var_data for case in match_cases for el in case], - default._var_data, # type: ignore - ] - - return match_cond_var._replace( + return ImmutableVar( _var_name=format.format_match( cond=match_cond_var._var_name_unwrapped, match_cases=match_cases, # type: ignore default=default, # type: ignore ), _var_type=default._var_type, # type: ignore - _var_is_local=False, - _var_full_name_needs_state_prefix=False, - _var_is_string=False, - merge_var_data=VarData.merge(*var_data), + _var_data=ImmutableVarData.merge( + match_cond_var._get_all_var_data(), + *[el._get_all_var_data() for case in match_cases for el in case], + default._get_all_var_data(), # type: ignore + ), ) def _render(self) -> Tag: diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index 3c2113253..f45ba276e 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -19,8 +19,10 @@ from reflex.event import ( call_script, parse_args_spec, ) +from reflex.ivars.base import ImmutableCallableVar, ImmutableVar +from reflex.ivars.sequence import LiteralStringVar from reflex.utils.imports import ImportVar -from reflex.vars import BaseVar, CallableVar, Var, VarData +from reflex.vars import Var, VarData DEFAULT_UPLOAD_ID: str = "default" @@ -35,8 +37,8 @@ upload_files_context_var_data: VarData = VarData( ) -@CallableVar -def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: +@ImmutableCallableVar +def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> ImmutableVar: """Get the file upload drop trigger. This var is passed to the dropzone component to update the file list when a @@ -48,23 +50,25 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: Returns: A var referencing the file upload drop trigger. """ - id_var = Var.create_safe(id_, _var_is_string=True) + id_var = LiteralStringVar.create(id_) var_name = f"""e => setFilesById(filesById => {{ const updatedFilesById = Object.assign({{}}, filesById); - updatedFilesById[{id_var._var_name_unwrapped}] = e; + updatedFilesById[{str(id_var)}] = e; return updatedFilesById; }}) """ - return BaseVar( + return ImmutableVar( _var_name=var_name, _var_type=EventChain, - _var_data=VarData.merge(upload_files_context_var_data, id_var._var_data), + _var_data=VarData.merge( + upload_files_context_var_data, id_var._get_all_var_data() + ), ) -@CallableVar -def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: +@ImmutableCallableVar +def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> ImmutableVar: """Get the list of selected files. Args: @@ -73,12 +77,14 @@ def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: Returns: A var referencing the list of selected file paths. """ - id_var = Var.create_safe(id_, _var_is_string=True) - return BaseVar( - _var_name=f"(filesById[{id_var._var_name_unwrapped}] ? filesById[{id_var._var_name_unwrapped}].map((f) => (f.path || f.name)) : [])", + id_var = LiteralStringVar.create(id_) + return ImmutableVar( + _var_name=f"(filesById[{str(id_var)}] ? filesById[{str(id_var)}].map((f) => (f.path || f.name)) : [])", _var_type=List[str], - _var_data=VarData.merge(upload_files_context_var_data, id_var._var_data), - ) + _var_data=VarData.merge( + upload_files_context_var_data, id_var._get_all_var_data() + ), + ).guess_type() @CallableEventSpec @@ -245,7 +251,7 @@ class Upload(MemoizationLeaf): # The file input to use. upload = Input.create(type="file") upload.special_props = { - BaseVar(_var_name="{...getInputProps()}", _var_type=None) + ImmutableVar(_var_name="{...getInputProps()}", _var_type=None) } # The dropzone to use. @@ -254,7 +260,9 @@ class Upload(MemoizationLeaf): *children, **{k: v for k, v in props.items() if k not in supported_props}, ) - zone.special_props = {BaseVar(_var_name="{...getRootProps()}", _var_type=None)} + zone.special_props = { + ImmutableVar(_var_name="{...getRootProps()}", _var_type=None) + } # Create the component. upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID) diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 5c78816a2..0ea9aefd4 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -14,7 +14,7 @@ from reflex.event import EventChain, EventHandler from reflex.ivars.base import ImmutableVar from reflex.utils.format import format_event_chain from reflex.utils.imports import ImportDict -from reflex.vars import BaseVar, Var +from reflex.vars import Var from .base import BaseHTML @@ -198,7 +198,7 @@ class Form(BaseHTML): if EventTriggers.ON_SUBMIT in self.event_triggers: render_tag.add_props( **{ - EventTriggers.ON_SUBMIT: BaseVar( + EventTriggers.ON_SUBMIT: ImmutableVar( _var_name=f"handleSubmit_{self.handle_submit_unique_name}", _var_type=EventChain, ) diff --git a/reflex/ivars/base.py b/reflex/ivars/base.py index d7d75a35d..ff8c087bf 100644 --- a/reflex/ivars/base.py +++ b/reflex/ivars/base.py @@ -1027,3 +1027,50 @@ class OrOperation(ImmutableVar): The VarData of the components and all of its children. """ return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ImmutableCallableVar(ImmutableVar): + """Decorate a Var-returning function to act as both a Var and a function. + + This is used as a compatibility shim for replacing Var objects in the + API with functions that return a family of Var. + """ + + fn: Callable[..., ImmutableVar] = dataclasses.field( + default_factory=lambda: lambda: LiteralVar.create(None) + ) + original_var: ImmutableVar = dataclasses.field( + default_factory=lambda: LiteralVar.create(None) + ) + + def __init__(self, fn: Callable[..., ImmutableVar]): + """Initialize a CallableVar. + + Args: + fn: The function to decorate (must return Var) + """ + original_var = fn() + super(ImmutableCallableVar, self).__init__( + _var_name=original_var._var_name, + _var_type=original_var._var_type, + _var_data=original_var._var_data, + ) + object.__setattr__(self, "fn", fn) + object.__setattr__(self, "original_var", original_var) + + def __call__(self, *args, **kwargs) -> ImmutableVar: + """Call the decorated function. + + Args: + *args: The args to pass to the function. + **kwargs: The kwargs to pass to the function. + + Returns: + The Var returned from calling the function. + """ + return self.fn(*args, **kwargs) diff --git a/reflex/style.py b/reflex/style.py index a4aeecdc3..35c8cd1aa 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -7,7 +7,7 @@ from typing import Any, Literal, Tuple, Type from reflex import constants from reflex.components.core.breakpoints import Breakpoints, breakpoints_values from reflex.event import EventChain -from reflex.ivars.base import ImmutableVar, LiteralVar +from reflex.ivars.base import ImmutableCallableVar, ImmutableVar, LiteralVar from reflex.ivars.function import FunctionVar from reflex.utils import format from reflex.utils.imports import ImportVar @@ -47,7 +47,7 @@ def _color_mode_var(_var_name: str, _var_type: Type = str) -> ImmutableVar: ).guess_type() -# @CallableVar +@ImmutableCallableVar def set_color_mode( new_color_mode: LiteralColorMode | Var[LiteralColorMode] | None = None, ) -> Var[EventChain]: