diff --git a/reflex/state.py b/reflex/state.py index 6af70db14..29ad84b3f 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -35,6 +35,7 @@ from typing import ( import dill from sqlalchemy.orm import DeclarativeBase +from typing_extensions import Self from reflex.config import get_config from reflex.vars.base import ( @@ -43,6 +44,7 @@ from reflex.vars.base import ( Var, computed_var, dispatch, + get_unique_variable_name, is_computed_var, ) @@ -695,6 +697,36 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): and hasattr(value, "__code__") ) + @classmethod + def _evaluate(cls, f: Callable[[Self], Any]) -> Var: + """Evaluate a function to a ComputedVar. Experimental. + + Args: + f: The function to evaluate. + + Returns: + The ComputedVar. + """ + console.warn( + "The _evaluate method is experimental and may be removed in future versions." + ) + from reflex.components.base.fragment import fragment + from reflex.components.component import Component + + unique_var_name = get_unique_variable_name() + + @computed_var(_js_expr=unique_var_name, return_type=Component) + def computed_var_func(state: Self): + return fragment(f(state)) + + setattr(cls, unique_var_name, computed_var_func) + cls.computed_vars[unique_var_name] = computed_var_func + cls.vars[unique_var_name] = computed_var_func + cls._update_substate_inherited_vars({unique_var_name: computed_var_func}) + cls._always_dirty_computed_vars.add(unique_var_name) + + return getattr(cls, unique_var_name) + @classmethod def _mixins(cls) -> List[Type]: """Get the mixin classes of the state. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index afbc56a55..2d78a14be 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1559,8 +1559,9 @@ class ComputedVar(Var[RETURN_TYPE]): Raises: TypeError: If the computed var dependencies are not Var instances or var names. """ - hints = get_type_hints(fget) - hint = hints.get("return", Any) + hint = kwargs.pop("return_type", None) or get_type_hints(fget).get( + "return", Any + ) kwargs["_js_expr"] = kwargs.pop("_js_expr", fget.__name__) kwargs["_var_type"] = kwargs.pop("_var_type", hint) diff --git a/tests/integration/test_dynamic_components.py b/tests/integration/test_dynamic_components.py index 31080223f..5a4d99f9e 100644 --- a/tests/integration/test_dynamic_components.py +++ b/tests/integration/test_dynamic_components.py @@ -16,6 +16,8 @@ def DynamicComponents(): import reflex as rx class DynamicComponentsState(rx.State): + value: int = 10 + button: rx.Component = rx.button( "Click me", custom_attrs={ @@ -52,11 +54,20 @@ def DynamicComponents(): app = rx.App() + def factorial(n: int) -> int: + if n == 0: + return 1 + return n * factorial(n - 1) + @app.add_page def index(): return rx.vstack( DynamicComponentsState.client_token_component, DynamicComponentsState.button, + rx.text( + DynamicComponentsState._evaluate(lambda state: factorial(state.value)), + id="factorial", + ), ) @@ -150,3 +161,7 @@ def test_dynamic_components(driver, dynamic_components: AppHarness): dynamic_components.poll_for_content(button, exp_not_equal="Click me") == "Clicked" ) + + factorial = poll_for_result(lambda: driver.find_element(By.ID, "factorial")) + assert factorial + assert factorial.text == "3628800"