Skip to content

Commit 2b580dc

Browse files
authored
feat: add IRSA support (#52)
* feat: implement IRSA support for AWS EKS * refactor: use io instead of deprecated ioutils * refactor: Small refactoring, get rid of if else statement * refactor: set 777 to executable to enable k8s users to use the app with uid different than 100 * refactor: set default IRSA client id to aws-signing-proxy * refactor: extract logic for client creation logic to separate methods * docs: Adapt Readme for IRSA
1 parent 14c7282 commit 2b580dc

File tree

7 files changed

+281
-56
lines changed

7 files changed

+281
-56
lines changed

Dockerfile

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
FROM golang:alpine3.16 AS builder
1+
FROM docker.io/golang:alpine3.16 AS builder
22
COPY . /build
33
WORKDIR /build/cmd/aws-signing-proxy
44
RUN GOOS=linux go build
55

6-
FROM alpine:3.16
6+
FROM docker.io/alpine:3.16
77
RUN apk --no-cache add ca-certificates
88
WORKDIR /app
99
COPY --from=builder /build/cmd/aws-signing-proxy/aws-signing-proxy .
1010

11-
RUN addgroup -S proxy && adduser -S proxy -G proxy && chown -R proxy:proxy /app && chmod 750 /app
11+
RUN addgroup -S proxy && adduser -S proxy -G proxy && chown -R proxy:proxy /app && chmod 777 /app
1212

1313
USER proxy
1414
ENTRYPOINT ["/app/aws-signing-proxy"]

README.md

+18
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@ Supported AWS credentials:
1414
* Fetching short-lived credentials from AWS via a OAuth2 authorization server
1515
and [OpenID Connect (OIDC)](https://openid.net/connect/)
1616
* Additionally, you can fetch these credentials asynchronously
17+
* Fetching short-lived credentials via AWS [IRSA](https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html) (IAM Roles for Service Accounts)
1718

1819
For ready-to-use binaries have a look at [Releases](https://github.com/idealo/aws-signing-proxy/releases).
1920

2021
Additionally, we provide a [Docker image](https://hub.docker.com/r/idealo/aws-signing-proxy) which can be used as a sidecar in Kubernetes.
2122

23+
24+
## 🎉 Version 2.1.0 Update 🎉
25+
* Support for AWS [IRSA](https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html)
26+
2227
## 🎉 Version 2.0.0 Update 🎉
2328

2429
* Version 2.0.0 comes
@@ -78,6 +83,19 @@ ASP_OPEN_ID_CLIENT_SECRET=someverysecurepassword; \
7883
aws-signing-proxy
7984
```
8085

86+
#### With Credentials via IRSA (IAM Roles for Service Accounts)
87+
88+
Execute the binary with either the required environment variables set or via cli flags:
89+
90+
```
91+
ASP_CREDENTIALS_PROVIDER=irsa; \
92+
ASP_TARGET_URL=https://someAWSServiceSupportingSignedHttpRequests; \
93+
ASP_ROLE_ARN=arn:aws:iam::123456242:role/some-access-role; \
94+
aws-signing-proxy
95+
```
96+
97+
Make sure, your AWS_WEB_IDENTITY_TOKEN_FILE environment variable is set!
98+
8199
#### Adjusting the Circuit Breaker Behaviour
82100

83101
If you want to adjust the built-in authorization server circuit breaker, you can set the following environment variables according to your needs.

cmd/aws-signing-proxy/main.go

+68-50
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ import (
44
"flag"
55
"fmt"
66
"github.com/go-co-op/gocron"
7+
"github.com/idealo/aws-signing-proxy/pkg/irsa"
78
. "github.com/idealo/aws-signing-proxy/pkg/logging"
89
"github.com/idealo/aws-signing-proxy/pkg/oidc"
910
"github.com/idealo/aws-signing-proxy/pkg/proxy"
1011
"github.com/idealo/aws-signing-proxy/pkg/vault"
1112
"github.com/kelseyhightower/envconfig"
1213
"github.com/prometheus/client_golang/prometheus/promhttp"
1314
"go.uber.org/zap"
14-
"log"
1515
"net/http"
1616
"net/url"
1717
"os"
@@ -24,15 +24,16 @@ type EnvConfig struct {
2424
MgmtPort int `split_words:"true" default:"8081"`
2525
Service string `default:"es"`
2626
CredentialsProvider string `split_words:"true"`
27-
VaultUrl string `split_words:"true"` // 'https://vaulthost'
28-
VaultAuthToken string `split_words:"true"` // auth-token for accessing Vault
27+
VaultUrl string `split_words:"true"`
28+
VaultAuthToken string `split_words:"true"`
2929
VaultCredentialsPath string `split_words:"true"` // path where aws credentials can be generated/retrieved (e.g: 'aws/creds/my-role')
3030
OpenIdAuthServerUrl string `split_words:"true"`
3131
OpenIdClientId string `split_words:"true"`
3232
OpenIdClientSecret string `split_words:"true"`
3333
AsyncOpenIdCredentialsFetch bool `split_words:"true" default:"false"`
3434
RoleArn string `split_words:"true"`
3535
MetricsPath string `split_words:"true" default:"/status/metrics"`
36+
IrsaClientId string `split_words:"true" default:"aws-signing-proxy"`
3637
}
3738

3839
type Flags struct {
@@ -54,6 +55,7 @@ type Flags struct {
5455
IdleConnTimeout *time.Duration
5556
DialTimeout *time.Duration
5657
MetricsPath *string
58+
IrsaClientId *string
5759
}
5860

5961
func main() {
@@ -71,7 +73,7 @@ func main() {
7173

7274
// Validate target URL
7375
if anyFlagEmpty(*flags.Service, *flags.Target) {
74-
log.Fatal("required parameter target (e.g. foo.eu-central-1.es.amazonaws.com) OR service (e.g. es) missing!")
76+
Logger.Fatal("required parameter target (e.g. foo.eu-central-1.es.amazonaws.com) OR service (e.g. es) missing!")
7577
}
7678
targetURL, err := url.Parse(*flags.Target)
7779
if err != nil {
@@ -87,24 +89,15 @@ func main() {
8789

8890
var client proxy.ReadClient
8991

90-
if *flags.CredentialsProvider == "oidc" {
91-
if anyFlagEmpty(*flags.OpenIdClientId, *flags.OpenIdClientSecret, *flags.OpenIdAuthServerUrl, *flags.RoleArn) {
92-
log.Fatal("Missing some needed flags for OIDC! Either: openIdClientId, openIdClientSecret, openIdAuthServerUrl or roleArn")
93-
} else {
94-
client = newOidcClient(&flags, client, e)
95-
}
96-
} else if *flags.CredentialsProvider == "vault" {
97-
if anyFlagEmpty(*flags.VaultUrl, *flags.VaultPath, *flags.VaultAuthToken) {
98-
Logger.Warn("Disabling vault credentials source due to missing flags/environment variables.")
99-
} else {
100-
client = vault.NewVaultClient().
101-
WithBaseUrl(*flags.VaultUrl).
102-
WithToken(*flags.VaultAuthToken).
103-
ReadFrom(*flags.VaultPath)
104-
Logger.Info("Using Credentials from Vault.", zap.String("vault-url", e.VaultUrl), zap.String("path", e.VaultCredentialsPath))
105-
}
106-
} else {
107-
Logger.Warn("Using static credentials is unsafe. Please consider using some short-living credentials mechanism like Vault or OIDC.")
92+
switch *flags.CredentialsProvider {
93+
case "irsa":
94+
client = newIrsaClient(flags, client, region)
95+
case "oidc":
96+
client = newOidcClient(flags, client, e)
97+
case "vault":
98+
client = newVaultClient(flags, client, e)
99+
default:
100+
Logger.Warn("Using static credentials is unsafe. Please consider using some short-living credentials mechanism like IRSA, OIDC or Vault.")
108101
}
109102

110103
signingProxy := proxy.NewSigningProxy(proxy.Config{
@@ -116,6 +109,7 @@ func main() {
116109
DialTimeout: *flags.DialTimeout,
117110
AuthClient: client,
118111
})
112+
119113
listenString := fmt.Sprintf(":%v", *flags.Port)
120114
mgmtPortString := fmt.Sprintf(":%v", *flags.MgmtPort)
121115
Logger.Info("Listening", zap.String("port", listenString))
@@ -128,36 +122,31 @@ func main() {
128122

129123
}
130124

131-
func parseFlags(flags *Flags, e EnvConfig) {
132-
flags.Target = flag.String("target", e.TargetUrl, "target url to proxy to (e.g. foo.eu-central-1.es.amazonaws.com)")
133-
flags.Port = flag.Int("port", e.Port, "Listening port for proxy (e.g. 8080)")
134-
flags.MgmtPort = flag.Int("mgmt-port", e.MgmtPort, "Management port for proxy (e.g. 8081)")
135-
flags.MetricsPath = flag.String("metrics-path", e.MetricsPath, "")
136-
flags.Service = flag.String("service", e.Service, "AWS Service (e.g. es)")
137-
138-
flags.CredentialsProvider = flag.String("credentials-provider", e.CredentialsProvider, "Either retrieve credentials via OpenID or Vault. Valid values are: oidc, vault")
139-
140-
// Vault
141-
flags.VaultUrl = flag.String("vault-url", e.VaultUrl, "base url of vault (e.g. 'https://foo.vault.invalid')")
142-
flags.VaultPath = flag.String("vault-path", e.VaultCredentialsPath, "path for credentials (e.g. '/some-aws-engine/creds/some-aws-role')")
143-
flags.VaultAuthToken = flag.String("vault-token", e.VaultAuthToken, "token for authenticating with vault (NOTE: use the environment variable ASP_VAULT_AUTH_TOKEN instead)")
144-
145-
// openID Connect
146-
flags.OpenIdAuthServerUrl = flag.String("openid-auth-server-url", e.OpenIdAuthServerUrl, "The authorization server url")
147-
flags.OpenIdClientId = flag.String("openid-client-id", e.OpenIdClientId, "OAuth client id")
148-
flags.OpenIdClientSecret = flag.String("openid-client-secret", e.OpenIdClientSecret, "Oauth client secret")
149-
flags.AsyncOpenIdCredentialsFetch = flag.Bool("async-open-id-creds-fetch", e.AsyncOpenIdCredentialsFetch, "Fetch AWS Credentials via OIDC asynchronously")
150-
flags.RoleArn = flag.String("role-arn", e.RoleArn, "AWS role ARN to assume to")
151-
152-
flags.Region = flag.String("region", os.Getenv("AWS_REGION"), "AWS region for credentials (e.g. eu-central-1)")
153-
flags.FlushInterval = flag.Duration("flush-interval", 0, "non essential: flush interval to flush to the client while copying the response body.")
154-
flags.IdleConnTimeout = flag.Duration("idle-conn-timeout", 90*time.Second, "non essential: the maximum amount of time an idle (keep-alive) connection will remain idle before closing itself. zero means no limit.")
155-
flags.DialTimeout = flag.Duration("dial-timeout", 30*time.Second, "non essential: the maximum amount of time a dial will wait for a connect to complete.")
125+
func newVaultClient(flags Flags, client proxy.ReadClient, e EnvConfig) proxy.ReadClient {
126+
if anyFlagEmpty(*flags.VaultUrl, *flags.VaultPath, *flags.VaultAuthToken) {
127+
Logger.Fatal("Missing some needed flags for using Vault! Either: vaultUrl, vaultPath or vaultAuthToken")
128+
}
129+
Logger.Info("Using Credentials from Vault.", zap.String("vault-url", e.VaultUrl), zap.String("path", e.VaultCredentialsPath))
130+
client = vault.NewVaultClient().
131+
WithBaseUrl(*flags.VaultUrl).
132+
WithToken(*flags.VaultAuthToken).
133+
ReadFrom(*flags.VaultPath)
134+
return client
135+
}
156136

157-
flag.Parse()
137+
func newIrsaClient(flags Flags, client proxy.ReadClient, region string) proxy.ReadClient {
138+
if anyFlagEmpty(*flags.RoleArn) {
139+
zap.S().Fatal("Missing needed role-arn flag for IRSA!")
140+
}
141+
client = irsa.NewIRSAClient(region, *flags.IrsaClientId, *flags.RoleArn)
142+
return client
158143
}
159144

160-
func newOidcClient(flags *Flags, client proxy.ReadClient, e EnvConfig) proxy.ReadClient {
145+
func newOidcClient(flags Flags, client proxy.ReadClient, e EnvConfig) proxy.ReadClient {
146+
if anyFlagEmpty(*flags.OpenIdClientId, *flags.OpenIdClientSecret, *flags.OpenIdAuthServerUrl, *flags.RoleArn) {
147+
zap.S().Fatal("Missing some needed flags for OIDC! Either: openIdClientId, openIdClientSecret, openIdAuthServerUrl or roleArn")
148+
}
149+
161150
var oidcClient oidc.ReadClient
162151
oidcClient = *oidc.NewOIDCClient(*flags.Region).
163152
WithAuthServerUrl(*flags.OpenIdAuthServerUrl).
@@ -186,6 +175,35 @@ func newOidcClient(flags *Flags, client proxy.ReadClient, e EnvConfig) proxy.Rea
186175
return client
187176
}
188177

178+
func parseFlags(flags *Flags, e EnvConfig) {
179+
flags.Target = flag.String("target", e.TargetUrl, "target url to proxy to (e.g. foo.eu-central-1.es.amazonaws.com)")
180+
flags.Port = flag.Int("port", e.Port, "Listening port for proxy (e.g. 8080)")
181+
flags.MgmtPort = flag.Int("mgmt-port", e.MgmtPort, "Management port for proxy (e.g. 8081)")
182+
flags.MetricsPath = flag.String("metrics-path", e.MetricsPath, "")
183+
flags.Service = flag.String("service", e.Service, "AWS Service (e.g. es)")
184+
185+
flags.CredentialsProvider = flag.String("credentials-provider", e.CredentialsProvider, "Either retrieve credentials via IRSA, OpenID Connect or Vault. Valid values are: irsa, oidc, vault. Leave empty if you would like to use static credentials.")
186+
187+
flags.VaultUrl = flag.String("vault-url", e.VaultUrl, "base url of vault (e.g. 'https://foo.vault.invalid')")
188+
flags.VaultPath = flag.String("vault-path", e.VaultCredentialsPath, "path for credentials (e.g. '/some-aws-engine/creds/some-aws-role')")
189+
flags.VaultAuthToken = flag.String("vault-token", e.VaultAuthToken, "token for authenticating with vault (NOTE: use the environment variable ASP_VAULT_AUTH_TOKEN instead)")
190+
191+
flags.OpenIdAuthServerUrl = flag.String("openid-auth-server-url", e.OpenIdAuthServerUrl, "The authorization server url")
192+
flags.OpenIdClientId = flag.String("openid-client-id", e.OpenIdClientId, "OAuth client id")
193+
flags.OpenIdClientSecret = flag.String("openid-client-secret", e.OpenIdClientSecret, "Oauth client secret")
194+
flags.AsyncOpenIdCredentialsFetch = flag.Bool("async-open-id-creds-fetch", e.AsyncOpenIdCredentialsFetch, "Fetch AWS Credentials via OIDC asynchronously")
195+
flags.RoleArn = flag.String("role-arn", e.RoleArn, "AWS role ARN to assume to")
196+
197+
flags.IrsaClientId = flag.String("irsa-client-id", e.IrsaClientId, "IRSA client id")
198+
199+
flags.Region = flag.String("region", os.Getenv("AWS_REGION"), "AWS region for credentials (e.g. eu-central-1)")
200+
flags.FlushInterval = flag.Duration("flush-interval", 0, "non essential: flush interval to flush to the client while copying the response body.")
201+
flags.IdleConnTimeout = flag.Duration("idle-conn-timeout", 90*time.Second, "non essential: the maximum amount of time an idle (keep-alive) connection will remain idle before closing itself. zero means no limit.")
202+
flags.DialTimeout = flag.Duration("dial-timeout", 30*time.Second, "non essential: the maximum amount of time a dial will wait for a connect to complete.")
203+
204+
flag.Parse()
205+
}
206+
189207
func provideMgmtEndpoint(mgmtPort string, metricsPath string) {
190208

191209
http.HandleFunc("/status/health", func(w http.ResponseWriter, request *http.Request) {
@@ -196,7 +214,7 @@ func provideMgmtEndpoint(mgmtPort string, metricsPath string) {
196214

197215
http.Handle(metricsPath, promhttp.Handler())
198216

199-
log.Fatal(http.ListenAndServe(mgmtPort, nil))
217+
zap.S().Fatal(http.ListenAndServe(mgmtPort, nil))
200218
}
201219

202220
func anyFlagEmpty(flags ...string) bool {

go.mod

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ require (
66
github.com/aws/aws-sdk-go v1.44.152
77
github.com/go-co-op/gocron v1.18.0
88
github.com/kelseyhightower/envconfig v1.3.1-0.20170420212316-202b52d1dba0
9+
github.com/pkg/errors v0.9.1
910
github.com/prometheus/client_golang v1.14.0
1011
github.com/sony/gobreaker v0.5.0
1112
github.com/stretchr/testify v1.8.1

pkg/irsa/irsa.go

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package irsa
2+
3+
import (
4+
"github.com/aws/aws-sdk-go/aws"
5+
"github.com/aws/aws-sdk-go/aws/credentials"
6+
"github.com/aws/aws-sdk-go/aws/session"
7+
"github.com/aws/aws-sdk-go/service/sts"
8+
"github.com/aws/aws-sdk-go/service/sts/stsiface"
9+
. "github.com/idealo/aws-signing-proxy/pkg/logging"
10+
"github.com/idealo/aws-signing-proxy/pkg/proxy"
11+
"go.uber.org/zap"
12+
"os"
13+
"time"
14+
)
15+
16+
var cachedCredentials *sts.Credentials
17+
18+
type ReadClient struct {
19+
stsClient stsiface.STSAPI
20+
clientId string
21+
roleArn string
22+
}
23+
24+
func NewIRSAClient(region string, clientId string, roleArn string) *ReadClient {
25+
return &ReadClient{
26+
stsClient: InitClient(region),
27+
clientId: clientId,
28+
roleArn: roleArn,
29+
}
30+
}
31+
32+
func (c *ReadClient) WithRoleArn(roleArn string) *ReadClient {
33+
c.roleArn = roleArn
34+
return c
35+
}
36+
37+
func (c *ReadClient) retrieveShortLivingCredentialsFromAwsSts(roleArn string, webToken string, roleSessionName string) *sts.Credentials {
38+
identity, err := c.stsClient.AssumeRoleWithWebIdentity(&sts.AssumeRoleWithWebIdentityInput{
39+
RoleArn: &roleArn,
40+
RoleSessionName: &roleSessionName,
41+
WebIdentityToken: &webToken,
42+
})
43+
44+
if err != nil {
45+
Logger.Error("Something went wrong with the STS Client", zap.Error(err))
46+
}
47+
48+
return identity.Credentials
49+
}
50+
51+
func InitClient(region string) stsiface.STSAPI {
52+
sess := session.Must(session.NewSession(&aws.Config{
53+
Region: aws.String(region),
54+
Credentials: credentials.AnonymousCredentials},
55+
))
56+
57+
return sts.New(sess, aws.NewConfig().WithRegion(region))
58+
}
59+
60+
func (c *ReadClient) RefreshCredentials(result interface{}) error {
61+
refreshedCredentials := result.(*proxy.RefreshedCredentials)
62+
63+
err := RetrieveCredentials(c)
64+
if err != nil {
65+
return err
66+
}
67+
stsCredentials := cachedCredentials
68+
69+
refreshedCredentials.ExpiresAt = *stsCredentials.Expiration
70+
refreshedCredentials.Data.AccessKey = *stsCredentials.AccessKeyId
71+
refreshedCredentials.Data.SecretKey = *stsCredentials.SecretAccessKey
72+
refreshedCredentials.Data.SecurityToken = *stsCredentials.SessionToken
73+
74+
return nil
75+
}
76+
77+
func RetrieveCredentials(c *ReadClient) error {
78+
if cachedCredentials == nil || isExpired(cachedCredentials.Expiration) {
79+
80+
tokenFile, ok := os.LookupEnv("AWS_WEB_IDENTITY_TOKEN_FILE")
81+
if !ok {
82+
zap.S().Fatalf("Environment variable 'AWS_WEB_IDENTITY_TOKEN_FILE' is not set!")
83+
}
84+
85+
bytes, err := os.ReadFile(tokenFile)
86+
if err != nil {
87+
return err
88+
}
89+
90+
cachedCredentials = c.retrieveShortLivingCredentialsFromAwsSts(c.roleArn, string(bytes), c.clientId)
91+
Logger.Info("Refreshed short living credentials.")
92+
}
93+
return nil
94+
}
95+
96+
func isExpired(expiration *time.Time) bool {
97+
// subtract 5 minutes from the actual expiration to retrieve every 55 minutes new credentials
98+
return time.Now().After(expiration.Add(-time.Minute * 5))
99+
}

0 commit comments

Comments
 (0)