Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ type Association struct {

willSendAbort bool
willSendAbortCause errorCause
abortSentOnce sync.Once
abortSentCh chan struct{}

// Reconfig
myNextRSN uint32
Expand Down Expand Up @@ -452,6 +454,7 @@ func createAssociation(config Config) *Association {
name: config.Name,
blockWrite: config.BlockWrite,
writeNotify: make(chan struct{}, 1),
abortSentCh: make(chan struct{}),
}

// adaptive burst mitigation defaults
Expand Down Expand Up @@ -659,16 +662,31 @@ func (a *Association) Abort(reason string) {

a.lock.Unlock()

flushTimeout := 200 * time.Millisecond

// short bound for abort flush.
_ = a.netConn.SetWriteDeadline(time.Now().Add(200 * time.Millisecond))
_ = a.netConn.SetWriteDeadline(time.Now().Add(flushTimeout))
a.awakeWriteLoop()

// unblock readLoop even if the underlying TCP connection is half-open.
// Give writeLoop a chance to write the ABORT before we force readLoop to exit
// (readLoop exit closes closeWriteLoopCh and can race the ABORT send).
select {
case <-a.abortSentCh:
case <-time.After(flushTimeout):
}

// unblock readLoop even if the underlying connection is half-open.
// We want Abort to return promptly during shutdown.
_ = a.netConn.SetReadDeadline(time.Now())

// Wait for readLoop to end
<-a.readLoopCloseCh

// Ensure ABORT write was at least attempted before returning (bounded).
select {
case <-a.abortSentCh:
case <-time.After(flushTimeout):
}
}

func (a *Association) closeAllTimers() {
Expand Down Expand Up @@ -734,7 +752,7 @@ func (a *Association) readLoop() {
a.log.Debugf("[%s] readLoop exited %s", a.name, closeErr)
}

func (a *Association) writeLoop() {
func (a *Association) writeLoop() { // nolint:cyclop
a.log.Debugf("[%s] writeLoop entered", a.name)
defer a.log.Debugf("[%s] writeLoop exited", a.name)

Expand All @@ -743,7 +761,11 @@ loop:
rawPackets, ok := a.gatherOutbound()

for _, raw := range rawPackets {
isAbortPacket := len(raw) > int(commonHeaderSize) && raw[commonHeaderSize] == byte(ctAbort)
_, err := a.netConn.Write(raw)
if isAbortPacket {
a.abortSentOnce.Do(func() { close(a.abortSentCh) })
}
Comment on lines +764 to +768
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If gatherAbortPacket returns an error (line 1204), the abortSentCh channel is never closed because no packet is written. This causes Abort() to always wait for the full 200ms timeout on both wait blocks (lines 673-676 and 686-689), resulting in a 400ms delay even though the ABORT cannot be sent.

Consider closing abortSentCh immediately when gatherAbortPacket fails, or handle this error case in the writeLoop to signal that the ABORT attempt completed (even if it failed). This ensures Abort() doesn't unnecessarily block when the ABORT cannot be marshaled.

Copilot uses AI. Check for mistakes.
if err != nil {
if !errors.Is(err, io.EOF) {
a.log.Warnf("[%s] failed to write packets on netConn: %v", a.name, err)
Expand All @@ -767,6 +789,15 @@ loop:
select {
case <-a.awakeWriteLoopCh:
case <-a.closeWriteLoopCh:
a.lock.Lock()
abortPending := a.willSendAbort
a.lock.Unlock()
if abortPending {
// If an ABORT is pending, prefer sending it even if readLoop has
// already ended and closed closeWriteLoopCh.
continue
}

break loop
}
}
Expand Down
151 changes: 151 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4700,3 +4700,154 @@ func TestTLR_GetDataPacketsToRetransmit_RespectsBurstBudget_LaterRTT(t *testing.
assert.Equal(t, 2, nChunks)
assert.True(t, consumed)
}

type dummyAddr struct{}

func (dummyAddr) Network() string { return "dummy" }
func (dummyAddr) String() string { return "dummy" }

type recordConn struct {
mu sync.Mutex
writes [][]byte
done chan struct{}
}

func (c *recordConn) Read(_ []byte) (int, error) {
<-c.done

return 0, io.EOF
}

func (c *recordConn) Write(p []byte) (int, error) {
cp := make([]byte, len(p))
copy(cp, p)

c.mu.Lock()
c.writes = append(c.writes, cp)
c.mu.Unlock()

return len(p), nil
}

func (c *recordConn) Close() error {
select {
case <-c.done:
default:
close(c.done)
}

return nil
}
func (c *recordConn) LocalAddr() net.Addr { return dummyAddr{} }
func (c *recordConn) RemoteAddr() net.Addr { return dummyAddr{} }
func (c *recordConn) SetDeadline(_ time.Time) error { return nil }
func (c *recordConn) SetReadDeadline(_ time.Time) error { return nil }
func (c *recordConn) SetWriteDeadline(_ time.Time) error { return nil }

func (c *recordConn) firstWrite() ([]byte, bool) {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.writes) == 0 {
return nil, false
}

return c.writes[0], true
}

func TestAbortStillSendsWhenWriteLoopClosing(t *testing.T) {
conn := &recordConn{done: make(chan struct{})}
defer conn.Close() // nolint:errcheck

assoc := createAssociation(Config{
NetConn: conn,
LoggerFactory: logging.NewDefaultLoggerFactory(),
EnableZeroChecksum: false,
})
assoc.init(false)

// Simulate the problematic timing: writeLoop is sitting in its select and
// closeWriteLoopCh gets closed (e.g. readLoop exited) while an ABORT is pending.
assoc.lock.Lock()
assoc.willSendAbort = true
assoc.willSendAbortCause = &errorCauseUserInitiatedAbort{upperLayerAbortReason: []byte("x")}
assoc.lock.Unlock()

assoc.closeWriteLoopOnce.Do(func() { close(assoc.closeWriteLoopCh) })

require.Eventually(t, func() bool {
raw, ok := conn.firstWrite()
if !ok || len(raw) <= int(commonHeaderSize) {
return false
}

return raw[commonHeaderSize] == uint8(ctAbort)
}, 1*time.Second, 5*time.Millisecond)

require.NoError(t, assoc.close())
}

type abortOrderingConn struct {
readDeadlineOnce sync.Once
readDeadlineCh chan struct{}

wroteAbortOnce sync.Once
wroteAbortCh chan struct{}
}

func newAbortOrderingConn() *abortOrderingConn {
return &abortOrderingConn{
readDeadlineCh: make(chan struct{}),
wroteAbortCh: make(chan struct{}),
}
}

func (c *abortOrderingConn) Read(_ []byte) (int, error) {
<-c.readDeadlineCh

return 0, net.ErrClosed
}

func (c *abortOrderingConn) Write(b []byte) (int, error) {
isAbort := len(b) > int(commonHeaderSize) && b[commonHeaderSize] == byte(ctAbort)
if isAbort {
time.Sleep(100 * time.Millisecond)
c.wroteAbortOnce.Do(func() { close(c.wroteAbortCh) })
}

return len(b), nil
}

func (c *abortOrderingConn) Close() error { return nil }

func (c *abortOrderingConn) LocalAddr() net.Addr { return &net.IPAddr{} }
func (c *abortOrderingConn) RemoteAddr() net.Addr { return &net.IPAddr{} }

func (c *abortOrderingConn) SetDeadline(_ time.Time) error { return nil }

func (c *abortOrderingConn) SetReadDeadline(_ time.Time) error {
c.readDeadlineOnce.Do(func() { close(c.readDeadlineCh) })

return nil
}

func (c *abortOrderingConn) SetWriteDeadline(_ time.Time) error { return nil }

func TestAbort_WaitsForAbortWriteAttempt(t *testing.T) {
conn := newAbortOrderingConn()

assoc := createAssociation(Config{
NetConn: conn,
LoggerFactory: logging.NewDefaultLoggerFactory(),
MaxMessageSize: 1200,
})
assoc.init(false)

assoc.Abort("test")

select {
case <-conn.wroteAbortCh:
// ok
default:
require.Fail(t, "Abort returned before ABORT write attempt")
}
}
Loading