Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

taskgroup: add an OnError method to Group #9

Merged
merged 1 commit into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions taskgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ type Task func() error
// non-nil error reported by any task (and not otherwise filtered) is returned
// from the Wait method.
type Group struct {
wg sync.WaitGroup // counter for active goroutines
onError ErrorFunc // called each time a task returns non-nil
wg sync.WaitGroup // counter for active goroutines

// active is nonzero when the group is "active", meaning there has been at
// least one call to Go since the group was created or the last Wait.
Expand All @@ -32,8 +31,9 @@ type Group struct {
// path reads active and only acquires μ if it discovers setup is needed.
active atomic.Uint32

μ sync.Mutex // guards err
err error // error returned from Wait
μ sync.Mutex // guards err
err error // error returned from Wait
onError ErrorFunc // called each time a task returns non-nil
}

// activate resets the state of the group and marks it as active. This is
Expand All @@ -47,13 +47,25 @@ func (g *Group) activate() {
}
}

// New constructs a new empty group. If ef != nil, it is called for each error
// reported by a task running in the group. The value returned by ef replaces
// the task's error. If ef == nil, errors are not filtered.
// New constructs a new empty group with the specified error filter.
// See [Group.OnError] for a description of how errors are filtered.
// If ef == nil, no filtering is performed.
func New(ef ErrorFunc) *Group { return new(Group).OnError(ef) }

// OnError sets the error filter for g. If ef == nil, the error filter is
// removed and errors are no longer filtered. Otherwise, each non-nil error
// reported by a task running in g is passed to ef, and the value it returns
// replaces the task's error.
//
// Calls to ef are issued by a single goroutine, so it is safe for ef to
// manipulate local data structures without additional locking.
func New(ef ErrorFunc) *Group { return &Group{onError: ef} }
// Calls to ef are synchronized so that it is safe for ef to manipulate local
// data structures without additional locking. It is safe to call OnError while
// tasks are active in g.
func (g *Group) OnError(ef ErrorFunc) *Group {
g.μ.Lock()
defer g.μ.Unlock()
g.onError = ef
return g
}

// Go runs task in a new goroutine in g.
func (g *Group) Go(task Task) {
Expand Down
21 changes: 14 additions & 7 deletions taskgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestBasic(t *testing.T) {
t.Logf("Group value is %d bytes", reflect.TypeOf((*taskgroup.Group)(nil)).Elem().Size())

// Verify that the group works at all.
g := taskgroup.New(nil)
var g taskgroup.Group
g.Go(busyWork(25, nil))
if err := g.Wait(); err != nil {
t.Errorf("Unexpected task error: %v", err)
Expand All @@ -44,7 +44,7 @@ func TestBasic(t *testing.T) {
}

t.Run("Zero", func(t *testing.T) {
var g taskgroup.Group
g := taskgroup.New(nil)
g.Go(busyWork(30, nil))
if err := g.Wait(); err != nil {
t.Errorf("Unexpected task error: %v", err)
Expand All @@ -62,11 +62,18 @@ func TestErrorPropagation(t *testing.T) {
defer leaktest.Check(t)()

var errBogus = errors.New("bogus")

var g taskgroup.Group
g.Go(func() error { return errBogus })
if err := g.Wait(); err != errBogus {
t.Errorf("Wait: got error %v, wanted %v", err, errBogus)
}

g.OnError(func(error) error { return nil }) // discard
g.Go(func() error { return errBogus })
if err := g.Wait(); err != nil {
t.Errorf("Wait: got error %v, wanted nil", err)
}
}

func TestCancellation(t *testing.T) {
Expand Down Expand Up @@ -154,7 +161,7 @@ func TestCapacity(t *testing.T) {
func TestRegression(t *testing.T) {
t.Run("WaitRace", func(t *testing.T) {
ready := make(chan struct{})
g := taskgroup.New(nil)
var g taskgroup.Group
g.Go(func() error {
<-ready
return nil
Expand All @@ -174,7 +181,7 @@ func TestRegression(t *testing.T) {
t.Errorf("Unexpected panic: %v", x)
}
}()
g := taskgroup.New(nil)
var g taskgroup.Group
g.Wait()
})
}
Expand Down Expand Up @@ -208,7 +215,7 @@ func TestSingleTask(t *testing.T) {
return <-release
})

g := taskgroup.New(nil)
var g taskgroup.Group
g.Run(func() {
if err := s.Wait(); err != sentinel {
t.Errorf("Background Wait: got %v, want %v", err, sentinel)
Expand All @@ -231,7 +238,7 @@ func TestWaitMoreTasks(t *testing.T) {
results++
})

g := taskgroup.New(nil)
var g taskgroup.Group

// Test that if a task spawns more tasks on its own recognizance, waiting
// correctly waits for all of them provided we do not let the group go empty
Expand Down Expand Up @@ -283,7 +290,7 @@ func TestCollector(t *testing.T) {
c := taskgroup.Collect(func(v int) { sum += v })

vs := rand.Perm(15)
g := taskgroup.New(nil)
var g taskgroup.Group

for i, v := range vs {
v := v
Expand Down