Skip to content

Commit 69d0eb9

Browse files
authored
Add check for Sec-WebSocket-Key header (gorilla#752)
* add Sec-WebSocket-Key header verification * add testcase to Sec-WebSocket-Key header verification
1 parent 9111bb8 commit 69d0eb9

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

server.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
154154
}
155155

156156
challengeKey := r.Header.Get("Sec-Websocket-Key")
157-
if challengeKey == "" {
158-
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
157+
if !isValidChallengeKey(challengeKey) {
158+
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
159159
}
160160

161161
subprotocol := u.selectSubprotocol(r, responseHeader)

util.go

+15
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,18 @@ headers:
281281
}
282282
return result
283283
}
284+
285+
// isValidChallengeKey checks if the argument meets RFC6455 specification.
286+
func isValidChallengeKey(s string) bool {
287+
// From RFC6455:
288+
//
289+
// A |Sec-WebSocket-Key| header field with a base64-encoded (see
290+
// Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
291+
// length.
292+
293+
if s == "" {
294+
return false
295+
}
296+
decoded, err := base64.StdEncoding.DecodeString(s)
297+
return err == nil && len(decoded) == 16
298+
}

util_test.go

+19
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@ func TestTokenListContainsValue(t *testing.T) {
5353
}
5454
}
5555

56+
var isValidChallengeKeyTests = []struct {
57+
key string
58+
ok bool
59+
}{
60+
{"dGhlIHNhbXBsZSBub25jZQ==", true},
61+
{"", false},
62+
{"InvalidKey", false},
63+
{"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false},
64+
}
65+
66+
func TestIsValidChallengeKey(t *testing.T) {
67+
for _, tt := range isValidChallengeKeyTests {
68+
ok := isValidChallengeKey(tt.key)
69+
if ok != tt.ok {
70+
t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok)
71+
}
72+
}
73+
}
74+
5675
var parseExtensionTests = []struct {
5776
value string
5877
extensions []map[string]string

0 commit comments

Comments
 (0)