diff --git a/internal/celt/frame.go b/internal/celt/frame.go index f3b2c27..c11eafe 100644 --- a/internal/celt/frame.go +++ b/internal/celt/frame.go @@ -6,8 +6,18 @@ package celt import "fmt" const ( - postFilterPitchBase = 16 - postFilterGainStep = 0.09375 + postFilterPitchBase = 16 + postFilterGainStep = 0.09375 + bitResolution = 3 + defaultSpreadDecision = 2 + defaultAllocationTrim = 5 + initialDynamicAllocationLogP = 6 + minDynamicAllocationLogP = 2 + allocationTrimBitCost = 6 + firstTimeFrequencyChangeLogP = 4 + firstTransientFrequencyChangeLogP = 2 + nextTimeFrequencyChangeLogP = 5 + nextTransientFrequencyChangeLogP = 4 ) type frameConfig struct { @@ -29,6 +39,11 @@ type frameSideInfo struct { shortBlockCount int intraEnergy bool coarseEnergy [2][maxBands]float32 + tfChange [maxBands]int + tfSelect int + spread int + bandBoost [maxBands]int + allocationTrim int } type postFilter struct { @@ -39,8 +54,8 @@ type postFilter struct { tapset int } -// decodeFrameSideInfo consumes the initial CELT symbols through coarse energy -// in the order specified by RFC 6716 Table 56. TF changes, allocation, and PVQ +// decodeFrameSideInfo consumes the initial CELT symbols through the allocation +// header in the order specified by RFC 6716 Table 56. Pulse allocation and PVQ // residual decoding are intentionally left to the following CELT slices. func (d *Decoder) decodeFrameSideInfo(data []byte, cfg frameConfig) (frameSideInfo, error) { info, err := d.validateFrameConfig(cfg) @@ -63,6 +78,7 @@ func (d *Decoder) decodeFrameSideInfo(data []byte, cfg frameConfig) (frameSideIn d.decodeIntraEnergyFlag(&info) d.prepareCoarseEnergyHistory(&info) d.decodeCoarseEnergy(&info) + d.decodeAllocationHeader(&info) return info, nil } @@ -97,10 +113,12 @@ func (d *Decoder) validateFrameConfig(cfg frameConfig) (frameSideInfo, error) { } return frameSideInfo{ - lm: lm, - startBand: cfg.startBand, - endBand: cfg.endBand, - channelCount: cfg.channelCount, + lm: lm, + startBand: cfg.startBand, + endBand: cfg.endBand, + channelCount: cfg.channelCount, + spread: defaultSpreadDecision, + allocationTrim: defaultAllocationTrim, }, nil } @@ -226,6 +244,122 @@ func (d *Decoder) decodeCoarseEnergyDelta(info *frameSideInfo, probModel []uint8 } } +func (d *Decoder) decodeAllocationHeader(info *frameSideInfo) { + d.decodeTimeFrequencyChanges(info) + d.decodeSpread(info) + totalBitsEighth := d.decodeDynamicAllocation(info, info.totalBits< 0 && tell+uint(logP)+1 <= budget + if tfSelectReserved { + budget-- + } + + current := 0 + changed := 0 + for band := info.startBand; band < info.endBand; band++ { + if tell+uint(logP) <= budget { + current ^= int(d.rangeDecoder.DecodeSymbolLogP(uint(logP))) + tell = d.rangeDecoder.Tell() + changed |= current + } + info.tfChange[band] = current + + if info.transient { + logP = nextTransientFrequencyChangeLogP + } else { + logP = nextTimeFrequencyChangeLogP + } + } + + info.tfSelect = 0 + table := tfSelectTable[info.lm] + if tfSelectReserved && + table[4*boolIndex(info.transient)+changed] != + table[4*boolIndex(info.transient)+2+changed] { + info.tfSelect = int(d.rangeDecoder.DecodeSymbolLogP(1)) + } + + for band := info.startBand; band < info.endBand; band++ { + info.tfChange[band] = int(table[4*boolIndex(info.transient)+2*info.tfSelect+info.tfChange[band]]) + } +} + +func (d *Decoder) decodeSpread(info *frameSideInfo) { + info.spread = defaultSpreadDecision + if d.rangeDecoder.Tell()+4 <= info.totalBits { + info.spread = int(d.rangeDecoder.DecodeSymbolWithICDF(icdfSpread)) + } +} + +// decodeDynamicAllocation decodes RFC 6716 Section 4.3.3 band boost offsets in +// 1/8-bit units and returns the boost-adjusted total bit budget in 1/8-bit units. +func (d *Decoder) decodeDynamicAllocation(info *frameSideInfo, totalBitsEighth uint) uint { + caps := allocationCaps(info.lm, info.channelCount) + dynamicAllocationLogP := initialDynamicAllocationLogP + tellFrac := d.rangeDecoder.TellFrac() + + for band := info.startBand; band < info.endBand; band++ { + width := info.channelCount * (int(bandEdges[band+1]-bandEdges[band]) << info.lm) + quanta := min(width<= totalBitsEighth { + totalBitsEighth = 0 + } else { + totalBitsEighth -= quantaBits + } + loopLogP = 1 + } + + info.bandBoost[band] = boost + if boost > 0 { + dynamicAllocationLogP = max(minDynamicAllocationLogP, dynamicAllocationLogP-1) + } + } + + return totalBitsEighth +} + +func (d *Decoder) decodeAllocationTrim(info *frameSideInfo, totalBitsEighth uint) { + info.allocationTrim = defaultAllocationTrim + if d.rangeDecoder.TellFrac()+uint(allocationTrimBitCost<> 2 + } + + return caps +} + func smallEnergyDelta(symbol uint32) int { switch symbol { case 1: diff --git a/internal/celt/frame_test.go b/internal/celt/frame_test.go index 42b8897..024d48c 100644 --- a/internal/celt/frame_test.go +++ b/internal/celt/frame_test.go @@ -356,6 +356,142 @@ func TestDecodeCoarseEnergy(t *testing.T) { }) } +func TestDecodeTimeFrequencyChanges(t *testing.T) { + t.Run("decodes non-transient tf_change and tf_select", func(t *testing.T) { + decoder := NewDecoder() + decoder.rangeDecoder = rangeDecoderWithBinaryOne() + info := frameSideInfo{ + lm: 1, + totalBits: 256, + startBand: 0, + endBand: 1, + channelCount: 1, + } + + decoder.decodeTimeFrequencyChanges(&info) + + assert.Equal(t, 1, info.tfSelect) + assert.Equal(t, -2, info.tfChange[0]) + }) + + t.Run("decodes transient tf_change and tf_select", func(t *testing.T) { + decoder := NewDecoder() + decoder.rangeDecoder = rangeDecoderWithBinaryOne() + info := frameSideInfo{ + lm: 2, + totalBits: 256, + startBand: 0, + endBand: 1, + channelCount: 1, + transient: true, + } + + decoder.decodeTimeFrequencyChanges(&info) + + assert.Equal(t, 1, info.tfSelect) + assert.Equal(t, -1, info.tfChange[0]) + }) + + t.Run("maps default transient changes when budget is exhausted", func(t *testing.T) { + decoder := NewDecoder() + info := frameSideInfo{ + lm: 3, + totalBits: 0, + startBand: 0, + endBand: 2, + channelCount: 1, + transient: true, + } + + decoder.decodeTimeFrequencyChanges(&info) + + assert.Zero(t, info.tfSelect) + assert.Equal(t, 3, info.tfChange[0]) + assert.Equal(t, 3, info.tfChange[1]) + }) +} + +func TestDecodeSpread(t *testing.T) { + t.Run("decodes spread decision when enough bits remain", func(t *testing.T) { + decoder := NewDecoder() + decoder.rangeDecoder = rangeDecoderWithCDFSymbol(31, 32) + info := frameSideInfo{totalBits: 256} + + decoder.decodeSpread(&info) + + assert.Equal(t, 3, info.spread) + }) + + t.Run("defaults to normal spread without enough bits", func(t *testing.T) { + decoder := NewDecoder() + info := frameSideInfo{totalBits: 0} + + decoder.decodeSpread(&info) + + assert.Equal(t, defaultSpreadDecision, info.spread) + }) +} + +func TestDecodeDynamicAllocation(t *testing.T) { + t.Run("decodes boosts until the band cap", func(t *testing.T) { + decoder := NewDecoder() + decoder.rangeDecoder = rangeDecoderWithBinaryOne() + info := frameSideInfo{ + lm: 0, + totalBits: 256, + startBand: 0, + endBand: 1, + channelCount: 1, + } + totalBitsEighth := info.totalBits << bitResolution + + remaining := decoder.decodeDynamicAllocation(&info, totalBitsEighth) + + assert.Equal(t, 72, info.bandBoost[0]) + assert.Less(t, remaining, totalBitsEighth) + }) + + t.Run("stops immediately on a zero boost flag", func(t *testing.T) { + decoder := NewDecoder() + decoder.rangeDecoder = rangeDecoderWithBinaryZero() + info := frameSideInfo{ + lm: 0, + totalBits: 256, + startBand: 0, + endBand: 1, + channelCount: 1, + } + totalBitsEighth := info.totalBits << bitResolution + + remaining := decoder.decodeDynamicAllocation(&info, totalBitsEighth) + + assert.Zero(t, info.bandBoost[0]) + assert.Equal(t, totalBitsEighth, remaining) + }) +} + +func TestDecodeAllocationTrim(t *testing.T) { + t.Run("decodes trim when six bits remain", func(t *testing.T) { + decoder := NewDecoder() + decoder.rangeDecoder = rangeDecoderWithCDFSymbol(87, 128) + info := frameSideInfo{} + + decoder.decodeAllocationTrim(&info, uint(256)<