Skip to content

Commit e8bb2b4

Browse files
committed
clientconn: Wait for all goroutines on close
Three goroutines could outlive a call to ClientConn.close(). Add mechanics to cancel them and wait for them to complete when closing a client connection. RELEASE NOTES: - Closing a client connection will cancel all pending goroutines and block until they complete. Signed-off-by: Tom Wieczorek <[email protected]>
1 parent 7472d57 commit e8bb2b4

File tree

6 files changed

+111
-38
lines changed

6 files changed

+111
-38
lines changed

balancer_wrapper.go

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ type acBalancerWrapper struct {
282282
// dropped or updated. This is required as closures can't be compared for
283283
// equality.
284284
healthData *healthData
285+
286+
shutdownMu sync.Mutex
287+
shutdownCh chan struct{}
288+
activeGofuncs sync.WaitGroup
285289
}
286290

287291
// healthData holds data related to health state reporting.
@@ -347,16 +351,45 @@ func (acbw *acBalancerWrapper) String() string {
347351
}
348352

349353
func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
350-
acbw.ac.updateAddrs(addrs)
354+
acbw.goFunc(func(shutdown <-chan struct{}) {
355+
acbw.ac.updateAddrs(shutdown, addrs)
356+
})
351357
}
352358

353359
func (acbw *acBalancerWrapper) Connect() {
354-
go acbw.ac.connect()
360+
acbw.goFunc(acbw.ac.connect)
361+
}
362+
363+
func (acbw *acBalancerWrapper) goFunc(fn func(shutdown <-chan struct{})) {
364+
acbw.shutdownMu.Lock()
365+
defer acbw.shutdownMu.Unlock()
366+
367+
shutdown := acbw.shutdownCh
368+
if shutdown == nil {
369+
shutdown = make(chan struct{})
370+
acbw.shutdownCh = shutdown
371+
}
372+
373+
acbw.activeGofuncs.Add(1)
374+
go func() {
375+
defer acbw.activeGofuncs.Done()
376+
fn(shutdown)
377+
}()
355378
}
356379

357380
func (acbw *acBalancerWrapper) Shutdown() {
358381
acbw.closeProducers()
359382
acbw.ccb.cc.removeAddrConn(acbw.ac, errConnDrain)
383+
384+
acbw.shutdownMu.Lock()
385+
defer acbw.shutdownMu.Unlock()
386+
387+
shutdown := acbw.shutdownCh
388+
acbw.shutdownCh = nil
389+
if shutdown != nil {
390+
close(shutdown)
391+
acbw.activeGofuncs.Wait()
392+
}
360393
}
361394

362395
// NewStream begins a streaming RPC on the addrConn. If the addrConn is not

clientconn.go

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -925,25 +925,24 @@ func (cc *ClientConn) incrCallsFailed() {
925925
// connect starts creating a transport.
926926
// It does nothing if the ac is not IDLE.
927927
// TODO(bar) Move this to the addrConn section.
928-
func (ac *addrConn) connect() error {
928+
func (ac *addrConn) connect(abort <-chan struct{}) {
929929
ac.mu.Lock()
930930
if ac.state == connectivity.Shutdown {
931931
if logger.V(2) {
932932
logger.Infof("connect called on shutdown addrConn; ignoring.")
933933
}
934934
ac.mu.Unlock()
935-
return errConnClosing
935+
return
936936
}
937937
if ac.state != connectivity.Idle {
938938
if logger.V(2) {
939939
logger.Infof("connect called on addrConn in non-idle state (%v); ignoring.", ac.state)
940940
}
941941
ac.mu.Unlock()
942-
return nil
942+
return
943943
}
944944

945-
ac.resetTransportAndUnlock()
946-
return nil
945+
ac.resetTransportAndUnlock(abort)
947946
}
948947

949948
// equalAddressIgnoringBalAttributes returns true is a and b are considered equal.
@@ -962,7 +961,7 @@ func equalAddressesIgnoringBalAttributes(a, b []resolver.Address) bool {
962961

963962
// updateAddrs updates ac.addrs with the new addresses list and handles active
964963
// connections or connection attempts.
965-
func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
964+
func (ac *addrConn) updateAddrs(abort <-chan struct{}, addrs []resolver.Address) {
966965
addrs = copyAddresses(addrs)
967966
limit := len(addrs)
968967
if limit > 5 {
@@ -1018,7 +1017,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
10181017

10191018
// Since we were connecting/connected, we should start a new connection
10201019
// attempt.
1021-
go ac.resetTransportAndUnlock()
1020+
ac.resetTransportAndUnlock(abort)
10221021
}
10231022

10241023
// getServerName determines the serverName to be used in the connection
@@ -1249,9 +1248,17 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) {
12491248
// resetTransportAndUnlock unconditionally connects the addrConn.
12501249
//
12511250
// ac.mu must be held by the caller, and this function will guarantee it is released.
1252-
func (ac *addrConn) resetTransportAndUnlock() {
1253-
acCtx := ac.ctx
1254-
if acCtx.Err() != nil {
1251+
func (ac *addrConn) resetTransportAndUnlock(abort <-chan struct{}) {
1252+
ctx, cancel := context.WithCancel(ac.ctx)
1253+
go func() {
1254+
select {
1255+
case <-abort:
1256+
cancel()
1257+
case <-ctx.Done():
1258+
}
1259+
}()
1260+
1261+
if ctx.Err() != nil {
12551262
ac.mu.Unlock()
12561263
return
12571264
}
@@ -1279,12 +1286,12 @@ func (ac *addrConn) resetTransportAndUnlock() {
12791286
ac.updateConnectivityState(connectivity.Connecting, nil)
12801287
ac.mu.Unlock()
12811288

1282-
if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil {
1289+
if err := ac.tryAllAddrs(ctx, addrs, connectDeadline); err != nil {
12831290
// TODO: #7534 - Move re-resolution requests into the pick_first LB policy
12841291
// to ensure one resolution request per pass instead of per subconn failure.
12851292
ac.cc.resolveNow(resolver.ResolveNowOptions{})
12861293
ac.mu.Lock()
1287-
if acCtx.Err() != nil {
1294+
if ctx.Err() != nil {
12881295
// addrConn was torn down.
12891296
ac.mu.Unlock()
12901297
return
@@ -1305,13 +1312,13 @@ func (ac *addrConn) resetTransportAndUnlock() {
13051312
ac.mu.Unlock()
13061313
case <-b:
13071314
timer.Stop()
1308-
case <-acCtx.Done():
1315+
case <-ctx.Done():
13091316
timer.Stop()
13101317
return
13111318
}
13121319

13131320
ac.mu.Lock()
1314-
if acCtx.Err() == nil {
1321+
if ctx.Err() == nil {
13151322
ac.updateConnectivityState(connectivity.Idle, err)
13161323
}
13171324
ac.mu.Unlock()
@@ -1366,6 +1373,9 @@ func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, c
13661373
// new transport.
13671374
func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error {
13681375
addr.ServerName = ac.cc.getServerName(addr)
1376+
1377+
var healthCheckStarted atomic.Bool
1378+
healthCheckDone := make(chan struct{})
13691379
hctx, hcancel := context.WithCancel(ctx)
13701380

13711381
onClose := func(r transport.GoAwayReason) {
@@ -1394,6 +1404,9 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
13941404
// Always go idle and wait for the LB policy to initiate a new
13951405
// connection attempt.
13961406
ac.updateConnectivityState(connectivity.Idle, nil)
1407+
if healthCheckStarted.Load() {
1408+
<-healthCheckDone
1409+
}
13971410
}
13981411

13991412
connectCtx, cancel := context.WithDeadline(ctx, connectDeadline)
@@ -1406,29 +1419,35 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
14061419
logger.Infof("Creating new client transport to %q: %v", addr, err)
14071420
}
14081421
// newTr is either nil, or closed.
1409-
hcancel()
14101422
channelz.Warningf(logger, ac.channelz, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addr, err)
14111423
return err
14121424
}
14131425

1414-
ac.mu.Lock()
1415-
defer ac.mu.Unlock()
1426+
acMu := &ac.mu
1427+
acMu.Lock()
1428+
defer func() {
1429+
if acMu != nil {
1430+
acMu.Unlock()
1431+
}
1432+
}()
14161433
if ctx.Err() != nil {
14171434
// This can happen if the subConn was removed while in `Connecting`
14181435
// state. tearDown() would have set the state to `Shutdown`, but
14191436
// would not have closed the transport since ac.transport would not
14201437
// have been set at that point.
1421-
//
1422-
// We run this in a goroutine because newTr.Close() calls onClose()
1438+
1439+
// We unlock ac.mu because newTr.Close() calls onClose()
14231440
// inline, which requires locking ac.mu.
1424-
//
1441+
acMu.Unlock()
1442+
acMu = nil
1443+
14251444
// The error we pass to Close() is immaterial since there are no open
14261445
// streams at this point, so no trailers with error details will be sent
14271446
// out. We just need to pass a non-nil error.
14281447
//
14291448
// This can also happen when updateAddrs is called during a connection
14301449
// attempt.
1431-
go newTr.Close(transport.ErrConnClosing)
1450+
newTr.Close(transport.ErrConnClosing)
14321451
return nil
14331452
}
14341453
if hctx.Err() != nil {
@@ -1440,7 +1459,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
14401459
}
14411460
ac.curAddr = addr
14421461
ac.transport = newTr
1443-
ac.startHealthCheck(hctx) // Will set state to READY if appropriate.
1462+
healthCheckStarted.Store(ac.startHealthCheck(hctx, healthCheckDone)) // Will set state to READY if appropriate.
14441463
return nil
14451464
}
14461465

@@ -1456,7 +1475,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
14561475
// It sets addrConn to READY if the health checking stream is not started.
14571476
//
14581477
// Caller must hold ac.mu.
1459-
func (ac *addrConn) startHealthCheck(ctx context.Context) {
1478+
func (ac *addrConn) startHealthCheck(ctx context.Context, done chan<- struct{}) bool {
14601479
var healthcheckManagingState bool
14611480
defer func() {
14621481
if !healthcheckManagingState {
@@ -1465,22 +1484,22 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
14651484
}()
14661485

14671486
if ac.cc.dopts.disableHealthCheck {
1468-
return
1487+
return false
14691488
}
14701489
healthCheckConfig := ac.cc.healthCheckConfig()
14711490
if healthCheckConfig == nil {
1472-
return
1491+
return false
14731492
}
14741493
if !ac.scopts.HealthCheckEnabled {
1475-
return
1494+
return false
14761495
}
14771496
healthCheckFunc := internal.HealthCheckFunc
14781497
if healthCheckFunc == nil {
14791498
// The health package is not imported to set health check function.
14801499
//
14811500
// TODO: add a link to the health check doc in the error message.
14821501
channelz.Error(logger, ac.channelz, "Health check is requested but health check function is not set.")
1483-
return
1502+
return false
14841503
}
14851504

14861505
healthcheckManagingState = true
@@ -1506,6 +1525,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
15061525
}
15071526
// Start the health checking stream.
15081527
go func() {
1528+
defer close(done)
15091529
err := healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName)
15101530
if err != nil {
15111531
if status.Code(err) == codes.Unimplemented {
@@ -1515,6 +1535,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
15151535
}
15161536
}
15171537
}()
1538+
return true
15181539
}
15191540

15201541
func (ac *addrConn) resetConnectBackoff() {

internal/balancer/gracefulswitch/gracefulswitch.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ type Balancer struct {
6767
// balancerCurrent before the UpdateSubConnState is called on the
6868
// balancerCurrent.
6969
currentMu sync.Mutex
70+
71+
pendingSwaps sync.WaitGroup
7072
}
7173

7274
// swap swaps out the current lb with the pending lb and updates the ClientConn.
@@ -76,7 +78,9 @@ func (gsb *Balancer) swap() {
7678
cur := gsb.balancerCurrent
7779
gsb.balancerCurrent = gsb.balancerPending
7880
gsb.balancerPending = nil
81+
gsb.pendingSwaps.Add(1)
7982
go func() {
83+
defer gsb.pendingSwaps.Done()
8084
gsb.currentMu.Lock()
8185
defer gsb.currentMu.Unlock()
8286
cur.Close()
@@ -274,6 +278,7 @@ func (gsb *Balancer) Close() {
274278

275279
currentBalancerToClose.Close()
276280
pendingBalancerToClose.Close()
281+
gsb.pendingSwaps.Wait()
277282
}
278283

279284
// balancerWrapper wraps a balancer.Balancer, and overrides some Balancer

internal/testutils/pipe_listener.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package testutils
2121

2222
import (
23+
"context"
2324
"errors"
2425
"net"
2526
"time"
@@ -81,11 +82,20 @@ func (p *PipeListener) Addr() net.Addr {
8182
// Dialer dials a connection.
8283
func (p *PipeListener) Dialer() func(string, time.Duration) (net.Conn, error) {
8384
return func(string, time.Duration) (net.Conn, error) {
85+
return p.ContextDialer()(context.Background(), "")
86+
}
87+
}
88+
89+
// ContextDialer dials a using a context.
90+
func (p *PipeListener) ContextDialer() func(context.Context, string) (net.Conn, error) {
91+
return func(ctx context.Context, _ string) (net.Conn, error) {
8492
connChan := make(chan net.Conn)
8593
select {
8694
case p.c <- connChan:
8795
case <-p.done:
8896
return nil, errClosed
97+
case <-ctx.Done():
98+
return nil, context.Cause(ctx)
8999
}
90100
conn, ok := <-connChan
91101
if !ok {

test/clientconn_state_transition_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s
166166
client, err := grpc.NewClient("passthrough:///",
167167
grpc.WithTransportCredentials(insecure.NewCredentials()),
168168
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)),
169-
grpc.WithDialer(pl.Dialer()),
169+
grpc.WithContextDialer(pl.ContextDialer()),
170170
grpc.WithConnectParams(grpc.ConnectParams{
171171
Backoff: backoff.Config{},
172172
MinConnectTimeout: 100 * time.Millisecond,

0 commit comments

Comments
 (0)