From f9cd0e2220796995240a78f59c5f98c86511b3d5 Mon Sep 17 00:00:00 2001 From: Robin Deeboonchai Date: Mon, 29 Jun 2026 12:23:18 -0700 Subject: [PATCH] fix: honor batch max size Co-Authored-By: Claude --- pkg/utils/batcher/batcher.go | 68 ++++++++++++++++++++----------- pkg/utils/batcher/batcher_test.go | 47 +++++++++++++++++++++ 2 files changed, 92 insertions(+), 23 deletions(-) diff --git a/pkg/utils/batcher/batcher.go b/pkg/utils/batcher/batcher.go index 37e9322c65..1f898efa37 100644 --- a/pkg/utils/batcher/batcher.go +++ b/pkg/utils/batcher/batcher.go @@ -133,6 +133,8 @@ func (b *Batcher[RequestPayload, ResponsePayload]) Enqueue(payload RequestPayloa Key: key, } + var fullBatch *Batch[RequestPayload, ResponsePayload] + b.mu.Lock() batch, exists := b.pendingBatches[req.Key] @@ -144,11 +146,23 @@ func (b *Batcher[RequestPayload, ResponsePayload]) Enqueue(payload RequestPayloa Requests: make([]*BatchedRequest[RequestPayload, ResponsePayload], 0, b.opts.MaxBatchSize), } b.pendingBatches[req.Key] = batch + } else if len(batch.Requests) >= b.opts.MaxBatchSize { + fullBatch = batch + batch = &Batch[RequestPayload, ResponsePayload]{ + ID: uuid.New().String(), + Key: req.Key, + Requests: make([]*BatchedRequest[RequestPayload, ResponsePayload], 0, b.opts.MaxBatchSize), + } + b.pendingBatches[req.Key] = batch } batch.Requests = append(batch.Requests, req) b.mu.Unlock() + if fullBatch != nil { + b.executeBatchAsync(uuid.New().String(), fullBatch) + } + // Alert the background loop (e.g., start timer, check execution conditions) // Non-blocking signal (buffer=1 coalesces multiple enqueues) select { @@ -185,12 +199,16 @@ func (b *Batcher[RequestPayload, ResponsePayload]) run() { // This is tolerable because requests typically arrive in bursts from the provisioner. // Suggestion: if needed, we could add per-batch-key timers for more precise control, but it adds complexity. + b.mu.Lock() + batchCount := len(b.pendingBatches) + b.mu.Unlock() + // TODO: use metrics instead? log.FromContext(b.ctx).V(2).Info("batcher iteration finishing wait, ready to execute batches", "batcherIterationID", batcherIterationID, "waitStartTime", waitStartTime, "waitDuration", time.Since(waitStartTime), - "batchCount", len(b.pendingBatches)) + "batchCount", batchCount) b.executeBatches(batcherIterationID) } } @@ -263,31 +281,35 @@ func (b *Batcher[RequestPayload, ResponsePayload]) executeBatches(batcherIterati // Dispatch batches in parallel, as they are independent (different keys). for _, batch := range batches { - // TODO: use metrics instead? - log.FromContext(b.ctx).V(2).Info("begin executing batch", - "batcherIterationID", batcherIterationID, - "ID", batch.ID, - "key", batch.Key, - "size", len(batch.Requests)) - go func(batch *Batch[RequestPayload, ResponsePayload]) { - defer func() { - if r := recover(); r != nil { - log.FromContext(b.ctx).Error(fmt.Errorf("%v", r), "panic in batch executor, distributing error to callers") - err := fmt.Errorf("batch execution panicked: %v", r) - for _, req := range batch.Requests { - // Non-blocking: if executeBatch already wrote a response before - // panicking, the buffer is full — skip to avoid goroutine leak. - select { - case req.ResponseChan <- &Response[ResponsePayload]{Err: err}: - default: - } + b.executeBatchAsync(batcherIterationID, batch) + } +} + +func (b *Batcher[RequestPayload, ResponsePayload]) executeBatchAsync(batcherIterationID string, batch *Batch[RequestPayload, ResponsePayload]) { + // TODO: use metrics instead? + log.FromContext(b.ctx).V(2).Info("begin executing batch", + "batcherIterationID", batcherIterationID, + "ID", batch.ID, + "key", batch.Key, + "size", len(batch.Requests)) + go func(batch *Batch[RequestPayload, ResponsePayload]) { + defer func() { + if r := recover(); r != nil { + log.FromContext(b.ctx).Error(fmt.Errorf("%v", r), "panic in batch executor, distributing error to callers") + err := fmt.Errorf("batch execution panicked: %v", r) + for _, req := range batch.Requests { + // Non-blocking: if executeBatch already wrote a response before + // panicking, the buffer is full — skip to avoid goroutine leak. + select { + case req.ResponseChan <- &Response[ResponsePayload]{Err: err}: + default: } } - }() + } + }() - b.executeBatch(b.ctx, batch) - }(batch) - } + b.executeBatch(b.ctx, batch) + }(batch) } // drain fails all in-flight requests with a shutdown error. diff --git a/pkg/utils/batcher/batcher_test.go b/pkg/utils/batcher/batcher_test.go index 8bb2f4df92..e4574f9980 100644 --- a/pkg/utils/batcher/batcher_test.go +++ b/pkg/utils/batcher/batcher_test.go @@ -312,6 +312,53 @@ func TestBatcherFiresWhenMaxBatchSizeReached(t *testing.T) { g.Expect(batchSizes).ToNot(gomega.BeEmpty(), "at least one batch should have fired") } +func TestBatcherDoesNotExceedMaxBatchSizeWhenEnqueuedBeforeStart(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var mu sync.Mutex + var batchSizes []int + + b := New(ctx, testKeyFunc, func(ctx context.Context, batch *Batch[testItem, struct{}]) { + mu.Lock() + batchSizes = append(batchSizes, len(batch.Requests)) + mu.Unlock() + for _, req := range batch.Requests { + req.ResponseChan <- &Response[struct{}]{Err: nil} + } + }, Options{ + IdleTimeout: 50 * time.Millisecond, + MaxTimeout: 5 * time.Second, + MaxBatchSize: 3, + }) + + var responseChans []chan *Response[struct{}] + for i := 0; i < 7; i++ { + ch, err := b.Enqueue(testItem{Group: "same-group", Name: fmt.Sprintf("item-%d", i)}) + g.Expect(err).ToNot(gomega.HaveOccurred()) + responseChans = append(responseChans, ch) + } + b.Start() + + for i, ch := range responseChans { + select { + case resp := <-ch: + g.Expect(resp.Err).ToNot(gomega.HaveOccurred()) + case <-time.After(2 * time.Second): + t.Fatalf("request %d timed out", i) + } + } + + mu.Lock() + defer mu.Unlock() + g.Expect(batchSizes).ToNot(gomega.BeEmpty()) + for _, size := range batchSizes { + g.Expect(size).To(gomega.BeNumerically("<=", 3)) + } +} + func TestBatcherFiresAtMaxTimeout(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background())