From 5babfa7d6b893cd987a166966d0bf79c03c6f4db Mon Sep 17 00:00:00 2001 From: Anand Rajagopal Date: Tue, 6 May 2025 21:32:41 +0000 Subject: [PATCH 1/2] Receivers initial commit Signed-off-by: Anand Rajagopal --- cmd/alertmanager/main.go | 25 ++- config/notifiers.go | 39 ++--- config/receiver/receiver.go | 7 +- go.mod | 14 ++ go.sum | 28 ++++ notify/pagerduty/pagerduty.go | 42 +++-- secrets/generic_secret.go | 32 ++++ secrets/providers/aws_secrets_manager.go | 185 +++++++++++++++++++++++ secrets/secrets_provider.go | 118 +++++++++++++++ 9 files changed, 453 insertions(+), 37 deletions(-) create mode 100644 secrets/generic_secret.go create mode 100644 secrets/providers/aws_secrets_manager.go create mode 100644 secrets/secrets_provider.go diff --git a/cmd/alertmanager/main.go b/cmd/alertmanager/main.go index 2b918061f4..c394673060 100644 --- a/cmd/alertmanager/main.go +++ b/cmd/alertmanager/main.go @@ -17,6 +17,8 @@ import ( "context" "errors" "fmt" + "github.com/prometheus/alertmanager/secrets" + "github.com/prometheus/alertmanager/secrets/providers" "log/slog" "net" "net/http" @@ -158,10 +160,10 @@ func run() int { httpTimeout = kingpin.Flag("web.timeout", "Timeout for HTTP requests. If negative or zero, no timeout is set.").Default("0").Duration() memlimitRatio = kingpin.Flag("auto-gomemlimit.ratio", "The ratio of reserved GOMEMLIMIT memory to the detected maximum container or system memory. The value must be greater than 0 and less than or equal to 1."). - Default("0.9").Float64() + Default("0.9").Float64() clusterBindAddr = kingpin.Flag("cluster.listen-address", "Listen address for cluster. Set to empty string to disable HA mode."). - Default(defaultClusterAddr).String() + Default(defaultClusterAddr).String() clusterAdvertiseAddr = kingpin.Flag("cluster.advertise-address", "Explicit address to advertise in cluster.").String() peers = kingpin.Flag("cluster.peer", "Initial peers (may be repeated).").Strings() peerTimeout = kingpin.Flag("cluster.peer-timeout", "Time to wait between peers to send notifications.").Default("15s").Duration() @@ -402,8 +404,9 @@ func run() int { } var ( - inhibitor *inhibit.Inhibitor - tmpl *template.Template + inhibitor *inhibit.Inhibitor + tmpl *template.Template + secretsProviderRegistry *secrets.SecretsProviderRegistry ) dispMetrics := dispatch.NewDispatcherMetrics(false, prometheus.DefaultRegisterer) @@ -414,6 +417,9 @@ func run() int { prometheus.DefaultRegisterer, configLogger, ) + defer func() { + secretsProviderRegistry.Stop() + }() configCoordinator.Subscribe(func(conf *config.Config) error { tmpl, err = template.FromGlobs(conf.Templates) if err != nil { @@ -428,6 +434,10 @@ func run() int { activeReceivers[r.RouteOpts.Receiver] = struct{}{} }) + spRegistry := secrets.NewSecretsProviderRegistry(logger, prometheus.NewRegistry()) + // currently only one secrets providers is supported + spRegistry.Register(providers.AWSSecretsManagerSecretProviderDiscoveryConfig{}) + spRegistry.Init() // Build the map of receiver to integrations. receivers := make(map[string][]notify.Integration, len(activeReceivers)) var integrationsNum int @@ -437,7 +447,7 @@ func run() int { configLogger.Info("skipping creation of receiver not referenced by any route", "receiver", rcv.Name) continue } - integrations, err := receiver.BuildReceiverIntegrations(rcv, tmpl, logger) + integrations, err := receiver.BuildReceiverIntegrations(rcv, tmpl, logger, spRegistry) if err != nil { return err } @@ -460,10 +470,13 @@ func run() int { inhibitor.Stop() disp.Stop() + if secretsProviderRegistry != nil { + secretsProviderRegistry.Stop() + } inhibitor = inhibit.NewInhibitor(alerts, conf.InhibitRules, marker, logger) silencer := silence.NewSilencer(silences, marker, logger) - + secretsProviderRegistry = spRegistry // An interface value that holds a nil concrete value is non-nil. // Therefore we explicly pass an empty interface, to detect if the // cluster is not enabled in notify. diff --git a/config/notifiers.go b/config/notifiers.go index 87f806aa27..ff3a0d1394 100644 --- a/config/notifiers.go +++ b/config/notifiers.go @@ -16,6 +16,7 @@ package config import ( "errors" "fmt" + "github.com/prometheus/alertmanager/secrets" "net/textproto" "regexp" "strings" @@ -328,22 +329,22 @@ type PagerdutyConfig struct { HTTPConfig *commoncfg.HTTPClientConfig `yaml:"http_config,omitempty" json:"http_config,omitempty"` - ServiceKey Secret `yaml:"service_key,omitempty" json:"service_key,omitempty"` - ServiceKeyFile string `yaml:"service_key_file,omitempty" json:"service_key_file,omitempty"` - RoutingKey Secret `yaml:"routing_key,omitempty" json:"routing_key,omitempty"` - RoutingKeyFile string `yaml:"routing_key_file,omitempty" json:"routing_key_file,omitempty"` - URL *URL `yaml:"url,omitempty" json:"url,omitempty"` - Client string `yaml:"client,omitempty" json:"client,omitempty"` - ClientURL string `yaml:"client_url,omitempty" json:"client_url,omitempty"` - Description string `yaml:"description,omitempty" json:"description,omitempty"` - Details map[string]string `yaml:"details,omitempty" json:"details,omitempty"` - Images []PagerdutyImage `yaml:"images,omitempty" json:"images,omitempty"` - Links []PagerdutyLink `yaml:"links,omitempty" json:"links,omitempty"` - Source string `yaml:"source,omitempty" json:"source,omitempty"` - Severity string `yaml:"severity,omitempty" json:"severity,omitempty"` - Class string `yaml:"class,omitempty" json:"class,omitempty"` - Component string `yaml:"component,omitempty" json:"component,omitempty"` - Group string `yaml:"group,omitempty" json:"group,omitempty"` + ServiceKey *secrets.GenericSecret `yaml:"service_key,omitempty" json:"service_key,omitempty"` + ServiceKeyFile string `yaml:"service_key_file,omitempty" json:"service_key_file,omitempty"` + RoutingKey *secrets.GenericSecret `yaml:"routing_key,omitempty" json:"routing_key,omitempty"` + RoutingKeyFile string `yaml:"routing_key_file,omitempty" json:"routing_key_file,omitempty"` + URL *URL `yaml:"url,omitempty" json:"url,omitempty"` + Client string `yaml:"client,omitempty" json:"client,omitempty"` + ClientURL string `yaml:"client_url,omitempty" json:"client_url,omitempty"` + Description string `yaml:"description,omitempty" json:"description,omitempty"` + Details map[string]string `yaml:"details,omitempty" json:"details,omitempty"` + Images []PagerdutyImage `yaml:"images,omitempty" json:"images,omitempty"` + Links []PagerdutyLink `yaml:"links,omitempty" json:"links,omitempty"` + Source string `yaml:"source,omitempty" json:"source,omitempty"` + Severity string `yaml:"severity,omitempty" json:"severity,omitempty"` + Class string `yaml:"class,omitempty" json:"class,omitempty"` + Component string `yaml:"component,omitempty" json:"component,omitempty"` + Group string `yaml:"group,omitempty" json:"group,omitempty"` } // PagerdutyLink is a link. @@ -366,13 +367,13 @@ func (c *PagerdutyConfig) UnmarshalYAML(unmarshal func(interface{}) error) error if err := unmarshal((*plain)(c)); err != nil { return err } - if c.RoutingKey == "" && c.ServiceKey == "" && c.RoutingKeyFile == "" && c.ServiceKeyFile == "" { + if c.RoutingKey == nil && c.ServiceKey == nil && c.RoutingKeyFile == "" && c.ServiceKeyFile == "" { return errors.New("missing service or routing key in PagerDuty config") } - if len(c.RoutingKey) > 0 && len(c.RoutingKeyFile) > 0 { + if c.RoutingKey != nil && len(c.RoutingKeyFile) > 0 { return errors.New("at most one of routing_key & routing_key_file must be configured") } - if len(c.ServiceKey) > 0 && len(c.ServiceKeyFile) > 0 { + if c.ServiceKey != nil && len(c.ServiceKeyFile) > 0 { return errors.New("at most one of service_key & service_key_file must be configured") } if c.Details == nil { diff --git a/config/receiver/receiver.go b/config/receiver/receiver.go index d92a19a4c5..33a85850a3 100644 --- a/config/receiver/receiver.go +++ b/config/receiver/receiver.go @@ -14,6 +14,7 @@ package receiver import ( + "github.com/prometheus/alertmanager/secrets" "log/slog" commoncfg "github.com/prometheus/common/config" @@ -43,7 +44,7 @@ import ( // BuildReceiverIntegrations builds a list of integration notifiers off of a // receiver config. -func BuildReceiverIntegrations(nc config.Receiver, tmpl *template.Template, logger *slog.Logger, httpOpts ...commoncfg.HTTPClientOption) ([]notify.Integration, error) { +func BuildReceiverIntegrations(nc config.Receiver, tmpl *template.Template, logger *slog.Logger, spRegistry *secrets.SecretsProviderRegistry, httpOpts ...commoncfg.HTTPClientOption) ([]notify.Integration, error) { if logger == nil { logger = promslog.NewNopLogger() } @@ -68,7 +69,9 @@ func BuildReceiverIntegrations(nc config.Receiver, tmpl *template.Template, logg add("email", i, c, func(l *slog.Logger) (notify.Notifier, error) { return email.New(c, tmpl, l), nil }) } for i, c := range nc.PagerdutyConfigs { - add("pagerduty", i, c, func(l *slog.Logger) (notify.Notifier, error) { return pagerduty.New(c, tmpl, l, httpOpts...) }) + add("pagerduty", i, c, func(l *slog.Logger) (notify.Notifier, error) { + return pagerduty.New(c, tmpl, l, spRegistry, httpOpts...) + }) } for i, c := range nc.OpsGenieConfigs { add("opsgenie", i, c, func(l *slog.Logger) (notify.Notifier, error) { return opsgenie.New(c, tmpl, l, httpOpts...) }) diff --git a/go.mod b/go.mod index 886f16f31d..3b2fcf7f74 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,9 @@ require ( github.com/alecthomas/kingpin/v2 v2.4.0 github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b github.com/aws/aws-sdk-go v1.55.5 + github.com/aws/aws-sdk-go-v2 v1.36.3 + github.com/aws/aws-sdk-go-v2/config v1.29.14 + github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.35.4 github.com/cenkalti/backoff/v4 v4.3.0 github.com/cespare/xxhash/v2 v2.3.0 github.com/coder/quartz v0.1.2 @@ -53,6 +56,17 @@ require ( require ( github.com/armon/go-metrics v0.3.10 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect + github.com/aws/smithy-go v1.22.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index 8b90e15442..60511aaa2b 100644 --- a/go.sum +++ b/go.sum @@ -80,6 +80,34 @@ github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3d github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= +github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.35.4 h1:EKXYJ8kgz4fiqef8xApu7eH0eae2SrVG+oHCLFybMRI= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.35.4/go.mod h1:yGhDiLKguA3iFJYxbrQkQiNzuy+ddxesSZYWVeeEH5Q= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= diff --git a/notify/pagerduty/pagerduty.go b/notify/pagerduty/pagerduty.go index abab5a70be..6a2ba01be2 100644 --- a/notify/pagerduty/pagerduty.go +++ b/notify/pagerduty/pagerduty.go @@ -19,6 +19,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/prometheus/alertmanager/secrets" "io" "log/slog" "net/http" @@ -45,27 +46,30 @@ const ( // Notifier implements a Notifier for PagerDuty notifications. type Notifier struct { - conf *config.PagerdutyConfig - tmpl *template.Template - logger *slog.Logger - apiV1 string // for tests. - client *http.Client - retrier *notify.Retrier + conf *config.PagerdutyConfig + tmpl *template.Template + logger *slog.Logger + apiV1 string // for tests. + client *http.Client + retrier *notify.Retrier + secretsFetcher secrets.SecretsFetcher } // New returns a new PagerDuty notifier. -func New(c *config.PagerdutyConfig, t *template.Template, l *slog.Logger, httpOpts ...commoncfg.HTTPClientOption) (*Notifier, error) { +func New(c *config.PagerdutyConfig, t *template.Template, l *slog.Logger, spRegistry *secrets.SecretsProviderRegistry, httpOpts ...commoncfg.HTTPClientOption) (*Notifier, error) { client, err := commoncfg.NewClientFromConfig(*c.HTTPConfig, "pagerduty", httpOpts...) if err != nil { return nil, err } n := &Notifier{conf: c, tmpl: t, logger: l, client: client} - if c.ServiceKey != "" || c.ServiceKeyFile != "" { + if c.ServiceKey != nil || c.ServiceKeyFile != "" { + n.secretsFetcher, err = spRegistry.RegisterSecret(c.ServiceKey) n.apiV1 = "https://events.pagerduty.com/generic/2010-04-15/create_event.json" // Retrying can solve the issue on 403 (rate limiting) and 5xx response codes. // https://v2.developer.pagerduty.com/docs/trigger-events n.retrier = ¬ify.Retrier{RetryCodes: []int{http.StatusForbidden}, CustomDetailsFunc: errDetails} } else { + n.secretsFetcher, err = spRegistry.RegisterSecret(c.RoutingKey) // Retrying can solve the issue on 429 (rate limiting) and 5xx response codes. // https://v2.developer.pagerduty.com/docs/events-api-v2#api-response-codes--retry-logic n.retrier = ¬ify.Retrier{RetryCodes: []int{http.StatusTooManyRequests}, CustomDetailsFunc: errDetails} @@ -143,6 +147,22 @@ func (n *Notifier) encodeMessage(msg *pagerDutyMessage) (bytes.Buffer, error) { return buf, nil } +func (n *Notifier) getSecret(ctx context.Context) string { + var secret *secrets.GenericSecret + if n.conf.ServiceKey != nil { + secret = n.conf.ServiceKey + } else { + secret = n.conf.RoutingKey + } + + if sec, err := n.secretsFetcher.FetchSecret(ctx, secret); err != nil { + n.logger.Error("unable to fetch secret", err) + return "" + } else { + return sec + } +} + func (n *Notifier) notifyV1( ctx context.Context, eventType string, @@ -159,7 +179,8 @@ func (n *Notifier) notifyV1( n.logger.Warn("Truncated description", "key", key, "max_runes", maxV1DescriptionLenRunes) } - serviceKey := string(n.conf.ServiceKey) + //serviceKey := string(n.conf.ServiceKey) + serviceKey := n.getSecret(ctx) if serviceKey == "" { content, fileErr := os.ReadFile(n.conf.ServiceKeyFile) if fileErr != nil { @@ -224,7 +245,8 @@ func (n *Notifier) notifyV2( n.logger.Warn("Truncated summary", "key", key, "max_runes", maxV2SummaryLenRunes) } - routingKey := string(n.conf.RoutingKey) + //routingKey := string(n.conf.RoutingKey) + routingKey := n.getSecret(ctx) if routingKey == "" { content, fileErr := os.ReadFile(n.conf.RoutingKeyFile) if fileErr != nil { diff --git a/secrets/generic_secret.go b/secrets/generic_secret.go new file mode 100644 index 0000000000..8742a705af --- /dev/null +++ b/secrets/generic_secret.go @@ -0,0 +1,32 @@ +package secrets + +import ( + "errors" + "time" +) + +type GenericSecret struct { + AWSSecretsManagerConfig *AWSSecretsManagerConfig `yaml:"aws_secrets_manager" json:"aws_secrets_manager_config"` +} + +// TODO implement this correctly +func (gs *GenericSecret) String() string { + return "" +} + +// TODO implement Marshal and JSON equivalent methods +func (gs *GenericSecret) UnmarshalYAML(unmarshalFn func(any) error) error { + var inlineForm string + if err := unmarshalFn(&inlineForm); err == nil { + return errors.New("inline form is not supported") + } + type plain GenericSecret + // We need to do this to avoid infinite recursion. + return unmarshalFn((*plain)(gs)) +} + +type AWSSecretsManagerConfig struct { + SecretARN string `yaml:"secret_arn"` + SecretKey string `yaml:"secret_key"` + RefreshInterval time.Duration `yaml:"refresh_interval"` +} diff --git a/secrets/providers/aws_secrets_manager.go b/secrets/providers/aws_secrets_manager.go new file mode 100644 index 0000000000..d569ad8405 --- /dev/null +++ b/secrets/providers/aws_secrets_manager.go @@ -0,0 +1,185 @@ +package providers + +import ( + "context" + "encoding/json" + "errors" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + "github.com/prometheus/alertmanager/secrets" + "github.com/prometheus/client_golang/prometheus" + "log/slog" + "sync" + "time" +) + +//TODO metrics + +type AWSSecretsManagerProvider struct { + mtx sync.RWMutex + fetchers map[string]*secretFetcher + logger *slog.Logger + reg prometheus.Registerer + ctx context.Context +} + +func (a *AWSSecretsManagerProvider) Register(secret *secrets.GenericSecret) secrets.SecretsFetcher { + s := secret.AWSSecretsManagerConfig + if s == nil { + a.logger.Error("secret is nil. nothing to register") + return nil + } + a.logger.Info("registering secret") + a.mtx.Lock() + defer a.mtx.Unlock() + if f, OK := a.fetchers[s.SecretARN]; OK { + a.logger.Info("found an existing secret fetcher") + f.update(s.RefreshInterval) + return f + } + a.logger.Info("no secret fetcher found. creating a new one") + a.fetchers[s.SecretARN] = newSecretFetcher(a.ctx, a.logger, a.reg, s.SecretARN, s.RefreshInterval) + return a.fetchers[s.SecretARN] +} + +func (a *AWSSecretsManagerProvider) Stop() { + a.mtx.Lock() + defer a.mtx.Unlock() + for name, fetcher := range a.fetchers { + a.logger.Info("stopping secrets fetcher", "name", name) + fetcher.Stop() + } + a.logger.Info("aws secrets manager providers stopped") +} + +type secretFetcher struct { + secrets map[string]string + mtx sync.RWMutex + logger *slog.Logger + reg prometheus.Registerer + arn string + interval time.Duration + ctx context.Context + client *secretsmanager.Client + done chan struct{} + ticker *time.Ticker + initialFetch bool +} + +func newSecretFetcher(ctx context.Context, logger *slog.Logger, reg prometheus.Registerer, arn string, interval time.Duration) *secretFetcher { + sf := &secretFetcher{ + secrets: make(map[string]string), + logger: logger, + reg: reg, + arn: arn, + interval: interval, + ctx: ctx, + done: make(chan struct{}), + ticker: time.NewTicker(interval), + } + sf.createSecretsManagerClient() + go sf.run() + return sf +} + +func (s *secretFetcher) createSecretsManagerClient() { + parsedARN, err := arn.Parse(s.arn) + if err != nil { + s.logger.Error("unable to create secret manager client", err) + return + } + config, err := awsconfig.LoadDefaultConfig(s.ctx, awsconfig.WithRegion(parsedARN.Region)) + if err != nil { + s.logger.Error("unable to load config", err) + return + } + s.client = secretsmanager.NewFromConfig(config) +} + +func (s *secretFetcher) Stop() { + <-s.done + s.logger.Info("secret fetcher stopped") +} + +func (s *secretFetcher) run() { + defer close(s.done) + defer s.ticker.Stop() + input := &secretsmanager.GetSecretValueInput{ + SecretId: aws.String(s.arn), + } + s.logger.Debug("fetch secret", "reason", "initial") + s.retrieveSecret(input) + s.initialFetch = true + for { + select { + case <-s.ticker.C: + s.logger.Debug("fetching secret", "reason", "periodic") + s.retrieveSecret(input) + s.initialFetch = true + case <-s.ctx.Done(): + s.logger.Info("stopping secrets fetcher") + return + } + } +} + +func (s *secretFetcher) retrieveSecret(input *secretsmanager.GetSecretValueInput) { + result, err := s.client.GetSecretValue(s.ctx, input) + if err != nil { + s.logger.Error("unable to fetch secret for ARN", "arn", s.arn, "error", err) + return + } + secretString := *result.SecretString + var m map[string]string + if err = json.Unmarshal([]byte(secretString), &m); err != nil { + s.logger.Error("unable to unmarshal payload", "arn", s.arn, "error", err) + return + } + s.logger.Debug("retrieved keys", "key count", len(m)) + s.mtx.Lock() + defer s.mtx.Unlock() + s.secrets = nil + s.secrets = m +} + +func (s *secretFetcher) update(interval time.Duration) { + s.mtx.Lock() + defer s.mtx.Unlock() + if s.interval > interval { + s.interval = interval + s.ticker.Reset(s.interval) + } +} + +func (s *secretFetcher) FetchSecret(_ context.Context, secret *secrets.GenericSecret) (string, error) { + sec := secret.AWSSecretsManagerConfig + if sec == nil { + return "", errors.New("cannot fetch empty secret") + } + + s.mtx.RLock() + value, exists := s.secrets[sec.SecretKey] + s.mtx.RUnlock() + if !exists { + return "", errors.New("secret not found") + } + return value, nil +} + +type AWSSecretsManagerSecretProviderDiscoveryConfig struct { +} + +func (a AWSSecretsManagerSecretProviderDiscoveryConfig) Name() string { + return "aws_secrets_manager" +} + +func (a AWSSecretsManagerSecretProviderDiscoveryConfig) NewSecretsProvider(options secrets.SecretProviderOptions) (secrets.SecretsProvider, error) { + return &AWSSecretsManagerProvider{ + fetchers: make(map[string]*secretFetcher), + logger: options.Logger, + reg: options.Registerer, + ctx: options.Context, + }, nil +} diff --git a/secrets/secrets_provider.go b/secrets/secrets_provider.go new file mode 100644 index 0000000000..407a9505e0 --- /dev/null +++ b/secrets/secrets_provider.go @@ -0,0 +1,118 @@ +package secrets + +import ( + "context" + "errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/common/config" + "log/slog" + "sync" +) + +var ( + AWS_SECRETS_MANAGER_PROVIDER = "aws_secrets_manager" +) + +type SecretsFetcher interface { + FetchSecret(ctx context.Context, secret *GenericSecret) (string, error) + Stop() +} + +type SecretsProvider interface { + Register(secret *GenericSecret) SecretsFetcher + Stop() +} + +type SecretsProviderRegistry struct { + mtx sync.RWMutex + providers map[string]SecretsProvider + logger *slog.Logger + reg prometheus.Registerer + configs map[string]SecretProviderDiscoveryConfig + ctx context.Context + cancel context.CancelFunc +} + +func NewSecretsProviderRegistry(logger *slog.Logger, reg prometheus.Registerer) *SecretsProviderRegistry { + registry := &SecretsProviderRegistry{ + providers: make(map[string]SecretsProvider), + configs: make(map[string]SecretProviderDiscoveryConfig), + logger: logger, + reg: reg, + } + return registry +} + +func (s *SecretsProviderRegistry) Register(config SecretProviderDiscoveryConfig) { + s.mtx.Lock() + defer s.mtx.Unlock() + s.logger.Info("registering secret providers", "name", config.Name()) + s.configs[config.Name()] = config +} + +func (s *SecretsProviderRegistry) Init() { + s.mtx.Lock() + defer s.mtx.Unlock() + s.ctx, s.cancel = context.WithCancel(context.Background()) + for name, providerConfig := range s.configs { + s.logger.Info("initializing secret providers", "name", name) + provider, err := providerConfig.NewSecretsProvider(SecretProviderOptions{ + Logger: s.logger, + Registerer: s.reg, + Context: s.ctx, + }) + if err != nil { + s.logger.Error("unable to initialize secrets provider", "name", name, "error", err.Error()) + continue + } + s.providers[name] = provider + } +} + +func (s *SecretsProviderRegistry) Stop() { + if s == nil { + return + } + s.mtx.Lock() + defer s.mtx.Unlock() + if s.cancel == nil { + return + } + s.cancel() + s.cancel = nil + for name, provider := range s.providers { + s.logger.Info("stopping secrets providers", "name", name) + provider.Stop() + } + s.logger.Info("stopped secrets providers registry") +} + +func (s *SecretsProviderRegistry) RegisterSecret(secret *GenericSecret) (SecretsFetcher, error) { + s.mtx.RLock() + defer s.mtx.RUnlock() + + s.logger.Info("registering secret") + if secret.AWSSecretsManagerConfig != nil { + s.logger.Info("registering aws_secret_manager secret") + return s.providers[AWS_SECRETS_MANAGER_PROVIDER].Register(secret), nil + } + return nil, errors.New("no secrets fetcher found for the given secret") +} + +type SecretProviderDiscoveryConfig interface { + // Name returns the name of the discovery mechanism. + Name() string + + NewSecretsProvider(SecretProviderOptions) (SecretsProvider, error) +} + +type SecretProviderOptions struct { + Logger *slog.Logger + + // A registerer for the SecretProvider's metrics. + Registerer prometheus.Registerer + + HTTPClientOptions []config.HTTPClientOption + + Context context.Context +} From e355ec66819224cd9a14c583043312515b24486c Mon Sep 17 00:00:00 2001 From: Anand Rajagopal Date: Sun, 1 Jun 2025 23:05:03 +0000 Subject: [PATCH 2/2] Updated to use non-pointer version Signed-off-by: Anand Rajagopal --- cmd/alertmanager/main.go | 30 +- config/config_test.go | 4 + config/notifiers.go | 45 +- config/notifiers_test.go | 2 +- config/receiver/receiver.go | 3 +- config/receiver/receiver_test.go | 8 +- notify/pagerduty/pagerduty.go | 43 +- notify/pagerduty/pagerduty_test.go | 103 +++- secrets/generic_secret.go | 72 ++- secrets/generic_secret_test.go | 230 +++++++++ secrets/providers/aws_secrets_manager.go | 387 ++++++++++++--- .../providers/aws_secrets_manager_internal.go | 78 +++ secrets/providers/aws_secrets_manager_test.go | 468 ++++++++++++++++++ secrets/secrets_provider.go | 75 ++- secrets/secrets_provider_test.go | 290 +++++++++++ 15 files changed, 1674 insertions(+), 164 deletions(-) create mode 100644 secrets/generic_secret_test.go create mode 100644 secrets/providers/aws_secrets_manager_internal.go create mode 100644 secrets/providers/aws_secrets_manager_test.go create mode 100644 secrets/secrets_provider_test.go diff --git a/cmd/alertmanager/main.go b/cmd/alertmanager/main.go index c394673060..c8d16acc49 100644 --- a/cmd/alertmanager/main.go +++ b/cmd/alertmanager/main.go @@ -17,8 +17,6 @@ import ( "context" "errors" "fmt" - "github.com/prometheus/alertmanager/secrets" - "github.com/prometheus/alertmanager/secrets/providers" "log/slog" "net" "net/http" @@ -32,6 +30,9 @@ import ( "syscall" "time" + "github.com/prometheus/alertmanager/secrets" + "github.com/prometheus/alertmanager/secrets/providers" + "github.com/KimMachineGun/automemlimit/memlimit" "github.com/alecthomas/kingpin/v2" "github.com/prometheus/client_golang/prometheus" @@ -160,10 +161,10 @@ func run() int { httpTimeout = kingpin.Flag("web.timeout", "Timeout for HTTP requests. If negative or zero, no timeout is set.").Default("0").Duration() memlimitRatio = kingpin.Flag("auto-gomemlimit.ratio", "The ratio of reserved GOMEMLIMIT memory to the detected maximum container or system memory. The value must be greater than 0 and less than or equal to 1."). - Default("0.9").Float64() + Default("0.9").Float64() clusterBindAddr = kingpin.Flag("cluster.listen-address", "Listen address for cluster. Set to empty string to disable HA mode."). - Default(defaultClusterAddr).String() + Default(defaultClusterAddr).String() clusterAdvertiseAddr = kingpin.Flag("cluster.advertise-address", "Explicit address to advertise in cluster.").String() peers = kingpin.Flag("cluster.peer", "Initial peers (may be repeated).").Strings() peerTimeout = kingpin.Flag("cluster.peer-timeout", "Time to wait between peers to send notifications.").Default("15s").Duration() @@ -434,10 +435,15 @@ func run() int { activeReceivers[r.RouteOpts.Receiver] = struct{}{} }) - spRegistry := secrets.NewSecretsProviderRegistry(logger, prometheus.NewRegistry()) - // currently only one secrets providers is supported - spRegistry.Register(providers.AWSSecretsManagerSecretProviderDiscoveryConfig{}) - spRegistry.Init() + if secretsProviderRegistry == nil { + secretsProviderRegistry = secrets.NewSecretsProviderRegistry(logger, prometheus.NewRegistry()) + // currently only one secrets provider is registered. Inline secrets provider is always available + if secretsProviderRegistry.Register(providers.AWSSecretsManagerSecretProviderDiscoveryConfig{}) != nil { + configLogger.Error("failed to register secrets provider", "err", err) + } + secretsProviderRegistry.Init() + } + // Build the map of receiver to integrations. receivers := make(map[string][]notify.Integration, len(activeReceivers)) var integrationsNum int @@ -447,7 +453,7 @@ func run() int { configLogger.Info("skipping creation of receiver not referenced by any route", "receiver", rcv.Name) continue } - integrations, err := receiver.BuildReceiverIntegrations(rcv, tmpl, logger, spRegistry) + integrations, err := receiver.BuildReceiverIntegrations(rcv, tmpl, logger, secretsProviderRegistry) if err != nil { return err } @@ -470,13 +476,9 @@ func run() int { inhibitor.Stop() disp.Stop() - if secretsProviderRegistry != nil { - secretsProviderRegistry.Stop() - } inhibitor = inhibit.NewInhibitor(alerts, conf.InhibitRules, marker, logger) silencer := silence.NewSilencer(silences, marker, logger) - secretsProviderRegistry = spRegistry // An interface value that holds a nil concrete value is non-nil. // Therefore we explicly pass an empty interface, to detect if the // cluster is not enabled in notify. @@ -535,7 +537,7 @@ func run() int { go disp.Run() go inhibitor.Run() - + secretsProviderRegistry.UpdateComplete() return nil }) diff --git a/config/config_test.go b/config/config_test.go index 728d7670b1..9951272e34 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -23,6 +23,8 @@ import ( "testing" "time" + "github.com/prometheus/alertmanager/secrets" + commoncfg "github.com/prometheus/common/config" "github.com/prometheus/common/model" "github.com/prometheus/common/promslog" @@ -528,6 +530,8 @@ func TestHideConfigSecrets(t *testing.T) { func TestShowMarshalSecretValues(t *testing.T) { MarshalSecretValue = true defer func() { MarshalSecretValue = false }() + secrets.MarshalSecretValue = true + defer func() { secrets.MarshalSecretValue = false }() c, err := LoadFile("testdata/conf.good.yml") if err != nil { diff --git a/config/notifiers.go b/config/notifiers.go index ff3a0d1394..d5b6025f63 100644 --- a/config/notifiers.go +++ b/config/notifiers.go @@ -16,13 +16,14 @@ package config import ( "errors" "fmt" - "github.com/prometheus/alertmanager/secrets" "net/textproto" "regexp" "strings" "text/template" "time" + "github.com/prometheus/alertmanager/secrets" + commoncfg "github.com/prometheus/common/config" "github.com/prometheus/common/model" "github.com/prometheus/sigv4" @@ -329,22 +330,22 @@ type PagerdutyConfig struct { HTTPConfig *commoncfg.HTTPClientConfig `yaml:"http_config,omitempty" json:"http_config,omitempty"` - ServiceKey *secrets.GenericSecret `yaml:"service_key,omitempty" json:"service_key,omitempty"` - ServiceKeyFile string `yaml:"service_key_file,omitempty" json:"service_key_file,omitempty"` - RoutingKey *secrets.GenericSecret `yaml:"routing_key,omitempty" json:"routing_key,omitempty"` - RoutingKeyFile string `yaml:"routing_key_file,omitempty" json:"routing_key_file,omitempty"` - URL *URL `yaml:"url,omitempty" json:"url,omitempty"` - Client string `yaml:"client,omitempty" json:"client,omitempty"` - ClientURL string `yaml:"client_url,omitempty" json:"client_url,omitempty"` - Description string `yaml:"description,omitempty" json:"description,omitempty"` - Details map[string]string `yaml:"details,omitempty" json:"details,omitempty"` - Images []PagerdutyImage `yaml:"images,omitempty" json:"images,omitempty"` - Links []PagerdutyLink `yaml:"links,omitempty" json:"links,omitempty"` - Source string `yaml:"source,omitempty" json:"source,omitempty"` - Severity string `yaml:"severity,omitempty" json:"severity,omitempty"` - Class string `yaml:"class,omitempty" json:"class,omitempty"` - Component string `yaml:"component,omitempty" json:"component,omitempty"` - Group string `yaml:"group,omitempty" json:"group,omitempty"` + ServiceKey secrets.GenericSecret `yaml:"service_key,omitempty" json:"service_key,omitempty"` + ServiceKeyFile string `yaml:"service_key_file,omitempty" json:"service_key_file,omitempty"` + RoutingKey secrets.GenericSecret `yaml:"routing_key,omitempty" json:"routing_key,omitempty"` + RoutingKeyFile string `yaml:"routing_key_file,omitempty" json:"routing_key_file,omitempty"` + URL *URL `yaml:"url,omitempty" json:"url,omitempty"` + Client string `yaml:"client,omitempty" json:"client,omitempty"` + ClientURL string `yaml:"client_url,omitempty" json:"client_url,omitempty"` + Description string `yaml:"description,omitempty" json:"description,omitempty"` + Details map[string]string `yaml:"details,omitempty" json:"details,omitempty"` + Images []PagerdutyImage `yaml:"images,omitempty" json:"images,omitempty"` + Links []PagerdutyLink `yaml:"links,omitempty" json:"links,omitempty"` + Source string `yaml:"source,omitempty" json:"source,omitempty"` + Severity string `yaml:"severity,omitempty" json:"severity,omitempty"` + Class string `yaml:"class,omitempty" json:"class,omitempty"` + Component string `yaml:"component,omitempty" json:"component,omitempty"` + Group string `yaml:"group,omitempty" json:"group,omitempty"` } // PagerdutyLink is a link. @@ -360,6 +361,10 @@ type PagerdutyImage struct { Href string `yaml:"href,omitempty" json:"href,omitempty"` } +func (c *PagerdutyConfig) isKeyZero() bool { + return c.ServiceKey.IsZero() && c.RoutingKey.IsZero() +} + // UnmarshalYAML implements the yaml.Unmarshaler interface. func (c *PagerdutyConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { *c = DefaultPagerdutyConfig @@ -367,13 +372,13 @@ func (c *PagerdutyConfig) UnmarshalYAML(unmarshal func(interface{}) error) error if err := unmarshal((*plain)(c)); err != nil { return err } - if c.RoutingKey == nil && c.ServiceKey == nil && c.RoutingKeyFile == "" && c.ServiceKeyFile == "" { + if c.isKeyZero() && c.RoutingKeyFile == "" && c.ServiceKeyFile == "" { return errors.New("missing service or routing key in PagerDuty config") } - if c.RoutingKey != nil && len(c.RoutingKeyFile) > 0 { + if !c.RoutingKey.IsZero() && len(c.RoutingKeyFile) > 0 { return errors.New("at most one of routing_key & routing_key_file must be configured") } - if c.ServiceKey != nil && len(c.ServiceKeyFile) > 0 { + if !c.ServiceKey.IsZero() && len(c.ServiceKeyFile) > 0 { return errors.New("at most one of service_key & service_key_file must be configured") } if c.Details == nil { diff --git a/config/notifiers_test.go b/config/notifiers_test.go index af348b9ff5..cc72bb15d6 100644 --- a/config/notifiers_test.go +++ b/config/notifiers_test.go @@ -142,7 +142,7 @@ routing_key_file: 'xyz' func TestPagerdutyServiceKey(t *testing.T) { t.Run("error if no service key or key file", func(t *testing.T) { in := ` -service_key: '' +service_key: ` var cfg PagerdutyConfig err := yaml.UnmarshalStrict([]byte(in), &cfg) diff --git a/config/receiver/receiver.go b/config/receiver/receiver.go index 33a85850a3..6af913b902 100644 --- a/config/receiver/receiver.go +++ b/config/receiver/receiver.go @@ -14,9 +14,10 @@ package receiver import ( - "github.com/prometheus/alertmanager/secrets" "log/slog" + "github.com/prometheus/alertmanager/secrets" + commoncfg "github.com/prometheus/common/config" "github.com/prometheus/common/promslog" diff --git a/config/receiver/receiver_test.go b/config/receiver/receiver_test.go index 3d146a98d0..f0d5690a00 100644 --- a/config/receiver/receiver_test.go +++ b/config/receiver/receiver_test.go @@ -16,6 +16,11 @@ package receiver import ( "testing" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/common/promslog" + + "github.com/prometheus/alertmanager/secrets" + commoncfg "github.com/prometheus/common/config" "github.com/stretchr/testify/require" @@ -71,7 +76,8 @@ func TestBuildReceiverIntegrations(t *testing.T) { } { tc := tc t.Run("", func(t *testing.T) { - integrations, err := BuildReceiverIntegrations(tc.receiver, nil, nil) + sp := secrets.NewSecretsProviderRegistry(promslog.NewNopLogger(), prometheus.DefaultRegisterer) + integrations, err := BuildReceiverIntegrations(tc.receiver, nil, nil, sp) if tc.err { require.Error(t, err) return diff --git a/notify/pagerduty/pagerduty.go b/notify/pagerduty/pagerduty.go index 6a2ba01be2..93fbbe4464 100644 --- a/notify/pagerduty/pagerduty.go +++ b/notify/pagerduty/pagerduty.go @@ -19,13 +19,14 @@ import ( "encoding/json" "errors" "fmt" - "github.com/prometheus/alertmanager/secrets" "io" "log/slog" "net/http" "os" "strings" + "github.com/prometheus/alertmanager/secrets" + "github.com/alecthomas/units" commoncfg "github.com/prometheus/common/config" "github.com/prometheus/common/model" @@ -62,14 +63,21 @@ func New(c *config.PagerdutyConfig, t *template.Template, l *slog.Logger, spRegi return nil, err } n := &Notifier{conf: c, tmpl: t, logger: l, client: client} - if c.ServiceKey != nil || c.ServiceKeyFile != "" { + + if !c.ServiceKey.IsZero() { n.secretsFetcher, err = spRegistry.RegisterSecret(c.ServiceKey) + } else if !c.RoutingKey.IsZero() { + n.secretsFetcher, err = spRegistry.RegisterSecret(c.RoutingKey) + } + if err != nil { + l.Error("error registering secret", "err", err) + } + if !c.ServiceKey.IsZero() || c.ServiceKeyFile != "" { n.apiV1 = "https://events.pagerduty.com/generic/2010-04-15/create_event.json" // Retrying can solve the issue on 403 (rate limiting) and 5xx response codes. // https://v2.developer.pagerduty.com/docs/trigger-events n.retrier = ¬ify.Retrier{RetryCodes: []int{http.StatusForbidden}, CustomDetailsFunc: errDetails} } else { - n.secretsFetcher, err = spRegistry.RegisterSecret(c.RoutingKey) // Retrying can solve the issue on 429 (rate limiting) and 5xx response codes. // https://v2.developer.pagerduty.com/docs/events-api-v2#api-response-codes--retry-logic n.retrier = ¬ify.Retrier{RetryCodes: []int{http.StatusTooManyRequests}, CustomDetailsFunc: errDetails} @@ -148,19 +156,22 @@ func (n *Notifier) encodeMessage(msg *pagerDutyMessage) (bytes.Buffer, error) { } func (n *Notifier) getSecret(ctx context.Context) string { - var secret *secrets.GenericSecret - if n.conf.ServiceKey != nil { + var secret secrets.GenericSecret + if !n.conf.ServiceKey.IsZero() { secret = n.conf.ServiceKey - } else { + } else if !n.conf.RoutingKey.IsZero() { secret = n.conf.RoutingKey } + if secret.IsZero() || n.secretsFetcher == nil { + return "" + } - if sec, err := n.secretsFetcher.FetchSecret(ctx, secret); err != nil { - n.logger.Error("unable to fetch secret", err) + sec, err := n.secretsFetcher.FetchSecret(ctx, secret) + if err != nil { + n.logger.Error("unable to fetch secret", "error", err) return "" - } else { - return sec } + return sec } func (n *Notifier) notifyV1( @@ -179,9 +190,8 @@ func (n *Notifier) notifyV1( n.logger.Warn("Truncated description", "key", key, "max_runes", maxV1DescriptionLenRunes) } - //serviceKey := string(n.conf.ServiceKey) serviceKey := n.getSecret(ctx) - if serviceKey == "" { + if serviceKey == "" && n.conf.ServiceKeyFile != "" { content, fileErr := os.ReadFile(n.conf.ServiceKeyFile) if fileErr != nil { return false, fmt.Errorf("failed to read service key from file: %w", fileErr) @@ -220,6 +230,9 @@ func (n *Notifier) notifyV1( if err != nil { return true, fmt.Errorf("failed to post message to PagerDuty v1: %w", err) } + if resp.StatusCode == 403 { + n.secretsFetcher.RefreshCredentialsAsync() + } defer notify.Drain(resp) return n.retrier.Check(resp.StatusCode, resp.Body) @@ -245,9 +258,8 @@ func (n *Notifier) notifyV2( n.logger.Warn("Truncated summary", "key", key, "max_runes", maxV2SummaryLenRunes) } - //routingKey := string(n.conf.RoutingKey) routingKey := n.getSecret(ctx) - if routingKey == "" { + if routingKey == "" && n.conf.RoutingKeyFile != "" { content, fileErr := os.ReadFile(n.conf.RoutingKeyFile) if fileErr != nil { return false, fmt.Errorf("failed to read routing key from file: %w", fileErr) @@ -317,6 +329,9 @@ func (n *Notifier) notifyV2( } defer notify.Drain(resp) + if resp.StatusCode == 403 { + n.secretsFetcher.RefreshCredentialsAsync() + } retry, err := n.retrier.Check(resp.StatusCode, resp.Body) if err != nil { return retry, notify.NewErrorWithReason(notify.GetFailureReasonFromStatusCode(resp.StatusCode), err) diff --git a/notify/pagerduty/pagerduty_test.go b/notify/pagerduty/pagerduty_test.go index 60e2e49259..5d23dacfef 100644 --- a/notify/pagerduty/pagerduty_test.go +++ b/notify/pagerduty/pagerduty_test.go @@ -27,6 +27,10 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" + + "github.com/prometheus/alertmanager/secrets" + commoncfg "github.com/prometheus/common/config" "github.com/prometheus/common/model" "github.com/prometheus/common/promslog" @@ -41,11 +45,14 @@ import ( func TestPagerDutyRetryV1(t *testing.T) { notifier, err := New( &config.PagerdutyConfig{ - ServiceKey: config.Secret("01234567890123456789012345678901"), + ServiceKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: "01234567890123456789012345678901"}, + }, HTTPConfig: &commoncfg.HTTPClientConfig{}, }, test.CreateTmpl(t), promslog.NewNopLogger(), + secrets.NewSecretsProviderRegistry(promslog.NewNopLogger(), prometheus.DefaultRegisterer), ) require.NoError(t, err) @@ -59,11 +66,14 @@ func TestPagerDutyRetryV1(t *testing.T) { func TestPagerDutyRetryV2(t *testing.T) { notifier, err := New( &config.PagerdutyConfig{ - RoutingKey: config.Secret("01234567890123456789012345678901"), + RoutingKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: "01234567890123456789012345678901"}, + }, HTTPConfig: &commoncfg.HTTPClientConfig{}, }, test.CreateTmpl(t), promslog.NewNopLogger(), + secrets.NewSecretsProviderRegistry(promslog.NewNopLogger(), prometheus.DefaultRegisterer), ) require.NoError(t, err) @@ -81,11 +91,14 @@ func TestPagerDutyRedactedURLV1(t *testing.T) { key := "01234567890123456789012345678901" notifier, err := New( &config.PagerdutyConfig{ - ServiceKey: config.Secret(key), + ServiceKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: "01234567890123456789012345678901"}, + }, HTTPConfig: &commoncfg.HTTPClientConfig{}, }, test.CreateTmpl(t), promslog.NewNopLogger(), + secrets.NewSecretsProviderRegistry(promslog.NewNopLogger(), prometheus.DefaultRegisterer), ) require.NoError(t, err) notifier.apiV1 = u.String() @@ -100,12 +113,15 @@ func TestPagerDutyRedactedURLV2(t *testing.T) { key := "01234567890123456789012345678901" notifier, err := New( &config.PagerdutyConfig{ - URL: &config.URL{URL: u}, - RoutingKey: config.Secret(key), + URL: &config.URL{URL: u}, + RoutingKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: "01234567890123456789012345678901"}, + }, HTTPConfig: &commoncfg.HTTPClientConfig{}, }, test.CreateTmpl(t), promslog.NewNopLogger(), + secrets.NewSecretsProviderRegistry(promslog.NewNopLogger(), prometheus.DefaultRegisterer), ) require.NoError(t, err) @@ -129,6 +145,7 @@ func TestPagerDutyV1ServiceKeyFromFile(t *testing.T) { }, test.CreateTmpl(t), promslog.NewNopLogger(), + secrets.NewSecretsProviderRegistry(promslog.NewNopLogger(), prometheus.DefaultRegisterer), ) require.NoError(t, err) notifier.apiV1 = u.String() @@ -154,6 +171,7 @@ func TestPagerDutyV2RoutingKeyFromFile(t *testing.T) { }, test.CreateTmpl(t), promslog.NewNopLogger(), + secrets.NewSecretsProviderRegistry(promslog.NewNopLogger(), prometheus.DefaultRegisterer), ) require.NoError(t, err) @@ -182,7 +200,9 @@ func TestPagerDutyTemplating(t *testing.T) { { title: "full-blown message", cfg: &config.PagerdutyConfig{ - RoutingKey: config.Secret("01234567890123456789012345678901"), + RoutingKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: "01234567890123456789012345678901"}, + }, Images: []config.PagerdutyImage{ { Src: "{{ .Status }}", @@ -207,7 +227,9 @@ func TestPagerDutyTemplating(t *testing.T) { { title: "details with templating errors", cfg: &config.PagerdutyConfig{ - RoutingKey: config.Secret("01234567890123456789012345678901"), + RoutingKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: "01234567890123456789012345678901"}, + }, Details: map[string]string{ "firing": `{{ template "pagerduty.default.instances" .Alerts.Firing`, "resolved": `{{ template "pagerduty.default.instances" .Alerts.Resolved }}`, @@ -220,38 +242,69 @@ func TestPagerDutyTemplating(t *testing.T) { { title: "v2 message with templating errors", cfg: &config.PagerdutyConfig{ - RoutingKey: config.Secret("01234567890123456789012345678901"), - Severity: "{{ ", + RoutingKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: "01234567890123456789012345678901"}, + }, + Severity: "{{ ", }, errMsg: "failed to template", }, { title: "v1 message with templating errors", cfg: &config.PagerdutyConfig{ - ServiceKey: config.Secret("01234567890123456789012345678901"), - Client: "{{ ", + ServiceKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: "01234567890123456789012345678901"}, + }, + Client: "{{ ", }, errMsg: "failed to template", }, { title: "routing key cannot be empty", cfg: &config.PagerdutyConfig{ - RoutingKey: config.Secret(`{{ "" }}`), + RoutingKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: `{{ "" }}`}, + }, }, errMsg: "routing key cannot be empty", }, { title: "service_key cannot be empty", cfg: &config.PagerdutyConfig{ - ServiceKey: config.Secret(`{{ "" }}`), + ServiceKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: `{{ "" }}`}, + }, }, errMsg: "service key cannot be empty", }, + { + title: "service_key cannot be empty - AWS Secrets Manager", + cfg: &config.PagerdutyConfig{ + ServiceKey: secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: `{{ "" }}`, + }, + }, + }, + errMsg: "service key cannot be empty", + }, + { + title: "routing_key cannot be empty - AWS Secrets Manager", + cfg: &config.PagerdutyConfig{ + RoutingKey: secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: `{{ "" }}`, + }, + }, + }, + errMsg: "routing key cannot be empty", + }, } { t.Run(tc.title, func(t *testing.T) { tc.cfg.URL = &config.URL{URL: u} tc.cfg.HTTPConfig = &commoncfg.HTTPClientConfig{} - pd, err := New(tc.cfg, test.CreateTmpl(t), promslog.NewNopLogger()) + spRegistry := secrets.NewSecretsProviderRegistry(promslog.NewNopLogger(), prometheus.DefaultRegisterer) + pd, err := New(tc.cfg, test.CreateTmpl(t), promslog.NewNopLogger(), spRegistry) require.NoError(t, err) if pd.apiV1 != "" { pd.apiV1 = u.String() @@ -336,11 +389,14 @@ func TestEventSizeEnforcement(t *testing.T) { notifierV1, err := New( &config.PagerdutyConfig{ - ServiceKey: config.Secret("01234567890123456789012345678901"), + ServiceKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: "01234567890123456789012345678901"}, + }, HTTPConfig: &commoncfg.HTTPClientConfig{}, }, test.CreateTmpl(t), promslog.NewNopLogger(), + secrets.NewSecretsProviderRegistry(promslog.NewNopLogger(), prometheus.DefaultRegisterer), ) require.NoError(t, err) @@ -359,11 +415,14 @@ func TestEventSizeEnforcement(t *testing.T) { notifierV2, err := New( &config.PagerdutyConfig{ - RoutingKey: config.Secret("01234567890123456789012345678901"), + RoutingKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: "01234567890123456789012345678901"}, + }, HTTPConfig: &commoncfg.HTTPClientConfig{}, }, test.CreateTmpl(t), promslog.NewNopLogger(), + secrets.NewSecretsProviderRegistry(promslog.NewNopLogger(), prometheus.DefaultRegisterer), ) require.NoError(t, err) @@ -472,13 +531,15 @@ func TestPagerDutyEmptySrcHref(t *testing.T) { pagerDutyConfig := config.PagerdutyConfig{ HTTPConfig: &commoncfg.HTTPClientConfig{}, - RoutingKey: config.Secret("01234567890123456789012345678901"), - URL: &config.URL{URL: url}, - Images: images, - Links: links, + RoutingKey: secrets.GenericSecret{ + Inline: secrets.Inline{Secret: "01234567890123456789012345678901"}, + }, + URL: &config.URL{URL: url}, + Images: images, + Links: links, } - pagerDuty, err := New(&pagerDutyConfig, test.CreateTmpl(t), promslog.NewNopLogger()) + pagerDuty, err := New(&pagerDutyConfig, test.CreateTmpl(t), promslog.NewNopLogger(), secrets.NewSecretsProviderRegistry(promslog.NewNopLogger(), prometheus.DefaultRegisterer)) require.NoError(t, err) ctx := context.Background() diff --git a/secrets/generic_secret.go b/secrets/generic_secret.go index 8742a705af..4d1e7c747a 100644 --- a/secrets/generic_secret.go +++ b/secrets/generic_secret.go @@ -1,32 +1,84 @@ +// Copyright 2019 Prometheus Team +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package secrets import ( - "errors" + "encoding/json" "time" ) +var MarshalSecretValue = false + +const secretToken = "" + type GenericSecret struct { - AWSSecretsManagerConfig *AWSSecretsManagerConfig `yaml:"aws_secrets_manager" json:"aws_secrets_manager_config"` + Inline Inline `yaml:",inline,omitempty" json:",inline,omitempty"` + AWSSecretsManagerConfig AWSSecretsManagerConfig `yaml:"aws_secrets_manager,omitempty" json:"aws_secrets_manager_config,omitempty"` } -// TODO implement this correctly -func (gs *GenericSecret) String() string { - return "" +func (gs GenericSecret) String() string { + if MarshalSecretValue { + return gs.Inline.Secret + } + return secretToken +} + +func (gs GenericSecret) IsZero() bool { + return gs.Inline.Secret == "" && gs.AWSSecretsManagerConfig.IsZero() } -// TODO implement Marshal and JSON equivalent methods func (gs *GenericSecret) UnmarshalYAML(unmarshalFn func(any) error) error { var inlineForm string if err := unmarshalFn(&inlineForm); err == nil { - return errors.New("inline form is not supported") + gs.Inline = Inline{inlineForm} + return nil } type plain GenericSecret // We need to do this to avoid infinite recursion. return unmarshalFn((*plain)(gs)) } +func (gs GenericSecret) MarshalYAML() (interface{}, error) { + if MarshalSecretValue { + return gs.String(), nil + } + if !gs.IsZero() { + return secretToken, nil + } + return nil, nil +} + +func (gs GenericSecret) MarshalJSON() ([]byte, error) { + if MarshalSecretValue { + return json.Marshal(gs) + } + if !gs.IsZero() { + return json.Marshal("") + } + return json.Marshal(secretToken) +} + type AWSSecretsManagerConfig struct { - SecretARN string `yaml:"secret_arn"` - SecretKey string `yaml:"secret_key"` - RefreshInterval time.Duration `yaml:"refresh_interval"` + SecretARN string `yaml:"secret_arn" json:"secret_arn"` + SecretKey string `yaml:"secret_key" json:"secret_key"` + RefreshInterval time.Duration `yaml:"refresh_interval" json:"refresh_interval"` +} + +func (a AWSSecretsManagerConfig) IsZero() bool { + return a.SecretARN == "" && a.SecretKey == "" +} + +type Inline struct { + Secret string } diff --git a/secrets/generic_secret_test.go b/secrets/generic_secret_test.go new file mode 100644 index 0000000000..640693014c --- /dev/null +++ b/secrets/generic_secret_test.go @@ -0,0 +1,230 @@ +// Copyright 2019 Prometheus Team +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package secrets + +import ( + "testing" + "time" + + "gopkg.in/yaml.v2" +) + +func TestGenericSecret_String(t *testing.T) { + tests := []struct { + name string + gs GenericSecret + marshal bool + expected string + }{ + { + name: "with marshal secret value true", + gs: GenericSecret{Inline: Inline{Secret: "test-secret"}}, + marshal: true, + expected: "test-secret", + }, + { + name: "with marshal secret value false", + gs: GenericSecret{Inline: Inline{Secret: "test-secret"}}, + marshal: false, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + MarshalSecretValue = tt.marshal + if got := tt.gs.String(); got != tt.expected { + t.Errorf("GenericSecret.String() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestGenericSecret_IsZero(t *testing.T) { + tests := []struct { + name string + gs GenericSecret + expected bool + }{ + { + name: "empty generic secret", + gs: GenericSecret{}, + expected: true, + }, + { + name: "non-empty inline secret", + gs: GenericSecret{ + Inline: Inline{Secret: "test-secret"}, + }, + expected: false, + }, + { + name: "non-empty AWS config", + gs: GenericSecret{ + AWSSecretsManagerConfig: AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:test", + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.gs.IsZero(); got != tt.expected { + t.Errorf("GenericSecret.IsZero() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestGenericSecret_MarshalYAML(t *testing.T) { + tests := []struct { + name string + gs GenericSecret + marshal bool + expected interface{} + wantErr bool + }{ + { + name: "marshal with secret value true", + gs: GenericSecret{ + Inline: Inline{Secret: "test-secret"}, + }, + marshal: true, + expected: "test-secret", + wantErr: false, + }, + { + name: "marshal with secret value false", + gs: GenericSecret{ + Inline: Inline{Secret: "test-secret"}, + }, + marshal: false, + expected: "", + wantErr: false, + }, + { + name: "marshal zero value", + gs: GenericSecret{}, + marshal: false, + expected: nil, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + MarshalSecretValue = tt.marshal + got, err := tt.gs.MarshalYAML() + if (err != nil) != tt.wantErr { + t.Errorf("GenericSecret.MarshalYAML() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.expected { + t.Errorf("GenericSecret.MarshalYAML() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestAWSSecretsManagerConfig_IsZero(t *testing.T) { + tests := []struct { + name string + config AWSSecretsManagerConfig + expected bool + }{ + { + name: "empty config", + config: AWSSecretsManagerConfig{}, + expected: true, + }, + { + name: "config with ARN only", + config: AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:test", + }, + expected: false, + }, + { + name: "config with key only", + config: AWSSecretsManagerConfig{ + SecretKey: "test-key", + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.config.IsZero(); got != tt.expected { + t.Errorf("AWSSecretsManagerConfig.IsZero() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestGenericSecret_UnmarshalYAML(t *testing.T) { + tests := []struct { + name string + input string + expected GenericSecret + wantErr bool + }{ + { + name: "inline string secret", + input: "test-secret", + expected: GenericSecret{ + Inline: Inline{Secret: "test-secret"}, + }, + wantErr: false, + }, + { + name: "aws secrets manager config", + input: ` +aws_secrets_manager: + secret_arn: arn:aws:secretsmanager:test + secret_key: test-key + refresh_interval: 1h +`, + expected: GenericSecret{ + AWSSecretsManagerConfig: AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:test", + SecretKey: "test-key", + RefreshInterval: time.Hour, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got GenericSecret + err := yaml.Unmarshal([]byte(tt.input), &got) + if (err != nil) != tt.wantErr { + t.Errorf("GenericSecret.UnmarshalYAML() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + // Compare the relevant fields + if got.Inline.Secret != tt.expected.Inline.Secret { + t.Errorf("Inline.Secret = %v, want %v", got.Inline.Secret, tt.expected.Inline.Secret) + } + if got.AWSSecretsManagerConfig.SecretARN != tt.expected.AWSSecretsManagerConfig.SecretARN { + t.Errorf("SecretARN = %v, want %v", got.AWSSecretsManagerConfig.SecretARN, tt.expected.AWSSecretsManagerConfig.SecretARN) + } + } + }) + } +} diff --git a/secrets/providers/aws_secrets_manager.go b/secrets/providers/aws_secrets_manager.go index d569ad8405..8f1acb6c0f 100644 --- a/secrets/providers/aws_secrets_manager.go +++ b/secrets/providers/aws_secrets_manager.go @@ -1,33 +1,112 @@ +// Copyright 2019 Prometheus Team +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package providers import ( "context" "encoding/json" "errors" + "fmt" + "log/slog" + "net/http" + "sync" + "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/arn" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/secretsmanager" - "github.com/prometheus/alertmanager/secrets" "github.com/prometheus/client_golang/prometheus" - "log/slog" - "sync" - "time" + commoncfg "github.com/prometheus/common/config" + + "github.com/prometheus/alertmanager/secrets" ) -//TODO metrics +type SecretFetchState int + +const ( + Success SecretFetchState = iota + Stale + Error +) + +var MinTimeInterval = 5 * time.Second + +func (s SecretFetchState) String() string { + switch s { + case Success: + return "success" + case Stale: + return "stale" + case Error: + return "error" + default: + return fmt.Sprintf("unknown(%d)", int(s)) + } +} + +func (s SecretFetchState) Value() float64 { + return float64(s) +} + +type RoundTripperFn func(rt http.RoundTripper) (http.RoundTripper, error) type AWSSecretsManagerProvider struct { - mtx sync.RWMutex - fetchers map[string]*secretFetcher - logger *slog.Logger - reg prometheus.Registerer - ctx context.Context + mtx sync.RWMutex + fetchers map[string]*secretFetcher + logger *slog.Logger + reg prometheus.Registerer + ctx context.Context + RoundTripper RoundTripperFn + secretFetchers prometheus.Gauge + + newFetchers map[string]struct{} +} + +func (a *AWSSecretsManagerProvider) fetchersCount() int { + a.mtx.RLock() + defer a.mtx.RUnlock() + return len(a.fetchers) +} + +func (a *AWSSecretsManagerProvider) validateARN(secretARN string) error { + parsedARN, err := arn.Parse(secretARN) + if err != nil { + return err + } + if parsedARN.Service != "secretsmanager" { + return errors.New("invalid service") + } + if parsedARN.Resource == "" { + return errors.New("invalid resource") + } + if parsedARN.AccountID == "" { + return errors.New("invalid account ID") + } + if parsedARN.Partition == "" { + return errors.New("invalid partition") + } + return nil } -func (a *AWSSecretsManagerProvider) Register(secret *secrets.GenericSecret) secrets.SecretsFetcher { +func (a *AWSSecretsManagerProvider) Register(secret secrets.GenericSecret) secrets.SecretsFetcher { s := secret.AWSSecretsManagerConfig - if s == nil { + if err := a.validateARN(s.SecretARN); err != nil { + a.logger.Error("invalid secret ARN", "error", err) + return nil + } + if s.IsZero() { a.logger.Error("secret is nil. nothing to register") return nil } @@ -35,12 +114,15 @@ func (a *AWSSecretsManagerProvider) Register(secret *secrets.GenericSecret) secr a.mtx.Lock() defer a.mtx.Unlock() if f, OK := a.fetchers[s.SecretARN]; OK { - a.logger.Info("found an existing secret fetcher") + a.logger.Info("found an existing secret fetcher", "ARN", s.SecretARN) f.update(s.RefreshInterval) + a.newFetchers[s.SecretARN] = struct{}{} return f } - a.logger.Info("no secret fetcher found. creating a new one") - a.fetchers[s.SecretARN] = newSecretFetcher(a.ctx, a.logger, a.reg, s.SecretARN, s.RefreshInterval) + a.logger.Info("no secret fetcher found. creating a new one", "ARN", s.SecretARN) + a.fetchers[s.SecretARN] = newSecretFetcher(a.ctx, a.logger, a.reg, a.RoundTripper, s) + a.secretFetchers.Set(float64(len(a.fetchers))) + a.newFetchers[s.SecretARN] = struct{}{} return a.fetchers[s.SecretARN] } @@ -49,116 +131,227 @@ func (a *AWSSecretsManagerProvider) Stop() { defer a.mtx.Unlock() for name, fetcher := range a.fetchers { a.logger.Info("stopping secrets fetcher", "name", name) - fetcher.Stop() + fetcher.AwaitStop() + delete(a.fetchers, name) + a.secretFetchers.Dec() } a.logger.Info("aws secrets manager providers stopped") } +func (a *AWSSecretsManagerProvider) UpdateComplete() { + a.mtx.Lock() + defer a.mtx.Unlock() + a.logger.Debug("Update begin", "fetchers", len(a.fetchers), "new fetchers", len(a.newFetchers)) + for name, fetcher := range a.fetchers { + if _, OK := a.newFetchers[name]; !OK { + fetcher.Stop() + delete(a.fetchers, name) + a.secretFetchers.Dec() + } + } + // reset new fetchers + a.newFetchers = make(map[string]struct{}) + a.logger.Debug("Update complete", "fetchers", len(a.fetchers)) +} + +func NewAWSSecretsManagerProvider(options secrets.SecretProviderOptions) *AWSSecretsManagerProvider { + provider := &AWSSecretsManagerProvider{ + fetchers: make(map[string]*secretFetcher), + newFetchers: make(map[string]struct{}), + logger: options.Logger, + reg: options.Registerer, + ctx: options.Context, + secretFetchers: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "alertmanager_secrets_fetchers", + Help: "Number of AWS Secrets Manager fetchers", + ConstLabels: prometheus.Labels{ + "provider": "aws_secrets_manager", + }, + }), + } + provider.reg.MustRegister(provider.secretFetchers) + return provider +} + type secretFetcher struct { - secrets map[string]string - mtx sync.RWMutex - logger *slog.Logger - reg prometheus.Registerer - arn string - interval time.Duration - ctx context.Context - client *secretsmanager.Client - done chan struct{} - ticker *time.Ticker - initialFetch bool -} - -func newSecretFetcher(ctx context.Context, logger *slog.Logger, reg prometheus.Registerer, arn string, interval time.Duration) *secretFetcher { + secretName string + secretConfig secrets.AWSSecretsManagerConfig + secrets map[string]string + mtx sync.RWMutex + logger *slog.Logger + reg prometheus.Registerer + ctx context.Context + client AWSSecretsManagerOperations + done chan struct{} + ticker *time.Ticker + roundTripper RoundTripperFn + asyncCh chan struct{} + lastRetrieved time.Time + metrics *SecretFetcherMetrics + everSucceeded bool + stopCh chan struct{} +} + +func newSecretFetcher(ctx context.Context, logger *slog.Logger, reg prometheus.Registerer, roundTripper RoundTripperFn, sc secrets.AWSSecretsManagerConfig) *secretFetcher { + if sc.RefreshInterval <= 0 { + sc.RefreshInterval = time.Minute * 5 + } + parsedARN, _ := arn.Parse(sc.SecretARN) sf := &secretFetcher{ - secrets: make(map[string]string), - logger: logger, - reg: reg, - arn: arn, - interval: interval, - ctx: ctx, - done: make(chan struct{}), - ticker: time.NewTicker(interval), + secrets: make(map[string]string), + logger: logger, + reg: reg, + secretConfig: sc, + ctx: ctx, + done: make(chan struct{}), + roundTripper: roundTripper, + asyncCh: make(chan struct{}), + lastRetrieved: time.Time{}, + secretName: parsedARN.Resource, + stopCh: make(chan struct{}), } + sf.metrics = NewSecretFetcherMetrics(reg, sf.secretName) + sf.ticker = time.NewTicker(sf.secretConfig.RefreshInterval) sf.createSecretsManagerClient() go sf.run() return sf } +func (s *secretFetcher) RefreshCredentialsAsync() { + s.asyncCh <- struct{}{} +} + func (s *secretFetcher) createSecretsManagerClient() { - parsedARN, err := arn.Parse(s.arn) + parsedARN, err := arn.Parse(s.secretConfig.SecretARN) + if err != nil { + s.logger.Error("unable to create secret manager client", "error", err) + return + } + var httpClient *http.Client + httpClient, err = commoncfg.NewClientFromConfig(commoncfg.DefaultHTTPClientConfig, "aws_secrets_manager") if err != nil { - s.logger.Error("unable to create secret manager client", err) + s.logger.Error("unable to create a new http client", "error", err) return } + if s.roundTripper != nil { + httpClient.Transport, err = s.roundTripper(httpClient.Transport) + if err != nil { + s.logger.Warn("unable to create round tripper. proceeding ", "error", err) + } + } + config, err := awsconfig.LoadDefaultConfig(s.ctx, awsconfig.WithRegion(parsedARN.Region)) + config.HTTPClient = httpClient if err != nil { - s.logger.Error("unable to load config", err) + s.logger.Error("unable to load config", "error", err) return } s.client = secretsmanager.NewFromConfig(config) } +func (s *secretFetcher) AwaitStop() { + <-s.done + s.metrics.Unregister(s.reg) + s.logger.Info("secret fetcher stopped", "secret id", s.secretName) +} + func (s *secretFetcher) Stop() { + close(s.stopCh) <-s.done - s.logger.Info("secret fetcher stopped") + s.metrics.Unregister(s.reg) + s.logger.Info("secret fetcher stopped", "secret id", s.secretName) } func (s *secretFetcher) run() { defer close(s.done) defer s.ticker.Stop() input := &secretsmanager.GetSecretValueInput{ - SecretId: aws.String(s.arn), + SecretId: aws.String(s.secretConfig.SecretARN), } - s.logger.Debug("fetch secret", "reason", "initial") - s.retrieveSecret(input) - s.initialFetch = true + s.retrieveSecret(input, "initial") for { select { case <-s.ticker.C: - s.logger.Debug("fetching secret", "reason", "periodic") - s.retrieveSecret(input) - s.initialFetch = true + s.retrieveSecret(input, "periodic") + case <-s.asyncCh: + s.retrieveSecret(input, "async refresh") + case <-s.stopCh: + s.logger.Info("stopping secrets fetcher via stop signal") + return case <-s.ctx.Done(): - s.logger.Info("stopping secrets fetcher") + s.logger.Info("stopping secrets fetcher via context cancellation") return } } } -func (s *secretFetcher) retrieveSecret(input *secretsmanager.GetSecretValueInput) { - result, err := s.client.GetSecretValue(s.ctx, input) +func (s *secretFetcher) retrieveSecret(input *secretsmanager.GetSecretValueInput, reason string) { + updateState := func(state SecretFetchState) { + if state == Error && s.everSucceeded { + state = Stale + } + s.metrics.secretState.WithLabelValues(state.String()).Set(state.Value()) + } + if time.Since(s.lastRetrieved) < MinTimeInterval { + s.logger.Debug("not refreshing secret", "reason", "too soon") + return + } + if !s.lastRetrieved.IsZero() { + s.metrics.timeSinceLastSuccessfulFetch.Set(float64(time.Since(s.lastRetrieved))) + } + s.logger.Debug("fetching secret", "reason", reason) + ctx, cancelFunc := context.WithTimeoutCause(s.ctx, time.Second, errors.New("timed out while retrieving secret")) + defer cancelFunc() + if s.client == nil { + s.logger.Error("secret manager client is nil", "arn", s.secretConfig.SecretARN) + updateState(Error) + return + } + retrievalTime := time.Now() + result, err := s.client.GetSecretValue(ctx, input) if err != nil { - s.logger.Error("unable to fetch secret for ARN", "arn", s.arn, "error", err) + s.logger.Error("unable to fetch secret for ARN", "arn", s.secretConfig.SecretARN, "error", err) + s.metrics.secretFetchErrors.Inc() + updateState(Error) return } secretString := *result.SecretString var m map[string]string if err = json.Unmarshal([]byte(secretString), &m); err != nil { - s.logger.Error("unable to unmarshal payload", "arn", s.arn, "error", err) + s.logger.Error("unable to unmarshal payload", "arn", s.secretConfig.SecretARN, "error", err) + s.metrics.secretFetchErrors.Inc() + updateState(Error) return } s.logger.Debug("retrieved keys", "key count", len(m)) s.mtx.Lock() defer s.mtx.Unlock() + s.lastRetrieved = retrievalTime s.secrets = nil s.secrets = m + s.metrics.secretFetchSuccess.Inc() + s.everSucceeded = true + updateState(Success) } func (s *secretFetcher) update(interval time.Duration) { s.mtx.Lock() defer s.mtx.Unlock() - if s.interval > interval { - s.interval = interval - s.ticker.Reset(s.interval) - } + s.secretConfig.RefreshInterval = interval + s.ticker.Reset(s.secretConfig.RefreshInterval) } -func (s *secretFetcher) FetchSecret(_ context.Context, secret *secrets.GenericSecret) (string, error) { +func (s *secretFetcher) FetchSecret(_ context.Context, secret secrets.GenericSecret) (string, error) { sec := secret.AWSSecretsManagerConfig - if sec == nil { + if sec.IsZero() { return "", errors.New("cannot fetch empty secret") } + // Pre-check if SecretKey is empty to avoid unnecessary lock + if sec.SecretKey == "" { + return "", errors.New("secret key is empty") + } + s.mtx.RLock() value, exists := s.secrets[sec.SecretKey] s.mtx.RUnlock() @@ -168,18 +361,76 @@ func (s *secretFetcher) FetchSecret(_ context.Context, secret *secrets.GenericSe return value, nil } -type AWSSecretsManagerSecretProviderDiscoveryConfig struct { -} +type AWSSecretsManagerSecretProviderDiscoveryConfig struct{} func (a AWSSecretsManagerSecretProviderDiscoveryConfig) Name() string { return "aws_secrets_manager" } func (a AWSSecretsManagerSecretProviderDiscoveryConfig) NewSecretsProvider(options secrets.SecretProviderOptions) (secrets.SecretsProvider, error) { - return &AWSSecretsManagerProvider{ - fetchers: make(map[string]*secretFetcher), - logger: options.Logger, - reg: options.Registerer, - ctx: options.Context, - }, nil + return NewAWSSecretsManagerProvider(options), nil +} + +type SecretFetcherMetrics struct { + secretFetchErrors prometheus.Counter + secretFetchSuccess prometheus.Counter + secretState *prometheus.GaugeVec + timeSinceLastSuccessfulFetch prometheus.Gauge +} + +func (sm *SecretFetcherMetrics) Unregister(reg prometheus.Registerer) { + reg.Unregister(sm.secretFetchErrors) + reg.Unregister(sm.secretFetchSuccess) + reg.Unregister(sm.secretState) + reg.Unregister(sm.timeSinceLastSuccessfulFetch) +} + +func NewSecretFetcherMetrics(reg prometheus.Registerer, secretName string) *SecretFetcherMetrics { + m := &SecretFetcherMetrics{ + secretFetchErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "alertmanager_remote_secret_fetch_failures_total", + Help: "Total number of failed secret fetches", + ConstLabels: prometheus.Labels{ + "secret_id": secretName, + "provider": "aws_secrets_manager", + }, + }), + secretFetchSuccess: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "alertmanager_remote_secret_fetch_success_total", + Help: "Total number of successful secret fetches", + ConstLabels: prometheus.Labels{ + "secret_id": secretName, + "provider": "aws_secrets_manager", + }, + }), + secretState: prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "alertmanager_remote_secret_state", + Help: "State of the secret", + ConstLabels: prometheus.Labels{ + "secret_id": secretName, + "provider": "aws_secrets_manager", + }, + }, []string{"state"}), + timeSinceLastSuccessfulFetch: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "alertmanager_remote_secret_time_since_last_successful_fetch", + Help: "Time since last successful secret fetch", + ConstLabels: prometheus.Labels{ + "secret_id": secretName, + "provider": "aws_secrets_manager", + }, + }), + } + if reg != nil { + reg.MustRegister( + m.secretFetchErrors, + m.secretFetchSuccess, + m.secretState, + m.timeSinceLastSuccessfulFetch, + ) + } + return m +} + +type AWSSecretsManagerOperations interface { + GetSecretValue(ctx context.Context, params *secretsmanager.GetSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) } diff --git a/secrets/providers/aws_secrets_manager_internal.go b/secrets/providers/aws_secrets_manager_internal.go new file mode 100644 index 0000000000..e7cb4da72b --- /dev/null +++ b/secrets/providers/aws_secrets_manager_internal.go @@ -0,0 +1,78 @@ +// Copyright 2019 Prometheus Team +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package providers + +import ( + "fmt" + "log/slog" + "net/http" + "os" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws/arn" + + "github.com/prometheus/alertmanager/secrets" +) + +type AMPAWSSecretsManagerSecretProviderDiscoveryConfig struct { + UserID string + workspaceARN string +} + +func (a AMPAWSSecretsManagerSecretProviderDiscoveryConfig) Name() string { + return "aws_secrets_manager" +} + +func (a AMPAWSSecretsManagerSecretProviderDiscoveryConfig) NewSecretsProvider(options secrets.SecretProviderOptions) (secrets.SecretsProvider, error) { + secretsManagerProvider := NewAWSSecretsManagerProvider(options) + userComponents := strings.Split(a.UserID, "_") + if len(userComponents) != 2 { + options.Logger.Info("user id is not in the correct format", "user id", a.UserID) + return secretsManagerProvider, nil + } + account := userComponents[0] + workspaceID := userComponents[1] + region := os.Getenv("AWS_REGION") + partition := os.Getenv("AWS_PARTITION") + a.workspaceARN = fmt.Sprintf("arn:%s:aps:%s:%s:workspace/%s", partition, region, account, workspaceID) + secretsManagerProvider.RoundTripper = newConfusedDeputyRoundTripper(a.workspaceARN, options.Logger) + return secretsManagerProvider, nil +} + +type confusedDeputyRoundTripper struct { + workspaceARN arn.ARN + rt http.RoundTripper + logger *slog.Logger +} + +func (rt *confusedDeputyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("x-amz-source-account", rt.workspaceARN.AccountID) + req.Header.Set("x-amz-source-arn", rt.workspaceARN.String()) + rt.logger.Debug("round tripper called", "account id", rt.workspaceARN.AccountID, "arn", rt.workspaceARN.String()) + return rt.rt.RoundTrip(req) +} + +func newConfusedDeputyRoundTripper(workspaceARN string, logger *slog.Logger) RoundTripperFn { + return func(rt http.RoundTripper) (tripper http.RoundTripper, err error) { + if workspaceARN == "" { + return rt, nil + } + + parsedARN, err := arn.Parse(workspaceARN) + if err != nil { + return nil, fmt.Errorf("%s is not a valid arn", workspaceARN) + } + return &confusedDeputyRoundTripper{parsedARN, rt, logger}, nil + } +} diff --git a/secrets/providers/aws_secrets_manager_test.go b/secrets/providers/aws_secrets_manager_test.go new file mode 100644 index 0000000000..7e5289a952 --- /dev/null +++ b/secrets/providers/aws_secrets_manager_test.go @@ -0,0 +1,468 @@ +// Copyright 2019 Prometheus Team +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package providers + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/prometheus/common/promslog" + "github.com/stretchr/testify/require" + + "github.com/prometheus/alertmanager/secrets" +) + +func TestFetchSecret(t *testing.T) { + tests := []struct { + name string + secret secrets.GenericSecret + setupSecrets map[string]string + expectedValue string + expectError bool + }{ + { + name: "successful fetch", + secret: secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty", + SecretKey: "test-secret", + RefreshInterval: 5 * time.Millisecond, + }, + }, + setupSecrets: map[string]string{ + "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty": ` + { + "test-secret": "secret-value" + } + `, + }, + expectedValue: "secret-value", + expectError: false, + }, + { + name: "successful fetch multi-value", + secret: secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty", + SecretKey: "test-secret", + RefreshInterval: 5 * time.Millisecond, + }, + }, + setupSecrets: map[string]string{ + "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty": ` + { + "test-secret": "secret-value", + "test-secret2": "secret-value2" + } + `, + }, + expectedValue: "secret-value", + expectError: false, + }, + { + name: "empty secret config", + secret: secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{}, + }, + setupSecrets: nil, + expectedValue: "", + expectError: true, + }, + { + name: "empty secret key", + secret: secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretKey: "", + }, + }, + setupSecrets: nil, + expectedValue: "", + expectError: true, + }, + { + name: "non-existent secret key", + secret: secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty", + SecretKey: "test-secret3", + RefreshInterval: 5 * time.Millisecond, + }, + }, + setupSecrets: map[string]string{ + "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty": ` + { + "test-secret": "secret-value", + "test-secret2": "secret-value2" + } + `, + }, + expectedValue: "", + expectError: true, + }, + { + name: "bad secret json", + secret: secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty", + SecretKey: "test-secret3", + RefreshInterval: 5 * time.Millisecond, + }, + }, + setupSecrets: map[string]string{ + "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty": ` + { + "test-secret": "secret-value" + "test-secret2": "secret-value2" + } + `, + }, + expectedValue: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithTimeoutCause(context.Background(), 10*time.Millisecond, errors.New("timeout")) + reg := prometheus.NewPedanticRegistry() + fetcher := &secretFetcher{ + secrets: make(map[string]string), + logger: promslog.NewNopLogger(), + ctx: ctx, + secretConfig: tt.secret.AWSSecretsManagerConfig, + ticker: time.NewTicker(time.Second), + client: &MockSecretsManagerClient{ + secrets: tt.setupSecrets, + }, + done: make(chan struct{}), + asyncCh: make(chan struct{}), + lastRetrieved: time.Time{}, + metrics: NewSecretFetcherMetrics(reg, tt.secret.AWSSecretsManagerConfig.SecretARN), + reg: reg, + } + fetcher.run() + value, err := fetcher.FetchSecret(context.Background(), tt.secret) + cancel() + fetcher.AwaitStop() + if tt.expectError { + require.Error(t, err) + require.Empty(t, value) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedValue, value) + } + }) + } +} + +func TestNewSecretFetcherMetrics_Error(t *testing.T) { + fetcher, reg := setupFetcher() + fetcher.client = &MockSecretsManagerClient{ + GetSecretValueFunc: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, opts ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) { + return nil, errors.New("test error") + }, + } + // Call retrieveSecret + secretID := "test-secret" + input := &secretsmanager.GetSecretValueInput{SecretId: &secretID} + fetcher.retrieveSecret(input, "testing") + errorCount := testutil.ToFloat64(fetcher.metrics.secretFetchErrors) + require.Equal(t, 1.0, errorCount, "Error counter should be 1") + require.False(t, fetcher.everSucceeded, "everSucceeded should still be false") + + metricFamily, err := reg.Gather() + require.NoError(t, err) + + // Find and check the state metric + var stateValue float64 + for _, mf := range metricFamily { + if mf.GetName() == "alertmanager_remote_secret_state" { + for _, m := range mf.GetMetric() { + for _, l := range m.GetLabel() { + if l.GetName() == "state" && l.GetValue() == "error" { + stateValue = m.GetGauge().GetValue() + } + } + } + } + } + require.Equal(t, float64(Error), stateValue, "State should be Error (2)") +} + +func TestNewSecretFetcherMetrics_Success(t *testing.T) { + fetcher, reg := setupFetcher() + + // Mock successful secret retrieval + secretData, _ := json.Marshal(map[string]string{"key": "value"}) + secretStr := string(secretData) + fetcher.client = &MockSecretsManagerClient{ + GetSecretValueFunc: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, opts ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) { + return &secretsmanager.GetSecretValueOutput{ + SecretString: &secretStr, + }, nil + }, + } + // Call retrieveSecret + secretID := "test-secret" + input := &secretsmanager.GetSecretValueInput{SecretId: &secretID} + fetcher.retrieveSecret(input, "testing") + successCount := testutil.ToFloat64(fetcher.metrics.secretFetchSuccess) + require.Equal(t, 1.0, successCount, "Success counter should be 1") + require.True(t, fetcher.everSucceeded, "everSucceeded should be true") + + metricFamily, err := reg.Gather() + require.NoError(t, err) + + // Find and check the state metric + var stateValue float64 + for _, mf := range metricFamily { + if mf.GetName() == "alertmanager_remote_secret_state" { + for _, m := range mf.GetMetric() { + for _, l := range m.GetLabel() { + if l.GetName() == "state" && l.GetValue() == "success" { + stateValue = m.GetGauge().GetValue() + } + } + } + } + } + require.Equal(t, float64(Success), stateValue, "State should be Success (0)") +} + +func TestNewSecretFetcherMetrics_Stale(t *testing.T) { + MinTimeInterval = 5 * time.Millisecond + defer func() { + MinTimeInterval = 5 * time.Second + }() + // Mock successful secret retrieval + secretData, _ := json.Marshal(map[string]string{"key": "value"}) + secretStr := string(secretData) + + fetcher, reg := setupFetcher() + fetcher.client = &MockSecretsManagerClient{ + GetSecretValueFunc: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, opts ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) { + return &secretsmanager.GetSecretValueOutput{ + SecretString: &secretStr, + }, nil + }, + } + // Call retrieveSecret + secretID := "test-secret" + input := &secretsmanager.GetSecretValueInput{SecretId: &secretID} + fetcher.retrieveSecret(input, "testing") + successCount := testutil.ToFloat64(fetcher.metrics.secretFetchSuccess) + require.Equal(t, 1.0, successCount, "Success counter should be 1") + require.True(t, fetcher.everSucceeded, "everSucceeded should be true") + + // Now mock a failure + fetcher.client = &MockSecretsManagerClient{ + GetSecretValueFunc: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, opts ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) { + return nil, errors.New("test error") + }, + } + + time.Sleep(MinTimeInterval) + fetcher.retrieveSecret(input, "testing") + + // Verify stale state (should be Stale since everSucceeded is true) + errorCount := testutil.ToFloat64(fetcher.metrics.secretFetchErrors) + require.Equal(t, 1.0, errorCount, "Error counter should be 1") + + metricFamily, err := reg.Gather() + require.NoError(t, err) + + // Find and check the state metric + var stateValue float64 + for _, mf := range metricFamily { + if mf.GetName() == "alertmanager_remote_secret_state" { + for _, m := range mf.GetMetric() { + for _, l := range m.GetLabel() { + if l.GetName() == "state" && l.GetValue() == "stale" { + stateValue = m.GetGauge().GetValue() + } + } + } + } + } + require.Equal(t, float64(Stale), stateValue, "State should be Stale (1)") + require.True(t, fetcher.everSucceeded, "everSucceeded should still be true") +} + +func TestAWSSecretsManagerSecretProviderDiscoveryConfig(t *testing.T) { + config := AWSSecretsManagerSecretProviderDiscoveryConfig{} + + t.Run("test name", func(t *testing.T) { + require.Equal(t, "aws_secrets_manager", config.Name()) + }) + + var provider *AWSSecretsManagerProvider + t.Run("test provider creation", func(t *testing.T) { + options := secrets.SecretProviderOptions{ + Logger: promslog.NewNopLogger(), + Registerer: prometheus.NewRegistry(), + Context: context.Background(), + } + + p, err := config.NewSecretsProvider(options) + + require.NoError(t, err) + require.NotNil(t, p) + + // Type assertion to ensure correct type + pt, ok := p.(*AWSSecretsManagerProvider) + provider = pt + require.True(t, ok) + }) + + t.Run("test register secret", func(t *testing.T) { + fetcher := provider.Register(secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty", + SecretKey: "test-secret", + RefreshInterval: 5 * time.Millisecond, + }, + }) + require.NotNil(t, fetcher) + require.Equal(t, 1, provider.fetchersCount()) + }) + + t.Run("test register second secret", func(t *testing.T) { + fetcher := provider.Register(secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty2", + SecretKey: "test-secret", + RefreshInterval: 5 * time.Millisecond, + }, + }) + require.NotNil(t, fetcher) + require.Equal(t, 2, provider.fetchersCount()) + }) + + t.Run("test register same secret again", func(t *testing.T) { + fetcher := provider.Register(secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty2", + SecretKey: "test-secret", + RefreshInterval: 5 * time.Millisecond, + }, + }) + require.NotNil(t, fetcher) + require.Equal(t, 2, provider.fetchersCount()) + }) + + t.Run("test register secret with bad ARN", func(t *testing.T) { + fetcher := provider.Register(secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: "arn:aws::us-west-2:123456789:secret:receiver-pager-duty2", + SecretKey: "test-secret", + RefreshInterval: 5 * time.Millisecond, + }, + }) + require.Nil(t, fetcher) + require.Equal(t, 2, provider.fetchersCount()) + }) +} + +func TestRegisterSecret(t *testing.T) { + reg := prometheus.NewPedanticRegistry() + ctx, cancel := context.WithCancel(context.Background()) + options := secrets.SecretProviderOptions{ + Logger: promslog.NewNopLogger(), + Registerer: reg, + Context: ctx, + } + provider := NewAWSSecretsManagerProvider(options) + secretOne := secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty", + SecretKey: "key1", + RefreshInterval: 5 * time.Millisecond, + }, + } + secretOneCopy := secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty", + SecretKey: "key2", + RefreshInterval: 5 * time.Millisecond, + }, + } + provider.Register(secretOne) + require.Equal(t, 1, provider.fetchersCount()) + + provider.Register(secretOneCopy) + require.Equal(t, 1, provider.fetchersCount()) + + secretTwo := secrets.GenericSecret{ + AWSSecretsManagerConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty-2", + SecretKey: "key1", + RefreshInterval: 5 * time.Millisecond, + }, + } + provider.Register(secretTwo) + require.Equal(t, 2, provider.fetchersCount()) + provider.UpdateComplete() + require.Equal(t, 2, provider.fetchersCount()) + + // simulate an update + provider.Register(secretTwo) + provider.UpdateComplete() + require.Equal(t, 1, provider.fetchersCount()) + cancel() + provider.Stop() +} + +type MockSecretsManagerClient struct { + secretsmanager.Client + secrets map[string]string + GetSecretValueFunc func(context.Context, *secretsmanager.GetSecretValueInput, ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) +} + +func (m *MockSecretsManagerClient) GetSecretValue(ctx context.Context, params *secretsmanager.GetSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) { + if m.GetSecretValueFunc != nil { + return m.GetSecretValueFunc(ctx, params, optFns...) + } + secretKey := *params.SecretId + if secretValue, ok := m.secrets[secretKey]; ok { + return &secretsmanager.GetSecretValueOutput{ + SecretString: &secretValue, + }, nil + } + return nil, errors.New("not found") +} + +func setupFetcher() (*secretFetcher, *prometheus.Registry) { + reg := prometheus.NewPedanticRegistry() + secretARN := "arn:aws:secretsmanager:us-west-2:123456789:secret:receiver-pager-duty" + return &secretFetcher{ + secrets: make(map[string]string), + logger: promslog.NewNopLogger(), + secretConfig: secrets.AWSSecretsManagerConfig{ + SecretARN: secretARN, + SecretKey: "test-secret", + RefreshInterval: 5 * time.Millisecond, + }, + ticker: time.NewTicker(time.Second), + done: make(chan struct{}), + asyncCh: make(chan struct{}), + lastRetrieved: time.Time{}, + ctx: context.Background(), + metrics: NewSecretFetcherMetrics(reg, secretARN), + }, reg +} diff --git a/secrets/secrets_provider.go b/secrets/secrets_provider.go index 407a9505e0..b51a6b4462 100644 --- a/secrets/secrets_provider.go +++ b/secrets/secrets_provider.go @@ -1,26 +1,40 @@ +// Copyright 2019 Prometheus Team +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package secrets import ( "context" "errors" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/common/config" "log/slog" "sync" -) -var ( - AWS_SECRETS_MANAGER_PROVIDER = "aws_secrets_manager" + "github.com/prometheus/client_golang/prometheus" ) +var AWSSecretsManagerProviderName = "aws_secrets_manager" + type SecretsFetcher interface { - FetchSecret(ctx context.Context, secret *GenericSecret) (string, error) + FetchSecret(ctx context.Context, secret GenericSecret) (string, error) + RefreshCredentialsAsync() Stop() + AwaitStop() } type SecretsProvider interface { - Register(secret *GenericSecret) SecretsFetcher + Register(secret GenericSecret) SecretsFetcher Stop() + UpdateComplete() } type SecretsProviderRegistry struct { @@ -43,11 +57,19 @@ func NewSecretsProviderRegistry(logger *slog.Logger, reg prometheus.Registerer) return registry } -func (s *SecretsProviderRegistry) Register(config SecretProviderDiscoveryConfig) { +func (s *SecretsProviderRegistry) Register(config SecretProviderDiscoveryConfig) error { + if config == nil { + return errors.New("nil config provided") + } + name := config.Name() + if name == "" { + return errors.New("empty provider name") + } s.mtx.Lock() defer s.mtx.Unlock() - s.logger.Info("registering secret providers", "name", config.Name()) + s.logger.Info("registering secret provider", "name", config.Name()) s.configs[config.Name()] = config + return nil } func (s *SecretsProviderRegistry) Init() { @@ -69,6 +91,15 @@ func (s *SecretsProviderRegistry) Init() { } } +func (s *SecretsProviderRegistry) UpdateComplete() { + s.mtx.RLock() + defer s.mtx.RUnlock() + for name, provider := range s.providers { + s.logger.Info("update complete invoked on provider", "name", name) + provider.UpdateComplete() + } +} + func (s *SecretsProviderRegistry) Stop() { if s == nil { return @@ -84,17 +115,25 @@ func (s *SecretsProviderRegistry) Stop() { s.logger.Info("stopping secrets providers", "name", name) provider.Stop() } + s.providers = nil s.logger.Info("stopped secrets providers registry") } -func (s *SecretsProviderRegistry) RegisterSecret(secret *GenericSecret) (SecretsFetcher, error) { +func (s *SecretsProviderRegistry) RegisterSecret(secret GenericSecret) (SecretsFetcher, error) { s.mtx.RLock() defer s.mtx.RUnlock() s.logger.Info("registering secret") - if secret.AWSSecretsManagerConfig != nil { + if !secret.AWSSecretsManagerConfig.IsZero() { + provider, exists := s.providers[AWSSecretsManagerProviderName] + if !exists { + return nil, errors.New("AWS secrets manager provider not initialized") + } s.logger.Info("registering aws_secret_manager secret") - return s.providers[AWS_SECRETS_MANAGER_PROVIDER].Register(secret), nil + return provider.Register(secret), nil + } + if secret.Inline.Secret != "" { + return InlineSecretsFetcher{}, nil } return nil, errors.New("no secrets fetcher found for the given secret") } @@ -112,7 +151,15 @@ type SecretProviderOptions struct { // A registerer for the SecretProvider's metrics. Registerer prometheus.Registerer - HTTPClientOptions []config.HTTPClientOption - Context context.Context } + +type InlineSecretsFetcher struct{} + +func (i InlineSecretsFetcher) FetchSecret(ctx context.Context, secret GenericSecret) (string, error) { + return secret.Inline.Secret, nil +} + +func (i InlineSecretsFetcher) RefreshCredentialsAsync() {} +func (i InlineSecretsFetcher) Stop() {} +func (i InlineSecretsFetcher) AwaitStop() {} diff --git a/secrets/secrets_provider_test.go b/secrets/secrets_provider_test.go new file mode 100644 index 0000000000..60a2aa1209 --- /dev/null +++ b/secrets/secrets_provider_test.go @@ -0,0 +1,290 @@ +// Copyright 2019 Prometheus Team +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secrets + +import ( + "context" + "errors" + "log/slog" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/common/promslog" + + "github.com/stretchr/testify/require" +) + +func TestNewSecretsProviderRegistry(t *testing.T) { + logger := slog.Default() + reg := prometheus.NewRegistry() + registry := NewSecretsProviderRegistry(logger, reg) + + require.NotNil(t, registry) + require.NotNil(t, registry.providers) + require.NotNil(t, registry.configs) + require.NotNil(t, registry.logger) + require.NotNil(t, registry.reg) +} + +func TestNewSecretsProviderRegistryInit(t *testing.T) { + logger := slog.Default() + reg := prometheus.NewRegistry() + + tests := []struct { + name string + discoveryConfig SecretProviderDiscoveryConfig + wantErr bool + registries []string + }{ + { + name: "successful initialization", + discoveryConfig: MockSecretProviderDiscoveryConfig{}, + wantErr: false, + registries: []string{"mock secret provider"}, + }, + { + name: "unsuccessful initialization", + discoveryConfig: nil, + wantErr: true, + }, + { + name: "unsuccessful initialization with no name provider", + discoveryConfig: MockSecretProviderDiscoveryConfigNoName{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + registry := NewSecretsProviderRegistry(logger, reg) + err := registry.Register(tt.discoveryConfig) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + registry.Init() + + for _, r := range tt.registries { + require.NotNil(t, registry.providers[r]) + } + }) + } +} + +func TestRegisterSecret(t *testing.T) { + tests := []struct { + name string + secret GenericSecret + setupRegistry func(*SecretsProviderRegistry) + wantErr bool + errMessage string + }{ + { + name: "successful AWS secrets manager registration", + secret: GenericSecret{ + AWSSecretsManagerConfig: AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:region:123456789012:secret:my-secret", + }, + }, + setupRegistry: func(registry *SecretsProviderRegistry) { + mockProvider := &MockSecretsProvider{} + registry.providers[AWSSecretsManagerProviderName] = mockProvider + }, + wantErr: false, + }, + { + name: "AWS secrets manager provider not initialized", + secret: GenericSecret{ + AWSSecretsManagerConfig: AWSSecretsManagerConfig{ + SecretARN: "arn:aws:secretsmanager:region:123456789012:secret:my-secret", + }, + }, + setupRegistry: func(registry *SecretsProviderRegistry) {}, + wantErr: true, + errMessage: "AWS secrets manager provider not initialized", + }, + { + name: "successful inline secret registration", + secret: GenericSecret{ + Inline: Inline{ + Secret: "test-secret", + }, + }, + setupRegistry: func(registry *SecretsProviderRegistry) {}, + wantErr: false, + }, + { + name: "no valid secret configuration", + secret: GenericSecret{}, + setupRegistry: func(registry *SecretsProviderRegistry) {}, + wantErr: true, + errMessage: "no secrets fetcher found for the given secret", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + registry := &SecretsProviderRegistry{ + logger: slog.Default(), + providers: make(map[string]SecretsProvider), + } + tt.setupRegistry(registry) + + fetcher, err := registry.RegisterSecret(tt.secret) + + if tt.wantErr { + require.Error(t, err) + require.Equal(t, tt.errMessage, err.Error()) + require.Nil(t, fetcher) + } else { + require.NoError(t, err) + require.NotNil(t, fetcher) + } + registry.Stop() + }) + } +} + +func TestSecretsProviderRegistry_Stop(t *testing.T) { + t.Run("stop all providers", func(t *testing.T) { + registry := NewSecretsProviderRegistry(promslog.NewNopLogger(), prometheus.NewPedanticRegistry()) + sp1 := &MockSecretsProvider{} + sp2 := &MockSecretsProvider{} + discoveryconfig1 := &MockCustomizableSecretProviderDiscoveryConfig{ + ProviderName: "provider-1", + SecretsProvider: sp1, + } + discoveryconfig2 := &MockCustomizableSecretProviderDiscoveryConfig{ + ProviderName: "provider-2", + SecretsProvider: sp2, + } + err := registry.Register(discoveryconfig1) + require.NoError(t, err) + err = registry.Register(discoveryconfig2) + require.NoError(t, err) + registry.Init() + registry.Stop() + require.Nil(t, registry.providers) + require.True(t, sp1.Stopped) + require.True(t, sp2.Stopped) + }) +} + +func TestInlineSecretsFetcher(t *testing.T) { + tests := []struct { + name string + secret GenericSecret + want string + wantErr bool + }{ + { + name: "fetch inline secret successfully", + secret: GenericSecret{ + Inline: Inline{ + Secret: "test-secret-value", + }, + }, + want: "test-secret-value", + wantErr: false, + }, + { + name: "fetch empty inline secret", + secret: GenericSecret{ + Inline: Inline{ + Secret: "", + }, + }, + want: "", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fetcher := InlineSecretsFetcher{} + got, err := fetcher.FetchSecret(context.Background(), tt.secret) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.want, got) + } + }) + } +} + +type MockSecretsProvider struct { + Stopped bool +} + +func (m *MockSecretsProvider) Stop() { + m.Stopped = true +} + +func (m *MockSecretsProvider) Register(secret GenericSecret) SecretsFetcher { + return &MockSecretsFetcher{} +} + +func (m *MockSecretsProvider) UpdateComplete() {} + +type MockSecretsFetcher struct{} + +func (m *MockSecretsFetcher) FetchSecret(ctx context.Context, secret GenericSecret) (string, error) { + return "", nil +} + +func (m *MockSecretsFetcher) RefreshCredentialsAsync() {} + +func (m *MockSecretsFetcher) Stop() {} + +func (m *MockSecretsFetcher) AwaitStop() {} + +type MockSecretProviderDiscoveryConfig struct { + wantErr bool +} + +func (m MockSecretProviderDiscoveryConfig) Name() string { + return "mock secret provider" +} + +func (m MockSecretProviderDiscoveryConfig) NewSecretsProvider(options SecretProviderOptions) (SecretsProvider, error) { + if m.wantErr { + return nil, errors.New("unable to initialize secrets provider") + } + return &MockSecretsProvider{}, nil +} + +type MockSecretProviderDiscoveryConfigNoName struct{} + +func (m MockSecretProviderDiscoveryConfigNoName) Name() string { + return "" +} + +func (m MockSecretProviderDiscoveryConfigNoName) NewSecretsProvider(options SecretProviderOptions) (SecretsProvider, error) { + panic("not implemented") +} + +type MockCustomizableSecretProviderDiscoveryConfig struct { + ProviderName string + SecretsProvider SecretsProvider +} + +func (m MockCustomizableSecretProviderDiscoveryConfig) Name() string { + return m.ProviderName +} + +func (m MockCustomizableSecretProviderDiscoveryConfig) NewSecretsProvider(options SecretProviderOptions) (SecretsProvider, error) { + return m.SecretsProvider, nil +}