feat: Adding subscriptions to rooms.

TODO: Remove subscriptions when scopes change.
This commit is contained in:
Andreas Eismann 2024-07-17 11:22:00 +02:00
parent ab19508451
commit db4c73a027
No known key found for this signature in database
3 changed files with 50 additions and 3 deletions

View File

@ -1078,7 +1078,9 @@ class App(MiddlewareMixin, LifespanMixin, Base):
# When the state is modified reset dirty status and emit the delta to the frontend.
state._clean()
await self.event_namespace.emit_update(
update=StateUpdate(delta=delta),
update=StateUpdate(
delta=delta, scopes=state.scopes_and_subscopes()
),
sid=state.router.session.session_id,
)
@ -1273,7 +1275,7 @@ async def process(
else:
if app._process_background(state, event) is not None:
# `final=True` allows the frontend send more events immediately.
yield StateUpdate(final=True)
yield StateUpdate(final=True, scopes=state.scopes_and_subscopes())
return
# Process the event synchronously.
@ -1470,6 +1472,13 @@ class EventNamespace(AsyncNamespace):
sid: The Socket.IO session id.
room: The room to send the update to.
"""
# TODO We don't know when to leave a room yet.
for receiver in update.scopes:
if self.sid_to_token[sid] != receiver:
room = receiver
if room not in self.rooms(sid):
await self.enter_room(sid, room)
# Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=room or sid)

View File

@ -46,4 +46,4 @@ class HydrateMiddleware(Middleware):
state._clean()
# Return the state update.
return StateUpdate(delta=delta, events=[])
return StateUpdate(delta=delta, events=[], scopes=state.scopes_and_subscopes())

View File

@ -8,6 +8,7 @@ import copy
import functools
import inspect
import os
import traceback
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
@ -68,6 +69,22 @@ var = computed_var
TOO_LARGE_SERIALIZED_STATE = 100 * 1024 # 100kb
def print_stack(depth: int = 3):
"""Print the current stacktrace to the console.
Args:
depth: Depth of the stack-trace to print
"""
stack = traceback.extract_stack()
stack.reverse()
print("stacktrace")
for idx in range(1, depth + 1):
stack_info = stack[idx]
print(
f" {stack_info.name} {os.path.basename(stack_info.filename)}:{stack_info.lineno}"
)
class HeaderData(Base):
"""An object containing headers data."""
@ -1430,6 +1447,20 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
)
def scopes_and_subscopes(self) -> list[str]:
"""Recursively gathers all scopes of self and substates.
Returns:
A unique list of the scopes/token
"""
result = [self._get_token()]
for substate in self.substates.values():
subscopes = substate.scopes_and_subscopes()
for subscope in subscopes:
if subscope not in result:
result.append(subscope)
return result
def _get_token(self, other=None) -> str:
token = self.router.session.client_token
cls = other or self.__class__
@ -1481,6 +1512,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
delta=delta,
events=fixed_events,
final=final if not handler.is_background else True,
scopes=state.scopes_and_subscopes(),
)
except Exception as ex:
state._clean()
@ -2254,6 +2286,8 @@ class StateUpdate(Base):
# Whether this is the final state update for the event.
final: bool = True
scopes: list[str] = []
class StateManager(Base, ABC):
"""A class to manage many client states."""
@ -2531,6 +2565,10 @@ class StateManagerRedis(StateManager):
for substate_name, substate_task in tasks.items():
state.substates[substate_name] = await substate_task
# async def print_all_keys(self):
# for key in await self.redis.keys():
# print(f"redis_key: {key}")
async def get_state(
self,
token: str,