From 63bf1b8817202ef70ba7e024a5e6321f2e220e55 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Wed, 11 Sep 2024 18:43:18 +0200 Subject: [PATCH] Dynamic route vars silently shadow all other vars (#3805) * fix dynamic route vars silently shadow all other vars * add test * fix: allow multiple dynamic routes with the same arg * add test for multiple dynamic args with the same name * avoid side-effects with DynamicState tests * fix dynamic route integration test which shadowed a var * fix darglint * refactor to DynamicRouteVar * old typing stuff again * from typing_extensions import Self try to keep typing backward compatible with older releases we support * Raise a specific exception when encountering dynamic route arg shadowing --------- Co-authored-by: Masen Furer <m_github@0x26.net> --- reflex/ivars/base.py | 12 ++++++-- reflex/state.py | 41 ++++++++++++++++++++++---- reflex/utils/exceptions.py | 4 +++ reflex/utils/types.py | 5 ++++ tests/test_app.py | 59 ++++++++++++++++++++++++++++++++++++-- 5 files changed, 110 insertions(+), 11 deletions(-) diff --git a/reflex/ivars/base.py b/reflex/ivars/base.py index 8bb7957c4..3c08b8119 100644 --- a/reflex/ivars/base.py +++ b/reflex/ivars/base.py @@ -43,7 +43,7 @@ from reflex.utils.exceptions import ( VarValueError, ) from reflex.utils.format import format_state_name -from reflex.utils.types import GenericType, get_origin +from reflex.utils.types import GenericType, Self, get_origin from reflex.vars import ( REPLACED_NAMES, Var, @@ -1467,7 +1467,7 @@ class ImmutableComputedVar(ImmutableVar[RETURN_TYPE]): object.__setattr__(self, "_fget", fget) @override - def _replace(self, merge_var_data=None, **kwargs: Any) -> ImmutableComputedVar: + def _replace(self, merge_var_data=None, **kwargs: Any) -> Self: """Replace the attributes of the ComputedVar. Args: @@ -1499,7 +1499,7 @@ class ImmutableComputedVar(ImmutableVar[RETURN_TYPE]): unexpected_kwargs = ", ".join(kwargs.keys()) raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}") - return ImmutableComputedVar(**field_values) + return type(self)(**field_values) @property def _cache_attr(self) -> str: @@ -1773,6 +1773,12 @@ class ImmutableComputedVar(ImmutableVar[RETURN_TYPE]): return self._fget +class DynamicRouteVar(ImmutableComputedVar[Union[str, List[str]]]): + """A ComputedVar that represents a dynamic route.""" + + pass + + if TYPE_CHECKING: BASE_STATE = TypeVar("BASE_STATE", bound=BaseState) diff --git a/reflex/state.py b/reflex/state.py index a1b5eac8c..b815fd44f 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -35,6 +35,7 @@ from sqlalchemy.orm import DeclarativeBase from reflex.config import get_config from reflex.ivars.base import ( + DynamicRouteVar, ImmutableComputedVar, ImmutableVar, immutable_computed_var, @@ -60,7 +61,11 @@ from reflex.event import ( fix_events, ) from reflex.utils import console, format, path_ops, prerequisites, types -from reflex.utils.exceptions import ImmutableStateError, LockExpiredError +from reflex.utils.exceptions import ( + DynamicRouteArgShadowsStateVar, + ImmutableStateError, + LockExpiredError, +) from reflex.utils.exec import is_testing_env from reflex.utils.serializers import SerializedType, serialize, serializer from reflex.utils.types import override @@ -1023,17 +1028,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if not args: return + cls._check_overwritten_dynamic_args(list(args.keys())) + def argsingle_factory(param): def inner_func(self) -> str: return self.router.page.params.get(param, "") - return ImmutableComputedVar(fget=inner_func, cache=True) + return DynamicRouteVar(fget=inner_func, cache=True) def arglist_factory(param): - def inner_func(self) -> List: + def inner_func(self) -> List[str]: return self.router.page.params.get(param, []) - return ImmutableComputedVar(fget=inner_func, cache=True) + return DynamicRouteVar(fget=inner_func, cache=True) for param, value in args.items(): if value == constants.RouteArgType.SINGLE: @@ -1044,12 +1051,36 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): continue # to allow passing as a prop, evade python frozen rules (bad practice) object.__setattr__(func, "_var_name", param) - cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls) # type: ignore + # cls.vars[param] = cls.computed_vars[param] = func._var_set_state(cls) # type: ignore + cls.vars[param] = cls.computed_vars[param] = func._replace( + _var_data=VarData.from_state(cls) + ) setattr(cls, param, func) # Reinitialize dependency tracking dicts. cls._init_var_dependency_dicts() + @classmethod + def _check_overwritten_dynamic_args(cls, args: list[str]): + """Check if dynamic args are shadowing existing vars. Recursively checks all child states. + + Args: + args: a dict of args + + Raises: + DynamicRouteArgShadowsStateVar: If a dynamic arg is shadowing an existing var. + """ + for arg in args: + if ( + arg in cls.computed_vars + and not isinstance(cls.computed_vars[arg], DynamicRouteVar) + ) or arg in cls.base_vars: + raise DynamicRouteArgShadowsStateVar( + f"Dynamic route arg '{arg}' is shadowing an existing var in {cls.__module__}.{cls.__name__}" + ) + for substate in cls.get_substates(): + substate._check_overwritten_dynamic_args(args) + def __getattribute__(self, name: str) -> Any: """Get the state var. diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 8c1a1f07f..dbab3fb6d 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -87,3 +87,7 @@ class EventHandlerArgMismatch(ReflexError, TypeError): class EventFnArgMismatch(ReflexError, TypeError): """Raised when the number of args accepted by a lambda differs from that provided by the event trigger.""" + + +class DynamicRouteArgShadowsStateVar(ReflexError, NameError): + """Raised when a dynamic route arg shadows a state var.""" diff --git a/reflex/utils/types.py b/reflex/utils/types.py index b4b4c1090..f4463fa92 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -111,6 +111,11 @@ RESERVED_BACKEND_VAR_NAMES = { "_was_touched", } +if sys.version_info >= (3, 11): + from typing import Self as Self +else: + from typing_extensions import Self as Self + class Unset: """A class to represent an unset value. diff --git a/tests/test_app.py b/tests/test_app.py index 167cbf0d4..6767abe35 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -69,7 +69,7 @@ class EmptyState(BaseState): @pytest.fixture -def index_page(): +def index_page() -> ComponentCallable: """An index page. Returns: @@ -83,7 +83,7 @@ def index_page(): @pytest.fixture -def about_page(): +def about_page() -> ComponentCallable: """An about page. Returns: @@ -919,9 +919,62 @@ class DynamicState(BaseState): on_load_internal = OnLoadInternalState.on_load_internal.fn +def test_dynamic_arg_shadow( + index_page: ComponentCallable, + windows_platform: bool, + token: str, + app_module_mock: unittest.mock.Mock, + mocker, +): + """Create app with dynamic route var and try to add a page with a dynamic arg that shadows a state var. + + Args: + index_page: The index page. + windows_platform: Whether the system is windows. + token: a Token. + app_module_mock: Mocked app module. + mocker: pytest mocker object. + """ + arg_name = "counter" + route = f"/test/[{arg_name}]" + if windows_platform: + route.lstrip("/").replace("/", "\\") + app = app_module_mock.app = App(state=DynamicState) + assert app.state is not None + with pytest.raises(NameError): + app.add_page(index_page, route=route, on_load=DynamicState.on_load) # type: ignore + + +def test_multiple_dynamic_args( + index_page: ComponentCallable, + windows_platform: bool, + token: str, + app_module_mock: unittest.mock.Mock, + mocker, +): + """Create app with multiple dynamic route vars with the same name. + + Args: + index_page: The index page. + windows_platform: Whether the system is windows. + token: a Token. + app_module_mock: Mocked app module. + mocker: pytest mocker object. + """ + arg_name = "my_arg" + route = f"/test/[{arg_name}]" + route2 = f"/test2/[{arg_name}]" + if windows_platform: + route = route.lstrip("/").replace("/", "\\") + route2 = route2.lstrip("/").replace("/", "\\") + app = app_module_mock.app = App(state=EmptyState) + app.add_page(index_page, route=route) + app.add_page(index_page, route=route2) + + @pytest.mark.asyncio async def test_dynamic_route_var_route_change_completed_on_load( - index_page, + index_page: ComponentCallable, windows_platform: bool, token: str, app_module_mock: unittest.mock.Mock,