forked from cumulo-autumn/StreamDiffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconnection_manager.py
116 lines (99 loc) · 3.92 KB
/
connection_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from typing import Dict, Union
from uuid import UUID
import asyncio
from fastapi import WebSocket
from starlette.websockets import WebSocketState
import logging
from types import SimpleNamespace
Connections = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]]
class ServerFullException(Exception):
"""Exception raised when the server is full."""
pass
class ConnectionManager:
def __init__(self):
self.active_connections: Connections = {}
async def connect(
self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0
):
await websocket.accept()
user_count = self.get_user_count()
print(f"User count: {user_count}")
if max_queue_size > 0 and user_count >= max_queue_size:
print("Server is full")
await websocket.send_json({"status": "error", "message": "Server is full"})
await websocket.close()
raise ServerFullException("Server is full")
print(f"New user connected: {user_id}")
self.active_connections[user_id] = {
"websocket": websocket,
"queue": asyncio.Queue(),
}
await websocket.send_json(
{"status": "connected", "message": "Connected"},
)
await websocket.send_json({"status": "wait"})
await websocket.send_json({"status": "send_frame"})
def check_user(self, user_id: UUID) -> bool:
return user_id in self.active_connections
async def update_data(self, user_id: UUID, new_data: SimpleNamespace):
user_session = self.active_connections.get(user_id)
if user_session:
queue = user_session["queue"]
while not queue.empty():
try:
queue.get_nowait()
except asyncio.QueueEmpty:
continue
await queue.put(new_data)
async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
user_session = self.active_connections.get(user_id)
if user_session:
queue = user_session["queue"]
try:
return await queue.get()
except asyncio.QueueEmpty:
return None
def delete_user(self, user_id: UUID):
user_session = self.active_connections.pop(user_id, None)
if user_session:
queue = user_session["queue"]
while not queue.empty():
try:
queue.get_nowait()
except asyncio.QueueEmpty:
continue
def get_user_count(self) -> int:
return len(self.active_connections)
def get_websocket(self, user_id: UUID) -> WebSocket:
user_session = self.active_connections.get(user_id)
if user_session:
websocket = user_session["websocket"]
if websocket.client_state == WebSocketState.CONNECTED:
return user_session["websocket"]
return None
async def disconnect(self, user_id: UUID):
websocket = self.get_websocket(user_id)
if websocket:
await websocket.close()
self.delete_user(user_id)
async def send_json(self, user_id: UUID, data: Dict):
try:
websocket = self.get_websocket(user_id)
if websocket:
await websocket.send_json(data)
except Exception as e:
logging.error(f"Error: Send json: {e}")
async def receive_json(self, user_id: UUID) -> Dict:
try:
websocket = self.get_websocket(user_id)
if websocket:
return await websocket.receive_json()
except Exception as e:
logging.error(f"Error: Receive json: {e}")
async def receive_bytes(self, user_id: UUID) -> bytes:
try:
websocket = self.get_websocket(user_id)
if websocket:
return await websocket.receive_bytes()
except Exception as e:
logging.error(f"Error: Receive bytes: {e}")