From d516b3bfc494ad2bbe1f84cb259066f3cf902195 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Wed, 14 Aug 2024 14:36:37 -0700 Subject: [PATCH] pass all testcases --- reflex/components/datadisplay/code.py | 5 +-- reflex/components/gridjs/datatable.py | 12 +++---- reflex/ivars/base.py | 49 +++++++++++++++++++++++---- reflex/ivars/object.py | 15 +++++--- reflex/vars.py | 4 +-- tests/components/media/test_image.py | 5 +-- tests/test_app.py | 2 +- tests/test_state.py | 29 ++++++++++------ tests/utils/test_format.py | 27 +++++++++------ 9 files changed, 102 insertions(+), 46 deletions(-) diff --git a/reflex/components/datadisplay/code.py b/reflex/components/datadisplay/code.py index 979147f99..c8106062f 100644 --- a/reflex/components/datadisplay/code.py +++ b/reflex/components/datadisplay/code.py @@ -440,9 +440,10 @@ class CodeBlock(Component): def _get_custom_code(self) -> Optional[str]: if ( self.language is not None - and self.language._var_name in LiteralCodeLanguage.__args__ # type: ignore + and (language_without_quotes := str(self.language).replace('"', "")) + in LiteralCodeLanguage.__args__ # type: ignore ): - return f"{self.alias}.registerLanguage('{self.language._var_name}', {format.to_camel_case(self.language._var_name)})" + return f"{self.alias}.registerLanguage('{language_without_quotes}', {format.to_camel_case(language_without_quotes)})" @classmethod def create( diff --git a/reflex/components/gridjs/datatable.py b/reflex/components/gridjs/datatable.py index 075c08d59..47069843b 100644 --- a/reflex/components/gridjs/datatable.py +++ b/reflex/components/gridjs/datatable.py @@ -10,7 +10,7 @@ from reflex.ivars.base import ImmutableComputedVar from reflex.utils import types from reflex.utils.imports import ImportDict from reflex.utils.serializers import serialize -from reflex.vars import BaseVar, ComputedVar, Var +from reflex.vars import ComputedVar, Var class Gridjs(Component): @@ -101,6 +101,8 @@ class DataTable(Gridjs): "column field should be specified when the data field is a list type" ) + print("props", props) + # Create the component. return super().create( *children, @@ -117,17 +119,13 @@ class DataTable(Gridjs): def _render(self) -> Tag: if isinstance(self.data, Var) and types.is_dataframe(self.data._var_type): - self.columns = BaseVar( + self.columns = self.data._replace( _var_name=f"{self.data._var_name}.columns", _var_type=List[Any], - _var_full_name_needs_state_prefix=True, - _var_data=self.data._var_data, ) - self.data = BaseVar( + self.data = self.data._replace( _var_name=f"{self.data._var_name}.data", _var_type=List[List[Any]], - _var_full_name_needs_state_prefix=True, - _var_data=self.data._var_data, ) if types.is_dataframe(type(self.data)): # If given a pandas df break up the data and columns diff --git a/reflex/ivars/base.py b/reflex/ivars/base.py index 540528eba..9b41c7dc6 100644 --- a/reflex/ivars/base.py +++ b/reflex/ivars/base.py @@ -10,6 +10,7 @@ import functools import inspect import json import sys +import warnings from types import CodeType, FunctionType from typing import ( TYPE_CHECKING, @@ -440,6 +441,8 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): return self var_type = self._var_type + if types.is_optional(var_type): + var_type = types.get_args(var_type)[0] fixed_type = var_type if inspect.isclass(var_type) else get_origin(var_type) @@ -450,15 +453,15 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): raise TypeError(f"Unsupported type {var_type} for guess_type.") if issubclass(fixed_type, (int, float)): - return self.to(NumberVar, var_type) + return self.to(NumberVar, self._var_type) if issubclass(fixed_type, dict): - return self.to(ObjectVar, var_type) + return self.to(ObjectVar, self._var_type) if issubclass(fixed_type, (list, tuple, set)): - return self.to(ArrayVar, var_type) + return self.to(ArrayVar, self._var_type) if issubclass(fixed_type, str): return self.to(StringVar) if issubclass(fixed_type, Base): - return self.to(ObjectVar, var_type) + return self.to(ObjectVar, self._var_type) return self def get_default_value(self) -> Any: @@ -837,16 +840,48 @@ class LiteralVar(ImmutableVar): except ImportError: pass + from .sequence import LiteralArrayVar, LiteralStringVar + + try: + import base64 + import io + + from PIL.Image import MIME + from PIL.Image import Image as Img + + if isinstance(value, Img): + with io.BytesIO() as buffer: + value.save(buffer, format=getattr(value, "format", None) or "PNG") + try: + # Newer method to get the mime type, but does not always work. + mimetype = value.get_format_mimetype() + except AttributeError: + try: + # Fallback method + mimetype = MIME[value.format] + except KeyError: + # Unknown mime_type: warn and return image/png and hope the browser can sort it out. + warnings.warn( # noqa: B028 + f"Unknown mime type for {value} {value.format}. Defaulting to image/png" + ) + mimetype = "image/png" + return LiteralStringVar.create( + f"data:{mimetype};base64,{base64.b64encode(buffer.getvalue()).decode()}", + _var_data=_var_data, + ) + except ImportError: + pass + if isinstance(value, Base): return LiteralObjectVar.create( - value.dict(), _var_type=type(value), _var_data=_var_data + {k: (None if callable(v) else v) for k, v in value.dict().items()}, + _var_type=type(value), + _var_data=_var_data, ) if isinstance(value, dict): return LiteralObjectVar.create(value, _var_data=_var_data) - from .sequence import LiteralArrayVar, LiteralStringVar - if isinstance(value, str): return LiteralStringVar.create(value, _var_data=_var_data) diff --git a/reflex/ivars/object.py b/reflex/ivars/object.py index 2bca9fa74..49441c3cd 100644 --- a/reflex/ivars/object.py +++ b/reflex/ivars/object.py @@ -22,6 +22,7 @@ from typing import ( from typing_extensions import get_origin +from reflex.utils import types from reflex.utils.exceptions import VarAttributeError from reflex.utils.types import GenericType, get_attribute_access_type from reflex.vars import ImmutableVarData, Var, VarData @@ -245,11 +246,15 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): """ if name.startswith("__") and name.endswith("__"): return getattr(super(type(self), self), name) - fixed_type = ( - self._var_type if isclass(self._var_type) else get_origin(self._var_type) - ) + + var_type = self._var_type + + if types.is_optional(var_type): + var_type = get_args(var_type)[0] + + fixed_type = var_type if isclass(var_type) else get_origin(var_type) if isclass(fixed_type) and not issubclass(fixed_type, dict): - attribute_type = get_attribute_access_type(self._var_type, name) + attribute_type = get_attribute_access_type(var_type, name) if attribute_type is None: raise VarAttributeError( f"The State var `{self._var_name}` has no attribute '{name}' or may have been annotated " @@ -727,6 +732,8 @@ class ObjectItemOperation(ImmutableVar): Returns: The name of the operation. """ + if types.is_optional(self._object._var_type): + return f"{str(self._object)}?.[{str(self._key)}]" return f"{str(self._object)}[{str(self._key)}]" def __getattr__(self, name): diff --git a/reflex/vars.py b/reflex/vars.py index 3a5f08e38..8f3732650 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -753,10 +753,10 @@ class Var: return self._var_name try: return json.loads(self._var_name) - except: + except ValueError: try: return json.loads(self.json()) - except: + except (ValueError, NotImplementedError): return self._var_name def equals(self, other: Var) -> bool: diff --git a/tests/components/media/test_image.py b/tests/components/media/test_image.py index e0dd771f7..198fbc844 100644 --- a/tests/components/media/test_image.py +++ b/tests/components/media/test_image.py @@ -6,6 +6,7 @@ from PIL.Image import Image as Img import reflex as rx from reflex.components.next.image import Image # type: ignore +from reflex.ivars.sequence import StringVar from reflex.utils.serializers import serialize, serialize_image @@ -52,7 +53,7 @@ def test_set_src_img(pil_image: Img): pil_image: The image to serialize. """ image = Image.create(src=pil_image) - assert str(image.src._var_name) == serialize_image(pil_image) # type: ignore + assert str(image.src._var_name) == '"' + serialize_image(pil_image) + '"' # type: ignore def test_render(pil_image: Img): @@ -62,4 +63,4 @@ def test_render(pil_image: Img): pil_image: The image to serialize. """ image = Image.create(src=pil_image) - assert image.src._var_is_string # type: ignore + assert isinstance(image.src, StringVar) diff --git a/tests/test_app.py b/tests/test_app.py index e6b86ddf6..0acb4b5ac 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1251,7 +1251,7 @@ def test_app_wrap_compile_theme(compilable_app: tuple[App, Path]): "function AppWrap({children}) {" "return (" "" - "" + "" "" "{children}" "" diff --git a/tests/test_state.py b/tests/test_state.py index 5d965cba3..11cab8df8 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1061,7 +1061,8 @@ def test_dirty_computed_var_from_backend_var( Args: interdependent_state: A state with varying Var dependencies. """ - assert InterdependentState._v3._backend is True + # Accessing ._v3 returns the immutable var it represents instead of the actual computed var + # assert InterdependentState._v3._backend is True interdependent_state._v2 = 2 assert interdependent_state.get_delta() == { interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4}, @@ -2601,15 +2602,23 @@ def test_state_union_optional(): c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo=""))) custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="") - assert UnionState.c3.c2._var_name == "c3?.c2" # type: ignore - assert UnionState.c3.c2.c1._var_name == "c3?.c2?.c1" # type: ignore - assert UnionState.c3.c2.c1.foo._var_name == "c3?.c2?.c1?.foo" # type: ignore - assert UnionState.c3.c2.c1r.foo._var_name == "c3?.c2?.c1r.foo" # type: ignore - assert UnionState.c3.c2r.c1._var_name == "c3?.c2r.c1" # type: ignore - assert UnionState.c3.c2r.c1.foo._var_name == "c3?.c2r.c1?.foo" # type: ignore - assert UnionState.c3.c2r.c1r.foo._var_name == "c3?.c2r.c1r.foo" # type: ignore - assert UnionState.c3i.c2._var_name == "c3i.c2" # type: ignore - assert UnionState.c3r.c2._var_name == "c3r.c2" # type: ignore + assert str(UnionState.c3.c2) == f'{str(UnionState.c3)}?.["c2"]' + assert str(UnionState.c3.c2.c1) == f'{str(UnionState.c3)}?.["c2"]?.["c1"]' + assert ( + str(UnionState.c3.c2.c1.foo) == f'{str(UnionState.c3)}?.["c2"]?.["c1"]?.["foo"]' + ) + assert ( + str(UnionState.c3.c2.c1r.foo) == f'{str(UnionState.c3)}?.["c2"]?.["c1r"]["foo"]' + ) + assert str(UnionState.c3.c2r.c1) == f'{str(UnionState.c3)}?.["c2r"]["c1"]' + assert ( + str(UnionState.c3.c2r.c1.foo) == f'{str(UnionState.c3)}?.["c2r"]["c1"]?.["foo"]' + ) + assert ( + str(UnionState.c3.c2r.c1r.foo) == f'{str(UnionState.c3)}?.["c2r"]["c1r"]["foo"]' + ) + assert str(UnionState.c3i.c2) == f'{str(UnionState.c3i)}["c2"]' + assert str(UnionState.c3r.c2) == f'{str(UnionState.c3r)}["c2"]' assert UnionState.custom_union.foo is not None # type: ignore assert UnionState.custom_union.c1 is not None # type: ignore assert UnionState.custom_union.c1r is not None # type: ignore diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index 95ebc047b..4623f0fb2 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -8,6 +8,7 @@ import pytest from reflex.components.tags.tag import Tag from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent +from reflex.ivars.base import ImmutableVar, LiteralVar from reflex.style import Style from reflex.utils import format from reflex.utils.serializers import serialize_figure @@ -422,19 +423,19 @@ def test_format_cond( ( "state__state.value", [ - [Var.create(1), Var.create("red", _var_is_string=True)], - [Var.create(2), Var.create(3), Var.create("blue", _var_is_string=True)], + [LiteralVar.create(1), LiteralVar.create("red")], + [LiteralVar.create(2), LiteralVar.create(3), LiteralVar.create("blue")], [TestState.mapping, TestState.num1], [ - Var.create(f"{TestState.map_key}-key", _var_is_string=True), - Var.create("return-key", _var_is_string=True), + LiteralVar.create(f"{TestState.map_key}-key"), + LiteralVar.create("return-key"), ], ], - Var.create("yellow", _var_is_string=True), - "(() => { switch (JSON.stringify(state__state.value)) {case JSON.stringify(1): return (`red`); break;case JSON.stringify(2): case JSON.stringify(3): " - f"return (`blue`); break;case JSON.stringify({TestState.get_full_name()}.mapping): return " - f"({TestState.get_full_name()}.num1); break;case JSON.stringify(`${{{TestState.get_full_name()}.map_key}}-key`): return (`return-key`);" - " break;default: return (`yellow`); break;};})()", + LiteralVar.create("yellow"), + '(() => { switch (JSON.stringify(state__state.value)) {case JSON.stringify(1): return ("red"); break;case JSON.stringify(2): case JSON.stringify(3): ' + f'return ("blue"); break;case JSON.stringify({TestState.get_full_name()}.mapping): return ' + f'({TestState.get_full_name()}.num1); break;case JSON.stringify(({TestState.get_full_name()}.map_key+"-key")): return ("return-key");' + ' break;default: return ("yellow"); break;};})()', ) ], ) @@ -541,7 +542,7 @@ def test_format_match( { "h1": f"{{({{node, ...props}}) => }}" }, - '{{"h1": ({node, ...props}) => }}', + '{{"h1": ({node, ...props}) => }}', ), ], ) @@ -558,7 +559,11 @@ def test_format_prop(prop: Var, formatted: str): @pytest.mark.parametrize( "single_props,key_value_props,output", [ - (["string"], {"key": 42}, ["key={42}", "string"]), + ( + [ImmutableVar.create_safe("...props")], + {"key": 42}, + ["key={42}", "{...props}"], + ), ], ) def test_format_props(single_props, key_value_props, output):