diff --git a/prometheus/prometheus.go b/prometheus/prometheus.go index 0a897c1..c0b4a8f 100644 --- a/prometheus/prometheus.go +++ b/prometheus/prometheus.go @@ -11,6 +11,7 @@ import ( "log" "strings" "sync" + "sync/atomic" "time" "github.com/hashicorp/go-metrics" @@ -74,9 +75,8 @@ type GaugeDefinition struct { type gauge struct { prometheus.Gauge - updatedAt time.Time - // canDelete is set if the metric is created during runtime so we know it's ephemeral and can delete it on expiry. - canDelete bool + updatedAtNS atomic.Int64 + canDelete atomic.Bool } // SummaryDefinition can be provided to PrometheusOpts to declare a constant summary that is not deleted on expiry. @@ -88,8 +88,8 @@ type SummaryDefinition struct { type summary struct { prometheus.Summary - updatedAt time.Time - canDelete bool + updatedAtNS atomic.Int64 + canDelete atomic.Bool } // CounterDefinition can be provided to PrometheusOpts to declare a constant counter that is not deleted on expiry. @@ -101,8 +101,8 @@ type CounterDefinition struct { type counter struct { prometheus.Counter - updatedAt time.Time - canDelete bool + updatedAtNS atomic.Int64 + canDelete atomic.Bool } // NewPrometheusSink creates a new PrometheusSink using the default options. @@ -158,50 +158,39 @@ func (p *PrometheusSink) Collect(c chan<- prometheus.Metric) { // collectAtTime allows internal testing of the expiry based logic here without // mocking clocks or making tests timing sensitive. func (p *PrometheusSink) collectAtTime(c chan<- prometheus.Metric, t time.Time) { - expire := p.expiration != 0 - p.gauges.Range(func(k, v interface{}) bool { - if v == nil { - return true + // Counters + p.counters.Range(func(k, v any) bool { + cnt := v.(*counter) + cnt.Collect(c) + last := time.Unix(0, cnt.updatedAtNS.Load()) + stale := p.expiration > 0 && t.Sub(last) > p.expiration + if stale && cnt.canDelete.Load() { + p.counters.Delete(k) } + return true + }) + + // Gauges + p.gauges.Range(func(k, v any) bool { g := v.(*gauge) - lastUpdate := g.updatedAt - if expire && lastUpdate.Add(p.expiration).Before(t) { - if g.canDelete { - p.gauges.Delete(k) - return true - } - } g.Collect(c) + last := time.Unix(0, g.updatedAtNS.Load()) + stale := p.expiration > 0 && t.Sub(last) > p.expiration + if stale && g.canDelete.Load() { + p.gauges.Delete(k) + } return true }) - p.summaries.Range(func(k, v interface{}) bool { - if v == nil { - return true - } + + // Summaries + p.summaries.Range(func(k, v any) bool { s := v.(*summary) - lastUpdate := s.updatedAt - if expire && lastUpdate.Add(p.expiration).Before(t) { - if s.canDelete { - p.summaries.Delete(k) - return true - } - } s.Collect(c) - return true - }) - p.counters.Range(func(k, v interface{}) bool { - if v == nil { - return true + last := time.Unix(0, s.updatedAtNS.Load()) + stale := p.expiration > 0 && t.Sub(last) > p.expiration + if stale && s.canDelete.Load() { + p.summaries.Delete(k) } - count := v.(*counter) - lastUpdate := count.updatedAt - if expire && lastUpdate.Add(p.expiration).Before(t) { - if count.canDelete { - p.counters.Delete(k) - return true - } - } - count.Collect(c) return true }) } @@ -283,40 +272,32 @@ func (p *PrometheusSink) SetPrecisionGauge(parts []string, val float64) { func (p *PrometheusSink) SetPrecisionGaugeWithLabels(parts []string, val float64, labels []metrics.Label) { key, hash := flattenKey(parts, labels) - pg, ok := p.gauges.Load(hash) - - // The sync.Map underlying gauges stores pointers to our structs. If we need to make updates, - // rather than modifying the underlying value directly, which would be racy, we make a local - // copy by dereferencing the pointer we get back, making the appropriate changes, and then - // storing a pointer to our local copy. The underlying Prometheus types are threadsafe, - // so there's no issues there. It's possible for racy updates to occur to the updatedAt - // value, but since we're always setting it to time.Now(), it doesn't really matter. - if ok { - localGauge := *pg.(*gauge) - localGauge.Set(val) - localGauge.updatedAt = time.Now() - p.gauges.Store(hash, &localGauge) - - // The gauge does not exist, create the gauge and allow it to be deleted - } else { - help := key - existingHelp, ok := p.help[fmt.Sprintf("gauge.%s", key)] - if ok { - help = existingHelp - } - g := prometheus.NewGauge(prometheus.GaugeOpts{ + + // Fast path: use existing + if v, ok := p.gauges.Load(hash); ok { + g := v.(*gauge) + g.Set(val) + g.updatedAtNS.Store(time.Now().UnixNano()) + return + } + + // Create-or-get single instance + help := p.help[fmt.Sprintf("gauge.%s", key)] + if help == "" { + help = key + } + w := &gauge{ + Gauge: prometheus.NewGauge(prometheus.GaugeOpts{ Name: key, Help: help, ConstLabels: prometheusLabels(labels), - }) - g.Set(val) - pg = &gauge{ - Gauge: g, - updatedAt: time.Now(), - canDelete: true, - } - p.gauges.Store(hash, pg) + }), } + w.canDelete.Store(true) + actual, _ := p.gauges.LoadOrStore(hash, w) + g := actual.(*gauge) + g.Set(val) + g.updatedAtNS.Store(time.Now().UnixNano()) } func (p *PrometheusSink) AddSample(parts []string, val float32) { @@ -325,37 +306,32 @@ func (p *PrometheusSink) AddSample(parts []string, val float32) { func (p *PrometheusSink) AddSampleWithLabels(parts []string, val float32, labels []metrics.Label) { key, hash := flattenKey(parts, labels) - ps, ok := p.summaries.Load(hash) - - // Does the summary already exist for this sample type? - if ok { - localSummary := *ps.(*summary) - localSummary.Observe(float64(val)) - localSummary.updatedAt = time.Now() - p.summaries.Store(hash, &localSummary) - - // The summary does not exist, create the Summary and allow it to be deleted - } else { - help := key - existingHelp, ok := p.help[fmt.Sprintf("summary.%s", key)] - if ok { - help = existingHelp - } - s := prometheus.NewSummary(prometheus.SummaryOpts{ + + if v, ok := p.summaries.Load(hash); ok { + s := v.(*summary) + s.Observe(float64(val)) + s.updatedAtNS.Store(time.Now().UnixNano()) + return + } + + help := p.help[fmt.Sprintf("summary.%s", key)] + if help == "" { + help = key + } + w := &summary{ + Summary: prometheus.NewSummary(prometheus.SummaryOpts{ Name: key, Help: help, MaxAge: 10 * time.Second, ConstLabels: prometheusLabels(labels), Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001}, - }) - s.Observe(float64(val)) - ps = &summary{ - Summary: s, - updatedAt: time.Now(), - canDelete: true, - } - p.summaries.Store(hash, ps) + }), } + w.canDelete.Store(true) + actual, _ := p.summaries.LoadOrStore(hash, w) + s := actual.(*summary) + s.Observe(float64(val)) + s.updatedAtNS.Store(time.Now().UnixNano()) } // EmitKey is not implemented. Prometheus doesn’t offer a type for which an @@ -370,42 +346,36 @@ func (p *PrometheusSink) IncrCounter(parts []string, val float32) { func (p *PrometheusSink) IncrCounterWithLabels(parts []string, val float32, labels []metrics.Label) { key, hash := flattenKey(parts, labels) - pc, ok := p.counters.Load(hash) - // Prometheus Counter.Add() panics if val < 0. We don't want this to - // cause applications to crash, so log an error instead. + // Prometheus Counter.Add() panics if val < 0; log and return. if val < 0 { - log.Printf("[ERR] Attempting to increment Prometheus counter %v with value negative value %v", key, val) + log.Printf("[ERR] 'IncrCounterWithLabels' called with a negative value: %v", val) return } - // Does the counter exist? - if ok { - localCounter := *pc.(*counter) - localCounter.Add(float64(val)) - localCounter.updatedAt = time.Now() - p.counters.Store(hash, &localCounter) - - // The counter does not exist yet, create it and allow it to be deleted - } else { - help := key - existingHelp, ok := p.help[fmt.Sprintf("counter.%s", key)] - if ok { - help = existingHelp - } - c := prometheus.NewCounter(prometheus.CounterOpts{ + if v, ok := p.counters.Load(hash); ok { + c := v.(*counter) + c.Add(float64(val)) + c.updatedAtNS.Store(time.Now().UnixNano()) + return + } + + help := p.help[fmt.Sprintf("counter.%s", key)] + if help == "" { + help = key + } + w := &counter{ + Counter: prometheus.NewCounter(prometheus.CounterOpts{ Name: key, Help: help, ConstLabels: prometheusLabels(labels), - }) - c.Add(float64(val)) - pc = &counter{ - Counter: c, - updatedAt: time.Now(), - canDelete: true, - } - p.counters.Store(hash, pc) + }), } + w.canDelete.Store(true) + actual, _ := p.counters.LoadOrStore(hash, w) + c := actual.(*counter) + c.Add(float64(val)) + c.updatedAtNS.Store(time.Now().UnixNano()) } // PrometheusPushSink wraps a normal prometheus sink and provides an address and facilities to export it to an address diff --git a/prometheus/prometheus_race_test.go b/prometheus/prometheus_race_test.go new file mode 100644 index 0000000..bc312ad --- /dev/null +++ b/prometheus/prometheus_race_test.go @@ -0,0 +1,65 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MIT + +package prometheus + +// This test demonstrates a race condition when using PrometheusSink when run from multiple +// goroutines concurrently resulting in missed updates. + +import ( + "sync" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" +) + +func TestPrometheusRaceCondition(t *testing.T) { + reg := prometheus.NewRegistry() + + promSink, err := NewPrometheusSinkFrom(PrometheusOpts{ + Registerer: reg, + }) + if err != nil { + t.Fatal(err) + } + + nrGoroutines := 20 + incrementsPerGoroutine := 1000 + expectedTotal := int64(nrGoroutines * incrementsPerGoroutine) + + var wg sync.WaitGroup + for range nrGoroutines { + wg.Add(1) + go func() { + for range incrementsPerGoroutine { + promSink.IncrCounter([]string{"race", "test", "counter"}, 1) + } + wg.Done() + }() + } + wg.Wait() + + // Collect metrics after all updates + timeAfterUpdates := time.Now() + ch := make(chan prometheus.Metric, 10) + promSink.collectAtTime(ch, timeAfterUpdates) + + // Read and verify the counter + select { + case m := <-ch: + var pb dto.Metric + if err := m.Write(&pb); err != nil { + t.Fatalf("unexpected error reading metric: %s", err) + } + if pb.Counter == nil { + t.Fatalf("expected counter metric, got %v", pb) + } + if *pb.Counter.Value != float64(expectedTotal) { + t.Fatalf("expected counter value %d, got %f", expectedTotal, *pb.Counter.Value) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timed out waiting to collect counter metric") + } +} diff --git a/prometheus/prometheus_test.go b/prometheus/prometheus_test.go index 6829314..19f058c 100644 --- a/prometheus/prometheus_test.go +++ b/prometheus/prometheus_test.go @@ -258,9 +258,9 @@ func fakeServer(q chan string) *httptest.Server { Help: proto.String("default_one_two"), Type: dto.MetricType_GAUGE.Enum(), Metric: []*dto.Metric{ - &dto.Metric{ + { Label: []*dto.LabelPair{ - &dto.LabelPair{ + { Name: proto.String("host"), Value: proto.String(MockGetHostname()), }, @@ -358,9 +358,9 @@ func TestDefinitionsWithLabels(t *testing.T) { {Name: "version", Value: "some info"}, }) sink.gauges.Range(func(key, value interface{}) bool { - localGauge := *value.(*gauge) - if !strings.Contains(localGauge.Desc().String(), gaugeDef.Help) { - t.Fatalf("expected gauge to include correct help=%s, but was %s", gaugeDef.Help, localGauge.Desc().String()) + g := value.(*gauge) + if !strings.Contains(g.Desc().String(), gaugeDef.Help) { + t.Fatalf("expected gauge to include correct help=%s, but was %s", gaugeDef.Help, g.Desc().String()) } return true }) @@ -369,9 +369,9 @@ func TestDefinitionsWithLabels(t *testing.T) { {Name: "version", Value: "some info"}, }) sink.summaries.Range(func(key, value interface{}) bool { - metric := *value.(*summary) - if !strings.Contains(metric.Desc().String(), summaryDef.Help) { - t.Fatalf("expected gauge to include correct help=%s, but was %s", summaryDef.Help, metric.Desc().String()) + s := value.(*summary) + if !strings.Contains(s.Desc().String(), summaryDef.Help) { + t.Fatalf("expected gauge to include correct help=%s, but was %s", summaryDef.Help, s.Desc().String()) } return true }) @@ -380,9 +380,9 @@ func TestDefinitionsWithLabels(t *testing.T) { {Name: "version", Value: "some info"}, }) sink.counters.Range(func(key, value interface{}) bool { - metric := *value.(*counter) - if !strings.Contains(metric.Desc().String(), counterDef.Help) { - t.Fatalf("expected gauge to include correct help=%s, but was %s", counterDef.Help, metric.Desc().String()) + c := value.(*counter) + if !strings.Contains(c.Desc().String(), counterDef.Help) { + t.Fatalf("expected gauge to include correct help=%s, but was %s", counterDef.Help, c.Desc().String()) } return true })