diff --git a/publisher/kinesis/kinesis.go b/publisher/kinesis/kinesis.go index dc56189..68f8413 100644 --- a/publisher/kinesis/kinesis.go +++ b/publisher/kinesis/kinesis.go @@ -20,14 +20,15 @@ import ( var globalCtx = context.Background() -type KinesisClient interface { +// Client is an interface to *kinesis.Client +type Client interface { PutRecord(context.Context, *kinesis.PutRecordInput, ...func(*kinesis.Options)) (*kinesis.PutRecordOutput, error) DescribeStreamSummary(context.Context, *kinesis.DescribeStreamSummaryInput, ...func(*kinesis.Options)) (*kinesis.DescribeStreamSummaryOutput, error) CreateStream(context.Context, *kinesis.CreateStreamInput, ...func(*kinesis.Options)) (*kinesis.CreateStreamOutput, error) } type Publisher struct { - client KinesisClient + client Client streamLock sync.RWMutex streams map[string]bool @@ -215,7 +216,7 @@ func WithStreamProbleInterval(interval time.Duration) Opt { } } -func New(client *kinesis.Client, opts ...Opt) (*Publisher, error) { +func New(client Client, opts ...Opt) (*Publisher, error) { p := &Publisher{ client: client, streamPattern: "%s", diff --git a/publisher/kinesis/kinesis_test.go b/publisher/kinesis/kinesis_test.go index 8460ed5..3398c1b 100644 --- a/publisher/kinesis/kinesis_test.go +++ b/publisher/kinesis/kinesis_test.go @@ -18,93 +18,90 @@ func TestKinesisProducer_UnitTest(t *testing.T) { Type: "unknown", }, } - t.Run("should return an error if stream existence check fails", func(t *testing.T) { - client := &mockKinesisClient{} - client.On( - "DescribeStreamSummary", - mock.Anything, - &kinesis.DescribeStreamSummaryInput{ - StreamName: aws.String("unknown"), + testCases := []struct { + Desc string + Init func(*mockKinesisClient) + Opts []Opt + ExpectedErr string + }{ + { + Desc: "should return an error if stream existence check fails", + Init: func(client *mockKinesisClient) { + client.On( + "DescribeStreamSummary", + mock.Anything, + &kinesis.DescribeStreamSummaryInput{ + StreamName: aws.String("unknown"), + }, + mock.Anything, + ).Return( + &kinesis.DescribeStreamSummaryOutput{}, + fmt.Errorf("simulated error"), + ).Once() }, - mock.Anything, - ).Return( - &kinesis.DescribeStreamSummaryOutput{}, - fmt.Errorf("simulated error"), - ).Once() - defer client.AssertExpectations(t) - - p, err := New( - nil, // we will override it later - WithStreamAutocreate(true), - ) - if err != nil { - t.Errorf("error constructing client: %v", err) - return - } - p.client = client - - err = p.ProduceBulk(events, "") - assert.Error(t, err, "error when sending message: simulated error") - }) - t.Run("should return an error if stream creation exceeds resource limit", func(t *testing.T) { - client := &mockKinesisClient{} - - client.On( - "DescribeStreamSummary", - mock.Anything, - &kinesis.DescribeStreamSummaryInput{ - StreamName: aws.String("unknown"), + Opts: []Opt{ + WithStreamAutocreate(true), }, - mock.Anything, - ).Return( - &kinesis.DescribeStreamSummaryOutput{}, - &types.ResourceNotFoundException{}, - ).Once() - - client.On("CreateStream", mock.Anything, mock.Anything, mock.Anything). - Return( - &kinesis.CreateStreamOutput{}, - &types.LimitExceededException{ - Message: aws.String("stream limit reached"), - }, - ).Once() - defer client.AssertExpectations(t) - - p, err := New( - nil, // we will override it later - WithStreamAutocreate(true), - ) - if err != nil { - t.Errorf("error constructing client: %v", err) - return - } - p.client = client - - err = p.ProduceBulk(events, "") - assert.Error(t, err, "error when sending messages: LimitExceededException: stream limit reached") - }) - t.Run("should return an error if rate limit is exceeded", func(t *testing.T) { + ExpectedErr: "error when sending message: simulated error", + }, + { + Desc: "should return an error if stream creation exceeds resource limit", + Init: func(client *mockKinesisClient) { + client.On( + "DescribeStreamSummary", + mock.Anything, + &kinesis.DescribeStreamSummaryInput{ + StreamName: aws.String("unknown"), + }, + mock.Anything, + ).Return( + &kinesis.DescribeStreamSummaryOutput{}, + &types.ResourceNotFoundException{}, + ).Once() - client := &mockKinesisClient{} + client.On("CreateStream", mock.Anything, mock.Anything, mock.Anything). + Return( + &kinesis.CreateStreamOutput{}, + &types.LimitExceededException{ + Message: aws.String("stream limit reached"), + }, + ).Once() + }, + Opts: []Opt{ + WithStreamAutocreate(true), + }, + ExpectedErr: "error when sending messages: LimitExceededException: stream limit reached", + }, + { + Desc: "should return an error if rate limit is exceeded", + Init: func(client *mockKinesisClient) { + client.On("PutRecord", mock.Anything, mock.Anything, mock.Anything). + Return( + &kinesis.PutRecordOutput{}, + &types.ProvisionedThroughputExceededException{ + Message: aws.String("put limit exceeded"), + }, + ).Once() + }, + ExpectedErr: "error when sending messages: ProvisionedThroughputExceededException: put limit exceeded", + }, + } + for _, testCase := range testCases { + t.Run(testCase.Desc, func(t *testing.T) { + client := &mockKinesisClient{} + testCase.Init(client) + defer client.AssertExpectations(t) - client.On("PutRecord", mock.Anything, mock.Anything, mock.Anything). - Return( - &kinesis.PutRecordOutput{}, - &types.ProvisionedThroughputExceededException{ - Message: aws.String("put limit exceeded"), - }, - ).Once() - defer client.AssertExpectations(t) + p, err := New(client, testCase.Opts...) + if err != nil { + t.Errorf("error constructing client: %v", err) + return + } - p, err := New(nil) - if err != nil { - t.Errorf("error constructing client: %v", err) - return - } - p.client = client + err = p.ProduceBulk(events, "") + assert.Error(t, err, testCase.ExpectedErr) + }) - err = p.ProduceBulk(events, "") - assert.Error(t, err, "error when sending messages: ProvisionedThroughputExceededException: put limit exceeded") - }) + } }