Skip to content

Commit

Permalink
fix: Fix Duplicate InstanceID Batching Bug (aws#3717)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-innis authored Apr 7, 2023
1 parent 0bcd446 commit dc49a34
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 50 deletions.
46 changes: 23 additions & 23 deletions pkg/batcher/describeinstances.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/mitchellh/hashstructure/v2"
"github.com/samber/lo"
"k8s.io/apimachinery/pkg/util/sets"
"knative.dev/pkg/logging"
)

Expand Down Expand Up @@ -68,27 +69,28 @@ func execDescribeInstancesBatch(ec2api ec2iface.EC2API) BatchExecutor[ec2.Descri
for _, input := range inputs[1:] {
firstInput.InstanceIds = append(firstInput.InstanceIds, input.InstanceIds...)
}

missingInstanceIDs := lo.SliceToMap(firstInput.InstanceIds, func(instanceID *string) (string, struct{}) { return *instanceID, struct{}{} })
missingInstanceIDs := sets.NewString(lo.Map(firstInput.InstanceIds, func(i *string, _ int) string { return *i })...)

// Execute fully aggregated request
// We don't care about the error here since we'll break up the batch upon any sort of failure
_ = ec2api.DescribeInstancesPagesWithContext(ctx, firstInput, func(dio *ec2.DescribeInstancesOutput, b bool) bool {
for _, r := range dio.Reservations {
for _, instance := range r.Instances {
if _, reqID, ok := lo.FindLastIndexOf(inputs, func(input *ec2.DescribeInstancesInput) bool {
return *input.InstanceIds[0] == *instance.InstanceId
}); ok {
delete(missingInstanceIDs, *instance.InstanceId)
inst := instance
results[reqID] = Result[ec2.DescribeInstancesOutput]{Output: &ec2.DescribeInstancesOutput{
Reservations: []*ec2.Reservation{{
OwnerId: r.OwnerId,
RequesterId: r.RequesterId,
ReservationId: r.ReservationId,
Instances: []*ec2.Instance{inst},
}},
}}
missingInstanceIDs.Delete(*instance.InstanceId)

// Find all indexes where we are requesting this instance and populate with the result
for reqID := range inputs {
if *inputs[reqID].InstanceIds[0] == *instance.InstanceId {
inst := instance // locally scoped to avoid pointer pollution in a range loop
results[reqID] = Result[ec2.DescribeInstancesOutput]{Output: &ec2.DescribeInstancesOutput{
Reservations: []*ec2.Reservation{{
OwnerId: r.OwnerId,
RequesterId: r.RequesterId,
ReservationId: r.ReservationId,
Instances: []*ec2.Instance{inst},
}},
}}
}
}
}
}
Expand All @@ -107,15 +109,13 @@ func execDescribeInstancesBatch(ec2api ec2iface.EC2API) BatchExecutor[ec2.Descri
out, err := ec2api.DescribeInstancesWithContext(ctx, &ec2.DescribeInstancesInput{
Filters: firstInput.Filters,
InstanceIds: []*string{aws.String(instanceID)}})
// Order by inputs' index so that instance IDs from input and output are in the same order
_, reqID, ok := lo.FindIndexOf(inputs, func(input *ec2.DescribeInstancesInput) bool {
return *input.InstanceIds[0] == instanceID
})
// if the instance ID returned from DescribeInstances was not passed as a DescribeInstancesInput, just skip
if !ok {
return

// Find all indexes where we are requesting this instance and populate with the result
for reqID := range inputs {
if *inputs[reqID].InstanceIds[0] == instanceID {
results[reqID] = Result[ec2.DescribeInstancesOutput]{Output: out, Err: err}
}
}
results[reqID] = Result[ec2.DescribeInstancesOutput]{Output: out, Err: err}
}(instanceID)
}
wg.Wait()
Expand Down
29 changes: 29 additions & 0 deletions pkg/batcher/describeinstances_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,35 @@ var _ = Describe("DescribeInstances Batcher", func() {
call := fakeEC2API.DescribeInstancesBehavior.CalledWithInput.Pop()
Expect(len(call.InstanceIds)).To(BeNumerically("==", len(instanceIDs)))
})
It("should batch input correctly when receiving multiple calls with the same instance id", func() {
instanceIDs := []string{"i-1", "i-1", "i-1", "i-2", "i-2"}
for _, id := range instanceIDs {
fakeEC2API.Instances.Store(id, &ec2.Instance{InstanceId: aws.String(id)})
}

var wg sync.WaitGroup
var receivedInstance int64
for _, instanceID := range instanceIDs {
wg.Add(1)
go func(instanceID string) {
defer GinkgoRecover()
defer wg.Done()
rsp, err := cfb.DescribeInstances(ctx, &ec2.DescribeInstancesInput{
InstanceIds: []*string{aws.String(instanceID)},
})
Expect(err).To(BeNil())
atomic.AddInt64(&receivedInstance, 1)
Expect(rsp.Reservations).To(HaveLen(1))
Expect(rsp.Reservations[0].Instances).To(HaveLen(1))
}(instanceID)
}
wg.Wait()

Expect(receivedInstance).To(BeNumerically("==", len(instanceIDs)))
Expect(fakeEC2API.DescribeInstancesBehavior.CalledWithInput.Len()).To(BeNumerically("==", 1))
call := fakeEC2API.DescribeInstancesBehavior.CalledWithInput.Pop()
Expect(len(call.InstanceIds)).To(BeNumerically("==", len(instanceIDs)))
})
It("should handle partial terminations on batched call and recover with individual requests", func() {
instanceIDs := []string{"i-1", "i-2", "i-3"}
// Output with only the first Instance
Expand Down
50 changes: 23 additions & 27 deletions pkg/batcher/terminateinstances.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/samber/lo"
"k8s.io/apimachinery/pkg/util/sets"
"knative.dev/pkg/logging"
)

Expand Down Expand Up @@ -61,7 +62,7 @@ func execTerminateInstancesBatch(ec2api ec2iface.EC2API) BatchExecutor[ec2.Termi
firstInput.InstanceIds = append(firstInput.InstanceIds, input.InstanceIds...)
}
// Create a set of all instance IDs
stillRunning := lo.SliceToMap(firstInput.InstanceIds, func(instanceID *string) (string, struct{}) { return *instanceID, struct{}{} })
stillRunning := sets.NewString(lo.Map(firstInput.InstanceIds, func(i *string, _ int) string { return *i })...)

// Execute fully aggregated request
// We don't care about the error here since we'll break up the batch upon any sort of failure
Expand All @@ -78,24 +79,21 @@ func execTerminateInstancesBatch(ec2api ec2iface.EC2API) BatchExecutor[ec2.Termi
for _, instanceStateChanges := range output.TerminatingInstances {
// Remove all instances that successfully terminated and separate into distinct outputs
if lo.Contains([]string{ec2.InstanceStateNameShuttingDown, ec2.InstanceStateNameTerminated}, *instanceStateChanges.CurrentState.Name) {
delete(stillRunning, *instanceStateChanges.InstanceId)
// Order by inputs' index so that instance IDs from input and output are in the same order
_, reqID, ok := lo.FindIndexOf(inputs, func(input *ec2.TerminateInstancesInput) bool {
return *input.InstanceIds[0] == *instanceStateChanges.InstanceId
})
// if the instance ID returned from TerminateInstances was not passed as a TerminateInstanceInput, just skip
if !ok {
continue
}
// add instance ID as a separate output
results[reqID] = Result[ec2.TerminateInstancesOutput]{
Output: &ec2.TerminateInstancesOutput{
TerminatingInstances: []*ec2.InstanceStateChange{{
InstanceId: instanceStateChanges.InstanceId,
CurrentState: instanceStateChanges.CurrentState,
PreviousState: instanceStateChanges.PreviousState,
}},
},
stillRunning.Delete(*instanceStateChanges.InstanceId)

// Find all indexes where we are requesting this instance and populate with the result
for reqID := range inputs {
if *inputs[reqID].InstanceIds[0] == *instanceStateChanges.InstanceId {
results[reqID] = Result[ec2.TerminateInstancesOutput]{
Output: &ec2.TerminateInstancesOutput{
TerminatingInstances: []*ec2.InstanceStateChange{{
InstanceId: instanceStateChanges.InstanceId,
CurrentState: instanceStateChanges.CurrentState,
PreviousState: instanceStateChanges.PreviousState,
}},
},
}
}
}
}
}
Expand All @@ -110,15 +108,13 @@ func execTerminateInstancesBatch(ec2api ec2iface.EC2API) BatchExecutor[ec2.Termi
defer wg.Done()
// try to execute separately
out, err := ec2api.TerminateInstancesWithContext(ctx, &ec2.TerminateInstancesInput{InstanceIds: []*string{aws.String(instanceID)}})
// Order by inputs' index so that instance IDs from input and output are in the same order
_, reqID, ok := lo.FindIndexOf(inputs, func(input *ec2.TerminateInstancesInput) bool {
return *input.InstanceIds[0] == instanceID
})
// if the instance ID returned from TerminateInstances was not passed as a TerminateInstanceInput, just skip
if !ok {
return

// Find all indexes where we are requesting this instance and populate with the result
for reqID := range inputs {
if *inputs[reqID].InstanceIds[0] == instanceID {
results[reqID] = Result[ec2.TerminateInstancesOutput]{Output: out, Err: err}
}
}
results[reqID] = Result[ec2.TerminateInstancesOutput]{Output: out, Err: err}
}(instanceID)
}
wg.Wait()
Expand Down
28 changes: 28 additions & 0 deletions pkg/batcher/terminateinstances_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,34 @@ var _ = Describe("TerminateInstances Batcher", func() {
call := fakeEC2API.TerminateInstancesBehavior.CalledWithInput.Pop()
Expect(len(call.InstanceIds)).To(BeNumerically("==", len(instanceIDs)))
})
It("should batch input correctly when receiving multiple calls with the same instance id", func() {
instanceIDs := []string{"i-1", "i-1", "i-1", "i-2", "i-2"}
for _, id := range instanceIDs {
fakeEC2API.Instances.Store(id, &ec2.Instance{})
}

var wg sync.WaitGroup
var receivedInstance int64
for _, instanceID := range instanceIDs {
wg.Add(1)
go func(instanceID string) {
defer GinkgoRecover()
defer wg.Done()
rsp, err := cfb.TerminateInstances(ctx, &ec2.TerminateInstancesInput{
InstanceIds: []*string{aws.String(instanceID)},
})
Expect(err).To(BeNil())
atomic.AddInt64(&receivedInstance, 1)
Expect(rsp.TerminatingInstances).To(HaveLen(1))
}(instanceID)
}
wg.Wait()

Expect(receivedInstance).To(BeNumerically("==", len(instanceIDs)))
Expect(fakeEC2API.TerminateInstancesBehavior.CalledWithInput.Len()).To(BeNumerically("==", 1))
call := fakeEC2API.TerminateInstancesBehavior.CalledWithInput.Pop()
Expect(len(call.InstanceIds)).To(BeNumerically("==", len(instanceIDs)))
})
It("should handle partial terminations on batched call and recover with individual requests", func() {
instanceIDs := []string{"i-1", "i-2", "i-3"}
// Output with only the first Terminating Instance
Expand Down

0 comments on commit dc49a34

Please sign in to comment.