From 7cbdebb5104eed150a2a5b216b04426fd7d0e154 Mon Sep 17 00:00:00 2001 From: Carlton Gibson Date: Tue, 24 May 2022 12:34:10 +0200 Subject: [PATCH] Set a default Server header for HTTP responses. --- CHANGELOG.txt | 9 ++++++++- daphne/cli.py | 7 +++++-- daphne/http_protocol.py | 4 ++-- daphne/server.py | 2 +- tests/test_cli.py | 15 +++++++++++++++ tests/test_http_response.py | 9 ++++++--- 6 files changed, 37 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.txt b/CHANGELOG.txt index 630c0fa1..31e46504 100644 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -11,7 +11,14 @@ Unreleased range of versions does not represent a good use of maintainer time. Going forward the latest Twisted version will be required. -* Added `log-fmt` CLI argument. +* Set ``daphne`` as default ``Server`` header. + + This can be configured with the ``--server-name`` CLI argument. + + Added the new ``--no-server-name`` CLI argument to disable the ``Server`` + header, which is equivalent to ``--server-name=` (an empty name). + +* Added ``--log-fmt`` CLI argument. 3.0.2 (2021-04-07) ------------------ diff --git a/daphne/cli.py b/daphne/cli.py index 2e83a5c1..accafe10 100755 --- a/daphne/cli.py +++ b/daphne/cli.py @@ -93,7 +93,7 @@ def __init__(self): self.parser.add_argument( "--log-fmt", help="Log format to use", - default="%(asctime)-15s %(levelname)-8s %(message)s" + default="%(asctime)-15s %(levelname)-8s %(message)s", ) self.parser.add_argument( "--ping-interval", @@ -162,7 +162,10 @@ def __init__(self): "--server-name", dest="server_name", help="specify which value should be passed to response header Server attribute", - default="Daphne", + default="daphne", + ) + self.parser.add_argument( + "--no-server-name", dest="server_name", action="store_const", const="" ) self.server = None diff --git a/daphne/http_protocol.py b/daphne/http_protocol.py index a289e936..f0657fdb 100755 --- a/daphne/http_protocol.py +++ b/daphne/http_protocol.py @@ -249,8 +249,8 @@ def handle_reply(self, message): # Write headers for header, value in message.get("headers", {}): self.responseHeaders.addRawHeader(header, value) - if self.server.server_name and self.server.server_name.lower() != "daphne": - self.setHeader(b"server", self.server.server_name.encode("utf-8")) + if self.server.server_name and not self.responseHeaders.hasHeader("server"): + self.setHeader(b"server", self.server.server_name.encode()) logger.debug( "HTTP %s response started for %s", message["status"], self.client_addr ) diff --git a/daphne/server.py b/daphne/server.py index 0d463d03..43342175 100755 --- a/daphne/server.py +++ b/daphne/server.py @@ -56,7 +56,7 @@ def __init__( websocket_handshake_timeout=5, application_close_timeout=10, ready_callable=None, - server_name="Daphne", + server_name="daphne", # Deprecated and does not work, remove in version 2.2 ws_protocols=None, ): diff --git a/tests/test_cli.py b/tests/test_cli.py index 17335ed0..51eab2ea 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -240,3 +240,18 @@ def test_custom_proxyport(self): exc.exception.message, "--proxy-headers has to be passed for this parameter.", ) + + def test_custom_servername(self): + """ + Passing `--server-name` will set the default server header + from 'daphne' to the passed one. + """ + self.assertCLI([], {"server_name": "daphne"}) + self.assertCLI(["--server-name", ""], {"server_name": ""}) + self.assertCLI(["--server-name", "python"], {"server_name": "python"}) + + def test_no_servername(self): + """ + Passing `--no-server-name` will set server name to '' (empty string) + """ + self.assertCLI(["--no-server-name"], {"server_name": ""}) diff --git a/tests/test_http_response.py b/tests/test_http_response.py index 1fc24395..22f6480a 100644 --- a/tests/test_http_response.py +++ b/tests/test_http_response.py @@ -13,9 +13,12 @@ def normalize_headers(self, headers): Lowercases and sorts headers, and strips transfer-encoding ones. """ return sorted( - (name.lower(), value.strip()) - for name, value in headers - if name.lower() != b"transfer-encoding" + [(b"server", b"daphne")] + + [ + (name.lower(), value.strip()) + for name, value in headers + if name.lower() not in (b"server", b"transfer-encoding") + ] ) def encode_headers(self, headers):