From 14915ed7f9b15c6b265ee71a5f822962e46ded30 Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Sat, 5 Oct 2024 15:34:51 -0700 Subject: [PATCH] taskgroup: add an OnError method to Group This allows an error filter to be set after the group is constructed. The original New constructor is widely used, so it has been left alone, apart from being rewritten in terms of the new method. Apart from the added method, no functional changes. Update a few tests to tickle the edges of it a bit. --- taskgroup.go | 32 ++++++++++++++++++++++---------- taskgroup_test.go | 21 ++++++++++++++------- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/taskgroup.go b/taskgroup.go index 19c0afc..8f144e4 100644 --- a/taskgroup.go +++ b/taskgroup.go @@ -22,8 +22,7 @@ type Task func() error // non-nil error reported by any task (and not otherwise filtered) is returned // from the Wait method. type Group struct { - wg sync.WaitGroup // counter for active goroutines - onError ErrorFunc // called each time a task returns non-nil + wg sync.WaitGroup // counter for active goroutines // active is nonzero when the group is "active", meaning there has been at // least one call to Go since the group was created or the last Wait. @@ -32,8 +31,9 @@ type Group struct { // path reads active and only acquires μ if it discovers setup is needed. active atomic.Uint32 - μ sync.Mutex // guards err - err error // error returned from Wait + μ sync.Mutex // guards err + err error // error returned from Wait + onError ErrorFunc // called each time a task returns non-nil } // activate resets the state of the group and marks it as active. This is @@ -47,13 +47,25 @@ func (g *Group) activate() { } } -// New constructs a new empty group. If ef != nil, it is called for each error -// reported by a task running in the group. The value returned by ef replaces -// the task's error. If ef == nil, errors are not filtered. +// New constructs a new empty group with the specified error filter. +// See [Group.OnError] for a description of how errors are filtered. +// If ef == nil, no filtering is performed. +func New(ef ErrorFunc) *Group { return new(Group).OnError(ef) } + +// OnError sets the error filter for g. If ef == nil, the error filter is +// removed and errors are no longer filtered. Otherwise, each non-nil error +// reported by a task running in g is passed to ef, and the value it returns +// replaces the task's error. // -// Calls to ef are issued by a single goroutine, so it is safe for ef to -// manipulate local data structures without additional locking. -func New(ef ErrorFunc) *Group { return &Group{onError: ef} } +// Calls to ef are synchronized so that it is safe for ef to manipulate local +// data structures without additional locking. It is safe to call OnError while +// tasks are active in g. +func (g *Group) OnError(ef ErrorFunc) *Group { + g.μ.Lock() + defer g.μ.Unlock() + g.onError = ef + return g +} // Go runs task in a new goroutine in g. func (g *Group) Go(task Task) { diff --git a/taskgroup_test.go b/taskgroup_test.go index 91df2af..b35eaa4 100644 --- a/taskgroup_test.go +++ b/taskgroup_test.go @@ -30,7 +30,7 @@ func TestBasic(t *testing.T) { t.Logf("Group value is %d bytes", reflect.TypeOf((*taskgroup.Group)(nil)).Elem().Size()) // Verify that the group works at all. - g := taskgroup.New(nil) + var g taskgroup.Group g.Go(busyWork(25, nil)) if err := g.Wait(); err != nil { t.Errorf("Unexpected task error: %v", err) @@ -44,7 +44,7 @@ func TestBasic(t *testing.T) { } t.Run("Zero", func(t *testing.T) { - var g taskgroup.Group + g := taskgroup.New(nil) g.Go(busyWork(30, nil)) if err := g.Wait(); err != nil { t.Errorf("Unexpected task error: %v", err) @@ -62,11 +62,18 @@ func TestErrorPropagation(t *testing.T) { defer leaktest.Check(t)() var errBogus = errors.New("bogus") + var g taskgroup.Group g.Go(func() error { return errBogus }) if err := g.Wait(); err != errBogus { t.Errorf("Wait: got error %v, wanted %v", err, errBogus) } + + g.OnError(func(error) error { return nil }) // discard + g.Go(func() error { return errBogus }) + if err := g.Wait(); err != nil { + t.Errorf("Wait: got error %v, wanted nil", err) + } } func TestCancellation(t *testing.T) { @@ -154,7 +161,7 @@ func TestCapacity(t *testing.T) { func TestRegression(t *testing.T) { t.Run("WaitRace", func(t *testing.T) { ready := make(chan struct{}) - g := taskgroup.New(nil) + var g taskgroup.Group g.Go(func() error { <-ready return nil @@ -174,7 +181,7 @@ func TestRegression(t *testing.T) { t.Errorf("Unexpected panic: %v", x) } }() - g := taskgroup.New(nil) + var g taskgroup.Group g.Wait() }) } @@ -208,7 +215,7 @@ func TestSingleTask(t *testing.T) { return <-release }) - g := taskgroup.New(nil) + var g taskgroup.Group g.Run(func() { if err := s.Wait(); err != sentinel { t.Errorf("Background Wait: got %v, want %v", err, sentinel) @@ -231,7 +238,7 @@ func TestWaitMoreTasks(t *testing.T) { results++ }) - g := taskgroup.New(nil) + var g taskgroup.Group // Test that if a task spawns more tasks on its own recognizance, waiting // correctly waits for all of them provided we do not let the group go empty @@ -283,7 +290,7 @@ func TestCollector(t *testing.T) { c := taskgroup.Collect(func(v int) { sum += v }) vs := rand.Perm(15) - g := taskgroup.New(nil) + var g taskgroup.Group for i, v := range vs { v := v