feat: Adding subscriptions to rooms.

TODO: Remove subscriptions when scopes change.
This commit is contained in:
Andreas Eismann 2024-07-17 11:22:00 +02:00
parent ab19508451
commit db4c73a027
No known key found for this signature in database
3 changed files with 50 additions and 3 deletions

View File

@ -1078,7 +1078,9 @@ class App(MiddlewareMixin, LifespanMixin, Base):
# When the state is modified reset dirty status and emit the delta to the frontend. # When the state is modified reset dirty status and emit the delta to the frontend.
state._clean() state._clean()
await self.event_namespace.emit_update( await self.event_namespace.emit_update(
update=StateUpdate(delta=delta), update=StateUpdate(
delta=delta, scopes=state.scopes_and_subscopes()
),
sid=state.router.session.session_id, sid=state.router.session.session_id,
) )
@ -1273,7 +1275,7 @@ async def process(
else: else:
if app._process_background(state, event) is not None: if app._process_background(state, event) is not None:
# `final=True` allows the frontend send more events immediately. # `final=True` allows the frontend send more events immediately.
yield StateUpdate(final=True) yield StateUpdate(final=True, scopes=state.scopes_and_subscopes())
return return
# Process the event synchronously. # Process the event synchronously.
@ -1470,6 +1472,13 @@ class EventNamespace(AsyncNamespace):
sid: The Socket.IO session id. sid: The Socket.IO session id.
room: The room to send the update to. room: The room to send the update to.
""" """
# TODO We don't know when to leave a room yet.
for receiver in update.scopes:
if self.sid_to_token[sid] != receiver:
room = receiver
if room not in self.rooms(sid):
await self.enter_room(sid, room)
# Creating a task prevents the update from being blocked behind other coroutines. # Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task( await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=room or sid) self.emit(str(constants.SocketEvent.EVENT), update.json(), to=room or sid)

View File

@ -46,4 +46,4 @@ class HydrateMiddleware(Middleware):
state._clean() state._clean()
# Return the state update. # Return the state update.
return StateUpdate(delta=delta, events=[]) return StateUpdate(delta=delta, events=[], scopes=state.scopes_and_subscopes())

View File

@ -8,6 +8,7 @@ import copy
import functools import functools
import inspect import inspect
import os import os
import traceback
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
@ -68,6 +69,22 @@ var = computed_var
TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb
def print_stack(depth: int = 3):
"""Print the current stacktrace to the console.
Args:
depth: Depth of the stack-trace to print
"""
stack = traceback.extract_stack()
stack.reverse()
print("stacktrace")
for idx in range(1, depth + 1):
stack_info = stack[idx]
print(
f" {stack_info.name} {os.path.basename(stack_info.filename)}:{stack_info.lineno}"
)
class HeaderData(Base): class HeaderData(Base):
"""An object containing headers data.""" """An object containing headers data."""
@ -1430,6 +1447,20 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)" f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
) )
def scopes_and_subscopes(self) -> list[str]:
"""Recursively gathers all scopes of self and substates.
Returns:
A unique list of the scopes/token
"""
result = [self._get_token()]
for substate in self.substates.values():
subscopes = substate.scopes_and_subscopes()
for subscope in subscopes:
if subscope not in result:
result.append(subscope)
return result
def _get_token(self, other=None) -> str: def _get_token(self, other=None) -> str:
token = self.router.session.client_token token = self.router.session.client_token
cls = other or self.__class__ cls = other or self.__class__
@ -1481,6 +1512,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
delta=delta, delta=delta,
events=fixed_events, events=fixed_events,
final=final if not handler.is_background else True, final=final if not handler.is_background else True,
scopes=state.scopes_and_subscopes(),
) )
except Exception as ex: except Exception as ex:
state._clean() state._clean()
@ -2254,6 +2286,8 @@ class StateUpdate(Base):
# Whether this is the final state update for the event. # Whether this is the final state update for the event.
final: bool = True final: bool = True
scopes: list[str] = []
class StateManager(Base, ABC): class StateManager(Base, ABC):
"""A class to manage many client states.""" """A class to manage many client states."""
@ -2531,6 +2565,10 @@ class StateManagerRedis(StateManager):
for substate_name, substate_task in tasks.items(): for substate_name, substate_task in tasks.items():
state.substates[substate_name] = await substate_task state.substates[substate_name] = await substate_task
# async def print_all_keys(self):
# for key in await self.redis.keys():
# print(f"redis_key: {key}")
async def get_state( async def get_state(
self, self,
token: str, token: str,