20
20
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
21
SOFTWARE.
22
22
"""
23
+ import secrets
23
24
from typing import Any
24
25
25
26
import aiohttp
@@ -50,7 +51,7 @@ def __init__(self, *, session: aiohttp.ClientSession, database: core.Database) -
50
51
Middleware (AuthenticationMiddleware , backend = AuthBackend (self )),
51
52
]
52
53
53
- self .sockets : dict [int , WebSocket ] = {}
54
+ self .sockets : dict [int , dict [ str , WebSocket ] ] = {}
54
55
self .subscription_sockets : dict [str , set [int ]] = {
55
56
core .WebsocketSubscriptions .DPY_MOD_LOG : set ()
56
57
}
@@ -71,7 +72,12 @@ async def websocket_connector(self, websocket: WebSocket) -> None:
71
72
uid : int | None = core .id_from_token (token )
72
73
73
74
assert uid
74
- self .sockets [uid ] = websocket
75
+
76
+ hash_ = secrets .token_urlsafe (8 )
77
+ try :
78
+ self .sockets [uid ][hash_ ] = websocket
79
+ except KeyError :
80
+ self .sockets [uid ] = {hash_ : websocket }
75
81
76
82
# Filter out bad subscriptions...
77
83
valid : list [str ] = list (self .subscription_sockets .keys ())
@@ -116,12 +122,8 @@ async def websocket_connector(self, websocket: WebSocket) -> None:
116
122
}
117
123
await websocket .send_json (data = response )
118
124
119
- # Remove the websocket and it's subscriptions...
120
- del self .sockets [uid ]
121
-
122
- subscribed : list [str ] = [sub for sub in self .subscription_sockets if uid in self .subscription_sockets [sub ]]
123
- for sub in subscribed :
124
- self .subscription_sockets [sub ].remove (uid )
125
+ # Remove the websocket...
126
+ del self .sockets [uid ][hash_ ]
125
127
126
128
def websocket_subscribe (self , * , uid : int , message : dict [str , Any ]) -> dict [str , Any ]:
127
129
subs : list [str ] = message .get ('subscriptions' , [])
@@ -151,17 +153,24 @@ def websocket_unsubscribe(self, *, uid: int, message: dict[str, Any]) -> dict[st
151
153
# Filter out bad subscriptions...
152
154
valid : list [str ] = list (self .subscription_sockets .keys ())
153
155
subscriptions : list [str ] = [sub for sub in subs if sub in valid ]
156
+ removed : list [str ] = []
154
157
155
158
for sub in subscriptions :
156
- self .subscription_sockets [sub ].remove (uid )
159
+
160
+ try :
161
+ self .subscription_sockets [sub ].remove (uid )
162
+ except KeyError :
163
+ pass
164
+ else :
165
+ removed .append (sub )
157
166
158
167
subscribed : list [str ] = [sub for sub in self .subscription_sockets if uid in self .subscription_sockets [sub ]]
159
168
160
169
data : dict [str , Any ] = {
161
170
'op' : core .WebsocketOPCodes .NOTIFICATION ,
162
171
'type' : core .WebsocketNotificationTypes .SUBSCRIPTION_REMOVED ,
163
172
'user_id' : uid ,
164
- 'removed' : subscriptions ,
173
+ 'removed' : removed ,
165
174
'subscriptions' : subscribed
166
175
}
167
176
0 commit comments