diff --git a/internal/celt/bands.go b/internal/celt/bands.go new file mode 100644 index 0000000..f7d741e --- /dev/null +++ b/internal/celt/bands.go @@ -0,0 +1,786 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//nolint:cyclop,gocognit,gocyclo,gosec,lll,maintidx,nestif,varnamelen,wastedassign // Keeps the PVQ band recursion close to the RFC/C reference. +package celt + +import ( + "math" + "math/bits" + + "github.com/pion/opus/internal/rangecoding" +) + +const ( + qThetaOffset = 4 + qThetaOffsetTwoPhase = 16 +) + +var orderyTable = [...]int{ //nolint:gochecknoglobals + 1, 0, + 3, 0, 2, 1, + 7, 0, 4, 3, 6, 1, 5, 2, + 15, 0, 8, 7, 12, 3, 11, 4, 14, 1, 9, 6, 13, 2, 10, 5, +} + +type bandDecodeState struct { + rangeDecoder *rangecoding.Decoder + seed uint32 + pulseScratch []int + tmpScratch []float32 +} + +// quantAllBands drives RFC 6716 Section 4.3.4 shape decoding across the coded +// band range. It keeps the allocation balance, lowband folding source, and +// per-band collapse masks needed by later anti-collapse and synthesis stages. +func quantAllBands(info *frameSideInfo, x []float32, y []float32, totalBits int, state *bandDecodeState) []byte { + channelCount := 1 + if y != nil { + channelCount = 2 + } + blocks := 1 + if info.transient { + blocks = 1 << info.lm + } + scale := 1 << info.lm + frameBins := scale * int(bandEdges[maxBands]) + norm := make([]float32, channelCount*frameBins) + norm2 := norm[frameBins:] + lowbandScratch := make([]float32, scale*int(bandEdges[maxBands]-bandEdges[maxBands-1])) + collapseMasks := make([]byte, channelCount*maxBands) + + lowbandOffset := 0 + updateLowband := true + balance := info.allocation.balance + dualStereo := channelCount == 2 && info.allocation.dualStereo != 0 + for band := info.startBand; band < info.endBand; band++ { + tell := int(state.rangeDecoder.TellFrac()) + if band != info.startBand { + balance -= tell + } + remainingBits := totalBits - tell - 1 + bandBits := 0 + if band <= info.allocation.codedBands-1 { + currentBalance := balance / min(3, info.allocation.codedBands-band) + bandBits = max(0, min(16383, min(remainingBits+1, info.allocation.pulses[band]+currentBalance))) + } + + bandStart := scale * int(bandEdges[band]) + bandEnd := scale * int(bandEdges[band+1]) + bandWidth := bandEnd - bandStart + // Shape folding reuses an earlier decoded band with matching width when + // the current band has too few pulses to code independently. + if bandStart-bandWidth >= scale*int(bandEdges[info.startBand]) || band == info.startBand+1 { + if updateLowband || lowbandOffset == 0 { + lowbandOffset = band + } + } + if band == info.startBand+1 && info.startBand+2 <= maxBands { + n1 := scale * int(bandEdges[info.startBand+1]-bandEdges[info.startBand]) + n2 := scale * int(bandEdges[info.startBand+2]-bandEdges[info.startBand+1]) + offset := scale * int(bandEdges[info.startBand]) + if n2 > n1 { + copy(norm[offset+n1:offset+n2], norm[offset+2*n1-n2:offset+n1]) + if channelCount == 2 { + copy(norm2[offset+n1:offset+n2], norm2[offset+2*n1-n2:offset+n1]) + } + } + } + + effectiveLowband := -1 + xMask := uint(0) + yMask := uint(0) + if lowbandOffset != 0 && (info.spread != spreadAggressive || blocks > 1 || info.tfChange[band] < 0) { + effectiveLowband = max(scale*int(bandEdges[info.startBand]), scale*int(bandEdges[lowbandOffset])-bandWidth) + foldStart := lowbandOffset + for { + foldStart-- + if scale*int(bandEdges[foldStart]) <= effectiveLowband { + break + } + } + foldEnd := lowbandOffset - 1 + for { + foldEnd++ + if foldEnd >= band || scale*int(bandEdges[foldEnd]) >= effectiveLowband+bandWidth { + break + } + } + for fold := foldStart; fold < foldEnd; fold++ { + xMask |= uint(collapseMasks[fold*channelCount]) + yMask |= uint(collapseMasks[fold*channelCount+channelCount-1]) + } + } else { + xMask = (1 << blocks) - 1 + yMask = xMask + } + + if dualStereo && band == info.allocation.intensity { + dualStereo = false + for i := scale * int(bandEdges[info.startBand]); i < bandStart; i++ { + norm[i] = 0.5 * (norm[i] + norm2[i]) + } + } + + var lowband []float32 + if effectiveLowband >= 0 { + lowband = norm[effectiveLowband:] + } + if dualStereo { + xMask = quantBand( + band, + x[bandStart:bandEnd], + nil, + bandWidth, + bandBits/2, + info.spread, + blocks, + info.allocation.intensity, + info.tfChange[band], + lowband, + &remainingBits, + info.lm, + norm[bandStart:], + 0, + 1, + lowbandScratch, + xMask, + state, + ) + var lowbandY []float32 + if effectiveLowband >= 0 { + lowbandY = norm2[effectiveLowband:] + } + yMask = quantBand( + band, + y[bandStart:bandEnd], + nil, + bandWidth, + bandBits/2, + info.spread, + blocks, + info.allocation.intensity, + info.tfChange[band], + lowbandY, + &remainingBits, + info.lm, + norm2[bandStart:], + 0, + 1, + lowbandScratch, + yMask, + state, + ) + } else { + xMask = quantBand( + band, + x[bandStart:bandEnd], + yBandSlice(y, bandStart, bandEnd), + bandWidth, + bandBits, + info.spread, + blocks, + info.allocation.intensity, + info.tfChange[band], + lowband, + &remainingBits, + info.lm, + norm[bandStart:], + 0, + 1, + lowbandScratch, + xMask|yMask, + state, + ) + yMask = xMask + } + collapseMasks[band*channelCount] = byte(xMask) + collapseMasks[band*channelCount+channelCount-1] = byte(yMask) + balance += info.allocation.pulses[band] + tell + updateLowband = bandBits > bandWidth<= 1< 0 { + recombine = tfChange + } + if lowband != nil && (recombine != 0 || (nPerBlock&1) == 0 && tfChange < 0 || originalBlocks > 1) { + copy(lowbandScratch[:n], lowband[:n]) + lowband = lowbandScratch[:n] + } + for k := range recombine { + if lowband != nil { + haar1(lowband, n>>k, 1<>= recombine + nPerBlock <<= recombine + for (nPerBlock&1) == 0 && tfChange < 0 { + if lowband != nil { + haar1(lowband, nPerBlock, blocks) + } + fill |= fill << blocks + blocks <<= 1 + nPerBlock >>= 1 + timeDivide++ + tfChange++ + } + originalBlocks = blocks + if originalBlocks > 1 { + if lowband != nil { + deinterleaveHadamard( + lowband, + nPerBlock>>recombine, + originalBlocks< 2 { + // Section 4.3.4.4 splits oversized codebooks recursively so PVQ + // indices stay within the range coder's bounded integer coding. + n >>= 1 + y = x[n:] + x = x[:n] + split = true + lm-- + if blocks == 1 { + fill = (fill & 1) | (fill << 1) + } + blocks = (blocks + 1) >> 1 + } + + if split { + pulseCap := logN400[band] + lm*(1<>1)-thetaOffset, pulseCap, stereo) + if stereo && band >= intensity { + qn = 1 + } + tell := int(state.rangeDecoder.TellFrac()) + itheta := 0 + if qn != 1 { + itheta = decodeBandTheta(qn, n, stereo, originalBlocks, state.rangeDecoder) + itheta = itheta * 16384 / qn + } else if stereo { + if bandBits > 2< 2< 8192 { + x2, y2 = y, x + } + sign := uint32(0) + if sideBits != 0 { + sign = state.rangeDecoder.DecodeRawBits(1) + } + signScale := float32(1) + if sign != 0 { + signScale = -1 + } + collapseMask = quantBand(band, x2, nil, n, midBits, spread, blocks, intensity, tfChange, lowband, remainingBits, lm, lowbandOut, level, gain, lowbandScratch, originalFill, state) + y2[0] = -signScale * x2[1] + y2[1] = signScale * x2[0] + x0 := mid * x[0] + x1 := mid * x[1] + y0 := side * y[0] + y1 := side * y[1] + x[0] = x0 - y0 + y[0] = x0 + y0 + x[1] = x1 - y1 + y[1] = x1 + y1 + } else { + if originalBlocks > 1 && !stereo && itheta&0x3fff != 0 { + if itheta > 8192 { + delta -= delta >> (4 - lm) + } else { + delta = min(0, delta+(n<>(5-lm))) + } + } + midBits := max(0, min(bandBits, (bandBits-delta)/2)) + sideBits := bandBits - midBits + *remainingBits -= qalloc + var nextLowband2 []float32 + if lowband != nil && !stereo { + nextLowband2 = lowband[n:] + } + var nextLowbandOut1 []float32 + nextLevel := 0 + if stereo { + nextLowbandOut1 = lowbandOut + } else { + nextLevel = level + 1 + } + collapseShift := 0 + if !stereo { + collapseShift = originalBlocks >> 1 + } + rebalance := *remainingBits + if midBits >= sideBits { + midGain := gain * mid + if stereo { + midGain = 1 + } + collapseMask = quantBand(band, x, nil, n, midBits, spread, blocks, intensity, tfChange, lowband, remainingBits, lm, nextLowbandOut1, nextLevel, midGain, lowbandScratch, fill, state) + rebalance = midBits - (rebalance - *remainingBits) + if rebalance > 3<>blocks, state) << collapseShift + } else { + collapseMask = quantBand(band, y, nil, n, sideBits, spread, blocks, intensity, tfChange, nextLowband2, remainingBits, lm, nil, nextLevel, gain*side, nil, fill>>blocks, state) << collapseShift + rebalance = sideBits - (rebalance - *remainingBits) + if rebalance > 3< 0 { + *remainingBits += currentBits + q-- + currentBits = pulsesToBits(band, lm, q) + *remainingBits -= currentBits + } + if q != 0 { + collapseMask = algUnquant(x, n, getPulses(q), spread, blocks, state.rangeDecoder, gain, state) + } else { + mask := uint(1<> 20) + } + collapseMask = mask + } else { + for i := range n { + state.seed = lcgRand(state.seed) + noise := float32(1.0 / 256) + if state.seed&0x8000 == 0 { + noise = -noise + } + x[i] = lowband[i] + noise + } + collapseMask = fill + } + renormaliseVector(x, n, gain) + } + } + } + + if stereo { + if n != 2 { + stereoMerge(x, y, mid, n) + } + if invert { + for i := range n { + y[i] = -y[i] + } + } + } else if level == 0 { + x = fullBand + if originalBlocks > 1 { + interleaveHadamard(x, nPerBlock>>recombine, originalBlocks<>= 1 + nPerBlock <<= 1 + collapseMask |= collapseMask >> blocks + haar1(x, nPerBlock, blocks) + } + for k := range recombine { + collapseMask = bitDeinterleave(collapseMask) + haar1(x, originalN>>k, 1< int(cache[cache[0]])+12 +} + +// decodeBandTheta decodes the split angle used for mono split bands and stereo +// mid/side coupling in RFC 6716 Section 4.3.4.4. +func decodeBandTheta(qn int, n int, stereo bool, blocks int, rangeDecoder *rangecoding.Decoder) int { + if stereo && n > 2 { + p0 := uint32(3) + x0 := uint32(qn / 2) + total := p0*(x0+1) + x0 + fs := rangeDecoder.DecodeCumulative(total) + x := uint32(0) + if fs < (x0+1)*p0 { + x = fs / p0 + } else { + x = x0 + 1 + (fs - (x0+1)*p0) + } + var low, high uint32 + if x <= x0 { + low = p0 * x + high = p0 * (x + 1) + } else { + low = (x - 1 - x0) + (x0+1)*p0 + high = (x - x0) + (x0+1)*p0 + } + rangeDecoder.UpdateCumulative(low, high, total) + + return int(x) + } + if blocks > 1 || stereo { + value, _ := rangeDecoder.DecodeUniform(uint32(qn + 1)) + + return int(value) + } + + half := qn >> 1 + total := uint32((half + 1) * (half + 1)) + fm := rangeDecoder.DecodeCumulative(total) + var itheta, symbolFrequency, low int + if fm < uint32(half*(half+1)>>1) { + itheta = (int(isqrt32(8*fm+1)) - 1) >> 1 + symbolFrequency = itheta + 1 + low = itheta * (itheta + 1) >> 1 + } else { + itheta = (2*(qn+1) - int(isqrt32(8*(total-fm-1)+1))) >> 1 + symbolFrequency = qn + 1 - itheta + low = int(total) - ((qn + 1 - itheta) * (qn + 2 - itheta) >> 1) + } + rangeDecoder.UpdateCumulative(uint32(low), uint32(low+symbolFrequency), total) + + return itheta +} + +func computeQN(n int, bitsValue int, offset int, pulseCap int, stereo bool) int { + exp2Table8 := [...]int{16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048} + n2 := 2*n - 1 + if stereo && n == 2 { + n2-- + } + qb := min(bitsValue-pulseCap-(4<>1 { + return 1 + } + + return ((exp2Table8[qb&0x7] >> (14 - (qb >> bitResolution))) + 1) >> 1 << 1 +} + +func bitexactCos(x int) int { + tmp := (4096 + x*x) >> 13 + x2 := tmp + x2 = (32767 - x2) + fracMul16(x2, -7651+fracMul16(x2, 8277+fracMul16(-626, x2))) + + return 1 + x2 +} + +func bitexactLog2Tan(isin int, icos int) int { + lc := bits.Len(uint(icos)) + ls := bits.Len(uint(isin)) + icos <<= 15 - lc + isin <<= 15 - ls + + return (ls-lc)*(1<<11) + + fracMul16(isin, fracMul16(isin, -2597)+7932) - + fracMul16(icos, fracMul16(icos, -2597)+7932) +} + +func fracMul16(a int, b int) int { + return (16384 + int(int16(a))*int(int16(b))) >> 15 +} + +func isqrt32(value uint32) uint32 { + if value == 0 { + return 0 + } + g := uint32(0) + bShift := (bits.Len32(value) - 1) >> 1 + b := uint32(1) << bShift + for { + t := ((g << 1) + b) << bShift + if t <= value { + g += b + value -= t + } + if bShift == 0 { + break + } + b >>= 1 + bShift-- + } + + return g +} + +func lcgRand(seed uint32) uint32 { + return 1664525*seed + 1013904223 +} + +func stereoMerge(x []float32, y []float32, mid float32, n int) { + cross := float32(0) + sideEnergy := float32(0) + for i := range n { + cross += x[i] * y[i] + sideEnergy += y[i] * y[i] + } + cross *= mid + leftEnergy := mid*mid + sideEnergy - 2*cross + rightEnergy := mid*mid + sideEnergy + 2*cross + if leftEnergy < 6e-4 || rightEnergy < 6e-4 { + copy(y[:n], x[:n]) + + return + } + leftScale := float32(1 / math.Sqrt(float64(leftEnergy))) + rightScale := float32(1 / math.Sqrt(float64(rightEnergy))) + for i := range n { + left := mid*x[i] - y[i] + right := mid*x[i] + y[i] + x[i] = left * leftScale + y[i] = right * rightScale + } +} + +func haar1(x []float32, n0 int, stride int) { + n0 >>= 1 + for i := range stride { + for j := range n0 { + index0 := stride*2*j + i + index1 := stride*(2*j+1) + i + tmp0 := float32(math.Sqrt(0.5)) * x[index0] + tmp1 := float32(math.Sqrt(0.5)) * x[index1] + x[index0] = tmp0 + tmp1 + x[index1] = tmp0 - tmp1 + } + } +} + +func deinterleaveHadamard(x []float32, n0 int, stride int, hadamard bool, state *bandDecodeState) { + tmp := state.floatScratch(n0 * stride) + if hadamard { + ordery := orderyTable[stride-2:] + for i := range stride { + for j := range n0 { + tmp[ordery[i]*n0+j] = x[j*stride+i] + } + } + } else { + for i := range stride { + for j := range n0 { + tmp[i*n0+j] = x[j*stride+i] + } + } + } + copy(x, tmp) +} + +func interleaveHadamard(x []float32, n0 int, stride int, hadamard bool, state *bandDecodeState) { + tmp := state.floatScratch(n0 * stride) + if hadamard { + ordery := orderyTable[stride-2:] + for i := range stride { + for j := range n0 { + tmp[j*stride+i] = x[ordery[i]*n0+j] + } + } + } else { + for i := range stride { + for j := range n0 { + tmp[j*stride+i] = x[i*n0+j] + } + } + } + copy(x, tmp) +} + +func (s *bandDecodeState) intScratch(n int) []int { + if cap(s.pulseScratch) < n { + s.pulseScratch = make([]int, n) + } + s.pulseScratch = s.pulseScratch[:n] + clear(s.pulseScratch) + + return s.pulseScratch +} + +func (s *bandDecodeState) floatScratch(n int) []float32 { + if cap(s.tmpScratch) < n { + s.tmpScratch = make([]float32, n) + } + s.tmpScratch = s.tmpScratch[:n] + + return s.tmpScratch +} + +func bitInterleave(fill uint) uint { + table := [...]uint{0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3} + + return table[fill&0xF] | table[fill>>4]<<2 +} + +func bitDeinterleave(fill uint) uint { + table := [...]uint{ + 0x00, 0x03, 0x0C, 0x0F, 0x30, 0x33, 0x3C, 0x3F, + 0xC0, 0xC3, 0xCC, 0xCF, 0xF0, 0xF3, 0xFC, 0xFF, + } + + return table[fill&0xF] +} diff --git a/internal/celt/bands_test.go b/internal/celt/bands_test.go new file mode 100644 index 0000000..41ff4d2 --- /dev/null +++ b/internal/celt/bands_test.go @@ -0,0 +1,246 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package celt + +import ( + "testing" + + "github.com/pion/opus/internal/rangecoding" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestQuantBandSingleBin(t *testing.T) { + decoder := rangeDecoderWithRawBits(0b00000001) + state := bandDecodeState{rangeDecoder: &decoder} + x := []float32{0} + y := []float32{0} + remainingBits := 2 << bitResolution + + mask := quantBand( + 0, x, y, 1, 2< +// SPDX-License-Identifier: MIT + +//nolint:varnamelen // CWRS notation follows RFC/reference vector names. +package celt + +import "github.com/pion/opus/internal/rangecoding" + +// decodePulses implements the RFC 6716 Section 4.3.4.2 CWRS index decode for +// the PVQ pulse vector. The row buffer stores one recurrence row of V(N,K). +func decodePulses(y []int, n, k int, rangeDecoder *rangecoding.Decoder) { + for i := range n { + y[i] = 0 + } + if k <= 0 { + return + } + + u := cwrsUrow(n, k) + total := u[k] + u[k+1] + index, _ := rangeDecoder.DecodeUniform(total) + cwrsDecode(y, n, k, index, u) +} + +// cwrsUrow initializes the recurrence row needed to count PVQ codewords for a +// vector of n dimensions and up to k pulses. +func cwrsUrow(n, k int) []uint32 { + row := make([]uint32, k+2) + if n == 0 { + row[0] = 1 + + return row + } + row[0] = 0 + if len(row) > 1 { + row[1] = 1 + } + if n == 1 { + for i := 2; i < len(row); i++ { + row[i] = 1 + } + + return row + } + for pulses := 2; pulses < len(row); pulses++ { + row[pulses] = uint32((pulses << 1) - 1) + } + for rowIndex := 2; rowIndex < n; rowIndex++ { + cwrsNextRow(row[1:], 1) + } + + return row +} + +// cwrsNextRow advances the V(N,K) recurrence by one dimension. +func cwrsNextRow(u []uint32, value0 uint32) { + value := value0 + for j := 1; j < len(u); j++ { + next := u[j] + u[j-1] + value + u[j-1] = value + value = next + } + u[len(u)-1] = value +} + +// cwrsDecode walks the recurrence row to recover signs and pulse magnitudes +// from the uniformly decoded codeword index. +func cwrsDecode(y []int, n, k int, index uint32, u []uint32) { + for j := range n { + p := u[k+1] + negative := index >= p + if negative { + index -= p + } + + yj := k + p = u[k] + for p > index { + k-- + p = u[k] + } + index -= p + yj -= k + if negative { + y[j] = -yj + } else { + y[j] = yj + } + cwrsPreviousRow(u, k+2, 0) + } +} + +// cwrsPreviousRow rewinds the recurrence after one coefficient has been +// decoded, matching the row update used by the reference CWRS decoder. +func cwrsPreviousRow(u []uint32, n int, value0 uint32) { + value := value0 + for j := 1; j < n; j++ { + next := u[j] - u[j-1] - value + u[j-1] = value + value = next + } + u[n-1] = value +} diff --git a/internal/celt/cwrs_test.go b/internal/celt/cwrs_test.go new file mode 100644 index 0000000..a97c230 --- /dev/null +++ b/internal/celt/cwrs_test.go @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package celt + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCWRSRows(t *testing.T) { + assert.Equal(t, []uint32{0, 1, 3, 5, 7}, cwrsUrow(2, 3)) + + row := []uint32{1, 3, 5, 7} + cwrsNextRow(row, 1) + assert.Equal(t, []uint32{1, 5, 13, 25}, row) + + cwrsPreviousRow(row, 4, 1) + assert.Equal(t, []uint32{1, 3, 5, 7}, row) +} + +func TestCWRSDecode(t *testing.T) { + y := []int{99, 99, 99} + decodePulses(y, len(y), 0, nil) + assert.Equal(t, []int{0, 0, 0}, y) + + row := cwrsUrow(3, 2) + cwrsDecode(y, len(y), 2, 0, row) + assert.Equal(t, []int{2, 0, 0}, y) + + decoder := rangeDecoderWithCDFSymbol(0, cwrsUrow(3, 2)[2]+cwrsUrow(3, 2)[3]) + decodePulses(y, len(y), 2, &decoder) + assert.Equal(t, []int{2, 0, 0}, y) +} diff --git a/internal/celt/pvq.go b/internal/celt/pvq.go new file mode 100644 index 0000000..7e82f7a --- /dev/null +++ b/internal/celt/pvq.go @@ -0,0 +1,147 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//nolint:varnamelen // PVQ math uses RFC/reference scalar and vector names. +package celt + +import ( + "math" + + "github.com/pion/opus/internal/rangecoding" +) + +const ( + spreadNone = 0 + spreadLight = 1 + spreadNormal = 2 + spreadAggressive = 3 + normScaling = 1 +) + +// algUnquant decodes the RFC 6716 Section 4.3.4.2 PVQ pulse vector, scales it +// to the requested gain, and applies Section 4.3.4.3 spreading rotation. +func algUnquant( + x []float32, + n int, + k int, + spread int, + blocks int, + rangeDecoder *rangecoding.Decoder, + gain float32, + state *bandDecodeState, +) uint { + iy := state.intScratch(n) + decodePulses(iy, n, k, rangeDecoder) + + energy := 0 + for i := range n { + energy += iy[i] * iy[i] + } + normaliseResidual(iy, x, n, energy, gain) + expRotation(x, n, -1, blocks, k, spread) + + return extractCollapseMask(iy, n, blocks) +} + +// normaliseResidual maps integer PVQ pulses back to a floating-point unit +// vector while preserving the band gain supplied by the split decoder. +func normaliseResidual(iy []int, x []float32, n int, energy int, gain float32) { + if energy <= 0 { + for i := range n { + x[i] = 0 + } + + return + } + + scale := gain / float32(math.Sqrt(float64(energy))) + for i := range n { + x[i] = float32(iy[i]) * scale + } +} + +// extractCollapseMask records which transient blocks received non-zero pulses. +func extractCollapseMask(iy []int, n int, blocks int) uint { + if blocks <= 1 { + return 1 + } + + blockSize := n / blocks + mask := uint(0) + for block := range blocks { + for i := range blockSize { + if iy[block*blockSize+i] != 0 { + mask |= 1 << block + } + } + } + + return mask +} + +// renormaliseVector restores unit energy after lowband folding or noise fill. +func renormaliseVector(x []float32, n int, gain float32) { + energy := float32(1e-27) + for i := range n { + energy += x[i] * x[i] + } + + scale := gain / float32(math.Sqrt(float64(energy))) + for i := range n { + x[i] *= scale + } +} + +// expRotation applies RFC 6716 Section 4.3.4.3 spreading rotation. Direction is +// negative when undoing the encoder rotation during decode. +func expRotation(x []float32, length int, direction int, stride int, pulses int, spread int) { + if 2*pulses >= length || spread == spreadNone { + return + } + + factors := [...]int{15, 10, 5} + factor := factors[spread-1] + gain := float64(length) / float64(length+factor*pulses) + theta := 0.5 * gain * gain + c := float32(math.Cos(0.5 * math.Pi * theta)) + s := float32(math.Sin(0.5 * math.Pi * theta)) + + stride2 := 0 + if length >= 8*stride { + stride2 = 1 + for (stride2*stride2+stride2)*stride+(stride>>2) < length { + stride2++ + } + } + + blockLen := length / stride + for block := range stride { + segment := x[block*blockLen : (block+1)*blockLen] + if direction < 0 { + if stride2 != 0 { + expRotation1(segment, blockLen, stride2, s, c) + } + expRotation1(segment, blockLen, 1, c, s) + } else { + expRotation1(segment, blockLen, 1, c, -s) + if stride2 != 0 { + expRotation1(segment, blockLen, stride2, s, -c) + } + } + } +} + +func expRotation1(x []float32, length int, stride int, c float32, s float32) { + for i := 0; i < length-stride; i++ { + x1 := x[i] + x2 := x[i+stride] + x[i+stride] = c*x2 + s*x1 + x[i] = c*x1 - s*x2 + } + for i := length - 2*stride - 1; i >= 0; i-- { + x1 := x[i] + x2 := x[i+stride] + x[i+stride] = c*x2 + s*x1 + x[i] = c*x1 - s*x2 + } +} diff --git a/internal/celt/pvq_test.go b/internal/celt/pvq_test.go new file mode 100644 index 0000000..e27b89c --- /dev/null +++ b/internal/celt/pvq_test.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package celt + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPVQResidualHelpers(t *testing.T) { + x := []float32{9, 9, 9} + normaliseResidual([]int{0, 0, 0}, x, len(x), 0, 1) + assert.Equal(t, []float32{0, 0, 0}, x) + + normaliseResidual([]int{3, 4}, x, 2, 25, 2) + assert.InDelta(t, 1.2, x[0], 0.000001) + assert.InDelta(t, 1.6, x[1], 0.000001) + + assert.Equal(t, uint(1), extractCollapseMask([]int{0, 0}, 2, 1)) + assert.Equal(t, uint(0b101), extractCollapseMask([]int{1, 0, 0, 0, -1, 0}, 6, 3)) + + renormaliseVector(x[:2], 2, 1) + assert.InDelta(t, 1, vectorEnergy(x[:2]), 0.000001) +} + +func TestPVQRotation(t *testing.T) { + x := []float32{1, 2, 3, 4} + expRotation(x, len(x), -1, 1, 1, spreadNone) + assert.Equal(t, []float32{1, 2, 3, 4}, x) + + expRotation(x, len(x), -1, 1, 1, spreadNormal) + assert.NotEqual(t, []float32{1, 2, 3, 4}, x) + assert.InDelta(t, 30, vectorEnergy(x), 0.0001) + + expRotation(x, len(x), 1, 1, 1, spreadNormal) + assert.InDelta(t, 30, vectorEnergy(x), 0.0001) +} + +func TestAlgUnquant(t *testing.T) { + decoder := rangeDecoderWithCDFSymbol(0, cwrsUrow(4, 2)[2]+cwrsUrow(4, 2)[3]) + state := bandDecodeState{} + x := make([]float32, 4) + + mask := algUnquant(x, len(x), 2, spreadNormal, 2, &decoder, 1, &state) + + assert.Equal(t, uint(1), mask) + assert.InDelta(t, 1, vectorEnergy(x), 0.000001) + assert.Len(t, state.pulseScratch, len(x)) +} + +func TestStereoMerge(t *testing.T) { + x := []float32{1, 0} + y := []float32{1, 0} + stereoMerge(x, y, 1, len(x)) + assert.Equal(t, x, y) + + x = []float32{1, 0} + y = []float32{0, 1} + stereoMerge(x, y, 0.5, len(x)) + assert.InDelta(t, 1, vectorEnergy(x), 0.000001) + assert.InDelta(t, 1, vectorEnergy(y), 0.000001) +} + +func vectorEnergy(x []float32) float64 { + energy := float64(0) + for _, value := range x { + energy += math.Pow(float64(value), 2) + } + + return energy +} diff --git a/internal/rangecoding/decoder.go b/internal/rangecoding/decoder.go index 0c8451b..4cecd63 100644 --- a/internal/rangecoding/decoder.go +++ b/internal/rangecoding/decoder.go @@ -111,6 +111,18 @@ func (r *Decoder) Init(data []byte) { r.normalize() } +// SetStorageSize adjusts the logical frame size without resetting decoder +// state. Opus hybrid redundancy removes tail bytes from the CELT range coder +// after the shared decoder has already consumed SILK symbols. +func (r *Decoder) SetStorageSize(size int) { + if size < 0 { + size = 0 + } + if size < len(r.data) { + r.data = r.data[:size] + } +} + // DecodeSymbolWithICDF decodes a single symbol // with a table-based context of up to 8 bits. // @@ -153,6 +165,18 @@ func (r *Decoder) decodeAndUpdateUniformSymbol(total uint32) uint32 { return symbol } +// DecodeCumulative decodes the cumulative frequency index used by CELT's +// custom range-coded symbols. Call UpdateCumulative with the selected interval. +func (r *Decoder) DecodeCumulative(total uint32) uint32 { + return r.decodeUniformSymbol(total) +} + +// UpdateCumulative commits a custom cumulative interval previously selected +// from DecodeCumulative. +func (r *Decoder) UpdateCumulative(low, high, total uint32) { + r.update(r.rangeSize/total, low, high, total) +} + // DecodeUniform decodes an RFC 6716 Section 4.1.5 ec_dec_uint() symbol. // // It returns false when the decoded raw-bit suffix produces a value outside diff --git a/internal/rangecoding/decoder_test.go b/internal/rangecoding/decoder_test.go index 46c3ed6..61c89a8 100644 --- a/internal/rangecoding/decoder_test.go +++ b/internal/rangecoding/decoder_test.go @@ -271,6 +271,27 @@ func TestDecodeUniform(t *testing.T) { }) } +func TestDecodeCumulative(t *testing.T) { + decoder := decoderWithUniformSymbol(2, 5) + + symbol := decoder.DecodeCumulative(5) + decoder.UpdateCumulative(symbol, symbol+1, 5) + + assert.Equal(t, uint32(2), symbol) + assert.NotZero(t, decoder.FinalRange()) +} + +func TestSetStorageSize(t *testing.T) { + decoder := &Decoder{} + decoder.Init([]byte{0x00, 0x01, 0x02}) + + decoder.SetStorageSize(2) + assert.Equal(t, -8, decoder.RemainingBits()) + + decoder.SetStorageSize(-1) + assert.Equal(t, -24, decoder.RemainingBits()) +} + func TestDecodeLaplace(t *testing.T) { zeroFrequency := uint32(72 << 7) decay := uint32(127 << 6)