Skip to content

Commit 98e283a

Browse files
committed
Prevent panic when batchloader returns nil values by replacing them with Result errors
1 parent 3afb6ae commit 98e283a

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

dataloader.go

+7
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,13 @@ func (b *batcher) batch(originalContext context.Context) {
460460
return
461461
}
462462

463+
// When a batchFunc returns a nil in it's items, we replace those by a Result struct containing an error
464+
for key, value := range items {
465+
if value == nil {
466+
items[key] = &Result{Error: fmt.Errorf("no value for key")}
467+
}
468+
}
469+
463470
for i, req := range reqs {
464471
req.channel <- items[i]
465472
close(req.channel)

dataloader_test.go

+50
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,32 @@ func TestLoader(t *testing.T) {
211211
// TODO: expect to get some kind of warning
212212
})
213213

214+
t.Run("first result is a nil", func(t *testing.T) {
215+
t.Parallel()
216+
faultyLoader, _ := LoaderNilInsteadOfResult()
217+
ctx := context.Background()
218+
219+
n := 10
220+
reqs := []Thunk{}
221+
keys := Keys{}
222+
for i := 0; i < n; i++ {
223+
key := StringKey(strconv.Itoa(i))
224+
reqs = append(reqs, faultyLoader.Load(ctx, key))
225+
keys = append(keys, key)
226+
}
227+
228+
for i, future := range reqs {
229+
_, err := future()
230+
if i == 0 && err == nil {
231+
t.Error("expected first result to contain an error")
232+
}
233+
234+
if i != 0 && err != nil {
235+
t.Error("expected rest of results not to contain an error")
236+
}
237+
}
238+
})
239+
214240
t.Run("responds to max batch size", func(t *testing.T) {
215241
t.Parallel()
216242
identityLoader, loadCalls := IDLoader(2)
@@ -586,6 +612,30 @@ func FaultyLoader() (*Loader, *[][]string) {
586612
return loader, &loadCalls
587613
}
588614

615+
// LoaderNilInsteadOfResult gives a nil result back for the first key.
616+
func LoaderNilInsteadOfResult() (*Loader, *[][]string) {
617+
var mu sync.Mutex
618+
var loadCalls [][]string
619+
620+
loader := NewBatchedLoader(func(_ context.Context, keys Keys) []*Result {
621+
var results []*Result
622+
mu.Lock()
623+
loadCalls = append(loadCalls, keys.Keys())
624+
mu.Unlock()
625+
626+
for i, key := range keys {
627+
if i == 0 {
628+
results = append(results, nil)
629+
} else {
630+
results = append(results, &Result{key, nil})
631+
}
632+
}
633+
return results
634+
})
635+
636+
return loader, &loadCalls
637+
}
638+
589639
///////////////////////////////////////////////////
590640
// Benchmarks
591641
///////////////////////////////////////////////////

0 commit comments

Comments
 (0)