diff --git a/lib/websocket/driver.rb b/lib/websocket/driver.rb index 586f6a6..3c04ae7 100644 --- a/lib/websocket/driver.rb +++ b/lib/websocket/driver.rb @@ -78,6 +78,7 @@ def initialize(socket, options = {}) @options = options @max_length = options[:max_length] || MAX_LENGTH @headers = Headers.new + @custom_headers = Headers.new @queue = [] @ready_state = 0 @@ -95,7 +96,7 @@ def add_extension(extension) def set_header(name, value) return false unless @ready_state <= 0 - @headers[name] = value + @custom_headers[name] = value true end diff --git a/lib/websocket/driver/client.rb b/lib/websocket/driver/client.rb index fadda79..6da49e4 100644 --- a/lib/websocket/driver/client.rb +++ b/lib/websocket/driver/client.rb @@ -18,28 +18,13 @@ def initialize(socket, options = {}) @accept = Hybi.generate_accept(@key) @http = HTTP::Response.new - uri = URI.parse(@socket.url) - unless VALID_SCHEMES.include?(uri.scheme) + @uri = URI.parse(@socket.url) + unless VALID_SCHEMES.include?(@uri.scheme) raise URIError, "#{ socket.url } is not a valid WebSocket URL" end - path = (uri.path == '') ? '/' : uri.path - @pathname = path + (uri.query ? '?' + uri.query : '') - - @headers['Host'] = Driver.host_header(uri) - @headers['Upgrade'] = 'websocket' - @headers['Connection'] = 'Upgrade' - @headers['Sec-WebSocket-Key'] = @key - @headers['Sec-WebSocket-Version'] = VERSION - - if @protocols.size > 0 - @headers['Sec-WebSocket-Protocol'] = @protocols * ', ' - end - - if uri.user - auth = Base64.strict_encode64([uri.user, uri.password] * ':') - @headers['Authorization'] = 'Basic ' + auth - end + path = (@uri.path == '') ? '/' : @uri.path + @pathname = path + (@uri.query ? '?' + @uri.query : '') end def version @@ -75,11 +60,26 @@ def parse(chunk) private def handshake_request + @headers['Host'] = Driver.host_header(@uri) + @headers['Upgrade'] = 'websocket' + @headers['Connection'] = 'Upgrade' + @headers['Sec-WebSocket-Key'] = @key + @headers['Sec-WebSocket-Version'] = VERSION + + if @protocols.size > 0 + @headers['Sec-WebSocket-Protocol'] = @protocols * ', ' + end + + if @uri.user + auth = Base64.strict_encode64([@uri.user, @uri.password] * ':') + @headers['Authorization'] = 'Basic ' + auth + end + extensions = @extensions.generate_offer @headers['Sec-WebSocket-Extensions'] = extensions if extensions start = "GET #{ @pathname } HTTP/1.1" - headers = [start, @headers.to_s, ''] + headers = [start, @headers.merge(@custom_headers).to_s, ''] headers.join("\r\n") end diff --git a/lib/websocket/driver/headers.rb b/lib/websocket/driver/headers.rb index 769dd49..12d6fbf 100644 --- a/lib/websocket/driver/headers.rb +++ b/lib/websocket/driver/headers.rb @@ -14,7 +14,7 @@ def initialize(received = {}) def clear @sent = Set.new - @lines = [] + @headers = {} end def [](name) @@ -25,7 +25,17 @@ def []=(name, value) return if value.nil? key = HTTP.normalize_header(name) return unless @sent.add?(key) or ALLOWED_DUPLICATES.include?(key) - @lines << "#{ name.strip }: #{ value.to_s.strip }\r\n" + + header = name.strip + unless @headers.include? header + @headers[header] = [] + end + + @headers[header] << value.to_s.strip + end + + def headers_set + @headers end def inspect @@ -36,8 +46,24 @@ def to_h @raw.dup end + def merge(other) + other.headers_set.each {|key, value| + if value.is_a? Array + value.each {|v| self[key] = v } + else + self[key] = value + end + } + + self + end + def to_s - @lines.join('') + @headers.flat_map {|header, values| + values.map {|value| + "#{ header.strip }: #{ value.to_s.strip }\r\n" + } + }.join('') end end diff --git a/lib/websocket/driver/hybi.rb b/lib/websocket/driver/hybi.rb index 4a0fb1b..a1289c9 100644 --- a/lib/websocket/driver/hybi.rb +++ b/lib/websocket/driver/hybi.rb @@ -62,21 +62,12 @@ def initialize(socket, options = {}) @extensions = ::WebSocket::Extensions.new @stage = 0 @masking = options[:masking] - @protocols = options[:protocols] || [] - @protocols = @protocols.strip.split(/ *, */) if String === @protocols @require_masking = options[:require_masking] @ping_callbacks = {} @frame = @message = nil - return unless @socket.respond_to?(:env) - - if protos = @socket.env['HTTP_SEC_WEBSOCKET_PROTOCOL'] - protos = protos.split(/ *, */) if String === protos - @protocol = protos.find { |p| @protocols.include?(p) } - else - @protocol = nil - end + set_protocols(options[:protocols] || []) end def version @@ -88,6 +79,22 @@ def add_extension(extension) true end + def set_protocols(protocols) + return false unless @ready_state <= 0 + + @protocols = protocols + @protocols = @protocols.strip.split(/ *, */) if String === @protocols + + return unless @socket.respond_to?(:env) + + if protos = @socket.env['HTTP_SEC_WEBSOCKET_PROTOCOL'] + protos = protos.split(/ *, */) if String === protos + @protocol = protos.find { |p| @protocols.include?(p) } + else + @protocol = nil + end + end + def parse(chunk) @reader.put(chunk) buffer = true @@ -253,7 +260,7 @@ def handshake_response @headers['Sec-WebSocket-Extensions'] = extensions if extensions start = 'HTTP/1.1 101 Switching Protocols' - headers = [start, @headers.to_s, ''] + headers = [start, @custom_headers.merge(@headers).to_s, ''] headers.join("\r\n") end diff --git a/spec/websocket/driver/client_spec.rb b/spec/websocket/driver/client_spec.rb index 2dd2ea5..977caa0 100644 --- a/spec/websocket/driver/client_spec.rb +++ b/spec/websocket/driver/client_spec.rb @@ -61,6 +61,23 @@ expect(@close).to eq [1000, ""] end end + + describe :set_protocols do + it "writes the handshake with Sec-WebSocket-Protocol" do + expect(socket).to receive(:write).with( + "GET /socket HTTP/1.1\r\n" + + "Host: www.example.com\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Key: 2vBVWg4Qyk3ZoM/5d3QD9Q==\r\n" + + "Sec-WebSocket-Version: 13\r\n" + + "Sec-WebSocket-Protocol: foo, bar, xmpp\r\n" + + "\r\n") + + driver.set_protocols(["foo", "bar", "xmpp"]) + driver.start + end + end describe :start do it "writes the handshake request to the socket" do