Skip to content

Commit 03b3376

Browse files
Reduce server load from the WebSocket server
Previously, the server was encoding and compressing each message for every client. This resulted in high server load, to the point that the WebSocket server couldn't keep up and dropped messages. Unfortunately, the websockets package doesn't officially expose its innards. Therefore, this is a disgusting hack that inspects the enabled extensions, prepares the data, and then sends that directly, all using internal APIs of websockets version 7.0. Compression (cf. RFC 7692) introduces a further complexity: context takeover. Normally, the compression context is reused across messages, but because clients connect at different times and might not be receiving all messages, the same compressor cannot be used across connections. Therefore, context takeover is disabled here. This also adds some stats about the throughput.
1 parent dd5ce62 commit 03b3376

File tree

2 files changed

+83
-21
lines changed

2 files changed

+83
-21
lines changed

INSTALL.backend

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ To run the backend, you will need:
99
- Bundler
1010
- ExecJS supported runtime (for the dashboard)
1111
(see https://github.com/sstephenson/execjs)
12-
- Python 3.6+ and websockets (for the dashboard WebSocket)
12+
- Python 3.6+ and websockets 7.0 (for the dashboard WebSocket)
1313

1414
(Little known fact: ArchiveBot is made to be as hard as possible to set
1515
up. If you have trouble with these instructions, drop by in IRC for
@@ -26,7 +26,7 @@ Quick install, for Debian and Debian-esque systems like Ubuntu:
2626
cd ArchiveBot
2727
git submodule update --init
2828
bundle install
29-
pip install websockets # Or apt install python3-websockets, or whichever method you prefer.
29+
pip install websockets==7.0 # Or apt install python3-websockets, or whichever method you prefer, but it must be version 7.0.
3030

3131

3232
** STEP 2: INSTALL REDIS **

dashboard/websocket.py

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import asyncio
2+
import collections
23
import datetime
4+
import io
35
import os
46
import sys
57
import websockets
8+
import websockets.extensions.permessage_deflate
9+
import websockets.framing
610

711

812
DEBUG = 'WSDEBUG' in os.environ and os.environ['WSDEBUG'] == '1'
@@ -22,63 +26,121 @@ async def stdin(loop):
2226
return reader
2327

2428

25-
async def stdin_to_amplifier(amplifier, loop):
29+
async def stdin_to_amplifier(amplifier, loop, stats):
2630
reader = await stdin(loop)
2731
while True:
28-
amplifier.send((await reader.readline()).decode('utf-8').strip())
32+
d = await reader.readline()
33+
stats['stdin read'] += len(d)
34+
amplifier.send(d.decode('utf-8').strip())
35+
36+
37+
def websocket_extensions_to_key(extensions):
38+
# Convert a list of websockets extensions into a key, handling PerMessageDeflate objects with the relevant care for server-side compression dedupe
39+
def _inner():
40+
for e in extensions:
41+
if isinstance(e, websockets.extensions.permessage_deflate.PerMessageDeflate) and e.local_no_context_takeover:
42+
yield (websockets.extensions.permessage_deflate.PerMessageDeflate, e.remote_max_window_bits, e.local_max_window_bits, tuple(e.compress_settings.items()))
43+
else:
44+
yield e
45+
return tuple(_inner())
2946

3047

3148
class MessageAmplifier:
32-
def __init__(self):
33-
self.queues = {}
49+
def __init__(self, stats):
50+
self.queues = {} # websocket -> queue
51+
self._stats = stats
3452

3553
def register(self, websocket):
36-
self.queues[websocket] = asyncio.Queue(maxsize = 1000)
37-
return self.queues[websocket]
54+
q = asyncio.Queue(maxsize = 1000)
55+
self.queues[websocket] = q
56+
return q
3857

3958
def send(self, message):
40-
for queue in self.queues.values():
59+
#FIXME This abuses internal API of websockets==7.0
60+
# Using the normal `websocket.send` reencodes and recompresses the message for every client.
61+
# So we construct the relevant Frame once instead and push that to the individual queues.
62+
frame = websockets.framing.Frame(fin = True, opcode = websockets.framing.OP_TEXT, data = message.encode('utf-8'))
63+
data = {} # tuple of extensions key → bytes
64+
for websocket, queue in self.queues.items():
65+
extensionsKey = websocket_extensions_to_key(websocket.extensions)
66+
if extensionsKey not in data:
67+
output = io.BytesIO()
68+
frame.write(output.write, mask = False, extensions = websocket.extensions)
69+
data[extensionsKey] = output.getvalue()
70+
self._stats['frame writes'] += len(data[extensionsKey])
4171
try:
42-
queue.put_nowait(message)
72+
queue.put_nowait(data[extensionsKey])
4373
except asyncio.QueueFull:
4474
# Pop one, try again; it should be impossible for this to fail, so no try/except here.
45-
queue.get_nowait()
46-
queue.put_nowait(message)
75+
dropped = queue.get_nowait()
76+
self._stats['dropped'] += len(dropped)
77+
queue.put_nowait(data[extensionsKey])
4778

4879
def unregister(self, websocket):
4980
del self.queues[websocket]
5081

5182

52-
async def websocket_server(amplifier, websocket, path):
83+
async def websocket_server(amplifier, websocket, path, stats):
5384
queue = amplifier.register(websocket)
5485
try:
5586
while True:
56-
await websocket.send(await queue.get())
87+
#FIXME See above; this is write_frame essentially
88+
data = await queue.get()
89+
await websocket.ensure_open()
90+
websocket.writer.write(data)
91+
stats['sent'] += len(data)
92+
if websocket.writer.transport is not None:
93+
if websocket.writer_is_closing():
94+
await asyncio.sleep(0)
95+
try:
96+
async with websocket._drain_lock:
97+
await websocket.writer.drain()
98+
except ConnectionError:
99+
websocket.fail_connection()
100+
await websocket.ensure_open()
57101
except websockets.exceptions.ConnectionClosed: # Silence connection closures
58102
pass
59103
finally:
60104
amplifier.unregister(websocket)
61105

62106

63-
async def print_status(amplifier):
107+
async def print_status(amplifier, stats):
108+
interval = 60
64109
previousUtime = None
110+
previousStats = {}
65111
while True:
66112
currentUtime = os.times().user
67-
cpu = (currentUtime - previousUtime) / 60 * 100 if previousUtime is not None else float('nan')
68-
print(f'{datetime.datetime.now():%Y-%m-%d %H:%M:%S} - {len(amplifier.queues)} clients, {sum(q.qsize() for q in amplifier.queues.values())} total queue size, {cpu:.1f} % CPU, {get_rss()/1048576:.1f} MiB RSS')
113+
cpu = (currentUtime - previousUtime) / interval * 100 if previousUtime is not None else float('nan')
114+
print(f'{datetime.datetime.now():%Y-%m-%d %H:%M:%S} - ' +
115+
', '.join([
116+
f'{len(amplifier.queues)} clients',
117+
f'{sum(q.qsize() for q in amplifier.queues.values())} total queue size',
118+
f'{cpu:.1f} % CPU',
119+
f'{get_rss()/1048576:.1f} MiB RSS',
120+
'throughput: ' + ', '.join(f'{(stats[k] - previousStats.get(k, 0))/interval/1000:.1f} kB/s {k}' for k in stats),
121+
])
122+
)
69123
if DEBUG:
70124
for socket in amplifier.queues:
71125
print(f' {socket.remote_address}: {amplifier.queues[socket].qsize()}')
72126
previousUtime = currentUtime
73-
await asyncio.sleep(60)
127+
previousStats.update(stats)
128+
await asyncio.sleep(interval)
74129

75130

76131
def main():
77-
amplifier = MessageAmplifier()
78-
start_server = websockets.serve(lambda websocket, path: websocket_server(amplifier, websocket, path), None, 4568)
132+
stats = {'stdin read': 0, 'frame writes': 0, 'sent': 0, 'dropped': 0}
133+
amplifier = MessageAmplifier(stats)
134+
# Disable context takeover (cf. RFC 7692) so the compression can be reused
135+
start_server = websockets.serve(
136+
lambda websocket, path: websocket_server(amplifier, websocket, path, stats),
137+
None,
138+
4568,
139+
extensions = [websockets.extensions.permessage_deflate.ServerPerMessageDeflateFactory(server_no_context_takeover = True)]
140+
)
79141
loop = asyncio.get_event_loop()
80142
loop.run_until_complete(start_server)
81-
loop.run_until_complete(asyncio.gather(stdin_to_amplifier(amplifier, loop), print_status(amplifier)))
143+
loop.run_until_complete(asyncio.gather(stdin_to_amplifier(amplifier, loop, stats), print_status(amplifier, stats)))
82144

83145

84146
if __name__ == '__main__':

0 commit comments

Comments
 (0)