Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 45 additions & 23 deletions pkg/utils/batcher/batcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ func (b *Batcher[RequestPayload, ResponsePayload]) Enqueue(payload RequestPayloa
Key: key,
}

var fullBatch *Batch[RequestPayload, ResponsePayload]

b.mu.Lock()

batch, exists := b.pendingBatches[req.Key]
Expand All @@ -144,11 +146,23 @@ func (b *Batcher[RequestPayload, ResponsePayload]) Enqueue(payload RequestPayloa
Requests: make([]*BatchedRequest[RequestPayload, ResponsePayload], 0, b.opts.MaxBatchSize),
}
b.pendingBatches[req.Key] = batch
} else if len(batch.Requests) >= b.opts.MaxBatchSize {
fullBatch = batch
batch = &Batch[RequestPayload, ResponsePayload]{
ID: uuid.New().String(),
Key: req.Key,
Requests: make([]*BatchedRequest[RequestPayload, ResponsePayload], 0, b.opts.MaxBatchSize),
}
b.pendingBatches[req.Key] = batch
}
batch.Requests = append(batch.Requests, req)

b.mu.Unlock()

if fullBatch != nil {
b.executeBatchAsync(uuid.New().String(), fullBatch)
}

// Alert the background loop (e.g., start timer, check execution conditions)
// Non-blocking signal (buffer=1 coalesces multiple enqueues)
select {
Expand Down Expand Up @@ -185,12 +199,16 @@ func (b *Batcher[RequestPayload, ResponsePayload]) run() {
// This is tolerable because requests typically arrive in bursts from the provisioner.
// Suggestion: if needed, we could add per-batch-key timers for more precise control, but it adds complexity.

b.mu.Lock()
batchCount := len(b.pendingBatches)
b.mu.Unlock()

// TODO: use metrics instead?
log.FromContext(b.ctx).V(2).Info("batcher iteration finishing wait, ready to execute batches",
"batcherIterationID", batcherIterationID,
"waitStartTime", waitStartTime,
"waitDuration", time.Since(waitStartTime),
"batchCount", len(b.pendingBatches))
"batchCount", batchCount)
b.executeBatches(batcherIterationID)
}
}
Expand Down Expand Up @@ -263,31 +281,35 @@ func (b *Batcher[RequestPayload, ResponsePayload]) executeBatches(batcherIterati

// Dispatch batches in parallel, as they are independent (different keys).
for _, batch := range batches {
// TODO: use metrics instead?
log.FromContext(b.ctx).V(2).Info("begin executing batch",
"batcherIterationID", batcherIterationID,
"ID", batch.ID,
"key", batch.Key,
"size", len(batch.Requests))
go func(batch *Batch[RequestPayload, ResponsePayload]) {
defer func() {
if r := recover(); r != nil {
log.FromContext(b.ctx).Error(fmt.Errorf("%v", r), "panic in batch executor, distributing error to callers")
err := fmt.Errorf("batch execution panicked: %v", r)
for _, req := range batch.Requests {
// Non-blocking: if executeBatch already wrote a response before
// panicking, the buffer is full — skip to avoid goroutine leak.
select {
case req.ResponseChan <- &Response[ResponsePayload]{Err: err}:
default:
}
b.executeBatchAsync(batcherIterationID, batch)
}
}

func (b *Batcher[RequestPayload, ResponsePayload]) executeBatchAsync(batcherIterationID string, batch *Batch[RequestPayload, ResponsePayload]) {
// TODO: use metrics instead?
log.FromContext(b.ctx).V(2).Info("begin executing batch",
"batcherIterationID", batcherIterationID,
"ID", batch.ID,
"key", batch.Key,
"size", len(batch.Requests))
go func(batch *Batch[RequestPayload, ResponsePayload]) {
defer func() {
if r := recover(); r != nil {
log.FromContext(b.ctx).Error(fmt.Errorf("%v", r), "panic in batch executor, distributing error to callers")
err := fmt.Errorf("batch execution panicked: %v", r)
for _, req := range batch.Requests {
// Non-blocking: if executeBatch already wrote a response before
// panicking, the buffer is full — skip to avoid goroutine leak.
select {
case req.ResponseChan <- &Response[ResponsePayload]{Err: err}:
default:
}
}
}()
}
}()

b.executeBatch(b.ctx, batch)
}(batch)
}
b.executeBatch(b.ctx, batch)
}(batch)
}

// drain fails all in-flight requests with a shutdown error.
Expand Down
47 changes: 47 additions & 0 deletions pkg/utils/batcher/batcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,53 @@ func TestBatcherFiresWhenMaxBatchSizeReached(t *testing.T) {
g.Expect(batchSizes).ToNot(gomega.BeEmpty(), "at least one batch should have fired")
}

func TestBatcherDoesNotExceedMaxBatchSizeWhenEnqueuedBeforeStart(t *testing.T) {
t.Parallel()
g := gomega.NewWithT(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var mu sync.Mutex
var batchSizes []int

b := New(ctx, testKeyFunc, func(ctx context.Context, batch *Batch[testItem, struct{}]) {
mu.Lock()
batchSizes = append(batchSizes, len(batch.Requests))
mu.Unlock()
for _, req := range batch.Requests {
req.ResponseChan <- &Response[struct{}]{Err: nil}
}
}, Options{
IdleTimeout: 50 * time.Millisecond,
MaxTimeout: 5 * time.Second,
MaxBatchSize: 3,
})

var responseChans []chan *Response[struct{}]
for i := 0; i < 7; i++ {
ch, err := b.Enqueue(testItem{Group: "same-group", Name: fmt.Sprintf("item-%d", i)})
g.Expect(err).ToNot(gomega.HaveOccurred())
responseChans = append(responseChans, ch)
}
b.Start()

for i, ch := range responseChans {
select {
case resp := <-ch:
g.Expect(resp.Err).ToNot(gomega.HaveOccurred())
case <-time.After(2 * time.Second):
t.Fatalf("request %d timed out", i)
}
}

mu.Lock()
defer mu.Unlock()
g.Expect(batchSizes).ToNot(gomega.BeEmpty())
for _, size := range batchSizes {
g.Expect(size).To(gomega.BeNumerically("<=", 3))
}
}

func TestBatcherFiresAtMaxTimeout(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
Expand Down