Skip to content

Commit 0bfcec1

Browse files
refine ForEachIdxErr
1 parent 43a82a4 commit 0bfcec1

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

iter/iter.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package iter
22

33
import (
4+
"errors"
45
"runtime"
56
"sync"
67
"sync/atomic"
78

89
"github.com/sourcegraph/conc"
9-
"github.com/sourcegraph/conc/internal/multierror"
1010
)
1111

1212
// defaultMaxGoroutines returns the default maximum number of
@@ -127,7 +127,7 @@ func (iter Iterator[T]) ForEachIdxErr(input []T, f func(int, *T) error) error {
127127
iter.MaxGoroutines = numInput
128128
}
129129

130-
var errs error
130+
var errs []error
131131
var errsMu sync.Mutex
132132
var idx atomic.Int64
133133
var failed atomic.Bool
@@ -137,9 +137,11 @@ func (iter Iterator[T]) ForEachIdxErr(input []T, f func(int, *T) error) error {
137137
i := int(idx.Add(1) - 1)
138138
for ; i < numInput && !failed.Load(); i = int(idx.Add(1) - 1) {
139139
if err := f(i, &input[i]); err != nil {
140-
errsMu.Lock()
141-
errs = multierror.Join(errs, err)
142-
errsMu.Unlock()
140+
if alreadyFailedFast := failed.Swap(iter.FailFast); !alreadyFailedFast {
141+
errsMu.Lock()
142+
errs = append(errs, err)
143+
errsMu.Unlock()
144+
}
143145

144146
failed.Store(iter.FailFast)
145147
}
@@ -152,5 +154,5 @@ func (iter Iterator[T]) ForEachIdxErr(input []T, f func(int, *T) error) error {
152154
}
153155
wg.Wait()
154156

155-
return errs
157+
return errors.Join(errs...)
156158
}

iter/iter_test.go

+15-15
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,13 @@ func TestForIterator_EachIdxErr(t *testing.T) {
174174
t.Parallel()
175175

176176
t.Run("failFast=false", func(t *testing.T) {
177-
it := Iterator[int]{MaxGoroutines: 999}
177+
it := iter.Iterator[int]{MaxGoroutines: 999}
178178
forEach := noIndex(it.ForEachIdxErr)
179179
testForEachErr(t, false, forEach)
180180
})
181181

182182
t.Run("failFast=true", func(t *testing.T) {
183-
it := Iterator[int]{MaxGoroutines: 999}
183+
it := iter.Iterator[int]{MaxGoroutines: 999}
184184
forEach := noIndex(it.ForEachIdxErr)
185185
testForEachErr(t, true, forEach)
186186
})
@@ -190,7 +190,7 @@ func TestForIterator_EachIdxErr(t *testing.T) {
190190

191191
input := []int{1, 2, 3, 4, 5}
192192
errTest := errors.New("test error")
193-
iterator := Iterator[int]{MaxGoroutines: 1, FailFast: true}
193+
iterator := iter.Iterator[int]{MaxGoroutines: 1, FailFast: true}
194194

195195
var mu sync.Mutex
196196
var results []int
@@ -211,7 +211,7 @@ func TestForIterator_EachIdxErr(t *testing.T) {
211211
t.Run("safe for reuse", func(t *testing.T) {
212212
t.Parallel()
213213

214-
iterator := Iterator[int]{MaxGoroutines: 999}
214+
iterator := iter.Iterator[int]{MaxGoroutines: 999}
215215

216216
// iter.Concurrency > numInput case that updates iter.Concurrency
217217
_ = iterator.ForEachIdxErr([]int{1, 2, 3}, func(i int, t *int) error {
@@ -224,12 +224,12 @@ func TestForIterator_EachIdxErr(t *testing.T) {
224224
t.Run("allows more than defaultMaxGoroutines() concurrent tasks", func(t *testing.T) {
225225
t.Parallel()
226226

227-
wantConcurrency := 2 * defaultMaxGoroutines()
227+
wantConcurrency := 2 * iter.DefaultMaxGoroutines()
228228

229229
maxConcurrencyHit := make(chan struct{})
230230

231231
tasks := make([]int, wantConcurrency)
232-
iterator := Iterator[int]{MaxGoroutines: wantConcurrency}
232+
iterator := iter.Iterator[int]{MaxGoroutines: wantConcurrency}
233233

234234
var concurrentTasks atomic.Int64
235235
_ = iterator.ForEachIdxErr(tasks, func(_ int, t *int) error {
@@ -257,19 +257,19 @@ func TestForIterator_EachErr(t *testing.T) {
257257
t.Parallel()
258258

259259
t.Run("failFast=false", func(t *testing.T) {
260-
it := Iterator[int]{MaxGoroutines: 999}
260+
it := iter.Iterator[int]{MaxGoroutines: 999}
261261
testForEachErr(t, false, it.ForEachErr)
262262
})
263263

264264
t.Run("failFast=true", func(t *testing.T) {
265-
it := Iterator[int]{MaxGoroutines: 999}
265+
it := iter.Iterator[int]{MaxGoroutines: 999}
266266
testForEachErr(t, true, it.ForEachErr)
267267
})
268268

269269
t.Run("safe for reuse", func(t *testing.T) {
270270
t.Parallel()
271271

272-
iterator := Iterator[int]{MaxGoroutines: 999}
272+
iterator := iter.Iterator[int]{MaxGoroutines: 999}
273273

274274
// iter.Concurrency > numInput case that updates iter.Concurrency
275275
_ = iterator.ForEachErr([]int{1, 2, 3}, func(t *int) error {
@@ -284,7 +284,7 @@ func TestForIterator_EachErr(t *testing.T) {
284284

285285
input := []int{1, 2, 3, 4, 5}
286286
errTest := errors.New("test error")
287-
iterator := Iterator[int]{MaxGoroutines: 1, FailFast: true}
287+
iterator := iter.Iterator[int]{MaxGoroutines: 1, FailFast: true}
288288

289289
var mu sync.Mutex
290290
var results []int
@@ -305,12 +305,12 @@ func TestForIterator_EachErr(t *testing.T) {
305305
t.Run("allows more than defaultMaxGoroutines() concurrent tasks", func(t *testing.T) {
306306
t.Parallel()
307307

308-
wantConcurrency := 2 * defaultMaxGoroutines()
308+
wantConcurrency := 2 * iter.DefaultMaxGoroutines()
309309

310310
maxConcurrencyHit := make(chan struct{})
311311

312312
tasks := make([]int, wantConcurrency)
313-
iterator := Iterator[int]{MaxGoroutines: wantConcurrency}
313+
iterator := iter.Iterator[int]{MaxGoroutines: wantConcurrency}
314314

315315
var concurrentTasks atomic.Int64
316316
_ = iterator.ForEachErr(tasks, func(t *int) error {
@@ -338,7 +338,7 @@ func TestForEachIdxErr(t *testing.T) {
338338
t.Parallel()
339339

340340
t.Run("standart", func(t *testing.T) {
341-
forEach := noIndex(ForEachIdxErr[int])
341+
forEach := noIndex(iter.ForEachIdxErr[int])
342342
testForEachErr(t, false, forEach)
343343
})
344344

@@ -347,7 +347,7 @@ func TestForEachIdxErr(t *testing.T) {
347347
got := []int{}
348348
gotMu := sync.Mutex{}
349349

350-
err := ForEachIdxErr(ints, func(i int, _ *int) error {
350+
err := iter.ForEachIdxErr(ints, func(i int, _ *int) error {
351351
gotMu.Lock()
352352
defer gotMu.Unlock()
353353
got = append(got, i)
@@ -362,7 +362,7 @@ func TestForEachIdxErr(t *testing.T) {
362362
func TestForEachErr(t *testing.T) {
363363
t.Parallel()
364364

365-
testForEachErr(t, false, ForEachErr[int])
365+
testForEachErr(t, false, iter.ForEachErr[int])
366366
}
367367

368368
// noIndex converts a ForEachIdxErr function (or method) into a ForEachErr function (or method).

0 commit comments

Comments
 (0)