diff --git a/README.rst b/README.rst index aa187986..4a3f49a9 100644 --- a/README.rst +++ b/README.rst @@ -105,6 +105,20 @@ should start with a slash, but not end with one; for example:: daphne --root-path=/forum django_project.asgi:application +Permessage compression +---------------------- + +Daphne supports and by default accepts ``permessage-deflate`` compression +(`permessage-deflate specification `_). +Additional ``permessage-bzip2``, ``permessage-snappy`` compressions will be also enabled by default if +``bz2`` and `snappy `_ python packages are available in daphne environment. +The compression implementation is provided by +`Autobahn|Python `_ package, see: +`permessage-deflate `_, +`permessage-bzip2 `_, +`permessage-snappy `_. + + Python Support -------------- diff --git a/daphne/server.py b/daphne/server.py index a6d38198..8ca22146 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -33,6 +33,7 @@ from concurrent.futures import CancelledError from functools import partial +from autobahn.websocket.compress import PERMESSAGE_COMPRESSION_EXTENSION as EXTENSIONS from twisted.internet import defer, reactor from twisted.internet.endpoints import serverFromString from twisted.logger import STDLibLogObserver, globalLogBeginner @@ -55,6 +56,11 @@ def __init__( request_buffer_size=8192, websocket_timeout=86400, websocket_connect_timeout=20, + websocket_permessage_compression_extensions=( + "permessage-deflate", + "permessage-bzip2", + "permessage-snappy", + ), ping_interval=20, ping_timeout=30, root_path="", @@ -83,6 +89,9 @@ def __init__( self.websocket_timeout = websocket_timeout self.websocket_connect_timeout = websocket_connect_timeout self.websocket_handshake_timeout = websocket_handshake_timeout + self.websocket_permessage_compression_extensions = ( + websocket_permessage_compression_extensions + ) self.application_close_timeout = application_close_timeout self.root_path = root_path self.verbosity = verbosity @@ -104,6 +113,7 @@ def run(self): autoPingTimeout=self.ping_timeout, allowNullOrigin=True, openHandshakeTimeout=self.websocket_handshake_timeout, + perMessageCompressionAccept=self.accept_permessage_compression_extension, ) if self.verbosity <= 1: # Redirect the Twisted log to nowhere @@ -258,6 +268,21 @@ def check_headers_type(message): ) ) + def accept_permessage_compression_extension(self, offers): + """ + Accepts websocket compression extension as required by `autobahn` package. + """ + for offer in offers: + for ext in self.websocket_permessage_compression_extensions: + if ext in EXTENSIONS and isinstance(offer, EXTENSIONS[ext]["Offer"]): + return EXTENSIONS[ext]["OfferAccept"](offer) + elif ext not in EXTENSIONS: + logger.warning( + "Compression extension %s could not be accepted. " + "It is not supported or a dependency is missing.", + ext, + ) + ### Utility def application_checker(self): diff --git a/daphne/ws_protocol.py b/daphne/ws_protocol.py index 975b1a9e..bf28c7e5 100755 --- a/daphne/ws_protocol.py +++ b/daphne/ws_protocol.py @@ -182,7 +182,10 @@ def handle_reply(self, message): if "type" not in message: raise ValueError("Message has no type defined") if message["type"] == "websocket.accept": - self.serverAccept(message.get("subprotocol", None)) + self.serverAccept( + message.get("subprotocol", None), message.get("headers", None) + ) + elif message["type"] == "websocket.close": if self.state == self.STATE_CONNECTING: self.serverReject() @@ -214,11 +217,15 @@ def handle_exception(self, exception): else: self.sendCloseFrame(code=1011) - def serverAccept(self, subprotocol=None): + def serverAccept(self, subprotocol=None, headers=None): """ Called when we get a message saying to accept the connection. """ - self.handshake_deferred.callback(subprotocol) + if headers is None: + self.handshake_deferred.callback(subprotocol) + else: + headers_dict = {key.decode(): value.decode() for key, value in headers} + self.handshake_deferred.callback((subprotocol, headers_dict)) del self.handshake_deferred logger.debug("WebSocket %s accepted by application", self.client_addr) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index e9544863..9cbb9a1d 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -139,6 +139,60 @@ def test_subprotocols(self): self.assert_valid_websocket_scope(scope, subprotocols=subprotocols) self.assert_valid_websocket_connect_message(messages[0]) + def test_accept_permessage_deflate_extension(self): + """ + Tests that permessage-deflate extension is successfuly accepted + by underlying `autobahn` package. + """ + + headers = [ + ( + b"Sec-WebSocket-Extensions", + b"permessage-deflate; client_max_window_bits", + ), + ] + + with DaphneTestingInstance() as test_app: + test_app.add_send_messages( + [ + { + "type": "websocket.accept", + } + ] + ) + + sock, subprotocol = self.websocket_handshake( + test_app, + headers=headers, + ) + # Validate the scope and messages we got + scope, messages = test_app.get_received() + self.assert_valid_websocket_connect_message(messages[0]) + + def test_accept_custom_extension(self): + """ + Tests that custom headers can be accpeted during handshake. + """ + with DaphneTestingInstance() as test_app: + test_app.add_send_messages( + [ + { + "type": "websocket.accept", + "headers": [(b"Sec-WebSocket-Extensions", b"custom-extension")], + } + ] + ) + + sock, subprotocol = self.websocket_handshake( + test_app, + headers=[ + (b"Sec-WebSocket-Extensions", b"custom-extension"), + ], + ) + # Validate the scope and messages we got + scope, messages = test_app.get_received() + self.assert_valid_websocket_connect_message(messages[0]) + def test_xff(self): """ Tests that X-Forwarded-For headers get parsed right