Skip to content

Commit ff14847

Browse files
authored
Merge pull request #9343 from ellemouton/contextGuard
fn: expand the ContextGuard and add tests
2 parents 6298f76 + f99cabf commit ff14847

File tree

3 files changed

+675
-57
lines changed

3 files changed

+675
-57
lines changed

fn/context_guard.go

+189-57
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package fn
33
import (
44
"context"
55
"sync"
6+
"sync/atomic"
67
"time"
78
)
89

@@ -11,103 +12,234 @@ var (
1112
DefaultTimeout = 30 * time.Second
1213
)
1314

14-
// ContextGuard is an embeddable struct that provides a wait group and main quit
15-
// channel that can be used to create guarded contexts.
15+
// ContextGuard is a struct that provides a wait group and main quit channel
16+
// that can be used to create guarded contexts.
1617
type ContextGuard struct {
17-
DefaultTimeout time.Duration
18-
Wg sync.WaitGroup
19-
Quit chan struct{}
18+
mu sync.Mutex
19+
wg sync.WaitGroup
20+
21+
quit chan struct{}
22+
stopped sync.Once
23+
24+
// id is used to generate unique ids for each context that should be
25+
// cancelled when the main quit signal is triggered.
26+
id atomic.Uint32
27+
28+
// cancelFns is a map of cancel functions that can be used to cancel
29+
// any context that should be cancelled when the main quit signal is
30+
// triggered. The key is the id of the context. The mutex must be held
31+
// when accessing this map.
32+
cancelFns map[uint32]context.CancelFunc
2033
}
2134

35+
// NewContextGuard constructs and returns a new instance of ContextGuard.
2236
func NewContextGuard() *ContextGuard {
2337
return &ContextGuard{
24-
DefaultTimeout: DefaultTimeout,
25-
Quit: make(chan struct{}),
38+
quit: make(chan struct{}),
39+
cancelFns: make(map[uint32]context.CancelFunc),
2640
}
2741
}
2842

29-
// WithCtxQuit is used to create a cancellable context that will be cancelled
30-
// if the main quit signal is triggered or after the default timeout occurred.
31-
func (g *ContextGuard) WithCtxQuit() (context.Context, func()) {
32-
return g.WithCtxQuitCustomTimeout(g.DefaultTimeout)
33-
}
43+
// Quit is used to signal the main quit channel, which will cancel all
44+
// non-blocking contexts derived from the ContextGuard.
45+
func (g *ContextGuard) Quit() {
46+
g.stopped.Do(func() {
47+
g.mu.Lock()
48+
defer g.mu.Unlock()
3449

35-
// WithCtxQuitCustomTimeout is used to create a cancellable context that will be
36-
// cancelled if the main quit signal is triggered or after the given timeout
37-
// occurred.
38-
func (g *ContextGuard) WithCtxQuitCustomTimeout(
39-
timeout time.Duration) (context.Context, func()) {
50+
for _, cancel := range g.cancelFns {
51+
cancel()
52+
}
4053

41-
timeoutTimer := time.NewTimer(timeout)
42-
ctx, cancel := context.WithCancel(context.Background())
54+
close(g.quit)
55+
})
56+
}
4357

44-
g.Wg.Add(1)
45-
go func() {
46-
defer timeoutTimer.Stop()
47-
defer cancel()
48-
defer g.Wg.Done()
58+
// Done returns a channel that will be closed when the main quit signal is
59+
// triggered.
60+
func (g *ContextGuard) Done() <-chan struct{} {
61+
return g.quit
62+
}
4963

50-
select {
51-
case <-g.Quit:
64+
// WgAdd is used to add delta to the internal wait group of the ContextGuard.
65+
func (g *ContextGuard) WgAdd(delta int) {
66+
g.wg.Add(delta)
67+
}
5268

53-
case <-timeoutTimer.C:
69+
// WgDone is used to decrement the internal wait group of the ContextGuard.
70+
func (g *ContextGuard) WgDone() {
71+
g.wg.Done()
72+
}
5473

55-
case <-ctx.Done():
56-
}
57-
}()
74+
// WgWait is used to block until the internal wait group of the ContextGuard is
75+
// empty.
76+
func (g *ContextGuard) WgWait() {
77+
g.wg.Wait()
78+
}
5879

59-
return ctx, cancel
80+
// ctxGuardOptions is used to configure the behaviour of the context derived
81+
// via the WithCtx method of the ContextGuard.
82+
type ctxGuardOptions struct {
83+
blocking bool
84+
withTimeout bool
85+
timeout time.Duration
6086
}
6187

62-
// CtxBlocking is used to create a cancellable context that will NOT be
88+
// ContextGuardOption defines the signature of a functional option that can be
89+
// used to configure the behaviour of the context derived via the WithCtx method
90+
// of the ContextGuard.
91+
type ContextGuardOption func(*ctxGuardOptions)
92+
93+
// WithBlockingCG is used to create a cancellable context that will NOT be
6394
// cancelled if the main quit signal is triggered, to block shutdown of
64-
// important tasks. The context will be cancelled if the timeout is reached.
65-
func (g *ContextGuard) CtxBlocking() (context.Context, func()) {
66-
return g.CtxBlockingCustomTimeout(g.DefaultTimeout)
95+
// important tasks.
96+
func WithBlockingCG() ContextGuardOption {
97+
return func(o *ctxGuardOptions) {
98+
o.blocking = true
99+
}
100+
}
101+
102+
// WithCustomTimeoutCG is used to create a cancellable context with a custom
103+
// timeout. Such a context will be cancelled if either the parent context is
104+
// cancelled, the timeout is reached or, if the Blocking option is not provided,
105+
// the main quit signal is triggered.
106+
func WithCustomTimeoutCG(timeout time.Duration) ContextGuardOption {
107+
return func(o *ctxGuardOptions) {
108+
o.withTimeout = true
109+
o.timeout = timeout
110+
}
111+
}
112+
113+
// WithTimeoutCG is used to create a cancellable context with a default timeout.
114+
// Such a context will be cancelled if either the parent context is cancelled,
115+
// the timeout is reached or, if the Blocking option is not provided, the main
116+
// quit signal is triggered.
117+
func WithTimeoutCG() ContextGuardOption {
118+
return func(o *ctxGuardOptions) {
119+
o.withTimeout = true
120+
o.timeout = DefaultTimeout
121+
}
67122
}
68123

69-
// CtxBlockingCustomTimeout is used to create a cancellable context with a
70-
// custom timeout that will NOT be cancelled if the main quit signal is
71-
// triggered, to block shutdown of important tasks. The context will be
72-
// cancelled if the timeout is reached.
73-
func (g *ContextGuard) CtxBlockingCustomTimeout(
74-
timeout time.Duration) (context.Context, func()) {
124+
// Create is used to derive a cancellable context from the parent. Various
125+
// options can be provided to configure the behaviour of the derived context.
126+
func (g *ContextGuard) Create(ctx context.Context,
127+
options ...ContextGuardOption) (context.Context, context.CancelFunc) {
75128

76-
timeoutTimer := time.NewTimer(timeout)
77-
ctx, cancel := context.WithCancel(context.Background())
129+
// Exit early if the parent context has already been cancelled.
130+
select {
131+
case <-ctx.Done():
132+
return ctx, func() {}
133+
default:
134+
}
135+
136+
var opts ctxGuardOptions
137+
for _, o := range options {
138+
o(&opts)
139+
}
78140

79-
g.Wg.Add(1)
141+
g.mu.Lock()
142+
defer g.mu.Unlock()
143+
144+
var cancel context.CancelFunc
145+
if opts.withTimeout {
146+
ctx, cancel = context.WithTimeout(ctx, opts.timeout)
147+
} else {
148+
ctx, cancel = context.WithCancel(ctx)
149+
}
150+
151+
if opts.blocking {
152+
g.ctxBlocking(ctx, cancel)
153+
154+
return ctx, cancel
155+
}
156+
157+
// If the call is non-blocking, then we can exit early if the main quit
158+
// signal has been triggered.
159+
select {
160+
case <-g.quit:
161+
cancel()
162+
163+
return ctx, cancel
164+
default:
165+
}
166+
167+
cancel = g.ctxQuitUnsafe(ctx, cancel)
168+
169+
return ctx, cancel
170+
}
171+
172+
// ctxQuitUnsafe spins off a goroutine that will block until the passed context
173+
// is cancelled or until the quit channel has been signaled after which it will
174+
// call the passed cancel function and decrement the wait group.
175+
//
176+
// NOTE: the caller must hold the ContextGuard's mutex before calling this
177+
// function.
178+
func (g *ContextGuard) ctxQuitUnsafe(ctx context.Context,
179+
cancel context.CancelFunc) context.CancelFunc {
180+
181+
cancel = g.addCancelFnUnsafe(cancel)
182+
183+
g.wg.Add(1)
80184
go func() {
81-
defer timeoutTimer.Stop()
82185
defer cancel()
83-
defer g.Wg.Done()
186+
defer g.wg.Done()
84187

85188
select {
86-
case <-timeoutTimer.C:
189+
case <-g.quit:
87190

88191
case <-ctx.Done():
89192
}
90193
}()
91194

92-
return ctx, cancel
195+
return cancel
93196
}
94197

95-
// WithCtxQuitNoTimeout is used to create a cancellable context that will be
96-
// cancelled if the main quit signal is triggered.
97-
func (g *ContextGuard) WithCtxQuitNoTimeout() (context.Context, func()) {
98-
ctx, cancel := context.WithCancel(context.Background())
198+
// ctxBlocking spins off a goroutine that will block until the passed context
199+
// is cancelled after which it will call the passed cancel function and
200+
// decrement the wait group.
201+
func (g *ContextGuard) ctxBlocking(ctx context.Context,
202+
cancel context.CancelFunc) {
99203

100-
g.Wg.Add(1)
204+
g.wg.Add(1)
101205
go func() {
102206
defer cancel()
103-
defer g.Wg.Done()
207+
defer g.wg.Done()
104208

105209
select {
106-
case <-g.Quit:
107-
108210
case <-ctx.Done():
109211
}
110212
}()
213+
}
111214

112-
return ctx, cancel
215+
// addCancelFnUnsafe adds a context cancel function to the manager and returns a
216+
// call-back which can safely be used to cancel the context.
217+
//
218+
// NOTE: the caller must hold the ContextGuard's mutex before calling this
219+
// function.
220+
func (g *ContextGuard) addCancelFnUnsafe(
221+
cancel context.CancelFunc) context.CancelFunc {
222+
223+
id := g.id.Add(1)
224+
g.cancelFns[id] = cancel
225+
226+
return g.cancelCtxFn(id)
227+
}
228+
229+
// cancelCtxFn returns a call-back that can be used to cancel the context
230+
// associated with the passed id.
231+
func (g *ContextGuard) cancelCtxFn(id uint32) context.CancelFunc {
232+
return func() {
233+
g.mu.Lock()
234+
235+
fn, ok := g.cancelFns[id]
236+
if !ok {
237+
g.mu.Unlock()
238+
return
239+
}
240+
delete(g.cancelFns, id)
241+
g.mu.Unlock()
242+
243+
fn()
244+
}
113245
}

0 commit comments

Comments
 (0)