pass all testcases

This commit is contained in:
Khaleel Al-Adhami 2024-08-14 14:36:37 -07:00
parent 39bc0c0b57
commit d516b3bfc4
9 changed files with 102 additions and 46 deletions

View File

@ -440,9 +440,10 @@ class CodeBlock(Component):
def _get_custom_code(self) -> Optional[str]: def _get_custom_code(self) -> Optional[str]:
if ( if (
self.language is not None self.language is not None
and self.language._var_name in LiteralCodeLanguage.__args__ # type: ignore and (language_without_quotes := str(self.language).replace('"', ""))
in LiteralCodeLanguage.__args__ # type: ignore
): ):
return f"{self.alias}.registerLanguage('{self.language._var_name}', {format.to_camel_case(self.language._var_name)})" return f"{self.alias}.registerLanguage('{language_without_quotes}', {format.to_camel_case(language_without_quotes)})"
@classmethod @classmethod
def create( def create(

View File

@ -10,7 +10,7 @@ from reflex.ivars.base import ImmutableComputedVar
from reflex.utils import types from reflex.utils import types
from reflex.utils.imports import ImportDict from reflex.utils.imports import ImportDict
from reflex.utils.serializers import serialize from reflex.utils.serializers import serialize
from reflex.vars import BaseVar, ComputedVar, Var from reflex.vars import ComputedVar, Var
class Gridjs(Component): class Gridjs(Component):
@ -101,6 +101,8 @@ class DataTable(Gridjs):
"column field should be specified when the data field is a list type" "column field should be specified when the data field is a list type"
) )
print("props", props)
# Create the component. # Create the component.
return super().create( return super().create(
*children, *children,
@ -117,17 +119,13 @@ class DataTable(Gridjs):
def _render(self) -> Tag: def _render(self) -> Tag:
if isinstance(self.data, Var) and types.is_dataframe(self.data._var_type): if isinstance(self.data, Var) and types.is_dataframe(self.data._var_type):
self.columns = BaseVar( self.columns = self.data._replace(
_var_name=f"{self.data._var_name}.columns", _var_name=f"{self.data._var_name}.columns",
_var_type=List[Any], _var_type=List[Any],
_var_full_name_needs_state_prefix=True,
_var_data=self.data._var_data,
) )
self.data = BaseVar( self.data = self.data._replace(
_var_name=f"{self.data._var_name}.data", _var_name=f"{self.data._var_name}.data",
_var_type=List[List[Any]], _var_type=List[List[Any]],
_var_full_name_needs_state_prefix=True,
_var_data=self.data._var_data,
) )
if types.is_dataframe(type(self.data)): if types.is_dataframe(type(self.data)):
# If given a pandas df break up the data and columns # If given a pandas df break up the data and columns

View File

@ -10,6 +10,7 @@ import functools
import inspect import inspect
import json import json
import sys import sys
import warnings
from types import CodeType, FunctionType from types import CodeType, FunctionType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -440,6 +441,8 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
return self return self
var_type = self._var_type var_type = self._var_type
if types.is_optional(var_type):
var_type = types.get_args(var_type)[0]
fixed_type = var_type if inspect.isclass(var_type) else get_origin(var_type) fixed_type = var_type if inspect.isclass(var_type) else get_origin(var_type)
@ -450,15 +453,15 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
raise TypeError(f"Unsupported type {var_type} for guess_type.") raise TypeError(f"Unsupported type {var_type} for guess_type.")
if issubclass(fixed_type, (int, float)): if issubclass(fixed_type, (int, float)):
return self.to(NumberVar, var_type) return self.to(NumberVar, self._var_type)
if issubclass(fixed_type, dict): if issubclass(fixed_type, dict):
return self.to(ObjectVar, var_type) return self.to(ObjectVar, self._var_type)
if issubclass(fixed_type, (list, tuple, set)): if issubclass(fixed_type, (list, tuple, set)):
return self.to(ArrayVar, var_type) return self.to(ArrayVar, self._var_type)
if issubclass(fixed_type, str): if issubclass(fixed_type, str):
return self.to(StringVar) return self.to(StringVar)
if issubclass(fixed_type, Base): if issubclass(fixed_type, Base):
return self.to(ObjectVar, var_type) return self.to(ObjectVar, self._var_type)
return self return self
def get_default_value(self) -> Any: def get_default_value(self) -> Any:
@ -837,16 +840,48 @@ class LiteralVar(ImmutableVar):
except ImportError: except ImportError:
pass pass
from .sequence import LiteralArrayVar, LiteralStringVar
try:
import base64
import io
from PIL.Image import MIME
from PIL.Image import Image as Img
if isinstance(value, Img):
with io.BytesIO() as buffer:
value.save(buffer, format=getattr(value, "format", None) or "PNG")
try:
# Newer method to get the mime type, but does not always work.
mimetype = value.get_format_mimetype()
except AttributeError:
try:
# Fallback method
mimetype = MIME[value.format]
except KeyError:
# Unknown mime_type: warn and return image/png and hope the browser can sort it out.
warnings.warn( # noqa: B028
f"Unknown mime type for {value} {value.format}. Defaulting to image/png"
)
mimetype = "image/png"
return LiteralStringVar.create(
f"data:{mimetype};base64,{base64.b64encode(buffer.getvalue()).decode()}",
_var_data=_var_data,
)
except ImportError:
pass
if isinstance(value, Base): if isinstance(value, Base):
return LiteralObjectVar.create( return LiteralObjectVar.create(
value.dict(), _var_type=type(value), _var_data=_var_data {k: (None if callable(v) else v) for k, v in value.dict().items()},
_var_type=type(value),
_var_data=_var_data,
) )
if isinstance(value, dict): if isinstance(value, dict):
return LiteralObjectVar.create(value, _var_data=_var_data) return LiteralObjectVar.create(value, _var_data=_var_data)
from .sequence import LiteralArrayVar, LiteralStringVar
if isinstance(value, str): if isinstance(value, str):
return LiteralStringVar.create(value, _var_data=_var_data) return LiteralStringVar.create(value, _var_data=_var_data)

View File

@ -22,6 +22,7 @@ from typing import (
from typing_extensions import get_origin from typing_extensions import get_origin
from reflex.utils import types
from reflex.utils.exceptions import VarAttributeError from reflex.utils.exceptions import VarAttributeError
from reflex.utils.types import GenericType, get_attribute_access_type from reflex.utils.types import GenericType, get_attribute_access_type
from reflex.vars import ImmutableVarData, Var, VarData from reflex.vars import ImmutableVarData, Var, VarData
@ -245,11 +246,15 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]):
""" """
if name.startswith("__") and name.endswith("__"): if name.startswith("__") and name.endswith("__"):
return getattr(super(type(self), self), name) return getattr(super(type(self), self), name)
fixed_type = (
self._var_type if isclass(self._var_type) else get_origin(self._var_type) var_type = self._var_type
)
if types.is_optional(var_type):
var_type = get_args(var_type)[0]
fixed_type = var_type if isclass(var_type) else get_origin(var_type)
if isclass(fixed_type) and not issubclass(fixed_type, dict): if isclass(fixed_type) and not issubclass(fixed_type, dict):
attribute_type = get_attribute_access_type(self._var_type, name) attribute_type = get_attribute_access_type(var_type, name)
if attribute_type is None: if attribute_type is None:
raise VarAttributeError( raise VarAttributeError(
f"The State var `{self._var_name}` has no attribute '{name}' or may have been annotated " f"The State var `{self._var_name}` has no attribute '{name}' or may have been annotated "
@ -727,6 +732,8 @@ class ObjectItemOperation(ImmutableVar):
Returns: Returns:
The name of the operation. The name of the operation.
""" """
if types.is_optional(self._object._var_type):
return f"{str(self._object)}?.[{str(self._key)}]"
return f"{str(self._object)}[{str(self._key)}]" return f"{str(self._object)}[{str(self._key)}]"
def __getattr__(self, name): def __getattr__(self, name):

View File

@ -753,10 +753,10 @@ class Var:
return self._var_name return self._var_name
try: try:
return json.loads(self._var_name) return json.loads(self._var_name)
except: except ValueError:
try: try:
return json.loads(self.json()) return json.loads(self.json())
except: except (ValueError, NotImplementedError):
return self._var_name return self._var_name
def equals(self, other: Var) -> bool: def equals(self, other: Var) -> bool:

View File

@ -6,6 +6,7 @@ from PIL.Image import Image as Img
import reflex as rx import reflex as rx
from reflex.components.next.image import Image # type: ignore from reflex.components.next.image import Image # type: ignore
from reflex.ivars.sequence import StringVar
from reflex.utils.serializers import serialize, serialize_image from reflex.utils.serializers import serialize, serialize_image
@ -52,7 +53,7 @@ def test_set_src_img(pil_image: Img):
pil_image: The image to serialize. pil_image: The image to serialize.
""" """
image = Image.create(src=pil_image) image = Image.create(src=pil_image)
assert str(image.src._var_name) == serialize_image(pil_image) # type: ignore assert str(image.src._var_name) == '"' + serialize_image(pil_image) + '"' # type: ignore
def test_render(pil_image: Img): def test_render(pil_image: Img):
@ -62,4 +63,4 @@ def test_render(pil_image: Img):
pil_image: The image to serialize. pil_image: The image to serialize.
""" """
image = Image.create(src=pil_image) image = Image.create(src=pil_image)
assert image.src._var_is_string # type: ignore assert isinstance(image.src, StringVar)

View File

@ -1251,7 +1251,7 @@ def test_app_wrap_compile_theme(compilable_app: tuple[App, Path]):
"function AppWrap({children}) {" "function AppWrap({children}) {"
"return (" "return ("
"<RadixThemesColorModeProvider>" "<RadixThemesColorModeProvider>"
"<RadixThemesTheme accentColor={`plum`} css={{...theme.styles.global[':root'], ...theme.styles.global.body}}>" "<RadixThemesTheme accentColor={\"plum\"} css={{...theme.styles.global[':root'], ...theme.styles.global.body}}>"
"<Fragment>" "<Fragment>"
"{children}" "{children}"
"</Fragment>" "</Fragment>"

View File

@ -1061,7 +1061,8 @@ def test_dirty_computed_var_from_backend_var(
Args: Args:
interdependent_state: A state with varying Var dependencies. interdependent_state: A state with varying Var dependencies.
""" """
assert InterdependentState._v3._backend is True # Accessing ._v3 returns the immutable var it represents instead of the actual computed var
# assert InterdependentState._v3._backend is True
interdependent_state._v2 = 2 interdependent_state._v2 = 2
assert interdependent_state.get_delta() == { assert interdependent_state.get_delta() == {
interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4}, interdependent_state.get_full_name(): {"v2x2": 4, "v3x2": 4},
@ -2601,15 +2602,23 @@ def test_state_union_optional():
c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo=""))) c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo="")))
custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="") custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="")
assert UnionState.c3.c2._var_name == "c3?.c2" # type: ignore assert str(UnionState.c3.c2) == f'{str(UnionState.c3)}?.["c2"]'
assert UnionState.c3.c2.c1._var_name == "c3?.c2?.c1" # type: ignore assert str(UnionState.c3.c2.c1) == f'{str(UnionState.c3)}?.["c2"]?.["c1"]'
assert UnionState.c3.c2.c1.foo._var_name == "c3?.c2?.c1?.foo" # type: ignore assert (
assert UnionState.c3.c2.c1r.foo._var_name == "c3?.c2?.c1r.foo" # type: ignore str(UnionState.c3.c2.c1.foo) == f'{str(UnionState.c3)}?.["c2"]?.["c1"]?.["foo"]'
assert UnionState.c3.c2r.c1._var_name == "c3?.c2r.c1" # type: ignore )
assert UnionState.c3.c2r.c1.foo._var_name == "c3?.c2r.c1?.foo" # type: ignore assert (
assert UnionState.c3.c2r.c1r.foo._var_name == "c3?.c2r.c1r.foo" # type: ignore str(UnionState.c3.c2.c1r.foo) == f'{str(UnionState.c3)}?.["c2"]?.["c1r"]["foo"]'
assert UnionState.c3i.c2._var_name == "c3i.c2" # type: ignore )
assert UnionState.c3r.c2._var_name == "c3r.c2" # type: ignore assert str(UnionState.c3.c2r.c1) == f'{str(UnionState.c3)}?.["c2r"]["c1"]'
assert (
str(UnionState.c3.c2r.c1.foo) == f'{str(UnionState.c3)}?.["c2r"]["c1"]?.["foo"]'
)
assert (
str(UnionState.c3.c2r.c1r.foo) == f'{str(UnionState.c3)}?.["c2r"]["c1r"]["foo"]'
)
assert str(UnionState.c3i.c2) == f'{str(UnionState.c3i)}["c2"]'
assert str(UnionState.c3r.c2) == f'{str(UnionState.c3r)}["c2"]'
assert UnionState.custom_union.foo is not None # type: ignore assert UnionState.custom_union.foo is not None # type: ignore
assert UnionState.custom_union.c1 is not None # type: ignore assert UnionState.custom_union.c1 is not None # type: ignore
assert UnionState.custom_union.c1r is not None # type: ignore assert UnionState.custom_union.c1r is not None # type: ignore

View File

@ -8,6 +8,7 @@ import pytest
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent from reflex.event import EventChain, EventHandler, EventSpec, FrontendEvent
from reflex.ivars.base import ImmutableVar, LiteralVar
from reflex.style import Style from reflex.style import Style
from reflex.utils import format from reflex.utils import format
from reflex.utils.serializers import serialize_figure from reflex.utils.serializers import serialize_figure
@ -422,19 +423,19 @@ def test_format_cond(
( (
"state__state.value", "state__state.value",
[ [
[Var.create(1), Var.create("red", _var_is_string=True)], [LiteralVar.create(1), LiteralVar.create("red")],
[Var.create(2), Var.create(3), Var.create("blue", _var_is_string=True)], [LiteralVar.create(2), LiteralVar.create(3), LiteralVar.create("blue")],
[TestState.mapping, TestState.num1], [TestState.mapping, TestState.num1],
[ [
Var.create(f"{TestState.map_key}-key", _var_is_string=True), LiteralVar.create(f"{TestState.map_key}-key"),
Var.create("return-key", _var_is_string=True), LiteralVar.create("return-key"),
], ],
], ],
Var.create("yellow", _var_is_string=True), LiteralVar.create("yellow"),
"(() => { switch (JSON.stringify(state__state.value)) {case JSON.stringify(1): return (`red`); break;case JSON.stringify(2): case JSON.stringify(3): " '(() => { switch (JSON.stringify(state__state.value)) {case JSON.stringify(1): return ("red"); break;case JSON.stringify(2): case JSON.stringify(3): '
f"return (`blue`); break;case JSON.stringify({TestState.get_full_name()}.mapping): return " f'return ("blue"); break;case JSON.stringify({TestState.get_full_name()}.mapping): return '
f"({TestState.get_full_name()}.num1); break;case JSON.stringify(`${{{TestState.get_full_name()}.map_key}}-key`): return (`return-key`);" f'({TestState.get_full_name()}.num1); break;case JSON.stringify(({TestState.get_full_name()}.map_key+"-key")): return ("return-key");'
" break;default: return (`yellow`); break;};})()", ' break;default: return ("yellow"); break;};})()',
) )
], ],
) )
@ -541,7 +542,7 @@ def test_format_match(
{ {
"h1": f"{{({{node, ...props}}) => <Heading {{...props}} {''.join(Tag(name='', props=Style({'as_': 'h1'})).format_props())} />}}" "h1": f"{{({{node, ...props}}) => <Heading {{...props}} {''.join(Tag(name='', props=Style({'as_': 'h1'})).format_props())} />}}"
}, },
'{{"h1": ({node, ...props}) => <Heading {...props} as={`h1`} />}}', '{{"h1": ({node, ...props}) => <Heading {...props} as={"h1"} />}}',
), ),
], ],
) )
@ -558,7 +559,11 @@ def test_format_prop(prop: Var, formatted: str):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"single_props,key_value_props,output", "single_props,key_value_props,output",
[ [
(["string"], {"key": 42}, ["key={42}", "string"]), (
[ImmutableVar.create_safe("...props")],
{"key": 42},
["key={42}", "{...props}"],
),
], ],
) )
def test_format_props(single_props, key_value_props, output): def test_format_props(single_props, key_value_props, output):