Skip to content

Concurrent safe writes #980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 27 additions & 36 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,20 +244,23 @@ type Conn struct {
subprotocol string

// Write fields
mu chan struct{} // used as mutex to protect write to conn
writeBuf []byte // frame is constructed in this buffer.
writePool BufferPool
writeBufSize int
writeDeadline time.Time
writer io.WriteCloser // the current writer returned to the application
isWriting bool // for best-effort concurrent write detection

writeErrMu sync.Mutex
writeErr error

enableWriteCompression bool
compressionLevel int
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
writeMux sync.Mutex // used to protect WriteMessage and WriteJSON from being called concurrently
mu chan struct{} // used in WriteControl to reset the timer when new messages are written
writeBuf []byte // frame is constructed in this buffer.
writePool BufferPool
writeBufSize int
writeDeadline time.Time
writeDeadlineMux sync.Mutex
writer io.WriteCloser // the current writer returned to the application
writerMux sync.Mutex // used to protect the writer attached to the connection
writeErr error
writeErrMu sync.Mutex

enableWriteCompression bool
enableWriteCompressionMux sync.Mutex
compressionLevel int
compressionLevelMux sync.Mutex
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser

// Read fields
reader io.ReadCloser // the current reader returned to the application
Expand Down Expand Up @@ -377,9 +380,6 @@ func (c *Conn) read(n int) ([]byte, error) {
}

func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
<-c.mu
defer func() { c.mu <- struct{}{} }()

c.writeErrMu.Lock()
err := c.writeErr
c.writeErrMu.Unlock()
Expand Down Expand Up @@ -525,6 +525,8 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
// PongMessage) are supported.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
c.writerMux.Lock()
defer c.writerMux.Unlock()
var mw messageWriter
if err := c.beginMessage(&mw, messageType); err != nil {
return nil, err
Expand Down Expand Up @@ -622,18 +624,8 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
// concurrent writes. See the concurrency section in the package
// documentation for more info.

if c.isWriting {
panic("concurrent write to websocket connection")
}
c.isWriting = true

err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)

if !c.isWriting {
panic("concurrent write to websocket connection")
}
c.isWriting = false

if err != nil {
return w.endMessage(err)
}
Expand Down Expand Up @@ -750,22 +742,15 @@ func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
if err != nil {
return err
}
if c.isWriting {
panic("concurrent write to websocket connection")
}
c.isWriting = true
err = c.write(frameType, c.writeDeadline, frameData, nil)
if !c.isWriting {
panic("concurrent write to websocket connection")
}
c.isWriting = false
return err
}

// WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer.
func (c *Conn) WriteMessage(messageType int, data []byte) error {

c.writeMux.Lock()
defer c.writeMux.Unlock()
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
// Fast path with no allocations and single frame.

Expand Down Expand Up @@ -794,6 +779,8 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
// all future writes will return an error. A zero value for t means writes will
// not time out.
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.writeDeadlineMux.Lock()
defer c.writeDeadlineMux.Unlock()
c.writeDeadline = t
return nil
}
Expand Down Expand Up @@ -1215,6 +1202,8 @@ func (c *Conn) UnderlyingConn() net.Conn {
// subsequent text and binary messages. This function is a noop if
// compression was not negotiated with the peer.
func (c *Conn) EnableWriteCompression(enable bool) {
c.enableWriteCompressionMux.Lock()
defer c.enableWriteCompressionMux.Unlock()
c.enableWriteCompression = enable
}

Expand All @@ -1226,6 +1215,8 @@ func (c *Conn) SetCompressionLevel(level int) error {
if !isValidCompressionLevel(level) {
return errors.New("websocket: invalid compression level")
}
c.compressionLevelMux.Lock()
defer c.compressionLevelMux.Unlock()
c.compressionLevel = level
return nil
}
Expand Down
129 changes: 106 additions & 23 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -691,37 +691,120 @@ func TestUnexpectedCloseErrors(t *testing.T) {
}
}

type blockingWriter struct {
c1, c2 chan struct{}
func TestConcurrencyNextWriter(t *testing.T) {
loop := 10
workers := 10
for i := 0; i < loop; i++ {
var connBuf bytes.Buffer

wg := sync.WaitGroup{}
wc := newTestConn(nil, &connBuf, true)

for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if _, err := wc.NextWriter(TextMessage); err != nil {
t.Errorf("concurrently wc.NextWriter() returned %v", err)
}
}()
}

wg.Wait()
wc.Close()
}
}

func (w blockingWriter) Write(p []byte) (int, error) {
// Allow main to continue
close(w.c1)
// Wait for panic in main
<-w.c2
return len(p), nil
func TestConcurrencyWriteMessage(t *testing.T) {
const message = "this is a pong messsage"
loop := 10
workers := 10
for i := 0; i < loop; i++ {
var connBuf bytes.Buffer

wg := sync.WaitGroup{}
wc := newTestConn(nil, &connBuf, true)

for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := wc.WriteMessage(PongMessage, []byte(message)); err != nil {
t.Errorf("concurrently wc.WriteMessage() returned %v", err)
}
}()
}

wg.Wait()
wc.Close()
}
}

func TestConcurrentWritePanic(t *testing.T) {
w := blockingWriter{make(chan struct{}), make(chan struct{})}
c := newTestConn(nil, w, false)
go func() {
_ = c.WriteMessage(TextMessage, []byte{})
}()
func TestConcurrencySetWriteDeadline(t *testing.T) {
loop := 10
workers := 10
for i := 0; i < loop; i++ {
var connBuf bytes.Buffer

// wait for goroutine to block in write.
<-w.c1
wg := sync.WaitGroup{}
wc := newTestConn(nil, &connBuf, true)

defer func() {
close(w.c2)
if v := recover(); v != nil {
return
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := wc.SetWriteDeadline(time.Now()); err != nil {
t.Errorf("concurrently wc.SetWriteDeadline() returned %v", err)
}
}()
}
}()

_ = c.WriteMessage(TextMessage, []byte{})
t.Fatal("should not get here")
wg.Wait()
wc.Close()
}
}

func TestConcurrencySetCompressionLevel(t *testing.T) {
loop := 10
workers := 10
for i := 0; i < loop; i++ {
var connBuf bytes.Buffer

wg := sync.WaitGroup{}
wc := newTestConn(nil, &connBuf, true)

for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := wc.SetCompressionLevel(defaultCompressionLevel); err != nil {
t.Errorf("concurrently wc.SetCompressionLevel() returned %v", err)
}
}()
}

wg.Wait()
wc.Close()
}
}

func TestConcurrentEnableWriteCompressionCalls(t *testing.T) {
var connBuf bytes.Buffer
wc := newTestConn(nil, &connBuf, false)
nGoroutines := 5
wg := &sync.WaitGroup{}
for i := 0; i < nGoroutines; i++ {
wg.Add(1)
go func(wg *sync.WaitGroup) {
wc.EnableWriteCompression(true)
wg.Done()
}(wg)
}
wg.Wait()
if !wc.enableWriteCompression {
t.Fatal("expected to enableWriteCompression to be true")
}
wc.Close()
}

type failingReader struct{}
Expand Down
Loading