diff --git a/README.md b/README.md index fec2ad2..7ba1253 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,38 @@ func main() { } ``` +### Streaming uploads + +`PutObject` buffers the whole body in memory. For large or unbounded sources +(query exports, log streams), use `PutObjectStream`, which uploads from an +`io.Reader` via the AWS SDK transfer manager without buffering the full payload: + +```go +out, err := s3Client.PutObjectStream(ctx, &client.PutObjectStreamInput{ + Bucket: "my-bucket", + Key: "exports/large.csv", + Body: reader, // any io.Reader + ContentType: "text/csv", + MaxBytes: 500 * 1024 * 1024, // optional: abort past this many bytes +}) +if err != nil { + if errors.Is(err, client.ErrStreamTooLarge) { + // stream exceeded MaxBytes and was aborted + } + log.Fatal(err) +} +log.Printf("uploaded, etag=%s", out.ETag) +``` + +Notes: + +- The per-operation timeout (`S3_TIMEOUT`) is **not** applied to streaming + uploads, since they can legitimately run much longer than a normal request. + Control the deadline through the supplied `context.Context`. +- `MaxBytes` bounds the stream at the library level. The read-only and + size-limit MCP extensions guard the tool layer, not direct library calls; + `PutObjectStream` is currently a library-only capability (no MCP tool). + ### Extensibility Patterns **Middleware** wraps tool execution for cross-cutting concerns: diff --git a/go.mod b/go.mod index b392c91..8a76623 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/aws/aws-sdk-go-v2 v1.41.9 github.com/aws/aws-sdk-go-v2/config v1.32.20 github.com/aws/aws-sdk-go-v2/credentials v1.19.19 + github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager v0.2.3 github.com/aws/aws-sdk-go-v2/service/s3 v1.102.2 github.com/modelcontextprotocol/go-sdk v1.6.1 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index fc8de36..99c3c0c 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/aws/aws-sdk-go-v2/credentials v1.19.19 h1:yuFzSV1U0aRNYCQGVaTY2zW2M/L github.com/aws/aws-sdk-go-v2/credentials v1.19.19/go.mod h1:7y63L1kGzeoDlJaQ3Z578KrnmfBut96JjvJUzGwR+YE= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.25 h1:0w6dCiO8iez+YKwRhRBlL1CH/E3GTfdkuzrwj1by8vo= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.25/go.mod h1:9FDWUothyr5RCRAHc45XOiVCzUR8n/IhCYX+uVqw6vk= +github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager v0.2.3 h1:w5OoDiMN6x53ROmiIImGzmVcxXv2q1GXY+aKV4WAJYM= +github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager v0.2.3/go.mod h1:dAhgYp776bX3LuWvnSCFwQEjNs6fuFg7YXIy5PXcP3Q= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.25 h1:Uii3frf9ztec/ABM2/FSH9/z7PLzxfpG8h4RpkUFflQ= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.25/go.mod h1:G6kntsA2GorAxDPbap6xgB2F+amSLUF8GJTi7PUoX44= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.25 h1:r1+/l6m+WaUJF9HISEsNOLHSNj5EXYQxK8VX6Cz9NlA= diff --git a/pkg/client/client.go b/pkg/client/client.go index 290738b..d42d43b 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" ) @@ -19,6 +20,7 @@ import ( type Client struct { s3Client S3API presignClient PresignAPI + uploader ObjectUploader config *Config connectionName string } @@ -186,9 +188,14 @@ func New(ctx context.Context, cfg *Config) (*Client, error) { // Create presign client presignClient := s3.NewPresignClient(s3Client) + // Create the streaming/multipart uploader. It shares the same underlying + // S3 client so it honors the configured endpoint, credentials, and region. + uploader := transfermanager.New(s3Client) + return &Client{ s3Client: s3Client, presignClient: presignClient, + uploader: uploader, config: cfg.Clone(), connectionName: cfg.Name, }, nil diff --git a/pkg/client/mock_test.go b/pkg/client/mock_test.go index 7243962..1943688 100644 --- a/pkg/client/mock_test.go +++ b/pkg/client/mock_test.go @@ -7,6 +7,7 @@ import ( "time" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager" "github.com/aws/aws-sdk-go-v2/service/s3" ) @@ -98,6 +99,23 @@ func (m *mockPresignAPI) PresignPutObject(ctx context.Context, params *s3.PutObj }, nil } +// mockUploader is a mock implementation of ObjectUploader for testing the +// streaming upload path. +type mockUploader struct { + uploadObjectFunc func( + ctx context.Context, input *transfermanager.UploadObjectInput, opts ...func(*transfermanager.Options), + ) (*transfermanager.UploadObjectOutput, error) +} + +func (m *mockUploader) UploadObject( + ctx context.Context, input *transfermanager.UploadObjectInput, opts ...func(*transfermanager.Options), +) (*transfermanager.UploadObjectOutput, error) { + if m.uploadObjectFunc != nil { + return m.uploadObjectFunc(ctx, input, opts...) + } + return &transfermanager.UploadObjectOutput{}, nil +} + // newMockClient creates a Client with mock S3 and presign APIs for testing. func newMockClient(s3api *mockS3API, presignAPI *mockPresignAPI) *Client { if s3api == nil { diff --git a/pkg/client/s3api.go b/pkg/client/s3api.go index 236ec7b..e5718d5 100644 --- a/pkg/client/s3api.go +++ b/pkg/client/s3api.go @@ -4,6 +4,7 @@ import ( "context" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager" "github.com/aws/aws-sdk-go-v2/service/s3" ) @@ -25,8 +26,18 @@ type PresignAPI interface { PresignPutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.PresignOptions)) (*v4.PresignedHTTPRequest, error) } +// ObjectUploader abstracts a streaming/multipart upload. It is satisfied by +// *transfermanager.Client from the AWS SDK and is defined here, at the consumer, +// so the streaming path can be mocked in unit tests. +type ObjectUploader interface { + UploadObject( + ctx context.Context, input *transfermanager.UploadObjectInput, opts ...func(*transfermanager.Options), + ) (*transfermanager.UploadObjectOutput, error) +} + // Compile-time interface checks. var ( - _ S3API = (*s3.Client)(nil) - _ PresignAPI = (*s3.PresignClient)(nil) + _ S3API = (*s3.Client)(nil) + _ PresignAPI = (*s3.PresignClient)(nil) + _ ObjectUploader = (*transfermanager.Client)(nil) ) diff --git a/pkg/client/stream.go b/pkg/client/stream.go new file mode 100644 index 0000000..bb33917 --- /dev/null +++ b/pkg/client/stream.go @@ -0,0 +1,103 @@ +package client + +import ( + "context" + "errors" + "fmt" + "io" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager" +) + +// ErrStreamTooLarge indicates that a streaming upload was aborted because the +// body exceeded the caller-supplied size limit. Callers can test for it with +// errors.Is. +var ErrStreamTooLarge = errors.New("stream exceeds maximum allowed size") + +// PutObjectStreamInput contains the parameters for a streaming/multipart upload. +// +// Unlike PutObjectInput, the body is an io.Reader rather than a []byte, so the +// payload is never fully buffered in memory. This makes it suitable for large +// or unbounded sources such as query exports. +type PutObjectStreamInput struct { + Bucket string + Key string + Body io.Reader + ContentType string + Metadata map[string]string + + // MaxBytes, when greater than zero, aborts the upload once more than + // MaxBytes have been read from Body, returning an error that wraps + // ErrStreamTooLarge. A value of zero means no limit is enforced here. + MaxBytes int64 +} + +// PutObjectStream uploads an object from an io.Reader using the AWS SDK transfer +// manager, which splits the body into parts and uploads them without buffering +// the full payload in memory. +// +// Unlike the buffered operations on Client, the per-operation timeout +// (S3_TIMEOUT) is intentionally NOT applied here: a streaming upload of a large +// object can legitimately run far longer than an ordinary request. Callers +// control the deadline through ctx. +// +// Like the other write methods on Client, PutObjectStream performs the upload +// directly; the read-only and size-limit MCP extensions guard the tool layer, +// not direct library calls. Use MaxBytes to bound a stream at the library level. +func (c *Client) PutObjectStream(ctx context.Context, input *PutObjectStreamInput) (*PutObjectOutput, error) { + if input == nil { + return nil, fmt.Errorf("put object stream: input is required") + } + if input.Body == nil { + return nil, fmt.Errorf("put object stream: body is required") + } + if c.uploader == nil { + return nil, fmt.Errorf("put object stream: uploader is not configured") + } + + body := input.Body + if input.MaxBytes > 0 { + body = &limitReader{r: body, max: input.MaxBytes} + } + + uploadInput := &transfermanager.UploadObjectInput{ + Bucket: aws.String(input.Bucket), + Key: aws.String(input.Key), + Body: body, + } + if input.ContentType != "" { + uploadInput.ContentType = aws.String(input.ContentType) + } + if len(input.Metadata) > 0 { + uploadInput.Metadata = input.Metadata + } + + output, err := c.uploader.UploadObject(ctx, uploadInput) + if err != nil { + return nil, fmt.Errorf("failed to stream object: %w", err) + } + + return &PutObjectOutput{ + ETag: aws.ToString(output.ETag), + VersionID: aws.ToString(output.VersionID), + }, nil +} + +// limitReader wraps an io.Reader and returns an error wrapping ErrStreamTooLarge +// once more than max bytes have been read. It enforces an upper bound on a +// stream whose length is not known in advance. +type limitReader struct { + r io.Reader + max int64 + read int64 +} + +func (l *limitReader) Read(p []byte) (int, error) { + n, err := l.r.Read(p) + l.read += int64(n) + if l.read > l.max { + return n, fmt.Errorf("read %d bytes: %w of %d bytes", l.read, ErrStreamTooLarge, l.max) + } + return n, err +} diff --git a/pkg/client/stream_test.go b/pkg/client/stream_test.go new file mode 100644 index 0000000..ec3e444 --- /dev/null +++ b/pkg/client/stream_test.go @@ -0,0 +1,194 @@ +package client + +import ( + "context" + "errors" + "io" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager" +) + +// newStreamClient builds a Client wired with a mock uploader for streaming tests. +func newStreamClient(up *mockUploader) *Client { + c := newMockClient(nil, nil) + c.uploader = up + return c +} + +func TestClient_PutObjectStream_Success(t *testing.T) { + const payload = "streamed content" + + var ( + gotBucket, gotKey, gotContentType string + gotMetadata map[string]string + gotBody string + ) + + up := &mockUploader{ + uploadObjectFunc: func( + _ context.Context, input *transfermanager.UploadObjectInput, _ ...func(*transfermanager.Options), + ) (*transfermanager.UploadObjectOutput, error) { + gotBucket = aws.ToString(input.Bucket) + gotKey = aws.ToString(input.Key) + gotContentType = aws.ToString(input.ContentType) + gotMetadata = input.Metadata + b, err := io.ReadAll(input.Body) + if err != nil { + return nil, err + } + gotBody = string(b) + return &transfermanager.UploadObjectOutput{ + ETag: aws.String("\"streametag\""), + VersionID: aws.String("v9"), + }, nil + }, + } + + result, err := newStreamClient(up).PutObjectStream(context.Background(), &PutObjectStreamInput{ + Bucket: "my-bucket", + Key: "export.csv", + Body: strings.NewReader(payload), + ContentType: "text/csv", + Metadata: map[string]string{"author": "test"}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if gotBucket != "my-bucket" { + t.Errorf("bucket: got %q, want my-bucket", gotBucket) + } + if gotKey != "export.csv" { + t.Errorf("key: got %q, want export.csv", gotKey) + } + if gotContentType != "text/csv" { + t.Errorf("content type: got %q, want text/csv", gotContentType) + } + if gotMetadata["author"] != "test" { + t.Errorf("metadata author: got %q, want test", gotMetadata["author"]) + } + if gotBody != payload { + t.Errorf("body: got %q, want %q", gotBody, payload) + } + if result.ETag != "\"streametag\"" { + t.Errorf("etag: got %q, want '\"streametag\"'", result.ETag) + } + if result.VersionID != "v9" { + t.Errorf("version: got %q, want v9", result.VersionID) + } +} + +func TestClient_PutObjectStream_Validation(t *testing.T) { + tests := []struct { + name string + setup func() (*Client, *PutObjectStreamInput) + }{ + { + name: "nil input", + setup: func() (*Client, *PutObjectStreamInput) { + return newStreamClient(&mockUploader{}), nil + }, + }, + { + name: "nil body", + setup: func() (*Client, *PutObjectStreamInput) { + return newStreamClient(&mockUploader{}), &PutObjectStreamInput{Bucket: "b", Key: "k"} + }, + }, + { + name: "uploader not configured", + setup: func() (*Client, *PutObjectStreamInput) { + c := newMockClient(nil, nil) // no uploader wired + return c, &PutObjectStreamInput{Bucket: "b", Key: "k", Body: strings.NewReader("x")} + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, in := tt.setup() + if _, err := c.PutObjectStream(context.Background(), in); err == nil { + t.Fatal("expected error, got nil") + } + }) + } +} + +func TestClient_PutObjectStream_UploaderError(t *testing.T) { + up := &mockUploader{ + uploadObjectFunc: func( + _ context.Context, _ *transfermanager.UploadObjectInput, _ ...func(*transfermanager.Options), + ) (*transfermanager.UploadObjectOutput, error) { + return nil, errors.New("access denied") + }, + } + + _, err := newStreamClient(up).PutObjectStream(context.Background(), &PutObjectStreamInput{ + Bucket: "b", Key: "k", Body: strings.NewReader("data"), + }) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "access denied") { + t.Errorf("error should wrap underlying cause, got: %v", err) + } +} + +func TestClient_PutObjectStream_MaxBytes(t *testing.T) { + // The uploader drains the body so the limit reader is exercised. + drain := &mockUploader{ + uploadObjectFunc: func( + _ context.Context, input *transfermanager.UploadObjectInput, _ ...func(*transfermanager.Options), + ) (*transfermanager.UploadObjectOutput, error) { + if _, err := io.ReadAll(input.Body); err != nil { + return nil, err + } + return &transfermanager.UploadObjectOutput{}, nil + }, + } + + t.Run("under limit succeeds", func(t *testing.T) { + _, err := newStreamClient(drain).PutObjectStream(context.Background(), &PutObjectStreamInput{ + Bucket: "b", Key: "k", Body: strings.NewReader("12345"), MaxBytes: 10, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("over limit aborts", func(t *testing.T) { + _, err := newStreamClient(drain).PutObjectStream(context.Background(), &PutObjectStreamInput{ + Bucket: "b", Key: "k", Body: strings.NewReader("this body is too large"), MaxBytes: 4, + }) + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, ErrStreamTooLarge) { + t.Errorf("expected ErrStreamTooLarge, got: %v", err) + } + }) +} + +func TestLimitReader(t *testing.T) { + t.Run("passes through under limit", func(t *testing.T) { + lr := &limitReader{r: strings.NewReader("hello"), max: 5} + b, err := io.ReadAll(lr) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(b) != "hello" { + t.Errorf("got %q, want hello", string(b)) + } + }) + + t.Run("errors over limit", func(t *testing.T) { + lr := &limitReader{r: strings.NewReader("hello world"), max: 5} + _, err := io.ReadAll(lr) + if !errors.Is(err, ErrStreamTooLarge) { + t.Errorf("expected ErrStreamTooLarge, got: %v", err) + } + }) +}