diff --git a/reflex/app.py b/reflex/app.py index d0ee06ae9..03382751a 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -111,7 +111,7 @@ from reflex.utils import ( prerequisites, types, ) -from reflex.utils.exec import is_prod_mode, is_testing_env +from reflex.utils.exec import get_compile_context, is_prod_mode, is_testing_env from reflex.utils.imports import ImportVar if TYPE_CHECKING: @@ -201,14 +201,17 @@ def default_overlay_component() -> Component: Returns: The default overlay_component, which is a connection_modal. """ - config = get_config() from reflex.components.component import memo def default_overlay_components(): return Fragment.create( connection_pulser(), connection_toaster(), - *([backend_disabled()] if config.is_reflex_cloud else []), + *( + [backend_disabled()] + if get_compile_context() == constants.CompileContext.DEPLOY + else [] + ), *codespaces.codespaces_auto_redirect(), ) @@ -1136,6 +1139,16 @@ class App(MiddlewareMixin, LifespanMixin): self._validate_var_dependencies() self._setup_overlay_component() + + if config.show_built_with_reflex is None: + if ( + get_compile_context() == constants.CompileContext.DEPLOY + and prerequisites.get_user_tier() in ["pro", "team", "enterprise"] + ): + config.show_built_with_reflex = False + else: + config.show_built_with_reflex = True + if is_prod_mode() and config.show_built_with_reflex: self._setup_sticky_badge() diff --git a/reflex/config.py b/reflex/config.py index 0d48057d7..296b01805 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -589,6 +589,11 @@ class ExecutorType(enum.Enum): class EnvironmentVariables: """Environment variables class to instantiate environment variables.""" + # Indicate the current command that was invoked in the reflex CLI. + REFLEX_COMPILE_CONTEXT: EnvVar[constants.CompileContext] = env_var( + constants.CompileContext.UNDEFINED, internal=True + ) + # Whether to use npm over bun to install frontend packages. REFLEX_USE_NPM: EnvVar[bool] = env_var(False) @@ -636,7 +641,7 @@ class EnvironmentVariables: REFLEX_COMPILE_THREADS: EnvVar[Optional[int]] = env_var(None) # The directory to store reflex dependencies. - REFLEX_DIR: EnvVar[Path] = env_var(Path(constants.Reflex.DIR)) + REFLEX_DIR: EnvVar[Path] = env_var(constants.Reflex.DIR) # Whether to print the SQL queries if the log level is INFO or lower. SQLALCHEMY_ECHO: EnvVar[bool] = env_var(False) @@ -844,7 +849,7 @@ class Config(Base): env_file: Optional[str] = None # Whether to display the sticky "Built with Reflex" badge on all pages. - show_built_with_reflex: bool = True + show_built_with_reflex: Optional[bool] = None # Whether the app is running in the reflex cloud environment. is_reflex_cloud: bool = False diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index f5946bf5e..5a918338d 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -25,6 +25,7 @@ from .base import ( from .compiler import ( NOCOMPILE_FILE, SETTER_PREFIX, + CompileContext, CompileVars, ComponentName, Ext, @@ -65,6 +66,7 @@ __ALL__ = [ ColorMode, Config, COOKIES, + CompileContext, ComponentName, CustomComponents, DefaultPage, diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 9bc9978dc..40134c15b 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -111,6 +111,15 @@ class ComponentName(Enum): return self.value.lower() + Ext.ZIP +class CompileContext(str, Enum): + """The context in which the compiler is running.""" + + RUN = "run" + EXPORT = "export" + DEPLOY = "deploy" + UNDEFINED = "undefined" + + class Imports(SimpleNamespace): """Common sets of import vars.""" diff --git a/reflex/reflex.py b/reflex/reflex.py index 4ed6f8d4a..878b32d76 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -292,6 +292,8 @@ def run( if frontend and backend: console.error("Cannot use both --frontend-only and --backend-only options.") raise typer.Exit(1) + + environment.REFLEX_COMPILE_CONTEXT.set(constants.CompileContext.RUN) environment.REFLEX_BACKEND_ONLY.set(backend) environment.REFLEX_FRONTEND_ONLY.set(frontend) @@ -338,6 +340,8 @@ def export( from reflex.utils import export as export_utils from reflex.utils import prerequisites + environment.REFLEX_COMPILE_CONTEXT.set(constants.CompileContext.EXPORT) + frontend, backend = prerequisites.check_running_mode(frontend, backend) if prerequisites.needs_reinit(frontend=frontend or not backend): @@ -537,6 +541,8 @@ def deploy( check_version() + environment.REFLEX_COMPILE_CONTEXT.set(constants.CompileContext.DEPLOY) + # Set the log level. console.set_log_level(loglevel) diff --git a/reflex/utils/exec.py b/reflex/utils/exec.py index b16aaea1c..5474ae82a 100644 --- a/reflex/utils/exec.py +++ b/reflex/utils/exec.py @@ -584,3 +584,12 @@ def is_prod_mode() -> bool: """ current_mode = environment.REFLEX_ENV_MODE.get() return current_mode == constants.Env.PROD + + +def get_compile_context() -> constants.CompileContext: + """Check if the app is compiled for deploy. + + Returns: + Whether the app is being compiled for deploy. + """ + return environment.REFLEX_COMPILE_CONTEXT.get() diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 3cd65a7eb..145b5324c 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -2001,6 +2001,22 @@ def is_generation_hash(template: str) -> bool: return re.match(r"^[0-9a-f]{32,}$", template) is not None +def get_user_tier(): + """Get the current user's tier. + + Returns: + The current user's tier. + """ + from reflex_cli.v2.utils import hosting + + authenticated_token = hosting.authenticated_token() + return ( + authenticated_token[1].get("tier", "").lower() + if authenticated_token[0] + else "anonymous" + ) + + def check_config_option_in_tier( option_name: str, allowed_tiers: list[str], @@ -2015,23 +2031,21 @@ def check_config_option_in_tier( fallback_value: The fallback value if the option is not allowed. help_link: The help link to show to a user that is authenticated. """ - from reflex_cli.v2.utils import hosting - config = get_config() - authenticated_token = hosting.authenticated_token() - if not authenticated_token[0]: + current_tier = get_user_tier() + + if current_tier == "anonymous": the_remedy = ( "You are currently logged out. Run `reflex login` to access this option." ) - current_tier = "anonymous" else: - current_tier = authenticated_token[1].get("tier", "").lower() the_remedy = ( f"Your current subscription tier is `{current_tier}`. " f"Please upgrade to {allowed_tiers} to access this option. " ) if help_link: the_remedy += f"See {help_link} for more information." + if current_tier not in allowed_tiers: console.warn(f"Config option `{option_name}` is restricted. {the_remedy}") setattr(config, option_name, fallback_value) diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index bfc9ea0ae..f7fd7365c 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -7,24 +7,19 @@ import pytest from selenium.common.exceptions import NoSuchElementException from selenium.webdriver.common.by import By +from reflex import constants +from reflex.config import environment from reflex.testing import AppHarness, WebDriver from .utils import SessionStorage -def ConnectionBanner(is_reflex_cloud: bool = False): - """App with a connection banner. - - Args: - is_reflex_cloud: The value for config.is_reflex_cloud. - """ +def ConnectionBanner(): + """App with a connection banner.""" import asyncio import reflex as rx - # Simulate reflex cloud deploy - rx.config.get_config().is_reflex_cloud = is_reflex_cloud - class State(rx.State): foo: int = 0 @@ -49,16 +44,17 @@ def ConnectionBanner(is_reflex_cloud: bool = False): @pytest.fixture( - params=[False, True], ids=["reflex_cloud_disabled", "reflex_cloud_enabled"] + params=[constants.CompileContext.RUN, constants.CompileContext.DEPLOY], + ids=["compile_context_run", "compile_context_deploy"], ) -def simulate_is_reflex_cloud(request) -> bool: +def simulate_compile_context(request) -> constants.CompileContext: """Fixture to simulate reflex cloud deployment. Args: request: pytest request fixture. Returns: - True if reflex cloud is enabled, False otherwise. + The context to run the app with. """ return request.param @@ -66,25 +62,27 @@ def simulate_is_reflex_cloud(request) -> bool: @pytest.fixture() def connection_banner( tmp_path, - simulate_is_reflex_cloud: bool, + simulate_compile_context: constants.CompileContext, ) -> Generator[AppHarness, None, None]: """Start ConnectionBanner app at tmp_path via AppHarness. Args: tmp_path: pytest tmp_path fixture - simulate_is_reflex_cloud: Whether is_reflex_cloud is set for the app. + simulate_compile_context: Which context to run the app with. Yields: running AppHarness instance """ + environment.REFLEX_COMPILE_CONTEXT.set(simulate_compile_context) + with AppHarness.create( root=tmp_path, - app_source=functools.partial( - ConnectionBanner, is_reflex_cloud=simulate_is_reflex_cloud + app_source=functools.partial(ConnectionBanner), + app_name=( + "connection_banner_reflex_cloud" + if simulate_compile_context == constants.CompileContext.DEPLOY + else "connection_banner" ), - app_name="connection_banner_reflex_cloud" - if simulate_is_reflex_cloud - else "connection_banner", ) as harness: yield harness @@ -194,13 +192,13 @@ async def test_connection_banner(connection_banner: AppHarness): @pytest.mark.asyncio async def test_cloud_banner( - connection_banner: AppHarness, simulate_is_reflex_cloud: bool + connection_banner: AppHarness, simulate_compile_context: constants.CompileContext ): """Test that the connection banner is displayed when the websocket drops. Args: connection_banner: AppHarness instance. - simulate_is_reflex_cloud: Whether is_reflex_cloud is set for the app. + simulate_compile_context: Which context to set for the app. """ assert connection_banner.app_instance is not None assert connection_banner.backend is not None @@ -213,7 +211,7 @@ async def test_cloud_banner( driver.add_cookie({"name": "backend-enabled", "value": "false"}) driver.refresh() - if simulate_is_reflex_cloud: + if simulate_compile_context == constants.CompileContext.DEPLOY: assert connection_banner._poll_for(lambda: has_cloud_banner(driver)) else: _assert_token(connection_banner, driver)