diff --git a/go.mod b/go.mod index 4de015e78..e95238002 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/alibabacloud-go/tea v1.2.2 github.com/alibabacloud-go/tea-utils/v2 v2.0.7 github.com/aliyun/credentials-go v1.3.11 - github.com/aws/aws-sdk-go-v2 v1.36.3 + github.com/aws/aws-sdk-go-v2 v1.39.2 github.com/aws/aws-sdk-go-v2/config v1.29.10 github.com/aws/aws-sdk-go-v2/credentials v1.17.66 github.com/aws/aws-sdk-go-v2/service/ecr v1.28.6 @@ -155,7 +155,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.33.18 // indirect - github.com/aws/smithy-go v1.22.2 // indirect + github.com/aws/smithy-go v1.23.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver v3.5.1+incompatible // indirect github.com/bshuster-repo/logrus-logstash-hook v1.1.0 diff --git a/go.sum b/go.sum index d698e76e7..86d147fd8 100644 --- a/go.sum +++ b/go.sum @@ -168,8 +168,8 @@ github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3d github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-sdk-go v1.55.6 h1:cSg4pvZ3m8dgYcgqB97MrcdjUmZ1BeMYKUxMMB89IPk= github.com/aws/aws-sdk-go v1.55.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= -github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2 v1.39.2 h1:EJLg8IdbzgeD7xgvZ+I8M1e0fL0ptn/M47lianzth0I= +github.com/aws/aws-sdk-go-v2 v1.39.2/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= github.com/aws/aws-sdk-go-v2/config v1.29.10 h1:yNjgjiGBp4GgaJrGythyBXg2wAs+Im9fSWIUwvi1CAc= github.com/aws/aws-sdk-go-v2/config v1.29.10/go.mod h1:A0mbLXSdtob/2t59n1X0iMkPQ5d+YzYZB4rwu7SZ7aA= github.com/aws/aws-sdk-go-v2/credentials v1.17.66 h1:aKpEKaTy6n4CEJeYI1MNj97oSDLi4xro3UzQfwf5RWE= @@ -198,8 +198,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0c github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= github.com/aws/aws-sdk-go-v2/service/sts v1.33.18 h1:xz7WvTMfSStb9Y8NpCT82FXLNC3QasqBfuAFHY4Pk5g= github.com/aws/aws-sdk-go-v2/service/sts v1.33.18/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= -github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= -github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE= +github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/awslabs/amazon-ecr-credential-helper/ecr-login v0.0.0-20231024185945-8841054dbdb8 h1:SoFYaT9UyGkR0+nogNyD/Lj+bsixB+SNuAS4ABlEs6M= github.com/awslabs/amazon-ecr-credential-helper/ecr-login v0.0.0-20231024185945-8841054dbdb8/go.mod h1:2JF49jcDOrLStIXN/j/K1EKRq8a8R2qRnlZA6/o/c7c= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= diff --git a/pkg/common/oras/authprovider/aws/awsecrbasic.go b/pkg/common/oras/authprovider/aws/awsecrbasic.go index 3d8866259..83beb40f1 100644 --- a/pkg/common/oras/authprovider/aws/awsecrbasic.go +++ b/pkg/common/oras/authprovider/aws/awsecrbasic.go @@ -24,7 +24,6 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/ecr" "github.com/aws/aws-sdk-go-v2/service/ecr/types" "github.com/pkg/errors" @@ -48,24 +47,19 @@ const ( awsSessionName string = "ratifyEcrBasicAuth" ) -// init calls Register for AWS IRSA Basic Auth provider +// init calls Register for AWS ECR Basic Auth provider (supports both IRSA and Pod Identity) func init() { provider.Register(awsEcrAuthProviderName, &AwsEcrBasicProviderFactory{}) } -// Get ECR auth token from IRSA config +// Get ECR auth token using AWS SDK default credential chain (supports IRSA, Pod Identity, etc.) func (d *awsEcrBasicAuthProvider) getEcrAuthToken(artifact string) (EcrAuthToken, error) { region := os.Getenv("AWS_REGION") - roleArn := os.Getenv("AWS_ROLE_ARN") - tokenFilePath := os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE") apiOverrideEndpoint := os.Getenv("AWS_API_OVERRIDE_ENDPOINT") apiOverridePartition := os.Getenv("AWS_API_OVERRIDE_PARTITION") apiOverrideRegion := os.Getenv("AWS_API_OVERRIDE_REGION") - // Verify IRSA ENV is present - if region == "" || roleArn == "" || tokenFilePath == "" { - return EcrAuthToken{}, fmt.Errorf("required environment variables not set, AWS_REGION: %s, AWS_ROLE_ARN: %s, AWS_WEB_IDENTITY_TOKEN_FILE: %s", region, roleArn, tokenFilePath) - } + logrus.Debug("AWS ECR auth using default credential chain (supports IRSA, Pod Identity, instance profiles, etc.)") ctx := context.Background() // TODO: Update to use regional endpoint @@ -91,10 +85,7 @@ func (d *awsEcrBasicAuthProvider) getEcrAuthToken(artifact string) (EcrAuthToken }) // TODO: Update to use regional endpoint // nolint:staticcheck - cfg, err := config.LoadDefaultConfig(ctx, config.WithEndpointResolverWithOptions(resolver), - config.WithWebIdentityRoleCredentialOptions(func(options *stscreds.WebIdentityRoleOptions) { - options.RoleSessionName = awsSessionName - })) + cfg, err := config.LoadDefaultConfig(ctx, config.WithEndpointResolverWithOptions(resolver)) if err != nil { return EcrAuthToken{}, fmt.Errorf("failed to load default AWS basic auth config: %w", err) @@ -105,11 +96,21 @@ func (d *awsEcrBasicAuthProvider) getEcrAuthToken(artifact string) (EcrAuthToken if err != nil { return EcrAuthToken{}, fmt.Errorf("failed to get registry from image: %w", err) } - region = awsauth.RegionFromRegistry(registry) - if region == "" { + + // Derive region from registry if not set via environment variable + derivedRegion := awsauth.RegionFromRegistry(registry) + if derivedRegion == "" { return EcrAuthToken{}, fmt.Errorf("failed to get region from image") } + // Use environment variable region if set, otherwise use derived region + if region == "" { + region = derivedRegion + logrus.Debugf("Using region derived from registry: %s", region) + } else { + logrus.Debugf("Using region from AWS_REGION environment variable: %s", region) + } + logrus.Debugf("AWS ECR basic artifact=%s, registry=%s, region=%s", artifact, registry, region) cfg.Region = region @@ -156,12 +157,12 @@ func (d *awsEcrBasicAuthProvider) Enabled(_ context.Context) bool { } // Provide returns the credentials for a specified artifact. -// Uses AWS IRSA to retrieve creds from IRSA credential chain +// Uses AWS SDK default credential chain (supports IRSA, Pod Identity, instance profiles, etc.) func (d *awsEcrBasicAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) { logrus.Debugf("artifact = %s", artifact) if !d.Enabled(ctx) { - return provider.AuthConfig{}, fmt.Errorf("AWS IRSA basic auth provider is not properly enabled") + return provider.AuthConfig{}, fmt.Errorf("AWS ECR auth provider is not properly enabled") } registry, err := provider.GetRegistryHostName(artifact) diff --git a/pkg/common/oras/authprovider/aws/awsecrbasic_test.go b/pkg/common/oras/authprovider/aws/awsecrbasic_test.go index e6ff94dee..43aecee44 100644 --- a/pkg/common/oras/authprovider/aws/awsecrbasic_test.go +++ b/pkg/common/oras/authprovider/aws/awsecrbasic_test.go @@ -18,7 +18,6 @@ package aws import ( "context" "encoding/base64" - "os" "strings" "testing" "time" @@ -110,9 +109,8 @@ func TestAwsEcrBasicAuthProvider_ProvidesWithHost(t *testing.T) { func TestAwsEcrBasicAuthProvider_GetAuthTokenWithoutRegion(t *testing.T) { authProvider := mockAuthProvider() - os.Setenv("AWS_REGION", "placeholder") - os.Setenv("AWS_ROLE_ARN", "placeholder") - os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "placeholder") + // Note: AWS_REGION is optional now - will be derived from registry if not set + // This test verifies that artifacts without region information fail appropriately _, err := authProvider.getEcrAuthToken(testArtifactWithoutRegion) if err == nil { t.Fatalf("expected error: %+v", err)