Skip to content

Commit

Permalink
don't block on connection close
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Nov 18, 2024
1 parent 43cd707 commit ede18a5
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 17 deletions.
9 changes: 6 additions & 3 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@ func (e *GoAwayError) Temporary() bool {

func (e *GoAwayError) Is(target error) bool {
// to maintain compatibility with errors returned by previous versions
if e.Remote && target == ErrRemoteGoAway {
if e.Remote && target == ErrRemoteGoAwayNormal {
return true

Check warning on line 49 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L49

Added line #L49 was not covered by tests
} else if !e.Remote && target == ErrSessionShutdown {
return true

Check warning on line 51 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L51

Added line #L51 was not covered by tests
} else if target == ErrStreamReset {
// A GoAway on a connection also resets all the streams.
return true
}

Check warning on line 55 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L53-L55

Added lines #L53 - L55 were not covered by tests

if err, ok := target.(*GoAwayError); ok {
Expand Down Expand Up @@ -111,8 +114,8 @@ var (
// ErrUnexpectedFlag is set when we get an unexpected flag
ErrUnexpectedFlag = &Error{msg: "unexpected flag"}

// ErrRemoteGoAway is used when we get a go away from the other side
ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal}
// ErrRemoteGoAwayNormal is used when we get a go away from the other side
ErrRemoteGoAwayNormal = &GoAwayError{Remote: true, ErrorCode: goAwayNormal}

// ErrStreamReset is sent if a stream is reset. This can happen
// if the backlog is exceeded, or if there was a remote GoAway.
Expand Down
23 changes: 12 additions & 11 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ var nullMemoryManager = &nullMemoryManagerImpl{}
type Session struct {
rtt int64 // to be accessed atomically, in nanoseconds

// remoteGoAway indicates the remote side does
// remoteGoAwayNormal indicates the remote side does
// not want futher connections. Must be first for alignment.
remoteGoAway int32
remoteGoAwayNormal int32

// localGoAway indicates that we should stop
// accepting futher connections. Must be first for alignment.
Expand Down Expand Up @@ -205,8 +205,8 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) {
if s.IsClosed() {
return nil, s.shutdownErr
}
if atomic.LoadInt32(&s.remoteGoAway) == 1 {
return nil, ErrRemoteGoAway
if atomic.LoadInt32(&s.remoteGoAwayNormal) == 1 {
return nil, ErrRemoteGoAwayNormal
}

// Block if we have too many inflight SYNs
Expand Down Expand Up @@ -285,15 +285,15 @@ func (s *Session) AcceptStream() (*Stream, error) {
}
}

// Close is used to close the session and all streams.
// Attempts to send a GoAway before closing the connection. The GoAway may not actually be sent depending on the
// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or
// if there's unread data in the kernel receive buffer.
// Close is used to close the session and all streams. It doesn't send a GoAway before
// closing the connection.
func (s *Session) Close() error {
return s.close(ErrSessionShutdown, false, goAwayNormal)
}

// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
// Blocks for ConnectionWriteTimeout to write the GoAway message.
//
// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
// receive buffer.
Expand All @@ -315,7 +315,8 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro
close(s.shutdownCh)
s.stopKeepalive()

if sendGoAway {
// Only send GoAway if we have an error code.
if sendGoAway && errCode != goAwayNormal {
// wait for write loop to exit
// We need to write the current frame completely before sending a goaway.
// This will wait for at most s.config.ConnectionWriteTimeout
Expand All @@ -334,7 +335,7 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro
s.streamLock.Lock()
defer s.streamLock.Unlock()
for id, stream := range s.streams {
stream.forceClose(fmt.Errorf("%w: connection closed: %w", ErrStreamReset, s.shutdownErr))
stream.forceClose(s.shutdownErr)
delete(s.streams, id)
stream.memorySpan.Done()
}
Expand Down Expand Up @@ -814,7 +815,7 @@ func (s *Session) handleGoAway(hdr header) error {
code := hdr.Length()
switch code {
case goAwayNormal:
atomic.SwapInt32(&s.remoteGoAway, 1)
atomic.SwapInt32(&s.remoteGoAwayNormal, 1)
// Don't close connection on normal go away. Let the existing streams
// complete gracefully.
return nil
Expand Down
2 changes: 1 addition & 1 deletion session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ func TestGoAway(t *testing.T) {
switch err {
case nil:
s.Close()
case ErrRemoteGoAway:
case ErrRemoteGoAwayNormal:
return
default:
t.Fatalf("err: %v", err)
Expand Down
5 changes: 3 additions & 2 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ func (s *Stream) CloseWrite() error {
return nil
case halfReset:
s.stateLock.Unlock()
return ErrStreamReset
return s.writeErr
default:
panic("invalid state")
}
Expand All @@ -331,7 +331,8 @@ func (s *Stream) CloseWrite() error {
return err
}

// CloseRead is used to close the stream for writing.
// CloseRead is used to close the stream for reading.
// Note: Remote is not notified.
func (s *Stream) CloseRead() error {
cleanup := false
s.stateLock.Lock()
Expand Down

0 comments on commit ede18a5

Please sign in to comment.