create immutable callable var and get rid of more base vars

This commit is contained in:
Khaleel Al-Adhami 2024-08-05 16:56:29 -07:00
parent da95d0f519
commit c397cb2b04
6 changed files with 105 additions and 61 deletions

View File

@ -3,8 +3,9 @@
from reflex.components.chakra import ChakraComponent from reflex.components.chakra import ChakraComponent
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.next.link import NextLink from reflex.components.next.link import NextLink
from reflex.ivars.base import ImmutableVar
from reflex.utils.imports import ImportDict from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var from reflex.vars import Var
next_link = NextLink.create() next_link = NextLink.create()
@ -24,9 +25,7 @@ class Link(ChakraComponent):
text: Var[str] text: Var[str]
# What the link renders to. # What the link renders to.
as_: Var[str] = BaseVar.create( as_: Var[str] = ImmutableVar(_var_name="{NextLink}", _var_type=str)
value="{NextLink}", _var_is_local=False, _var_is_string=False
) # type: ignore
# If true, the link will open in new tab. # If true, the link will open in new tab.
is_external: Var[bool] is_external: Var[bool]

View File

@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from reflex.components.base import Fragment from reflex.components.base import Fragment
from reflex.components.component import BaseComponent, Component, MemoizationLeaf from reflex.components.component import BaseComponent, Component, MemoizationLeaf
from reflex.components.tags import MatchTag, Tag from reflex.components.tags import MatchTag, Tag
from reflex.ivars.base import LiteralVar from reflex.ivars.base import ImmutableVar, LiteralVar
from reflex.style import Style from reflex.style import Style
from reflex.utils import format, types from reflex.utils import format, types
from reflex.utils.exceptions import MatchTypeError from reflex.utils.exceptions import MatchTypeError
from reflex.utils.imports import ImportDict from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var, VarData from reflex.vars import ImmutableVarData, Var
class Match(MemoizationLeaf): class Match(MemoizationLeaf):
@ -27,7 +27,7 @@ class Match(MemoizationLeaf):
default: Any default: Any
@classmethod @classmethod
def create(cls, cond: Any, *cases) -> Union[Component, BaseVar]: def create(cls, cond: Any, *cases) -> Union[Component, Var]:
"""Create a Match Component. """Create a Match Component.
Args: Args:
@ -46,7 +46,7 @@ class Match(MemoizationLeaf):
cls._validate_return_types(match_cases) cls._validate_return_types(match_cases)
if default is None and types._issubclass(type(match_cases[0][-1]), BaseVar): if default is None and types._issubclass(type(match_cases[0][-1]), Var):
raise ValueError( raise ValueError(
"For cases with return types as Vars, a default case must be provided" "For cases with return types as Vars, a default case must be provided"
) )
@ -56,7 +56,7 @@ class Match(MemoizationLeaf):
) )
@classmethod @classmethod
def _create_condition_var(cls, cond: Any) -> BaseVar: def _create_condition_var(cls, cond: Any) -> Var:
"""Convert the condition to a Var. """Convert the condition to a Var.
Args: Args:
@ -72,12 +72,12 @@ class Match(MemoizationLeaf):
if match_cond_var is None: if match_cond_var is None:
raise ValueError("The condition must be set") raise ValueError("The condition must be set")
return match_cond_var # type: ignore return match_cond_var
@classmethod @classmethod
def _process_cases( def _process_cases(
cls, cases: List cls, cases: List
) -> Tuple[List, Optional[Union[BaseVar, BaseComponent]]]: ) -> Tuple[List, Optional[Union[Var, BaseComponent]]]:
"""Process the list of match cases and the catchall default case. """Process the list of match cases and the catchall default case.
Args: Args:
@ -103,7 +103,7 @@ class Match(MemoizationLeaf):
else default else default
) )
return cases, default # type: ignore return cases, default
@classmethod @classmethod
def _create_case_var_with_var_data(cls, case_element): def _create_case_var_with_var_data(cls, case_element):
@ -117,16 +117,12 @@ class Match(MemoizationLeaf):
Returns: Returns:
The case element Var. The case element Var.
""" """
_var_data = case_element._var_data if isinstance(case_element, Style) else None # type: ignore _var_data = case_element._var_data if isinstance(case_element, Style) else None
case_element = LiteralVar.create(case_element) case_element = LiteralVar.create(case_element, _var_data=_var_data)
if _var_data is not None:
case_element._var_data = VarData.merge(
case_element._get_all_var_data(), _var_data
) # type: ignore
return case_element return case_element
@classmethod @classmethod
def _process_match_cases(cls, cases: List) -> List[List[BaseVar]]: def _process_match_cases(cls, cases: List) -> List[List[Var]]:
"""Process the individual match cases. """Process the individual match cases.
Args: Args:
@ -158,7 +154,7 @@ class Match(MemoizationLeaf):
if not isinstance(element, BaseComponent) if not isinstance(element, BaseComponent)
else element else element
) )
if not isinstance(el, (BaseVar, BaseComponent)): if not isinstance(el, (Var, BaseComponent)):
raise ValueError("Case element must be a var or component") raise ValueError("Case element must be a var or component")
case_list.append(el) case_list.append(el)
@ -167,7 +163,7 @@ class Match(MemoizationLeaf):
return match_cases return match_cases
@classmethod @classmethod
def _validate_return_types(cls, match_cases: List[List[BaseVar]]) -> None: def _validate_return_types(cls, match_cases: List[List[Var]]) -> None:
"""Validate that match cases have the same return types. """Validate that match cases have the same return types.
Args: Args:
@ -181,14 +177,14 @@ class Match(MemoizationLeaf):
if types._isinstance(first_case_return, BaseComponent): if types._isinstance(first_case_return, BaseComponent):
return_type = BaseComponent return_type = BaseComponent
elif types._isinstance(first_case_return, BaseVar): elif types._isinstance(first_case_return, Var):
return_type = BaseVar return_type = Var
for index, case in enumerate(match_cases): for index, case in enumerate(match_cases):
if not types._issubclass(type(case[-1]), return_type): if not types._issubclass(type(case[-1]), return_type):
raise MatchTypeError( raise MatchTypeError(
f"Match cases should have the same return types. Case {index} with return " f"Match cases should have the same return types. Case {index} with return "
f"value `{case[-1]._var_name if isinstance(case[-1], BaseVar) else textwrap.shorten(str(case[-1]), width=250)}`" f"value `{case[-1]._var_name if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`"
f" of type {type(case[-1])!r} is not {return_type}" f" of type {type(case[-1])!r} is not {return_type}"
) )
@ -196,9 +192,9 @@ class Match(MemoizationLeaf):
def _create_match_cond_var_or_component( def _create_match_cond_var_or_component(
cls, cls,
match_cond_var: Var, match_cond_var: Var,
match_cases: List[List[BaseVar]], match_cases: List[List[Var]],
default: Optional[Union[BaseVar, BaseComponent]], default: Optional[Union[Var, BaseComponent]],
) -> Union[Component, BaseVar]: ) -> Union[Component, Var]:
"""Create and return the match condition var or component. """Create and return the match condition var or component.
Args: Args:
@ -229,28 +225,22 @@ class Match(MemoizationLeaf):
# Validate the match cases (as well as the default case) to have Var return types. # Validate the match cases (as well as the default case) to have Var return types.
if any( if any(
case for case in match_cases if not types._isinstance(case[-1], BaseVar) case for case in match_cases if not types._isinstance(case[-1], Var)
) or not types._isinstance(default, BaseVar): ) or not types._isinstance(default, Var):
raise ValueError("Return types of match cases should be Vars.") raise ValueError("Return types of match cases should be Vars.")
# match cases and default should all be Vars at this point. return ImmutableVar(
# Retrieve var data of every var in the match cases and default.
var_data = [
*[el._var_data for case in match_cases for el in case],
default._var_data, # type: ignore
]
return match_cond_var._replace(
_var_name=format.format_match( _var_name=format.format_match(
cond=match_cond_var._var_name_unwrapped, cond=match_cond_var._var_name_unwrapped,
match_cases=match_cases, # type: ignore match_cases=match_cases, # type: ignore
default=default, # type: ignore default=default, # type: ignore
), ),
_var_type=default._var_type, # type: ignore _var_type=default._var_type, # type: ignore
_var_is_local=False, _var_data=ImmutableVarData.merge(
_var_full_name_needs_state_prefix=False, match_cond_var._get_all_var_data(),
_var_is_string=False, *[el._get_all_var_data() for case in match_cases for el in case],
merge_var_data=VarData.merge(*var_data), default._get_all_var_data(), # type: ignore
),
) )
def _render(self) -> Tag: def _render(self) -> Tag:

View File

@ -19,8 +19,10 @@ from reflex.event import (
call_script, call_script,
parse_args_spec, parse_args_spec,
) )
from reflex.ivars.base import ImmutableCallableVar, ImmutableVar
from reflex.ivars.sequence import LiteralStringVar
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
from reflex.vars import BaseVar, CallableVar, Var, VarData from reflex.vars import Var, VarData
DEFAULT_UPLOAD_ID: str = "default" DEFAULT_UPLOAD_ID: str = "default"
@ -35,8 +37,8 @@ upload_files_context_var_data: VarData = VarData(
) )
@CallableVar @ImmutableCallableVar
def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> ImmutableVar:
"""Get the file upload drop trigger. """Get the file upload drop trigger.
This var is passed to the dropzone component to update the file list when a This var is passed to the dropzone component to update the file list when a
@ -48,23 +50,25 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
Returns: Returns:
A var referencing the file upload drop trigger. A var referencing the file upload drop trigger.
""" """
id_var = Var.create_safe(id_, _var_is_string=True) id_var = LiteralStringVar.create(id_)
var_name = f"""e => setFilesById(filesById => {{ var_name = f"""e => setFilesById(filesById => {{
const updatedFilesById = Object.assign({{}}, filesById); const updatedFilesById = Object.assign({{}}, filesById);
updatedFilesById[{id_var._var_name_unwrapped}] = e; updatedFilesById[{str(id_var)}] = e;
return updatedFilesById; return updatedFilesById;
}}) }})
""" """
return BaseVar( return ImmutableVar(
_var_name=var_name, _var_name=var_name,
_var_type=EventChain, _var_type=EventChain,
_var_data=VarData.merge(upload_files_context_var_data, id_var._var_data), _var_data=VarData.merge(
upload_files_context_var_data, id_var._get_all_var_data()
),
) )
@CallableVar @ImmutableCallableVar
def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> ImmutableVar:
"""Get the list of selected files. """Get the list of selected files.
Args: Args:
@ -73,12 +77,14 @@ def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
Returns: Returns:
A var referencing the list of selected file paths. A var referencing the list of selected file paths.
""" """
id_var = Var.create_safe(id_, _var_is_string=True) id_var = LiteralStringVar.create(id_)
return BaseVar( return ImmutableVar(
_var_name=f"(filesById[{id_var._var_name_unwrapped}] ? filesById[{id_var._var_name_unwrapped}].map((f) => (f.path || f.name)) : [])", _var_name=f"(filesById[{str(id_var)}] ? filesById[{str(id_var)}].map((f) => (f.path || f.name)) : [])",
_var_type=List[str], _var_type=List[str],
_var_data=VarData.merge(upload_files_context_var_data, id_var._var_data), _var_data=VarData.merge(
) upload_files_context_var_data, id_var._get_all_var_data()
),
).guess_type()
@CallableEventSpec @CallableEventSpec
@ -245,7 +251,7 @@ class Upload(MemoizationLeaf):
# The file input to use. # The file input to use.
upload = Input.create(type="file") upload = Input.create(type="file")
upload.special_props = { upload.special_props = {
BaseVar(_var_name="{...getInputProps()}", _var_type=None) ImmutableVar(_var_name="{...getInputProps()}", _var_type=None)
} }
# The dropzone to use. # The dropzone to use.
@ -254,7 +260,9 @@ class Upload(MemoizationLeaf):
*children, *children,
**{k: v for k, v in props.items() if k not in supported_props}, **{k: v for k, v in props.items() if k not in supported_props},
) )
zone.special_props = {BaseVar(_var_name="{...getRootProps()}", _var_type=None)} zone.special_props = {
ImmutableVar(_var_name="{...getRootProps()}", _var_type=None)
}
# Create the component. # Create the component.
upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID) upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID)

View File

@ -14,7 +14,7 @@ from reflex.event import EventChain, EventHandler
from reflex.ivars.base import ImmutableVar from reflex.ivars.base import ImmutableVar
from reflex.utils.format import format_event_chain from reflex.utils.format import format_event_chain
from reflex.utils.imports import ImportDict from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var from reflex.vars import Var
from .base import BaseHTML from .base import BaseHTML
@ -198,7 +198,7 @@ class Form(BaseHTML):
if EventTriggers.ON_SUBMIT in self.event_triggers: if EventTriggers.ON_SUBMIT in self.event_triggers:
render_tag.add_props( render_tag.add_props(
**{ **{
EventTriggers.ON_SUBMIT: BaseVar( EventTriggers.ON_SUBMIT: ImmutableVar(
_var_name=f"handleSubmit_{self.handle_submit_unique_name}", _var_name=f"handleSubmit_{self.handle_submit_unique_name}",
_var_type=EventChain, _var_type=EventChain,
) )

View File

@ -1027,3 +1027,50 @@ class OrOperation(ImmutableVar):
The VarData of the components and all of its children. The VarData of the components and all of its children.
""" """
return self._cached_get_all_var_data return self._cached_get_all_var_data
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ImmutableCallableVar(ImmutableVar):
"""Decorate a Var-returning function to act as both a Var and a function.
This is used as a compatibility shim for replacing Var objects in the
API with functions that return a family of Var.
"""
fn: Callable[..., ImmutableVar] = dataclasses.field(
default_factory=lambda: lambda: LiteralVar.create(None)
)
original_var: ImmutableVar = dataclasses.field(
default_factory=lambda: LiteralVar.create(None)
)
def __init__(self, fn: Callable[..., ImmutableVar]):
"""Initialize a CallableVar.
Args:
fn: The function to decorate (must return Var)
"""
original_var = fn()
super(ImmutableCallableVar, self).__init__(
_var_name=original_var._var_name,
_var_type=original_var._var_type,
_var_data=original_var._var_data,
)
object.__setattr__(self, "fn", fn)
object.__setattr__(self, "original_var", original_var)
def __call__(self, *args, **kwargs) -> ImmutableVar:
"""Call the decorated function.
Args:
*args: The args to pass to the function.
**kwargs: The kwargs to pass to the function.
Returns:
The Var returned from calling the function.
"""
return self.fn(*args, **kwargs)

View File

@ -7,7 +7,7 @@ from typing import Any, Literal, Tuple, Type
from reflex import constants from reflex import constants
from reflex.components.core.breakpoints import Breakpoints, breakpoints_values from reflex.components.core.breakpoints import Breakpoints, breakpoints_values
from reflex.event import EventChain from reflex.event import EventChain
from reflex.ivars.base import ImmutableVar, LiteralVar from reflex.ivars.base import ImmutableCallableVar, ImmutableVar, LiteralVar
from reflex.ivars.function import FunctionVar from reflex.ivars.function import FunctionVar
from reflex.utils import format from reflex.utils import format
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
@ -47,7 +47,7 @@ def _color_mode_var(_var_name: str, _var_type: Type = str) -> ImmutableVar:
).guess_type() ).guess_type()
# @CallableVar @ImmutableCallableVar
def set_color_mode( def set_color_mode(
new_color_mode: LiteralColorMode | Var[LiteralColorMode] | None = None, new_color_mode: LiteralColorMode | Var[LiteralColorMode] | None = None,
) -> Var[EventChain]: ) -> Var[EventChain]: