From e31b458a692caf780e2738fa31a4cf0f801ad071 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Wed, 1 May 2024 22:33:38 +0200 Subject: [PATCH] allow optional props with None default value (#3179) --- reflex/components/component.py | 23 ++++- tests/components/test_component.py | 158 ++++++++++++++++++++++++++++- tests/utils/test_types.py | 27 ++++- 3 files changed, 201 insertions(+), 7 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 9dd11254c..013181a58 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -303,6 +303,8 @@ class Component(BaseComponent, ABC): # Check whether the key is a component prop. if types._issubclass(field_type, Var): + # Used to store the passed types if var type is a union. + passed_types = None try: # Try to create a var from the value. kwargs[key] = Var.create(value) @@ -327,10 +329,25 @@ class Component(BaseComponent, ABC): # If it is not a valid var, check the base types. passed_type = type(value) expected_type = fields[key].outer_type_ - if not types._issubclass(passed_type, expected_type): + if types.is_union(passed_type): + # We need to check all possible types in the union. + passed_types = ( + arg for arg in passed_type.__args__ if arg is not type(None) + ) + if ( + # If the passed var is a union, check if all possible types are valid. + passed_types + and not all( + types._issubclass(pt, expected_type) for pt in passed_types + ) + ) or ( + # Else just check if the passed var type is valid. + not passed_types + and not types._issubclass(passed_type, expected_type) + ): value_name = value._var_name if isinstance(value, Var) else value raise TypeError( - f"Invalid var passed for prop {type(self).__name__}.{key}, expected type {expected_type}, got value {value_name} of type {passed_type}." + f"Invalid var passed for prop {type(self).__name__}.{key}, expected type {expected_type}, got value {value_name} of type {passed_types or passed_type}." ) # Check if the key is an event trigger. @@ -1523,7 +1540,7 @@ class CustomComponent(Component): def custom_component( - component_fn: Callable[..., Component] + component_fn: Callable[..., Component], ) -> Callable[..., CustomComponent]: """Create a custom component from a function. diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 21ec409af..15ceee7e4 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Type +from contextlib import nullcontext +from typing import Any, Dict, List, Optional, Type, Union import pytest @@ -20,7 +21,7 @@ from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports from reflex.utils.imports import ImportVar -from reflex.vars import Var, VarData +from reflex.vars import BaseVar, Var, VarData @pytest.fixture @@ -52,6 +53,9 @@ def component1() -> Type[Component]: # A test number prop. number: Var[int] + # A test string/number prop. + text_or_number: Var[Union[int, str]] + def _get_imports(self) -> imports.ImportDict: return {"react": [ImportVar(tag="Component")]} @@ -253,6 +257,154 @@ def test_create_component(component1): assert c.style == {"color": "white", "textAlign": "center"} +@pytest.mark.parametrize( + "prop_name,var,expected", + [ + pytest.param( + "text", + Var.create("hello"), + None, + id="text", + ), + pytest.param( + "text", + BaseVar(_var_name="hello", _var_type=Optional[str]), + None, + id="text-optional", + ), + pytest.param( + "text", + BaseVar(_var_name="hello", _var_type=Union[str, None]), + None, + id="text-union-str-none", + ), + pytest.param( + "text", + BaseVar(_var_name="hello", _var_type=Union[None, str]), + None, + id="text-union-none-str", + ), + pytest.param( + "text", + Var.create(1), + TypeError, + id="text-int", + ), + pytest.param( + "number", + Var.create(1), + None, + id="number", + ), + pytest.param( + "number", + BaseVar(_var_name="1", _var_type=Optional[int]), + None, + id="number-optional", + ), + pytest.param( + "number", + BaseVar(_var_name="1", _var_type=Union[int, None]), + None, + id="number-union-int-none", + ), + pytest.param( + "number", + BaseVar(_var_name="1", _var_type=Union[None, int]), + None, + id="number-union-none-int", + ), + pytest.param( + "number", + Var.create("1"), + TypeError, + id="number-str", + ), + pytest.param( + "text_or_number", + Var.create("hello"), + None, + id="text_or_number-str", + ), + pytest.param( + "text_or_number", + Var.create(1), + None, + id="text_or_number-int", + ), + pytest.param( + "text_or_number", + BaseVar(_var_name="hello", _var_type=Optional[str]), + None, + id="text_or_number-optional-str", + ), + pytest.param( + "text_or_number", + BaseVar(_var_name="hello", _var_type=Union[str, None]), + None, + id="text_or_number-union-str-none", + ), + pytest.param( + "text_or_number", + BaseVar(_var_name="hello", _var_type=Union[None, str]), + None, + id="text_or_number-union-none-str", + ), + pytest.param( + "text_or_number", + BaseVar(_var_name="1", _var_type=Optional[int]), + None, + id="text_or_number-optional-int", + ), + pytest.param( + "text_or_number", + BaseVar(_var_name="1", _var_type=Union[int, None]), + None, + id="text_or_number-union-int-none", + ), + pytest.param( + "text_or_number", + BaseVar(_var_name="1", _var_type=Union[None, int]), + None, + id="text_or_number-union-none-int", + ), + pytest.param( + "text_or_number", + Var.create(1.0), + TypeError, + id="text_or_number-float", + ), + pytest.param( + "text_or_number", + BaseVar(_var_name="hello", _var_type=Optional[Union[str, int]]), + None, + id="text_or_number-optional-union-str-int", + ), + ], +) +def test_create_component_prop_validation( + component1: Type[Component], + prop_name: str, + var: Union[Var, str, int], + expected: Type[Exception], +): + """Test that component props are validated correctly. + + Args: + component1: A test component. + prop_name: The name of the prop. + var: The value of the prop. + expected: The expected exception. + """ + ctx = pytest.raises(expected) if expected else nullcontext() + kwargs = {prop_name: var} + with ctx: + c = component1.create(**kwargs) + assert isinstance(c, component1) + assert c.children == [] + assert c.style == {} + + def test_add_style(component1, component2): """Test adding a style to a component. @@ -338,7 +490,7 @@ def test_get_props(component1, component2): component1: A test component. component2: A test component. """ - assert component1.get_props() == {"text", "number"} + assert component1.get_props() == {"text", "number", "text_or_number"} assert component2.get_props() == {"arr"} diff --git a/tests/utils/test_types.py b/tests/utils/test_types.py index 3ad3a26e1..fc9261e04 100644 --- a/tests/utils/test_types.py +++ b/tests/utils/test_types.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Any, List, Literal, Tuple, Union import pytest @@ -20,3 +20,28 @@ def test_validate_literal_error_msg(params, allowed_value_str, value_str): err.value.args[0] == f"prop value for {str(params[0])} of the `{params[-1]}` " f"component should be one of the following: {allowed_value_str}. Got {value_str} instead" ) + + +@pytest.mark.parametrize( + "cls,cls_check,expected", + [ + (int, Any, True), + (Tuple[int], Any, True), + (List[int], Any, True), + (int, int, True), + (int, object, True), + (int, Union[int, str], True), + (int, Union[str, int], True), + (str, Union[str, int], True), + (str, Union[int, str], True), + (int, Union[str, float, int], True), + (int, Union[str, float], False), + (int, Union[float, str], False), + (int, str, False), + (int, List[int], False), + ], +) +def test_issubclass( + cls: types.GenericType, cls_check: types.GenericType, expected: bool +) -> None: + assert types._issubclass(cls, cls_check) == expected