1
1
import asyncio
2
+ import collections
2
3
import datetime
4
+ import io
3
5
import os
4
6
import sys
5
7
import websockets
8
+ import websockets .extensions .permessage_deflate
9
+ import websockets .framing
6
10
7
11
8
12
DEBUG = 'WSDEBUG' in os .environ and os .environ ['WSDEBUG' ] == '1'
@@ -22,63 +26,121 @@ async def stdin(loop):
22
26
return reader
23
27
24
28
25
- async def stdin_to_amplifier (amplifier , loop ):
29
+ async def stdin_to_amplifier (amplifier , loop , stats ):
26
30
reader = await stdin (loop )
27
31
while True :
28
- amplifier .send ((await reader .readline ()).decode ('utf-8' ).strip ())
32
+ d = await reader .readline ()
33
+ stats ['stdin read' ] += len (d )
34
+ amplifier .send (d .decode ('utf-8' ).strip ())
35
+
36
+
37
+ def websocket_extensions_to_key (extensions ):
38
+ # Convert a list of websockets extensions into a key, handling PerMessageDeflate objects with the relevant care for server-side compression dedupe
39
+ def _inner ():
40
+ for e in extensions :
41
+ if isinstance (e , websockets .extensions .permessage_deflate .PerMessageDeflate ) and e .local_no_context_takeover :
42
+ yield (websockets .extensions .permessage_deflate .PerMessageDeflate , e .remote_max_window_bits , e .local_max_window_bits , tuple (e .compress_settings .items ()))
43
+ else :
44
+ yield e
45
+ return tuple (_inner ())
29
46
30
47
31
48
class MessageAmplifier :
32
- def __init__ (self ):
33
- self .queues = {}
49
+ def __init__ (self , stats ):
50
+ self .queues = {} # websocket -> queue
51
+ self ._stats = stats
34
52
35
53
def register (self , websocket ):
36
- self .queues [websocket ] = asyncio .Queue (maxsize = 1000 )
37
- return self .queues [websocket ]
54
+ q = asyncio .Queue (maxsize = 1000 )
55
+ self .queues [websocket ] = q
56
+ return q
38
57
39
58
def send (self , message ):
40
- for queue in self .queues .values ():
59
+ #FIXME This abuses internal API of websockets==7.0
60
+ # Using the normal `websocket.send` reencodes and recompresses the message for every client.
61
+ # So we construct the relevant Frame once instead and push that to the individual queues.
62
+ frame = websockets .framing .Frame (fin = True , opcode = websockets .framing .OP_TEXT , data = message .encode ('utf-8' ))
63
+ data = {} # tuple of extensions key → bytes
64
+ for websocket , queue in self .queues .items ():
65
+ extensionsKey = websocket_extensions_to_key (websocket .extensions )
66
+ if extensionsKey not in data :
67
+ output = io .BytesIO ()
68
+ frame .write (output .write , mask = False , extensions = websocket .extensions )
69
+ data [extensionsKey ] = output .getvalue ()
70
+ self ._stats ['frame writes' ] += len (data [extensionsKey ])
41
71
try :
42
- queue .put_nowait (message )
72
+ queue .put_nowait (data [ extensionsKey ] )
43
73
except asyncio .QueueFull :
44
74
# Pop one, try again; it should be impossible for this to fail, so no try/except here.
45
- queue .get_nowait ()
46
- queue .put_nowait (message )
75
+ dropped = queue .get_nowait ()
76
+ self ._stats ['dropped' ] += len (dropped )
77
+ queue .put_nowait (data [extensionsKey ])
47
78
48
79
def unregister (self , websocket ):
49
80
del self .queues [websocket ]
50
81
51
82
52
- async def websocket_server (amplifier , websocket , path ):
83
+ async def websocket_server (amplifier , websocket , path , stats ):
53
84
queue = amplifier .register (websocket )
54
85
try :
55
86
while True :
56
- await websocket .send (await queue .get ())
87
+ #FIXME See above; this is write_frame essentially
88
+ data = await queue .get ()
89
+ await websocket .ensure_open ()
90
+ websocket .writer .write (data )
91
+ stats ['sent' ] += len (data )
92
+ if websocket .writer .transport is not None :
93
+ if websocket .writer_is_closing ():
94
+ await asyncio .sleep (0 )
95
+ try :
96
+ async with websocket ._drain_lock :
97
+ await websocket .writer .drain ()
98
+ except ConnectionError :
99
+ websocket .fail_connection ()
100
+ await websocket .ensure_open ()
57
101
except websockets .exceptions .ConnectionClosed : # Silence connection closures
58
102
pass
59
103
finally :
60
104
amplifier .unregister (websocket )
61
105
62
106
63
- async def print_status (amplifier ):
107
+ async def print_status (amplifier , stats ):
108
+ interval = 60
64
109
previousUtime = None
110
+ previousStats = {}
65
111
while True :
66
112
currentUtime = os .times ().user
67
- cpu = (currentUtime - previousUtime ) / 60 * 100 if previousUtime is not None else float ('nan' )
68
- print (f'{ datetime .datetime .now ():%Y-%m-%d %H:%M:%S} - { len (amplifier .queues )} clients, { sum (q .qsize () for q in amplifier .queues .values ())} total queue size, { cpu :.1f} % CPU, { get_rss ()/ 1048576 :.1f} MiB RSS' )
113
+ cpu = (currentUtime - previousUtime ) / interval * 100 if previousUtime is not None else float ('nan' )
114
+ print (f'{ datetime .datetime .now ():%Y-%m-%d %H:%M:%S} - ' +
115
+ ', ' .join ([
116
+ f'{ len (amplifier .queues )} clients' ,
117
+ f'{ sum (q .qsize () for q in amplifier .queues .values ())} total queue size' ,
118
+ f'{ cpu :.1f} % CPU' ,
119
+ f'{ get_rss ()/ 1048576 :.1f} MiB RSS' ,
120
+ 'throughput: ' + ', ' .join (f'{ (stats [k ] - previousStats .get (k , 0 ))/ interval / 1000 :.1f} kB/s { k } ' for k in stats ),
121
+ ])
122
+ )
69
123
if DEBUG :
70
124
for socket in amplifier .queues :
71
125
print (f' { socket .remote_address } : { amplifier .queues [socket ].qsize ()} ' )
72
126
previousUtime = currentUtime
73
- await asyncio .sleep (60 )
127
+ previousStats .update (stats )
128
+ await asyncio .sleep (interval )
74
129
75
130
76
131
def main ():
77
- amplifier = MessageAmplifier ()
78
- start_server = websockets .serve (lambda websocket , path : websocket_server (amplifier , websocket , path ), None , 4568 )
132
+ stats = {'stdin read' : 0 , 'frame writes' : 0 , 'sent' : 0 , 'dropped' : 0 }
133
+ amplifier = MessageAmplifier (stats )
134
+ # Disable context takeover (cf. RFC 7692) so the compression can be reused
135
+ start_server = websockets .serve (
136
+ lambda websocket , path : websocket_server (amplifier , websocket , path , stats ),
137
+ None ,
138
+ 4568 ,
139
+ extensions = [websockets .extensions .permessage_deflate .ServerPerMessageDeflateFactory (server_no_context_takeover = True )]
140
+ )
79
141
loop = asyncio .get_event_loop ()
80
142
loop .run_until_complete (start_server )
81
- loop .run_until_complete (asyncio .gather (stdin_to_amplifier (amplifier , loop ), print_status (amplifier )))
143
+ loop .run_until_complete (asyncio .gather (stdin_to_amplifier (amplifier , loop , stats ), print_status (amplifier , stats )))
82
144
83
145
84
146
if __name__ == '__main__' :
0 commit comments