diff --git a/internal/rangecoding/encoder.go b/internal/rangecoding/encoder.go new file mode 100644 index 0000000..4e1b735 --- /dev/null +++ b/internal/rangecoding/encoder.go @@ -0,0 +1,567 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package rangecoding + +import ( + "math/bits" +) + +// Range coder constants for the 32-bit encoder, matching the decoder +// parameters in Section 4.1 of RFC 6716. +// +// - codeBits = 32 — precision of the range coder integer arithmetic. +// - codeTop = 2**31 — the initial range size and the sentinel that +// separates the carry buffer from the data bits. +// - codeShift = 23 — right-shift to extract the top 9 bits (8 data + 1 +// carry) from val before normalizing. +// - symBits / symMax — the byte-sized output symbol width used by the +// renormalization loop. +const ( + codeBits = 32 + codeTop = uint32(1) << (codeBits - 1) + codeShift = codeBits - 9 + symBits = 8 + symMax = (1 << symBits) - 1 +) + +// Encoder implements the range encoder defined in RFC 6716 Section 5.1. +// +// The range coder acts as the bit-packer for Opus. It is used in three +// different ways: to encode +// +// - Entropy-coded symbols with a fixed probability model using +// ec_encode() (entenc.c), +// +// - Integers from 0 to (2**M - 1) using ec_enc_uint() or ec_enc_bits() +// (entenc.c), +// +// - Integers from 0 to (ft - 1) (where ft is not a power of two) using +// ec_enc_uint() (entenc.c). +// +// The range encoder maintains an internal state vector composed of the +// four-tuple (val, rng, rem, ext) representing the low end of the +// current range, the size of the current range, a single buffered +// output byte, and a count of additional carry-propagating output +// bytes. Both val and rng are 32-bit unsigned integer values, rem is a +// byte value less than 255 or the special value -1, and ext is an +// unsigned integer with at least 11 bits. This state vector is +// initialized at the start of each frame to the value +// (0, 2**31, -1, 0). After encoding a sequence of symbols, the value +// of rng in the encoder should exactly match the value of rng in the +// decoder after decoding the same sequence of symbols. This is a +// powerful tool for detecting errors in either an encoder or decoder +// implementation. The value of val, on the other hand, represents +// different things in the encoder and decoder, and is not expected to +// match. +// +// The decoder has no analog for rem and ext. These are used to perform +// carry propagation in the renormalization loop below. Each iteration +// of this loop produces 9 bits of output, consisting of 8 data bits and +// a carry flag. The encoder cannot determine the final value of the +// output bytes until it propagates these carry flags. Therefore, the +// reference implementation buffers a single non-propagating output byte +// (i.e., one less than 255) in rem and keeps a count of additional +// propagating (i.e., 255) output bytes in ext. +// +// Symbols may also be coded as "raw bits" packed directly into the +// bitstream, bypassing the range coder. These are packed backwards +// starting at the end of the frame, as illustrated in Figure 12 of +// RFC 6716. This reduces complexity and makes the stream more resilient +// to bit errors. Raw bits are only used in the CELT layer. +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Range coder data (packed MSB to LSB) -> : +// + + +// : : +// + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// : | <- Boundary occurs at an arbitrary bit position : +// +-+-+-+ + +// : <- Raw bits data (packed LSB to MSB) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Legend: +// +// LSB = Least Significant Bit +// MSB = Most Significant Bit +// +// Figure 12: Illustrative Example of Packing Range Coder +// and Raw Bits Data +// +// https://datatracker.ietf.org/doc/html/rfc6716#section-5.1 +type Encoder struct { + buf []byte // range-coded bytes flushed front-to-back (val in RFC 6716) + tail []byte // raw-bits bytes flushed in LSB-first order + + endWindow uint64 // accumulator for raw bits not yet flushed to tail + nendBits uint // number of valid bits in endWindow + + rangeSize uint32 // rng in RFC 6716 — current range size + low uint32 // val in RFC 6716 — low end of the current range + + rem int // buffered pending byte (-1 = empty); rem in RFC 6716 + extBytes int // count of carry-propagating 0xFF bytes; ext in RFC 6716 + + nbitsTotal uint // conservative bit-usage counter for Tell/TellFrac +} + +// Init resets the Encoder state for a new frame. +// +// RFC 6716 Section 5.1 specifies that the encoder state vector +// (val, rng, rem, ext) is initialized at the start of each frame to +// (0, 2**31, -1, 0). nbitsTotal is set to codeBits + 1 so that Tell() +// returns 1 after initialization, matching the decoder's post-Init value +// (the decoder consumes one bit of bootstrap input during Init). +// +// https://datatracker.ietf.org/doc/html/rfc6716#section-5.1 +func (e *Encoder) Init() { + e.buf = e.buf[:0] + e.tail = e.tail[:0] + e.endWindow = 0 + e.nendBits = 0 + e.rangeSize = codeTop + e.low = 0 + e.rem = -1 + e.extBytes = 0 + e.nbitsTotal = codeBits + 1 +} + +// EncodeSymbolWithICDF encodes a symbol using the same inverse cumulative +// distribution table format consumed by Decoder.DecodeSymbolWithICDF. +// +// This implements ec_enc_icdf() (entenc.c), which is mathematically +// equivalent to calling ec_encode() with fl[k] = (1<= len(table) { + return + } + + high := uint32(table[symbol]) //nolint:gosec // G115 + low := uint32(0) + if symbol != 0 { + low = uint32(table[symbol-1]) //nolint:gosec // G115 + } + + e.EncodeCumulative(low, high, total) +} + +// EncodeSymbolLogP encodes a single binary symbol with probability 1/(1<> logp + rangeSize := e.rangeSize - scale + + if symbol != 0 { + e.low += rangeSize + e.rangeSize = scale + } else { + e.rangeSize = rangeSize + } + + e.normalize() +} + +// EncodeCumulative encodes a pre-selected cumulative interval (low, high) +// out of total equally weighted bins. +// +// This is the main encoding function ec_encode() (entenc.c) defined in +// RFC 6716 Section 5.1.1. It encodes symbol k described by the three-tuple +// (fl[k], fh[k], ft) using the same semantics as the decoder's ec_decode(). +// +// If fl[k] (low) is greater than zero: +// +// val = val + rng - (rng / ft) * (ft - fl) +// rng = (rng / ft) * (fh - fl) +// +// Otherwise val is unchanged and: +// +// rng = rng - (rng / ft) * (ft - fh) +// +// https://datatracker.ietf.org/doc/html/rfc6716#section-5.1.1 +func (e *Encoder) EncodeCumulative(low, high, total uint32) { + if total == 0 || low >= high || high > total { + return + } + + scale := e.rangeSize / total + if low != 0 { + e.low += e.rangeSize - scale*(total-low) + e.rangeSize = scale * (high - low) + } else { + e.rangeSize -= scale * (total - high) + } + + e.normalize() +} + +// EncodeUniform encodes one of ft equiprobable symbols in the range +// [0, ft), implementing ec_enc_uint() (entenc.c). +// +// RFC 6716 Section 5.1.4 splits the value into a range-coded prefix of up +// to 8 high bits and, if ft requires more than 8 bits, a raw-bit suffix: +// +// If ftb = ilog(ft - 1) <= 8, encode t directly via ec_encode(). +// If ftb > 8, encode t>>(ftb-8) via ec_encode() and the remaining +// (ftb - 8) bits of t via ec_enc_bits(). +// +// https://datatracker.ietf.org/doc/html/rfc6716#section-5.1.4 +func (e *Encoder) EncodeUniform(total, symbol uint32) { + if total <= 1 { + return + } + + if symbol >= total { + symbol = total - 1 + } + + limit := total - 1 + bitCount := bits.Len32(limit) + if bitCount <= maxUniformRangeCoderBits { + e.EncodeCumulative(symbol, symbol+1, total) + + return + } + + rawBitCount := bitCount - maxUniformRangeCoderBits + rangeTotal := (limit >> rawBitCount) + 1 + prefix := symbol >> rawBitCount + e.EncodeCumulative(prefix, prefix+1, rangeTotal) + e.EncodeRawBits(uint(rawBitCount), symbol&bitMask(uint(rawBitCount))) +} + +// EncodeLaplace encodes a Laplace-distributed integer value using the +// same probability model as Decoder.DecodeLaplace. +// +// RFC 6716 Section 4.3.2.1 describes coarse energy deltas as +// Laplace-distributed prediction errors. The distribution is parameterized +// by fs0 (the frequency of the zero symbol, in units of 1/32768) and decay +// (the geometric decay rate of adjacent-magnitude frequencies, Q15). +// +// https://datatracker.ietf.org/doc/html/rfc6716#section-4.3.2.1 +func (e *Encoder) EncodeLaplace(fs0, decay uint32, value int) { + low, high := laplaceInterval(fs0, decay, value) + e.EncodeCumulative(low, high, laplaceTotal) +} + +// EncodeRawBits appends n bits of value to the raw-bits region at the end of +// the frame, packed in LSB-first order. +// +// RFC 6716 Section 5.1.3 specifies that raw bits used by the CELT layer are +// packed at the end of the buffer using ec_enc_bits() (entenc.c). Because the +// raw bits may continue into the last byte output by the range coder if there +// is room in the low-order bits, Done() merges the two regions into a single +// byte when they meet. +// +// https://datatracker.ietf.org/doc/html/rfc6716#section-5.1.3 +func (e *Encoder) EncodeRawBits(n uint, value uint32) { + if n == 0 { + return + } + if n < 32 { + value &= bitMask(n) + } + e.endWindow |= uint64(value) << e.nendBits + e.nendBits += n + e.nbitsTotal += n + for e.nendBits >= symBits { + e.tail = append(e.tail, byte(e.endWindow&symMax)) + e.endWindow >>= symBits + e.nendBits -= symBits + } +} + +// Tell returns a conservative upper bound, in whole bits, of the number of +// bits encoded into the current frame so far. +// +// This implements ec_tell() (entcode.h) from RFC 6716 Section 5.1.6. The +// bit allocation routines in Opus use this value to track budget consumption +// and prevent the range coder from overflowing the output buffer. After +// encoding the same symbols, the encoder and decoder must produce identical +// Tell() values. +// +// https://datatracker.ietf.org/doc/html/rfc6716#section-5.1.6 +func (e *Encoder) Tell() uint { + lg := uint(bits.Len32(e.rangeSize)) //nolint:gosec // G115: bits.Len32 returns 0..32. + if lg == 0 { + return e.nbitsTotal + } + + if e.nbitsTotal <= lg { + return 0 + } + + return e.nbitsTotal - lg +} + +// TellFrac returns a conservative upper bound in 1/8-bit units. +// +// This implements ec_tell_frac() (entcode.c) from RFC 6716 Section 5.1.6. +// It refines the Tell() estimate by squaring down the fractional part of the +// range size three times to obtain three additional sub-bit fractions. The +// encoder and decoder must produce identical TellFrac() values after encoding +// and decoding the same symbols. +// +// https://datatracker.ietf.org/doc/html/rfc6716#section-5.1.6 +func (e *Encoder) TellFrac() uint { + lg := uint(bits.Len32(e.rangeSize)) //nolint:gosec // G115: bits.Len32 returns 0..32. + if lg == 0 { + return e.nbitsTotal * 8 + } + if lg < 24 { + return e.Tell() * 8 + } + + rangeQ15 := uint64(e.rangeSize >> (lg - 16)) + for range 3 { + rangeQ15 = (rangeQ15 * rangeQ15) >> 15 + bit := rangeQ15 >> 16 + lg = 2*lg + uint(bit) + if bit != 0 { + rangeQ15 >>= 1 + } + } + + total := e.nbitsTotal * 8 + if total <= lg { + return 0 + } + + return total - lg +} + +// FinalRange exposes the current range coder range state for tests. +// +// RFC 6716 Section 5.1 states that after encoding a sequence of symbols the +// value of rng in the encoder should exactly match the value of rng in the +// decoder after decoding the same sequence of symbols. This is a powerful +// tool for detecting errors in either an encoder or decoder implementation. +func (e *Encoder) FinalRange() uint32 { + return e.rangeSize +} + +// Done flushes the range coder and raw bits into a single output frame, +// implementing ec_enc_done() (entenc.c). +// +// RFC 6716 Section 5.1.5 describes the finalization procedure: +// +// 1. Find the unsigned integer end in [val, val+rng) with the largest +// number of trailing zero bits b such that (end + (1<= 0 || e.extBytes > 0 { + e.carryOut(0) + } + + freeBitsInLastRangeByte := uint(0) + if remainingBits < 0 { + freeBitsInLastRangeByte = uint(-remainingBits) //nolint:gosec // G115 + } + + out := make([]byte, len(e.buf)+len(e.tail)+boolToInt(e.shouldWritePartialToNewByte(freeBitsInLastRangeByte))) + copy(out, e.buf) + + for index, value := range e.tail { + out[len(out)-1-index] = value + } + + if e.nendBits > 0 { + partial := byte(e.endWindow & uint64(bitMask(e.nendBits))) //nolint:gosec // G115: masked to at most 8 bits. + if e.shouldWritePartialToNewByte(freeBitsInLastRangeByte) { + out[len(e.buf)] = partial + } else { + out[len(e.buf)-1] |= partial + } + } + + return out +} + +// flushRangeCoder finalizes the range-coded portion of the frame by finding +// the integer end in [val, val+rng) with the most trailing zero bits and +// flushing its remaining bytes through the carry buffer. +// +// RFC 6716 Section 5.1.5 specifies that end is chosen so that +// (end + (1<> remainingBits + end := (e.low + mask) &^ mask + if (end | mask) >= e.low+e.rangeSize { + remainingBits++ + mask >>= 1 + end = (e.low + mask) &^ mask + } + + for remainingBits > 0 { + e.carryOut(int(end >> codeShift)) + end = (end << symBits) & (codeTop - 1) + remainingBits -= symBits + } + + return remainingBits +} + +// shouldWritePartialToNewByte reports whether the leftover raw-bits nibble +// must occupy its own byte rather than being ORed into the last range-coder +// byte. This is the case when there are no range-coder bytes yet, or when +// the partial nibble is wider than the free low-order bits of the last +// range-coder byte. +func (e *Encoder) shouldWritePartialToNewByte(freeBitsInLastRangeByte uint) bool { + return e.nendBits > 0 && (len(e.buf) == 0 || e.nendBits > freeBitsInLastRangeByte) +} + +// normalize implements ec_enc_normalize() (entenc.c), the renormalization +// step that maintains the invariant rng > 2**23 after each symbol is encoded. +// +// RFC 6716 Section 5.1.1.1 specifies: repeat until rng > 2**23. First, +// send the top 9 bits of val, (val>>23), to the carry buffer. Then set +// +// val = (val<<8) & 0x7FFFFFFF +// rng = rng<<8 +// +// https://datatracker.ietf.org/doc/html/rfc6716#section-5.1.1 +func (e *Encoder) normalize() { + for e.rangeSize <= minRangeSize { + e.carryOut(int(e.low >> codeShift)) + e.low = (e.low << symBits) & (codeTop - 1) + e.rangeSize <<= symBits + e.nbitsTotal += symBits + } +} + +// carryOut implements ec_enc_carry_out() (entenc.c), which performs carry +// propagation and output buffering for the 9-bit value produced by each +// iteration of the renormalization loop. +// +// RFC 6716 Section 5.1.1.2: the input value c consists of 8 data bits and an +// additional carry bit. +// +// - If c == 255: ext is incremented and no other state update is performed. +// +// - Otherwise let b = c >> 8 be the carry bit. Then: +// +// o If rem contains a value other than -1, output the byte (rem + b). +// o If ext is non-zero, output ext bytes of value (255 if b == 0, else 0), +// then set ext to 0. +// o Set rem = c & 255. +// +// https://datatracker.ietf.org/doc/html/rfc6716#section-5.1.1 +func (e *Encoder) carryOut(value int) { + if value != symMax { + carry := value >> symBits + if e.rem >= 0 { + e.buf = append(e.buf, byte(e.rem+carry)) //nolint:gosec // G115: carry propagation is bounded to one byte. + } + if e.extBytes > 0 { + flush := byte((symMax + carry) & symMax) + for range e.extBytes { + e.buf = append(e.buf, flush) + } + e.extBytes = 0 + } + e.rem = value & symMax + + return + } + e.extBytes++ +} + +// bitMask returns a uint32 with the n lowest bits set. +func bitMask(n uint) uint32 { + if n >= 32 { + return ^uint32(0) + } + + return (uint32(1) << n) - 1 +} + +// laplaceInterval computes the cumulative [low, high) interval for the given +// value in the Laplace distribution defined by RFC 6716 Section 4.3.2.1. +// +// The distribution is parameterized by fs0 (the cumulative frequency of the +// zero symbol, in units of 1/32768) and decay (the geometric decay rate of +// adjacent-magnitude frequencies, in Q15 fixed-point). Positive and negative +// values of equal magnitude share a frequency but occupy disjoint halves of the +// cumulative axis: the positive half comes first, the negative half follows +// immediately. laplaceFirstDecayFrequency() computes the per-step decay. +func laplaceInterval(fs0 uint32, decay uint32, value int) (uint32, uint32) { + if value == 0 { + return 0, min(fs0, uint32(laplaceTotal)) + } + + magnitude := value + if magnitude < 0 { + magnitude = -magnitude + } + + low := fs0 + frequency := laplaceFirstDecayFrequency(fs0, decay) + laplaceMinProbability + currentMagnitude := 1 + for currentMagnitude < magnitude && frequency > laplaceMinProbability { + low += 2 * frequency + frequency = ((2*frequency - 2*laplaceMinProbability) * decay) >> 15 + frequency += laplaceMinProbability + currentMagnitude++ + } + if currentMagnitude < magnitude { + deltaCount := uint32(magnitude - currentMagnitude) //nolint:gosec // G115 + low += 2 * deltaCount * laplaceMinProbability + frequency = laplaceMinProbability + } + if value < 0 { + return min(low, uint32(laplaceTotal)), min(low+frequency, uint32(laplaceTotal)) + } + + low += frequency + + return min(low, uint32(laplaceTotal)), min(low+frequency, uint32(laplaceTotal)) +} + +// boolToInt converts a bool to 0 or 1. +func boolToInt(value bool) int { + if value { + return 1 + } + + return 0 +} diff --git a/internal/rangecoding/encoder_test.go b/internal/rangecoding/encoder_test.go new file mode 100644 index 0000000..439a159 --- /dev/null +++ b/internal/rangecoding/encoder_test.go @@ -0,0 +1,360 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package rangecoding + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +//nolint:gochecknoglobals +var testICDFTable = []uint{256, 32, 160, 256} + +func TestEncoderRoundTrip(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + encoder.EncodeSymbolLogP(1, 0) + encoder.EncodeSymbolLogP(3, 1) + encoder.EncodeSymbolWithICDF(testICDFTable, 2) + encoder.EncodeUniform(6, 4) + encoder.EncodeUniform(300, 257) + encoder.EncodeLaplace(72<<7, 127<<6, -2) + encoder.EncodeLaplace(72<<7, 127<<6, 3) + encoder.EncodeRawBits(5, 0x16) + encoder.EncodeRawBits(11, 0x5A3) + + packet := encoder.Done() + decoder := &Decoder{} + decoder.Init(packet) + + assert.Equal(t, uint32(0), decoder.DecodeSymbolLogP(1)) + assert.Equal(t, uint32(1), decoder.DecodeSymbolLogP(3)) + assert.Equal(t, uint32(2), decoder.DecodeSymbolWithICDF(testICDFTable)) + + value, ok := decoder.DecodeUniform(6) + assert.True(t, ok) + assert.Equal(t, uint32(4), value) + + value, ok = decoder.DecodeUniform(300) + assert.True(t, ok) + assert.Equal(t, uint32(257), value) + + assert.Equal(t, -2, decoder.DecodeLaplace(72<<7, 127<<6)) + assert.Equal(t, 3, decoder.DecodeLaplace(72<<7, 127<<6)) + assert.Equal(t, uint32(0x16), decoder.DecodeRawBits(5)) + assert.Equal(t, uint32(0x5A3), decoder.DecodeRawBits(11)) + assert.NotZero(t, decoder.FinalRange()) +} + +func TestEncoderCumulativeRoundTrip(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + encoder.EncodeCumulative(2, 3, 5) + encoder.EncodeRawBits(3, 0x05) + + packet := encoder.Done() + decoder := &Decoder{} + decoder.Init(packet) + + symbol := decoder.DecodeCumulative(5) + decoder.UpdateCumulative(symbol, symbol+1, 5) + + assert.Equal(t, uint32(2), symbol) + assert.Equal(t, uint32(0x05), decoder.DecodeRawBits(3)) +} + +func TestEncoderDoneEmptyFrame(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + + packet := encoder.Done() + + assert.NotNil(t, packet) + decoder := &Decoder{} + assert.NotPanics(t, func() { + decoder.Init(packet) + }) +} + +func TestEncoderFinalRangeMatchesDecoder(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + encoder.EncodeSymbolLogP(3, 1) + encoder.EncodeUniform(6, 4) + encoder.EncodeLaplace(72<<7, 127<<6, 2) + + packet := encoder.Done() + decoder := &Decoder{} + decoder.Init(packet) + + decoder.DecodeSymbolLogP(3) + decoder.DecodeUniform(6) + decoder.DecodeLaplace(72<<7, 127<<6) + + assert.Equal(t, encoder.FinalRange(), decoder.FinalRange()) +} + +func TestEncoderTell(t *testing.T) { + t.Run("reports one bit after initialization", func(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + + assert.Equal(t, uint(1), encoder.Tell()) + assert.Equal(t, uint(8), encoder.TellFrac()) + }) + + t.Run("increases after raw bits are encoded", func(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + tellBefore := encoder.Tell() + + encoder.EncodeRawBits(8, 0xAB) + + assert.Equal(t, tellBefore+8, encoder.Tell()) + }) + + t.Run("matches decoder Tell after the same symbols", func(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + encoder.EncodeSymbolLogP(1, 0) + encoder.EncodeSymbolLogP(3, 1) + encoder.EncodeRawBits(8, 0xFF) + encoderTell := encoder.Tell() + + packet := encoder.Done() + decoder := &Decoder{} + decoder.Init(packet) + decoder.DecodeSymbolLogP(1) + decoder.DecodeSymbolLogP(3) + decoder.DecodeRawBits(8) + + assert.Equal(t, encoderTell, decoder.Tell()) + }) + + t.Run("TellFrac does not underflow on fresh encoder", func(t *testing.T) { + encoder := &Encoder{rangeSize: 1 << 31} + + assert.NotPanics(t, func() { + _ = encoder.TellFrac() + }) + }) +} + +func TestEncoderUniformEdgeCases(t *testing.T) { + t.Run("total of one is a no-op", func(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + tellBefore := encoder.Tell() + + encoder.EncodeUniform(1, 0) + + assert.Equal(t, tellBefore, encoder.Tell()) + }) + + t.Run("symbol clamped to total-1 still decodes", func(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + encoder.EncodeUniform(6, 99) + + packet := encoder.Done() + decoder := &Decoder{} + decoder.Init(packet) + + got, ok := decoder.DecodeUniform(6) + assert.True(t, ok) + assert.Equal(t, uint32(5), got) + }) + + t.Run("round-trips all values in a small alphabet", func(t *testing.T) { + for symbol := range uint32(6) { + encoder := &Encoder{} + encoder.Init() + encoder.EncodeUniform(6, symbol) + + packet := encoder.Done() + decoder := &Decoder{} + decoder.Init(packet) + + got, ok := decoder.DecodeUniform(6) + assert.True(t, ok) + assert.Equal(t, symbol, got, "symbol %d", symbol) + } + }) +} + +func TestEncoderRawBitsEdgeCases(t *testing.T) { + t.Run("zero bits is a no-op", func(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + tellBefore := encoder.Tell() + + encoder.EncodeRawBits(0, 0xFF) + + assert.Equal(t, tellBefore, encoder.Tell()) + }) + + t.Run("rounds trip a full byte", func(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + encoder.EncodeRawBits(8, 0xB2) + + packet := encoder.Done() + decoder := &Decoder{} + decoder.Init(packet) + + assert.Equal(t, uint32(0xB2), decoder.DecodeRawBits(8)) + }) + + t.Run("high bits beyond n are masked", func(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + encoder.EncodeRawBits(4, 0xFF) + + packet := encoder.Done() + decoder := &Decoder{} + decoder.Init(packet) + + assert.Equal(t, uint32(0x0F), decoder.DecodeRawBits(4)) + }) +} + +func TestEncoderLaplaceEdgeCases(t *testing.T) { + zeroFrequency := uint32(72 << 7) + decay := uint32(127 << 6) + + for _, test := range []struct { + name string + value int + }{ + {name: "zero delta", value: 0}, + {name: "first positive delta", value: 1}, + {name: "first negative delta", value: -1}, + {name: "large positive delta", value: 10}, + {name: "large negative delta", value: -10}, + } { + t.Run(test.name, func(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + encoder.EncodeLaplace(zeroFrequency, decay, test.value) + + packet := encoder.Done() + decoder := &Decoder{} + decoder.Init(packet) + + assert.Equal(t, test.value, decoder.DecodeLaplace(zeroFrequency, decay)) + }) + } +} + +func TestEncoderICDFEdgeCases(t *testing.T) { + t.Run("encodes all valid symbols in the table", func(t *testing.T) { + for symbol := range len(testICDFTable) - 1 { + symbol := uint32(symbol) + encoder := &Encoder{} + encoder.Init() + encoder.EncodeSymbolWithICDF(testICDFTable, symbol) + + packet := encoder.Done() + decoder := &Decoder{} + decoder.Init(packet) + + assert.Equal(t, symbol, decoder.DecodeSymbolWithICDF(testICDFTable), "symbol %d", symbol) + } + }) + + t.Run("ignores too-short table", func(t *testing.T) { + encoder := &Encoder{} + encoder.Init() + tellBefore := encoder.Tell() + + encoder.EncodeSymbolWithICDF([]uint{256}, 0) + + assert.Equal(t, tellBefore, encoder.Tell()) + }) +} + +func FuzzEncoderRoundTrip(f *testing.F) { + f.Add([]byte{0, 1, 2, 3, 4, 5, 6, 7}) + f.Add([]byte{255, 128, 64, 32, 16, 8, 4, 2}) + f.Add([]byte{9, 7, 5, 3, 1, 0, 2, 4, 6, 8}) + + f.Fuzz(func(t *testing.T, data []byte) { + encoder := &Encoder{} + encoder.Init() + + ops := make([]fuzzOperation, 0, len(data)) + for index := range data { + switch data[index] % 5 { + case 0: + logp := uint(data[index]%7 + 1) + symbol := uint32(data[index] & 1) + encoder.EncodeSymbolLogP(logp, symbol) + ops = append(ops, fuzzOperation{kind: 0, logp: logp, symbol: symbol}) + case 1: + symbol := uint32(data[index] % 3) + encoder.EncodeSymbolWithICDF(testICDFTable, symbol) + ops = append(ops, fuzzOperation{kind: 1, symbol: symbol, icdfTable: testICDFTable}) + case 2: + total := uint32(data[index]) + 2 + symbol := uint32(data[index]>>1) % total + encoder.EncodeUniform(total, symbol) + ops = append(ops, fuzzOperation{kind: 2, total: total, symbol: symbol}) + case 3: + value := int(data[index]%7) - 3 + encoder.EncodeLaplace(72<<7, 127<<6, value) + ops = append(ops, fuzzOperation{kind: 3, laplace: value}) + default: + rawBits := uint(data[index] % 17) + rawValue := uint32(data[index]) + if index+1 < len(data) { + rawValue = uint32(data[index]) | uint32(data[index+1])<<8 + } + if rawBits < 32 { + rawValue &= bitMask(rawBits) + } + encoder.EncodeRawBits(rawBits, rawValue) + ops = append(ops, fuzzOperation{kind: 4, rawBits: rawBits, rawValue: rawValue}) + } + } + + packet := encoder.Done() + decoder := &Decoder{} + decoder.Init(packet) + + for _, op := range ops { + assertFuzzOperation(t, decoder, op) + } + }) +} + +type fuzzOperation struct { + kind byte + logp uint + symbol uint32 + total uint32 + laplace int + rawBits uint + rawValue uint32 + icdfTable []uint +} + +func assertFuzzOperation(t *testing.T, decoder *Decoder, op fuzzOperation) { + t.Helper() + + switch op.kind { + case 0: + assert.Equal(t, op.symbol, decoder.DecodeSymbolLogP(op.logp)) + case 1: + assert.Equal(t, op.symbol, decoder.DecodeSymbolWithICDF(op.icdfTable)) + case 2: + got, ok := decoder.DecodeUniform(op.total) + assert.True(t, ok) + assert.Equal(t, op.symbol, got) + case 3: + assert.Equal(t, op.laplace, decoder.DecodeLaplace(72<<7, 127<<6)) + case 4: + assert.Equal(t, op.rawValue, decoder.DecodeRawBits(op.rawBits)) + } +} diff --git a/pkg/oggreader/oggreader_test.go b/pkg/oggreader/oggreader_test.go index 2e5e3d7..e95bc7c 100644 --- a/pkg/oggreader/oggreader_test.go +++ b/pkg/oggreader/oggreader_test.go @@ -370,6 +370,7 @@ func buildOpusIDHeader( header = append(header, streamCount, coupledCount) for _, coefficient := range demixingMatrix { packed := make([]byte, 2) + //nolint:gosec // Test data needs the int16 bit pattern encoded verbatim. binary.LittleEndian.PutUint16(packed, uint16(coefficient)) header = append(header, packed...) }