Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ require (
github.com/alibabacloud-go/tea v1.2.2
github.com/alibabacloud-go/tea-utils/v2 v2.0.7
github.com/aliyun/credentials-go v1.3.11
github.com/aws/aws-sdk-go-v2 v1.36.3
github.com/aws/aws-sdk-go-v2 v1.39.2
github.com/aws/aws-sdk-go-v2/config v1.29.10
github.com/aws/aws-sdk-go-v2/credentials v1.17.66
github.com/aws/aws-sdk-go-v2/service/ecr v1.28.6
Expand Down Expand Up @@ -155,7 +155,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.18 // indirect
github.com/aws/smithy-go v1.22.2 // indirect
github.com/aws/smithy-go v1.23.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/blang/semver v3.5.1+incompatible // indirect
github.com/bshuster-repo/logrus-logstash-hook v1.1.0
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3d
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/aws/aws-sdk-go v1.55.6 h1:cSg4pvZ3m8dgYcgqB97MrcdjUmZ1BeMYKUxMMB89IPk=
github.com/aws/aws-sdk-go v1.55.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg=
github.com/aws/aws-sdk-go-v2 v1.39.2 h1:EJLg8IdbzgeD7xgvZ+I8M1e0fL0ptn/M47lianzth0I=
github.com/aws/aws-sdk-go-v2 v1.39.2/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY=
github.com/aws/aws-sdk-go-v2/config v1.29.10 h1:yNjgjiGBp4GgaJrGythyBXg2wAs+Im9fSWIUwvi1CAc=
github.com/aws/aws-sdk-go-v2/config v1.29.10/go.mod h1:A0mbLXSdtob/2t59n1X0iMkPQ5d+YzYZB4rwu7SZ7aA=
github.com/aws/aws-sdk-go-v2/credentials v1.17.66 h1:aKpEKaTy6n4CEJeYI1MNj97oSDLi4xro3UzQfwf5RWE=
Expand Down Expand Up @@ -198,8 +198,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0c
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs=
github.com/aws/aws-sdk-go-v2/service/sts v1.33.18 h1:xz7WvTMfSStb9Y8NpCT82FXLNC3QasqBfuAFHY4Pk5g=
github.com/aws/aws-sdk-go-v2/service/sts v1.33.18/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4=
github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ=
github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE=
github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
github.com/awslabs/amazon-ecr-credential-helper/ecr-login v0.0.0-20231024185945-8841054dbdb8 h1:SoFYaT9UyGkR0+nogNyD/Lj+bsixB+SNuAS4ABlEs6M=
github.com/awslabs/amazon-ecr-credential-helper/ecr-login v0.0.0-20231024185945-8841054dbdb8/go.mod h1:2JF49jcDOrLStIXN/j/K1EKRq8a8R2qRnlZA6/o/c7c=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
Expand Down
35 changes: 18 additions & 17 deletions pkg/common/oras/authprovider/aws/awsecrbasic.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/ecr"
"github.com/aws/aws-sdk-go-v2/service/ecr/types"
"github.com/pkg/errors"
Expand All @@ -48,24 +47,19 @@ const (
awsSessionName string = "ratifyEcrBasicAuth"
)

// init calls Register for AWS IRSA Basic Auth provider
// init calls Register for AWS ECR Basic Auth provider (supports both IRSA and Pod Identity)
func init() {
provider.Register(awsEcrAuthProviderName, &AwsEcrBasicProviderFactory{})
}

// Get ECR auth token from IRSA config
// Get ECR auth token using AWS SDK default credential chain (supports IRSA, Pod Identity, etc.)
func (d *awsEcrBasicAuthProvider) getEcrAuthToken(artifact string) (EcrAuthToken, error) {
region := os.Getenv("AWS_REGION")
roleArn := os.Getenv("AWS_ROLE_ARN")
tokenFilePath := os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
apiOverrideEndpoint := os.Getenv("AWS_API_OVERRIDE_ENDPOINT")
apiOverridePartition := os.Getenv("AWS_API_OVERRIDE_PARTITION")
apiOverrideRegion := os.Getenv("AWS_API_OVERRIDE_REGION")

// Verify IRSA ENV is present
if region == "" || roleArn == "" || tokenFilePath == "" {
return EcrAuthToken{}, fmt.Errorf("required environment variables not set, AWS_REGION: %s, AWS_ROLE_ARN: %s, AWS_WEB_IDENTITY_TOKEN_FILE: %s", region, roleArn, tokenFilePath)
}
logrus.Debug("AWS ECR auth using default credential chain (supports IRSA, Pod Identity, instance profiles, etc.)")

ctx := context.Background()
// TODO: Update to use regional endpoint
Expand All @@ -91,10 +85,7 @@ func (d *awsEcrBasicAuthProvider) getEcrAuthToken(artifact string) (EcrAuthToken
})
// TODO: Update to use regional endpoint
// nolint:staticcheck
cfg, err := config.LoadDefaultConfig(ctx, config.WithEndpointResolverWithOptions(resolver),
config.WithWebIdentityRoleCredentialOptions(func(options *stscreds.WebIdentityRoleOptions) {
options.RoleSessionName = awsSessionName
}))
cfg, err := config.LoadDefaultConfig(ctx, config.WithEndpointResolverWithOptions(resolver))

if err != nil {
return EcrAuthToken{}, fmt.Errorf("failed to load default AWS basic auth config: %w", err)
Expand All @@ -105,11 +96,21 @@ func (d *awsEcrBasicAuthProvider) getEcrAuthToken(artifact string) (EcrAuthToken
if err != nil {
return EcrAuthToken{}, fmt.Errorf("failed to get registry from image: %w", err)
}
region = awsauth.RegionFromRegistry(registry)
if region == "" {

// Derive region from registry if not set via environment variable
derivedRegion := awsauth.RegionFromRegistry(registry)
if derivedRegion == "" {
return EcrAuthToken{}, fmt.Errorf("failed to get region from image")
}

// Use environment variable region if set, otherwise use derived region
if region == "" {
region = derivedRegion
logrus.Debugf("Using region derived from registry: %s", region)
} else {
logrus.Debugf("Using region from AWS_REGION environment variable: %s", region)
}

logrus.Debugf("AWS ECR basic artifact=%s, registry=%s, region=%s", artifact, registry, region)
cfg.Region = region

Expand Down Expand Up @@ -156,12 +157,12 @@ func (d *awsEcrBasicAuthProvider) Enabled(_ context.Context) bool {
}

// Provide returns the credentials for a specified artifact.
// Uses AWS IRSA to retrieve creds from IRSA credential chain
// Uses AWS SDK default credential chain (supports IRSA, Pod Identity, instance profiles, etc.)
func (d *awsEcrBasicAuthProvider) Provide(ctx context.Context, artifact string) (provider.AuthConfig, error) {
logrus.Debugf("artifact = %s", artifact)

if !d.Enabled(ctx) {
return provider.AuthConfig{}, fmt.Errorf("AWS IRSA basic auth provider is not properly enabled")
return provider.AuthConfig{}, fmt.Errorf("AWS ECR auth provider is not properly enabled")
}

registry, err := provider.GetRegistryHostName(artifact)
Expand Down
6 changes: 2 additions & 4 deletions pkg/common/oras/authprovider/aws/awsecrbasic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package aws
import (
"context"
"encoding/base64"
"os"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -110,9 +109,8 @@ func TestAwsEcrBasicAuthProvider_ProvidesWithHost(t *testing.T) {
func TestAwsEcrBasicAuthProvider_GetAuthTokenWithoutRegion(t *testing.T) {
authProvider := mockAuthProvider()

os.Setenv("AWS_REGION", "placeholder")
os.Setenv("AWS_ROLE_ARN", "placeholder")
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "placeholder")
// Note: AWS_REGION is optional now - will be derived from registry if not set
// This test verifies that artifacts without region information fail appropriately
_, err := authProvider.getEcrAuthToken(testArtifactWithoutRegion)
if err == nil {
t.Fatalf("expected error: %+v", err)
Expand Down
Loading