allow optional props with None default value (#3179)

This commit is contained in:
benedikt-bartscher 2024-05-01 22:33:38 +02:00 committed by GitHub
parent 73e9123733
commit e31b458a69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 201 additions and 7 deletions

View File

@ -303,6 +303,8 @@ class Component(BaseComponent, ABC):
# Check whether the key is a component prop.
if types._issubclass(field_type, Var):
# Used to store the passed types if var type is a union.
passed_types = None
try:
# Try to create a var from the value.
kwargs[key] = Var.create(value)
@ -327,10 +329,25 @@ class Component(BaseComponent, ABC):
# If it is not a valid var, check the base types.
passed_type = type(value)
expected_type = fields[key].outer_type_
if not types._issubclass(passed_type, expected_type):
if types.is_union(passed_type):
# We need to check all possible types in the union.
passed_types = (
arg for arg in passed_type.__args__ if arg is not type(None)
)
if (
# If the passed var is a union, check if all possible types are valid.
passed_types
and not all(
types._issubclass(pt, expected_type) for pt in passed_types
)
) or (
# Else just check if the passed var type is valid.
not passed_types
and not types._issubclass(passed_type, expected_type)
):
value_name = value._var_name if isinstance(value, Var) else value
raise TypeError(
f"Invalid var passed for prop {type(self).__name__}.{key}, expected type {expected_type}, got value {value_name} of type {passed_type}."
f"Invalid var passed for prop {type(self).__name__}.{key}, expected type {expected_type}, got value {value_name} of type {passed_types or passed_type}."
)
# Check if the key is an event trigger.
@ -1523,7 +1540,7 @@ class CustomComponent(Component):
def custom_component(
component_fn: Callable[..., Component]
component_fn: Callable[..., Component],
) -> Callable[..., CustomComponent]:
"""Create a custom component from a function.

View File

@ -1,4 +1,5 @@
from typing import Any, Dict, List, Type
from contextlib import nullcontext
from typing import Any, Dict, List, Optional, Type, Union
import pytest
@ -20,7 +21,7 @@ from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
from reflex.utils.imports import ImportVar
from reflex.vars import Var, VarData
from reflex.vars import BaseVar, Var, VarData
@pytest.fixture
@ -52,6 +53,9 @@ def component1() -> Type[Component]:
# A test number prop.
number: Var[int]
# A test string/number prop.
text_or_number: Var[Union[int, str]]
def _get_imports(self) -> imports.ImportDict:
return {"react": [ImportVar(tag="Component")]}
@ -253,6 +257,154 @@ def test_create_component(component1):
assert c.style == {"color": "white", "textAlign": "center"}
@pytest.mark.parametrize(
"prop_name,var,expected",
[
pytest.param(
"text",
Var.create("hello"),
None,
id="text",
),
pytest.param(
"text",
BaseVar(_var_name="hello", _var_type=Optional[str]),
None,
id="text-optional",
),
pytest.param(
"text",
BaseVar(_var_name="hello", _var_type=Union[str, None]),
None,
id="text-union-str-none",
),
pytest.param(
"text",
BaseVar(_var_name="hello", _var_type=Union[None, str]),
None,
id="text-union-none-str",
),
pytest.param(
"text",
Var.create(1),
TypeError,
id="text-int",
),
pytest.param(
"number",
Var.create(1),
None,
id="number",
),
pytest.param(
"number",
BaseVar(_var_name="1", _var_type=Optional[int]),
None,
id="number-optional",
),
pytest.param(
"number",
BaseVar(_var_name="1", _var_type=Union[int, None]),
None,
id="number-union-int-none",
),
pytest.param(
"number",
BaseVar(_var_name="1", _var_type=Union[None, int]),
None,
id="number-union-none-int",
),
pytest.param(
"number",
Var.create("1"),
TypeError,
id="number-str",
),
pytest.param(
"text_or_number",
Var.create("hello"),
None,
id="text_or_number-str",
),
pytest.param(
"text_or_number",
Var.create(1),
None,
id="text_or_number-int",
),
pytest.param(
"text_or_number",
BaseVar(_var_name="hello", _var_type=Optional[str]),
None,
id="text_or_number-optional-str",
),
pytest.param(
"text_or_number",
BaseVar(_var_name="hello", _var_type=Union[str, None]),
None,
id="text_or_number-union-str-none",
),
pytest.param(
"text_or_number",
BaseVar(_var_name="hello", _var_type=Union[None, str]),
None,
id="text_or_number-union-none-str",
),
pytest.param(
"text_or_number",
BaseVar(_var_name="1", _var_type=Optional[int]),
None,
id="text_or_number-optional-int",
),
pytest.param(
"text_or_number",
BaseVar(_var_name="1", _var_type=Union[int, None]),
None,
id="text_or_number-union-int-none",
),
pytest.param(
"text_or_number",
BaseVar(_var_name="1", _var_type=Union[None, int]),
None,
id="text_or_number-union-none-int",
),
pytest.param(
"text_or_number",
Var.create(1.0),
TypeError,
id="text_or_number-float",
),
pytest.param(
"text_or_number",
BaseVar(_var_name="hello", _var_type=Optional[Union[str, int]]),
None,
id="text_or_number-optional-union-str-int",
),
],
)
def test_create_component_prop_validation(
component1: Type[Component],
prop_name: str,
var: Union[Var, str, int],
expected: Type[Exception],
):
"""Test that component props are validated correctly.
Args:
component1: A test component.
prop_name: The name of the prop.
var: The value of the prop.
expected: The expected exception.
"""
ctx = pytest.raises(expected) if expected else nullcontext()
kwargs = {prop_name: var}
with ctx:
c = component1.create(**kwargs)
assert isinstance(c, component1)
assert c.children == []
assert c.style == {}
def test_add_style(component1, component2):
"""Test adding a style to a component.
@ -338,7 +490,7 @@ def test_get_props(component1, component2):
component1: A test component.
component2: A test component.
"""
assert component1.get_props() == {"text", "number"}
assert component1.get_props() == {"text", "number", "text_or_number"}
assert component2.get_props() == {"arr"}

View File

@ -1,4 +1,4 @@
from typing import Literal
from typing import Any, List, Literal, Tuple, Union
import pytest
@ -20,3 +20,28 @@ def test_validate_literal_error_msg(params, allowed_value_str, value_str):
err.value.args[0] == f"prop value for {str(params[0])} of the `{params[-1]}` "
f"component should be one of the following: {allowed_value_str}. Got {value_str} instead"
)
@pytest.mark.parametrize(
"cls,cls_check,expected",
[
(int, Any, True),
(Tuple[int], Any, True),
(List[int], Any, True),
(int, int, True),
(int, object, True),
(int, Union[int, str], True),
(int, Union[str, int], True),
(str, Union[str, int], True),
(str, Union[int, str], True),
(int, Union[str, float, int], True),
(int, Union[str, float], False),
(int, Union[float, str], False),
(int, str, False),
(int, List[int], False),
],
)
def test_issubclass(
cls: types.GenericType, cls_check: types.GenericType, expected: bool
) -> None:
assert types._issubclass(cls, cls_check) == expected