diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 05675643f..b86f083cc 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -199,7 +199,7 @@ class LogLevel(str, Enum): Returns: The log level for the subprocess """ - return self if self != LogLevel.DEFAULT else LogLevel.INFO + return self if self != LogLevel.DEFAULT else LogLevel.WARNING # Server socket configuration variables diff --git a/reflex/event.py b/reflex/event.py index 95358ace1..9b54eddeb 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -839,6 +839,16 @@ def call_script( ), ), } + if isinstance(javascript_code, str): + # When there is VarData, include it and eval the JS code inline on the client. + javascript_code, original_code = ( + LiteralVar.create(javascript_code), + javascript_code, + ) + if not javascript_code._get_all_var_data(): + # Without VarData, cast to string and eval the code in the event loop. + javascript_code = str(Var(_js_expr=original_code)) + return server_side( "_call_script", get_fn_signature(call_script), diff --git a/reflex/istate/data.py b/reflex/istate/data.py new file mode 100644 index 000000000..9f6e3b3f4 --- /dev/null +++ b/reflex/istate/data.py @@ -0,0 +1,126 @@ +"""This module contains the dataclasses representing the router object.""" + +import dataclasses +from typing import Optional + +from reflex import constants +from reflex.utils import format + + +@dataclasses.dataclass(frozen=True) +class HeaderData: + """An object containing headers data.""" + + host: str = "" + origin: str = "" + upgrade: str = "" + connection: str = "" + cookie: str = "" + pragma: str = "" + cache_control: str = "" + user_agent: str = "" + sec_websocket_version: str = "" + sec_websocket_key: str = "" + sec_websocket_extensions: str = "" + accept_encoding: str = "" + accept_language: str = "" + + def __init__(self, router_data: Optional[dict] = None): + """Initalize the HeaderData object based on router_data. + + Args: + router_data: the router_data dict. + """ + if router_data: + for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items(): + object.__setattr__(self, format.to_snake_case(k), v) + else: + for k in dataclasses.fields(self): + object.__setattr__(self, k.name, "") + + +@dataclasses.dataclass(frozen=True) +class PageData: + """An object containing page data.""" + + host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) + path: str = "" + raw_path: str = "" + full_path: str = "" + full_raw_path: str = "" + params: dict = dataclasses.field(default_factory=dict) + + def __init__(self, router_data: Optional[dict] = None): + """Initalize the PageData object based on router_data. + + Args: + router_data: the router_data dict. + """ + if router_data: + object.__setattr__( + self, + "host", + router_data.get(constants.RouteVar.HEADERS, {}).get("origin", ""), + ) + object.__setattr__( + self, "path", router_data.get(constants.RouteVar.PATH, "") + ) + object.__setattr__( + self, "raw_path", router_data.get(constants.RouteVar.ORIGIN, "") + ) + object.__setattr__(self, "full_path", f"{self.host}{self.path}") + object.__setattr__(self, "full_raw_path", f"{self.host}{self.raw_path}") + object.__setattr__( + self, "params", router_data.get(constants.RouteVar.QUERY, {}) + ) + else: + object.__setattr__(self, "host", "") + object.__setattr__(self, "path", "") + object.__setattr__(self, "raw_path", "") + object.__setattr__(self, "full_path", "") + object.__setattr__(self, "full_raw_path", "") + object.__setattr__(self, "params", {}) + + +@dataclasses.dataclass(frozen=True, init=False) +class SessionData: + """An object containing session data.""" + + client_token: str = "" + client_ip: str = "" + session_id: str = "" + + def __init__(self, router_data: Optional[dict] = None): + """Initalize the SessionData object based on router_data. + + Args: + router_data: the router_data dict. + """ + if router_data: + client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "") + client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "") + session_id = router_data.get(constants.RouteVar.SESSION_ID, "") + else: + client_token = client_ip = session_id = "" + object.__setattr__(self, "client_token", client_token) + object.__setattr__(self, "client_ip", client_ip) + object.__setattr__(self, "session_id", session_id) + + +@dataclasses.dataclass(frozen=True, init=False) +class RouterData: + """An object containing RouterData.""" + + session: SessionData = dataclasses.field(default_factory=SessionData) + headers: HeaderData = dataclasses.field(default_factory=HeaderData) + page: PageData = dataclasses.field(default_factory=PageData) + + def __init__(self, router_data: Optional[dict] = None): + """Initialize the RouterData object. + + Args: + router_data: the router_data dict. + """ + object.__setattr__(self, "session", SessionData(router_data)) + object.__setattr__(self, "headers", HeaderData(router_data)) + object.__setattr__(self, "page", PageData(router_data)) diff --git a/reflex/state.py b/reflex/state.py index 64ea960e1..b1988e38a 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -38,6 +38,7 @@ from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self from reflex.config import get_config +from reflex.istate.data import RouterData from reflex.vars.base import ( ComputedVar, DynamicRouteVar, @@ -93,125 +94,6 @@ var = computed_var TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb -@dataclasses.dataclass(frozen=True) -class HeaderData: - """An object containing headers data.""" - - host: str = "" - origin: str = "" - upgrade: str = "" - connection: str = "" - cookie: str = "" - pragma: str = "" - cache_control: str = "" - user_agent: str = "" - sec_websocket_version: str = "" - sec_websocket_key: str = "" - sec_websocket_extensions: str = "" - accept_encoding: str = "" - accept_language: str = "" - - def __init__(self, router_data: Optional[dict] = None): - """Initalize the HeaderData object based on router_data. - - Args: - router_data: the router_data dict. - """ - if router_data: - for k, v in router_data.get(constants.RouteVar.HEADERS, {}).items(): - object.__setattr__(self, format.to_snake_case(k), v) - else: - for k in dataclasses.fields(self): - object.__setattr__(self, k.name, "") - - -@dataclasses.dataclass(frozen=True) -class PageData: - """An object containing page data.""" - - host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) - path: str = "" - raw_path: str = "" - full_path: str = "" - full_raw_path: str = "" - params: dict = dataclasses.field(default_factory=dict) - - def __init__(self, router_data: Optional[dict] = None): - """Initalize the PageData object based on router_data. - - Args: - router_data: the router_data dict. - """ - if router_data: - object.__setattr__( - self, - "host", - router_data.get(constants.RouteVar.HEADERS, {}).get("origin", ""), - ) - object.__setattr__( - self, "path", router_data.get(constants.RouteVar.PATH, "") - ) - object.__setattr__( - self, "raw_path", router_data.get(constants.RouteVar.ORIGIN, "") - ) - object.__setattr__(self, "full_path", f"{self.host}{self.path}") - object.__setattr__(self, "full_raw_path", f"{self.host}{self.raw_path}") - object.__setattr__( - self, "params", router_data.get(constants.RouteVar.QUERY, {}) - ) - else: - object.__setattr__(self, "host", "") - object.__setattr__(self, "path", "") - object.__setattr__(self, "raw_path", "") - object.__setattr__(self, "full_path", "") - object.__setattr__(self, "full_raw_path", "") - object.__setattr__(self, "params", {}) - - -@dataclasses.dataclass(frozen=True, init=False) -class SessionData: - """An object containing session data.""" - - client_token: str = "" - client_ip: str = "" - session_id: str = "" - - def __init__(self, router_data: Optional[dict] = None): - """Initalize the SessionData object based on router_data. - - Args: - router_data: the router_data dict. - """ - if router_data: - client_token = router_data.get(constants.RouteVar.CLIENT_TOKEN, "") - client_ip = router_data.get(constants.RouteVar.CLIENT_IP, "") - session_id = router_data.get(constants.RouteVar.SESSION_ID, "") - else: - client_token = client_ip = session_id = "" - object.__setattr__(self, "client_token", client_token) - object.__setattr__(self, "client_ip", client_ip) - object.__setattr__(self, "session_id", session_id) - - -@dataclasses.dataclass(frozen=True, init=False) -class RouterData: - """An object containing RouterData.""" - - session: SessionData = dataclasses.field(default_factory=SessionData) - headers: HeaderData = dataclasses.field(default_factory=HeaderData) - page: PageData = dataclasses.field(default_factory=PageData) - - def __init__(self, router_data: Optional[dict] = None): - """Initialize the RouterData object. - - Args: - router_data: the router_data dict. - """ - object.__setattr__(self, "session", SessionData(router_data)) - object.__setattr__(self, "headers", HeaderData(router_data)) - object.__setattr__(self, "page", PageData(router_data)) - - def _no_chain_background_task( state_cls: Type["BaseState"], name: str, fn: Callable ) -> Callable: @@ -699,11 +581,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) @classmethod - def _evaluate(cls, f: Callable[[Self], Any]) -> Var: + def _evaluate( + cls, f: Callable[[Self], Any], of_type: Union[type, None] = None + ) -> Var: """Evaluate a function to a ComputedVar. Experimental. Args: f: The function to evaluate. + of_type: The type of the ComputedVar. Defaults to Component. Returns: The ComputedVar. @@ -711,14 +596,23 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): console.warn( "The _evaluate method is experimental and may be removed in future versions." ) - from reflex.components.base.fragment import fragment from reflex.components.component import Component + of_type = of_type or Component + unique_var_name = get_unique_variable_name() - @computed_var(_js_expr=unique_var_name, return_type=Component) + @computed_var(_js_expr=unique_var_name, return_type=of_type) def computed_var_func(state: Self): - return fragment(f(state)) + result = f(state) + + if not isinstance(result, of_type): + console.warn( + f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. " + "You can specify expected type with `of_type` argument." + ) + + return result setattr(cls, unique_var_name, computed_var_func) cls.computed_vars[unique_var_name] = computed_var_func diff --git a/tests/integration/test_call_script.py b/tests/integration/test_call_script.py index 5a3b83abf..744d83d16 100644 --- a/tests/integration/test_call_script.py +++ b/tests/integration/test_call_script.py @@ -46,6 +46,7 @@ def CallScript(): inline_counter: int = 0 external_counter: int = 0 value: str = "Initial" + last_result: str = "" def call_script_callback(self, result): self.results.append(result) @@ -137,6 +138,32 @@ def CallScript(): callback=CallScriptState.set_external_counter, # type: ignore ) + def call_with_var_f_string(self): + return rx.call_script( + f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}", + callback=CallScriptState.set_last_result, # type: ignore + ) + + def call_with_var_str_cast(self): + return rx.call_script( + f"{str(rx.Var('inline_counter'))} + {str(rx.Var('external_counter'))}", + callback=CallScriptState.set_last_result, # type: ignore + ) + + def call_with_var_f_string_wrapped(self): + return rx.call_script( + rx.Var(f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}"), + callback=CallScriptState.set_last_result, # type: ignore + ) + + def call_with_var_str_cast_wrapped(self): + return rx.call_script( + rx.Var( + f"{str(rx.Var('inline_counter'))} + {str(rx.Var('external_counter'))}" + ), + callback=CallScriptState.set_last_result, # type: ignore + ) + def reset_(self): yield rx.call_script("inline_counter = 0; external_counter = 0") self.reset() @@ -234,6 +261,68 @@ def CallScript(): id="update_value", ), rx.button("Reset", id="reset", on_click=CallScriptState.reset_), + rx.input( + value=CallScriptState.last_result, + id="last_result", + read_only=True, + on_click=CallScriptState.set_last_result(""), # type: ignore + ), + rx.button( + "call_with_var_f_string", + on_click=CallScriptState.call_with_var_f_string, + id="call_with_var_f_string", + ), + rx.button( + "call_with_var_str_cast", + on_click=CallScriptState.call_with_var_str_cast, + id="call_with_var_str_cast", + ), + rx.button( + "call_with_var_f_string_wrapped", + on_click=CallScriptState.call_with_var_f_string_wrapped, + id="call_with_var_f_string_wrapped", + ), + rx.button( + "call_with_var_str_cast_wrapped", + on_click=CallScriptState.call_with_var_str_cast_wrapped, + id="call_with_var_str_cast_wrapped", + ), + rx.button( + "call_with_var_f_string_inline", + on_click=rx.call_script( + f"{rx.Var('inline_counter')} + {CallScriptState.last_result}", + callback=CallScriptState.set_last_result, # type: ignore + ), + id="call_with_var_f_string_inline", + ), + rx.button( + "call_with_var_str_cast_inline", + on_click=rx.call_script( + f"{str(rx.Var('inline_counter'))} + {str(rx.Var('external_counter'))}", + callback=CallScriptState.set_last_result, # type: ignore + ), + id="call_with_var_str_cast_inline", + ), + rx.button( + "call_with_var_f_string_wrapped_inline", + on_click=rx.call_script( + rx.Var( + f"{rx.Var('inline_counter')} + {CallScriptState.last_result}" + ), + callback=CallScriptState.set_last_result, # type: ignore + ), + id="call_with_var_f_string_wrapped_inline", + ), + rx.button( + "call_with_var_str_cast_wrapped_inline", + on_click=rx.call_script( + rx.Var( + f"{str(rx.Var('inline_counter'))} + {str(rx.Var('external_counter'))}" + ), + callback=CallScriptState.set_last_result, # type: ignore + ), + id="call_with_var_str_cast_wrapped_inline", + ), ) @@ -363,3 +452,73 @@ def test_call_script( call_script.poll_for_content(update_value_button, exp_not_equal="Initial") == "updated" ) + + +def test_call_script_w_var( + call_script: AppHarness, + driver: WebDriver, +): + """Test evaluating javascript expressions containing Vars. + + Args: + call_script: harness for CallScript app. + driver: WebDriver instance. + """ + assert_token(driver) + last_result = driver.find_element(By.ID, "last_result") + assert last_result.get_attribute("value") == "" + + inline_return_button = driver.find_element(By.ID, "inline_return") + + call_with_var_f_string_button = driver.find_element(By.ID, "call_with_var_f_string") + call_with_var_str_cast_button = driver.find_element(By.ID, "call_with_var_str_cast") + call_with_var_f_string_wrapped_button = driver.find_element( + By.ID, "call_with_var_f_string_wrapped" + ) + call_with_var_str_cast_wrapped_button = driver.find_element( + By.ID, "call_with_var_str_cast_wrapped" + ) + call_with_var_f_string_inline_button = driver.find_element( + By.ID, "call_with_var_f_string_inline" + ) + call_with_var_str_cast_inline_button = driver.find_element( + By.ID, "call_with_var_str_cast_inline" + ) + call_with_var_f_string_wrapped_inline_button = driver.find_element( + By.ID, "call_with_var_f_string_wrapped_inline" + ) + call_with_var_str_cast_wrapped_inline_button = driver.find_element( + By.ID, "call_with_var_str_cast_wrapped_inline" + ) + + inline_return_button.click() + call_with_var_f_string_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="") == "1" + + inline_return_button.click() + call_with_var_str_cast_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="1") == "2" + + inline_return_button.click() + call_with_var_f_string_wrapped_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="2") == "3" + + inline_return_button.click() + call_with_var_str_cast_wrapped_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="3") == "4" + + inline_return_button.click() + call_with_var_f_string_inline_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="4") == "9" + + inline_return_button.click() + call_with_var_str_cast_inline_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="9") == "6" + + inline_return_button.click() + call_with_var_f_string_wrapped_inline_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="6") == "13" + + inline_return_button.click() + call_with_var_str_cast_wrapped_inline_button.click() + assert call_script.poll_for_value(last_result, exp_not_equal="13") == "8" diff --git a/tests/integration/test_dynamic_components.py b/tests/integration/test_dynamic_components.py index 5a4d99f9e..aeebd10e9 100644 --- a/tests/integration/test_dynamic_components.py +++ b/tests/integration/test_dynamic_components.py @@ -65,7 +65,9 @@ def DynamicComponents(): DynamicComponentsState.client_token_component, DynamicComponentsState.button, rx.text( - DynamicComponentsState._evaluate(lambda state: factorial(state.value)), + DynamicComponentsState._evaluate( + lambda state: factorial(state.value), of_type=int + ), id="factorial", ), )