allow optional props with None default value (#3179)
This commit is contained in:
parent
73e9123733
commit
e31b458a69
@ -303,6 +303,8 @@ class Component(BaseComponent, ABC):
|
||||
|
||||
# Check whether the key is a component prop.
|
||||
if types._issubclass(field_type, Var):
|
||||
# Used to store the passed types if var type is a union.
|
||||
passed_types = None
|
||||
try:
|
||||
# Try to create a var from the value.
|
||||
kwargs[key] = Var.create(value)
|
||||
@ -327,10 +329,25 @@ class Component(BaseComponent, ABC):
|
||||
# If it is not a valid var, check the base types.
|
||||
passed_type = type(value)
|
||||
expected_type = fields[key].outer_type_
|
||||
if not types._issubclass(passed_type, expected_type):
|
||||
if types.is_union(passed_type):
|
||||
# We need to check all possible types in the union.
|
||||
passed_types = (
|
||||
arg for arg in passed_type.__args__ if arg is not type(None)
|
||||
)
|
||||
if (
|
||||
# If the passed var is a union, check if all possible types are valid.
|
||||
passed_types
|
||||
and not all(
|
||||
types._issubclass(pt, expected_type) for pt in passed_types
|
||||
)
|
||||
) or (
|
||||
# Else just check if the passed var type is valid.
|
||||
not passed_types
|
||||
and not types._issubclass(passed_type, expected_type)
|
||||
):
|
||||
value_name = value._var_name if isinstance(value, Var) else value
|
||||
raise TypeError(
|
||||
f"Invalid var passed for prop {type(self).__name__}.{key}, expected type {expected_type}, got value {value_name} of type {passed_type}."
|
||||
f"Invalid var passed for prop {type(self).__name__}.{key}, expected type {expected_type}, got value {value_name} of type {passed_types or passed_type}."
|
||||
)
|
||||
|
||||
# Check if the key is an event trigger.
|
||||
@ -1523,7 +1540,7 @@ class CustomComponent(Component):
|
||||
|
||||
|
||||
def custom_component(
|
||||
component_fn: Callable[..., Component]
|
||||
component_fn: Callable[..., Component],
|
||||
) -> Callable[..., CustomComponent]:
|
||||
"""Create a custom component from a function.
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, List, Type
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@ -20,7 +21,7 @@ from reflex.state import BaseState
|
||||
from reflex.style import Style
|
||||
from reflex.utils import imports
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import Var, VarData
|
||||
from reflex.vars import BaseVar, Var, VarData
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -52,6 +53,9 @@ def component1() -> Type[Component]:
|
||||
# A test number prop.
|
||||
number: Var[int]
|
||||
|
||||
# A test string/number prop.
|
||||
text_or_number: Var[Union[int, str]]
|
||||
|
||||
def _get_imports(self) -> imports.ImportDict:
|
||||
return {"react": [ImportVar(tag="Component")]}
|
||||
|
||||
@ -253,6 +257,154 @@ def test_create_component(component1):
|
||||
assert c.style == {"color": "white", "textAlign": "center"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prop_name,var,expected",
|
||||
[
|
||||
pytest.param(
|
||||
"text",
|
||||
Var.create("hello"),
|
||||
None,
|
||||
id="text",
|
||||
),
|
||||
pytest.param(
|
||||
"text",
|
||||
BaseVar(_var_name="hello", _var_type=Optional[str]),
|
||||
None,
|
||||
id="text-optional",
|
||||
),
|
||||
pytest.param(
|
||||
"text",
|
||||
BaseVar(_var_name="hello", _var_type=Union[str, None]),
|
||||
None,
|
||||
id="text-union-str-none",
|
||||
),
|
||||
pytest.param(
|
||||
"text",
|
||||
BaseVar(_var_name="hello", _var_type=Union[None, str]),
|
||||
None,
|
||||
id="text-union-none-str",
|
||||
),
|
||||
pytest.param(
|
||||
"text",
|
||||
Var.create(1),
|
||||
TypeError,
|
||||
id="text-int",
|
||||
),
|
||||
pytest.param(
|
||||
"number",
|
||||
Var.create(1),
|
||||
None,
|
||||
id="number",
|
||||
),
|
||||
pytest.param(
|
||||
"number",
|
||||
BaseVar(_var_name="1", _var_type=Optional[int]),
|
||||
None,
|
||||
id="number-optional",
|
||||
),
|
||||
pytest.param(
|
||||
"number",
|
||||
BaseVar(_var_name="1", _var_type=Union[int, None]),
|
||||
None,
|
||||
id="number-union-int-none",
|
||||
),
|
||||
pytest.param(
|
||||
"number",
|
||||
BaseVar(_var_name="1", _var_type=Union[None, int]),
|
||||
None,
|
||||
id="number-union-none-int",
|
||||
),
|
||||
pytest.param(
|
||||
"number",
|
||||
Var.create("1"),
|
||||
TypeError,
|
||||
id="number-str",
|
||||
),
|
||||
pytest.param(
|
||||
"text_or_number",
|
||||
Var.create("hello"),
|
||||
None,
|
||||
id="text_or_number-str",
|
||||
),
|
||||
pytest.param(
|
||||
"text_or_number",
|
||||
Var.create(1),
|
||||
None,
|
||||
id="text_or_number-int",
|
||||
),
|
||||
pytest.param(
|
||||
"text_or_number",
|
||||
BaseVar(_var_name="hello", _var_type=Optional[str]),
|
||||
None,
|
||||
id="text_or_number-optional-str",
|
||||
),
|
||||
pytest.param(
|
||||
"text_or_number",
|
||||
BaseVar(_var_name="hello", _var_type=Union[str, None]),
|
||||
None,
|
||||
id="text_or_number-union-str-none",
|
||||
),
|
||||
pytest.param(
|
||||
"text_or_number",
|
||||
BaseVar(_var_name="hello", _var_type=Union[None, str]),
|
||||
None,
|
||||
id="text_or_number-union-none-str",
|
||||
),
|
||||
pytest.param(
|
||||
"text_or_number",
|
||||
BaseVar(_var_name="1", _var_type=Optional[int]),
|
||||
None,
|
||||
id="text_or_number-optional-int",
|
||||
),
|
||||
pytest.param(
|
||||
"text_or_number",
|
||||
BaseVar(_var_name="1", _var_type=Union[int, None]),
|
||||
None,
|
||||
id="text_or_number-union-int-none",
|
||||
),
|
||||
pytest.param(
|
||||
"text_or_number",
|
||||
BaseVar(_var_name="1", _var_type=Union[None, int]),
|
||||
None,
|
||||
id="text_or_number-union-none-int",
|
||||
),
|
||||
pytest.param(
|
||||
"text_or_number",
|
||||
Var.create(1.0),
|
||||
TypeError,
|
||||
id="text_or_number-float",
|
||||
),
|
||||
pytest.param(
|
||||
"text_or_number",
|
||||
BaseVar(_var_name="hello", _var_type=Optional[Union[str, int]]),
|
||||
None,
|
||||
id="text_or_number-optional-union-str-int",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_create_component_prop_validation(
|
||||
component1: Type[Component],
|
||||
prop_name: str,
|
||||
var: Union[Var, str, int],
|
||||
expected: Type[Exception],
|
||||
):
|
||||
"""Test that component props are validated correctly.
|
||||
|
||||
Args:
|
||||
component1: A test component.
|
||||
prop_name: The name of the prop.
|
||||
var: The value of the prop.
|
||||
expected: The expected exception.
|
||||
"""
|
||||
ctx = pytest.raises(expected) if expected else nullcontext()
|
||||
kwargs = {prop_name: var}
|
||||
with ctx:
|
||||
c = component1.create(**kwargs)
|
||||
assert isinstance(c, component1)
|
||||
assert c.children == []
|
||||
assert c.style == {}
|
||||
|
||||
|
||||
def test_add_style(component1, component2):
|
||||
"""Test adding a style to a component.
|
||||
|
||||
@ -338,7 +490,7 @@ def test_get_props(component1, component2):
|
||||
component1: A test component.
|
||||
component2: A test component.
|
||||
"""
|
||||
assert component1.get_props() == {"text", "number"}
|
||||
assert component1.get_props() == {"text", "number", "text_or_number"}
|
||||
assert component2.get_props() == {"arr"}
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Literal
|
||||
from typing import Any, List, Literal, Tuple, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@ -20,3 +20,28 @@ def test_validate_literal_error_msg(params, allowed_value_str, value_str):
|
||||
err.value.args[0] == f"prop value for {str(params[0])} of the `{params[-1]}` "
|
||||
f"component should be one of the following: {allowed_value_str}. Got {value_str} instead"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cls,cls_check,expected",
|
||||
[
|
||||
(int, Any, True),
|
||||
(Tuple[int], Any, True),
|
||||
(List[int], Any, True),
|
||||
(int, int, True),
|
||||
(int, object, True),
|
||||
(int, Union[int, str], True),
|
||||
(int, Union[str, int], True),
|
||||
(str, Union[str, int], True),
|
||||
(str, Union[int, str], True),
|
||||
(int, Union[str, float, int], True),
|
||||
(int, Union[str, float], False),
|
||||
(int, Union[float, str], False),
|
||||
(int, str, False),
|
||||
(int, List[int], False),
|
||||
],
|
||||
)
|
||||
def test_issubclass(
|
||||
cls: types.GenericType, cls_check: types.GenericType, expected: bool
|
||||
) -> None:
|
||||
assert types._issubclass(cls, cls_check) == expected
|
||||
|
Loading…
Reference in New Issue
Block a user