diff --git a/conn.go b/conn.go index 9562ffd4..32963199 100644 --- a/conn.go +++ b/conn.go @@ -244,20 +244,23 @@ type Conn struct { subprotocol string // Write fields - mu chan struct{} // used as mutex to protect write to conn - writeBuf []byte // frame is constructed in this buffer. - writePool BufferPool - writeBufSize int - writeDeadline time.Time - writer io.WriteCloser // the current writer returned to the application - isWriting bool // for best-effort concurrent write detection - - writeErrMu sync.Mutex - writeErr error - - enableWriteCompression bool - compressionLevel int - newCompressionWriter func(io.WriteCloser, int) io.WriteCloser + writeMux sync.Mutex // used to protect WriteMessage and WriteJSON from being called concurrently + mu chan struct{} // used in WriteControl to reset the timer when new messages are written + writeBuf []byte // frame is constructed in this buffer. + writePool BufferPool + writeBufSize int + writeDeadline time.Time + writeDeadlineMux sync.Mutex + writer io.WriteCloser // the current writer returned to the application + writerMux sync.Mutex // used to protect the writer attached to the connection + writeErr error + writeErrMu sync.Mutex + + enableWriteCompression bool + enableWriteCompressionMux sync.Mutex + compressionLevel int + compressionLevelMux sync.Mutex + newCompressionWriter func(io.WriteCloser, int) io.WriteCloser // Read fields reader io.ReadCloser // the current reader returned to the application @@ -377,9 +380,6 @@ func (c *Conn) read(n int) ([]byte, error) { } func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { - <-c.mu - defer func() { c.mu <- struct{}{} }() - c.writeErrMu.Lock() err := c.writeErr c.writeErrMu.Unlock() @@ -525,6 +525,8 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and // PongMessage) are supported. func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + c.writerMux.Lock() + defer c.writerMux.Unlock() var mw messageWriter if err := c.beginMessage(&mw, messageType); err != nil { return nil, err @@ -622,18 +624,8 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { // concurrent writes. See the concurrency section in the package // documentation for more info. - if c.isWriting { - panic("concurrent write to websocket connection") - } - c.isWriting = true - err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) - if !c.isWriting { - panic("concurrent write to websocket connection") - } - c.isWriting = false - if err != nil { return w.endMessage(err) } @@ -750,22 +742,15 @@ func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { if err != nil { return err } - if c.isWriting { - panic("concurrent write to websocket connection") - } - c.isWriting = true err = c.write(frameType, c.writeDeadline, frameData, nil) - if !c.isWriting { - panic("concurrent write to websocket connection") - } - c.isWriting = false return err } // WriteMessage is a helper method for getting a writer using NextWriter, // writing the message and closing the writer. func (c *Conn) WriteMessage(messageType int, data []byte) error { - + c.writeMux.Lock() + defer c.writeMux.Unlock() if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { // Fast path with no allocations and single frame. @@ -794,6 +779,8 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { // all future writes will return an error. A zero value for t means writes will // not time out. func (c *Conn) SetWriteDeadline(t time.Time) error { + c.writeDeadlineMux.Lock() + defer c.writeDeadlineMux.Unlock() c.writeDeadline = t return nil } @@ -1215,6 +1202,8 @@ func (c *Conn) UnderlyingConn() net.Conn { // subsequent text and binary messages. This function is a noop if // compression was not negotiated with the peer. func (c *Conn) EnableWriteCompression(enable bool) { + c.enableWriteCompressionMux.Lock() + defer c.enableWriteCompressionMux.Unlock() c.enableWriteCompression = enable } @@ -1226,6 +1215,8 @@ func (c *Conn) SetCompressionLevel(level int) error { if !isValidCompressionLevel(level) { return errors.New("websocket: invalid compression level") } + c.compressionLevelMux.Lock() + defer c.compressionLevelMux.Unlock() c.compressionLevel = level return nil } diff --git a/conn_test.go b/conn_test.go index 28f5c4a3..f8457ba1 100644 --- a/conn_test.go +++ b/conn_test.go @@ -691,37 +691,120 @@ func TestUnexpectedCloseErrors(t *testing.T) { } } -type blockingWriter struct { - c1, c2 chan struct{} +func TestConcurrencyNextWriter(t *testing.T) { + loop := 10 + workers := 10 + for i := 0; i < loop; i++ { + var connBuf bytes.Buffer + + wg := sync.WaitGroup{} + wc := newTestConn(nil, &connBuf, true) + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if _, err := wc.NextWriter(TextMessage); err != nil { + t.Errorf("concurrently wc.NextWriter() returned %v", err) + } + }() + } + + wg.Wait() + wc.Close() + } } -func (w blockingWriter) Write(p []byte) (int, error) { - // Allow main to continue - close(w.c1) - // Wait for panic in main - <-w.c2 - return len(p), nil +func TestConcurrencyWriteMessage(t *testing.T) { + const message = "this is a pong messsage" + loop := 10 + workers := 10 + for i := 0; i < loop; i++ { + var connBuf bytes.Buffer + + wg := sync.WaitGroup{} + wc := newTestConn(nil, &connBuf, true) + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := wc.WriteMessage(PongMessage, []byte(message)); err != nil { + t.Errorf("concurrently wc.WriteMessage() returned %v", err) + } + }() + } + + wg.Wait() + wc.Close() + } } -func TestConcurrentWritePanic(t *testing.T) { - w := blockingWriter{make(chan struct{}), make(chan struct{})} - c := newTestConn(nil, w, false) - go func() { - _ = c.WriteMessage(TextMessage, []byte{}) - }() +func TestConcurrencySetWriteDeadline(t *testing.T) { + loop := 10 + workers := 10 + for i := 0; i < loop; i++ { + var connBuf bytes.Buffer - // wait for goroutine to block in write. - <-w.c1 + wg := sync.WaitGroup{} + wc := newTestConn(nil, &connBuf, true) - defer func() { - close(w.c2) - if v := recover(); v != nil { - return + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := wc.SetWriteDeadline(time.Now()); err != nil { + t.Errorf("concurrently wc.SetWriteDeadline() returned %v", err) + } + }() } - }() - _ = c.WriteMessage(TextMessage, []byte{}) - t.Fatal("should not get here") + wg.Wait() + wc.Close() + } +} + +func TestConcurrencySetCompressionLevel(t *testing.T) { + loop := 10 + workers := 10 + for i := 0; i < loop; i++ { + var connBuf bytes.Buffer + + wg := sync.WaitGroup{} + wc := newTestConn(nil, &connBuf, true) + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := wc.SetCompressionLevel(defaultCompressionLevel); err != nil { + t.Errorf("concurrently wc.SetCompressionLevel() returned %v", err) + } + }() + } + + wg.Wait() + wc.Close() + } +} + +func TestConcurrentEnableWriteCompressionCalls(t *testing.T) { + var connBuf bytes.Buffer + wc := newTestConn(nil, &connBuf, false) + nGoroutines := 5 + wg := &sync.WaitGroup{} + for i := 0; i < nGoroutines; i++ { + wg.Add(1) + go func(wg *sync.WaitGroup) { + wc.EnableWriteCompression(true) + wg.Done() + }(wg) + } + wg.Wait() + if !wc.enableWriteCompression { + t.Fatal("expected to enableWriteCompression to be true") + } + wc.Close() } type failingReader struct{} diff --git a/doc.go b/doc.go index 8db0cef9..154c6ae6 100644 --- a/doc.go +++ b/doc.go @@ -4,40 +4,40 @@ // Package websocket implements the WebSocket protocol defined in RFC 6455. // -// Overview +// # Overview // // The Conn type represents a WebSocket connection. A server application calls // the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: // -// var upgrader = websocket.Upgrader{ -// ReadBufferSize: 1024, -// WriteBufferSize: 1024, -// } +// var upgrader = websocket.Upgrader{ +// ReadBufferSize: 1024, +// WriteBufferSize: 1024, +// } // -// func handler(w http.ResponseWriter, r *http.Request) { -// conn, err := upgrader.Upgrade(w, r, nil) -// if err != nil { -// log.Println(err) -// return -// } -// ... Use conn to send and receive messages. -// } +// func handler(w http.ResponseWriter, r *http.Request) { +// conn, err := upgrader.Upgrade(w, r, nil) +// if err != nil { +// log.Println(err) +// return +// } +// ... Use conn to send and receive messages. +// } // // Call the connection's WriteMessage and ReadMessage methods to send and // receive messages as a slice of bytes. This snippet of code shows how to echo // messages using these methods: // -// for { -// messageType, p, err := conn.ReadMessage() -// if err != nil { -// log.Println(err) -// return -// } -// if err := conn.WriteMessage(messageType, p); err != nil { -// log.Println(err) -// return -// } -// } +// for { +// messageType, p, err := conn.ReadMessage() +// if err != nil { +// log.Println(err) +// return +// } +// if err := conn.WriteMessage(messageType, p); err != nil { +// log.Println(err) +// return +// } +// } // // In above snippet of code, p is a []byte and messageType is an int with value // websocket.BinaryMessage or websocket.TextMessage. @@ -49,24 +49,24 @@ // method to get an io.Reader and read until io.EOF is returned. This snippet // shows how to echo messages using the NextWriter and NextReader methods: // -// for { -// messageType, r, err := conn.NextReader() -// if err != nil { -// return -// } -// w, err := conn.NextWriter(messageType) -// if err != nil { -// return err -// } -// if _, err := io.Copy(w, r); err != nil { -// return err -// } -// if err := w.Close(); err != nil { -// return err -// } -// } -// -// Data Messages +// for { +// messageType, r, err := conn.NextReader() +// if err != nil { +// return +// } +// w, err := conn.NextWriter(messageType) +// if err != nil { +// return err +// } +// if _, err := io.Copy(w, r); err != nil { +// return err +// } +// if err := w.Close(); err != nil { +// return err +// } +// } +// +// # Data Messages // // The WebSocket protocol distinguishes between text and binary data messages. // Text messages are interpreted as UTF-8 encoded text. The interpretation of @@ -80,7 +80,7 @@ // It is the application's responsibility to ensure that text messages are // valid UTF-8 encoded text. // -// Control Messages +// # Control Messages // // The WebSocket protocol defines three types of control messages: close, ping // and pong. Call the connection WriteControl, WriteMessage or NextWriter @@ -110,30 +110,30 @@ // in messages from the peer, then the application should start a goroutine to // read and discard messages from the peer. A simple example is: // -// func readLoop(c *websocket.Conn) { -// for { -// if _, _, err := c.NextReader(); err != nil { -// c.Close() -// break -// } -// } -// } +// func readLoop(c *websocket.Conn) { +// for { +// if _, _, err := c.NextReader(); err != nil { +// c.Close() +// break +// } +// } +// } // -// Concurrency +// # Concurrency // -// Connections support one concurrent reader and one concurrent writer. +// Connections support one concurrent reader. // -// Applications are responsible for ensuring that no more than one goroutine -// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, -// WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and -// that no more than one goroutine calls the read methods (NextReader, -// SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler) +// Applications are responsible for ensuring that +// no more than one goroutine calls the read methods +// (NextReader, SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler) // concurrently. // -// The Close and WriteControl methods can be called concurrently with all other -// methods. +// The write methods (NextWriter, SetWriteDeadline, WriteMessage, +// WriteJSON, EnableWriteCompression, SetCompressionLevel, +// Close and WriteControl) +// can be called concurrently with all other methods. // -// Origin Considerations +// # Origin Considerations // // Web browsers allow Javascript applications to open a WebSocket connection to // any host. It's up to the server to enforce an origin policy using the Origin @@ -151,7 +151,7 @@ // checking. The application is responsible for checking the Origin header // before calling the Upgrade function. // -// Buffers +// # Buffers // // Connections buffer network input and output to reduce the number // of system calls when reading or writing messages. @@ -198,16 +198,16 @@ // buffer size has a reduced impact on total memory use and has the benefit of // reducing system calls and frame overhead. // -// Compression EXPERIMENTAL +// # Compression EXPERIMENTAL // // Per message compression extensions (RFC 7692) are experimentally supported // by this package in a limited capacity. Setting the EnableCompression option // to true in Dialer or Upgrader will attempt to negotiate per message deflate // support. // -// var upgrader = websocket.Upgrader{ -// EnableCompression: true, -// } +// var upgrader = websocket.Upgrader{ +// EnableCompression: true, +// } // // If compression was successfully negotiated with the connection's peer, any // message received in compressed form will be automatically decompressed. @@ -216,7 +216,7 @@ // Per message compression of messages written to a connection can be enabled // or disabled by calling the corresponding Conn method: // -// conn.EnableWriteCompression(false) +// conn.EnableWriteCompression(false) // // Currently this package does not support compression with "context takeover". // This means that messages must be compressed and decompressed in isolation, diff --git a/json.go b/json.go index dc2c1f64..aff9baae 100644 --- a/json.go +++ b/json.go @@ -21,6 +21,8 @@ func WriteJSON(c *Conn, v interface{}) error { // See the documentation for encoding/json Marshal for details about the // conversion of Go values to JSON. func (c *Conn) WriteJSON(v interface{}) error { + c.writeMux.Lock() + defer c.writeMux.Unlock() w, err := c.NextWriter(TextMessage) if err != nil { return err diff --git a/json_test.go b/json_test.go index e4c4bdfe..1cd75c35 100644 --- a/json_test.go +++ b/json_test.go @@ -37,6 +37,31 @@ func TestJSON(t *testing.T) { } } +func TestConcurrentWriteJsonCalls(t *testing.T) { + var buf bytes.Buffer + wc := newTestConn(nil, &buf, false) + var jsonMsg struct { + A int + B string + } + jsonMsg.A = 1 + jsonMsg.B = "hello" + nGoroutines := 5 + done := make(chan error, nGoroutines) + for i := 0; i < nGoroutines; i++ { + go func() { + err := wc.WriteJSON(jsonMsg) + done <- err + }() + } + for i := 0; i < nGoroutines; i++ { + err := <-done + if err != nil { + t.Fatal(err) + } + } +} + func TestPartialJSONRead(t *testing.T) { var buf0, buf1 bytes.Buffer wc := newTestConn(nil, &buf0, true) diff --git a/prepared.go b/prepared.go index c854225e..092c0764 100644 --- a/prepared.go +++ b/prepared.go @@ -73,12 +73,9 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { // Prepare a frame using a 'fake' connection. // TODO: Refactor code in conn.go to allow more direct construction of // the frame. - mu := make(chan struct{}, 1) - mu <- struct{}{} var nc prepareConn c := &Conn{ conn: &nc, - mu: mu, isServer: key.isServer, compressionLevel: key.compressionLevel, enableWriteCompression: true, diff --git a/server.go b/server.go index 02ea01fd..bd734f01 100644 --- a/server.go +++ b/server.go @@ -370,4 +370,3 @@ func (b *brNetConn) Read(p []byte) (n int, err error) { func (b *brNetConn) NetConn() net.Conn { return b.Conn } -