diff --git a/common/archiver/s3store/aws_credentials.go b/common/archiver/s3store/aws_credentials.go new file mode 100644 index 0000000000..b82f044870 --- /dev/null +++ b/common/archiver/s3store/aws_credentials.go @@ -0,0 +1,76 @@ +// AWS Credential Provider for S3 Archiver + +package s3store + +import ( + "fmt" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "go.temporal.io/server/common/config" +) + +// newS3Credentials creates AWS credentials based on the provided configuration. +// It supports three credential provider types: +// - "static": Uses explicitly provided access key and secret +// - "environment": Reads credentials from environment variables +// - "aws-sdk-default" or empty: Uses AWS SDK default credential chain +// +// Returns nil for "aws-sdk-default" to allow session.NewSession to use the default credential chain. +func newS3Credentials(cfg *config.S3Archiver) (*credentials.Credentials, error) { + // Default to aws-sdk-default if not specified (backward compatibility) + provider := strings.ToLower(cfg.CredentialProvider) + if provider == "" { + provider = "aws-sdk-default" + } + + switch provider { + case "static": + return credentials.NewStaticCredentials( + cfg.Static.AccessKeyID, + cfg.Static.SecretAccessKey, + cfg.Static.Token, + ), nil + + case "environment": + return credentials.NewEnvCredentials(), nil + + case "aws-sdk-default": + // Return nil to let session.NewSession use default credential chain + return nil, nil + + default: + return nil, fmt.Errorf( + "unknown AWS credential provider specified: %q. Accepted options are 'static', 'environment', or 'aws-sdk-default'", + cfg.CredentialProvider, + ) + } +} + +// createS3Session creates an AWS session with the provided configuration and credentials. +func createS3Session(cfg *config.S3Archiver) (*session.Session, error) { + if len(cfg.Region) == 0 { + return nil, errEmptyAwsRegion + } + + creds, err := newS3Credentials(cfg) + if err != nil { + return nil, err + } + + s3Config := &aws.Config{ + Endpoint: cfg.Endpoint, + Region: aws.String(cfg.Region), + S3ForcePathStyle: aws.Bool(cfg.S3ForcePathStyle), + LogLevel: (*aws.LogLevelType)(&cfg.LogLevel), + } + + // Only set credentials if explicitly provided (not aws-sdk-default) + if creds != nil { + s3Config.Credentials = creds + } + + return session.NewSession(s3Config) +} diff --git a/common/archiver/s3store/aws_credentials_test.go b/common/archiver/s3store/aws_credentials_test.go new file mode 100644 index 0000000000..02645924d7 --- /dev/null +++ b/common/archiver/s3store/aws_credentials_test.go @@ -0,0 +1,292 @@ +package s3store + +import ( + "os" + "testing" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.temporal.io/server/common/config" +) + +func TestNewS3Credentials_Static(t *testing.T) { + cfg := &config.S3Archiver{ + Region: "us-east-1", + CredentialProvider: "static", + Static: config.S3StaticCredentialProvider{ + AccessKeyID: "test-access-key", + SecretAccessKey: "test-secret-key", + }, + } + + creds, err := newS3Credentials(cfg) + require.NoError(t, err) + require.NotNil(t, creds) + + // Verify credentials can be retrieved + value, err := creds.Get() + require.NoError(t, err) + assert.Equal(t, "test-access-key", value.AccessKeyID) + assert.Equal(t, "test-secret-key", value.SecretAccessKey) + assert.Empty(t, value.SessionToken) +} + +func TestNewS3Credentials_StaticWithToken(t *testing.T) { + cfg := &config.S3Archiver{ + Region: "us-east-1", + CredentialProvider: "static", + Static: config.S3StaticCredentialProvider{ + AccessKeyID: "test-access-key", + SecretAccessKey: "test-secret-key", + Token: "test-session-token", + }, + } + + creds, err := newS3Credentials(cfg) + require.NoError(t, err) + require.NotNil(t, creds) + + // Verify credentials can be retrieved + value, err := creds.Get() + require.NoError(t, err) + assert.Equal(t, "test-access-key", value.AccessKeyID) + assert.Equal(t, "test-secret-key", value.SecretAccessKey) + assert.Equal(t, "test-session-token", value.SessionToken) +} + +func TestNewS3Credentials_Environment(t *testing.T) { + cfg := &config.S3Archiver{ + Region: "us-east-1", + CredentialProvider: "environment", + } + + creds, err := newS3Credentials(cfg) + require.NoError(t, err) + require.NotNil(t, creds) + + // Note: We can't validate the actual credential values without setting env vars, + // but we can verify the credentials object was created + assert.NotNil(t, creds) +} + +func TestNewS3Credentials_AwsSdkDefault(t *testing.T) { + cfg := &config.S3Archiver{ + Region: "us-east-1", + CredentialProvider: "aws-sdk-default", + } + + creds, err := newS3Credentials(cfg) + require.NoError(t, err) + // Should return nil to allow session to use default credential chain + assert.Nil(t, creds) +} + +func TestNewS3Credentials_EmptyProvider(t *testing.T) { + cfg := &config.S3Archiver{ + Region: "us-east-1", + CredentialProvider: "", + } + + creds, err := newS3Credentials(cfg) + require.NoError(t, err) + // Empty provider should default to aws-sdk-default (nil) + assert.Nil(t, creds) +} + +func TestNewS3Credentials_InvalidProvider(t *testing.T) { + cfg := &config.S3Archiver{ + Region: "us-east-1", + CredentialProvider: "invalid-provider", + } + + creds, err := newS3Credentials(cfg) + require.Error(t, err) + assert.Nil(t, creds) + assert.Contains(t, err.Error(), "unknown AWS credential provider") + assert.Contains(t, err.Error(), "invalid-provider") +} + +func TestNewS3Credentials_CaseInsensitive(t *testing.T) { + testCases := []struct { + name string + provider string + expectNil bool + }{ + { + name: "uppercase STATIC", + provider: "STATIC", + expectNil: false, + }, + { + name: "mixed case Static", + provider: "Static", + expectNil: false, + }, + { + name: "uppercase ENVIRONMENT", + provider: "ENVIRONMENT", + expectNil: false, + }, + { + name: "mixed case Environment", + provider: "Environment", + expectNil: false, + }, + { + name: "uppercase AWS-SDK-DEFAULT", + provider: "AWS-SDK-DEFAULT", + expectNil: true, + }, + { + name: "mixed case Aws-Sdk-Default", + provider: "Aws-Sdk-Default", + expectNil: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cfg := &config.S3Archiver{ + Region: "us-east-1", + CredentialProvider: tc.provider, + Static: config.S3StaticCredentialProvider{ + AccessKeyID: "test-key", + SecretAccessKey: "test-secret", + }, + } + + creds, err := newS3Credentials(cfg) + require.NoError(t, err) + if tc.expectNil { + assert.Nil(t, creds) + } else { + assert.NotNil(t, creds) + } + }) + } +} + +func TestCreateS3Session_Success(t *testing.T) { + cfg := &config.S3Archiver{ + Region: "us-east-1", + CredentialProvider: "static", + Static: config.S3StaticCredentialProvider{ + AccessKeyID: "test-access-key", + SecretAccessKey: "test-secret-key", + }, + } + + sess, err := createS3Session(cfg) + require.NoError(t, err) + require.NotNil(t, sess) + + // Verify session config + assert.Equal(t, "us-east-1", *sess.Config.Region) +} + +func TestCreateS3Session_WithEndpoint(t *testing.T) { + endpoint := "http://localhost:4566" + cfg := &config.S3Archiver{ + Region: "us-east-1", + Endpoint: &endpoint, + S3ForcePathStyle: true, + CredentialProvider: "static", + Static: config.S3StaticCredentialProvider{ + AccessKeyID: "test-access-key", + SecretAccessKey: "test-secret-key", + }, + } + + sess, err := createS3Session(cfg) + require.NoError(t, err) + require.NotNil(t, sess) + + // Verify session config + assert.Equal(t, "us-east-1", *sess.Config.Region) + assert.Equal(t, endpoint, *sess.Config.Endpoint) + assert.True(t, *sess.Config.S3ForcePathStyle) +} + +func TestCreateS3Session_EmptyRegion(t *testing.T) { + cfg := &config.S3Archiver{ + Region: "", + CredentialProvider: "static", + Static: config.S3StaticCredentialProvider{ + AccessKeyID: "test-access-key", + SecretAccessKey: "test-secret-key", + }, + } + + sess, err := createS3Session(cfg) + require.Error(t, err) + assert.Nil(t, sess) + assert.Equal(t, errEmptyAwsRegion, err) +} + +func TestCreateS3Session_InvalidCredentialProvider(t *testing.T) { + cfg := &config.S3Archiver{ + Region: "us-east-1", + CredentialProvider: "invalid", + } + + sess, err := createS3Session(cfg) + require.Error(t, err) + assert.Nil(t, sess) + assert.Contains(t, err.Error(), "unknown AWS credential provider") +} + +func TestCreateS3Session_BackwardCompatibility(t *testing.T) { + // Test that existing configs without credential provider still work + cfg := &config.S3Archiver{ + Region: "us-west-2", + } + + sess, err := createS3Session(cfg) + require.NoError(t, err) + require.NotNil(t, sess) + + // Verify session uses default credential chain + assert.Equal(t, "us-west-2", *sess.Config.Region) + // Credentials should not be explicitly set (using default chain) + assert.NotNil(t, sess.Config.Credentials) +} + +func TestNewS3Credentials_EnvironmentWithRealEnvVars(t *testing.T) { + // Save original env vars + originalAccessKey := os.Getenv("AWS_ACCESS_KEY_ID") + originalSecretKey := os.Getenv("AWS_SECRET_ACCESS_KEY") + defer func() { + // Restore original env vars + if originalAccessKey != "" { + os.Setenv("AWS_ACCESS_KEY_ID", originalAccessKey) + } else { + os.Unsetenv("AWS_ACCESS_KEY_ID") + } + if originalSecretKey != "" { + os.Setenv("AWS_SECRET_ACCESS_KEY", originalSecretKey) + } else { + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + } + }() + + // Set test env vars + os.Setenv("AWS_ACCESS_KEY_ID", "env-test-key") + os.Setenv("AWS_SECRET_ACCESS_KEY", "env-test-secret") + + cfg := &config.S3Archiver{ + Region: "us-east-1", + CredentialProvider: "environment", + } + + creds, err := newS3Credentials(cfg) + require.NoError(t, err) + require.NotNil(t, creds) + + // Verify credentials can be retrieved from environment + value, err := creds.Get() + require.NoError(t, err) + assert.Equal(t, "env-test-key", value.AccessKeyID) + assert.Equal(t, "env-test-secret", value.SecretAccessKey) + assert.Equal(t, credentials.EnvProviderName, value.ProviderName) +} diff --git a/common/archiver/s3store/history_archiver.go b/common/archiver/s3store/history_archiver.go index fa6ef5f1e4..151fd05ecf 100644 --- a/common/archiver/s3store/history_archiver.go +++ b/common/archiver/s3store/history_archiver.go @@ -13,7 +13,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3iface" "go.temporal.io/api/serviceerror" @@ -83,16 +82,7 @@ func newHistoryArchiver( config *config.S3Archiver, historyIterator archiver.HistoryIterator, ) (*historyArchiver, error) { - if len(config.Region) == 0 { - return nil, errEmptyAwsRegion - } - s3Config := &aws.Config{ - Endpoint: config.Endpoint, - Region: aws.String(config.Region), - S3ForcePathStyle: aws.Bool(config.S3ForcePathStyle), - LogLevel: (*aws.LogLevelType)(&config.LogLevel), - } - sess, err := session.NewSession(s3Config) + sess, err := createS3Session(config) if err != nil { return nil, err } diff --git a/common/archiver/s3store/visibility_archiver.go b/common/archiver/s3store/visibility_archiver.go index ae032942a6..ce588d5237 100644 --- a/common/archiver/s3store/visibility_archiver.go +++ b/common/archiver/s3store/visibility_archiver.go @@ -7,7 +7,6 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3iface" "go.temporal.io/api/serviceerror" @@ -66,13 +65,7 @@ func newVisibilityArchiver( logger log.Logger, metricsHandler metrics.Handler, config *config.S3Archiver) (*visibilityArchiver, error) { - s3Config := &aws.Config{ - Endpoint: config.Endpoint, - Region: aws.String(config.Region), - S3ForcePathStyle: aws.Bool(config.S3ForcePathStyle), - LogLevel: (*aws.LogLevelType)(&config.LogLevel), - } - sess, err := session.NewSession(s3Config) + sess, err := createS3Session(config) if err != nil { return nil, err } diff --git a/common/config/config.go b/common/config/config.go index 0ca4baaa6b..d68e049aa8 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -520,6 +520,25 @@ type ( Endpoint *string `yaml:"endpoint"` S3ForcePathStyle bool `yaml:"s3ForcePathStyle"` LogLevel uint `yaml:"logLevel"` + + // Possible options for CredentialProvider include: + // 1) static (fill out Static Credential Provider) + // 2) environment + // a) AccessKeyID from either AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY environment variable + // b) SecretAccessKey from either AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY environment variable + // 3) aws-sdk-default (default if not specified) + // a) Follows aws-go-sdk default credential resolution for session.NewSession + CredentialProvider string `yaml:"credentialProvider"` + Static S3StaticCredentialProvider `yaml:"static"` + } + + // S3StaticCredentialProvider represents static AWS credentials for S3 archiver + S3StaticCredentialProvider struct { + AccessKeyID string `yaml:"accessKeyID"` + SecretAccessKey string `yaml:"secretAccessKey"` + + // Token only required for temporary security credentials retrieved via STS. Otherwise, this is optional. + Token string `yaml:"token"` } // PublicClient is the config for internal nodes (history/matching/worker) connecting to