From d2fd0d3b9253831bd19840bda0e2e4586aee8b14 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Wed, 31 Jan 2024 02:14:20 +0000 Subject: [PATCH] [REF-1742] Radio group prop types fix (#2452) --- .../radix/themes/components/radiogroup.py | 44 +++++++++++---- .../radix/themes/components/radiogroup.pyi | 2 +- reflex/vars.py | 36 +++++++++++-- tests/test_var.py | 53 +++++++++++++++++++ 4 files changed, 120 insertions(+), 15 deletions(-) diff --git a/reflex/components/radix/themes/components/radiogroup.py b/reflex/components/radix/themes/components/radiogroup.py index 01ceca188..099a7c287 100644 --- a/reflex/components/radix/themes/components/radiogroup.py +++ b/reflex/components/radix/themes/components/radiogroup.py @@ -1,5 +1,5 @@ """Interactive components provided by @radix-ui/themes.""" -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional, Union import reflex as rx from reflex.components.component import Component @@ -97,7 +97,11 @@ class HighLevelRadioGroup(RadioGroupRoot): size: Var[Literal["1", "2", "3"]] = Var.create_safe("2") @classmethod - def create(cls, items: Var[List[str]], **props) -> Component: + def create( + cls, + items: Var[List[Optional[Union[str, int, float, list, dict, bool]]]], + **props + ) -> Component: """Create a radio group component. Args: @@ -110,29 +114,49 @@ class HighLevelRadioGroup(RadioGroupRoot): direction = props.pop("direction", "column") gap = props.pop("gap", "2") size = props.pop("size", "2") + default_value = props.pop("default_value", "") + + # convert only non-strings to json(JSON.stringify) so quotes are not rendered + # for string literal types. + if ( + type(default_value) is str + or isinstance(default_value, Var) + and default_value._var_type is str + ): + default_value = Var.create(default_value, _var_is_string=True) # type: ignore + else: + default_value = ( + Var.create(default_value).to_string()._replace(_var_is_local=False) # type: ignore + ) + + def radio_group_item(value: str | Var) -> Component: + item_value = Var.create(value) # type: ignore + item_value = rx.cond( + item_value._type() == str, # type: ignore + item_value, + item_value.to_string()._replace(_var_is_local=False), # type: ignore + )._replace(_var_type=str) - def radio_group_item(value: str) -> Component: return Text.create( Flex.create( - RadioGroupItem.create(value=value), - value, + RadioGroupItem.create(value=item_value), + item_value, gap="2", ), size=size, as_="label", ) - if isinstance(items, Var): - child = [rx.foreach(items, radio_group_item)] - else: - child = [radio_group_item(value) for value in items] # type: ignore + items = Var.create(items) # type: ignore + children = [rx.foreach(items, radio_group_item)] return RadioGroupRoot.create( Flex.create( - *child, + *children, direction=direction, gap=gap, ), size=size, + default_value=default_value, **props, ) diff --git a/reflex/components/radix/themes/components/radiogroup.pyi b/reflex/components/radix/themes/components/radiogroup.pyi index 1f3b4075b..8f4b0a15c 100644 --- a/reflex/components/radix/themes/components/radiogroup.pyi +++ b/reflex/components/radix/themes/components/radiogroup.pyi @@ -7,7 +7,7 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.vars import Var, BaseVar, ComputedVar from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional, Union import reflex as rx from reflex.components.component import Component from reflex.components.radix.themes.layout.flex import Flex diff --git a/reflex/vars.py b/reflex/vars.py index 3fba2bdde..414bb357b 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -87,6 +87,15 @@ REPLACED_NAMES = { "deps": "_deps", } +PYTHON_JS_TYPE_MAP = { + (int, float): "number", + (str,): "string", + (bool,): "boolean", + (list, tuple): "Array", + (dict,): "Object", + (None,): "null", +} + def get_unique_variable_name() -> str: """Get a unique variable name. @@ -739,13 +748,13 @@ class Var: operation_name = format.wrap(operation_name, "(") else: # apply operator to left operand ( left_operand) - operation_name = f"{op}{self._var_full_name}" + operation_name = f"{op}{get_operand_full_name(self)}" # apply function to operands if fn is not None: operation_name = ( f"{fn}({operation_name})" if not invoke_fn - else f"{self._var_full_name}.{fn}()" + else f"{get_operand_full_name(self)}.{fn}()" ) return self._replace( @@ -839,7 +848,20 @@ class Var: _var_is_string=False, ) - def __eq__(self, other: Var) -> Var: + def _type(self) -> Var: + """Get the type of the Var in Javascript. + + Returns: + A var representing the type check. + """ + return self._replace( + _var_name=f"typeof {self._var_full_name}", + _var_type=str, + _var_is_string=False, + _var_full_name_needs_state_prefix=False, + ) + + def __eq__(self, other: Union[Var, Type]) -> Var: """Perform an equality comparison. Args: @@ -848,9 +870,12 @@ class Var: Returns: A var representing the equality comparison. """ + for python_types, js_type in PYTHON_JS_TYPE_MAP.items(): + if not isinstance(other, Var) and other in python_types: + return self.compare("===", Var.create(js_type, _var_is_string=True)) # type: ignore return self.compare("===", other) - def __ne__(self, other: Var) -> Var: + def __ne__(self, other: Union[Var, Type]) -> Var: """Perform an inequality comparison. Args: @@ -859,6 +884,9 @@ class Var: Returns: A var representing the inequality comparison. """ + for python_types, js_type in PYTHON_JS_TYPE_MAP.items(): + if not isinstance(other, Var) and other in python_types: + return self.compare("!==", Var.create(js_type, _var_is_string=True)) # type: ignore return self.compare("!==", other) def __gt__(self, other: Var) -> Var: diff --git a/tests/test_var.py b/tests/test_var.py index 5833e28f0..df9b7453e 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -316,6 +316,59 @@ def test_basic_operations(TestObj): str(BaseVar(_var_name="foo", _var_type=list).reverse()) == "{[...foo].reverse()}" ) + assert str(BaseVar(_var_name="foo", _var_type=str)._type()) == "{typeof foo}" # type: ignore + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() == str) # type: ignore + == "{(typeof foo === `string`)}" + ) + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() == str) # type: ignore + == "{(typeof foo === `string`)}" + ) + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() == int) # type: ignore + == "{(typeof foo === `number`)}" + ) + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() == list) # type: ignore + == "{(typeof foo === `Array`)}" + ) + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() == float) # type: ignore + == "{(typeof foo === `number`)}" + ) + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() == tuple) # type: ignore + == "{(typeof foo === `Array`)}" + ) + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() == dict) # type: ignore + == "{(typeof foo === `Object`)}" + ) + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() != str) # type: ignore + == "{(typeof foo !== `string`)}" + ) + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() != int) # type: ignore + == "{(typeof foo !== `number`)}" + ) + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() != list) # type: ignore + == "{(typeof foo !== `Array`)}" + ) + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() != float) # type: ignore + == "{(typeof foo !== `number`)}" + ) + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() != tuple) # type: ignore + == "{(typeof foo !== `Array`)}" + ) + assert ( + str(BaseVar(_var_name="foo", _var_type=str)._type() != dict) # type: ignore + == "{(typeof foo !== `Object`)}" + ) @pytest.mark.parametrize(