create immutable callable var and get rid of more base vars

This commit is contained in:
Khaleel Al-Adhami 2024-08-05 16:56:29 -07:00
parent da95d0f519
commit c397cb2b04
6 changed files with 105 additions and 61 deletions

View File

@ -3,8 +3,9 @@
from reflex.components.chakra import ChakraComponent
from reflex.components.component import Component
from reflex.components.next.link import NextLink
from reflex.ivars.base import ImmutableVar
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var
from reflex.vars import Var
next_link = NextLink.create()
@ -24,9 +25,7 @@ class Link(ChakraComponent):
text: Var[str]
# What the link renders to.
as_: Var[str] = BaseVar.create(
value="{NextLink}", _var_is_local=False, _var_is_string=False
) # type: ignore
as_: Var[str] = ImmutableVar(_var_name="{NextLink}", _var_type=str)
# If true, the link will open in new tab.
is_external: Var[bool]

View File

@ -6,12 +6,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from reflex.components.base import Fragment
from reflex.components.component import BaseComponent, Component, MemoizationLeaf
from reflex.components.tags import MatchTag, Tag
from reflex.ivars.base import LiteralVar
from reflex.ivars.base import ImmutableVar, LiteralVar
from reflex.style import Style
from reflex.utils import format, types
from reflex.utils.exceptions import MatchTypeError
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var, VarData
from reflex.vars import ImmutableVarData, Var
class Match(MemoizationLeaf):
@ -27,7 +27,7 @@ class Match(MemoizationLeaf):
default: Any
@classmethod
def create(cls, cond: Any, *cases) -> Union[Component, BaseVar]:
def create(cls, cond: Any, *cases) -> Union[Component, Var]:
"""Create a Match Component.
Args:
@ -46,7 +46,7 @@ class Match(MemoizationLeaf):
cls._validate_return_types(match_cases)
if default is None and types._issubclass(type(match_cases[0][-1]), BaseVar):
if default is None and types._issubclass(type(match_cases[0][-1]), Var):
raise ValueError(
"For cases with return types as Vars, a default case must be provided"
)
@ -56,7 +56,7 @@ class Match(MemoizationLeaf):
)
@classmethod
def _create_condition_var(cls, cond: Any) -> BaseVar:
def _create_condition_var(cls, cond: Any) -> Var:
"""Convert the condition to a Var.
Args:
@ -72,12 +72,12 @@ class Match(MemoizationLeaf):
if match_cond_var is None:
raise ValueError("The condition must be set")
return match_cond_var # type: ignore
return match_cond_var
@classmethod
def _process_cases(
cls, cases: List
) -> Tuple[List, Optional[Union[BaseVar, BaseComponent]]]:
) -> Tuple[List, Optional[Union[Var, BaseComponent]]]:
"""Process the list of match cases and the catchall default case.
Args:
@ -103,7 +103,7 @@ class Match(MemoizationLeaf):
else default
)
return cases, default # type: ignore
return cases, default
@classmethod
def _create_case_var_with_var_data(cls, case_element):
@ -117,16 +117,12 @@ class Match(MemoizationLeaf):
Returns:
The case element Var.
"""
_var_data = case_element._var_data if isinstance(case_element, Style) else None # type: ignore
case_element = LiteralVar.create(case_element)
if _var_data is not None:
case_element._var_data = VarData.merge(
case_element._get_all_var_data(), _var_data
) # type: ignore
_var_data = case_element._var_data if isinstance(case_element, Style) else None
case_element = LiteralVar.create(case_element, _var_data=_var_data)
return case_element
@classmethod
def _process_match_cases(cls, cases: List) -> List[List[BaseVar]]:
def _process_match_cases(cls, cases: List) -> List[List[Var]]:
"""Process the individual match cases.
Args:
@ -158,7 +154,7 @@ class Match(MemoizationLeaf):
if not isinstance(element, BaseComponent)
else element
)
if not isinstance(el, (BaseVar, BaseComponent)):
if not isinstance(el, (Var, BaseComponent)):
raise ValueError("Case element must be a var or component")
case_list.append(el)
@ -167,7 +163,7 @@ class Match(MemoizationLeaf):
return match_cases
@classmethod
def _validate_return_types(cls, match_cases: List[List[BaseVar]]) -> None:
def _validate_return_types(cls, match_cases: List[List[Var]]) -> None:
"""Validate that match cases have the same return types.
Args:
@ -181,14 +177,14 @@ class Match(MemoizationLeaf):
if types._isinstance(first_case_return, BaseComponent):
return_type = BaseComponent
elif types._isinstance(first_case_return, BaseVar):
return_type = BaseVar
elif types._isinstance(first_case_return, Var):
return_type = Var
for index, case in enumerate(match_cases):
if not types._issubclass(type(case[-1]), return_type):
raise MatchTypeError(
f"Match cases should have the same return types. Case {index} with return "
f"value `{case[-1]._var_name if isinstance(case[-1], BaseVar) else textwrap.shorten(str(case[-1]), width=250)}`"
f"value `{case[-1]._var_name if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`"
f" of type {type(case[-1])!r} is not {return_type}"
)
@ -196,9 +192,9 @@ class Match(MemoizationLeaf):
def _create_match_cond_var_or_component(
cls,
match_cond_var: Var,
match_cases: List[List[BaseVar]],
default: Optional[Union[BaseVar, BaseComponent]],
) -> Union[Component, BaseVar]:
match_cases: List[List[Var]],
default: Optional[Union[Var, BaseComponent]],
) -> Union[Component, Var]:
"""Create and return the match condition var or component.
Args:
@ -229,28 +225,22 @@ class Match(MemoizationLeaf):
# Validate the match cases (as well as the default case) to have Var return types.
if any(
case for case in match_cases if not types._isinstance(case[-1], BaseVar)
) or not types._isinstance(default, BaseVar):
case for case in match_cases if not types._isinstance(case[-1], Var)
) or not types._isinstance(default, Var):
raise ValueError("Return types of match cases should be Vars.")
# match cases and default should all be Vars at this point.
# Retrieve var data of every var in the match cases and default.
var_data = [
*[el._var_data for case in match_cases for el in case],
default._var_data, # type: ignore
]
return match_cond_var._replace(
return ImmutableVar(
_var_name=format.format_match(
cond=match_cond_var._var_name_unwrapped,
match_cases=match_cases, # type: ignore
default=default, # type: ignore
),
_var_type=default._var_type, # type: ignore
_var_is_local=False,
_var_full_name_needs_state_prefix=False,
_var_is_string=False,
merge_var_data=VarData.merge(*var_data),
_var_data=ImmutableVarData.merge(
match_cond_var._get_all_var_data(),
*[el._get_all_var_data() for case in match_cases for el in case],
default._get_all_var_data(), # type: ignore
),
)
def _render(self) -> Tag:

View File

@ -19,8 +19,10 @@ from reflex.event import (
call_script,
parse_args_spec,
)
from reflex.ivars.base import ImmutableCallableVar, ImmutableVar
from reflex.ivars.sequence import LiteralStringVar
from reflex.utils.imports import ImportVar
from reflex.vars import BaseVar, CallableVar, Var, VarData
from reflex.vars import Var, VarData
DEFAULT_UPLOAD_ID: str = "default"
@ -35,8 +37,8 @@ upload_files_context_var_data: VarData = VarData(
)
@CallableVar
def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
@ImmutableCallableVar
def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> ImmutableVar:
"""Get the file upload drop trigger.
This var is passed to the dropzone component to update the file list when a
@ -48,23 +50,25 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
Returns:
A var referencing the file upload drop trigger.
"""
id_var = Var.create_safe(id_, _var_is_string=True)
id_var = LiteralStringVar.create(id_)
var_name = f"""e => setFilesById(filesById => {{
const updatedFilesById = Object.assign({{}}, filesById);
updatedFilesById[{id_var._var_name_unwrapped}] = e;
updatedFilesById[{str(id_var)}] = e;
return updatedFilesById;
}})
"""
return BaseVar(
return ImmutableVar(
_var_name=var_name,
_var_type=EventChain,
_var_data=VarData.merge(upload_files_context_var_data, id_var._var_data),
_var_data=VarData.merge(
upload_files_context_var_data, id_var._get_all_var_data()
),
)
@CallableVar
def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
@ImmutableCallableVar
def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> ImmutableVar:
"""Get the list of selected files.
Args:
@ -73,12 +77,14 @@ def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
Returns:
A var referencing the list of selected file paths.
"""
id_var = Var.create_safe(id_, _var_is_string=True)
return BaseVar(
_var_name=f"(filesById[{id_var._var_name_unwrapped}] ? filesById[{id_var._var_name_unwrapped}].map((f) => (f.path || f.name)) : [])",
id_var = LiteralStringVar.create(id_)
return ImmutableVar(
_var_name=f"(filesById[{str(id_var)}] ? filesById[{str(id_var)}].map((f) => (f.path || f.name)) : [])",
_var_type=List[str],
_var_data=VarData.merge(upload_files_context_var_data, id_var._var_data),
)
_var_data=VarData.merge(
upload_files_context_var_data, id_var._get_all_var_data()
),
).guess_type()
@CallableEventSpec
@ -245,7 +251,7 @@ class Upload(MemoizationLeaf):
# The file input to use.
upload = Input.create(type="file")
upload.special_props = {
BaseVar(_var_name="{...getInputProps()}", _var_type=None)
ImmutableVar(_var_name="{...getInputProps()}", _var_type=None)
}
# The dropzone to use.
@ -254,7 +260,9 @@ class Upload(MemoizationLeaf):
*children,
**{k: v for k, v in props.items() if k not in supported_props},
)
zone.special_props = {BaseVar(_var_name="{...getRootProps()}", _var_type=None)}
zone.special_props = {
ImmutableVar(_var_name="{...getRootProps()}", _var_type=None)
}
# Create the component.
upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID)

View File

@ -14,7 +14,7 @@ from reflex.event import EventChain, EventHandler
from reflex.ivars.base import ImmutableVar
from reflex.utils.format import format_event_chain
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var
from reflex.vars import Var
from .base import BaseHTML
@ -198,7 +198,7 @@ class Form(BaseHTML):
if EventTriggers.ON_SUBMIT in self.event_triggers:
render_tag.add_props(
**{
EventTriggers.ON_SUBMIT: BaseVar(
EventTriggers.ON_SUBMIT: ImmutableVar(
_var_name=f"handleSubmit_{self.handle_submit_unique_name}",
_var_type=EventChain,
)

View File

@ -1027,3 +1027,50 @@ class OrOperation(ImmutableVar):
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ImmutableCallableVar(ImmutableVar):
"""Decorate a Var-returning function to act as both a Var and a function.
This is used as a compatibility shim for replacing Var objects in the
API with functions that return a family of Var.
"""
fn: Callable[..., ImmutableVar] = dataclasses.field(
default_factory=lambda: lambda: LiteralVar.create(None)
)
original_var: ImmutableVar = dataclasses.field(
default_factory=lambda: LiteralVar.create(None)
)
def __init__(self, fn: Callable[..., ImmutableVar]):
"""Initialize a CallableVar.
Args:
fn: The function to decorate (must return Var)
"""
original_var = fn()
super(ImmutableCallableVar, self).__init__(
_var_name=original_var._var_name,
_var_type=original_var._var_type,
_var_data=original_var._var_data,
)
object.__setattr__(self, "fn", fn)
object.__setattr__(self, "original_var", original_var)
def __call__(self, *args, **kwargs) -> ImmutableVar:
"""Call the decorated function.
Args:
*args: The args to pass to the function.
**kwargs: The kwargs to pass to the function.
Returns:
The Var returned from calling the function.
"""
return self.fn(*args, **kwargs)

View File

@ -7,7 +7,7 @@ from typing import Any, Literal, Tuple, Type
from reflex import constants
from reflex.components.core.breakpoints import Breakpoints, breakpoints_values
from reflex.event import EventChain
from reflex.ivars.base import ImmutableVar, LiteralVar
from reflex.ivars.base import ImmutableCallableVar, ImmutableVar, LiteralVar
from reflex.ivars.function import FunctionVar
from reflex.utils import format
from reflex.utils.imports import ImportVar
@ -47,7 +47,7 @@ def _color_mode_var(_var_name: str, _var_type: Type = str) -> ImmutableVar:
).guess_type()
# @CallableVar
@ImmutableCallableVar
def set_color_mode(
new_color_mode: LiteralColorMode | Var[LiteralColorMode] | None = None,
) -> Var[EventChain]: