diff --git a/.github/scripts/run-rfc6716-conformance.sh b/.github/scripts/run-rfc6716-conformance.sh index f5b8bc6..05ed797 100644 --- a/.github/scripts/run-rfc6716-conformance.sh +++ b/.github/scripts/run-rfc6716-conformance.sh @@ -280,7 +280,7 @@ main() { export OPUS_CONFORMANCE_MARKDOWN="${matrix_file}" set +e - go test -v -tags conformance -run TestRFC6716Conformance . 2>&1 | tee "${log_file}" + go test -v -timeout 60m -tags conformance -run TestRFC6716Conformance . 2>&1 | tee "${log_file}" local test_status="${PIPESTATUS[0]}" set -e diff --git a/conformance_test.go b/conformance_test.go index a05723d..dfcb4d0 100644 --- a/conformance_test.go +++ b/conformance_test.go @@ -15,14 +15,14 @@ import ( "runtime" "strconv" "strings" + "sync" "testing" ) type conformanceKey struct { - vectorSet string - rate int - channels int - vector string + rate int + channels int + vector string } type conformanceResult struct { @@ -48,58 +48,56 @@ func TestRFC6716Conformance(t *testing.T) { "01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12", } - vectorSets := []string{"rfc6716", "rfc8251"} refDir, vectorRoot := conformanceDataPaths(t) opusCompare := buildRFC6716ReferenceTools(t, refDir) results := make(map[conformanceKey]conformanceResult) + var resultsMu sync.Mutex + + t.Run("vectors", func(t *testing.T) { + for _, rate := range rates { + for _, channels := range channelCounts { + for _, vector := range vectors { + key := conformanceKey{ + rate: rate, + channels: channels, + vector: vector, + } + t.Run( + fmt.Sprintf("rate_%d/channels_%d/testvector%s", rate, channels, vector), + func(t *testing.T) { + t.Parallel() - for _, vectorSet := range vectorSets { - vectorDir := filepath.Join(vectorRoot, vectorSet) - t.Run(vectorSet, func(t *testing.T) { - for _, rate := range rates { - for _, channels := range channelCounts { - t.Run(fmt.Sprintf("rate_%d/channels_%d", rate, channels), func(t *testing.T) { - for _, vector := range vectors { - key := conformanceKey{ - vectorSet: vectorSet, - rate: rate, - channels: channels, - vector: vector, - } - ran := false quality := "" - passed := t.Run("testvector"+vector, func(t *testing.T) { - ran = true - bitstream := filepath.Join(vectorDir, "testvector"+vector+".bit") - referencePCM := filepath.Join(vectorDir, "testvector"+vector+".dec") - alternateReferencePCM := filepath.Join(vectorDir, "testvector"+vector+"m.dec") - goPCM := filepath.Join(t.TempDir(), "go.pcm") - - decodeRFC6716Vector(t, rate, channels, bitstream, goPCM) - quality = compareRFC6716Output( - t, - opusCompare, - rate, - channels, - referencePCM, - alternateReferencePCM, - goPCM, - ) - }) - if ran { - results[key] = conformanceResult{passed: passed, quality: quality} - } - } - }) + defer func() { + resultsMu.Lock() + results[key] = conformanceResult{passed: !t.Failed(), quality: quality} + resultsMu.Unlock() + }() + + bitstream := conformanceBitstreamPath(t, vectorRoot, vector) + referencePCMs := conformanceReferencePCMs(vectorRoot, vector) + goPCM := filepath.Join(t.TempDir(), "go.pcm") + + decodeRFC6716Vector(t, rate, channels, bitstream, goPCM) + quality = compareRFC6716Output( + t, + opusCompare, + rate, + channels, + referencePCMs, + goPCM, + ) + }, + ) } } - }) - } + } + }) - printConformanceMatrix(results, vectorSets, rates, channelCounts, vectors) - writeConformanceMarkdown(t, os.Getenv(envConformanceMarkdown), results, vectorSets, rates, channelCounts, vectors) + printConformanceMatrix(results, rates, channelCounts, vectors) + writeConformanceMarkdown(t, os.Getenv(envConformanceMarkdown), results, rates, channelCounts, vectors) } func conformanceDataPaths(t *testing.T) (refDir, vectorRoot string) { @@ -118,42 +116,62 @@ func compareRFC6716Output( t *testing.T, opusCompare string, rate, channels int, - referencePCM, alternateReferencePCM, goPCM string, + referencePCMs []string, + goPCM string, ) string { t.Helper() - out, err := runOpusCompare(opusCompare, rate, channels, referencePCM, goPCM) - if err == nil { - quality := opusCompareQuality(out) - printOpusCompareQuality(t, quality) + checkedReference := false + var failures []string + for _, referencePCM := range referencePCMs { + if _, err := os.Stat(referencePCM); err != nil { + continue + } + checkedReference = true + out, err := runOpusCompare(opusCompare, rate, channels, referencePCM, goPCM) + if err == nil { + quality := opusCompareQuality(out) + printOpusCompareQuality(t, quality) - return quality + return quality + } + failures = append(failures, fmt.Sprintf("%s: %v\n%s", referencePCM, err, out)) + } + if !checkedReference { + t.Fatalf("no reference PCM found among %v", referencePCMs) } - primaryErr := err - primaryOut := out - if _, err := os.Stat(alternateReferencePCM); err != nil { - t.Fatalf("opus_compare failed: %v\n%s", primaryErr, primaryOut) + t.Fatalf("opus_compare failed for all references:\n%s", strings.Join(failures, "\n")) - return "" - } + return "" +} - out, err = runOpusCompare(opusCompare, rate, channels, alternateReferencePCM, goPCM) - if err != nil { - t.Fatalf( - "opus_compare failed for both references: primary=%v alternate=%v\nprimary:\n%s\nalternate:\n%s", - primaryErr, - err, - primaryOut, - out, - ) +func conformanceBitstreamPath(t *testing.T, vectorRoot, vector string) string { + t.Helper() - return "" + // RFC 8251 Section 11 keeps the decoder input bitstreams unchanged, so the + // newer archive is the preferred source and the RFC 6716 archive is an + // equivalent fallback when only the legacy bundle is available. + for _, vectorSet := range []string{"rfc8251", "rfc6716"} { + path := filepath.Join(vectorRoot, vectorSet, "testvector"+vector+".bit") + if _, err := os.Stat(path); err == nil { + return path + } } - quality := opusCompareQuality(out) - printOpusCompareQuality(t, quality) - return quality + t.Fatalf("missing testvector%s.bit in RFC 8251 or RFC 6716 vectors", vector) + + return "" +} + +func conformanceReferencePCMs(vectorRoot, vector string) []string { + // RFC 8251 Section 11 permits either the original RFC 6716 output set or + // one of the updated output sets for the same unchanged input bitstreams. + return []string{ + filepath.Join(vectorRoot, "rfc8251", "testvector"+vector+".dec"), + filepath.Join(vectorRoot, "rfc8251", "testvector"+vector+"m.dec"), + filepath.Join(vectorRoot, "rfc6716", "testvector"+vector+".dec"), + } } func buildRFC6716ReferenceTools(t *testing.T, refDir string) (opusCompare string) { @@ -309,7 +327,6 @@ func printOpusCompareQuality(t *testing.T, quality string) { func printConformanceMatrix( results map[conformanceKey]conformanceResult, - vectorSets []string, rates []int, channelCounts []int, vectors []string, @@ -318,36 +335,33 @@ func printConformanceMatrix( return } - fmt.Println("RFC 6716 / 8251 conformation matrix") + fmt.Println("Opus conformance matrix") fmt.Println("Legend: numeric cells are opus_compare quality percentages; FAIL means the vector did not pass.") + fmt.Println("Inputs use the shared RFC 6716 / RFC 8251 bitstream corpus; accepted references follow RFC 8251 Section 11.") - for _, vectorSet := range vectorSets { - fmt.Printf("\nvector set: %s\n", vectorSet) - printConformanceMatrixRule(vectors) - fmt.Printf("| %-8s | %-2s |", "rate", "ch") - for _, vector := range vectors { - fmt.Printf(" %-*s |", conformanceMatrixVectorCellWidth, vector) - } - fmt.Println() - printConformanceMatrixRule(vectors) - - for _, rate := range rates { - for _, channels := range channelCounts { - fmt.Printf("| %-8d | %-2d |", rate, channels) - for _, vector := range vectors { - key := conformanceKey{ - vectorSet: vectorSet, - rate: rate, - channels: channels, - vector: vector, - } - fmt.Printf(" %-*s |", conformanceMatrixVectorCellWidth, conformanceMatrixCell(results, key)) + printConformanceMatrixRule(vectors) + fmt.Printf("| %-8s | %-2s |", "rate", "ch") + for _, vector := range vectors { + fmt.Printf(" %-*s |", conformanceMatrixVectorCellWidth, vector) + } + fmt.Println() + printConformanceMatrixRule(vectors) + + for _, rate := range rates { + for _, channels := range channelCounts { + fmt.Printf("| %-8d | %-2d |", rate, channels) + for _, vector := range vectors { + key := conformanceKey{ + rate: rate, + channels: channels, + vector: vector, } - fmt.Println() + fmt.Printf(" %-*s |", conformanceMatrixVectorCellWidth, conformanceMatrixCell(results, key)) } + fmt.Println() } - printConformanceMatrixRule(vectors) } + printConformanceMatrixRule(vectors) } const conformanceMatrixVectorCellWidth = 5 @@ -364,7 +378,6 @@ func writeConformanceMarkdown( t *testing.T, path string, results map[conformanceKey]conformanceResult, - vectorSets []string, rates []int, channelCounts []int, vectors []string, @@ -377,35 +390,32 @@ func writeConformanceMarkdown( var b strings.Builder b.WriteString("Legend: numeric cells are `opus_compare` quality percentages; `FAIL` means the vector did not pass.\n\n") - for _, vectorSet := range vectorSets { - fmt.Fprintf(&b, "### %s\n\n", vectorSet) - b.WriteString("| rate | ch |") - for _, vector := range vectors { - fmt.Fprintf(&b, " %s |", vector) - } - b.WriteString("\n| --- | --- |") - for range vectors { - b.WriteString(" --- |") - } - b.WriteString("\n") - - for _, rate := range rates { - for _, channels := range channelCounts { - fmt.Fprintf(&b, "| %d | %d |", rate, channels) - for _, vector := range vectors { - key := conformanceKey{ - vectorSet: vectorSet, - rate: rate, - channels: channels, - vector: vector, - } - fmt.Fprintf(&b, " %s |", conformanceMatrixCell(results, key)) + b.WriteString("Inputs use the shared RFC 6716 / RFC 8251 bitstream corpus; accepted references follow RFC 8251 Section 11.\n\n") + b.WriteString("| rate | ch |") + for _, vector := range vectors { + fmt.Fprintf(&b, " %s |", vector) + } + b.WriteString("\n| --- | --- |") + for range vectors { + b.WriteString(" --- |") + } + b.WriteString("\n") + + for _, rate := range rates { + for _, channels := range channelCounts { + fmt.Fprintf(&b, "| %d | %d |", rate, channels) + for _, vector := range vectors { + key := conformanceKey{ + rate: rate, + channels: channels, + vector: vector, } - b.WriteString("\n") + fmt.Fprintf(&b, " %s |", conformanceMatrixCell(results, key)) } + b.WriteString("\n") } - b.WriteString("\n") } + b.WriteString("\n") if err := os.WriteFile(path, []byte(b.String()), 0o600); err != nil { t.Fatalf("write conformance markdown: %v", err) @@ -428,7 +438,16 @@ func conformanceMatrixCell(results map[conformanceKey]conformanceResult, key con } func conformanceFinalRange(d *Decoder) (uint32, error) { - return d.silkDecoder.FinalRange(), nil + switch d.previousMode { + case configurationModeCELTOnly: + return d.celtDecoder.FinalRange(), nil + case configurationModeSilkOnly: + return d.rangeFinal, nil + case configurationModeHybrid: + return d.rangeFinal, nil + default: + return 0, fmt.Errorf("unsupported final range mode: %s", d.previousMode) + } } func conformancePacketSamplesPerChannel(packet []byte, rate int) (int, error) { diff --git a/decoder.go b/decoder.go index 8aa825e..2dd4f2f 100644 --- a/decoder.go +++ b/decoder.go @@ -219,7 +219,7 @@ func (d *Decoder) resetModeState(mode configurationMode) { // copySilkResamplerToHybrid preserves the WB SILK resampler history across the // normatively continuous WB SILK -> Hybrid transition in RFC 6716 Section 4.5. func (d *Decoder) copySilkResamplerToHybrid() { - if d.sampleRate != celtSampleRate || d.silkResamplerBandwidth != BandwidthWideband || d.silkResamplerChannels == 0 { + if d.silkResamplerBandwidth != BandwidthWideband || d.silkResamplerChannels == 0 { return } for i := range d.hybridSilkResampler { @@ -231,7 +231,7 @@ func (d *Decoder) copySilkResamplerToHybrid() { // copyHybridSilkResamplerToSilk preserves the same WB SILK history for the // reverse Hybrid -> WB SILK transition described by RFC 6716 Section 4.5. func (d *Decoder) copyHybridSilkResamplerToSilk() { - if d.sampleRate != celtSampleRate || d.hybridSilkChannels == 0 { + if d.hybridSilkChannels == 0 { return } for i := range d.silkResampler { @@ -290,6 +290,24 @@ func (c Configuration) decodedSampleRate() int { } } +// sampleCountAtRate converts a 48 kHz CELT-domain length into the caller's +// requested Opus API output rate. +func sampleCountAtRate(samples48 int, outputSampleRate int) int { + return samples48 * outputSampleRate / celtSampleRate +} + +// celtFadeSampleCount returns the 2.5 ms CELT transition overlap at the +// caller's output rate. +func celtFadeSampleCount(outputSampleRate int) int { + return outputSampleRate / 400 +} + +// celtRedundantFrameSampleCount returns the 5 ms redundant CELT frame length +// at the caller's output rate. +func celtRedundantFrameSampleCount(outputSampleRate int) int { + return outputSampleRate / 200 +} + func parseFrameLength(in []byte) (frameLength int, bytesRead int, err error) { if len(in) < 1 { return 0, 0, fmt.Errorf("%w: missing frame length", errMalformedPacket) @@ -356,28 +374,37 @@ func parsePacketFramesCode2(in []byte) ([][]byte, error) { return [][]byte{in[firstFrameStart:firstFrameEnd], in[firstFrameEnd:]}, nil } -func parsePacketPadding(in []byte, offset int) (padding int, newOffset int, err error) { +func parsePacketPadding(in []byte, offset int) (newOffset int, payloadEnd int, err error) { + remaining := len(in) - offset for { // [R6][R7] Padding length bytes are part of the Code 3 header and // must be present before any frame data. - if offset >= len(in) { + if remaining <= 0 { return 0, 0, fmt.Errorf("%w: truncated padding length", errMalformedPacket) } paddingByte := int(in[offset]) offset++ + remaining-- + paddingLength := paddingByte + if paddingByte == 255 { + paddingLength = 254 + } + // RFC 8251 Section 4 hardens the reference parser by decrementing the + // remaining packet length as each padding byte is consumed, rather than + // accumulating a potentially overflowing padding total. + if paddingLength > remaining { + return 0, 0, fmt.Errorf("%w: padding overruns packet", errMalformedPacket) + } + remaining -= paddingLength if paddingByte == 255 { - padding += 254 - continue } - padding += paddingByte - break } - return padding, offset, nil + return offset, offset + remaining, nil } func parsePacketFramesCode3(in []byte, tocHeader tableOfContentsHeader) ([][]byte, error) { @@ -398,16 +425,15 @@ func parsePacketFramesCode3(in []byte, tocHeader tableOfContentsHeader) ([][]byt } offset := 2 - padding := 0 + payloadEnd := len(in) var err error if hasPadding { - padding, offset, err = parsePacketPadding(in, offset) + offset, payloadEnd, err = parsePacketPadding(in, offset) if err != nil { return nil, err } } - payloadEnd := len(in) - padding // [R6] In CBR Code 3, the padding-length bytes plus trailing padding // must fit within the packet, leaving at least TOC + frame count. // [R7] In VBR Code 3, the same bound applies before frame data. @@ -507,9 +533,16 @@ func parsePacketFrames(in []byte, tocHeader tableOfContentsHeader) ([][]byte, er func (d *Decoder) decode( in []byte, out []float32, -) (bandwidth Bandwidth, isStereo bool, sampleCount int, decodedChannelCount int, err error) { +) ( + bandwidth Bandwidth, + decodedSampleRate int, + isStereo bool, + sampleCount int, + decodedChannelCount int, + err error, +) { if len(in) < 1 { - return 0, false, 0, 0, errTooShortForTableOfContentsHeader + return 0, 0, false, 0, 0, errTooShortForTableOfContentsHeader } tocHeader := tableOfContentsHeader(in[0]) @@ -517,7 +550,7 @@ func (d *Decoder) decode( encodedFrames, err := parsePacketFrames(in, tocHeader) if err != nil { - return 0, false, 0, 0, err + return 0, 0, false, 0, 0, err } switch cfg.mode() { @@ -534,18 +567,27 @@ func (d *Decoder) decode( return d.decodeHybridFrames(cfg, tocHeader, encodedFrames, out) default: - return 0, false, 0, 0, fmt.Errorf("%w: %d", errUnsupportedConfigurationMode, cfg.mode()) + return 0, 0, false, 0, 0, fmt.Errorf("%w: %d", errUnsupportedConfigurationMode, cfg.mode()) } } -// decodeCeltFrames decodes the CELT-only path at CELT's internal 48 kHz rate. +// decodeCeltFrames keeps CELT synthesis in the internal 48 kHz mode while +// emitting PCM at the caller-requested Opus API output rate. func (d *Decoder) decodeCeltFrames( cfg Configuration, tocHeader tableOfContentsHeader, encodedFrames [][]byte, out []float32, -) (bandwidth Bandwidth, isStereo bool, sampleCount int, decodedChannelCount int, err error) { +) ( + bandwidth Bandwidth, + decodedSampleRate int, + isStereo bool, + sampleCount int, + decodedChannelCount int, + err error, +) { frameSampleCount := cfg.celtFrameSampleCount() + outputFrameSampleCount := sampleCountAtRate(frameSampleCount, d.sampleRate) streamChannelCount := 1 if tocHeader.isStereo() { streamChannelCount = 2 @@ -554,7 +596,7 @@ func (d *Decoder) decodeCeltFrames( if decodedChannelCount == 0 { decodedChannelCount = streamChannelCount } - requiredSamples := frameSampleCount * len(encodedFrames) * decodedChannelCount + requiredSamples := outputFrameSampleCount * len(encodedFrames) * decodedChannelCount if cap(out) < requiredSamples { d.silkBuffer = make([]float32, requiredSamples) out = d.silkBuffer @@ -566,12 +608,12 @@ func (d *Decoder) decodeCeltFrames( startBand, endBand, err := d.celtDecoder.Mode().BandRangeForSampleRate(cfg.bandwidth().SampleRate()) if err != nil { - return 0, false, 0, 0, err + return 0, 0, false, 0, 0, err } - frameOutputSamples := frameSampleCount * decodedChannelCount + frameOutputSamples := outputFrameSampleCount * decodedChannelCount for i, encodedFrame := range encodedFrames { frameOut := out[i*frameOutputSamples : (i+1)*frameOutputSamples] - if err = d.celtDecoder.Decode( + if err = d.celtDecoder.DecodeToSampleRate( encodedFrame, frameOut, tocHeader.isStereo(), @@ -579,8 +621,9 @@ func (d *Decoder) decodeCeltFrames( frameSampleCount, startBand, endBand, + d.sampleRate, ); err != nil { - return 0, false, 0, 0, err + return 0, 0, false, 0, 0, err } d.previousMode = configurationModeCELTOnly d.previousRedundancy = false @@ -591,7 +634,7 @@ func (d *Decoder) decodeCeltFrames( } } - return cfg.bandwidth(), tocHeader.isStereo(), requiredSamples, decodedChannelCount, nil + return cfg.bandwidth(), d.sampleRate, tocHeader.isStereo(), requiredSamples, decodedChannelCount, nil } // decodeHybridFrames combines the SILK and CELT layers for Hybrid packets. @@ -600,8 +643,16 @@ func (d *Decoder) decodeHybridFrames( tocHeader tableOfContentsHeader, encodedFrames [][]byte, out []float32, -) (bandwidth Bandwidth, isStereo bool, sampleCount int, decodedChannelCount int, err error) { +) ( + bandwidth Bandwidth, + decodedSampleRate int, + isStereo bool, + sampleCount int, + decodedChannelCount int, + err error, +) { frameSampleCount := cfg.hybridFrameSampleCount() + outputFrameSampleCount := sampleCountAtRate(frameSampleCount, d.sampleRate) streamChannelCount := 1 if tocHeader.isStereo() { streamChannelCount = 2 @@ -610,7 +661,7 @@ func (d *Decoder) decodeHybridFrames( if decodedChannelCount == 0 { decodedChannelCount = streamChannelCount } - requiredSamples := frameSampleCount * len(encodedFrames) * decodedChannelCount + requiredSamples := outputFrameSampleCount * len(encodedFrames) * decodedChannelCount if cap(out) < requiredSamples { d.silkBuffer = make([]float32, requiredSamples) out = d.silkBuffer @@ -622,9 +673,9 @@ func (d *Decoder) decodeHybridFrames( startBand, endBand, err := d.celtDecoder.Mode().HybridBandRange(cfg.bandwidth().SampleRate()) if err != nil { - return 0, false, 0, 0, err + return 0, 0, false, 0, 0, err } - frameOutputSamples := frameSampleCount * decodedChannelCount + frameOutputSamples := outputFrameSampleCount * decodedChannelCount silkSamplesPerChannel := frameSampleCount * BandwidthWideband.SampleRate() / celtSampleRate for i, encodedFrame := range encodedFrames { frameOut := out[i*frameOutputSamples : (i+1)*frameOutputSamples] @@ -635,16 +686,17 @@ func (d *Decoder) decodeHybridFrames( streamChannelCount, decodedChannelCount, frameSampleCount, + outputFrameSampleCount, silkSamplesPerChannel, cfg.frameDuration().nanoseconds(), startBand, endBand, ); err != nil { - return 0, false, 0, 0, err + return 0, 0, false, 0, 0, err } } - return cfg.bandwidth(), tocHeader.isStereo(), requiredSamples, decodedChannelCount, nil + return cfg.bandwidth(), d.sampleRate, tocHeader.isStereo(), requiredSamples, decodedChannelCount, nil } type hybridRedundancy struct { @@ -669,6 +721,7 @@ func (d *Decoder) decodeHybridFrame( streamChannelCount int, outputChannelCount int, frameSampleCount int, + outputFrameSampleCount int, silkSamplesPerChannel int, frameNanoseconds int, startBand int, @@ -676,11 +729,13 @@ func (d *Decoder) decodeHybridFrame( ) error { d.rangeDecoder.Init(encodedFrame) - silkInternal := make([]float32, silkSamplesPerChannel*streamChannelCount) - if err := d.silkDecoder.DecodeWithRange( + silkOutputChannelCount := min(streamChannelCount, outputChannelCount) + silkInternal := make([]float32, silkSamplesPerChannel*silkOutputChannelCount) + if err := d.silkDecoder.DecodeWithRangeToChannels( &d.rangeDecoder, silkInternal, isStereo, + silkOutputChannelCount, frameNanoseconds, silk.Bandwidth(BandwidthWideband), ); err != nil { @@ -698,7 +753,7 @@ func (d *Decoder) decodeHybridFrame( d.celtDecoder.Reset() clear(d.celtBuffer) } - if err = d.celtDecoder.DecodeWithRange( + if err = d.celtDecoder.DecodeWithRangeToSampleRate( encodedFrame[:redundancy.celtDataLen], out, isStereo, @@ -706,46 +761,51 @@ func (d *Decoder) decodeHybridFrame( frameSampleCount, startBand, endBand, + d.sampleRate, &d.rangeDecoder, ); err != nil { return err } - silk48 := make([]float32, frameSampleCount*streamChannelCount) - if err = d.resampleHybridSilkTo48(silkInternal, silk48, streamChannelCount); err != nil { + silkPCM := make([]float32, outputFrameSampleCount*silkOutputChannelCount) + if err = d.resampleHybridSilk(silkInternal, silkPCM, silkOutputChannelCount); err != nil { return err } - d.addHybridSilk(out, silk48, streamChannelCount, outputChannelCount, frameSampleCount) + d.addHybridSilk(out, silkPCM, silkOutputChannelCount, outputChannelCount, outputFrameSampleCount) if redundancy.present && !redundancy.celtToSilk { d.celtDecoder.Reset() clear(d.celtBuffer) if err = d.decodeHybridRedundantFrame(&redundancy, isStereo, outputChannelCount, endBand); err != nil { return err } - fadeStart := (frameSampleCount - hybridFadeSampleCount) * outputChannelCount - redundantStart := hybridFadeSampleCount * outputChannelCount - celt.SmoothFade( + fadeSampleCount := celtFadeSampleCount(d.sampleRate) + fadeStart := (outputFrameSampleCount - fadeSampleCount) * outputChannelCount + redundantStart := fadeSampleCount * outputChannelCount + celt.SmoothFadeWithSampleRate( out[fadeStart:], redundancy.audio[redundantStart:], out[fadeStart:], - hybridFadeSampleCount, + fadeSampleCount, outputChannelCount, + d.sampleRate, ) } if redundancy.present && redundancy.celtToSilk { - for sample := range hybridFadeSampleCount { + fadeSampleCount := celtFadeSampleCount(d.sampleRate) + for sample := range fadeSampleCount { for channel := range outputChannelCount { index := sample*outputChannelCount + channel out[index] = redundancy.audio[index] } } - fadeStart := hybridFadeSampleCount * outputChannelCount - celt.SmoothFade( + fadeStart := fadeSampleCount * outputChannelCount + celt.SmoothFadeWithSampleRate( redundancy.audio[fadeStart:], out[fadeStart:], out[fadeStart:], - hybridFadeSampleCount, + fadeSampleCount, outputChannelCount, + d.sampleRate, ) } if len(encodedFrame) <= 1 { @@ -842,8 +902,9 @@ func (d *Decoder) decodeHybridRedundantFrame( outputChannelCount int, endBand int, ) error { - redundancy.audio = make([]float32, hybridRedundantFrameSampleCount*outputChannelCount) - if err := d.celtDecoder.Decode( + redundantOutputSamples := celtRedundantFrameSampleCount(d.sampleRate) + redundancy.audio = make([]float32, redundantOutputSamples*outputChannelCount) + if err := d.celtDecoder.DecodeToSampleRate( redundancy.data, redundancy.audio, isStereo, @@ -851,6 +912,7 @@ func (d *Decoder) decodeHybridRedundantFrame( hybridRedundantFrameSampleCount, 0, endBand, + d.sampleRate, ); err != nil { return err } @@ -859,12 +921,12 @@ func (d *Decoder) decodeHybridRedundantFrame( return nil } -// resampleHybridSilkTo48 lifts the Hybrid packet's WB SILK layer to the 48 kHz -// CELT domain before the two layers are summed. -func (d *Decoder) resampleHybridSilkTo48(in []float32, out []float32, channelCount int) error { +// resampleHybridSilk lifts the Hybrid packet's WB SILK layer into the decoder's +// output-rate domain before the two layers are summed. +func (d *Decoder) resampleHybridSilk(in []float32, out []float32, channelCount int) error { if d.hybridSilkChannels == 0 { for i := range d.hybridSilkResampler { - if err := d.hybridSilkResampler[i].Init(BandwidthWideband.SampleRate(), celtSampleRate); err != nil { + if err := d.hybridSilkResampler[i].Init(BandwidthWideband.SampleRate(), d.sampleRate); err != nil { return err } } @@ -919,16 +981,16 @@ func (d *Decoder) resampleHybridSilkChannel( } // addHybridSilk combines the decoded WB SILK contribution with the CELT layer -// after both are represented at 48 kHz. +// after both are represented in the decoder's output-rate domain. func (d *Decoder) addHybridSilk( out []float32, - silk48 []float32, + silkPCM []float32, streamChannelCount int, outputChannelCount int, samplesPerChannel int, ) { - for i := range silk48 { - silk48[i] = float32(bitdepth.Float32ToSigned16(silk48[i])) / 32768 + for i := range silkPCM { + silkPCM[i] = float32(bitdepth.Float32ToSigned16(silkPCM[i])) / 32768 } for sample := range samplesPerChannel { silkIndex := sample * streamChannelCount @@ -936,13 +998,13 @@ func (d *Decoder) addHybridSilk( switch { case streamChannelCount == outputChannelCount: for channel := range outputChannelCount { - out[outIndex+channel] += silk48[silkIndex+channel] + out[outIndex+channel] += silkPCM[silkIndex+channel] } case streamChannelCount == 1 && outputChannelCount == 2: - out[outIndex] += silk48[silkIndex] - out[outIndex+1] += silk48[silkIndex] + out[outIndex] += silkPCM[silkIndex] + out[outIndex+1] += silkPCM[silkIndex] case streamChannelCount == 2 && outputChannelCount == 1: - out[outIndex] += 0.5 * (silk48[silkIndex] + silk48[silkIndex+1]) + out[outIndex] += 0.5 * (silkPCM[silkIndex] + silkPCM[silkIndex+1]) } } } @@ -956,14 +1018,17 @@ func (d *Decoder) decodeSilkFrames( tocHeader tableOfContentsHeader, encodedFrames [][]byte, out []float32, -) (bandwidth Bandwidth, isStereo bool, sampleCount int, decodedChannelCount int, err error) { +) ( + bandwidth Bandwidth, + decodedSampleRate int, + isStereo bool, + sampleCount int, + decodedChannelCount int, + err error, +) { frameSamplesPerChannel := cfg.silkFrameSampleCount() - frameSampleCount := frameSamplesPerChannel - decodedChannelCount = 1 - if tocHeader.isStereo() { - frameSampleCount *= 2 - decodedChannelCount = 2 - } + decodedChannelCount = silkOutputChannelCount(tocHeader.isStereo(), d.channels) + frameSampleCount := frameSamplesPerChannel * decodedChannelCount d.silkRedundancyFades = d.silkRedundancyFades[:0] d.silkCeltAdditions = d.silkCeltAdditions[:0] requiredSamples := frameSampleCount * len(encodedFrames) @@ -981,19 +1046,20 @@ func (d *Decoder) decodeSilkFrames( previousRedundancy := d.previousRedundancy frameOut := out[i*frameSampleCount : (i+1)*frameSampleCount] d.rangeDecoder.Init(encodedFrame) - err := d.silkDecoder.DecodeWithRange( + err := d.silkDecoder.DecodeWithRangeToChannels( &d.rangeDecoder, frameOut, tocHeader.isStereo(), + decodedChannelCount, cfg.frameDuration().nanoseconds(), silk.Bandwidth(cfg.bandwidth()), ) if err != nil { - return 0, false, 0, 0, err + return 0, 0, false, 0, 0, err } redundancy, err := d.decodeSilkOnlyRedundancyHeader(encodedFrame, cfg.bandwidth()) if err != nil { - return 0, false, 0, 0, err + return 0, 0, false, 0, 0, err } if redundancy.present { if !redundancy.celtToSilk { @@ -1006,13 +1072,13 @@ func (d *Decoder) decodeSilkFrames( decodedChannelCount, redundancy.endBand, ); err != nil { - return 0, false, 0, 0, err + return 0, 0, false, 0, 0, err } d.silkRedundancyFades = append(d.silkRedundancyFades, silkRedundancyFade{ celtToSilk: redundancy.celtToSilk, audio: redundancy.audio, - startSample: i * frameSamplesPerChannel * celtSampleRate / cfg.bandwidth().SampleRate(), - frameSampleCount: frameSamplesPerChannel * celtSampleRate / cfg.bandwidth().SampleRate(), + startSample: i * frameSamplesPerChannel * d.sampleRate / cfg.bandwidth().SampleRate(), + frameSampleCount: frameSamplesPerChannel * d.sampleRate / cfg.bandwidth().SampleRate(), channelCount: decodedChannelCount, }) } @@ -1020,10 +1086,11 @@ func (d *Decoder) decodeSilkFrames( (!redundancy.present || !redundancy.celtToSilk || !previousRedundancy) { endBand, err := d.celtEndBandForSilkBandwidth(cfg.bandwidth()) if err != nil { - return 0, false, 0, 0, err + return 0, 0, false, 0, 0, err } - transitionAudio := make([]float32, hybridFadeSampleCount*decodedChannelCount) - if err = d.celtDecoder.Decode( + fadeSampleCount := celtFadeSampleCount(d.sampleRate) + transitionAudio := make([]float32, fadeSampleCount*decodedChannelCount) + if err = d.celtDecoder.DecodeToSampleRate( []byte{0xff, 0xff}, transitionAudio, tocHeader.isStereo(), @@ -1031,12 +1098,13 @@ func (d *Decoder) decodeSilkFrames( hybridFadeSampleCount, 0, endBand, + d.sampleRate, ); err != nil { - return 0, false, 0, 0, err + return 0, 0, false, 0, 0, err } d.silkCeltAdditions = append(d.silkCeltAdditions, silkCeltAddition{ audio: transitionAudio, - startSample: i * frameSamplesPerChannel * celtSampleRate / cfg.bandwidth().SampleRate(), + startSample: i * frameSamplesPerChannel * d.sampleRate / cfg.bandwidth().SampleRate(), channelCount: decodedChannelCount, }) } @@ -1051,9 +1119,18 @@ func (d *Decoder) decodeSilkFrames( sampleCount = requiredSamples - return cfg.bandwidth(), tocHeader.isStereo(), sampleCount, decodedChannelCount, nil + return cfg.bandwidth(), cfg.bandwidth().SampleRate(), tocHeader.isStereo(), sampleCount, decodedChannelCount, nil +} + +func silkOutputChannelCount(isStereo bool, requestedChannelCount int) int { + if isStereo && requestedChannelCount == 2 { + return 2 + } + + return 1 } +//nolint:cyclop func (d *Decoder) decodeToFloat32( in []byte, out []float32, @@ -1065,20 +1142,29 @@ func (d *Decoder) decodeToFloat32( return 0, 0, false, errInvalidChannelCount } - bandwidth, isStereo, sampleCount, decodedChannelCount, err := d.decode(in, d.silkBuffer) + bandwidth, decodedSampleRate, isStereo, sampleCount, decodedChannelCount, err := d.decode(in, d.silkBuffer) if err != nil { return 0, 0, false, err } - samplesPerChannel = (sampleCount / decodedChannelCount) * d.sampleRate / bandwidth.SampleRate() + samplesPerChannel = (sampleCount / decodedChannelCount) * d.sampleRate / decodedSampleRate requiredSamples := samplesPerChannel * decodedChannelCount if cap(d.resampleBuffer) < requiredSamples { d.resampleBuffer = make([]float32, requiredSamples) } d.resampleBuffer = d.resampleBuffer[:requiredSamples] - if d.sampleRate == bandwidth.SampleRate() { + decodedMode := d.previousMode + switch { + case decodedMode == configurationModeSilkOnly && + decodedSampleRate == bandwidth.SampleRate() && + bandwidth != BandwidthFullband: + // The RFC SILK decoder resampler has delay even for same-rate copy paths. + if err = d.resampleSilk(d.silkBuffer[:sampleCount], d.resampleBuffer, decodedChannelCount, bandwidth); err != nil { + return 0, 0, false, err + } + case d.sampleRate == decodedSampleRate: copy(d.resampleBuffer, d.silkBuffer[:sampleCount]) - } else { + default: if err = d.resampleSilk(d.silkBuffer[:sampleCount], d.resampleBuffer, decodedChannelCount, bandwidth); err != nil { return 0, 0, false, err } @@ -1103,9 +1189,7 @@ func (d *Decoder) applySilkRedundancyFades(channelCount int) { additions := d.silkCeltAdditions d.silkRedundancyFades = d.silkRedundancyFades[:0] d.silkCeltAdditions = d.silkCeltAdditions[:0] - if d.sampleRate != celtSampleRate { - return - } + fadeSampleCount := celtFadeSampleCount(d.sampleRate) for _, addition := range additions { if addition.channelCount != channelCount { continue @@ -1124,34 +1208,36 @@ func (d *Decoder) applySilkRedundancyFades(channelCount int) { } frameStart := fade.startSample * channelCount if fade.celtToSilk { - copyCount := hybridFadeSampleCount * channelCount + copyCount := fadeSampleCount * channelCount if frameStart+2*copyCount > len(d.resampleBuffer) || copyCount > len(fade.audio) { continue } copy(d.resampleBuffer[frameStart:frameStart+copyCount], fade.audio[:copyCount]) - celt.SmoothFade( + celt.SmoothFadeWithSampleRate( fade.audio[copyCount:], d.resampleBuffer[frameStart+copyCount:], d.resampleBuffer[frameStart+copyCount:], - hybridFadeSampleCount, + fadeSampleCount, channelCount, + d.sampleRate, ) continue } - fadeStart := (fade.startSample + fade.frameSampleCount - hybridFadeSampleCount) * channelCount - redundantStart := hybridFadeSampleCount * channelCount - if fadeStart < 0 || fadeStart+hybridFadeSampleCount*channelCount > len(d.resampleBuffer) || - redundantStart+hybridFadeSampleCount*channelCount > len(fade.audio) { + fadeStart := (fade.startSample + fade.frameSampleCount - fadeSampleCount) * channelCount + redundantStart := fadeSampleCount * channelCount + if fadeStart < 0 || fadeStart+fadeSampleCount*channelCount > len(d.resampleBuffer) || + redundantStart+fadeSampleCount*channelCount > len(fade.audio) { continue } - celt.SmoothFade( + celt.SmoothFadeWithSampleRate( d.resampleBuffer[fadeStart:], fade.audio[redundantStart:], d.resampleBuffer[fadeStart:], - hybridFadeSampleCount, + fadeSampleCount, channelCount, + d.sampleRate, ) } } diff --git a/decoder_test.go b/decoder_test.go index 06d2941..49c0bc4 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -125,7 +125,7 @@ func TestNewDecoderWithOutput(t *testing.T) { func TestInitResetsCeltState(t *testing.T) { decoder := NewDecoder() - _, stereo, sampleCount, decodedChannelCount, err := decoder.decode( + _, _, stereo, sampleCount, decodedChannelCount, err := decoder.decode( []byte{byte(16<<3) | byte(frameCodeOneFrame), 0xff, 0xff}, nil, ) @@ -159,6 +159,30 @@ func TestDecodeToFloat32(t *testing.T) { assert.ErrorIs(t, err, errOutBufferTooSmall) } +func TestDecodeCeltAtBandwidthSampleRateSkipsSilkResampler(t *testing.T) { + decoder, err := NewDecoderWithOutput(8000, 1) + assert.NoError(t, err) + out := make([]float32, 20) + + sampleCount, err := decoder.DecodeToFloat32([]byte{byte(16<<3) | byte(frameCodeOneFrame)}, out) + + assert.NoError(t, err) + assert.Equal(t, 20, sampleCount) + assert.Zero(t, decoder.silkResamplerBandwidth) +} + +func TestDecodeHybridAtBandwidthSampleRateSkipsSilkResampler(t *testing.T) { + decoder, err := NewDecoderWithOutput(24000, 1) + assert.NoError(t, err) + out := make([]float32, 240) + + sampleCount, err := decoder.DecodeToFloat32([]byte{byte(12<<3) | byte(frameCodeOneFrame)}, out) + + assert.NoError(t, err) + assert.Equal(t, 240, sampleCount) + assert.Zero(t, decoder.silkResamplerBandwidth) +} + func TestDecodeToInt16(t *testing.T) { decoder, err := NewDecoderWithOutput(8000, 1) assert.NoError(t, err) @@ -182,7 +206,7 @@ func TestDecodeSilkFrameDurations(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { decoder := NewDecoder() - _, _, _, _, err := decoder.decode([]byte{byte(test.configuration<<3) | byte(frameCodeOneFrame)}, nil) + _, _, _, _, _, err := decoder.decode([]byte{byte(test.configuration<<3) | byte(frameCodeOneFrame)}, nil) assert.NoError(t, err) assert.Len(t, decoder.silkBuffer, test.sampleCount) }) @@ -218,7 +242,7 @@ func TestDecodedSampleRate(t *testing.T) { func TestDecodeCeltOnly(t *testing.T) { decoder := NewDecoder() - bandwidth, isStereo, sampleCount, _, err := decoder.decode([]byte{byte(16<<3) | byte(frameCodeOneFrame)}, nil) + bandwidth, _, isStereo, sampleCount, _, err := decoder.decode([]byte{byte(16<<3) | byte(frameCodeOneFrame)}, nil) assert.NoError(t, err) assert.Equal(t, BandwidthNarrowband, bandwidth) @@ -231,7 +255,7 @@ func TestDecodeCeltOnly(t *testing.T) { func TestDecodeHybrid(t *testing.T) { decoder := NewDecoder() - bandwidth, isStereo, sampleCount, _, err := decoder.decode([]byte{byte(12<<3) | byte(frameCodeOneFrame)}, nil) + bandwidth, _, isStereo, sampleCount, _, err := decoder.decode([]byte{byte(12<<3) | byte(frameCodeOneFrame)}, nil) assert.NoError(t, err) assert.Equal(t, BandwidthSuperwideband, bandwidth) @@ -341,28 +365,28 @@ func TestAddHybridSilkMapsChannels(t *testing.T) { name string streamChannelCount int outputChannelCount int - silk48 []float32 + silkPCM []float32 expected []float32 }{ { name: "mono", streamChannelCount: 1, outputChannelCount: 1, - silk48: []float32{0.25}, + silkPCM: []float32{0.25}, expected: []float32{0.25}, }, { name: "mono to stereo", streamChannelCount: 1, outputChannelCount: 2, - silk48: []float32{0.25}, + silkPCM: []float32{0.25}, expected: []float32{0.25, 0.25}, }, { name: "stereo to mono", streamChannelCount: 2, outputChannelCount: 1, - silk48: []float32{0.25, 0.5}, + silkPCM: []float32{0.25, 0.5}, expected: []float32{0.375}, }, } { @@ -370,7 +394,7 @@ func TestAddHybridSilkMapsChannels(t *testing.T) { decoder := NewDecoder() out := make([]float32, len(test.expected)) - decoder.addHybridSilk(out, test.silk48, test.streamChannelCount, test.outputChannelCount, 1) + decoder.addHybridSilk(out, test.silkPCM, test.streamChannelCount, test.outputChannelCount, 1) assert.Equal(t, test.expected, out) }) @@ -381,7 +405,7 @@ func TestDecodeSilkFramesAddsHybridTransitionAudio(t *testing.T) { decoder := NewDecoder() decoder.previousMode = configurationModeHybrid - bandwidth, isStereo, sampleCount, decodedChannelCount, err := decoder.decodeSilkFrames( + bandwidth, _, isStereo, sampleCount, decodedChannelCount, err := decoder.decodeSilkFrames( Configuration(8), tableOfContentsHeader(byte(8<<3)|byte(frameCodeOneFrame)), [][]byte{nil}, diff --git a/internal/celt/decoder.go b/internal/celt/decoder.go index 7b53129..11572e7 100644 --- a/internal/celt/decoder.go +++ b/internal/celt/decoder.go @@ -68,7 +68,23 @@ func (d *Decoder) Decode( startBand int, endBand int, ) error { - return d.decode(in, out, isStereo, outputChannelCount, frameSampleCount, startBand, endBand, nil) + return d.DecodeToSampleRate(in, out, isStereo, outputChannelCount, frameSampleCount, startBand, endBand, sampleRate) +} + +// DecodeToSampleRate decodes one CELT frame into interleaved float PCM at the +// requested Opus API sample rate. RFC 6716 keeps the MDCT mode at 48 kHz and +// decimates during CELT deemphasis for lower output rates. +func (d *Decoder) DecodeToSampleRate( + in []byte, + out []float32, + isStereo bool, + outputChannelCount int, + frameSampleCount int, + startBand int, + endBand int, + outputSampleRate int, +) error { + return d.decode(in, out, isStereo, outputChannelCount, frameSampleCount, startBand, endBand, outputSampleRate, nil) } // DecodeWithRange decodes one CELT frame using an Opus range decoder shared @@ -83,7 +99,43 @@ func (d *Decoder) DecodeWithRange( endBand int, rangeDecoder *rangecoding.Decoder, ) error { - return d.decode(in, out, isStereo, outputChannelCount, frameSampleCount, startBand, endBand, rangeDecoder) + return d.DecodeWithRangeToSampleRate( + in, + out, + isStereo, + outputChannelCount, + frameSampleCount, + startBand, + endBand, + sampleRate, + rangeDecoder, + ) +} + +// DecodeWithRangeToSampleRate decodes one CELT frame with a shared range coder +// and emits PCM at the requested Opus API sample rate. +func (d *Decoder) DecodeWithRangeToSampleRate( + in []byte, + out []float32, + isStereo bool, + outputChannelCount int, + frameSampleCount int, + startBand int, + endBand int, + outputSampleRate int, + rangeDecoder *rangecoding.Decoder, +) error { + return d.decode( + in, + out, + isStereo, + outputChannelCount, + frameSampleCount, + startBand, + endBand, + outputSampleRate, + rangeDecoder, + ) } func (d *Decoder) decode( @@ -94,6 +146,7 @@ func (d *Decoder) decode( frameSampleCount int, startBand int, endBand int, + outputSampleRate int, rangeDecoder *rangecoding.Decoder, ) error { channelCount := 1 @@ -103,7 +156,11 @@ func (d *Decoder) decode( if outputChannelCount != 1 && outputChannelCount != 2 { return errInvalidChannelCount } - if len(out) < frameSampleCount*outputChannelCount { + outputFrameSampleCount, err := frameSampleCountAtRate(frameSampleCount, outputSampleRate) + if err != nil { + return err + } + if len(out) < outputFrameSampleCount*outputChannelCount { return errInvalidFrameSize } @@ -113,6 +170,7 @@ func (d *Decoder) decode( endBand: endBand, channelCount: channelCount, outputChannelCount: outputChannelCount, + outputSampleRate: outputSampleRate, } // The reference decoder routes empty and one-byte CELT frames to PLC before // trying to parse side information. @@ -121,7 +179,7 @@ func (d *Decoder) decode( if validateErr != nil { return validateErr } - d.decodeLostFrame(&lostInfo, out[:frameSampleCount*outputChannelCount]) + d.decodeLostFrame(&lostInfo, out[:outputFrameSampleCount*outputChannelCount]) return nil } diff --git a/internal/celt/frame.go b/internal/celt/frame.go index 5920adb..e606d64 100644 --- a/internal/celt/frame.go +++ b/internal/celt/frame.go @@ -30,6 +30,7 @@ type frameConfig struct { endBand int channelCount int outputChannelCount int + outputSampleRate int } type frameSideInfo struct { @@ -39,6 +40,7 @@ type frameSideInfo struct { endBand int channelCount int outputChannelCount int + outputSampleRate int silence bool postFilter postFilter transient bool @@ -114,22 +116,20 @@ func (d *Decoder) prepareCoarseEnergyHistory(info *frameSideInfo) { } func (d *Decoder) validateFrameConfig(cfg frameConfig) (frameSideInfo, error) { + cfg.outputSampleRate = normalizeOutputSampleRate(cfg.outputSampleRate) // RFC 6716 Section 4.3.3 defines LM as log2(frame_size/120). lm, err := d.Mode().LMForFrameSampleCount(cfg.frameSampleCount) if err != nil { return frameSideInfo{}, err } - if cfg.startBand < 0 || cfg.startBand >= d.Mode().BandCount() { - return frameSideInfo{}, errInvalidBand - } - if cfg.endBand <= cfg.startBand || cfg.endBand > d.Mode().BandCount() { - return frameSideInfo{}, errInvalidBand + if _, err = frameSampleCountAtRate(cfg.frameSampleCount, cfg.outputSampleRate); err != nil { + return frameSideInfo{}, err } - if cfg.channelCount != 1 && cfg.channelCount != 2 { - return frameSideInfo{}, errInvalidChannelCount + if err = d.validateBandRange(cfg.startBand, cfg.endBand); err != nil { + return frameSideInfo{}, err } - if cfg.outputChannelCount != 1 && cfg.outputChannelCount != 2 { - return frameSideInfo{}, errInvalidChannelCount + if err = validateChannelCounts(cfg.channelCount, cfg.outputChannelCount); err != nil { + return frameSideInfo{}, err } return frameSideInfo{ @@ -138,11 +138,61 @@ func (d *Decoder) validateFrameConfig(cfg frameConfig) (frameSideInfo, error) { endBand: cfg.endBand, channelCount: cfg.channelCount, outputChannelCount: cfg.outputChannelCount, + outputSampleRate: cfg.outputSampleRate, spread: defaultSpreadDecision, allocationTrim: defaultAllocationTrim, }, nil } +func normalizeOutputSampleRate(outputSampleRate int) int { + if outputSampleRate == 0 { + return sampleRate + } + + return outputSampleRate +} + +func (d *Decoder) validateBandRange(startBand int, endBand int) error { + if startBand < 0 || startBand >= d.Mode().BandCount() { + return errInvalidBand + } + if endBand <= startBand || endBand > d.Mode().BandCount() { + return errInvalidBand + } + + return nil +} + +func validateChannelCounts(channelCount int, outputChannelCount int) error { + if channelCount != 1 && channelCount != 2 { + return errInvalidChannelCount + } + if outputChannelCount != 1 && outputChannelCount != 2 { + return errInvalidChannelCount + } + + return nil +} + +// frameSampleCountAtRate validates an Opus API output rate and maps a 48 kHz +// CELT-domain frame size into the emitted PCM length. +func frameSampleCountAtRate(frameSampleCount int, outputSampleRate int) (int, error) { + switch outputSampleRate { + case 8000, 12000, 16000, 24000, sampleRate: + default: + return 0, errInvalidSampleRate + } + if sampleRate%outputSampleRate != 0 { + return 0, errInvalidSampleRate + } + downsample := sampleRate / outputSampleRate + if frameSampleCount%downsample != 0 { + return 0, errInvalidFrameSize + } + + return frameSampleCount / downsample, nil +} + func (d *Decoder) decodeSilenceFlag(info *frameSideInfo) { tell := d.rangeDecoder.Tell() switch { diff --git a/internal/celt/frame_test.go b/internal/celt/frame_test.go index 1873737..617fea6 100644 --- a/internal/celt/frame_test.go +++ b/internal/celt/frame_test.go @@ -42,6 +42,18 @@ func TestDecodeFrameSideInfoValidatesConfig(t *testing.T) { assert.ErrorIs(t, err, errInvalidChannelCount) } +func TestFrameSampleCountAtRate(t *testing.T) { + samples, err := frameSampleCountAtRate(shortBlockSampleCount, 8000) + require.NoError(t, err) + assert.Equal(t, 20, samples) + + _, err = frameSampleCountAtRate(shortBlockSampleCount, 44100) + assert.ErrorIs(t, err, errInvalidSampleRate) + + _, err = frameSampleCountAtRate(shortBlockSampleCount+1, 8000) + assert.ErrorIs(t, err, errInvalidFrameSize) +} + func TestDecodeFrameSideInfoSilence(t *testing.T) { decoder := NewDecoder() diff --git a/internal/celt/synthesis.go b/internal/celt/synthesis.go index c8a27c4..c453bf5 100644 --- a/internal/celt/synthesis.go +++ b/internal/celt/synthesis.go @@ -63,12 +63,18 @@ var celtWindow120 = [shortBlockSampleCount]float32{ //nolint:gochecknoglobals } // SmoothFade applies the RFC 6716 CELT transition window over one 2.5 ms -// overlap. The decoder mixes in 48 kHz CELT time, so the reference window -// increment is always one sample here. +// overlap at the internal 48 kHz CELT rate. func SmoothFade(in1, in2, out []float32, overlap int, channels int) { + SmoothFadeWithSampleRate(in1, in2, out, overlap, channels, sampleRate) +} + +// SmoothFadeWithSampleRate applies the CELT transition window at an Opus API +// output rate. RFC 6716 indexes the 48 kHz window by 48000/Fs for lower rates. +func SmoothFadeWithSampleRate(in1, in2, out []float32, overlap int, channels int, outputSampleRate int) { + inc := sampleRate / outputSampleRate for channel := range channels { for i := range overlap { - w := celtWindow120[i] * celtWindow120[i] + w := celtWindow120[i*inc] * celtWindow120[i*inc] index := i*channels + channel out[index] = w*in2[index] + (1-w)*in1[index] } @@ -100,10 +106,12 @@ func (d *Decoder) denormaliseAndSynthesize( frameSampleCount := len(x) freqX := make([]float32, frameSampleCount) denormaliseBands(info, x, freqX, bandEnergy[0]) + limitOutputBandwidth(info, freqX) var freqY []float32 if info.channelCount == 2 { freqY = make([]float32, frameSampleCount) denormaliseBands(info, y, freqY, bandEnergy[1]) + limitOutputBandwidth(info, freqY) } if info.outputChannelCount == 2 && info.channelCount == 1 { freqY = make([]float32, frameSampleCount) @@ -120,14 +128,14 @@ func (d *Decoder) denormaliseAndSynthesize( d.applyPostfilter(info, timeX, 0) if info.outputChannelCount == 1 { d.updatePostfilterState(info) - d.deemphasisAndInterleave(timeX, nil, out, frameSampleCount, 1) + d.deemphasisAndInterleave(timeX, nil, out, frameSampleCount, 1, info.outputSampleRate) return } timeY := d.inverseTransformChannel(freqY, 1, info) d.applyPostfilter(info, timeY, 1) d.updatePostfilterState(info) - d.deemphasisAndInterleave(timeX, timeY, out, frameSampleCount, 2) + d.deemphasisAndInterleave(timeX, timeY, out, frameSampleCount, 2, info.outputSampleRate) } // antiCollapse implements RFC 6716 Section 4.3.5 by injecting low-energy @@ -343,6 +351,22 @@ func denormaliseBands(info *frameSideInfo, x []float32, freq []float32, bandEner } } +// limitOutputBandwidth zeros bins above the requested API output rate before +// synthesis, matching RFC 6716's lower-rate CELT output path. +func limitOutputBandwidth(info *frameSideInfo, freq []float32) { + outputSampleRate := info.outputSampleRate + if outputSampleRate == 0 { + outputSampleRate = sampleRate + } + if outputSampleRate == sampleRate { + return + } + bound := len(freq) * outputSampleRate / sampleRate + for i := bound; i < len(freq); i++ { + freq[i] = 0 + } +} + // inverseTransformChannel performs the RFC 6716 Section 4.3.7 IMDCT path for // one channel and carries the weighted overlap-add tail into the next frame. func (d *Decoder) inverseTransformChannel(freq []float32, channel int, info *frameSideInfo) []float32 { @@ -480,16 +504,37 @@ func celtWindow(i int) float32 { // deemphasisAndInterleave applies the decoder-side pre-emphasis inversion after // RFC 6716 synthesis and writes interleaved PCM samples for the caller. -func (d *Decoder) deemphasisAndInterleave(x []float32, y []float32, out []float32, frameSampleCount int, channelCount int) { +func (d *Decoder) deemphasisAndInterleave( + timeX []float32, + timeY []float32, + out []float32, + frameSampleCount int, + channelCount int, + outputSampleRate int, +) { + if outputSampleRate == 0 { + outputSampleRate = sampleRate + } + downsample := sampleRate / outputSampleRate + outputSample := 0 for sample := range frameSampleCount { - left := x[sample] + d.preemphasisMem[0] + left := timeX[sample] + d.preemphasisMem[0] d.preemphasisMem[0] = 0.85000610 * left - out[sample*channelCount] = left / 32768 + if sample%downsample != 0 { + if channelCount == 2 { + right := timeY[sample] + d.preemphasisMem[1] + d.preemphasisMem[1] = 0.85000610 * right + } + + continue + } + out[outputSample*channelCount] = left / 32768 if channelCount == 2 { - right := y[sample] + d.preemphasisMem[1] + right := timeY[sample] + d.preemphasisMem[1] d.preemphasisMem[1] = 0.85000610 * right - out[sample*channelCount+1] = right / 32768 + out[outputSample*channelCount+1] = right / 32768 } + outputSample++ } } diff --git a/internal/celt/synthesis_test.go b/internal/celt/synthesis_test.go index 1058f40..dd2377f 100644 --- a/internal/celt/synthesis_test.go +++ b/internal/celt/synthesis_test.go @@ -53,6 +53,15 @@ func TestLog2AmpAndDenormaliseBands(t *testing.T) { assert.Equal(t, float32(-1), minFloat32(2, -1)) } +func TestLimitOutputBandwidth(t *testing.T) { + freq := []float32{1, 2, 3, 4, 5, 6} + limitOutputBandwidth(&frameSideInfo{outputSampleRate: sampleRate}, freq) + assert.Equal(t, []float32{1, 2, 3, 4, 5, 6}, freq) + + limitOutputBandwidth(&frameSideInfo{outputSampleRate: 16000}, freq) + assert.Equal(t, []float32{1, 2, 0, 0, 0, 0}, freq) +} + func TestDenormaliseAndSynthesizeLayouts(t *testing.T) { tests := []struct { name string @@ -242,7 +251,7 @@ func TestInverseMDCTAndDeemphasisHelpers(t *testing.T) { decoder := NewDecoder() out := make([]float32, 4) - decoder.deemphasisAndInterleave([]float32{32768, 0}, []float32{16384, 0}, out, 2, 2) + decoder.deemphasisAndInterleave([]float32{32768, 0}, []float32{16384, 0}, out, 2, 2, sampleRate) assert.Equal(t, float32(1), out[0]) assert.Equal(t, float32(0.5), out[1]) assert.NotZero(t, decoder.preemphasisMem[0]) diff --git a/internal/silk/decoder.go b/internal/silk/decoder.go index c9349dc..f1f9842 100644 --- a/internal/silk/decoder.go +++ b/internal/silk/decoder.go @@ -16,6 +16,10 @@ type Decoder struct { rangeDecoder rangecoding.Decoder sideDecoder *Decoder + // SILK resets its per-channel prediction state whenever the internal + // decoder rate changes between NB, MB, and WB. + previousBandwidth Bandwidth + // Have we decoded a frame yet? haveDecoded bool @@ -63,6 +67,16 @@ func newChannelDecoder() *Decoder { } } +func (d *Decoder) resetPredictionState() { + d.haveDecoded = false + d.isPreviousFrameVoiced = false + d.previousLag = 100 + d.previousLogGain = 10 + d.previousFrameLPCValues = nil + clear(d.finalOutValues) + d.n0Q15 = nil +} + // RFC 6716 Sections 4.2.7.4, 4.2.7.5.5, and 4.2.7.6.1 require the side // channel to restart gain, LSF, and pitch prediction after an uncoded frame. func (d *Decoder) resetSideDecoderPrediction() { @@ -70,13 +84,18 @@ func (d *Decoder) resetSideDecoderPrediction() { d.sideDecoder = newChannelDecoder() } - d.sideDecoder.haveDecoded = false - d.sideDecoder.isPreviousFrameVoiced = false - d.sideDecoder.previousLag = 100 - d.sideDecoder.previousLogGain = 10 - d.sideDecoder.previousFrameLPCValues = nil - clear(d.sideDecoder.finalOutValues) - d.sideDecoder.n0Q15 = nil + d.sideDecoder.resetPredictionState() +} + +// silk_decoder_set_fs() in the RFC 6716 reference implementation resets the +// predictor history whenever the internal SILK rate changes. The normative +// predictor dependencies are described in Sections 4.2.7.4, 4.2.7.5.5, and +// 4.2.7.6.1, so carrying them across NB/MB/WB switches changes later frames. +func (d *Decoder) resetPredictionForBandwidthChange(bandwidth Bandwidth) { + if d.previousBandwidth != 0 && d.previousBandwidth != bandwidth { + d.resetPredictionState() + } + d.previousBandwidth = bandwidth } // The LP layer begins with two to eight header bits These consist of one @@ -94,6 +113,35 @@ func (d *Decoder) decodeHeaderBits(frameCount int) (voiceActivityDetected []bool return } +// decodeLowBitrateRedundancyFlags expands RFC 6716 Section 4.2.4's global +// LBRR-present bit into one flag per SILK frame. +func (d *Decoder) decodeLowBitrateRedundancyFlags(frameCount int, present bool) []bool { + flags := make([]bool, frameCount) + if !present { + return flags + } + + switch frameCount { + case 1: + flags[0] = true + case 2: + d.decodeLowBitrateRedundancyFlagSymbol(flags, icdfLowBitrateRedundancyFlags40Ms) + case 3: + d.decodeLowBitrateRedundancyFlagSymbol(flags, icdfLowBitrateRedundancyFlags60Ms) + } + + return flags +} + +// decodeLowBitrateRedundancyFlagSymbol decodes the Table 4 bitmap symbol used +// for 40 ms and 60 ms SILK packets. +func (d *Decoder) decodeLowBitrateRedundancyFlagSymbol(flags []bool, icdf []uint) { + symbol := d.rangeDecoder.DecodeSymbolWithICDF(icdf) + for i := range flags { + flags[i] = symbol&(1<