Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for sending error codes on session close #121

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
40 changes: 38 additions & 2 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import (
"encoding/binary"
"fmt"
"time"
)

type Error struct {
Expand All @@ -22,6 +23,40 @@
return ye.temporary
}

type GoAwayError struct {
ErrorCode uint32
Remote bool
}

func (e *GoAwayError) Error() string {
if e.Remote {
return fmt.Sprintf("remote sent go away, code: %d", e.ErrorCode)

Check warning on line 33 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L33

Added line #L33 was not covered by tests
}
return fmt.Sprintf("sent go away, code: %d", e.ErrorCode)
}

func (e *GoAwayError) Timeout() bool {
return false

Check warning on line 39 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L38-L39

Added lines #L38 - L39 were not covered by tests
}

func (e *GoAwayError) Temporary() bool {
return false

Check warning on line 43 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L42-L43

Added lines #L42 - L43 were not covered by tests
}

func (e *GoAwayError) Is(target error) bool {
// to maintain compatibility with errors returned by previous versions
if e.Remote && target == ErrRemoteGoAway {
return true

Check warning on line 49 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L49

Added line #L49 was not covered by tests
} else if !e.Remote && target == ErrSessionShutdown {
return true

Check warning on line 51 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L51

Added line #L51 was not covered by tests
}

if err, ok := target.(*GoAwayError); ok {
return *e == *err
}
return false

Check warning on line 57 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L57

Added line #L57 was not covered by tests
}

var (
// ErrInvalidVersion means we received a frame with an
// invalid version
Expand All @@ -33,7 +68,7 @@

// ErrSessionShutdown is used if there is a shutdown during
// an operation
ErrSessionShutdown = &Error{msg: "session shutdown"}
ErrSessionShutdown = &GoAwayError{ErrorCode: goAwayNormal, Remote: false}

// ErrStreamsExhausted is returned if we have no more
// stream ids to issue
Expand All @@ -56,7 +91,7 @@
ErrUnexpectedFlag = &Error{msg: "unexpected flag"}

// ErrRemoteGoAway is used when we get a go away from the other side
ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"}
ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal}

// ErrStreamReset is sent if a stream is reset. This can happen
// if the backlog is exceeded, or if there was a remote GoAway.
Expand Down Expand Up @@ -117,6 +152,7 @@
// It's not an implementation choice, the value defined in the specification.
initialStreamWindow = 256 * 1024
maxStreamWindow = 16 * 1024 * 1024
goAwayWaitTime = 5 * time.Second
sukunrt marked this conversation as resolved.
Show resolved Hide resolved
)

const (
Expand Down
81 changes: 57 additions & 24 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ type Session struct {
// recvDoneCh is closed when recv() exits to avoid a race
// between stream registration and stream shutdown
recvDoneCh chan struct{}
// recvErr is the error the receive loop ended with
recvErr error

// sendDoneCh is closed when send() exits to avoid a race
// between returning from a Stream.Write and exiting from the send loop
Expand Down Expand Up @@ -284,8 +286,22 @@ func (s *Session) AcceptStream() (*Stream, error) {
}

// Close is used to close the session and all streams.
// Attempts to send a GoAway before closing the connection.
// Attempts to send a GoAway before closing the connection. The GoAway may not actually be sent depending on the
// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or
// if there's unread data in the kernel receive buffer.
sukunrt marked this conversation as resolved.
Show resolved Hide resolved
func (s *Session) Close() error {
return s.close(ErrSessionShutdown, true, goAwayNormal)
}

// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
// receive buffer.
func (s *Session) CloseWithError(errCode uint32) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we updated the connection manager to be able to deal with potential blocking here? Also, we should probably document it.

return s.close(&GoAwayError{Remote: false, ErrorCode: errCode}, true, errCode)
}

func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) error {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()

Expand All @@ -294,13 +310,26 @@ func (s *Session) Close() error {
}
s.shutdown = true
if s.shutdownErr == nil {
s.shutdownErr = ErrSessionShutdown
s.shutdownErr = shutdownErr
}
close(s.shutdownCh)
s.conn.Close()
s.stopKeepalive()
<-s.recvDoneCh

if sendGoAway {
// wait for write loop to exit
// We need to write the current frame completely before sending a goaway.
// This will wait for at most s.config.ConnectionWriteTimeout
<-s.sendDoneCh
ga := s.goAway(errCode)
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
}
s.conn.SetWriteDeadline(time.Time{})
}

s.conn.Close()
<-s.sendDoneCh
<-s.recvDoneCh

s.streamLock.Lock()
defer s.streamLock.Unlock()
Expand All @@ -312,17 +341,6 @@ func (s *Session) Close() error {
return nil
}

// exitErr is used to handle an error that is causing the
// session to terminate.
func (s *Session) exitErr(err error) {
s.shutdownLock.Lock()
if s.shutdownErr == nil {
s.shutdownErr = err
}
s.shutdownLock.Unlock()
s.Close()
}

// GoAway can be used to prevent accepting further
// connections. It does not close the underlying conn.
func (s *Session) GoAway() error {
Expand Down Expand Up @@ -451,7 +469,7 @@ func (s *Session) startKeepalive() {

if err != nil {
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
s.exitErr(ErrKeepAliveTimeout)
s.close(ErrKeepAliveTimeout, false, 0)
}
})
}
Expand Down Expand Up @@ -516,7 +534,19 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
// send is a long running goroutine that sends data
func (s *Session) send() {
if err := s.sendLoop(); err != nil {
s.exitErr(err)
// Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
// received in a GoAway frame received just before the TCP RST that closed the sendLoop
s.shutdownLock.Lock()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment as to why you are holding the shutdownLock around this section.

Copy link
Member Author

@sukunrt sukunrt Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comment, can you review once more?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That looks better, thanks.

if s.shutdownErr == nil {
s.conn.Close()
<-s.recvDoneCh
if _, ok := s.recvErr.(*GoAwayError); ok {
err = s.recvErr
}
s.shutdownErr = err
}
s.shutdownLock.Unlock()
s.close(err, false, 0)
}
}

Expand Down Expand Up @@ -644,7 +674,7 @@ func (s *Session) sendLoop() (err error) {
// recv is a long running goroutine that accepts new data
func (s *Session) recv() {
if err := s.recvLoop(); err != nil {
s.exitErr(err)
s.close(err, false, 0)
}
}

Expand All @@ -666,7 +696,10 @@ func (s *Session) recvLoop() (err error) {
err = fmt.Errorf("panic in yamux receive loop: %s", rerr)
}
}()
defer close(s.recvDoneCh)
defer func() {
s.recvErr = err
close(s.recvDoneCh)
}()
var hdr header
for {
// fmt.Printf("ReadFull from %#v\n", s.reader)
Expand Down Expand Up @@ -782,17 +815,17 @@ func (s *Session) handleGoAway(hdr header) error {
switch code {
case goAwayNormal:
atomic.SwapInt32(&s.remoteGoAway, 1)
// Don't close connection on normal go away. Let the existing streams
// complete gracefully.
return nil
case goAwayProtoErr:
s.logger.Printf("[ERR] yamux: received protocol error go away")
return fmt.Errorf("yamux protocol error")
case goAwayInternalErr:
s.logger.Printf("[ERR] yamux: received internal error go away")
return fmt.Errorf("remote yamux internal error")
default:
s.logger.Printf("[ERR] yamux: received unexpected go away")
return fmt.Errorf("unexpected go away received")
s.logger.Printf("[ERR] yamux: received go away with error code: %d", code)
}
return nil
return &GoAwayError{Remote: true, ErrorCode: code}
}

// incomingStream is used to create a new incoming stream
Expand Down
43 changes: 39 additions & 4 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package yamux
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -39,6 +40,8 @@ type pipeConn struct {
writeDeadline pipeDeadline
writeBlocker chan struct{}
closeCh chan struct{}
closeOnce sync.Once
closeErr error
}

func (p *pipeConn) SetDeadline(t time.Time) error {
Expand All @@ -65,10 +68,12 @@ func (p *pipeConn) Write(b []byte) (int, error) {
}

func (p *pipeConn) Close() error {
p.writeDeadline.set(time.Time{})
err := p.Conn.Close()
close(p.closeCh)
return err
p.closeOnce.Do(func() {
p.writeDeadline.set(time.Time{})
p.closeErr = p.Conn.Close()
close(p.closeCh)
})
return p.closeErr
}

func (p *pipeConn) BlockWrites() {
Expand Down Expand Up @@ -650,6 +655,35 @@ func TestGoAway(t *testing.T) {
default:
t.Fatalf("err: %v", err)
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("expected GoAway error")
}

func TestCloseWithError(t *testing.T) {
// This test is noisy.
conf := testConf()
conf.LogOutput = io.Discard

client, server := testClientServerConfig(conf)
defer client.Close()
defer server.Close()

if err := server.CloseWithError(42); err != nil {
t.Fatalf("err: %v", err)
}

for i := 0; i < 100; i++ {
s, err := client.Open(context.Background())
if err == nil {
s.Close()
time.Sleep(50 * time.Millisecond)
continue
}
if !errors.Is(err, &GoAwayError{ErrorCode: 42, Remote: true}) {
t.Fatalf("err: %v", err)
}
return
}
t.Fatalf("expected GoAway error")
}
Expand Down Expand Up @@ -1048,6 +1082,7 @@ func TestKeepAlive_Timeout(t *testing.T) {
// Prevent the client from responding
clientConn := client.conn.(*pipeConn)
clientConn.BlockWrites()
defer clientConn.UnblockWrites()

select {
case err := <-errCh:
Expand Down