diff --git a/decoder.go b/decoder.go index 73ad978..f52fa7b 100644 --- a/decoder.go +++ b/decoder.go @@ -42,6 +42,8 @@ type Decoder struct { silkResamplerChannels int hybridSilkResampler [2]silkresample.Resampler hybridSilkChannels int + hybridSilkBuffer []float32 + hybridSilkPCM []float32 silkRedundancyFades []silkRedundancyFade silkCeltAdditions []silkCeltAddition floatBuffer []float32 @@ -755,7 +757,7 @@ func (d *Decoder) decodeHybridFrame( d.rangeDecoder.Init(encodedFrame) silkOutputChannelCount := min(streamChannelCount, outputChannelCount) - silkInternal := make([]float32, silkSamplesPerChannel*silkOutputChannelCount) + silkInternal := resizeFloat32Buffer(&d.hybridSilkBuffer, silkSamplesPerChannel*silkOutputChannelCount) if err := d.silkDecoder.DecodeWithRangeToChannels( &d.rangeDecoder, silkInternal, @@ -792,7 +794,7 @@ func (d *Decoder) decodeHybridFrame( return err } - silkPCM := make([]float32, outputFrameSampleCount*silkOutputChannelCount) + silkPCM := resizeFloat32Buffer(&d.hybridSilkPCM, outputFrameSampleCount*silkOutputChannelCount) if err = d.resampleHybridSilk(silkInternal, silkPCM, silkOutputChannelCount); err != nil { return err } @@ -1363,6 +1365,15 @@ func float32ToInt16(in []float32, out []int16, sampleCount int) { } } +func resizeFloat32Buffer(buffer *[]float32, sampleCount int) []float32 { + if cap(*buffer) < sampleCount { + *buffer = make([]float32, sampleCount) + } + *buffer = (*buffer)[:sampleCount] + + return *buffer +} + // Decode decodes the Opus bitstream into S16LE PCM. func (d *Decoder) Decode(in, out []byte) (bandwidth Bandwidth, isStereo bool, err error) { if cap(d.floatBuffer) < len(out)/2 { diff --git a/internal/celt/celt.go b/internal/celt/celt.go index 2498d6e..712573c 100644 --- a/internal/celt/celt.go +++ b/internal/celt/celt.go @@ -10,6 +10,7 @@ const ( sampleRate = 48000 shortBlockSampleCount = 120 maxLM = 3 + maxFrameSampleCount = shortBlockSampleCount << maxLM maxBands = 21 hybridStartBand = 17 ) diff --git a/internal/celt/decoder.go b/internal/celt/decoder.go index 11572e7..004a629 100644 --- a/internal/celt/decoder.go +++ b/internal/celt/decoder.go @@ -19,6 +19,7 @@ type Decoder struct { preemphasisMem [2]float32 rng uint32 lossCount int + scratch *decoderScratch } // NewDecoder creates a CELT decoder with the static Opus 48 kHz mode. @@ -149,6 +150,7 @@ func (d *Decoder) decode( outputSampleRate int, rangeDecoder *rangecoding.Decoder, ) error { + scratch := d.scratchBuffer() channelCount := 1 if isStereo { channelCount = 2 @@ -189,10 +191,12 @@ func (d *Decoder) decode( return err } if info.silence { - x := make([]float32, frameSampleCount) + x := scratch.x[:frameSampleCount] + clear(x) var y []float32 if isStereo { - y = make([]float32, frameSampleCount) + y = scratch.y[:frameSampleCount] + clear(y) } for channel := range info.channelCount { for band := info.startBand; band < info.endBand; band++ { @@ -213,10 +217,12 @@ func (d *Decoder) decode( // RFC 6716 Sections 4.3.4 through 4.3.7 decode the normalized residual, // optionally repair collapsed transient blocks, then synthesize PCM. - x := make([]float32, frameSampleCount) + x := scratch.x[:frameSampleCount] + clear(x) var y []float32 if isStereo { - y = make([]float32, frameSampleCount) + y = scratch.y[:frameSampleCount] + clear(y) } state := bandDecodeState{ rangeDecoder: &d.rangeDecoder, @@ -284,3 +290,11 @@ func (d *Decoder) Mode() *Mode { func (d *Decoder) FinalRange() uint32 { return d.rangeDecoder.FinalRange() } + +func (d *Decoder) scratchBuffer() *decoderScratch { + if d.scratch == nil { + d.scratch = &decoderScratch{} + } + + return d.scratch +} diff --git a/internal/celt/synthesis.go b/internal/celt/synthesis.go index c453bf5..f6c608c 100644 --- a/internal/celt/synthesis.go +++ b/internal/celt/synthesis.go @@ -27,6 +27,47 @@ type complex32 struct { i float32 } +type decoderScratch struct { + x [maxFrameSampleCount]float32 + y [maxFrameSampleCount]float32 + channels [2]channelScratch + postfilter [2][postfilterHistorySampleCount + maxFrameSampleCount]float32 +} + +type channelScratch struct { + freq [maxFrameSampleCount]float32 + accumulated [maxFrameSampleCount + shortBlockSampleCount]float32 + blockFreq [maxFrameSampleCount]float32 + time [maxFrameSampleCount]float32 + mdct mdctScratch +} + +type mdctScratch struct { + preRotated [maxFrameSampleCount / 2]complex32 + fftOut [maxFrameSampleCount / 2]complex32 + fftWork [maxFrameSampleCount / 2]complex32 + postRotated [maxFrameSampleCount]float32 + deshuffled [maxFrameSampleCount]float32 + out [maxFrameSampleCount + shortBlockSampleCount]float32 +} + +type inverseTransformPlan struct { + frameSampleCount int + n4 int + sine float32 + rotateCos []float32 + rotateSinQuarter []float32 + fftCos []float32 + fftSin []float32 +} + +var inverseTransformPlans = [maxLM + 1]inverseTransformPlan{ //nolint:gochecknoglobals + newInverseTransformPlan(shortBlockSampleCount), + newInverseTransformPlan(shortBlockSampleCount << 1), + newInverseTransformPlan(shortBlockSampleCount << 2), + newInverseTransformPlan(shortBlockSampleCount << 3), +} + var energyMeans = [maxBands]float32{ //nolint:gochecknoglobals 6.437500, 6.250000, 5.750000, 5.312500, 5.062500, 4.812500, 4.500000, 4.375000, 4.875000, 4.687500, @@ -103,18 +144,21 @@ func (d *Decoder) denormaliseAndSynthesize( bandEnergy [2][maxBands]float32, out []float32, ) { + scratch := d.scratchBuffer() frameSampleCount := len(x) - freqX := make([]float32, frameSampleCount) + freqX := scratch.channels[0].freq[:frameSampleCount] + clear(freqX) denormaliseBands(info, x, freqX, bandEnergy[0]) limitOutputBandwidth(info, freqX) var freqY []float32 if info.channelCount == 2 { - freqY = make([]float32, frameSampleCount) + freqY = scratch.channels[1].freq[:frameSampleCount] + clear(freqY) denormaliseBands(info, y, freqY, bandEnergy[1]) limitOutputBandwidth(info, freqY) } if info.outputChannelCount == 2 && info.channelCount == 1 { - freqY = make([]float32, frameSampleCount) + freqY = scratch.channels[1].freq[:frameSampleCount] copy(freqY, freqX) } if info.outputChannelCount == 1 && info.channelCount == 2 { @@ -234,7 +278,7 @@ func (d *Decoder) applyPostfilter(info *frameSideInfo, time []float32, channel i d.postfilterMem[channel] = make([]float32, postfilterHistorySampleCount) } mem := d.postfilterMem[channel][:postfilterHistorySampleCount] - buf := make([]float32, postfilterHistorySampleCount+len(time)) + buf := d.scratchBuffer().postfilter[channel][:postfilterHistorySampleCount+len(time)] copy(buf, mem) copy(buf[postfilterHistorySampleCount:], time) @@ -370,8 +414,10 @@ func limitOutputBandwidth(info *frameSideInfo, freq []float32) { // inverseTransformChannel performs the RFC 6716 Section 4.3.7 IMDCT path for // one channel and carries the weighted overlap-add tail into the next frame. func (d *Decoder) inverseTransformChannel(freq []float32, channel int, info *frameSideInfo) []float32 { + channelScratch := &d.scratchBuffer().channels[channel] frameSampleCount := len(freq) - accumulated := make([]float32, frameSampleCount+shortBlockSampleCount) + accumulated := channelScratch.accumulated[:frameSampleCount+shortBlockSampleCount] + clear(accumulated) blockCount := 1 blockSampleCount := frameSampleCount stride := 1 @@ -383,7 +429,7 @@ func (d *Decoder) inverseTransformChannel(freq []float32, channel int, info *fra // Transient spectra are interleaved short MDCTs; non-transient frames are // one long transform. Accumulate either form into a single time buffer. for block := range blockCount { - blockFreq := make([]float32, blockSampleCount) + blockFreq := channelScratch.blockFreq[:blockSampleCount] if info.transient { for i := range blockSampleCount { blockFreq[i] = freq[block+i*stride] @@ -391,13 +437,13 @@ func (d *Decoder) inverseTransformChannel(freq []float32, channel int, info *fra } else { copy(blockFreq, freq) } - blockTime := inverseMDCT(blockFreq) + blockTime := inverseMDCTWithScratch(blockFreq, &channelScratch.mdct) for i := range blockSampleCount + shortBlockSampleCount { accumulated[block*blockSampleCount+i] += blockTime[i] } } - time := make([]float32, frameSampleCount) + time := channelScratch.time[:frameSampleCount] for i := range shortBlockSampleCount { time[i] = accumulated[i] + d.overlap[channel][i] } @@ -410,46 +456,53 @@ func (d *Decoder) inverseTransformChannel(freq []float32, channel int, info *fra // inverseMDCT follows the RFC 6716 Section 4.3.7 low-overlap IMDCT shape: // N frequency samples become 2*N time samples plus the CELT overlap tail. func inverseMDCT(freq []float32) []float32 { + scratch := mdctScratch{} + + return inverseMDCTWithScratch(freq, &scratch) +} + +func inverseMDCTWithScratch(freq []float32, scratch *mdctScratch) []float32 { n2 := len(freq) - n := 2 * n2 - n4 := n >> 2 - sine := float32(2 * math.Pi * 0.125 / float64(n)) - preRotated := make([]complex32, n4) + plan := inverseTransformPlanForFrameSampleCount(n2) + n4 := plan.n4 + preRotated := scratch.preRotated[:n4] // Pack the MDCT input into the complex half-size transform domain before // the inverse complex step, matching the reference mdct_backward staging. for i := range n4 { xp1 := freq[2*i] xp2 := freq[n2-1-2*i] - cosine := float32(math.Cos(2 * math.Pi * float64(i) / float64(n))) - sineQuarter := float32(math.Cos(2 * math.Pi * float64(n4-i) / float64(n))) + cosine := plan.rotateCos[i] + sineQuarter := plan.rotateSinQuarter[i] yr := -xp2*cosine + xp1*sineQuarter yi := -xp2*sineQuarter - xp1*cosine - preRotated[i] = complex32{r: yr - yi*sine, i: yi + yr*sine} + preRotated[i] = complex32{r: yr - yi*plan.sine, i: yi + yr*plan.sine} } - fftOut := inverseComplexDFT(preRotated) - postRotated := make([]float32, n2) + fftOut := scratch.fftOut[:n4] + inverseComplexDFTInto(preRotated, fftOut, scratch.fftWork[:n4], plan) + postRotated := scratch.postRotated[:n2] // Rotate back out of the complex domain and restore the packed even/odd // ordering expected by the time-domain mirror step. for i, value := range fftOut { re := value.r im := value.i - cosine := float32(math.Cos(2 * math.Pi * float64(i) / float64(n))) - sineQuarter := float32(math.Cos(2 * math.Pi * float64(n4-i) / float64(n))) + cosine := plan.rotateCos[i] + sineQuarter := plan.rotateSinQuarter[i] yr := re*cosine - im*sineQuarter yi := im*cosine + re*sineQuarter - postRotated[2*i] = yr - yi*sine - postRotated[2*i+1] = yi + yr*sine + postRotated[2*i] = yr - yi*plan.sine + postRotated[2*i+1] = yi + yr*plan.sine } - deshuffled := make([]float32, n2) + deshuffled := scratch.deshuffled[:n2] for i := range n4 { deshuffled[2*i] = -postRotated[2*i] deshuffled[2*i+1] = postRotated[n2-1-2*i] } overlap := shortBlockSampleCount - out := make([]float32, n2+overlap) + out := scratch.out[:n2+overlap] + clear(out) leftPlain := n4 - overlap/2 // Apply the low-overlap window from RFC 6716 Section 4.3.7. The middle // region is unwindowed; the edges are mirrored for TDAC overlap-add. @@ -480,22 +533,114 @@ func inverseMDCT(freq []float32) []float32 { // implementation. It is kept separate so a later FFT implementation can replace // this step without changing the surrounding RFC 6716 Section 4.3.7 mapping. func inverseComplexDFT(in []complex32) []complex32 { - n := len(in) - out := make([]complex32, n) - for k := range n { - sumR := float32(0) - sumI := float32(0) - for m, value := range in { - angle := 2 * math.Pi * float64(k*m) / float64(n) - cosine := float32(math.Cos(angle)) - sine := float32(math.Sin(angle)) - sumR += value.r*cosine - value.i*sine - sumI += value.r*sine + value.i*cosine + out := make([]complex32, len(in)) + work := make([]complex32, len(in)) + inverseComplexDFTInto(in, out, work, inverseTransformPlanForFrameSampleCount(len(in)*2)) + + return out +} + +func inverseComplexDFTInto(in []complex32, out []complex32, work []complex32, plan *inverseTransformPlan) { + inverseComplexFFTRecursive(in, 1, out, work, len(in), plan) +} + +func inverseComplexFFTRecursive( + in []complex32, + stride int, + out []complex32, + work []complex32, + n int, + plan *inverseTransformPlan, +) { + if n == 1 { + out[0] = in[0] + + return + } + + radix := fftRadix(n) + m := n / radix + for subtransform := range radix { + inverseComplexFFTRecursive( + in[subtransform*stride:], + stride*radix, + work[subtransform*m:(subtransform+1)*m], + out[subtransform*m:(subtransform+1)*m], + m, + plan, + ) + } + + for k := range m { + for frequencyGroup := range radix { + sum := complex32{} + for subtransform := range radix { + value := work[subtransform*m+k] + twiddle := plan.fftTwiddle(subtransform*(k+frequencyGroup*m), n) + sum.r += value.r*twiddle.r - value.i*twiddle.i + sum.i += value.r*twiddle.i + value.i*twiddle.r + } + out[k+frequencyGroup*m] = sum } - out[k] = complex32{r: sumR, i: sumI} } +} - return out +func fftRadix(n int) int { + switch { + case n%2 == 0: + return 2 + case n%3 == 0: + return 3 + case n%5 == 0: + return 5 + default: + return n + } +} + +func newInverseTransformPlan(frameSampleCount int) inverseTransformPlan { + n := 2 * frameSampleCount + n4 := n >> 2 + plan := inverseTransformPlan{ + frameSampleCount: frameSampleCount, + n4: n4, + sine: float32(2 * math.Pi * 0.125 / float64(n)), + rotateCos: make([]float32, n4), + rotateSinQuarter: make([]float32, n4), + fftCos: make([]float32, n4), + fftSin: make([]float32, n4), + } + for i := range n4 { + plan.rotateCos[i] = float32(math.Cos(2 * math.Pi * float64(i) / float64(n))) + plan.rotateSinQuarter[i] = float32(math.Cos(2 * math.Pi * float64(n4-i) / float64(n))) + angle := 2 * math.Pi * float64(i) / float64(n4) + plan.fftCos[i] = float32(math.Cos(angle)) + plan.fftSin[i] = float32(math.Sin(angle)) + } + + return plan +} + +func inverseTransformPlanForFrameSampleCount(frameSampleCount int) *inverseTransformPlan { + for i := range inverseTransformPlans { + if inverseTransformPlans[i].frameSampleCount == frameSampleCount { + return &inverseTransformPlans[i] + } + } + + plan := newInverseTransformPlan(frameSampleCount) + + return &plan +} + +func (p *inverseTransformPlan) fftTwiddle(index, transformSize int) complex32 { + index *= p.n4 / transformSize + index %= p.n4 + if index < 0 { + index += p.n4 + } + + return complex32{r: p.fftCos[index], i: p.fftSin[index]} } func celtWindow(i int) float32 {