Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 142 additions & 8 deletions internal/celt/frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,18 @@ package celt
import "fmt"

const (
postFilterPitchBase = 16
postFilterGainStep = 0.09375
postFilterPitchBase = 16
postFilterGainStep = 0.09375
bitResolution = 3
defaultSpreadDecision = 2
defaultAllocationTrim = 5
initialDynamicAllocationLogP = 6
minDynamicAllocationLogP = 2
allocationTrimBitCost = 6
firstTimeFrequencyChangeLogP = 4
firstTransientFrequencyChangeLogP = 2
nextTimeFrequencyChangeLogP = 5
nextTransientFrequencyChangeLogP = 4
)

type frameConfig struct {
Expand All @@ -29,6 +39,11 @@ type frameSideInfo struct {
shortBlockCount int
intraEnergy bool
coarseEnergy [2][maxBands]float32
tfChange [maxBands]int
tfSelect int
spread int
bandBoost [maxBands]int
allocationTrim int
}

type postFilter struct {
Expand All @@ -39,8 +54,8 @@ type postFilter struct {
tapset int
}

// decodeFrameSideInfo consumes the initial CELT symbols through coarse energy
// in the order specified by RFC 6716 Table 56. TF changes, allocation, and PVQ
// decodeFrameSideInfo consumes the initial CELT symbols through the allocation
// header in the order specified by RFC 6716 Table 56. Pulse allocation and PVQ
// residual decoding are intentionally left to the following CELT slices.
func (d *Decoder) decodeFrameSideInfo(data []byte, cfg frameConfig) (frameSideInfo, error) {
info, err := d.validateFrameConfig(cfg)
Expand All @@ -63,6 +78,7 @@ func (d *Decoder) decodeFrameSideInfo(data []byte, cfg frameConfig) (frameSideIn
d.decodeIntraEnergyFlag(&info)
d.prepareCoarseEnergyHistory(&info)
d.decodeCoarseEnergy(&info)
d.decodeAllocationHeader(&info)

return info, nil
}
Expand Down Expand Up @@ -97,10 +113,12 @@ func (d *Decoder) validateFrameConfig(cfg frameConfig) (frameSideInfo, error) {
}

return frameSideInfo{
lm: lm,
startBand: cfg.startBand,
endBand: cfg.endBand,
channelCount: cfg.channelCount,
lm: lm,
startBand: cfg.startBand,
endBand: cfg.endBand,
channelCount: cfg.channelCount,
spread: defaultSpreadDecision,
allocationTrim: defaultAllocationTrim,
}, nil
}

Expand Down Expand Up @@ -226,6 +244,122 @@ func (d *Decoder) decodeCoarseEnergyDelta(info *frameSideInfo, probModel []uint8
}
}

func (d *Decoder) decodeAllocationHeader(info *frameSideInfo) {
d.decodeTimeFrequencyChanges(info)
d.decodeSpread(info)
totalBitsEighth := d.decodeDynamicAllocation(info, info.totalBits<<bitResolution)
d.decodeAllocationTrim(info, totalBitsEighth)
}

// decodeTimeFrequencyChanges decodes the RFC 6716 Section 4.3.1 per-band
// tf_change flags and optional tf_select bit, then maps them through Tables 60-63.
func (d *Decoder) decodeTimeFrequencyChanges(info *frameSideInfo) {
logP := firstTimeFrequencyChangeLogP
if info.transient {
logP = firstTransientFrequencyChangeLogP
}

budget := info.totalBits
tell := d.rangeDecoder.Tell()
tfSelectReserved := info.lm > 0 && tell+uint(logP)+1 <= budget
if tfSelectReserved {
budget--
}

current := 0
changed := 0
for band := info.startBand; band < info.endBand; band++ {
if tell+uint(logP) <= budget {
current ^= int(d.rangeDecoder.DecodeSymbolLogP(uint(logP)))
tell = d.rangeDecoder.Tell()
changed |= current
}
info.tfChange[band] = current

if info.transient {
logP = nextTransientFrequencyChangeLogP
} else {
logP = nextTimeFrequencyChangeLogP
}
}

info.tfSelect = 0
table := tfSelectTable[info.lm]
if tfSelectReserved &&
table[4*boolIndex(info.transient)+changed] !=
table[4*boolIndex(info.transient)+2+changed] {
info.tfSelect = int(d.rangeDecoder.DecodeSymbolLogP(1))
}

for band := info.startBand; band < info.endBand; band++ {
info.tfChange[band] = int(table[4*boolIndex(info.transient)+2*info.tfSelect+info.tfChange[band]])
}
}

func (d *Decoder) decodeSpread(info *frameSideInfo) {
info.spread = defaultSpreadDecision
if d.rangeDecoder.Tell()+4 <= info.totalBits {
info.spread = int(d.rangeDecoder.DecodeSymbolWithICDF(icdfSpread))
}
}

// decodeDynamicAllocation decodes RFC 6716 Section 4.3.3 band boost offsets in
// 1/8-bit units and returns the boost-adjusted total bit budget in 1/8-bit units.
func (d *Decoder) decodeDynamicAllocation(info *frameSideInfo, totalBitsEighth uint) uint {
caps := allocationCaps(info.lm, info.channelCount)
dynamicAllocationLogP := initialDynamicAllocationLogP
tellFrac := d.rangeDecoder.TellFrac()

for band := info.startBand; band < info.endBand; band++ {
width := info.channelCount * (int(bandEdges[band+1]-bandEdges[band]) << info.lm)
quanta := min(width<<bitResolution, max(allocationTrimBitCost<<bitResolution, width))
quantaBits := uint(quanta) // #nosec G115 -- quanta is positive by construction from CELT band widths.
loopLogP := dynamicAllocationLogP
boost := 0

for tellFrac+uint(loopLogP<<bitResolution) < totalBitsEighth && boost < caps[band] {
flag := d.rangeDecoder.DecodeSymbolLogP(uint(loopLogP))
tellFrac = d.rangeDecoder.TellFrac()
if flag == 0 {
break
}

boost += quanta
if quantaBits >= totalBitsEighth {
totalBitsEighth = 0
} else {
totalBitsEighth -= quantaBits
}
loopLogP = 1
}

info.bandBoost[band] = boost
if boost > 0 {
dynamicAllocationLogP = max(minDynamicAllocationLogP, dynamicAllocationLogP-1)
}
}

return totalBitsEighth
}

func (d *Decoder) decodeAllocationTrim(info *frameSideInfo, totalBitsEighth uint) {
info.allocationTrim = defaultAllocationTrim
if d.rangeDecoder.TellFrac()+uint(allocationTrimBitCost<<bitResolution) <= totalBitsEighth {
info.allocationTrim = int(d.rangeDecoder.DecodeSymbolWithICDF(icdfAllocationTrim))
}
}

func allocationCaps(lm, channelCount int) [maxBands]int {
caps := [maxBands]int{}
indexBase := maxBands * (2*lm + channelCount - 1)
for band := range maxBands {
width := int(bandEdges[band+1]-bandEdges[band]) << lm
caps[band] = (int(bandCaps[indexBase+band]) + 64) * channelCount * width >> 2
}

return caps
}

func smallEnergyDelta(symbol uint32) int {
switch symbol {
case 1:
Expand Down
154 changes: 149 additions & 5 deletions internal/celt/frame_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,21 +356,165 @@ func TestDecodeCoarseEnergy(t *testing.T) {
})
}

func TestDecodeTimeFrequencyChanges(t *testing.T) {
t.Run("decodes non-transient tf_change and tf_select", func(t *testing.T) {
decoder := NewDecoder()
decoder.rangeDecoder = rangeDecoderWithBinaryOne()
info := frameSideInfo{
lm: 1,
totalBits: 256,
startBand: 0,
endBand: 1,
channelCount: 1,
}

decoder.decodeTimeFrequencyChanges(&info)

assert.Equal(t, 1, info.tfSelect)
assert.Equal(t, -2, info.tfChange[0])
})

t.Run("decodes transient tf_change and tf_select", func(t *testing.T) {
decoder := NewDecoder()
decoder.rangeDecoder = rangeDecoderWithBinaryOne()
info := frameSideInfo{
lm: 2,
totalBits: 256,
startBand: 0,
endBand: 1,
channelCount: 1,
transient: true,
}

decoder.decodeTimeFrequencyChanges(&info)

assert.Equal(t, 1, info.tfSelect)
assert.Equal(t, -1, info.tfChange[0])
})

t.Run("maps default transient changes when budget is exhausted", func(t *testing.T) {
decoder := NewDecoder()
info := frameSideInfo{
lm: 3,
totalBits: 0,
startBand: 0,
endBand: 2,
channelCount: 1,
transient: true,
}

decoder.decodeTimeFrequencyChanges(&info)

assert.Zero(t, info.tfSelect)
assert.Equal(t, 3, info.tfChange[0])
assert.Equal(t, 3, info.tfChange[1])
})
}

func TestDecodeSpread(t *testing.T) {
t.Run("decodes spread decision when enough bits remain", func(t *testing.T) {
decoder := NewDecoder()
decoder.rangeDecoder = rangeDecoderWithCDFSymbol(31, 32)
info := frameSideInfo{totalBits: 256}

decoder.decodeSpread(&info)

assert.Equal(t, 3, info.spread)
})

t.Run("defaults to normal spread without enough bits", func(t *testing.T) {
decoder := NewDecoder()
info := frameSideInfo{totalBits: 0}

decoder.decodeSpread(&info)

assert.Equal(t, defaultSpreadDecision, info.spread)
})
}

func TestDecodeDynamicAllocation(t *testing.T) {
t.Run("decodes boosts until the band cap", func(t *testing.T) {
decoder := NewDecoder()
decoder.rangeDecoder = rangeDecoderWithBinaryOne()
info := frameSideInfo{
lm: 0,
totalBits: 256,
startBand: 0,
endBand: 1,
channelCount: 1,
}
totalBitsEighth := info.totalBits << bitResolution

remaining := decoder.decodeDynamicAllocation(&info, totalBitsEighth)

assert.Equal(t, 72, info.bandBoost[0])
assert.Less(t, remaining, totalBitsEighth)
})

t.Run("stops immediately on a zero boost flag", func(t *testing.T) {
decoder := NewDecoder()
decoder.rangeDecoder = rangeDecoderWithBinaryZero()
info := frameSideInfo{
lm: 0,
totalBits: 256,
startBand: 0,
endBand: 1,
channelCount: 1,
}
totalBitsEighth := info.totalBits << bitResolution

remaining := decoder.decodeDynamicAllocation(&info, totalBitsEighth)

assert.Zero(t, info.bandBoost[0])
assert.Equal(t, totalBitsEighth, remaining)
})
}

func TestDecodeAllocationTrim(t *testing.T) {
t.Run("decodes trim when six bits remain", func(t *testing.T) {
decoder := NewDecoder()
decoder.rangeDecoder = rangeDecoderWithCDFSymbol(87, 128)
info := frameSideInfo{}

decoder.decodeAllocationTrim(&info, uint(256)<<bitResolution)

assert.Equal(t, 6, info.allocationTrim)
})

t.Run("defaults when six bits are unavailable", func(t *testing.T) {
decoder := NewDecoder()
decoder.rangeDecoder = rangeDecoderWithCDFSymbol(87, 128)
info := frameSideInfo{}

decoder.decodeAllocationTrim(&info, decoder.rangeDecoder.TellFrac()+47)

assert.Equal(t, defaultAllocationTrim, info.allocationTrim)
})
}

func rangeDecoderWithBinaryOne() rangecoding.Decoder {
decoder := rangecoding.Decoder{}
decoder.SetInternalValues(nil, 40, 1<<31, 0)

return decoder
}

func rangeDecoderWithBinaryZero() rangecoding.Decoder {
decoder := rangecoding.Decoder{}
decoder.SetInternalValues(nil, 40, 1<<31, (1<<31)-1)

return decoder
}

func rangeDecoderWithSmallEnergyCDFSymbol(symbol uint32) rangecoding.Decoder {
const (
smallEnergyTotal = 4
scale = 1 << 24
)
return rangeDecoderWithCDFSymbol(symbol, 4)
}

func rangeDecoderWithCDFSymbol(symbol, total uint32) rangecoding.Decoder {
const scale = 1 << 24

decoder := rangecoding.Decoder{}
decoder.SetInternalValues(nil, 0, smallEnergyTotal*scale, (smallEnergyTotal-symbol-1)*scale)
decoder.SetInternalValues(nil, 0, total*scale, (total-symbol-1)*scale)

return decoder
}
30 changes: 30 additions & 0 deletions internal/celt/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,36 @@ var (
icdfSmallEnergy = []uint{4, 2, 3, 4}
)

// tfSelectTable is the RFC 6716 Section 4.3.1/Table 60-63 mapping from
// {LM, transient, tf_select, tf_change} to the final per-band TF adjustment.
var tfSelectTable = [maxLM + 1][8]int8{ //nolint:gochecknoglobals
{0, -1, 0, -1, 0, -1, 0, -1},
{0, -1, 0, -2, 1, 0, 1, -1},
{0, -2, 0, -3, 2, 0, 1, -1},
{0, -2, 0, -3, 3, 0, 1, -1},
}

// bandCaps is the static caps cache for the 48 kHz CELT mode used by the
// reference init_caps() helper when bounding dynamic allocation boosts.
var bandCaps = [(maxLM + 1) * 2 * maxBands]uint8{ //nolint:gochecknoglobals
224, 224, 224, 224, 224, 224, 224, 224, 160, 160, 160, 160, 185, 185,
185, 178, 178, 168, 134, 61, 37,
224, 224, 224, 224, 224, 224, 224, 224, 240, 240, 240, 240, 207, 207,
207, 198, 198, 183, 144, 66, 40,
160, 160, 160, 160, 160, 160, 160, 160, 185, 185, 185, 185, 193, 193,
193, 183, 183, 172, 138, 64, 38,
240, 240, 240, 240, 240, 240, 240, 240, 207, 207, 207, 207, 204, 204,
204, 193, 193, 180, 143, 66, 40,
185, 185, 185, 185, 185, 185, 185, 185, 193, 193, 193, 193, 193, 193,
193, 183, 183, 172, 138, 65, 39,
207, 207, 207, 207, 207, 207, 207, 207, 204, 204, 204, 204, 201, 201,
201, 188, 188, 176, 141, 66, 40,
193, 193, 193, 193, 193, 193, 193, 193, 193, 193, 193, 193, 194, 194,
194, 184, 184, 173, 139, 65, 39,
204, 204, 204, 204, 204, 204, 204, 204, 201, 201, 201, 201, 198, 198,
198, 187, 187, 175, 140, 66, 40,
}

// prediction/decay coefficients and eProbModel are the coarse-energy model
// constants used by unquant_coarse_energy() in RFC 6716 Section 4.3.2.1.
var energyPredictionCoefficients = [...]float32{ //nolint:gochecknoglobals
Expand Down
Loading