diff --git a/internal/celt/allocation.go b/internal/celt/allocation.go new file mode 100644 index 0000000..f86d1e0 --- /dev/null +++ b/internal/celt/allocation.go @@ -0,0 +1,497 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//nolint:cyclop,gocognit,gocyclo,gosec,lll,maintidx,nestif,nlreturn,wastedassign // Mirrors RFC 6716 allocation flow and bounded entropy-code arithmetic. +package celt + +import "github.com/pion/opus/internal/rangecoding" + +const ( + allocationSteps = 6 + maxFineBits = 8 + fineOffset = 21 +) + +type allocationState struct { + pulses [maxBands]int // Shape budget in 1/8-bit units; PVQ converts this to a pulse count later. + fineQuant [maxBands]int + finePriority [maxBands]int + intensity int + dualStereo int + balance int + codedBands int + bits int +} + +// computeAllocation derives the per-band shape and fine-energy budgets in +// 1/8-bit units. The implementation follows the RFC 6716 Section 4.3.3 +// structure: reserve later side-symbol budgets, choose and interpolate static +// allocation vectors, decide coded bands and stereo side data, then split each +// coded-band budget between fine energy and shape coding. +func (d *Decoder) computeAllocation(info *frameSideInfo, bits int) allocationState { + state := allocationState{bits: bits} + caps := allocationCaps(info.lm, info.channelCount) + balance := 0 + state.codedBands = computeAllocation( + info.startBand, + info.endBand, + info.bandBoost[:], + caps[:], + info.allocationTrim, + &state.intensity, + &state.dualStereo, + bits, + &balance, + state.pulses[:], + state.fineQuant[:], + state.finePriority[:], + info.channelCount, + info.lm, + &d.rangeDecoder, + ) + state.balance = balance + + return state +} + +func computeAllocation( + start, end int, + offsets []int, + caps []int, + allocationTrim int, + intensity *int, + dualStereo *int, + total int, + balance *int, + pulses []int, + fineQuant []int, + finePriority []int, + channelCount int, + lm int, + rangeDecoder *rangecoding.Decoder, +) int { + if total < 0 { + total = 0 + } + + // Section 4.3.3 recovers the exact budget the encoder saw. The symbols + // decoded after allocation still need their entropy space reserved here. + skipReserved := 0 + if total >= 1< total { + intensityReserved = 0 + } else { + total -= intensityReserved + if total >= 1<>4) + trimOffset[band] = channelCount * bandWidth * (allocationTrim - defaultAllocationTrim - lm) * + (end - band - 1) * (1 << (lm + bitResolution)) >> 6 + if bandWidth<> 1 + psum := 0 + done := false + for band := end - 1; band >= start; band-- { + bandWidth := int(bandEdges[band+1] - bandEdges[band]) + bits := channelCount * bandWidth * int(bandAllocation[mid][band]) << lm >> 2 + if bits > 0 { + bits = max(0, bits+trimOffset[band]) + } + bits += offsets[band] + if bits >= threshold[band] || done { + done = true + psum += min(bits, caps[band]) + } else if bits >= channelCount< total { + hi = mid - 1 + } else { + lo = mid + 1 + } + } + + hi = lo + lo-- + skipStart := start + // Convert the bracketing allocation vectors into a base budget plus delta + // per band. Boosted bands become the lower bound for band skipping. + for band := start; band < end; band++ { + bandWidth := int(bandEdges[band+1] - bandEdges[band]) + bits1Band := channelCount * bandWidth * int(bandAllocation[lo][band]) << lm >> 2 + bits2Band := 0 + if hi >= len(bandAllocation) { + bits2Band = caps[band] + } else { + bits2Band = channelCount * bandWidth * int(bandAllocation[hi][band]) << lm >> 2 + } + if bits1Band > 0 { + bits1Band = max(0, bits1Band+trimOffset[band]) + } + if bits2Band > 0 { + bits2Band = max(0, bits2Band+trimOffset[band]) + } + if lo > 0 { + bits1Band += offsets[band] + } + bits2Band += offsets[band] + if offsets[band] > 0 { + skipStart = band + } + bits2[band] = max(0, bits2Band-bits1Band) + bits1[band] = bits1Band + } + + return interpolateBitsToPulses( + start, + end, + skipStart, + bits1, + bits2, + threshold, + caps, + total, + balance, + skipReserved, + intensity, + intensityReserved, + dualStereo, + dualStereoReserved, + pulses, + fineQuant, + finePriority, + channelCount, + lm, + rangeDecoder, + ) +} + +// interpolateBitsToPulses completes the RFC 6716 Section 4.3.3 allocation +// computation after the static table search. The bits slice leaves this +// function as the per-band shape budget in 1/8-bit units, after fine-energy +// bits have been removed. +func interpolateBitsToPulses( + start, end, skipStart int, + bits1 []int, + bits2 []int, + threshold []int, + caps []int, + total int, + balance *int, + skipReserved int, + intensity *int, + intensityReserved int, + dualStereo *int, + dualStereoReserved int, + bits []int, + fineQuant []int, + finePriority []int, + channelCount int, + lm int, + rangeDecoder *rangecoding.Decoder, +) int { + allocationFloor := channelCount << bitResolution + stereo := boolIndex(channelCount > 1) + + // Interpolate between the two neighboring static allocation vectors with + // six fractional steps, matching the reference CELT allocation search. + lo := 0 + hi := 1 << allocationSteps + for range allocationSteps { + mid := (lo + hi) >> 1 + psum := 0 + done := false + for band := end - 1; band >= start; band-- { + tmp := bits1[band] + (mid * bits2[band] >> allocationSteps) + if tmp >= threshold[band] || done { + done = true + psum += min(tmp, caps[band]) + } else if tmp >= allocationFloor { + psum += allocationFloor + } + } + if psum > total { + hi = mid + } else { + lo = mid + } + } + + // Apply the interpolation point, thresholding, and per-band caps to get + // the provisional allocation sum. + psum := 0 + done := false + for band := end - 1; band >= start; band-- { + tmp := bits1[band] + (lo * bits2[band] >> allocationSteps) + if tmp < threshold[band] && !done { + if tmp >= allocationFloor { + tmp = allocationFloor + } else { + tmp = 0 + } + } else { + done = true + } + tmp = min(tmp, caps[band]) + bits[band] = tmp + psum += tmp + } + + // Walk backward to find the coded band boundary. The optional skip bit + // lets the stream keep a high band coded when enough bits remain. + codedBands := end + for { + codedBands-- + band := codedBands + if band <= skipStart { + total += skipReserved + codedBands++ + break + } + + left := total - psum + perCoeff := left / (int(bandEdges[codedBands+1]) - int(bandEdges[start])) + left -= (int(bandEdges[codedBands+1]) - int(bandEdges[start])) * perCoeff + rem := max(left-(int(bandEdges[band])-int(bandEdges[start])), 0) + bandWidth := int(bandEdges[codedBands+1] - bandEdges[band]) + bandBits := bits[band] + perCoeff*bandWidth + rem + if bandBits >= max(threshold[band], allocationFloor+(1< 0 { + intensityReserved = log2FracTable[band-start] + } + psum += intensityReserved + if bandBits >= allocationFloor { + psum += allocationFloor + bits[band] = allocationFloor + } else { + bits[band] = 0 + } + } + + // Intensity and dual-stereo symbols are decoded only after codedBands is + // known, because their alphabets depend on the surviving band range. + if intensityReserved > 0 { + value, _ := rangeDecoder.DecodeUniform(uint32(codedBands + 1 - start)) + *intensity = start + int(value) + } else { + *intensity = 0 + } + if *intensity <= start { + total += dualStereoReserved + dualStereoReserved = 0 + } + if dualStereoReserved > 0 { + *dualStereo = int(rangeDecoder.DecodeSymbolLogP(1)) + } else { + *dualStereo = 0 + } + + // Distribute any slack bits uniformly over coded MDCT bins before + // separating each band into fine-energy and shape budgets. + left := total - psum + perCoeff := left / (int(bandEdges[codedBands]) - int(bandEdges[start])) + left -= (int(bandEdges[codedBands]) - int(bandEdges[start])) * perCoeff + for band := start; band < codedBands; band++ { + bits[band] += perCoeff * int(bandEdges[band+1]-bandEdges[band]) + } + for band := start; band < codedBands; band++ { + tmp := min(left, int(bandEdges[band+1]-bandEdges[band])) + bits[band] += tmp + left -= tmp + } + + currentBalance := 0 + band := start + for ; band < codedBands; band++ { + // Fine energy gets whole raw bits per channel first; the remainder is + // carried forward as shape budget for PVQ in the next implementation slice. + width := int(bandEdges[band+1] - bandEdges[band]) + n := width << lm + bits[band] += currentBalance + excess := 0 + if n > 1 { + excess = max(bits[band]-caps[band], 0) + bits[band] -= excess + den := channelCount * n + if channelCount == 2 && n > 2 && *dualStereo == 0 && band < *intensity { + den++ + } + ncLogN := den * (logN400[band] + (lm << bitResolution)) + offset := (ncLogN >> 1) - den*fineOffset + if n == 2 { + offset += den << bitResolution >> 2 + } + if bits[band]+offset < den*2<> 2 + } else if bits[band]+offset < den*3<> 3 + } + fineQuant[band] = max(0, (bits[band]+offset+(den<<(bitResolution-1)))/(den< bits[band]>>bitResolution { + fineQuant[band] = bits[band] >> stereo >> bitResolution + } + fineQuant[band] = min(fineQuant[band], maxFineBits) + finePriority[band] = boolIndex(fineQuant[band]*(den<= bits[band]+offset) + bits[band] -= channelCount * fineQuant[band] << bitResolution + } else { + excess = max(0, bits[band]-(channelCount< 0 { + extraFine := min(excess>>(stereo+bitResolution), maxFineBits-fineQuant[band]) + fineQuant[band] += extraFine + extraBits := extraFine * channelCount << bitResolution + finePriority[band] = boolIndex(extraBits >= excess-currentBalance) + excess -= extraBits + } + currentBalance = excess + } + *balance = currentBalance + + // Bands above codedBands do not get shape bits, but may retain final + // fine-energy priority bookkeeping if they had enough provisional budget. + for ; band < end; band++ { + fineQuant[band] = bits[band] >> stereo >> bitResolution + bits[band] = 0 + finePriority[band] = boolIndex(fineQuant[band] < 1) + } + + return codedBands +} + +// getPulses expands the compact pulse-count cache indices used by the RFC +// 6716 Section 4.3.3 allocation tables. +func getPulses(index int) int { + if index < 8 { + return index + } + + return (8 + (index & 7)) << ((index >> 3) - 1) +} + +// bitsToPulses implements the RFC 6716 Section 4.3.4.1 cache search from a +// 1/8-bit shape budget to the nearest allowed integer pulse count. +func bitsToPulses(band, lm, bits int) int { + if bits <= 0 { + return 0 + } + lm++ + cacheStart := int(pulseCacheIndex[lm*maxBands+band]) + if cacheStart < 0 { + return 0 + } + cache := pulseCacheBits[cacheStart:] + lo := 0 + hi := int(cache[0]) + bits-- + for range 6 { + mid := (lo + hi + 1) >> 1 + if int(cache[mid]) >= bits { + hi = mid + } else { + lo = mid + } + } + loBits := -1 + if lo != 0 { + loBits = int(cache[lo]) + } + if bits-loBits <= int(cache[hi])-bits { + return lo + } + + return hi +} + +// pulsesToBits maps an allowed pulse count back to its cached 1/8-bit cost. +func pulsesToBits(band, lm, pulses int) int { + if pulses == 0 { + return 0 + } + lm++ + cacheStart := int(pulseCacheIndex[lm*maxBands+band]) + + return int(pulseCacheBits[cacheStart+pulses]) + 1 +} + +// decodeFineEnergy applies the first RFC 6716 Section 4.3.2.2 fine-energy +// refinement, using the number of raw bits assigned by Section 4.3.3. +func (d *Decoder) decodeFineEnergy(info *frameSideInfo, fineQuant [maxBands]int) { + for band := info.startBand; band < info.endBand; band++ { + if fineQuant[band] <= 0 { + continue + } + for channel := range info.channelCount { + // Fine energy uses raw tail bits so refinement does not perturb the + // range coder state used by the main CELT symbols. + q2 := d.rangeDecoder.DecodeRawBits(uint(fineQuant[band])) + offset := (float32(q2)+0.5)*float32(uint(1)<<(14-fineQuant[band]))/16384 - 0.5 + d.previousLogE[channel][band] += offset + } + } +} + +// finalizeFineEnergy consumes the RFC 6716 Section 4.3.2.2 final fine-energy +// priority bits that refine bands after PVQ decoding. +func (d *Decoder) finalizeFineEnergy( + info *frameSideInfo, + fineQuant [maxBands]int, + finePriority [maxBands]int, + bitsLeft int, +) { + for priority := range 2 { + for band := info.startBand; band < info.endBand && bitsLeft >= info.channelCount; band++ { + if fineQuant[band] >= maxFineBits || finePriority[band] != priority { + continue + } + for channel := range info.channelCount { + q2 := d.rangeDecoder.DecodeRawBits(1) + offset := (float32(q2) - 0.5) * float32(uint(1)<<(14-fineQuant[band]-1)) / 16384 + d.previousLogE[channel][band] += offset + bitsLeft-- + } + } + } +} diff --git a/internal/celt/allocation_test.go b/internal/celt/allocation_test.go new file mode 100644 index 0000000..1e9e022 --- /dev/null +++ b/internal/celt/allocation_test.go @@ -0,0 +1,82 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package celt + +import ( + "testing" + + "github.com/pion/opus/internal/rangecoding" + "github.com/stretchr/testify/assert" +) + +func TestPulseCacheHelpers(t *testing.T) { + assert.Equal(t, 0, getPulses(0)) + assert.Equal(t, 7, getPulses(7)) + assert.Equal(t, 8, getPulses(8)) + assert.Equal(t, 15, getPulses(15)) + assert.Equal(t, 16, getPulses(16)) + assert.Equal(t, 60, getPulses(31)) + + assert.Zero(t, bitsToPulses(0, 0, 0)) + assert.Equal(t, 1, bitsToPulses(0, 0, 8)) + assert.Equal(t, 1, bitsToPulses(0, 1, 16)) + assert.Equal(t, 2, bitsToPulses(0, 1, 24)) + assert.Equal(t, 1, bitsToPulses(1, 1, 16)) + assert.Equal(t, 16, pulsesToBits(0, 1, 1)) + assert.Equal(t, 24, pulsesToBits(0, 1, 2)) + assert.Equal(t, 16, pulsesToBits(1, 1, 1)) +} + +func TestDecodeFineEnergy(t *testing.T) { + decoder := NewDecoder() + decoder.rangeDecoder = rangeDecoderWithRawBits(0b00001010) + decoder.previousLogE[0][0] = 1 + decoder.previousLogE[1][0] = 2 + decoder.previousLogE[0][1] = 3 + decoder.previousLogE[1][1] = 4 + info := frameSideInfo{ + startBand: 0, + endBand: 2, + channelCount: 2, + } + var fineQuant [maxBands]int + fineQuant[0] = 2 + + decoder.decodeFineEnergy(&info, fineQuant) + + assert.InDelta(t, 1.125, decoder.previousLogE[0][0], 0.000001) + assert.InDelta(t, 2.125, decoder.previousLogE[1][0], 0.000001) + assert.Equal(t, float32(3), decoder.previousLogE[0][1]) + assert.Equal(t, float32(4), decoder.previousLogE[1][1]) +} + +func TestFinalizeFineEnergy(t *testing.T) { + decoder := NewDecoder() + decoder.rangeDecoder = rangeDecoderWithRawBits(0b00000101) + info := frameSideInfo{ + startBand: 0, + endBand: 3, + channelCount: 1, + } + var fineQuant [maxBands]int + var finePriority [maxBands]int + fineQuant[0] = 2 + fineQuant[1] = 2 + fineQuant[2] = maxFineBits + finePriority[0] = 0 + finePriority[1] = 1 + + decoder.finalizeFineEnergy(&info, fineQuant, finePriority, 2) + + assert.InDelta(t, 0.0625, decoder.previousLogE[0][0], 0.000001) + assert.InDelta(t, -0.0625, decoder.previousLogE[0][1], 0.000001) + assert.Zero(t, decoder.previousLogE[0][2]) +} + +func rangeDecoderWithRawBits(bits byte) rangecoding.Decoder { + decoder := rangecoding.Decoder{} + decoder.SetInternalValues([]byte{bits}, 0, 1<<31, 0) + + return decoder +} diff --git a/internal/celt/frame.go b/internal/celt/frame.go index c11eafe..72ce1f0 100644 --- a/internal/celt/frame.go +++ b/internal/celt/frame.go @@ -44,6 +44,8 @@ type frameSideInfo struct { spread int bandBoost [maxBands]int allocationTrim int + allocation allocationState + antiCollapseRsv int } type postFilter struct { @@ -79,6 +81,7 @@ func (d *Decoder) decodeFrameSideInfo(data []byte, cfg frameConfig) (frameSideIn d.prepareCoarseEnergyHistory(&info) d.decodeCoarseEnergy(&info) d.decodeAllocationHeader(&info) + d.decodeAllocationAndFineEnergy(&info) return info, nil } @@ -251,6 +254,22 @@ func (d *Decoder) decodeAllocationHeader(info *frameSideInfo) { d.decodeAllocationTrim(info, totalBitsEighth) } +// decodeAllocationAndFineEnergy follows RFC 6716 Section 4.3.3 after the +// allocation header: reserve the Section 4.3.5 anti-collapse bit, compute +// shape/fine-energy budgets, then decode the first fine-energy refinement pass. +func (d *Decoder) decodeAllocationAndFineEnergy(info *frameSideInfo) { + totalBits := int(info.totalBits) //nolint:gosec // G115: CELT frame bit counts are packet-bounded. + tellFrac := int(d.rangeDecoder.TellFrac()) //nolint:gosec // G115: entropy cursor is packet-bounded. + bits := (totalBits << bitResolution) - tellFrac - 1 + info.antiCollapseRsv = 0 + if info.transient && info.lm >= 2 && bits >= (info.lm+2)<