diff --git a/pynecone/app.py b/pynecone/app.py index cf0c75b0c..4bb5e168d 100644 --- a/pynecone/app.py +++ b/pynecone/app.py @@ -59,7 +59,7 @@ class App(Base): self.middleware.append(HydrateMiddleware()) # Set up the state manager. - self.state_manager.set(state=self.state) + self.state_manager.setup(state=self.state) # Set up the API. self.api = fastapi.FastAPI() diff --git a/pynecone/state.py b/pynecone/state.py index bbf62c1cf..6418003d6 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -451,9 +451,6 @@ class StateUpdate(Base): events: List[Event] = [] -redis = None - - class StateManager(Base): """A class to manage many client states.""" @@ -466,16 +463,13 @@ class StateManager(Base): # The token expiration time (s). token_expiration: int = constants.TOKEN_EXPIRATION - def __init__(self, *args, **kwargs): - """Initialize the state manager. + # The redis client to use. + redis: Any = None - Args: - *args: Args to pass to the base class. - **kwargs: Kwargs to pass to the base class. - """ - super().__init__(*args, **kwargs) - global redis - redis = utils.get_redis() + def setup(self, state: Type[State]): + """Setup the state manager.""" + self.state = state + self.redis = utils.get_redis() def get_state(self, token: str) -> State: """Get the state for a token. @@ -486,8 +480,8 @@ class StateManager(Base): Returns: The state for the token. """ - if redis is not None: - redis_state = redis.get(token) + if self.redis is not None: + redis_state = self.redis.get(token) if redis_state is None: self.set_state(token, self.state()) return self.get_state(token) @@ -504,6 +498,6 @@ class StateManager(Base): token: The token to set the state for. state: The state to set. """ - if redis is None: + if self.redis is None: return - redis.set(token, pickle.dumps(state), ex=self.token_expiration) + self.redis.set(token, pickle.dumps(state), ex=self.token_expiration) diff --git a/pynecone/utils.py b/pynecone/utils.py index fc7ee4f5e..704e764ac 100644 --- a/pynecone/utils.py +++ b/pynecone/utils.py @@ -860,12 +860,12 @@ def get_redis(): """ try: import redis - - config = get_config() - if config.redis_url is None: - return None - redis_url, redis_port = config.redis_url.split(":") - print("Using redis at", config.redis_url) - return redis.Redis(host=redis_url, port=int(redis_port), db=0) except: return None + + config = get_config() + if config.redis_url is None: + return None + redis_url, redis_port = config.redis_url.split(":") + print("Using redis at", config.redis_url) + return redis.Redis(host=redis_url, port=int(redis_port), db=0)