Skip to content

Commit

Permalink
taskgroup: add an OnError method to Group (#9)
Browse files Browse the repository at this point in the history
This allows an error filter to be set after the group is constructed.
The original New constructor is widely used, so it has been left alone, apart
from being rewritten in terms of the new method.

Apart from the added method, no functional changes.  Update a few tests to
tickle the edges of it a bit.
  • Loading branch information
creachadair authored Oct 6, 2024
1 parent afebe58 commit 68ef45d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
32 changes: 22 additions & 10 deletions taskgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ type Task func() error
// non-nil error reported by any task (and not otherwise filtered) is returned
// from the Wait method.
type Group struct {
wg sync.WaitGroup // counter for active goroutines
onError ErrorFunc // called each time a task returns non-nil
wg sync.WaitGroup // counter for active goroutines

// active is nonzero when the group is "active", meaning there has been at
// least one call to Go since the group was created or the last Wait.
Expand All @@ -32,8 +31,9 @@ type Group struct {
// path reads active and only acquires μ if it discovers setup is needed.
active atomic.Uint32

μ sync.Mutex // guards err
err error // error returned from Wait
μ sync.Mutex // guards err
err error // error returned from Wait
onError ErrorFunc // called each time a task returns non-nil
}

// activate resets the state of the group and marks it as active. This is
Expand All @@ -47,13 +47,25 @@ func (g *Group) activate() {
}
}

// New constructs a new empty group. If ef != nil, it is called for each error
// reported by a task running in the group. The value returned by ef replaces
// the task's error. If ef == nil, errors are not filtered.
// New constructs a new empty group with the specified error filter.
// See [Group.OnError] for a description of how errors are filtered.
// If ef == nil, no filtering is performed.
func New(ef ErrorFunc) *Group { return new(Group).OnError(ef) }

// OnError sets the error filter for g. If ef == nil, the error filter is
// removed and errors are no longer filtered. Otherwise, each non-nil error
// reported by a task running in g is passed to ef, and the value it returns
// replaces the task's error.
//
// Calls to ef are issued by a single goroutine, so it is safe for ef to
// manipulate local data structures without additional locking.
func New(ef ErrorFunc) *Group { return &Group{onError: ef} }
// Calls to ef are synchronized so that it is safe for ef to manipulate local
// data structures without additional locking. It is safe to call OnError while
// tasks are active in g.
func (g *Group) OnError(ef ErrorFunc) *Group {
g.μ.Lock()
defer g.μ.Unlock()
g.onError = ef
return g
}

// Go runs task in a new goroutine in g.
func (g *Group) Go(task Task) {
Expand Down
21 changes: 14 additions & 7 deletions taskgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestBasic(t *testing.T) {
t.Logf("Group value is %d bytes", reflect.TypeOf((*taskgroup.Group)(nil)).Elem().Size())

// Verify that the group works at all.
g := taskgroup.New(nil)
var g taskgroup.Group
g.Go(busyWork(25, nil))
if err := g.Wait(); err != nil {
t.Errorf("Unexpected task error: %v", err)
Expand All @@ -44,7 +44,7 @@ func TestBasic(t *testing.T) {
}

t.Run("Zero", func(t *testing.T) {
var g taskgroup.Group
g := taskgroup.New(nil)
g.Go(busyWork(30, nil))
if err := g.Wait(); err != nil {
t.Errorf("Unexpected task error: %v", err)
Expand All @@ -62,11 +62,18 @@ func TestErrorPropagation(t *testing.T) {
defer leaktest.Check(t)()

var errBogus = errors.New("bogus")

var g taskgroup.Group
g.Go(func() error { return errBogus })
if err := g.Wait(); err != errBogus {
t.Errorf("Wait: got error %v, wanted %v", err, errBogus)
}

g.OnError(func(error) error { return nil }) // discard
g.Go(func() error { return errBogus })
if err := g.Wait(); err != nil {
t.Errorf("Wait: got error %v, wanted nil", err)
}
}

func TestCancellation(t *testing.T) {
Expand Down Expand Up @@ -154,7 +161,7 @@ func TestCapacity(t *testing.T) {
func TestRegression(t *testing.T) {
t.Run("WaitRace", func(t *testing.T) {
ready := make(chan struct{})
g := taskgroup.New(nil)
var g taskgroup.Group
g.Go(func() error {
<-ready
return nil
Expand All @@ -174,7 +181,7 @@ func TestRegression(t *testing.T) {
t.Errorf("Unexpected panic: %v", x)
}
}()
g := taskgroup.New(nil)
var g taskgroup.Group
g.Wait()
})
}
Expand Down Expand Up @@ -208,7 +215,7 @@ func TestSingleTask(t *testing.T) {
return <-release
})

g := taskgroup.New(nil)
var g taskgroup.Group
g.Run(func() {
if err := s.Wait(); err != sentinel {
t.Errorf("Background Wait: got %v, want %v", err, sentinel)
Expand All @@ -231,7 +238,7 @@ func TestWaitMoreTasks(t *testing.T) {
results++
})

g := taskgroup.New(nil)
var g taskgroup.Group

// Test that if a task spawns more tasks on its own recognizance, waiting
// correctly waits for all of them provided we do not let the group go empty
Expand Down Expand Up @@ -283,7 +290,7 @@ func TestCollector(t *testing.T) {
c := taskgroup.Collect(func(v int) { sum += v })

vs := rand.Perm(15)
g := taskgroup.New(nil)
var g taskgroup.Group

for i, v := range vs {
v := v
Expand Down

0 comments on commit 68ef45d

Please sign in to comment.