fix tests
This commit is contained in:
parent
982e4c110d
commit
f6938cffa6
@ -813,14 +813,14 @@ class LiteralVar(ImmutableVar):
|
||||
|
||||
if isinstance(value, Figure):
|
||||
return LiteralObjectVar.create(
|
||||
json.loads(to_json(value)), _var_type=Figure, _var_data=_var_data
|
||||
json.loads(str(to_json(value))), _var_type=Figure, _var_data=_var_data
|
||||
)
|
||||
|
||||
if isinstance(value, layout.Template):
|
||||
return LiteralObjectVar.create(
|
||||
{
|
||||
"data": json.loads(to_json(value.data)),
|
||||
"layout": json.loads(to_json(value.layout)),
|
||||
"data": json.loads(str(to_json(value.data))),
|
||||
"layout": json.loads(str(to_json(value.layout))),
|
||||
},
|
||||
_var_type=layout.Template,
|
||||
_var_data=_var_data,
|
||||
@ -1206,12 +1206,20 @@ class ImmutableCallableVar(ImmutableVar):
|
||||
return hash((self.__class__.__name__, self.original_var))
|
||||
|
||||
|
||||
RETURN_TYPE = TypeVar("RETURN_TYPE")
|
||||
|
||||
DICT_KEY = TypeVar("DICT_KEY")
|
||||
DICT_VAL = TypeVar("DICT_VAL")
|
||||
|
||||
LIST_INSIDE = TypeVar("LIST_INSIDE")
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class ImmutableComputedVar(ImmutableVar):
|
||||
class ImmutableComputedVar(ImmutableVar[RETURN_TYPE]):
|
||||
"""A field with computed getters."""
|
||||
|
||||
# Whether to track dependencies and cache computed values
|
||||
@ -1221,7 +1229,7 @@ class ImmutableComputedVar(ImmutableVar):
|
||||
_backend: bool = dataclasses.field(default=False)
|
||||
|
||||
# The initial value of the computed var
|
||||
_initial_value: Any | types.Unset = dataclasses.field(default=types.Unset())
|
||||
_initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
|
||||
|
||||
# Explicit var dependencies to track
|
||||
_static_deps: set[str] = dataclasses.field(default_factory=set)
|
||||
@ -1232,14 +1240,14 @@ class ImmutableComputedVar(ImmutableVar):
|
||||
# Interval at which the computed var should be updated
|
||||
_update_interval: Optional[datetime.timedelta] = dataclasses.field(default=None)
|
||||
|
||||
_fget: Callable[[BaseState], Any] = dataclasses.field(
|
||||
_fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field(
|
||||
default_factory=lambda: lambda _: None
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fget: Callable[[BaseState], Any],
|
||||
initial_value: Any | types.Unset = types.Unset(),
|
||||
fget: Callable[[BASE_STATE], RETURN_TYPE],
|
||||
initial_value: RETURN_TYPE | types.Unset = types.Unset(),
|
||||
cache: bool = False,
|
||||
deps: Optional[List[Union[str, Var]]] = None,
|
||||
auto_deps: bool = True,
|
||||
@ -1381,6 +1389,56 @@ class ImmutableComputedVar(ImmutableVar):
|
||||
return True
|
||||
return datetime.datetime.now() - last_updated > self._update_interval
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: ImmutableComputedVar[int] | ImmutableComputedVar[float],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: ImmutableComputedVar[str],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> StringVar: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: ImmutableComputedVar[dict[DICT_KEY, DICT_VAL]],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: ImmutableComputedVar[list[LIST_INSIDE]],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> ArrayVar[list[LIST_INSIDE]]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: ImmutableComputedVar[set[LIST_INSIDE]],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> ArrayVar[set[LIST_INSIDE]]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: ImmutableComputedVar[tuple[LIST_INSIDE, ...]],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self, instance: None, owner: Type
|
||||
) -> ImmutableComputedVar[RETURN_TYPE]: ...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: BaseState, owner: Type) -> RETURN_TYPE: ...
|
||||
|
||||
def __get__(self, instance: BaseState | None, owner):
|
||||
"""Get the ComputedVar value.
|
||||
|
||||
@ -1556,7 +1614,7 @@ class ImmutableComputedVar(ImmutableVar):
|
||||
return ComputedVar
|
||||
|
||||
@property
|
||||
def fget(self) -> Callable[[BaseState], Any]:
|
||||
def fget(self) -> Callable[[BaseState], RETURN_TYPE]:
|
||||
"""Get the getter function.
|
||||
|
||||
Returns:
|
||||
@ -1565,8 +1623,42 @@ class ImmutableComputedVar(ImmutableVar):
|
||||
return self._fget
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
|
||||
|
||||
|
||||
@overload
|
||||
def immutable_computed_var(
|
||||
fget: Callable[[BaseState], Any] | None = None,
|
||||
fget: None = None,
|
||||
initial_value: Any | types.Unset = types.Unset(),
|
||||
cache: bool = False,
|
||||
deps: Optional[List[Union[str, Var]]] = None,
|
||||
auto_deps: bool = True,
|
||||
interval: Optional[Union[datetime.timedelta, int]] = None,
|
||||
backend: bool | None = None,
|
||||
_deprecated_cached_var: bool = False,
|
||||
**kwargs,
|
||||
) -> Callable[
|
||||
[Callable[[BASE_STATE], RETURN_TYPE]], ImmutableComputedVar[RETURN_TYPE]
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def immutable_computed_var(
|
||||
fget: Callable[[BASE_STATE], RETURN_TYPE],
|
||||
initial_value: RETURN_TYPE | types.Unset = types.Unset(),
|
||||
cache: bool = False,
|
||||
deps: Optional[List[Union[str, Var]]] = None,
|
||||
auto_deps: bool = True,
|
||||
interval: Optional[Union[datetime.timedelta, int]] = None,
|
||||
backend: bool | None = None,
|
||||
_deprecated_cached_var: bool = False,
|
||||
**kwargs,
|
||||
) -> ImmutableComputedVar[RETURN_TYPE]: ...
|
||||
|
||||
|
||||
def immutable_computed_var(
|
||||
fget: Callable[[BASE_STATE], Any] | None = None,
|
||||
initial_value: Any | types.Unset = types.Unset(),
|
||||
cache: bool = False,
|
||||
deps: Optional[List[Union[str, Var]]] = None,
|
||||
@ -1576,7 +1668,7 @@ def immutable_computed_var(
|
||||
_deprecated_cached_var: bool = False,
|
||||
**kwargs,
|
||||
) -> (
|
||||
ImmutableComputedVar | Callable[[Callable[[BaseState], Any]], ImmutableComputedVar]
|
||||
ImmutableComputedVar | Callable[[Callable[[BASE_STATE], Any]], ImmutableComputedVar]
|
||||
):
|
||||
"""A ComputedVar decorator with or without kwargs.
|
||||
|
||||
@ -1615,7 +1707,7 @@ def immutable_computed_var(
|
||||
if fget is not None:
|
||||
return ImmutableComputedVar(fget, cache=cache)
|
||||
|
||||
def wrapper(fget: Callable[[BaseState], Any]) -> ImmutableComputedVar:
|
||||
def wrapper(fget: Callable[[BASE_STATE], Any]) -> ImmutableComputedVar:
|
||||
return ImmutableComputedVar(
|
||||
fget,
|
||||
initial_value=initial_value,
|
||||
|
@ -892,10 +892,10 @@ def test_literal_var():
|
||||
|
||||
|
||||
def test_function_var():
|
||||
addition_func = FunctionStringVar("((a, b) => a + b)")
|
||||
addition_func = FunctionStringVar.create("((a, b) => a + b)")
|
||||
assert str(addition_func.call(1, 2)) == "(((a, b) => a + b)(1, 2))"
|
||||
|
||||
manual_addition_func = ArgsFunctionOperation(
|
||||
manual_addition_func = ArgsFunctionOperation.create(
|
||||
("a", "b"),
|
||||
{
|
||||
"args": [ImmutableVar.create_safe("a"), ImmutableVar.create_safe("b")],
|
||||
@ -913,11 +913,11 @@ def test_function_var():
|
||||
== "(((...args) => ((((a, b) => a + b)(1, ...args))))(2))"
|
||||
)
|
||||
|
||||
create_hello_statement = ArgsFunctionOperation(
|
||||
create_hello_statement = ArgsFunctionOperation.create(
|
||||
("name",), f"Hello, {ImmutableVar.create_safe('name')}!"
|
||||
)
|
||||
first_name = LiteralStringVar("Steven")
|
||||
last_name = LiteralStringVar("Universe")
|
||||
first_name = LiteralStringVar.create("Steven")
|
||||
last_name = LiteralStringVar.create("Universe")
|
||||
assert (
|
||||
str(create_hello_statement.call(f"{first_name} {last_name}"))
|
||||
== '(((name) => (("Hello, "+name+"!")))(("Steven"+" "+"Universe")))'
|
||||
@ -932,7 +932,7 @@ def test_var_operation():
|
||||
assert str(add(1, 2)) == "(1 + 2)"
|
||||
assert str(add(a=4, b=-9)) == "(4 + -9)"
|
||||
|
||||
five = LiteralNumberVar(5)
|
||||
five = LiteralNumberVar.create(5)
|
||||
seven = add(2, five)
|
||||
|
||||
assert isinstance(seven, NumberVar)
|
||||
@ -952,7 +952,7 @@ def test_string_operations():
|
||||
|
||||
|
||||
def test_all_number_operations():
|
||||
starting_number = LiteralNumberVar(-5.4)
|
||||
starting_number = LiteralNumberVar.create(-5.4)
|
||||
|
||||
complicated_number = (((-(starting_number + 1)) * 2 / 3) // 2 % 3) ** 2
|
||||
|
||||
@ -970,16 +970,16 @@ def test_all_number_operations():
|
||||
== "!(((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) != 0) || (true && (Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)) != 0))))"
|
||||
)
|
||||
|
||||
assert str(LiteralNumberVar(5) > False) == "(5 > 0)"
|
||||
assert str(LiteralBooleanVar(False) < 5) == "((false ? 1 : 0) < 5)"
|
||||
assert str(LiteralNumberVar.create(5) > False) == "(5 > 0)"
|
||||
assert str(LiteralBooleanVar.create(False) < 5) == "((false ? 1 : 0) < 5)"
|
||||
assert (
|
||||
str(LiteralBooleanVar(False) < LiteralBooleanVar(True))
|
||||
str(LiteralBooleanVar.create(False) < LiteralBooleanVar.create(True))
|
||||
== "((false ? 1 : 0) < (true ? 1 : 0))"
|
||||
)
|
||||
|
||||
|
||||
def test_index_operation():
|
||||
array_var = LiteralArrayVar([1, 2, 3, 4, 5])
|
||||
array_var = LiteralArrayVar.create([1, 2, 3, 4, 5])
|
||||
assert str(array_var[0]) == "[1, 2, 3, 4, 5].at(0)"
|
||||
assert str(array_var[1:2]) == "[1, 2, 3, 4, 5].slice(1, 2)"
|
||||
assert (
|
||||
@ -1019,7 +1019,7 @@ def test_array_operations():
|
||||
|
||||
|
||||
def test_object_operations():
|
||||
object_var = LiteralObjectVar({"a": 1, "b": 2, "c": 3})
|
||||
object_var = LiteralObjectVar.create({"a": 1, "b": 2, "c": 3})
|
||||
|
||||
assert (
|
||||
str(object_var.keys()) == 'Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))'
|
||||
@ -1035,13 +1035,13 @@ def test_object_operations():
|
||||
assert str(object_var.a) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]'
|
||||
assert str(object_var["a"]) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]'
|
||||
assert (
|
||||
str(object_var.merge(LiteralObjectVar({"c": 4, "d": 5})))
|
||||
str(object_var.merge(LiteralObjectVar.create({"c": 4, "d": 5})))
|
||||
== 'Object.assign(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }), ({ ["c"] : 4, ["d"] : 5 }))'
|
||||
)
|
||||
|
||||
|
||||
def test_type_chains():
|
||||
object_var = LiteralObjectVar({"a": 1, "b": 2, "c": 3})
|
||||
object_var = LiteralObjectVar.create({"a": 1, "b": 2, "c": 3})
|
||||
assert (object_var._key_type(), object_var._value_type()) == (str, int)
|
||||
assert (object_var.keys()._var_type, object_var.values()._var_type) == (
|
||||
List[str],
|
||||
@ -1062,7 +1062,7 @@ def test_type_chains():
|
||||
|
||||
|
||||
def test_nested_dict():
|
||||
arr = LiteralArrayVar([{"bar": ["foo", "bar"]}], List[Dict[str, List[str]]])
|
||||
arr = LiteralArrayVar.create([{"bar": ["foo", "bar"]}], List[Dict[str, List[str]]])
|
||||
|
||||
assert (
|
||||
str(arr[0]["bar"][0]) == '[({ ["bar"] : ["foo", "bar"] })].at(0)["bar"].at(0)'
|
||||
|
Loading…
Reference in New Issue
Block a user