diff --git a/internal/bitdepth/bitdepth.go b/internal/bitdepth/bitdepth.go index f382b9d..b3a0994 100644 --- a/internal/bitdepth/bitdepth.go +++ b/internal/bitdepth/bitdepth.go @@ -18,7 +18,7 @@ var ( // Float32ToSigned16 quantizes a float32 PCM sample to signed 16-bit PCM. func Float32ToSigned16(sample float32) int16 { - sample64 := math.Round(float64(sample * 32768)) + sample64 := math.Floor(0.5 + float64(sample*32768)) sample64 = math.Max(sample64, -32768) sample64 = math.Min(sample64, 32767) diff --git a/internal/celt/allocation.go b/internal/celt/allocation.go new file mode 100644 index 0000000..0cb70ae --- /dev/null +++ b/internal/celt/allocation.go @@ -0,0 +1,449 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//nolint:cyclop,gocognit,gocyclo,gosec,lll,maintidx,nestif,nlreturn,wastedassign // Mirrors RFC 6716 allocation flow and bounded entropy-code arithmetic. +package celt + +import "github.com/pion/opus/internal/rangecoding" + +const ( + allocationSteps = 6 + maxFineBits = 8 + fineOffset = 21 +) + +type allocationState struct { + pulses [maxBands]int + fineQuant [maxBands]int + finePriority [maxBands]int + intensity int + dualStereo int + balance int + codedBands int + bits int +} + +func (d *Decoder) computeAllocation(info *frameSideInfo, bits int) allocationState { + state := allocationState{bits: bits} + caps := allocationCaps(info.lm, info.channelCount) + balance := 0 + state.codedBands = computeAllocation( + info.startBand, + info.endBand, + info.bandBoost[:], + caps[:], + info.allocationTrim, + &state.intensity, + &state.dualStereo, + bits, + &balance, + state.pulses[:], + state.fineQuant[:], + state.finePriority[:], + info.channelCount, + info.lm, + &d.rangeDecoder, + ) + state.balance = balance + + return state +} + +func computeAllocation( + start, end int, + offsets []int, + caps []int, + allocationTrim int, + intensity *int, + dualStereo *int, + total int, + balance *int, + pulses []int, + fineQuant []int, + finePriority []int, + channelCount int, + lm int, + rangeDecoder *rangecoding.Decoder, +) int { + if total < 0 { + total = 0 + } + + skipReserved := 0 + if total >= 1< total { + intensityReserved = 0 + } else { + total -= intensityReserved + if total >= 1<>4) + trimOffset[band] = channelCount * bandWidth * (allocationTrim - defaultAllocationTrim - lm) * + (end - band - 1) * (1 << (lm + bitResolution)) >> 6 + if bandWidth<> 1 + psum := 0 + done := false + for band := end - 1; band >= start; band-- { + bandWidth := int(bandEdges[band+1] - bandEdges[band]) + bits := channelCount * bandWidth * int(bandAllocation[mid][band]) << lm >> 2 + if bits > 0 { + bits = max(0, bits+trimOffset[band]) + } + bits += offsets[band] + if bits >= threshold[band] || done { + done = true + psum += min(bits, caps[band]) + } else if bits >= channelCount< total { + hi = mid - 1 + } else { + lo = mid + 1 + } + } + + hi = lo + lo-- + skipStart := start + for band := start; band < end; band++ { + bandWidth := int(bandEdges[band+1] - bandEdges[band]) + bits1Band := channelCount * bandWidth * int(bandAllocation[lo][band]) << lm >> 2 + bits2Band := 0 + if hi >= len(bandAllocation) { + bits2Band = caps[band] + } else { + bits2Band = channelCount * bandWidth * int(bandAllocation[hi][band]) << lm >> 2 + } + if bits1Band > 0 { + bits1Band = max(0, bits1Band+trimOffset[band]) + } + if bits2Band > 0 { + bits2Band = max(0, bits2Band+trimOffset[band]) + } + if lo > 0 { + bits1Band += offsets[band] + } + bits2Band += offsets[band] + if offsets[band] > 0 { + skipStart = band + } + bits2[band] = max(0, bits2Band-bits1Band) + bits1[band] = bits1Band + } + + return interpolateBitsToPulses( + start, + end, + skipStart, + bits1, + bits2, + threshold, + caps, + total, + balance, + skipReserved, + intensity, + intensityReserved, + dualStereo, + dualStereoReserved, + pulses, + fineQuant, + finePriority, + channelCount, + lm, + rangeDecoder, + ) +} + +func interpolateBitsToPulses( + start, end, skipStart int, + bits1 []int, + bits2 []int, + threshold []int, + caps []int, + total int, + balance *int, + skipReserved int, + intensity *int, + intensityReserved int, + dualStereo *int, + dualStereoReserved int, + bits []int, + fineQuant []int, + finePriority []int, + channelCount int, + lm int, + rangeDecoder *rangecoding.Decoder, +) int { + allocationFloor := channelCount << bitResolution + stereo := boolIndex(channelCount > 1) + lo := 0 + hi := 1 << allocationSteps + for range allocationSteps { + mid := (lo + hi) >> 1 + psum := 0 + done := false + for band := end - 1; band >= start; band-- { + tmp := bits1[band] + (mid * bits2[band] >> allocationSteps) + if tmp >= threshold[band] || done { + done = true + psum += min(tmp, caps[band]) + } else if tmp >= allocationFloor { + psum += allocationFloor + } + } + if psum > total { + hi = mid + } else { + lo = mid + } + } + + psum := 0 + done := false + for band := end - 1; band >= start; band-- { + tmp := bits1[band] + (lo * bits2[band] >> allocationSteps) + if tmp < threshold[band] && !done { + if tmp >= allocationFloor { + tmp = allocationFloor + } else { + tmp = 0 + } + } else { + done = true + } + tmp = min(tmp, caps[band]) + bits[band] = tmp + psum += tmp + } + + codedBands := end + for { + codedBands-- + band := codedBands + if band <= skipStart { + total += skipReserved + codedBands++ + break + } + + left := total - psum + perCoeff := left / (int(bandEdges[codedBands+1]) - int(bandEdges[start])) + left -= (int(bandEdges[codedBands+1]) - int(bandEdges[start])) * perCoeff + rem := max(left-(int(bandEdges[band])-int(bandEdges[start])), 0) + bandWidth := int(bandEdges[codedBands+1] - bandEdges[band]) + bandBits := bits[band] + perCoeff*bandWidth + rem + if bandBits >= max(threshold[band], allocationFloor+(1< 0 { + intensityReserved = log2FracTable[band-start] + } + psum += intensityReserved + if bandBits >= allocationFloor { + psum += allocationFloor + bits[band] = allocationFloor + } else { + bits[band] = 0 + } + } + + if intensityReserved > 0 { + value, _ := rangeDecoder.DecodeUniform(uint32(codedBands + 1 - start)) + *intensity = start + int(value) + } else { + *intensity = 0 + } + if *intensity <= start { + total += dualStereoReserved + dualStereoReserved = 0 + } + if dualStereoReserved > 0 { + *dualStereo = int(rangeDecoder.DecodeSymbolLogP(1)) + } else { + *dualStereo = 0 + } + + left := total - psum + perCoeff := left / (int(bandEdges[codedBands]) - int(bandEdges[start])) + left -= (int(bandEdges[codedBands]) - int(bandEdges[start])) * perCoeff + for band := start; band < codedBands; band++ { + bits[band] += perCoeff * int(bandEdges[band+1]-bandEdges[band]) + } + for band := start; band < codedBands; band++ { + tmp := min(left, int(bandEdges[band+1]-bandEdges[band])) + bits[band] += tmp + left -= tmp + } + + currentBalance := 0 + band := start + for ; band < codedBands; band++ { + width := int(bandEdges[band+1] - bandEdges[band]) + n := width << lm + bits[band] += currentBalance + excess := 0 + if n > 1 { + excess = max(bits[band]-caps[band], 0) + bits[band] -= excess + den := channelCount * n + if channelCount == 2 && n > 2 && *dualStereo == 0 && band < *intensity { + den++ + } + ncLogN := den * (logN400[band] + (lm << bitResolution)) + offset := (ncLogN >> 1) - den*fineOffset + if n == 2 { + offset += den << bitResolution >> 2 + } + if bits[band]+offset < den*2<> 2 + } else if bits[band]+offset < den*3<> 3 + } + fineQuant[band] = max(0, (bits[band]+offset+(den<<(bitResolution-1)))/(den< bits[band]>>bitResolution { + fineQuant[band] = bits[band] >> stereo >> bitResolution + } + fineQuant[band] = min(fineQuant[band], maxFineBits) + finePriority[band] = boolIndex(fineQuant[band]*(den<= bits[band]+offset) + bits[band] -= channelCount * fineQuant[band] << bitResolution + } else { + excess = max(0, bits[band]-(channelCount< 0 { + extraFine := min(excess>>(stereo+bitResolution), maxFineBits-fineQuant[band]) + fineQuant[band] += extraFine + extraBits := extraFine * channelCount << bitResolution + finePriority[band] = boolIndex(extraBits >= excess-currentBalance) + excess -= extraBits + } + currentBalance = excess + } + *balance = currentBalance + + for ; band < end; band++ { + fineQuant[band] = bits[band] >> stereo >> bitResolution + bits[band] = 0 + finePriority[band] = boolIndex(fineQuant[band] < 1) + } + + return codedBands +} + +func getPulses(index int) int { + if index < 8 { + return index + } + + return (8 + (index & 7)) << ((index >> 3) - 1) +} + +func bitsToPulses(band, lm, bits int) int { + if bits <= 0 { + return 0 + } + lm++ + cacheStart := int(pulseCacheIndex[lm*maxBands+band]) + if cacheStart < 0 { + return 0 + } + cache := pulseCacheBits[cacheStart:] + lo := 0 + hi := int(cache[0]) + bits-- + for range 6 { + mid := (lo + hi + 1) >> 1 + if int(cache[mid]) >= bits { + hi = mid + } else { + lo = mid + } + } + loBits := -1 + if lo != 0 { + loBits = int(cache[lo]) + } + if bits-loBits <= int(cache[hi])-bits { + return lo + } + + return hi +} + +func pulsesToBits(band, lm, pulses int) int { + if pulses == 0 { + return 0 + } + lm++ + cacheStart := int(pulseCacheIndex[lm*maxBands+band]) + + return int(pulseCacheBits[cacheStart+pulses]) + 1 +} + +func (d *Decoder) decodeFineEnergy(info *frameSideInfo, fineQuant [maxBands]int) { + for band := info.startBand; band < info.endBand; band++ { + if fineQuant[band] <= 0 { + continue + } + for channel := range info.channelCount { + q2 := d.rangeDecoder.DecodeRawBits(uint(fineQuant[band])) + offset := (float32(q2)+0.5)*float32(uint(1)<<(14-fineQuant[band]))/16384 - 0.5 + d.previousLogE[channel][band] += offset + } + } +} + +func (d *Decoder) finalizeFineEnergy(info *frameSideInfo, fineQuant [maxBands]int, finePriority [maxBands]int, bitsLeft int) { + for priority := range 2 { + for band := info.startBand; band < info.endBand && bitsLeft >= info.channelCount; band++ { + if fineQuant[band] >= maxFineBits || finePriority[band] != priority { + continue + } + for channel := range info.channelCount { + q2 := d.rangeDecoder.DecodeRawBits(1) + offset := (float32(q2) - 0.5) * float32(uint(1)<<(14-fineQuant[band]-1)) / 16384 + d.previousLogE[channel][band] += offset + bitsLeft-- + } + } + } +} diff --git a/internal/celt/bands.go b/internal/celt/bands.go new file mode 100644 index 0000000..2cdbd6d --- /dev/null +++ b/internal/celt/bands.go @@ -0,0 +1,737 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//nolint:cyclop,gocognit,gocyclo,gosec,lll,maintidx,nestif,varnamelen,wastedassign // Keeps the PVQ band recursion close to the RFC/C reference. +package celt + +import ( + "math" + "math/bits" + + "github.com/pion/opus/internal/rangecoding" +) + +const ( + qThetaOffset = 4 + qThetaOffsetTwoPhase = 16 +) + +var orderyTable = [...]int{ //nolint:gochecknoglobals + 1, 0, + 3, 0, 2, 1, + 7, 0, 4, 3, 6, 1, 5, 2, + 15, 0, 8, 7, 12, 3, 11, 4, 14, 1, 9, 6, 13, 2, 10, 5, +} + +type bandDecodeState struct { + rangeDecoder *rangecoding.Decoder + seed uint32 +} + +func quantAllBands(info *frameSideInfo, x []float32, y []float32, totalBits int, state *bandDecodeState) []byte { + channelCount := 1 + if y != nil { + channelCount = 2 + } + blocks := 1 + if info.transient { + blocks = 1 << info.lm + } + scale := 1 << info.lm + frameBins := scale * int(bandEdges[maxBands]) + norm := make([]float32, channelCount*frameBins) + norm2 := norm[frameBins:] + lowbandScratch := make([]float32, scale*int(bandEdges[maxBands]-bandEdges[maxBands-1])) + collapseMasks := make([]byte, channelCount*maxBands) + + lowbandOffset := 0 + updateLowband := true + balance := info.allocation.balance + dualStereo := info.allocation.dualStereo != 0 + for band := info.startBand; band < info.endBand; band++ { + tell := int(state.rangeDecoder.TellFrac()) + if band != info.startBand { + balance -= tell + } + remainingBits := totalBits - tell - 1 + bandBits := 0 + if band <= info.allocation.codedBands-1 { + currentBalance := balance / min(3, info.allocation.codedBands-band) + bandBits = max(0, min(16383, min(remainingBits+1, info.allocation.pulses[band]+currentBalance))) + } + + bandStart := scale * int(bandEdges[band]) + bandEnd := scale * int(bandEdges[band+1]) + bandWidth := bandEnd - bandStart + if bandStart-bandWidth >= scale*int(bandEdges[info.startBand]) || band == info.startBand+1 { + if updateLowband || lowbandOffset == 0 { + lowbandOffset = band + } + } + if band == info.startBand+1 && info.startBand+2 <= maxBands { + n1 := scale * int(bandEdges[info.startBand+1]-bandEdges[info.startBand]) + n2 := scale * int(bandEdges[info.startBand+2]-bandEdges[info.startBand+1]) + offset := scale * int(bandEdges[info.startBand]) + if n2 > n1 { + copy(norm[offset+n1:offset+n2], norm[offset+2*n1-n2:offset+n1]) + if channelCount == 2 { + copy(norm2[offset+n1:offset+n2], norm2[offset+2*n1-n2:offset+n1]) + } + } + } + + effectiveLowband := -1 + xMask := uint(0) + yMask := uint(0) + if lowbandOffset != 0 && (info.spread != spreadAggressive || blocks > 1 || info.tfChange[band] < 0) { + effectiveLowband = max(scale*int(bandEdges[info.startBand]), scale*int(bandEdges[lowbandOffset])-bandWidth) + foldStart := lowbandOffset + for { + foldStart-- + if scale*int(bandEdges[foldStart]) <= effectiveLowband { + break + } + } + foldEnd := lowbandOffset - 1 + for { + foldEnd++ + if foldEnd >= band || scale*int(bandEdges[foldEnd]) >= effectiveLowband+bandWidth { + break + } + } + for fold := foldStart; fold < foldEnd; fold++ { + xMask |= uint(collapseMasks[fold*channelCount]) + yMask |= uint(collapseMasks[fold*channelCount+channelCount-1]) + } + } else { + xMask = (1 << blocks) - 1 + yMask = xMask + } + + if dualStereo && band == info.allocation.intensity { + dualStereo = false + for i := scale * int(bandEdges[info.startBand]); i < bandStart; i++ { + norm[i] = 0.5 * (norm[i] + norm2[i]) + } + } + + var lowband []float32 + if effectiveLowband >= 0 { + lowband = norm[effectiveLowband:] + } + if dualStereo { + xMask = quantBand( + band, + x[bandStart:bandEnd], + nil, + bandWidth, + bandBits/2, + info.spread, + blocks, + info.allocation.intensity, + info.tfChange[band], + lowband, + &remainingBits, + info.lm, + norm[bandStart:], + 0, + 1, + lowbandScratch, + xMask, + state, + ) + var lowbandY []float32 + if effectiveLowband >= 0 { + lowbandY = norm2[effectiveLowband:] + } + yMask = quantBand( + band, + y[bandStart:bandEnd], + nil, + bandWidth, + bandBits/2, + info.spread, + blocks, + info.allocation.intensity, + info.tfChange[band], + lowbandY, + &remainingBits, + info.lm, + norm2[bandStart:], + 0, + 1, + lowbandScratch, + yMask, + state, + ) + } else { + xMask = quantBand( + band, + x[bandStart:bandEnd], + yBandSlice(y, bandStart, bandEnd), + bandWidth, + bandBits, + info.spread, + blocks, + info.allocation.intensity, + info.tfChange[band], + lowband, + &remainingBits, + info.lm, + norm[bandStart:], + 0, + 1, + lowbandScratch, + xMask|yMask, + state, + ) + yMask = xMask + } + collapseMasks[band*channelCount] = byte(xMask) + collapseMasks[band*channelCount+channelCount-1] = byte(yMask) + balance += info.allocation.pulses[band] + tell + updateLowband = bandBits > bandWidth<= 1< 0 { + recombine = tfChange + } + if lowband != nil && (recombine != 0 || (nPerBlock&1) == 0 && tfChange < 0 || originalBlocks > 1) { + copy(lowbandScratch[:n], lowband[:n]) + lowband = lowbandScratch[:n] + } + for k := range recombine { + if lowband != nil { + haar1(lowband, n>>k, 1<>= recombine + nPerBlock <<= recombine + for (nPerBlock&1) == 0 && tfChange < 0 { + if lowband != nil { + haar1(lowband, nPerBlock, blocks) + } + fill |= fill << blocks + blocks <<= 1 + nPerBlock >>= 1 + timeDivide++ + tfChange++ + } + originalBlocks = blocks + if originalBlocks > 1 { + if lowband != nil { + deinterleaveHadamard(lowband, nPerBlock>>recombine, originalBlocks< 2 { + n >>= 1 + y = x[n:] + x = x[:n] + split = true + lm-- + if blocks == 1 { + fill = (fill & 1) | (fill << 1) + } + blocks = (blocks + 1) >> 1 + } + + if split { + pulseCap := logN400[band] + lm*(1<>1)-thetaOffset, pulseCap, stereo) + if stereo && band >= intensity { + qn = 1 + } + tell := int(state.rangeDecoder.TellFrac()) + itheta := 0 + if qn != 1 { + itheta = decodeBandTheta(qn, n, stereo, originalBlocks, state.rangeDecoder) + itheta = itheta * 16384 / qn + } else if stereo { + if bandBits > 2< 2< 8192 { + x2, y2 = y, x + } + sign := uint32(0) + if sideBits != 0 { + sign = state.rangeDecoder.DecodeRawBits(1) + } + signScale := float32(1) + if sign != 0 { + signScale = -1 + } + collapseMask = quantBand(band, x2, nil, n, midBits, spread, blocks, intensity, tfChange, lowband, remainingBits, lm, lowbandOut, level, gain, lowbandScratch, originalFill, state) + y2[0] = -signScale * x2[1] + y2[1] = signScale * x2[0] + x0 := mid * x[0] + x1 := mid * x[1] + y0 := side * y[0] + y1 := side * y[1] + x[0] = x0 - y0 + y[0] = x0 + y0 + x[1] = x1 - y1 + y[1] = x1 + y1 + } else { + if originalBlocks > 1 && !stereo && itheta&0x3fff != 0 { + if itheta > 8192 { + delta -= delta >> (4 - lm) + } else { + delta = min(0, delta+(n<>(5-lm))) + } + } + midBits := max(0, min(bandBits, (bandBits-delta)/2)) + sideBits := bandBits - midBits + *remainingBits -= qalloc + var nextLowband2 []float32 + if lowband != nil && !stereo { + nextLowband2 = lowband[n:] + } + var nextLowbandOut1 []float32 + nextLevel := 0 + if stereo { + nextLowbandOut1 = lowbandOut + } else { + nextLevel = level + 1 + } + collapseShift := 0 + if !stereo { + collapseShift = originalBlocks >> 1 + } + rebalance := *remainingBits + if midBits >= sideBits { + midGain := gain * mid + if stereo { + midGain = 1 + } + collapseMask = quantBand(band, x, nil, n, midBits, spread, blocks, intensity, tfChange, lowband, remainingBits, lm, nextLowbandOut1, nextLevel, midGain, lowbandScratch, fill, state) + rebalance = midBits - (rebalance - *remainingBits) + if rebalance > 3<>blocks, state) << collapseShift + } else { + collapseMask = quantBand(band, y, nil, n, sideBits, spread, blocks, intensity, tfChange, nextLowband2, remainingBits, lm, nil, nextLevel, gain*side, nil, fill>>blocks, state) << collapseShift + rebalance = sideBits - (rebalance - *remainingBits) + if rebalance > 3< 0 { + *remainingBits += currentBits + q-- + currentBits = pulsesToBits(band, lm, q) + *remainingBits -= currentBits + } + if q != 0 { + collapseMask = algUnquant(x, n, getPulses(q), spread, blocks, state.rangeDecoder, gain) + } else { + mask := uint(1<> 20) + } + collapseMask = mask + } else { + for i := range n { + state.seed = lcgRand(state.seed) + noise := float32(1.0 / 256) + if state.seed&0x8000 == 0 { + noise = -noise + } + x[i] = lowband[i] + noise + } + collapseMask = fill + } + renormaliseVector(x, n, gain) + } + } + } + + if stereo { + if n != 2 { + stereoMerge(x, y, mid, n) + } + if invert { + for i := range n { + y[i] = -y[i] + } + } + } else if level == 0 { + x = fullBand + if originalBlocks > 1 { + interleaveHadamard(x, nPerBlock>>recombine, originalBlocks<>= 1 + nPerBlock <<= 1 + collapseMask |= collapseMask >> blocks + haar1(x, nPerBlock, blocks) + } + for k := range recombine { + collapseMask = bitDeinterleave(collapseMask) + haar1(x, originalN>>k, 1< int(cache[cache[0]])+12 +} + +func decodeBandTheta(qn int, n int, stereo bool, blocks int, rangeDecoder *rangecoding.Decoder) int { + if stereo && n > 2 { + p0 := uint32(3) + x0 := uint32(qn / 2) + total := p0*(x0+1) + x0 + fs := rangeDecoder.DecodeCumulative(total) + x := uint32(0) + if fs < (x0+1)*p0 { + x = fs / p0 + } else { + x = x0 + 1 + (fs - (x0+1)*p0) + } + var low, high uint32 + if x <= x0 { + low = p0 * x + high = p0 * (x + 1) + } else { + low = (x - 1 - x0) + (x0+1)*p0 + high = (x - x0) + (x0+1)*p0 + } + rangeDecoder.UpdateCumulative(low, high, total) + + return int(x) + } + if blocks > 1 || stereo { + value, _ := rangeDecoder.DecodeUniform(uint32(qn + 1)) + + return int(value) + } + + half := qn >> 1 + total := uint32((half + 1) * (half + 1)) + fm := rangeDecoder.DecodeCumulative(total) + var itheta, symbolFrequency, low int + if fm < uint32(half*(half+1)>>1) { + itheta = (int(isqrt32(8*fm+1)) - 1) >> 1 + symbolFrequency = itheta + 1 + low = itheta * (itheta + 1) >> 1 + } else { + itheta = (2*(qn+1) - int(isqrt32(8*(total-fm-1)+1))) >> 1 + symbolFrequency = qn + 1 - itheta + low = int(total) - ((qn + 1 - itheta) * (qn + 2 - itheta) >> 1) + } + rangeDecoder.UpdateCumulative(uint32(low), uint32(low+symbolFrequency), total) + + return itheta +} + +func computeQN(n int, bitsValue int, offset int, pulseCap int, stereo bool) int { + exp2Table8 := [...]int{16384, 17866, 19483, 21247, 23170, 25267, 27554, 30048} + n2 := 2*n - 1 + if stereo && n == 2 { + n2-- + } + qb := min(bitsValue-pulseCap-(4<>1 { + return 1 + } + + return ((exp2Table8[qb&0x7] >> (14 - (qb >> bitResolution))) + 1) >> 1 << 1 +} + +func bitexactCos(x int) int { + tmp := (4096 + x*x) >> 13 + x2 := tmp + x2 = (32767 - x2) + fracMul16(x2, -7651+fracMul16(x2, 8277+fracMul16(-626, x2))) + + return 1 + x2 +} + +func bitexactLog2Tan(isin int, icos int) int { + lc := bits.Len(uint(icos)) + ls := bits.Len(uint(isin)) + icos <<= 15 - lc + isin <<= 15 - ls + + return (ls-lc)*(1<<11) + + fracMul16(isin, fracMul16(isin, -2597)+7932) - + fracMul16(icos, fracMul16(icos, -2597)+7932) +} + +func fracMul16(a int, b int) int { + return (16384 + int(int16(a))*int(int16(b))) >> 15 +} + +func isqrt32(value uint32) uint32 { + if value == 0 { + return 0 + } + g := uint32(0) + bShift := (bits.Len32(value) - 1) >> 1 + b := uint32(1) << bShift + for { + t := ((g << 1) + b) << bShift + if t <= value { + g += b + value -= t + } + if bShift == 0 { + break + } + b >>= 1 + bShift-- + } + + return g +} + +func lcgRand(seed uint32) uint32 { + return 1664525*seed + 1013904223 +} + +func stereoMerge(x []float32, y []float32, mid float32, n int) { + cross := float32(0) + sideEnergy := float32(0) + for i := range n { + cross += x[i] * y[i] + sideEnergy += y[i] * y[i] + } + cross *= mid + leftEnergy := mid*mid + sideEnergy - 2*cross + rightEnergy := mid*mid + sideEnergy + 2*cross + if leftEnergy < 6e-4 || rightEnergy < 6e-4 { + copy(y[:n], x[:n]) + + return + } + leftScale := float32(1 / math.Sqrt(float64(leftEnergy))) + rightScale := float32(1 / math.Sqrt(float64(rightEnergy))) + for i := range n { + left := mid*x[i] - y[i] + right := mid*x[i] + y[i] + x[i] = left * leftScale + y[i] = right * rightScale + } +} + +func haar1(x []float32, n0 int, stride int) { + n0 >>= 1 + for i := range stride { + for j := range n0 { + index0 := stride*2*j + i + index1 := stride*(2*j+1) + i + tmp0 := float32(math.Sqrt(0.5)) * x[index0] + tmp1 := float32(math.Sqrt(0.5)) * x[index1] + x[index0] = tmp0 + tmp1 + x[index1] = tmp0 - tmp1 + } + } +} + +func deinterleaveHadamard(x []float32, n0 int, stride int, hadamard bool) { + tmp := make([]float32, n0*stride) + if hadamard { + ordery := orderyTable[stride-2:] + for i := range stride { + for j := range n0 { + tmp[ordery[i]*n0+j] = x[j*stride+i] + } + } + } else { + for i := range stride { + for j := range n0 { + tmp[i*n0+j] = x[j*stride+i] + } + } + } + copy(x, tmp) +} + +func interleaveHadamard(x []float32, n0 int, stride int, hadamard bool) { + tmp := make([]float32, n0*stride) + if hadamard { + ordery := orderyTable[stride-2:] + for i := range stride { + for j := range n0 { + tmp[j*stride+i] = x[ordery[i]*n0+j] + } + } + } else { + for i := range stride { + for j := range n0 { + tmp[j*stride+i] = x[i*n0+j] + } + } + } + copy(x, tmp) +} + +func bitInterleave(fill uint) uint { + table := [...]uint{0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3} + + return table[fill&0xF] | table[fill>>4]<<2 +} + +func bitDeinterleave(fill uint) uint { + table := [...]uint{ + 0x00, 0x03, 0x0C, 0x0F, 0x30, 0x33, 0x3C, 0x3F, + 0xC0, 0xC3, 0xCC, 0xCF, 0xF0, 0xF3, 0xFC, 0xFF, + } + + return table[fill&0xF] +} diff --git a/internal/celt/cwrs.go b/internal/celt/cwrs.go new file mode 100644 index 0000000..052e1f8 --- /dev/null +++ b/internal/celt/cwrs.go @@ -0,0 +1,97 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//nolint:varnamelen // CWRS notation intentionally follows the RFC/reference vector variable names. +package celt + +import "github.com/pion/opus/internal/rangecoding" + +func decodePulses(y []int, n, k int, rangeDecoder *rangecoding.Decoder) { + for i := range n { + y[i] = 0 + } + if k <= 0 { + return + } + + u := cwrsUrow(n, k) + total := u[k] + u[k+1] + index, _ := rangeDecoder.DecodeUniform(total) + cwrsDecode(y, n, k, index, u) +} + +func cwrsUrow(n, k int) []uint32 { + row := make([]uint32, k+2) + if len(row) == 0 { + return row + } + if n == 0 { + row[0] = 1 + + return row + } + row[0] = 0 + if len(row) > 1 { + row[1] = 1 + } + if n == 1 { + for i := 2; i < len(row); i++ { + row[i] = 1 + } + + return row + } + for pulses := 2; pulses < len(row); pulses++ { + row[pulses] = uint32((pulses << 1) - 1) + } + for rowIndex := 2; rowIndex < n; rowIndex++ { + cwrsNextRow(row[1:], 1) + } + + return row +} + +func cwrsNextRow(u []uint32, value0 uint32) { + value := value0 + for j := 1; j < len(u); j++ { + next := u[j] + u[j-1] + value + u[j-1] = value + value = next + } + u[len(u)-1] = value +} + +func cwrsDecode(y []int, n, k int, index uint32, u []uint32) { + for j := range n { + p := u[k+1] + negative := index >= p + if negative { + index -= p + } + + yj := k + p = u[k] + for p > index { + k-- + p = u[k] + } + index -= p + yj -= k + if negative { + y[j] = -yj + } else { + y[j] = yj + } + cwrsPreviousRow(u, k+2, 0) + } +} + +func cwrsPreviousRow(u []uint32, n int, value0 uint32) { + value := value0 + for j := 1; j < n; j++ { + next := u[j] - u[j-1] - value + u[j-1] = value + value = next + } + u[n-1] = value +} diff --git a/internal/celt/decoder.go b/internal/celt/decoder.go index 5b5debc..c27dd8b 100644 --- a/internal/celt/decoder.go +++ b/internal/celt/decoder.go @@ -1,16 +1,24 @@ // SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT +//nolint:cyclop,gosec,varnamelen // CELT decode keeps RFC/reference branch structure and vector naming. package celt import "github.com/pion/opus/internal/rangecoding" // Decoder maintains state for the RFC 6716 Section 4.3 CELT layer. type Decoder struct { - mode *Mode - rangeDecoder rangecoding.Decoder - previousLogE [2][maxBands]float32 - overlap [2][]float32 + mode *Mode + rangeDecoder rangecoding.Decoder + previousLogE [2][maxBands]float32 + previousLogE1 [2][maxBands]float32 + previousLogE2 [2][maxBands]float32 + overlap [2][]float32 + postfilterMem [2][]float32 + postfilter postFilterState + preemphasisMem [2]float32 + rng uint32 + lossCount int } // NewDecoder creates a CELT decoder with the static Opus 48 kHz mode. @@ -27,13 +35,177 @@ func (d *Decoder) Reset() { d.rangeDecoder = rangecoding.Decoder{} clear(d.previousLogE[0][:]) clear(d.previousLogE[1][:]) + for channel := range d.previousLogE1 { + for band := range d.previousLogE1[channel] { + d.previousLogE1[channel][band] = -28 + d.previousLogE2[channel][band] = -28 + } + } + clear(d.preemphasisMem[:]) + d.postfilter = postFilterState{} + d.rng = 0 + d.lossCount = 0 for channelIndex := range d.overlap { if cap(d.overlap[channelIndex]) < shortBlockSampleCount { d.overlap[channelIndex] = make([]float32, shortBlockSampleCount) } clear(d.overlap[channelIndex]) + if cap(d.postfilterMem[channelIndex]) < postfilterHistorySampleCount { + d.postfilterMem[channelIndex] = make([]float32, postfilterHistorySampleCount) + } + clear(d.postfilterMem[channelIndex]) + } +} + +// Decode decodes one CELT frame into interleaved 48 kHz float PCM. +func (d *Decoder) Decode( + in []byte, + out []float32, + isStereo bool, + outputChannelCount int, + frameSampleCount int, + startBand int, + endBand int, +) error { + return d.decode(in, out, isStereo, outputChannelCount, frameSampleCount, startBand, endBand, nil) +} + +// DecodeWithRange decodes one CELT frame using an Opus range decoder shared +// with the SILK layer in hybrid packets. +func (d *Decoder) DecodeWithRange( + in []byte, + out []float32, + isStereo bool, + outputChannelCount int, + frameSampleCount int, + startBand int, + endBand int, + rangeDecoder *rangecoding.Decoder, +) error { + return d.decode(in, out, isStereo, outputChannelCount, frameSampleCount, startBand, endBand, rangeDecoder) +} + +func (d *Decoder) decode( + in []byte, + out []float32, + isStereo bool, + outputChannelCount int, + frameSampleCount int, + startBand int, + endBand int, + rangeDecoder *rangecoding.Decoder, +) error { + channelCount := 1 + if isStereo { + channelCount = 2 + } + if outputChannelCount != 1 && outputChannelCount != 2 { + return errInvalidChannelCount + } + if len(out) < frameSampleCount*outputChannelCount { + return errInvalidFrameSize + } + + cfg := frameConfig{ + frameSampleCount: frameSampleCount, + startBand: startBand, + endBand: endBand, + channelCount: channelCount, + outputChannelCount: outputChannelCount, + } + if len(in) <= 1 { + info, err := d.validateFrameConfig(cfg) + if err != nil { + return err + } + d.decodeLostFrame(&info, out[:frameSampleCount*outputChannelCount]) + + return nil + } + + info, err := d.decodeFrameSideInfo(in, cfg, rangeDecoder) + if err != nil { + return err + } + if info.silence { + x := make([]float32, frameSampleCount) + var y []float32 + if isStereo { + y = make([]float32, frameSampleCount) + } + for channel := range info.channelCount { + for band := info.startBand; band < info.endBand; band++ { + d.previousLogE[channel][band] = -28 + } + } + d.denormaliseAndSynthesize(&info, x, y, [2][maxBands]float32{}, out) + d.updateLogEHistory(&info) + d.resetInactiveBandState(&info) + d.rng = d.rangeDecoder.FinalRange() + d.lossCount = 0 + if rangeDecoder != nil { + *rangeDecoder = d.rangeDecoder + } + + return nil + } + + x := make([]float32, frameSampleCount) + var y []float32 + if isStereo { + y = make([]float32, frameSampleCount) + } + state := bandDecodeState{ + rangeDecoder: &d.rangeDecoder, + seed: d.rng, } + totalBits := (int(info.totalBits) << bitResolution) - info.antiCollapseRsv + collapseMasks := quantAllBands(&info, x, y, totalBits, &state) + antiCollapseOn := false + if info.antiCollapseRsv > 0 { + antiCollapseOn = d.rangeDecoder.DecodeRawBits(1) != 0 + } + bitsLeft := int(info.totalBits) - int(d.rangeDecoder.Tell()) + d.finalizeFineEnergy(&info, info.allocation.fineQuant, info.allocation.finePriority, bitsLeft) + if antiCollapseOn { + d.antiCollapse(&info, x, y, collapseMasks, state.seed) + } + + bandEnergy := d.log2Amp(&info) + d.denormaliseAndSynthesize(&info, x, y, bandEnergy, out) + d.updateLogEHistory(&info) + d.resetInactiveBandState(&info) + d.rng = d.rangeDecoder.FinalRange() + d.lossCount = 0 + if rangeDecoder != nil { + *rangeDecoder = d.rangeDecoder + } + + return nil +} + +func (d *Decoder) decodeLostFrame(info *frameSideInfo, out []float32) { + clear(out) + decay := float32(1.5) + if d.lossCount > 0 { + decay = 0.5 + } + for channel := range info.channelCount { + for band := info.startBand; band < info.endBand; band++ { + d.previousLogE[channel][band] -= decay + } + } + if info.channelCount == 1 { + copy(d.previousLogE[1][:], d.previousLogE[0][:]) + } + d.resetInactiveBandState(info) + for channel := range d.overlap { + clear(d.overlap[channel]) + } + clear(d.preemphasisMem[:]) + d.rangeDecoder = rangecoding.Decoder{} + d.lossCount++ } // Mode returns the static CELT mode used by this decoder. @@ -44,3 +216,8 @@ func (d *Decoder) Mode() *Mode { return d.mode } + +// FinalRange exposes the range coder state for RFC conformance tests. +func (d *Decoder) FinalRange() uint32 { + return d.rangeDecoder.FinalRange() +} diff --git a/internal/celt/frame.go b/internal/celt/frame.go index c11eafe..1ee2795 100644 --- a/internal/celt/frame.go +++ b/internal/celt/frame.go @@ -3,7 +3,11 @@ package celt -import "fmt" +import ( + "fmt" + + "github.com/pion/opus/internal/rangecoding" +) const ( postFilterPitchBase = 16 @@ -21,29 +25,33 @@ const ( ) type frameConfig struct { - frameSampleCount int - startBand int - endBand int - channelCount int + frameSampleCount int + startBand int + endBand int + channelCount int + outputChannelCount int } type frameSideInfo struct { - lm int - totalBits uint - startBand int - endBand int - channelCount int - silence bool - postFilter postFilter - transient bool - shortBlockCount int - intraEnergy bool - coarseEnergy [2][maxBands]float32 - tfChange [maxBands]int - tfSelect int - spread int - bandBoost [maxBands]int - allocationTrim int + lm int + totalBits uint + startBand int + endBand int + channelCount int + outputChannelCount int + silence bool + postFilter postFilter + transient bool + shortBlockCount int + intraEnergy bool + coarseEnergy [2][maxBands]float32 + tfChange [maxBands]int + tfSelect int + spread int + bandBoost [maxBands]int + allocationTrim int + allocation allocationState + antiCollapseRsv int } type postFilter struct { @@ -57,14 +65,22 @@ type postFilter struct { // decodeFrameSideInfo consumes the initial CELT symbols through the allocation // header in the order specified by RFC 6716 Table 56. Pulse allocation and PVQ // residual decoding are intentionally left to the following CELT slices. -func (d *Decoder) decodeFrameSideInfo(data []byte, cfg frameConfig) (frameSideInfo, error) { +func (d *Decoder) decodeFrameSideInfo( + data []byte, + cfg frameConfig, + rangeDecoder *rangecoding.Decoder, +) (frameSideInfo, error) { info, err := d.validateFrameConfig(cfg) if err != nil { return frameSideInfo{}, err } info.totalBits = uint(len(data) * 8) - d.rangeDecoder.Init(data) + if rangeDecoder != nil { + d.rangeDecoder = *rangeDecoder + } else { + d.rangeDecoder.Init(data) + } d.decodeSilenceFlag(&info) if info.silence { @@ -79,6 +95,7 @@ func (d *Decoder) decodeFrameSideInfo(data []byte, cfg frameConfig) (frameSideIn d.prepareCoarseEnergyHistory(&info) d.decodeCoarseEnergy(&info) d.decodeAllocationHeader(&info) + d.decodeAllocationAndFineEnergy(&info) return info, nil } @@ -111,14 +128,18 @@ func (d *Decoder) validateFrameConfig(cfg frameConfig) (frameSideInfo, error) { if cfg.channelCount != 1 && cfg.channelCount != 2 { return frameSideInfo{}, errInvalidChannelCount } + if cfg.outputChannelCount != 1 && cfg.outputChannelCount != 2 { + return frameSideInfo{}, errInvalidChannelCount + } return frameSideInfo{ - lm: lm, - startBand: cfg.startBand, - endBand: cfg.endBand, - channelCount: cfg.channelCount, - spread: defaultSpreadDecision, - allocationTrim: defaultAllocationTrim, + lm: lm, + startBand: cfg.startBand, + endBand: cfg.endBand, + channelCount: cfg.channelCount, + outputChannelCount: cfg.outputChannelCount, + spread: defaultSpreadDecision, + allocationTrim: defaultAllocationTrim, }, nil } @@ -251,6 +272,22 @@ func (d *Decoder) decodeAllocationHeader(info *frameSideInfo) { d.decodeAllocationTrim(info, totalBitsEighth) } +// decodeAllocationAndFineEnergy follows RFC 6716 Section 4.3.3 after the +// allocation header: reserve the anti-collapse bit, compute PVQ/fine-energy +// budgets, then decode the first fine-energy refinement pass. +func (d *Decoder) decodeAllocationAndFineEnergy(info *frameSideInfo) { + totalBits := int(info.totalBits) //nolint:gosec // G115: CELT frame bit counts are packet-bounded. + tellFrac := int(d.rangeDecoder.TellFrac()) //nolint:gosec // G115: entropy cursor is packet-bounded. + bits := (totalBits << bitResolution) - tellFrac - 1 + info.antiCollapseRsv = 0 + if info.transient && info.lm >= 2 && bits >= (info.lm+2)< +// SPDX-License-Identifier: MIT + +//nolint:varnamelen // PVQ math uses RFC/reference scalar and vector names. +package celt + +import ( + "math" + + "github.com/pion/opus/internal/rangecoding" +) + +const ( + spreadNone = 0 + spreadLight = 1 + spreadNormal = 2 + spreadAggressive = 3 + normScaling = 1 +) + +func algUnquant( + x []float32, + n int, + k int, + spread int, + blocks int, + rangeDecoder *rangecoding.Decoder, + gain float32, +) uint { + iy := make([]int, n) + decodePulses(iy, n, k, rangeDecoder) + + energy := 0 + for i := range n { + energy += iy[i] * iy[i] + } + normaliseResidual(iy, x, n, energy, gain) + expRotation(x, n, -1, blocks, k, spread) + + return extractCollapseMask(iy, n, blocks) +} + +func normaliseResidual(iy []int, x []float32, n int, energy int, gain float32) { + if energy <= 0 { + for i := range n { + x[i] = 0 + } + + return + } + + scale := gain / float32(math.Sqrt(float64(energy))) + for i := range n { + x[i] = float32(iy[i]) * scale + } +} + +func extractCollapseMask(iy []int, n int, blocks int) uint { + if blocks <= 1 { + return 1 + } + + blockSize := n / blocks + mask := uint(0) + for block := range blocks { + for i := range blockSize { + if iy[block*blockSize+i] != 0 { + mask |= 1 << block + } + } + } + + return mask +} + +func renormaliseVector(x []float32, n int, gain float32) { + energy := float32(1e-27) + for i := range n { + energy += x[i] * x[i] + } + + scale := gain / float32(math.Sqrt(float64(energy))) + for i := range n { + x[i] *= scale + } +} + +func expRotation(x []float32, length int, direction int, stride int, pulses int, spread int) { + if 2*pulses >= length || spread == spreadNone { + return + } + + factors := [...]int{15, 10, 5} + factor := factors[spread-1] + gain := float64(length) / float64(length+factor*pulses) + theta := 0.5 * gain * gain + c := float32(math.Cos(0.5 * math.Pi * theta)) + s := float32(math.Sin(0.5 * math.Pi * theta)) + + stride2 := 0 + if length >= 8*stride { + stride2 = 1 + for (stride2*stride2+stride2)*stride+(stride>>2) < length { + stride2++ + } + } + + blockLen := length / stride + for block := range stride { + segment := x[block*blockLen : (block+1)*blockLen] + if direction < 0 { + if stride2 != 0 { + expRotation1(segment, blockLen, stride2, s, c) + } + expRotation1(segment, blockLen, 1, c, s) + } else { + expRotation1(segment, blockLen, 1, c, -s) + if stride2 != 0 { + expRotation1(segment, blockLen, stride2, s, -c) + } + } + } +} + +func expRotation1(x []float32, length int, stride int, c float32, s float32) { + for i := 0; i < length-stride; i++ { + x1 := x[i] + x2 := x[i+stride] + x[i+stride] = c*x2 + s*x1 + x[i] = c*x1 - s*x2 + } + for i := length - 2*stride - 1; i >= 0; i-- { + x1 := x[i] + x2 := x[i+stride] + x[i+stride] = c*x2 + s*x1 + x[i] = c*x1 - s*x2 + } +} diff --git a/internal/celt/synthesis.go b/internal/celt/synthesis.go new file mode 100644 index 0000000..cecc496 --- /dev/null +++ b/internal/celt/synthesis.go @@ -0,0 +1,469 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//nolint:cyclop,gosec,lll,modernize // Synthesis follows the RFC/reference filter and anti-collapse structure. +package celt + +import "math" + +const ( + combFilterMinPeriod = 15 + combFilterMaxPeriod = 1024 + postfilterHistoryPad = 2 + postfilterHistorySampleCount = combFilterMaxPeriod + postfilterHistoryPad +) + +type postFilterState struct { + period int + oldPeriod int + gain float32 + oldGain float32 + tapset int + oldTapset int +} + +type complex32 struct { + r float32 + i float32 +} + +var energyMeans = [maxBands]float32{ //nolint:gochecknoglobals + 6.437500, 6.250000, 5.750000, 5.312500, 5.062500, + 4.812500, 4.500000, 4.375000, 4.875000, 4.687500, + 4.562500, 4.437500, 4.875000, 4.625000, 4.312500, + 4.500000, 4.375000, 4.625000, 4.750000, 4.437500, + 3.750000, +} + +var celtWindow120 = [shortBlockSampleCount]float32{ //nolint:gochecknoglobals + 6.7286966e-05, 0.00060551348, 0.0016815970, 0.0032947962, 0.0054439943, + 0.0081276923, 0.011344001, 0.015090633, 0.019364886, 0.024163635, + 0.029483315, 0.035319905, 0.041668911, 0.048525347, 0.055883718, + 0.063737999, 0.072081616, 0.080907428, 0.090207705, 0.099974111, + 0.11019769, 0.12086883, 0.13197729, 0.14351214, 0.15546177, + 0.16781389, 0.18055550, 0.19367290, 0.20715171, 0.22097682, + 0.23513243, 0.24960208, 0.26436860, 0.27941419, 0.29472040, + 0.31026818, 0.32603788, 0.34200931, 0.35816177, 0.37447407, + 0.39092462, 0.40749142, 0.42415215, 0.44088423, 0.45766484, + 0.47447104, 0.49127978, 0.50806798, 0.52481261, 0.54149077, + 0.55807973, 0.57455701, 0.59090049, 0.60708841, 0.62309951, + 0.63891306, 0.65450896, 0.66986776, 0.68497077, 0.69980010, + 0.71433873, 0.72857055, 0.74248043, 0.75605424, 0.76927895, + 0.78214257, 0.79463430, 0.80674445, 0.81846456, 0.82978733, + 0.84070669, 0.85121779, 0.86131698, 0.87100183, 0.88027111, + 0.88912479, 0.89756398, 0.90559094, 0.91320904, 0.92042270, + 0.92723738, 0.93365955, 0.93969656, 0.94535671, 0.95064907, + 0.95558353, 0.96017067, 0.96442171, 0.96834849, 0.97196334, + 0.97527906, 0.97830883, 0.98106616, 0.98356480, 0.98581869, + 0.98784191, 0.98964856, 0.99125274, 0.99266849, 0.99390969, + 0.99499004, 0.99592297, 0.99672162, 0.99739874, 0.99796667, + 0.99843728, 0.99882195, 0.99913147, 0.99937606, 0.99956527, + 0.99970802, 0.99981248, 0.99988613, 0.99993565, 0.99996697, + 0.99998518, 0.99999457, 0.99999859, 0.99999982, 1.0000000, +} + +// SmoothFade applies the RFC 6716 CELT transition window over one 2.5 ms +// overlap. The decoder mixes in 48 kHz CELT time, so the reference window +// increment is always one sample here. +func SmoothFade(in1, in2, out []float32, overlap int, channels int) { + for channel := range channels { + for i := range overlap { + w := celtWindow120[i] * celtWindow120[i] + index := i*channels + channel + out[index] = w*in2[index] + (1-w)*in1[index] + } + } +} + +func (d *Decoder) log2Amp(info *frameSideInfo) [2][maxBands]float32 { + energy := [2][maxBands]float32{} + for channel := range info.channelCount { + for band := info.startBand; band < info.endBand; band++ { + lg := minFloat32(32, d.previousLogE[channel][band]+energyMeans[band]) + energy[channel][band] = float32(math.Pow(2, float64(lg))) + } + } + + return energy +} + +func (d *Decoder) denormaliseAndSynthesize( + info *frameSideInfo, + x []float32, + y []float32, + bandEnergy [2][maxBands]float32, + out []float32, +) { + frameSampleCount := len(x) + freqX := make([]float32, frameSampleCount) + denormaliseBands(info, x, freqX, bandEnergy[0]) + var freqY []float32 + if info.channelCount == 2 { + freqY = make([]float32, frameSampleCount) + denormaliseBands(info, y, freqY, bandEnergy[1]) + } + if info.outputChannelCount == 2 && info.channelCount == 1 { + freqY = make([]float32, frameSampleCount) + copy(freqY, freqX) + } + if info.outputChannelCount == 1 && info.channelCount == 2 { + for i := range frameSampleCount { + freqX[i] = 0.5 * (freqX[i] + freqY[i]) + } + freqY = nil + } + + timeX := d.inverseTransformChannel(freqX, 0, info) + d.applyPostfilter(info, timeX, 0) + if info.outputChannelCount == 1 { + d.updatePostfilterState(info) + d.deemphasisAndInterleave(timeX, nil, out, frameSampleCount, 1) + + return + } + timeY := d.inverseTransformChannel(freqY, 1, info) + d.applyPostfilter(info, timeY, 1) + d.updatePostfilterState(info) + d.deemphasisAndInterleave(timeX, timeY, out, frameSampleCount, 2) +} + +func (d *Decoder) antiCollapse(info *frameSideInfo, x []float32, y []float32, collapseMasks []byte, seed uint32) { + channels := [][]float32{x} + if info.channelCount == 2 { + channels = append(channels, y) + } + for band := info.startBand; band < info.endBand; band++ { + n0 := int(bandEdges[band+1] - bandEdges[band]) + n := n0 << info.lm + depth := (1 + info.allocation.pulses[band]) / n + threshold := 0.5 * math.Pow(2, -0.125*float64(depth)) + sqrtInv := 1 / math.Sqrt(float64(n)) + for channel, spectrum := range channels { + prev1 := d.previousLogE1[channel][band] + prev2 := d.previousLogE2[channel][band] + if info.channelCount == 1 { + prev1 = max(prev1, d.previousLogE1[1][band]) + prev2 = max(prev2, d.previousLogE2[1][band]) + } + energyDiff := max(float32(0), d.previousLogE[channel][band]-minFloat32(prev1, prev2)) + radius := 2 * math.Pow(2, -float64(energyDiff)) + if info.lm == maxLM { + radius *= math.Sqrt2 + } + radius = math.Min(threshold, radius) * sqrtInv + bandStart := int(bandEdges[band]) << info.lm + mask := collapseMasks[band*info.channelCount+channel] + renormalize := false + for block := 0; block < 1< shortBlockSampleCount { + current := currentPostfilter(info) + combFilter( + buf, + postfilterHistorySampleCount+shortBlockSampleCount, + period, + max(current.period, combFilterMinPeriod), + len(time)-shortBlockSampleCount, + d.postfilter.gain, + current.gain, + d.postfilter.tapset, + current.tapset, + ) + } + copy(time, buf[postfilterHistorySampleCount:postfilterHistorySampleCount+len(time)]) + copy(mem, buf[len(time):len(time)+postfilterHistorySampleCount]) +} + +func (d *Decoder) updatePostfilterState(info *frameSideInfo) { + current := currentPostfilter(info) + d.postfilter.oldPeriod = d.postfilter.period + d.postfilter.oldGain = d.postfilter.gain + d.postfilter.oldTapset = d.postfilter.tapset + d.postfilter.period = current.period + d.postfilter.gain = current.gain + d.postfilter.tapset = current.tapset + if info.lm != 0 { + d.postfilter.oldPeriod = d.postfilter.period + d.postfilter.oldGain = d.postfilter.gain + d.postfilter.oldTapset = d.postfilter.tapset + } +} + +func currentPostfilter(info *frameSideInfo) postFilterState { + if !info.postFilter.enabled { + return postFilterState{} + } + + return postFilterState{ + period: info.postFilter.period, + gain: info.postFilter.gain, + tapset: info.postFilter.tapset, + } +} + +func combFilter(buf []float32, start int, period0 int, period1 int, n int, gain0 float32, gain1 float32, tapset0 int, tapset1 int) { + gains := [3][3]float32{ + {0.3066406250, 0.2170410156, 0.1296386719}, + {0.4638671875, 0.2680664062, 0}, + {0.7998046875, 0.1000976562, 0}, + } + g00 := gain0 * gains[tapset0][0] + g01 := gain0 * gains[tapset0][1] + g02 := gain0 * gains[tapset0][2] + g10 := gain1 * gains[tapset1][0] + g11 := gain1 * gains[tapset1][1] + g12 := gain1 * gains[tapset1][2] + overlap := min(shortBlockSampleCount, n) + for i := 0; i < overlap; i++ { + window := celtWindow(i) + fade := window * window + index := start + i + buf[index] = buf[index] + + (1-fade)*g00*buf[index-period0] + + (1-fade)*g01*buf[index-period0-1] + + (1-fade)*g01*buf[index-period0+1] + + (1-fade)*g02*buf[index-period0-2] + + (1-fade)*g02*buf[index-period0+2] + + fade*g10*buf[index-period1] + + fade*g11*buf[index-period1-1] + + fade*g11*buf[index-period1+1] + + fade*g12*buf[index-period1-2] + + fade*g12*buf[index-period1+2] + } + for i := overlap; i < n; i++ { + index := start + i + buf[index] = buf[index] + + g10*buf[index-period1] + + g11*buf[index-period1-1] + + g11*buf[index-period1+1] + + g12*buf[index-period1-2] + + g12*buf[index-period1+2] + } +} + +func denormaliseBands(info *frameSideInfo, x []float32, freq []float32, bandEnergy [maxBands]float32) { + scale := 1 << info.lm + for band := info.startBand; band < info.endBand; band++ { + start := scale * int(bandEdges[band]) + end := scale * int(bandEdges[band+1]) + for i := start; i < end; i++ { + freq[i] = x[i] * bandEnergy[band] + } + } +} + +func (d *Decoder) inverseTransformChannel(freq []float32, channel int, info *frameSideInfo) []float32 { + frameSampleCount := len(freq) + accumulated := make([]float32, frameSampleCount+shortBlockSampleCount) + blockCount := 1 + blockSampleCount := frameSampleCount + stride := 1 + if info.transient { + blockCount = 1 << info.lm + blockSampleCount = shortBlockSampleCount + stride = blockCount + } + for block := range blockCount { + blockFreq := make([]float32, blockSampleCount) + if info.transient { + for i := range blockSampleCount { + blockFreq[i] = freq[block+i*stride] + } + } else { + copy(blockFreq, freq) + } + blockTime := inverseMDCT(blockFreq) + for i := range blockSampleCount + shortBlockSampleCount { + accumulated[block*blockSampleCount+i] += blockTime[i] + } + } + + time := make([]float32, frameSampleCount) + for i := range shortBlockSampleCount { + time[i] = accumulated[i] + d.overlap[channel][i] + } + copy(time[shortBlockSampleCount:], accumulated[shortBlockSampleCount:frameSampleCount]) + copy(d.overlap[channel], accumulated[frameSampleCount:frameSampleCount+shortBlockSampleCount]) + + return time +} + +func inverseMDCT(freq []float32) []float32 { + n2 := len(freq) + n := 2 * n2 + n4 := n >> 2 + sine := float32(2 * math.Pi * 0.125 / float64(n)) + preRotated := make([]complex32, n4) + for i := range n4 { + xp1 := freq[2*i] + xp2 := freq[n2-1-2*i] + cosine := float32(math.Cos(2 * math.Pi * float64(i) / float64(n))) + sineQuarter := float32(math.Cos(2 * math.Pi * float64(n4-i) / float64(n))) + yr := -xp2*cosine + xp1*sineQuarter + yi := -xp2*sineQuarter - xp1*cosine + preRotated[i] = complex32{r: yr - yi*sine, i: yi + yr*sine} + } + + fftOut := inverseComplexDFT(preRotated) + postRotated := make([]float32, n2) + for i, value := range fftOut { + re := value.r + im := value.i + cosine := float32(math.Cos(2 * math.Pi * float64(i) / float64(n))) + sineQuarter := float32(math.Cos(2 * math.Pi * float64(n4-i) / float64(n))) + yr := re*cosine - im*sineQuarter + yi := im*cosine + re*sineQuarter + postRotated[2*i] = yr - yi*sine + postRotated[2*i+1] = yi + yr*sine + } + + deshuffled := make([]float32, n2) + for i := range n4 { + deshuffled[2*i] = -postRotated[2*i] + deshuffled[2*i+1] = postRotated[n2-1-2*i] + } + + overlap := shortBlockSampleCount + out := make([]float32, n2+overlap) + leftPlain := n4 - overlap/2 + for i := 0; i < leftPlain; i++ { + out[n4+overlap/2-1-i] = float32(deshuffled[n4-1-i]) + } + for i := leftPlain; i < n4; i++ { + x1 := deshuffled[n4-1-i] + windowIndex := i - leftPlain + out[windowIndex] += -celtWindow(windowIndex) * x1 + out[n4+overlap/2-1-i] += celtWindow(overlap-1-windowIndex) * x1 + } + + for i := 0; i < leftPlain; i++ { + out[n4+overlap/2+i] = deshuffled[n4+i] + } + for i := leftPlain; i < n4; i++ { + x2 := deshuffled[n4+i] + windowIndex := i - leftPlain + out[n2+overlap-1-windowIndex] = celtWindow(windowIndex) * x2 + out[n4+overlap/2+i] = celtWindow(overlap-1-windowIndex) * x2 + } + + return out +} + +func inverseComplexDFT(in []complex32) []complex32 { + n := len(in) + out := make([]complex32, n) + for k := range n { + sumR := float32(0) + sumI := float32(0) + for m, value := range in { + angle := 2 * math.Pi * float64(k*m) / float64(n) + cosine := float32(math.Cos(angle)) + sine := float32(math.Sin(angle)) + sumR += value.r*cosine - value.i*sine + sumI += value.r*sine + value.i*cosine + } + out[k] = complex32{r: sumR, i: sumI} + } + + return out +} + +func celtWindow(i int) float32 { + return celtWindow120[i] +} + +func (d *Decoder) deemphasisAndInterleave(x []float32, y []float32, out []float32, frameSampleCount int, channelCount int) { + for sample := range frameSampleCount { + left := x[sample] + d.preemphasisMem[0] + d.preemphasisMem[0] = 0.85000610 * left + out[sample*channelCount] = left / 32768 + if channelCount == 2 { + right := y[sample] + d.preemphasisMem[1] + d.preemphasisMem[1] = 0.85000610 * right + out[sample*channelCount+1] = right / 32768 + } + } +} + +func minFloat32(a, b float32) float32 { + if a < b { + return a + } + + return b +} diff --git a/internal/celt/tables.go b/internal/celt/tables.go index f43034e..6d75911 100644 --- a/internal/celt/tables.go +++ b/internal/celt/tables.go @@ -5,33 +5,75 @@ package celt // bandEdges are the 2.5 ms CELT band edges from RFC 6716 Table 55. var bandEdges = [...]int16{ //nolint:gochecknoglobals - 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 60, 78, 100, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 34, 40, 48, 60, 78, 100, } // bandAllocation is the static CELT allocation table from RFC 6716 Table 57. -// Rows are energy bands, columns are allocation vectors 0 through 10. -var bandAllocation = [maxBands][11]uint8{ //nolint:gochecknoglobals - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - {90, 80, 75, 69, 63, 56, 49, 40, 34, 29, 20}, - {110, 100, 90, 84, 78, 71, 65, 58, 51, 45, 39}, - {118, 110, 103, 93, 86, 80, 75, 71, 65, 60, 54}, - {126, 119, 112, 104, 95, 89, 83, 80, 76, 70, 65}, - {134, 128, 120, 114, 103, 97, 91, 88, 83, 77, 72}, - {144, 137, 129, 124, 113, 107, 101, 97, 92, 86, 83}, - {152, 145, 137, 132, 123, 117, 111, 107, 102, 96, 93}, - {162, 154, 147, 142, 133, 127, 121, 117, 112, 106, 103}, - {172, 164, 157, 152, 143, 137, 131, 127, 122, 116, 113}, - {200, 200, 198, 194, 183, 177, 171, 167, 162, 156, 153}, - {200, 200, 200, 200, 198, 194, 188, 183, 179, 173, 168}, - {200, 200, 200, 200, 200, 200, 199, 194, 190, 185, 180}, - {200, 200, 200, 200, 200, 200, 200, 200, 199, 194, 190}, - {200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200}, - {200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200}, - {200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200}, - {200, 200, 200, 200, 200, 200, 200, 200, 198, 193, 188}, - {200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200}, - {200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200}, - {200, 200, 200, 200, 200, 200, 200, 200, 198, 193, 188}, +// Rows are allocation vectors 0 through 10; columns are energy bands. +var bandAllocation = [11][maxBands]uint8{ //nolint:gochecknoglobals + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {90, 80, 75, 69, 63, 56, 49, 40, 34, 29, 20, 18, 10, 0, 0, 0, 0, 0, 0, 0, 0}, + {110, 100, 90, 84, 78, 71, 65, 58, 51, 45, 39, 32, 26, 20, 12, 0, 0, 0, 0, 0, 0}, + {118, 110, 103, 93, 86, 80, 75, 70, 65, 59, 53, 47, 40, 31, 23, 15, 4, 0, 0, 0, 0}, + {126, 119, 112, 104, 95, 89, 83, 78, 72, 66, 60, 54, 47, 39, 32, 25, 17, 12, 1, 0, 0}, + {134, 127, 120, 114, 103, 97, 91, 85, 78, 72, 66, 60, 54, 47, 41, 35, 29, 23, 16, 10, 1}, + {144, 137, 130, 124, 113, 107, 101, 95, 88, 82, 76, 70, 64, 57, 51, 45, 39, 33, 26, 15, 1}, + {152, 145, 138, 132, 123, 117, 111, 105, 98, 92, 86, 80, 74, 67, 61, 55, 49, 43, 36, 20, 1}, + {162, 155, 148, 142, 133, 127, 121, 115, 108, 102, 96, 90, 84, 77, 71, 65, 59, 53, 46, 30, 1}, + {172, 165, 158, 152, 143, 137, 131, 125, 118, 112, 106, 100, 94, 87, 81, 75, 69, 63, 56, 45, 20}, + {200, 200, 200, 200, 200, 200, 200, 200, 198, 193, 188, 183, 178, 173, 168, 163, 158, 153, 148, 129, 104}, +} + +var logN400 = [maxBands]int{ //nolint:gochecknoglobals + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 16, 16, 16, 21, 21, 24, 29, 34, 36, +} + +var log2FracTable = [24]int{ //nolint:gochecknoglobals + 0, + 8, 13, + 16, 19, 21, 23, + 24, 26, 27, 28, 29, 30, 31, 32, + 32, 33, 34, 34, 35, 36, 36, 37, 37, +} + +var pulseCacheIndex = [105]int16{ //nolint:gochecknoglobals + -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 41, 41, 41, + 82, 82, 123, 164, 200, 222, 0, 0, 0, 0, 0, 0, 0, 0, 41, + 41, 41, 41, 123, 123, 123, 164, 164, 240, 266, 283, 295, 41, 41, 41, + 41, 41, 41, 41, 41, 123, 123, 123, 123, 240, 240, 240, 266, 266, 305, + 318, 328, 336, 123, 123, 123, 123, 123, 123, 123, 123, 240, 240, 240, 240, + 305, 305, 305, 318, 318, 343, 351, 358, 364, 240, 240, 240, 240, 240, 240, + 240, 240, 305, 305, 305, 305, 343, 343, 343, 351, 351, 370, 376, 382, 387, +} + +var pulseCacheBits = [392]uint8{ //nolint:gochecknoglobals + 40, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 40, 15, 23, 28, + 31, 34, 36, 38, 39, 41, 42, 43, 44, 45, 46, 47, 47, 49, 50, + 51, 52, 53, 54, 55, 55, 57, 58, 59, 60, 61, 62, 63, 63, 65, + 66, 67, 68, 69, 70, 71, 71, 40, 20, 33, 41, 48, 53, 57, 61, + 64, 66, 69, 71, 73, 75, 76, 78, 80, 82, 85, 87, 89, 91, 92, + 94, 96, 98, 101, 103, 105, 107, 108, 110, 112, 114, 117, 119, 121, 123, + 124, 126, 128, 40, 23, 39, 51, 60, 67, 73, 79, 83, 87, 91, 94, + 97, 100, 102, 105, 107, 111, 115, 118, 121, 124, 126, 129, 131, 135, 139, + 142, 145, 148, 150, 153, 155, 159, 163, 166, 169, 172, 174, 177, 179, 35, + 28, 49, 65, 78, 89, 99, 107, 114, 120, 126, 132, 136, 141, 145, 149, + 153, 159, 165, 171, 176, 180, 185, 189, 192, 199, 205, 211, 216, 220, 225, + 229, 232, 239, 245, 251, 21, 33, 58, 79, 97, 112, 125, 137, 148, 157, + 166, 174, 182, 189, 195, 201, 207, 217, 227, 235, 243, 251, 17, 35, 63, + 86, 106, 123, 139, 152, 165, 177, 187, 197, 206, 214, 222, 230, 237, 250, + 25, 31, 55, 75, 91, 105, 117, 128, 138, 146, 154, 161, 168, 174, 180, + 185, 190, 200, 208, 215, 222, 229, 235, 240, 245, 255, 16, 36, 65, 89, + 110, 128, 144, 159, 173, 185, 196, 207, 217, 226, 234, 242, 250, 11, 41, + 74, 103, 128, 151, 172, 191, 209, 225, 241, 255, 9, 43, 79, 110, 138, + 163, 186, 207, 227, 246, 12, 39, 71, 99, 123, 144, 164, 182, 198, 214, + 228, 241, 253, 9, 44, 81, 113, 142, 168, 192, 214, 235, 255, 7, 49, + 90, 127, 160, 191, 220, 247, 6, 51, 95, 134, 170, 203, 234, 7, 47, + 87, 123, 155, 184, 212, 237, 6, 52, 97, 137, 174, 208, 240, 5, 57, + 106, 151, 192, 231, 5, 59, 111, 158, 202, 243, 5, 55, 103, 147, 187, + 224, 5, 60, 113, 161, 206, 248, 4, 65, 122, 175, 224, 4, 67, 127, + 182, 234, } // These ICDFs mirror the CELT symbol PDFs from RFC 6716 Table 56 and the diff --git a/internal/rangecoding/decoder.go b/internal/rangecoding/decoder.go index 0c8451b..4cecd63 100644 --- a/internal/rangecoding/decoder.go +++ b/internal/rangecoding/decoder.go @@ -111,6 +111,18 @@ func (r *Decoder) Init(data []byte) { r.normalize() } +// SetStorageSize adjusts the logical frame size without resetting decoder +// state. Opus hybrid redundancy removes tail bytes from the CELT range coder +// after the shared decoder has already consumed SILK symbols. +func (r *Decoder) SetStorageSize(size int) { + if size < 0 { + size = 0 + } + if size < len(r.data) { + r.data = r.data[:size] + } +} + // DecodeSymbolWithICDF decodes a single symbol // with a table-based context of up to 8 bits. // @@ -153,6 +165,18 @@ func (r *Decoder) decodeAndUpdateUniformSymbol(total uint32) uint32 { return symbol } +// DecodeCumulative decodes the cumulative frequency index used by CELT's +// custom range-coded symbols. Call UpdateCumulative with the selected interval. +func (r *Decoder) DecodeCumulative(total uint32) uint32 { + return r.decodeUniformSymbol(total) +} + +// UpdateCumulative commits a custom cumulative interval previously selected +// from DecodeCumulative. +func (r *Decoder) UpdateCumulative(low, high, total uint32) { + r.update(r.rangeSize/total, low, high, total) +} + // DecodeUniform decodes an RFC 6716 Section 4.1.5 ec_dec_uint() symbol. // // It returns false when the decoded raw-bit suffix produces a value outside