diff --git a/gcc/arrival_group_accumulator.go b/gcc/arrival_group_accumulator.go new file mode 100644 index 0000000..9cc8719 --- /dev/null +++ b/gcc/arrival_group_accumulator.go @@ -0,0 +1,85 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "time" +) + +type arrivalGroupItem struct { + SequenceNumber uint64 + Departure time.Time + Arrival time.Time + Size int +} + +type arrivalGroup []arrivalGroupItem + +type arrivalGroupAccumulator struct { + next arrivalGroup + burstInterval time.Duration + maxBurstDuration time.Duration +} + +func newArrivalGroupAccumulator() *arrivalGroupAccumulator { + return &arrivalGroupAccumulator{ + next: make([]arrivalGroupItem, 0), + burstInterval: 5 * time.Millisecond, + maxBurstDuration: 5 * time.Millisecond, + } +} + +func (a *arrivalGroupAccumulator) onPacketAcked( + sequenceNumber uint64, + size int, + departure, arrival time.Time, +) arrivalGroup { + if len(a.next) == 0 { + a.next = append(a.next, arrivalGroupItem{ + SequenceNumber: sequenceNumber, + Size: size, + Departure: departure, + Arrival: arrival, + }) + + return nil + } + + sendTimeDelta := departure.Sub(a.next[0].Departure) + if sendTimeDelta < a.burstInterval { + a.next = append(a.next, arrivalGroupItem{ + SequenceNumber: sequenceNumber, + Size: size, + Departure: departure, + Arrival: arrival, + }) + + return nil + } + + arrivalTimeDeltaFirst := arrival.Sub(a.next[0].Arrival) + propagationDelta := arrivalTimeDeltaFirst - sendTimeDelta + + if propagationDelta < 0 && arrivalTimeDeltaFirst < a.maxBurstDuration { + a.next = append(a.next, arrivalGroupItem{ + SequenceNumber: sequenceNumber, + Size: size, + Departure: departure, + Arrival: arrival, + }) + + return nil + } + + group := make(arrivalGroup, len(a.next)) + copy(group, a.next) + a.next = arrivalGroup{arrivalGroupItem{ + SequenceNumber: sequenceNumber, + Size: size, + Departure: departure, + Arrival: arrival, + }} + + return group +} diff --git a/gcc/arrival_group_accumulator_test.go b/gcc/arrival_group_accumulator_test.go new file mode 100644 index 0000000..213856b --- /dev/null +++ b/gcc/arrival_group_accumulator_test.go @@ -0,0 +1,244 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestArrivalGroupAccumulator(t *testing.T) { + type logItem struct { + SequenceNumber uint64 + Departure time.Time + Arrival time.Time + } + triggerNewGroupElement := logItem{ + Departure: time.Time{}.Add(time.Second), + Arrival: time.Time{}.Add(time.Second), + } + cases := []struct { + name string + log []logItem + exp []arrivalGroup + }{ + { + name: "emptyCreatesNoGroups", + log: []logItem{}, + exp: []arrivalGroup{}, + }, + { + name: "createsSingleElementGroup", + log: []logItem{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(time.Millisecond), + }, + }, + }, + }, + { + name: "createsTwoElementGroup", + log: []logItem{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + }}, + }, + { + name: "createsTwoArrivalGroups1", + log: []logItem{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(24 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + }, + { + { + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(24 * time.Millisecond), + }, + }, + }, + }, + { + name: "ignoresOutOfOrderPackets", + log: []logItem{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(34 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(8 * time.Millisecond), + Arrival: time.Time{}.Add(30 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + }, + { + { + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(34 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(8 * time.Millisecond), + Arrival: time.Time{}.Add(30 * time.Millisecond), + }, + }, + }, + }, + { + name: "newGroupBecauseOfInterDepartureTime", + log: []logItem{ + { + SequenceNumber: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + SequenceNumber: 1, + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + SequenceNumber: 2, + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + SequenceNumber: 3, + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + SequenceNumber: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + SequenceNumber: 1, + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + }, + { + { + SequenceNumber: 2, + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + SequenceNumber: 3, + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + }, + }, + }, + { + name: "createsSingleGroupArrivalBurst", + log: []logItem{ + { + SequenceNumber: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + SequenceNumber: 1, + Departure: time.Time{}.Add(10 * time.Millisecond), + Arrival: time.Time{}.Add(12 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + SequenceNumber: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + SequenceNumber: 1, + Departure: time.Time{}.Add(10 * time.Millisecond), + Arrival: time.Time{}.Add(12 * time.Millisecond), + }, + }, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + aga := newArrivalGroupAccumulator() + received := []arrivalGroup{} + for _, ack := range tc.log { + next := aga.onPacketAcked(ack.SequenceNumber, 0, ack.Departure, ack.Arrival) + if next != nil { + received = append(received, next) + } + } + assert.Equal(t, tc.exp, received) + }) + } +} diff --git a/gcc/delay_rate_controller.go b/gcc/delay_rate_controller.go new file mode 100644 index 0000000..ebcba6b --- /dev/null +++ b/gcc/delay_rate_controller.go @@ -0,0 +1,147 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "math" + "time" + + "github.com/pion/logging" +) + +const ( + defaultDecreaseFactor = 0.85 +) + +type delayRateController struct { + log logging.LeveledLogger + decreaseFactor float64 + arrivalGroups *arrivalGroupAccumulator + lastArrivalGroup arrivalGroup + trend *trendlineEstimator + overuse *overuseDetector + samples int + usage usage + state state + lastDecreaseRate *ewma + lastUpdate time.Time + targetRate int + minTarget int + maxTarget int +} + +func newDelayRateController(initialRate, minRate, maxRate int, logger logging.LeveledLogger) *delayRateController { + return &delayRateController{ + log: logger, + decreaseFactor: defaultDecreaseFactor, + arrivalGroups: newArrivalGroupAccumulator(), + lastArrivalGroup: []arrivalGroupItem{}, + trend: newTrendlineEstimator(), + overuse: newOveruseDetector(false), + usage: 0, + samples: 0, + state: 0, + lastDecreaseRate: newEWMA(0.95), + targetRate: initialRate, + minTarget: minRate, + maxTarget: maxRate, + } +} + +func (c *delayRateController) onPacketAcked(sequenceNumber uint64, size int, departure, arrival time.Time) { + next := c.arrivalGroups.onPacketAcked( + sequenceNumber, + size, + departure, + arrival, + ) + if next == nil { + return + } + if len(next) == 0 { + // ignore empty groups, should never occur + return + } + if len(c.lastArrivalGroup) == 0 { + c.lastArrivalGroup = next + + return + } + + interArrivalTime := next[len(next)-1].Arrival.Sub(c.lastArrivalGroup[len(c.lastArrivalGroup)-1].Arrival) + interDepartureTime := next[len(next)-1].Departure.Sub(c.lastArrivalGroup[len(c.lastArrivalGroup)-1].Departure) + interGroupDelay := interArrivalTime - interDepartureTime + + trend := c.trend.update(arrival, interGroupDelay) + c.samples++ + c.usage = c.overuse.update(arrival, trend, c.samples) + c.lastArrivalGroup = next + + c.log.Tracef( + "ts=%v.%06d, seq=%v, interArrivalTime=%v, interDepartureTime=%v, interGroupDelay=%v, estimate=%f, threshold=%f, usage=%v, state=%v", // nolint + c.lastArrivalGroup[0].Departure.UTC().Format("2006/01/02 15:04:05"), + c.lastArrivalGroup[0].Departure.UTC().Nanosecond()/1e3, + next[0].SequenceNumber, + interArrivalTime.Microseconds(), + interDepartureTime.Microseconds(), + interGroupDelay.Microseconds(), + trend, + c.overuse.delayThreshold, + int(c.usage), + int(c.state), + ) +} + +func (c *delayRateController) update(ts time.Time, deliveryRate int, rtt time.Duration) int { + deliveredRate := float64(deliveryRate) + c.state = c.state.transition(c.usage) + if c.state == stateIncrease { + window := ts.Sub(c.lastUpdate) + if c.canIncreaseMultiplicatively(deliveredRate) { + c.targetRate = max(c.targetRate, multiplicativeIncrease(c.targetRate, window)) + } else { + c.targetRate = additiveIncrease(c.targetRate, rtt, window) + } + c.targetRate = min(c.targetRate, int(1.5*deliveredRate)) + } + if c.state == stateDecrease { + c.lastDecreaseRate.update(float64(deliveryRate)) + c.targetRate = int(c.decreaseFactor * float64(deliveryRate)) + } + c.lastUpdate = ts + + c.targetRate = max(c.targetRate, c.minTarget) + c.targetRate = min(c.targetRate, c.maxTarget) + + return c.targetRate +} + +func (c *delayRateController) canIncreaseMultiplicatively(deliveredRate float64) bool { + avg := c.lastDecreaseRate.avg() + if avg == 0 { + return true + } + stdDev := math.Sqrt(c.lastDecreaseRate.varr()) + lower := avg - 3*stdDev + upper := avg + 3*stdDev + + return deliveredRate < lower || deliveredRate > upper +} + +func multiplicativeIncrease(rate int, window time.Duration) int { + exponent := min(window.Seconds(), 1.0) + eta := math.Pow(1.08, exponent) + + return int(eta * float64(rate)) +} + +func additiveIncrease(rate int, rtt, window time.Duration) int { + responseTime := 100 + rtt.Milliseconds() + alpha := 0.5 * min(float64(window.Milliseconds())/float64(responseTime), 1.0) + bitsPerFrame := float64(rate) / 30.0 + packetsPerFrame := math.Ceil(bitsPerFrame / (1200 * 8)) + expectedPacketSizeBits := bitsPerFrame / packetsPerFrame + + return rate + max(1000, int(alpha*float64(expectedPacketSizeBits))) +} diff --git a/gcc/delay_rate_controller_test.go b/gcc/delay_rate_controller_test.go new file mode 100644 index 0000000..74ad75e --- /dev/null +++ b/gcc/delay_rate_controller_test.go @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDelayRateController(t *testing.T) { + t.Run("init", func(t *testing.T) { + controller := newDelayRateController(1_000_000, 500_000, 2_000_000, nil) + assert.Nil(t, controller.log) + assert.Equal(t, controller.decreaseFactor, defaultDecreaseFactor) + assert.NotNil(t, controller.arrivalGroups) + assert.NotNil(t, controller.lastArrivalGroup) + assert.NotNil(t, controller.trend) + assert.NotNil(t, controller.overuse) + assert.Equal(t, controller.samples, 0) + assert.Equal(t, controller.usage, usage(0)) + assert.Equal(t, controller.state, state(0)) + assert.NotNil(t, controller.lastDecreaseRate) + assert.Zero(t, controller.lastUpdate) + assert.Equal(t, controller.minTarget, 500_000) + assert.Equal(t, controller.maxTarget, 2_000_000) + assert.Equal(t, controller.targetRate, 1_000_000) + }) + + t.Run("canIncreaseMultiplicatively", func(t *testing.T) { + cases := []struct { + deliveredRate float64 + decreaseRate ewma + expected bool + }{ + {deliveredRate: 1000, decreaseRate: ewma{average: 0, variance: 0}, expected: true}, + {deliveredRate: 1000, decreaseRate: ewma{average: 1500, variance: 100}, expected: true}, + {deliveredRate: 1000, decreaseRate: ewma{average: 1020, variance: 100}, expected: false}, + {deliveredRate: 1000, decreaseRate: ewma{average: 800, variance: 50}, expected: true}, + {deliveredRate: 1000, decreaseRate: ewma{average: 995, variance: 100}, expected: false}, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + controller := newDelayRateController(1000, 500, 2000, nil) + controller.lastDecreaseRate = &c.decreaseRate + assert.Equal(t, c.expected, controller.canIncreaseMultiplicatively(c.deliveredRate)) + }) + } + }) + + t.Run("multiplicativeIncrease", func(t *testing.T) { + cases := []struct { + initialRate int + rate int + window time.Duration + expected float64 + }{ + {initialRate: 1000, rate: 1000, window: 100 * time.Millisecond, expected: 1007}, + } + for i, c := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + res := multiplicativeIncrease(c.rate, c.window) + assert.InDelta(t, res, c.expected, 1) + }) + } + }) + + t.Run("additiveIncrease", func(t *testing.T) { + cases := []struct { + initialRate int + rate int + window time.Duration + expected int + }{ + {initialRate: 1000, rate: 1000, window: 100 * time.Millisecond, expected: 2000}, + {initialRate: 1_000_000, rate: 1_500_000, window: 100 * time.Millisecond, expected: 1_500_000 + 2083}, + } + for i, c := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + res := additiveIncrease(c.rate, 100*time.Millisecond, c.window) + assert.InDelta(t, res, c.expected, 1) + }) + } + }) +} diff --git a/gcc/overuse_detector.go b/gcc/overuse_detector.go new file mode 100644 index 0000000..53b60a9 --- /dev/null +++ b/gcc/overuse_detector.go @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "math" + "time" +) + +const ( + kUp = 0.0087 + kDown = 0.039 + + minNumDeltas = 60 +) + +const ( + defaultThresholdGain = 4.0 + defaultOveruseTimeThreshold = 5 * time.Millisecond +) + +type overuseDetector struct { + adaptiveThreshold bool + thresholdGain float64 + overUseTimeThreshold time.Duration + delayThreshold float64 + lastUpdate time.Time + firstOverUse time.Time + overUseCounter int + previousTrend float64 +} + +func newOveruseDetector(adaptive bool) *overuseDetector { + return &overuseDetector{ + adaptiveThreshold: adaptive, + thresholdGain: defaultThresholdGain, + overUseTimeThreshold: defaultOveruseTimeThreshold, + delayThreshold: 6, + lastUpdate: time.Time{}, + firstOverUse: time.Time{}, + overUseCounter: 0, + previousTrend: 0, + } +} + +func (d *overuseDetector) update(ts time.Time, trend float64, numDeltas int) usage { + if d.lastUpdate.IsZero() { + d.lastUpdate = ts + } + if numDeltas < 2 { + return usageNormal + } + modifiedTrend := float64(min(numDeltas, minNumDeltas)) * trend * d.thresholdGain + + var currentUsage usage + switch { + case modifiedTrend > d.delayThreshold: + if d.firstOverUse.IsZero() { + delta := ts.Sub(d.lastUpdate) + d.firstOverUse = ts.Add(-delta / 2) + } + d.overUseCounter++ + if ts.Sub(d.firstOverUse) > d.overUseTimeThreshold && d.overUseCounter > 1 && trend >= d.previousTrend { + d.firstOverUse = time.Time{} + d.overUseCounter = 0 + currentUsage = usageOver + } + case modifiedTrend < -d.delayThreshold: + d.firstOverUse = time.Time{} + d.overUseCounter = 0 + currentUsage = usageUnder + default: + d.firstOverUse = time.Time{} + d.overUseCounter = 0 + currentUsage = usageNormal + } + d.adaptThreshold(ts, modifiedTrend) + d.previousTrend = trend + d.lastUpdate = ts + + return currentUsage +} + +func (d *overuseDetector) adaptThreshold(ts time.Time, modifiedTrend float64) { + if !d.adaptiveThreshold { + return + } + if math.Abs(modifiedTrend) > d.delayThreshold+15 { + return + } + k := kUp + if math.Abs(modifiedTrend) < d.delayThreshold { + k = kDown + } + delta := min(ts.Sub(d.lastUpdate), 100*time.Millisecond) + d.delayThreshold += k * (math.Abs(modifiedTrend) - d.delayThreshold) * float64(delta.Milliseconds()) + d.delayThreshold = min(d.delayThreshold, 600.0) + d.delayThreshold = max(d.delayThreshold, 6.0) +} diff --git a/gcc/overuse_detector_test.go b/gcc/overuse_detector_test.go new file mode 100644 index 0000000..58d4555 --- /dev/null +++ b/gcc/overuse_detector_test.go @@ -0,0 +1,194 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestOveruseDetectorUpdate(t *testing.T) { + type estimate struct { + ts time.Time + estimate float64 + numDeltas int + } + cases := []struct { + name string + adaptive bool + values []estimate + expected []usage + }{ + { + name: "noEstimateNoUsageStatic", + adaptive: false, + values: []estimate{}, + expected: []usage{}, + }, + { + name: "overuseStatic", + adaptive: false, + values: []estimate{ + {time.Time{}, 1.0, 1}, + {time.Time{}.Add(5 * time.Millisecond), 20, 2}, + {time.Time{}.Add(20 * time.Millisecond), 30, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "normaluseStatic", + adaptive: false, + values: []estimate{{estimate: 0}}, + expected: []usage{usageNormal}, + }, + { + name: "underuseStatic", + adaptive: false, + values: []estimate{{time.Time{}, -20, 2}}, + expected: []usage{usageUnder}, + }, + { + name: "noOverUseBeforeDelayStatic", + adaptive: false, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(2 * time.Millisecond), 30, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "noOverUseIfEstimateDecreasedStatic", + adaptive: false, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(10 * time.Millisecond), 40, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + {time.Time{}.Add(35 * time.Millisecond), 3, 4}, + }, + expected: []usage{usageNormal, usageNormal, usageOver, usageNormal}, + }, + { + name: "noEstimateNoUsageAdaptive", + adaptive: true, + values: []estimate{}, + expected: []usage{}, + }, + { + name: "overuseAdaptive", + adaptive: true, + values: []estimate{ + {time.Time{}, 1, 1}, + {time.Time{}.Add(5 * time.Millisecond), 20, 2}, + {time.Time{}.Add(20 * time.Millisecond), 30, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "normaluseAdaptive", + adaptive: true, + values: []estimate{{estimate: 0}}, + expected: []usage{usageNormal}, + }, + { + name: "underuseAdaptive", + adaptive: true, + values: []estimate{{time.Time{}, -20, 2}}, + expected: []usage{usageUnder}, + }, + { + name: "noOverUseBeforeDelayAdaptive", + adaptive: true, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(2 * time.Millisecond), 30, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "noOverUseIfEstimateDecreasedAdaptive", + adaptive: true, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(10 * time.Millisecond), 40, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + {time.Time{}.Add(35 * time.Millisecond), 3, 4}, + }, + expected: []usage{usageNormal, usageNormal, usageOver, usageNormal}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + od := newOveruseDetector(tc.adaptive) + received := []usage{} + for _, e := range tc.values { + u := od.update(e.ts, e.estimate, e.numDeltas) + received = append(received, u) + } + assert.Equal(t, tc.expected, received) + }) + } +} + +func TestOveruseDetectorAdaptThreshold(t *testing.T) { + cases := []struct { + name string + od *overuseDetector + ts time.Time + estimate float64 + expectedThreshold float64 + }{ + { + name: "minThreshold", + od: &overuseDetector{ + adaptiveThreshold: true, + }, + ts: time.Time{}, + estimate: 0, + expectedThreshold: 6, + }, + { + name: "increase", + od: &overuseDetector{ + adaptiveThreshold: true, + delayThreshold: 12.5, + lastUpdate: time.Time{}.Add(time.Second), + }, + ts: time.Time{}.Add(2 * time.Second), + estimate: 25, + expectedThreshold: 23.375, + }, + { + name: "maxThreshold", + od: &overuseDetector{ + adaptiveThreshold: true, + delayThreshold: 600, + lastUpdate: time.Time{}, + }, + ts: time.Time{}.Add(time.Second), + estimate: 610, + expectedThreshold: 600, + }, + { + name: "decrease", + od: &overuseDetector{ + adaptiveThreshold: true, + delayThreshold: 12.5, + lastUpdate: time.Time{}, + }, + ts: time.Time{}.Add(10 * time.Millisecond), + estimate: 1, + expectedThreshold: 8.015, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.od.adaptThreshold(tc.ts, tc.estimate) + assert.Equal(t, tc.expectedThreshold, tc.od.delayThreshold) + }) + } +} diff --git a/gcc/send_side_bwe.go b/gcc/send_side_bwe.go new file mode 100644 index 0000000..06b0bef --- /dev/null +++ b/gcc/send_side_bwe.go @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "time" + + "github.com/pion/logging" +) + +// Option is a functional option for a SendSideController. +type Option func(*SendSideController) error + +// WithLoggerFactory configures a custom logger factory for a +// SendSideController. +func WithLoggerFactory(lf logging.LoggerFactory) Option { + return func(ssc *SendSideController) error { + ssc.logFactory = lf + + return nil + } +} + +// SendSideController is a sender side congestion controller. +type SendSideController struct { + logFactory logging.LoggerFactory + log logging.LeveledLogger + dre *deliveryRateEstimator + lrc *lossRateController + drc *delayRateController + targetRate int +} + +// NewSendSideController creates a new SendSideController with initial, min and +// max rates. +func NewSendSideController(initialRate, minRate, maxRate int, opts ...Option) (*SendSideController, error) { + ssc := &SendSideController{ + logFactory: logging.NewDefaultLoggerFactory(), + dre: newDeliveryRateEstimator(time.Second), + lrc: newLossRateController(initialRate, minRate, maxRate), + targetRate: initialRate, + } + for _, opt := range opts { + if err := opt(ssc); err != nil { + return nil, err + } + } + ssc.log = ssc.logFactory.NewLogger("bwe_send_side_controller") + ssc.drc = newDelayRateController(initialRate, minRate, maxRate, ssc.logFactory.NewLogger("bwe_delay_rate_controller")) + + return ssc, nil +} + +func (c *SendSideController) OnLoss() { + c.lrc.onPacketLost() +} + +// OnAck must be called when new acknowledgments arrive. Packets MUST not be +// acknowledged more than once. +func (c *SendSideController) OnAck(sequenceNumber uint64, size int, departure, arrival time.Time) { + c.lrc.onPacketAcked() + if !arrival.IsZero() { + c.dre.onPacketAcked(arrival, size) + c.drc.onPacketAcked( + sequenceNumber, + size, + departure, + arrival, + ) + } +} + +// OnFeedback must be called when a new feedback report arrives. ts is the +// arrival timestamp of the feedback report. rtt is the latest RTT sample. It +// returns the new target rate. +func (c *SendSideController) OnFeedback(ts time.Time, rtt time.Duration) int { + delivered := c.dre.getRate() + lossTarget := c.lrc.update(delivered) + delayTarget := c.drc.update(ts, delivered, rtt) + c.targetRate = min(lossTarget, delayTarget) + c.log.Tracef( + "rtt=%v, delivered=%v, lossTarget=%v, delayTarget=%v, target=%v", + rtt.Nanoseconds(), + delivered, + lossTarget, + delayTarget, + c.targetRate, + ) + + return c.targetRate +}