diff --git a/reflex/components/component.py b/reflex/components/component.py index 2c01559ca..0d619a8cf 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -199,7 +199,6 @@ class Component(BaseComponent, ABC): Raises: TypeError: If an invalid prop is passed. - ValueError: If a prop value is invalid. """ # Set the id and children initially. children = kwargs.get("children", []) @@ -249,17 +248,10 @@ class Component(BaseComponent, ABC): raise TypeError expected_type = fields[key].outer_type_.__args__[0] - - if ( - types.is_literal(expected_type) - and value not in expected_type.__args__ - ): - allowed_values = expected_type.__args__ - if value not in allowed_values: - raise ValueError( - f"prop value for {key} of the `{type(self).__name__}` component should be one of the following: {','.join(allowed_values)}. Got '{value}' instead" - ) - + # validate literal fields. + types.validate_literal( + key, value, expected_type, type(self).__name__ + ) # Get the passed type and the var type. passed_type = kwargs[key]._var_type expected_type = ( diff --git a/reflex/components/core/colors.py b/reflex/components/core/colors.py index d146fae1e..fbc6825aa 100644 --- a/reflex/components/core/colors.py +++ b/reflex/components/core/colors.py @@ -1,13 +1,11 @@ """The colors used in Reflex are a wrapper around https://www.radix-ui.com/colors.""" from reflex.constants.colors import Color, ColorType, ShadeType +from reflex.utils.types import validate_parameter_literals -def color( - color: ColorType, - shade: ShadeType = 7, - alpha: bool = False, -) -> Color: +@validate_parameter_literals +def color(color: ColorType, shade: ShadeType = 7, alpha: bool = False) -> Color: """Create a color object. Args: diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 93469e7b7..41e910dd2 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -3,7 +3,9 @@ from __future__ import annotations import contextlib +import inspect import types +from functools import wraps from typing import ( Any, Callable, @@ -330,6 +332,69 @@ def check_prop_in_allowed_types(prop: Any, allowed_types: Iterable) -> bool: return type_ in allowed_types +def validate_literal(key: str, value: Any, expected_type: Type, comp_name: str): + """Check that a value is a valid literal. + + Args: + key: The prop name. + value: The prop value to validate. + expected_type: The expected type(literal type). + comp_name: Name of the component. + + Raises: + ValueError: When the value is not a valid literal. + """ + from reflex.vars import Var + + if ( + is_literal(expected_type) + and not isinstance(value, Var) # validating vars is not supported yet. + and value not in expected_type.__args__ + ): + allowed_values = expected_type.__args__ + if value not in allowed_values: + value_str = ",".join( + [str(v) if not isinstance(v, str) else f"'{v}'" for v in allowed_values] + ) + raise ValueError( + f"prop value for {str(key)} of the `{comp_name}` component should be one of the following: {value_str}. Got '{value}' instead" + ) + + +def validate_parameter_literals(func): + """Decorator to check that the arguments passed to a function + correspond to the correct function parameter if it (the parameter) + is a literal type. + + Args: + func: The function to validate. + + Returns: + The wrapper function. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + func_params = list(inspect.signature(func).parameters.items()) + annotations = {param[0]: param[1].annotation for param in func_params} + + # validate args + for param, arg in zip(annotations.keys(), args): + if annotations[param] is inspect.Parameter.empty: + continue + validate_literal(param, arg, annotations[param], func.__name__) + + # validate kwargs. + for key, value in kwargs.items(): + annotation = annotations.get(key) + if not annotation or annotation is inspect.Parameter.empty: + continue + validate_literal(key, value, annotation, func.__name__) + return func(*args, **kwargs) + + return wrapper + + # Store this here for performance. StateBases = get_base_class(StateVar) StateIterBases = get_base_class(StateIterVar)