diff --git a/pool/result_context_pool.go b/pool/result_context_pool.go index 9cac635..8560c6a 100644 --- a/pool/result_context_pool.go +++ b/pool/result_context_pool.go @@ -20,11 +20,10 @@ type ResultContextPool[T any] struct { // Go submits a task to the pool. If all goroutines in the pool // are busy, a call to Go() will block until the task can be started. func (p *ResultContextPool[T]) Go(f func(context.Context) (T, error)) { + idx := p.agg.nextIndex() p.contextPool.Go(func(ctx context.Context) error { res, err := f(ctx) - if err == nil || p.collectErrored { - p.agg.add(res) - } + p.agg.save(idx, res, err != nil) return err }) } @@ -33,7 +32,7 @@ func (p *ResultContextPool[T]) Go(f func(context.Context) (T, error)) { // returns an error if any of the tasks errored. func (p *ResultContextPool[T]) Wait() ([]T, error) { err := p.contextPool.Wait() - return p.agg.results, err + return p.agg.collect(p.collectErrored), err } // WithCollectErrored configures the pool to still collect the result of a task diff --git a/pool/result_context_pool_test.go b/pool/result_context_pool_test.go index d98e0c8..ceae5e8 100644 --- a/pool/result_context_pool_test.go +++ b/pool/result_context_pool_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "sort" "strconv" "sync/atomic" "testing" @@ -223,7 +222,6 @@ func TestResultContextPool(t *testing.T) { }) } res, err := g.Wait() - sort.Ints(res) require.Equal(t, expected, res) require.NoError(t, err) require.Equal(t, int64(0), currentConcurrent.Load()) diff --git a/pool/result_error_pool.go b/pool/result_error_pool.go index 4caaadc..5a0bfb9 100644 --- a/pool/result_error_pool.go +++ b/pool/result_error_pool.go @@ -8,9 +8,8 @@ import ( // type and an error. Tasks are executed in the pool with Go(), then the // results of the tasks are returned by Wait(). // -// The order of the results is not guaranteed to be the same as the order the -// tasks were submitted. If your use case requires consistent ordering, -// consider using the `stream` package or `Map` from the `iter` package. +// The order of the results is guaranteed to be the same as the order the +// tasks were submitted. // // The configuration methods (With*) will panic if they are used after calling // Go() for the first time. @@ -23,11 +22,10 @@ type ResultErrorPool[T any] struct { // Go submits a task to the pool. If all goroutines in the pool // are busy, a call to Go() will block until the task can be started. func (p *ResultErrorPool[T]) Go(f func() (T, error)) { + idx := p.agg.nextIndex() p.errorPool.Go(func() error { res, err := f() - if err == nil || p.collectErrored { - p.agg.add(res) - } + p.agg.save(idx, res, err != nil) return err }) } @@ -36,7 +34,7 @@ func (p *ResultErrorPool[T]) Go(f func() (T, error)) { // returning the results and any errors from tasks. func (p *ResultErrorPool[T]) Wait() ([]T, error) { err := p.errorPool.Wait() - return p.agg.results, err + return p.agg.collect(p.collectErrored), err } // WithCollectErrored configures the pool to still collect the result of a task diff --git a/pool/result_error_pool_test.go b/pool/result_error_pool_test.go index 84d83b9..c9b1b08 100644 --- a/pool/result_error_pool_test.go +++ b/pool/result_error_pool_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestResultErrorGroup(t *testing.T) { +func TestResultErrorPool(t *testing.T) { t.Parallel() err1 := errors.New("err1") diff --git a/pool/result_pool.go b/pool/result_pool.go index ea304cb..16d8b46 100644 --- a/pool/result_pool.go +++ b/pool/result_pool.go @@ -2,6 +2,7 @@ package pool import ( "context" + "sort" "sync" ) @@ -19,9 +20,8 @@ func NewWithResults[T any]() *ResultPool[T] { // Tasks are executed in the pool with Go(), then the results of the tasks are // returned by Wait(). // -// The order of the results is not guaranteed to be the same as the order the -// tasks were submitted. If your use case requires consistent ordering, -// consider using the `stream` package or `Map` from the `iter` package. +// The order of the results is guaranteed to be the same as the order the +// tasks were submitted. type ResultPool[T any] struct { pool Pool agg resultAggregator[T] @@ -30,8 +30,9 @@ type ResultPool[T any] struct { // Go submits a task to the pool. If all goroutines in the pool // are busy, a call to Go() will block until the task can be started. func (p *ResultPool[T]) Go(f func() T) { + idx := p.agg.nextIndex() p.pool.Go(func() { - p.agg.add(f()) + p.agg.save(idx, f(), false) }) } @@ -39,7 +40,7 @@ func (p *ResultPool[T]) Go(f func() T) { // a slice of results from tasks that did not panic. func (p *ResultPool[T]) Wait() []T { p.pool.Wait() - return p.agg.results + return p.agg.collect(true) } // MaxGoroutines returns the maximum size of the pool. @@ -83,11 +84,57 @@ func (p *ResultPool[T]) panicIfInitialized() { // goroutines. The zero value is valid and ready to use. type resultAggregator[T any] struct { mu sync.Mutex + len int results []T + errored []int } -func (r *resultAggregator[T]) add(res T) { +// nextIndex reserves a slot for a result. The returned value should be passed +// to save() when adding a result to the aggregator. +func (r *resultAggregator[T]) nextIndex() int { r.mu.Lock() - r.results = append(r.results, res) - r.mu.Unlock() + defer r.mu.Unlock() + + nextIdx := r.len + r.len += 1 + return nextIdx +} + +func (r *resultAggregator[T]) save(i int, res T, errored bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if i >= len(r.results) { + old := r.results + r.results = make([]T, r.len) + copy(r.results, old) + } + + r.results[i] = res + + if errored { + r.errored = append(r.errored, i) + } +} + +// collect returns the set of aggregated results. +func (r *resultAggregator[T]) collect(collectErrored bool) []T { + if !r.mu.TryLock() { + panic("collect should not be called until all goroutines have exited") + } + + if collectErrored || len(r.errored) == 0 { + return r.results + } + + filtered := r.results[:0] + sort.Ints(r.errored) + for i, e := range r.errored { + if i == 0 { + filtered = append(filtered, r.results[:e]...) + } else { + filtered = append(filtered, r.results[r.errored[i-1]+1:e]...) + } + } + return filtered } diff --git a/pool/result_pool_test.go b/pool/result_pool_test.go index 10b3096..ccd7892 100644 --- a/pool/result_pool_test.go +++ b/pool/result_pool_test.go @@ -2,7 +2,7 @@ package pool_test import ( "fmt" - "sort" + "math/rand" "strconv" "sync/atomic" "testing" @@ -22,8 +22,6 @@ func ExampleResultPool() { }) } res := p.Wait() - // Result order is nondeterministic, so sort them first - sort.Ints(res) fmt.Println(res) // Output: @@ -62,10 +60,29 @@ func TestResultGroup(t *testing.T) { }) } res := g.Wait() - sort.Ints(res) require.Equal(t, expected, res) }) + t.Run("deterministic order", func(t *testing.T) { + t.Parallel() + p := pool.NewWithResults[int]() + results := make([]int, 100) + for i := 0; i < 100; i++ { + results[i] = i + } + for _, result := range results { + result := result + p.Go(func() int { + // Add a random sleep to make it exceedingly unlikely that the + // results are returned in the order they are submitted. + time.Sleep(time.Duration(rand.Int()%100) * time.Millisecond) + return result + }) + } + got := p.Wait() + require.Equal(t, results, got) + }) + t.Run("limit", func(t *testing.T) { t.Parallel() for _, maxGoroutines := range []int{1, 10, 100} { @@ -90,7 +107,6 @@ func TestResultGroup(t *testing.T) { }) } res := g.Wait() - sort.Ints(res) require.Equal(t, expected, res) require.Equal(t, int64(0), errCount.Load()) require.Equal(t, int64(0), currentConcurrent.Load())