diff --git a/.golangci.yml b/.golangci.yml index 88cb4fbf..120faf29 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -19,12 +19,16 @@ linters-settings: recommendations: - errors forbidigo: + analyze-types: true forbid: - ^fmt.Print(f|ln)?$ - ^log.(Panic|Fatal|Print)(f|ln)?$ - ^os.Exit$ - ^panic$ - ^print(ln)?$ + - p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ + pkg: ^testing$ + msg: "use testify/assert instead" varnamelen: max-distance: 12 min-name-length: 2 @@ -127,9 +131,12 @@ issues: exclude-dirs-use-default: false exclude-rules: # Allow complex tests and examples, better to be self contained - - path: (examples|main\.go|_test\.go) + - path: (examples|main\.go) linters: + - gocognit - forbidigo + - path: _test\.go + linters: - gocognit # Allow forbidden identifiers in CLI commands diff --git a/abscapturetimeextension_test.go b/abscapturetimeextension_test.go index 60c397f0..d11c072e 100644 --- a/abscapturetimeextension_test.go +++ b/abscapturetimeextension_test.go @@ -6,81 +6,56 @@ package rtp import ( "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestAbsCaptureTimeExtension_Roundtrip(t *testing.T) { //nolint:cyclop t.Run("positive captureClockOffset", func(t *testing.T) { t0 := time.Now() e1 := NewAbsCaptureTimeExtension(t0) - b1, err1 := e1.Marshal() - if err1 != nil { - t.Fatal(err1) - } + b1, err := e1.Marshal() + assert.NoError(t, err) var o1 AbsCaptureTimeExtension - if err := o1.Unmarshal(b1); err != nil { - t.Fatal(err) - } + assert.NoError(t, o1.Unmarshal(b1)) dt1 := o1.CaptureTime().Sub(t0).Seconds() - if dt1 < -0.001 || dt1 > 0.001 { - t.Fatalf("timestamp differs, want %v got %v (dt=%f)", t0, o1.CaptureTime(), dt1) - } - if o1.EstimatedCaptureClockOffsetDuration() != nil { - t.Fatalf("duration differs, want nil got %d", o1.EstimatedCaptureClockOffsetDuration()) - } + assert.GreaterOrEqual(t, dt1, -0.001) + assert.LessOrEqual(t, dt1, 0.001) + assert.Nil(t, o1.EstimatedCaptureClockOffsetDuration()) e2 := NewAbsCaptureTimeExtensionWithCaptureClockOffset(t0, 1250*time.Millisecond) - b2, err2 := e2.Marshal() - if err2 != nil { - t.Fatal(err2) - } + b2, err := e2.Marshal() + assert.NoError(t, err) var o2 AbsCaptureTimeExtension - if err := o2.Unmarshal(b2); err != nil { - t.Fatal(err) - } + assert.NoError(t, o2.Unmarshal(b2)) dt2 := o1.CaptureTime().Sub(t0).Seconds() - if dt2 < -0.001 || dt2 > 0.001 { - t.Fatalf("timestamp differs, want %v got %v (dt=%f)", t0, o2.CaptureTime(), dt2) - } - if *o2.EstimatedCaptureClockOffsetDuration() != 1250*time.Millisecond { - t.Fatalf("duration differs, want 250ms got %d", *o2.EstimatedCaptureClockOffsetDuration()) - } + assert.GreaterOrEqual(t, dt2, -0.001) + assert.LessOrEqual(t, dt2, 0.001) + assert.Equal(t, 1250*time.Millisecond, *o2.EstimatedCaptureClockOffsetDuration()) }) // This test can verify the for for the issue 247 t.Run("negative captureClockOffset", func(t *testing.T) { t0 := time.Now() e1 := NewAbsCaptureTimeExtension(t0) - b1, err1 := e1.Marshal() - if err1 != nil { - t.Fatal(err1) - } + b1, err := e1.Marshal() + assert.NoError(t, err) var o1 AbsCaptureTimeExtension - if err := o1.Unmarshal(b1); err != nil { - t.Fatal(err) - } + assert.NoError(t, o1.Unmarshal(b1)) dt1 := o1.CaptureTime().Sub(t0).Seconds() - if dt1 < -0.001 || dt1 > 0.001 { - t.Fatalf("timestamp differs, want %v got %v (dt=%f)", t0, o1.CaptureTime(), dt1) - } - if o1.EstimatedCaptureClockOffsetDuration() != nil { - t.Fatalf("duration differs, want nil got %d", o1.EstimatedCaptureClockOffsetDuration()) - } + assert.GreaterOrEqual(t, dt1, -0.001) + assert.LessOrEqual(t, dt1, 0.001) + assert.Nil(t, o1.EstimatedCaptureClockOffsetDuration()) e2 := NewAbsCaptureTimeExtensionWithCaptureClockOffset(t0, -250*time.Millisecond) - b2, err2 := e2.Marshal() - if err2 != nil { - t.Fatal(err2) - } + b2, err := e2.Marshal() + assert.NoError(t, err) + var o2 AbsCaptureTimeExtension - if err := o2.Unmarshal(b2); err != nil { - t.Fatal(err) - } + assert.NoError(t, o2.Unmarshal(b2)) dt2 := o1.CaptureTime().Sub(t0).Seconds() - if dt2 < -0.001 || dt2 > 0.001 { - t.Fatalf("timestamp differs, want %v got %v (dt=%f)", t0, o2.CaptureTime(), dt2) - } - if *o2.EstimatedCaptureClockOffsetDuration() != -250*time.Millisecond { - t.Fatalf("duration differs, want -250ms got %v", *o2.EstimatedCaptureClockOffsetDuration()) - } + assert.GreaterOrEqual(t, dt2, -0.001) + assert.LessOrEqual(t, dt2, 0.001) + assert.Equal(t, -250*time.Millisecond, *o2.EstimatedCaptureClockOffsetDuration()) }) } diff --git a/abssendtimeextension_test.go b/abssendtimeextension_test.go index 16834d9a..3446b12a 100644 --- a/abssendtimeextension_test.go +++ b/abssendtimeextension_test.go @@ -6,6 +6,8 @@ package rtp import ( "testing" "time" + + "github.com/stretchr/testify/assert" ) const absSendTimeResolution = 3800 * time.Nanosecond @@ -24,20 +26,22 @@ func TestNtpConversion(t *testing.T) { for i, in := range tests { out := toNtpTime(in.t) - if out != in.n { - t.Errorf("[%d] Converted NTP time from time.Time differs, expected: %d, got: %d", - i, in.n, out, - ) - } + assert.Equalf( + t, in.n, out, + "[%d] Converted NTP time from time.Time differs", i, + ) } for i, in := range tests { out := toTime(in.n) diff := in.t.Sub(out) - if diff < -absSendTimeResolution || absSendTimeResolution < diff { - t.Errorf("[%d] Converted time.Time from NTP time differs, expected: %v, got: %v", - i, in.t.UTC(), out.UTC(), - ) - } + assert.GreaterOrEqualf( + t, diff, -absSendTimeResolution, + "[%d] Converted time.Time from NTP time differs", i, + ) + assert.LessOrEqual( + t, diff, absSendTimeResolution, + "[%d] Converted time.Time from NTP time differs", i, + ) } } @@ -52,16 +56,14 @@ func TestAbsSendTimeExtension_Roundtrip(t *testing.T) { } for i, in := range tests { b, err := in.Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + var out AbsSendTimeExtension - if err = out.Unmarshal(b); err != nil { - t.Fatal(err) - } - if in.Timestamp != out.Timestamp { - t.Errorf("[%d] Timestamp differs, expected: %d, got: %d", i, in.Timestamp, out.Timestamp) - } + assert.NoError(t, out.Unmarshal(b)) + assert.Equalf( + t, in.Timestamp, out.Timestamp, + "[%d] Timestamp differs", i, + ) } } @@ -77,20 +79,21 @@ func TestAbsSendTimeExtension_Estimate(t *testing.T) { inTime := toTime(in.sendNTP) send := &AbsSendTimeExtension{in.sendNTP >> 14} b, err := send.Marshal() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) var received AbsSendTimeExtension - if err = received.Unmarshal(b); err != nil { - t.Fatal(err) - } + assert.NoError(t, received.Unmarshal(b)) estimated := received.Estimate(toTime(in.receiveNTP)) diff := estimated.Sub(inTime) - if diff < -absSendTimeResolution || absSendTimeResolution < diff { - t.Errorf("[%d] Estimated time differs, expected: %v, estimated: %v (receive time: %v)", - i, inTime.UTC(), estimated.UTC(), toTime(in.receiveNTP).UTC(), - ) - } + assert.GreaterOrEqualf( + t, diff, -absSendTimeResolution, + "[%d] Estimated time differs, expected: %v, estimated: %v (receive time: %v)", + i, inTime.UTC(), estimated.UTC(), toTime(in.receiveNTP).UTC(), + ) + assert.LessOrEqual( + t, diff, absSendTimeResolution, + "[%d] Estimated time differs, expected: %v, estimated: %v (receive time: %v)", + i, inTime.UTC(), estimated.UTC(), toTime(in.receiveNTP).UTC(), + ) } } diff --git a/audiolevelextension_test.go b/audiolevelextension_test.go index 9c651a0b..25c06253 100644 --- a/audiolevelextension_test.go +++ b/audiolevelextension_test.go @@ -4,71 +4,49 @@ package rtp import ( - "bytes" - "errors" "testing" + + "github.com/stretchr/testify/assert" ) func TestAudioLevelExtensionTooSmall(t *testing.T) { a := AudioLevelExtension{} - rawData := []byte{} - - if err := a.Unmarshal(rawData); !errors.Is(err, errTooSmall) { - t.Fatal("err != errTooSmall") - } + assert.ErrorIs(t, a.Unmarshal(rawData), errTooSmall) } func TestAudioLevelExtensionVoiceTrue(t *testing.T) { a1 := AudioLevelExtension{} - rawData := []byte{ 0x88, } - - if err := a1.Unmarshal(rawData); err != nil { - t.Fatal("Unmarshal error on extension data") - } + assert.NoError(t, a1.Unmarshal(rawData)) a2 := AudioLevelExtension{ Level: 8, Voice: true, } - - if a1 != a2 { - t.Error("Unmarshal failed") - } + assert.Equal(t, a2, a1) dstData, _ := a2.Marshal() - if !bytes.Equal(dstData, rawData) { - t.Error("Marshal failed") - } + assert.Equal(t, rawData, dstData) } func TestAudioLevelExtensionVoiceFalse(t *testing.T) { a1 := AudioLevelExtension{} - rawData := []byte{ 0x8, } - - if err := a1.Unmarshal(rawData); err != nil { - t.Fatal("Unmarshal error on extension data") - } + assert.NoError(t, a1.Unmarshal(rawData)) a2 := AudioLevelExtension{ Level: 8, Voice: false, } - - if a1 != a2 { - t.Error("Unmarshal failed") - } + assert.Equal(t, a2, a1) dstData, _ := a2.Marshal() - if !bytes.Equal(dstData, rawData) { - t.Error("Marshal failed") - } + assert.Equal(t, rawData, dstData) } func TestAudioLevelExtensionLevelOverflow(t *testing.T) { @@ -77,7 +55,6 @@ func TestAudioLevelExtensionLevelOverflow(t *testing.T) { Voice: false, } - if _, err := a.Marshal(); !errors.Is(err, errAudioLevelOverflow) { - t.Fatal("err != errAudioLevelOverflow") - } + _, err := a.Marshal() + assert.ErrorIs(t, err, errAudioLevelOverflow) } diff --git a/codecs/av1/frame/av1_test.go b/codecs/av1/frame/av1_test.go index 2e6c2b40..65f73974 100644 --- a/codecs/av1/frame/av1_test.go +++ b/codecs/av1/frame/av1_test.go @@ -4,10 +4,10 @@ package frame import ( - "reflect" "testing" "github.com/pion/rtp/codecs" + "github.com/stretchr/testify/assert" ) // First is Fragment (and no buffer) @@ -17,34 +17,22 @@ func TestAV1_ReadFrames(t *testing.T) { // First is Fragment of OBU, but no OBU Elements is cached fragm := &AV1{} frames, err := fragm.ReadFrames(&codecs.AV1Packet{Z: true, OBUElements: [][]byte{{0x01}}}) - if err != nil { - t.Fatal(err) - } else if !reflect.DeepEqual(frames, [][]byte{}) { - t.Fatalf("No frames should be generated, %v", frames) - } + assert.NoError(t, err) + assert.Equal(t, [][]byte{}, frames, "No frames should be generated") fragm = &AV1{} frames, err = fragm.ReadFrames(&codecs.AV1Packet{OBUElements: [][]byte{{0x01}}}) - if err != nil { - t.Fatal(err) - } else if !reflect.DeepEqual(frames, [][]byte{{0x01}}) { - t.Fatalf("One frame should be generated, %v", frames) - } + assert.NoError(t, err) + assert.Equal(t, [][]byte{{0x01}}, frames, "One frame should be generated") fragm = &AV1{} frames, err = fragm.ReadFrames(&codecs.AV1Packet{Y: true, OBUElements: [][]byte{{0x00}}}) - if err != nil { - t.Fatal(err) - } else if !reflect.DeepEqual(frames, [][]byte{}) { - t.Fatalf("No frames should be generated, %v", frames) - } + assert.NoError(t, err) + assert.Equal(t, [][]byte{}, frames, "No frames should be generated") frames, err = fragm.ReadFrames(&codecs.AV1Packet{Z: true, OBUElements: [][]byte{{0x01}}}) - if err != nil { - t.Fatal(err) - } else if !reflect.DeepEqual(frames, [][]byte{{0x00, 0x01}}) { - t.Fatalf("One frame should be generated, %v", frames) - } + assert.NoError(t, err) + assert.Equal(t, [][]byte{{0x00, 0x01}}, frames, "One frame should be generated") } // Marshal some AV1 Frames to RTP, assert that AV1 can get them back in the original format. @@ -81,14 +69,14 @@ func TestAV1_ReadFrames_E2E(t *testing.T) { for _, originalFrame := range frames { for _, payload := range payloader.Payload(mtu, originalFrame) { rtpPacket := &codecs.AV1Packet{} - if _, err := rtpPacket.Unmarshal(payload); err != nil { - t.Fatal(err) - } + _, err := rtpPacket.Unmarshal(payload) + assert.NoError(t, err) + decodedFrame, err := f.ReadFrames(rtpPacket) - if err != nil { - t.Fatal(err) - } else if len(decodedFrame) != 0 && !reflect.DeepEqual(originalFrame, decodedFrame[0]) { - t.Fatalf("Decode(%02x) and Original(%02x) are not equal", decodedFrame[0], originalFrame) + assert.NoError(t, err) + + if len(decodedFrame) != 0 { + assert.Equal(t, originalFrame, decodedFrame[0]) } } } diff --git a/codecs/av1/obu/leb128_test.go b/codecs/av1/obu/leb128_test.go index 54646498..a7bf2150 100644 --- a/codecs/av1/obu/leb128_test.go +++ b/codecs/av1/obu/leb128_test.go @@ -5,10 +5,11 @@ package obu import ( "encoding/hex" - "errors" "fmt" "math" "testing" + + "github.com/stretchr/testify/assert" ) func TestLEB128(t *testing.T) { @@ -23,25 +24,19 @@ func TestLEB128(t *testing.T) { test := test encoded := EncodeLEB128(test.Value) - if encoded != test.Encoded { - t.Fatalf("Actual(%d) did not equal expected(%d)", encoded, test.Encoded) - } + assert.Equal(t, test.Encoded, encoded) decoded := decodeLEB128(encoded) - if decoded != test.Value { - t.Fatalf("Actual(%d) did not equal expected(%d)", decoded, test.Value) - } + assert.Equal(t, test.Value, decoded) } } func TestReadLeb128(t *testing.T) { - if _, _, err := ReadLeb128(nil); !errors.Is(err, ErrFailedToReadLEB128) { - t.Fatal("ReadLeb128 on a nil buffer should return an error") - } + _, _, err := ReadLeb128(nil) + assert.ErrorIs(t, err, ErrFailedToReadLEB128, "ReadLeb128 on a nil buffer should return an error") - if _, _, err := ReadLeb128([]byte{0xFF}); !errors.Is(err, ErrFailedToReadLEB128) { - t.Fatal("ReadLeb128 on a buffer with all MSB set should fail") - } + _, _, err = ReadLeb128([]byte{0xFF}) + assert.ErrorIs(t, err, ErrFailedToReadLEB128, "ReadLeb128 on a buffer with all MSB set should return an error") } func TestWriteToLeb128(t *testing.T) { @@ -64,9 +59,7 @@ func TestWriteToLeb128(t *testing.T) { t.Helper() b := WriteToLeb128(v.value) - if v.leb128 != hex.EncodeToString(b) { - t.Errorf("Expected %s, got %s", v.leb128, hex.EncodeToString(b)) - } + assert.Equal(t, v.leb128, hex.EncodeToString(b)) } for _, v := range testVectors { diff --git a/codecs/av1/obu/obu_test.go b/codecs/av1/obu/obu_test.go index a8ab1cd2..6267e84b 100644 --- a/codecs/av1/obu/obu_test.go +++ b/codecs/av1/obu/obu_test.go @@ -4,9 +4,9 @@ package obu import ( - "bytes" - "errors" "testing" + + "github.com/stretchr/testify/assert" ) func TestOBUType(t *testing.T) { @@ -28,14 +28,8 @@ func TestOBUType(t *testing.T) { {Type(9), 9, "OBU_RESERVED"}, } { test := test - - if test.Type.String() != test.Str { - t.Errorf("Expected %s, got %s", test.Str, test.Type.String()) - } - - if uint8(test.Type) != test.TypeValue { - t.Errorf("Expected %d, got %d", test.TypeValue, uint8(test.Type)) - } + assert.Equal(t, test.Str, test.Type.String()) + assert.Equal(t, test.TypeValue, uint8(test.Type)) } } @@ -65,27 +59,13 @@ func TestOBUHeader_NoExtension(t *testing.T) { buff := []byte{test.Value} header, err := ParseOBUHeader(buff) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if *header != test.Header { - t.Errorf("Expected %v, got %v", test.Header, *header) - } - - if test.Header.Size() != 1 { - t.Errorf("Expected size 1 for header without extension, got %d", test.Header.Size()) - } + assert.NoError(t, err) + assert.Equal(t, test.Header, *header) + assert.Equal(t, 1, header.Size()) value := test.Header.Marshal() - - if len(value) != 1 { - t.Errorf("Expected size 1 for header without extension, got %d", len(value)) - } - - if value[0] != test.Value { - t.Errorf("Expected %d for header value, got %d", test.Value, value[0]) - } + assert.Len(t, value, 1, "Expected size 1 for header without extension") + assert.Equal(t, test.Value, value[0]) } } @@ -139,103 +119,62 @@ func TestOBUHeader_Extension(t *testing.T) { buff := []byte{test.HeaderValue, test.ExtensionHeaderValue} header, err := ParseOBUHeader(buff) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + assert.NoError(t, err) expected := Header{ Type: test.Header.Type, HasSizeField: test.Header.HasSizeField, Reserved1Bit: test.Header.Reserved1Bit, } - if expected != test.Header { - t.Errorf("Expected %v, got %v", test.Header, *header) - } - - if header.Size() != 2 { - t.Errorf("Expected size 2 for header with extension, got %d", test.Header.Size()) - } + assert.Equal(t, expected, test.Header) + assert.Equal(t, 2, header.Size()) extension := header.ExtensionHeader - if extension == nil { - t.Fatalf("Expected extension header to be present") - } - - if *extension != test.ExtensionHeader { - t.Errorf("Expected %v, got %v", test.ExtensionHeader, *extension) - } - - if extension.Marshal() != test.ExtensionHeaderValue { - t.Errorf("Expected %d for extension header value, got %d", test.ExtensionHeaderValue, extension.Marshal()) - } + assert.NotNil(t, extension) + assert.Equal(t, test.ExtensionHeader, *extension) + assert.Equal(t, test.ExtensionHeaderValue, extension.Marshal()) value := header.Marshal() - if len(value) != 2 { - t.Errorf("Expected size 2 for header with extension, got %d", len(value)) - } - - if !bytes.Equal(value, buff) { - t.Errorf("Expected %v for header value, got %v", buff, value) - } + assert.Lenf( + t, value, 2, + "Expected size 2 for header with extension, got %d", len(value), + ) + assert.Equal(t, buff, value) } } func TestOBUHeader_Short(t *testing.T) { _, err := ParseOBUHeader([]byte{}) - if err == nil { - t.Fatalf("Expected error, got nil") - } - if !errors.Is(err, ErrShortHeader) { - t.Errorf("Expected ErrShortHeader, got %v", err) - } + assert.ErrorIs(t, err, ErrShortHeader) // Missing extension header _, err = ParseOBUHeader([]byte{0b0_0000_1_0_0}) - if err == nil { - t.Fatalf("Expected error, got nil") - } - - if !errors.Is(err, ErrShortHeader) { - t.Errorf("Expected ErrShortHeader, got %v", err) - } + assert.ErrorIs(t, err, ErrShortHeader) } func TestOBUHeader_Invalid(t *testing.T) { + // forbidden bit is set _, err := ParseOBUHeader([]byte{0b1_0010_0_0_1}) - if err == nil { - t.Fatalf("Expected error, got nil") - } - if !errors.Is(err, ErrInvalidOBUHeader) { - t.Errorf("Expected ErrInvalidOBUHeader, got %v", err) - } + assert.ErrorIs(t, err, ErrInvalidOBUHeader) } func TestOBUHeader_MarshalOutbound(t *testing.T) { // Marshal should turnicate the extension header values. header := Header{Type: Type(255)} - if header.Marshal()[0] != 0b0_1111_000 { - t.Errorf("Expected 0b0_1111_000, got %b", header.Marshal()[0]) - } + assert.Equal(t, uint8(0b0_1111_000), header.Marshal()[0]) extentionHeader := ExtensionHeader{TemporalID: 255} - - if extentionHeader.Marshal() != 0b111_00_000 { - t.Errorf("Expected 0b111_00_000, got %b", extentionHeader.Marshal()) - } + assert.Equal(t, uint8(0b111_00_000), extentionHeader.Marshal()) extensionHeader := ExtensionHeader{SpatialID: 255} - if extensionHeader.Marshal() != 0b000_11_000 { - t.Errorf("Expected 0b000_11_000, got %b", extensionHeader.Marshal()) - } + assert.Equal(t, uint8(0b000_11_000), extensionHeader.Marshal()) extensionHeader = ExtensionHeader{Reserved3Bits: 255} - if extensionHeader.Marshal() != 0b000_00_111 { - t.Errorf("Expected 0b000_00_111, got %b", extensionHeader.Marshal()) - } + assert.Equal(t, uint8(0b000_00_111), extensionHeader.Marshal()) } func TestOBUMarshal(t *testing.T) { - obu := OBU{ + testOBU := OBU{ Header: Header{ Type: OBUFrame, HasSizeField: false, @@ -244,23 +183,14 @@ func TestOBUMarshal(t *testing.T) { Payload: []byte{0x01, 0x02, 0x03}, } - data := obu.Marshal() - - if len(data) != 4 { - t.Fatalf("Expected 4 bytes, got %d", len(data)) - } - - if data[0] != obu.Header.Marshal()[0] { - t.Errorf("Expected header to be %v, got %v", obu.Header.Marshal(), data[0]) - } - - if !bytes.Equal(data[1:], obu.Payload) { - t.Errorf("Expected payload to be %v, got %v", obu.Payload, data[1:]) - } + data := testOBU.Marshal() + assert.Len(t, data, 4) + assert.Equal(t, testOBU.Header.Marshal()[0], data[0], "Expected header to be equal") + assert.Equal(t, testOBU.Payload, data[1:]) } func TestOBUMarshal_ExtensionHeader(t *testing.T) { - obu := OBU{ + testOBU := OBU{ Header: Header{ Type: OBUFrame, HasSizeField: false, @@ -272,24 +202,11 @@ func TestOBUMarshal_ExtensionHeader(t *testing.T) { }, Payload: []byte{0x01, 0x02, 0x03}, } - - data := obu.Marshal() - - if len(data) != 5 { - t.Fatalf("Expected 5 bytes, got %d", len(data)) - } - - if data[0] != obu.Header.Marshal()[0] { - t.Errorf("Expected header to be %v, got %v", obu.Header.Marshal(), data[0]) - } - - if data[1] != obu.Header.ExtensionHeader.Marshal() { - t.Errorf("Expected extension header to be %v, got %v", obu.Header.ExtensionHeader.Marshal(), data[1]) - } - - if !bytes.Equal(data[2:], obu.Payload) { - t.Errorf("Expected payload to be %v, got %v", obu.Payload, data[1:]) - } + data := testOBU.Marshal() + assert.Len(t, data, 5) + assert.Equal(t, testOBU.Header.Marshal()[0], data[0], "Expected header to be equal") + assert.Equal(t, testOBU.Header.ExtensionHeader.Marshal(), data[1], "Expected extension header to equal") + assert.Equal(t, testOBU.Payload, data[2:]) } func TestOBUMarshal_HasOBUSize(t *testing.T) { @@ -300,7 +217,7 @@ func TestOBUMarshal_HasOBUSize(t *testing.T) { payload[i] = byte(i) } - obu := OBU{ + testOBU := OBU{ Header: Header{ Type: OBUFrame, HasSizeField: true, @@ -309,57 +226,37 @@ func TestOBUMarshal_HasOBUSize(t *testing.T) { Payload: payload, } expected := append( - obu.Header.Marshal(), + testOBU.Header.Marshal(), append( // obu_size leb128 (128) []byte{0x80, 0x01}, - obu.Payload..., + testOBU.Payload..., )..., ) - data := obu.Marshal() - - if len(data) != payloadSize+3 { - t.Fatalf("Expected 4 bytes, got %d", len(data)) - } - - if data[0] != obu.Header.Marshal()[0] { - t.Errorf("Expected header to be %v, got %v", obu.Header.Marshal(), data[0]) - } - - if !bytes.Equal(data, expected) { - t.Errorf("Expected payload to be %v, got %v", expected, data) - } + data := testOBU.Marshal() + assert.Len(t, data, payloadSize+3) + assert.Equal(t, testOBU.Header.Marshal()[0], data[0], "Expected header to be equal") + assert.Equal(t, expected, data) } func TestOBUMarshal_ZeroPayload(t *testing.T) { - obu := OBU{ + testOBU := OBU{ Header: Header{ Type: OBUTemporalDelimiter, HasSizeField: false, }, } + data := testOBU.Marshal() + assert.Len(t, data, 1) - data := obu.Marshal() - - if len(data) != 1 { - t.Fatalf("Expected 1 byte, got %d", len(data)) - } - - obu = OBU{ + testOBU = OBU{ Header: Header{ Type: OBUTemporalDelimiter, HasSizeField: true, }, } - - data = obu.Marshal() - - if len(data) != 2 { - t.Fatalf("Expected two bytes, got %d", len(data)) - } - - if data[1] != 0 { - t.Errorf("Expected 0 for size, got %d", data[1]) - } + data = testOBU.Marshal() + assert.Len(t, data, 2) + assert.Equal(t, uint8(0), data[1], "Expected 0 for size") } diff --git a/codecs/av1_depacketizer_test.go b/codecs/av1_depacketizer_test.go index 73d68cd3..6249feb3 100644 --- a/codecs/av1_depacketizer_test.go +++ b/codecs/av1_depacketizer_test.go @@ -4,11 +4,10 @@ package codecs import ( - "bytes" - "errors" "testing" "github.com/pion/rtp/codecs/av1/obu" + "github.com/stretchr/testify/assert" ) // Create an AV1 OBU for testing. Returns one without the obu_size_field and another with it included. @@ -36,23 +35,16 @@ func createTestPayload(obuHeader obu.Header, payload []byte) []byte { func TestAV1Depacketizer_invalidPackets(t *testing.T) { depacketizer := AV1Depacketizer{} _, err := depacketizer.Unmarshal([]byte{}) - if !errors.Is(err, errShortPacket) { - t.Fatalf("Unexpected error: %v", err) - } + assert.ErrorIs(t, err, errShortPacket) + _, err = depacketizer.Unmarshal([]byte{0b11000000, 0xFF}) - if !errors.Is(err, obu.ErrFailedToReadLEB128) { - t.Fatalf("Unexpected error: %v", err) - } + assert.ErrorIs(t, err, obu.ErrFailedToReadLEB128) _, err = depacketizer.Unmarshal(append([]byte{0b00000000}, obu.WriteToLeb128(0x99)...)) - if !errors.Is(err, errShortPacket) { - t.Fatalf("Unexpected error: %v", err) - } + assert.ErrorIs(t, err, errShortPacket) _, err = depacketizer.Unmarshal(append([]byte{0b00000000}, obu.WriteToLeb128(0x01)...)) - if !errors.Is(err, errShortPacket) { - t.Fatalf("Unexpected error: %v", err) - } + assert.ErrorIs(t, err, errShortPacket) _, err = depacketizer.Unmarshal( append( @@ -63,9 +55,7 @@ func TestAV1Depacketizer_invalidPackets(t *testing.T) { )..., ), ) - if !errors.Is(err, errShortPacket) { - t.Fatalf("Unexpected error: %v", err) - } + assert.ErrorIs(t, err, errShortPacket) } func TestAV1Depacketizer_singleOBU(t *testing.T) { @@ -80,13 +70,8 @@ func TestAV1Depacketizer_singleOBU(t *testing.T) { d := AV1Depacketizer{} obu, err := d.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(obu, expectedOBU) { - t.Fatalf("OBU data mismatch, expected %v, got %v", expectedOBU, obu) - } + assert.NoError(t, err) + assert.Equal(t, expectedOBU, obu) } func TestAV1Depacketizer_singleOBUWithPadding(t *testing.T) { @@ -103,13 +88,8 @@ func TestAV1Depacketizer_singleOBUWithPadding(t *testing.T) { d := AV1Depacketizer{} obu, err := d.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(obu, expectedOBU) { - t.Fatalf("OBU data mismatch, expected %v, got %v", expectedOBU, obu) - } + assert.NoError(t, err) + assert.Equal(t, expectedOBU, obu) } // AV1 OBUs shouldn't include the obu_size_field when packetized in RTP, @@ -126,13 +106,8 @@ func TestAV1Depacketizer_withOBUSize(t *testing.T) { d := AV1Depacketizer{} obu, err := d.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(obu, obuData) { - t.Fatalf("OBU data mismatch, expected %v, got %v", obuData, obu) - } + assert.NoError(t, err) + assert.Equal(t, obuData, obu) } func TestAV1Depacketizer_validateOBUSize(t *testing.T) { @@ -179,9 +154,7 @@ func TestAV1Depacketizer_validateOBUSize(t *testing.T) { t.Run(tt.name, func(t *testing.T) { d := AV1Depacketizer{} _, err := d.Unmarshal(tt.payload) - if !errors.Is(err, tt.err) { - t.Fatalf("Expected error %v, got %v", tt.err, err) - } + assert.ErrorIs(t, err, tt.err) }) } } @@ -189,13 +162,8 @@ func TestAV1Depacketizer_validateOBUSize(t *testing.T) { func TestAV1Depacketizer_dropBuffer(t *testing.T) { depacketizer := &AV1Depacketizer{} empty, err := depacketizer.Unmarshal([]byte{0x41, 0x02, 0x00, 0x01}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if len(empty) != 0 { - t.Fatalf("Expected empty OBU") - } + assert.NoError(t, err) + assert.Len(t, empty, 0) payload := []byte{0x08, 0x02, 0x03} obuData, expectedOBU := createAV1OBU(4, payload) @@ -208,13 +176,8 @@ func TestAV1Depacketizer_dropBuffer(t *testing.T) { packet = append(packet, obuData...) obu, err := depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(obu, expectedOBU) { - t.Fatalf("OBU data mismatch, expected %v, got %v", expectedOBU, obu) - } + assert.NoError(t, err) + assert.Equal(t, expectedOBU, obu) } func TestAV1Depacketizer_singleOBUWithW(t *testing.T) { @@ -225,13 +188,8 @@ func TestAV1Depacketizer_singleOBUWithW(t *testing.T) { d := AV1Depacketizer{} obu, err := d.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(obu, expectedOBU) { - t.Fatalf("OBU data mismatch, expected %v, got %v", obuData, obu) - } + assert.NoError(t, err) + assert.Equal(t, expectedOBU, obu) } func TestDepacketizer_multipleFullOBUs(t *testing.T) { @@ -252,13 +210,8 @@ func TestDepacketizer_multipleFullOBUs(t *testing.T) { d := AV1Depacketizer{} obus, err := d.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(obus, expected) { - t.Fatalf("OBU data mismatch, expected %v, got %v", expected, obus) - } + assert.NoError(t, err) + assert.Equal(t, expected, obus) } func TestAV1Depacketizer_multipleFullOBUsWithW(t *testing.T) { @@ -279,13 +232,8 @@ func TestAV1Depacketizer_multipleFullOBUsWithW(t *testing.T) { depacketizer := AV1Depacketizer{} obus, err := depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(obus, expected) { - t.Fatalf("OBU data mismatch, expected %v, got %v", expected, obus) - } + assert.NoError(t, err) + assert.Equal(t, expected, obus) } func TestDepacketizer_fragmentedOBUS(t *testing.T) { @@ -316,17 +264,12 @@ func TestDepacketizer_fragmentedOBUS(t *testing.T) { packet = append(packet, obu3f1...) obus, err := depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + assert.NoError(t, err) expected := make([]byte, 0) expected = append(expected, expectedOBU1...) expected = append(expected, expectedOBU2...) - - if !bytes.Equal(obus, expected) { - t.Fatalf("OBU data mismatch, expected %v, got %v", expected, obus) - } + assert.Equal(t, expected, obus) packet = make([]byte, 0) packet = append(packet, []byte{0b11000000}...) @@ -340,14 +283,10 @@ func TestDepacketizer_fragmentedOBUS(t *testing.T) { packet = append(packet, obu6f1...) obus, err = depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + assert.NoError(t, err) expected = append(append(expectedOBU3, expectedOBU4...), expectedOBU5...) - if !bytes.Equal(obus, expected) { - t.Fatalf("OBU data mismatch, expected %v, got %v", expected, obus) - } + assert.Equal(t, expected, obus) packet = make([]byte, 0) packet = append(packet, []byte{0b10100000}...) @@ -357,17 +296,12 @@ func TestDepacketizer_fragmentedOBUS(t *testing.T) { packet = append(packet, obu7...) obus, err = depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + assert.NoError(t, err) expected = make([]byte, 0) expected = append(expected, expectedOBU6...) expected = append(expected, expectedOBU7...) - - if !bytes.Equal(obus, expected) { - t.Fatalf("OBU data mismatch, expected %v, got %v", expected, obus) - } + assert.Equal(t, expected, obus) packet = make([]byte, 0) packet = append(packet, []byte{0b00000000}...) @@ -375,13 +309,8 @@ func TestDepacketizer_fragmentedOBUS(t *testing.T) { packet = append(packet, obu8...) obus, err = depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(obus, expectedOBU8) { - t.Fatalf("OBU data mismatch, expected %v, got %v", expected, obus) - } + assert.NoError(t, err) + assert.Equal(t, expectedOBU8, obus) } func TestAV1Depacketizer_dropLostFragment(t *testing.T) { @@ -393,13 +322,8 @@ func TestAV1Depacketizer_dropLostFragment(t *testing.T) { []byte{0x01, 0x02, 0x03}..., ), ) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if len(obus) != 0 { - t.Fatalf("Expected empty OBU for fragmented OBU") - } + assert.NoError(t, err) + assert.Len(t, obus, 0, "Expected empty OBU for fragmented OBU") newOBU, expected := createAV1OBU(obu.OBUTileGroup, []byte{0x04, 0x05, 0x06}) obus, err = depacketizer.Unmarshal( @@ -408,13 +332,8 @@ func TestAV1Depacketizer_dropLostFragment(t *testing.T) { newOBU..., ), ) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(obus, expected) { - t.Fatalf("Expected OBU data to be %v, got %v", newOBU, obus) - } + assert.NoError(t, err) + assert.Equal(t, expected, obus) } func TestAV1Depacketizer_dropIfLostFragment(t *testing.T) { @@ -426,13 +345,8 @@ func TestAV1Depacketizer_dropIfLostFragment(t *testing.T) { []byte{0x01, 0x02, 0x03}..., ), ) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if len(obus) != 0 { - t.Fatalf("Expected empty OBU for fragmented OBU") - } + assert.NoError(t, err) + assert.Len(t, obus, 0, "Expected empty OBU for fragmented OBU") newOBU, expected := createAV1OBU(obu.OBUTileGroup, []byte{0x04, 0x05, 0x06}) obus, err = depacketizer.Unmarshal( @@ -441,13 +355,8 @@ func TestAV1Depacketizer_dropIfLostFragment(t *testing.T) { newOBU..., ), ) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(obus, expected) { - t.Fatalf("Expected OBU data to be %v, got %v", newOBU, obus) - } + assert.NoError(t, err) + assert.Equal(t, expected, obus) packet := make([]byte, 0) packet = append(packet, []byte{0b10000000}...) @@ -457,13 +366,8 @@ func TestAV1Depacketizer_dropIfLostFragment(t *testing.T) { packet = append(packet, newOBU...) obus, err = depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(obus, expected) { - t.Fatalf("Expected OBU data to be %v, got %v", newOBU, obus) - } + assert.NoError(t, err) + assert.Equal(t, expected, obus) } func TestAV1Depacketizer_IsPartitionTail(t *testing.T) { @@ -471,37 +375,18 @@ func TestAV1Depacketizer_IsPartitionTail(t *testing.T) { buffer: []byte{1, 2}, } - if depacketizer.IsPartitionTail(false, []byte{1, 2}) { - t.Fatalf("Expected false") - } - - if !bytes.Equal(depacketizer.buffer, []byte{1, 2}) { - t.Fatalf("Buffer was modified") - } - - if !depacketizer.IsPartitionTail(true, []byte{1, 2}) { - t.Fatalf("Expected true") - } + assert.False(t, depacketizer.IsPartitionTail(false, []byte{1, 2})) + assert.Equal(t, depacketizer.buffer, []byte{1, 2}) + assert.True(t, depacketizer.IsPartitionTail(true, []byte{1, 2})) } func TestAV1Depacketizer_IsPartitionHead(t *testing.T) { depacketizer := &AV1Depacketizer{} - if depacketizer.IsPartitionHead(nil) { - t.Fatalf("Expected false") - } - - if depacketizer.IsPartitionHead([]byte{}) { - t.Fatalf("Expected false") - } - - if depacketizer.IsPartitionHead([]byte{0b11000000}) { - t.Fatalf("Expected false") - } - - if !depacketizer.IsPartitionHead([]byte{0b00000000}) { - t.Fatalf("Expected true") - } + assert.False(t, depacketizer.IsPartitionHead(nil)) + assert.False(t, depacketizer.IsPartitionHead([]byte{})) + assert.False(t, depacketizer.IsPartitionHead([]byte{0b11000000})) + assert.True(t, depacketizer.IsPartitionHead([]byte{0b00000000})) } func TestAV1Depacketizer_ignoreBadOBUs(t *testing.T) { @@ -521,13 +406,8 @@ func TestAV1Depacketizer_ignoreBadOBUs(t *testing.T) { depacketizer := AV1Depacketizer{} obu, err := depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if len(obu) != 0 { - t.Fatalf("Expected empty OBU for OBU type %d", obuType) - } + assert.NoError(t, err) + assert.Len(t, obu, 0, "Expected empty payload for OBU type %d", obuType) } } @@ -549,13 +429,8 @@ func TestAV1Depacketizer_fragmentedOverMultiple(t *testing.T) { packet = append(packet, obuf1...) obus, err := depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if len(obus) != 0 { - t.Fatalf("Expected empty OBU for fragmented OBU") - } + assert.NoError(t, err) + assert.Len(t, obus, 0, "Expected empty OBU for fragmented OBU") packet = make([]byte, 0) packet = append(packet, []byte{0b11000000}...) @@ -563,13 +438,8 @@ func TestAV1Depacketizer_fragmentedOverMultiple(t *testing.T) { packet = append(packet, obuf2...) obus, err = depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if len(obus) != 0 { - t.Fatalf("Expected empty OBU for fragmented OBU") - } + assert.NoError(t, err) + assert.Len(t, obus, 0, "Expected empty OBU for fragmented OBU") packet = make([]byte, 0) packet = append(packet, []byte{0b11000000}...) @@ -577,13 +447,8 @@ func TestAV1Depacketizer_fragmentedOverMultiple(t *testing.T) { packet = append(packet, obuf3...) obus, err = depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if len(obus) != 0 { - t.Fatalf("Expected empty OBU for fragmented OBU") - } + assert.NoError(t, err) + assert.Len(t, obus, 0, "Expected empty OBU for fragmented OBU") packet = make([]byte, 0) packet = append(packet, []byte{0b10000000}...) @@ -591,26 +456,16 @@ func TestAV1Depacketizer_fragmentedOverMultiple(t *testing.T) { packet = append(packet, obuf4...) obus, err = depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(obus, expected) { - t.Fatalf("Expected OBU data to be %v, got %v", expected, obus) - } + assert.NoError(t, err) + assert.Equal(t, expected, obus) } func TestAV1Depacketizer_shortOBUHeader(t *testing.T) { d := AV1Depacketizer{} payload, err := d.Unmarshal([]byte{0x00, 0x01, 0x04}) - if err == nil { - t.Fatalf("Expected error") - } - - if len(payload) != 0 { - t.Fatalf("Expected empty payload") - } + assert.Error(t, err) + assert.Len(t, payload, 0, "Expected empty payload for short OBU header") } func TestAV1Depacketizer_aggregationHeader(t *testing.T) { @@ -667,25 +522,12 @@ func TestAV1Depacketizer_aggregationHeader(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { payload, err := depacketizer.Unmarshal(tt.input) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if !bytes.Equal(payload, tt.payload) { - t.Fatalf("Expected payload to be %v, got %v", tt.payload, payload) - } - - if depacketizer.Z != tt.Z { - t.Fatalf("Expected Z to be %v, got %v", tt.Z, depacketizer.Z) - } - - if depacketizer.Y != tt.Y { - t.Fatalf("Expected Y to be %v, got %v", tt.Y, depacketizer.Y) - } + assert.NoError(t, err) - if depacketizer.N != tt.N { - t.Fatalf("Expected N to be %v, got %v", tt.N, depacketizer.N) - } + assert.Equal(t, tt.payload, payload) + assert.Equal(t, tt.Z, depacketizer.Z) + assert.Equal(t, tt.Y, depacketizer.Y) + assert.Equal(t, tt.N, depacketizer.N) }) } } diff --git a/codecs/av1_packet_test.go b/codecs/av1_packet_test.go index 3b26d017..a4b1201d 100644 --- a/codecs/av1_packet_test.go +++ b/codecs/av1_packet_test.go @@ -4,13 +4,11 @@ package codecs import ( - "bytes" - "errors" "fmt" - "reflect" "testing" "github.com/pion/rtp/codecs/av1/obu" + "github.com/stretchr/testify/assert" ) type testAV1AggregationHeader struct { @@ -101,14 +99,10 @@ func testAV1TestRun(t *testing.T, tests []testAV1Tests) { t.Run(test.Name, func(t *testing.T) { result := payloader.Payload(test.MTU, test.InputPayload) - if len(result) != len(test.OutputPayloads) { - t.Fatalf("Expected %d payloads but got %d", len(test.OutputPayloads), len(result)) - } + assert.Equal(t, len(test.OutputPayloads), len(result)) for i := range result { - if !bytes.Equal(result[i], test.OutputPayloads[i]) { - t.Fatalf("Expected %v but got %v for payload #%d", test.OutputPayloads[i], result[i], i+1) - } + assert.Equal(t, test.OutputPayloads[i], result[i]) } }) } @@ -117,14 +111,10 @@ func testAV1TestRun(t *testing.T, tests []testAV1Tests) { func TestAV1Payloader_ShortMtU(t *testing.T) { p := &AV1Payloader{} - if result := p.Payload(0, []byte{0x00, 0x01, 0x18}); len(result) != 0 { - t.Errorf("Expected empty payload but got %v", result) - } - + assert.Len(t, p.Payload(0, []byte{0x00, 0x01, 0x18}), 0, "Expected empty payload") + assert.Len(t, p.Payload(1, []byte{0x00, 0x01, 0x18}), 0, "Expected empty payload") // 2 is the minimum MTU for AV1 (aggregate header + 1 byte) - if result := p.Payload(1, []byte{0x00, 0x01, 0x18}); len(result) != 0 { - t.Errorf("Expected empty payload but got %v", result) - } + assert.Greater(t, len(p.Payload(2, []byte{0x00, 0x01, 0x18})), 0) } func TestAV1Payloader_SinglePacket(t *testing.T) { @@ -1758,13 +1748,9 @@ func TestAV1Payloader_Leb128Size(t *testing.T) { for _, test := range tests { actual, edge := payloader.leb128Size(test.leb128) - if actual != test.size { - t.Fatalf("Expected size %d but got %d", test.size, actual) - } - if edge != test.edge { - t.Fatalf("Expected edge %t but got %t", test.edge, edge) - } + assert.Equal(t, test.size, actual) + assert.Equal(t, test.edge, edge) } } @@ -1821,24 +1807,20 @@ func TestAV1_depacketizer_to_packetizer(t *testing.T) { packets := payloader.Payload(mtu, payload) for _, packet := range packets { p, err := depacketizer.Unmarshal(packet) - if err != nil { - t.Fatalf("Failed to depacketize: %v", err) - } - - if len(packet) > int(mtu) { - t.Fatalf("Expected packet size to be %d but got %d", mtu, len(packet)) - } + assert.NoError(t, err) + assert.GreaterOrEqual(t, int(mtu), len(packet), "Expected packet size to be smaller or equal to %d", mtu) result = append(result, p...) } - if len(payload) != len(result) { - t.Fatalf("Expected to packetize and depacketize to be the same for MTU=%d", mtu) - } - - if !bytes.Equal(payload, result) { - t.Fatalf("Expected to packetize and depacketize to be the same for MTU=%d", mtu) - } + assert.Equalf( + t, + len(payload), + len(result), + "Expected to packetize and depacketize to be the same for MTU=%d", + mtu, + ) + assert.Equalf(t, payload, result, "Expected to packetize and depacketize to be the same for MTU=%d", mtu) }) } } @@ -1857,9 +1839,8 @@ func TestAV1_Unmarshal_Error(t *testing.T) { test := test av1Pkt := &AV1Packet{} - if _, err := av1Pkt.Unmarshal(test.input); !errors.Is(err, test.expectedError) { - t.Fatalf("Expected error(%s) but got (%s)", test.expectedError, err) - } + _, err := av1Pkt.Unmarshal(test.input) + assert.ErrorIs(t, err, test.expectedError) } } @@ -2038,11 +2019,10 @@ func TestAV1_Unmarshal(t *testing.T) { } av1Pkt := &AV1Packet{} - if _, err := av1Pkt.Unmarshal(av1Payload); err != nil { - t.Fatal(err) - } + _, err := av1Pkt.Unmarshal(av1Payload) + assert.NoError(t, err) - if !reflect.DeepEqual(av1Pkt, &AV1Packet{ + expect := &AV1Packet{ Z: false, Y: true, W: 2, @@ -2051,7 +2031,6 @@ func TestAV1_Unmarshal(t *testing.T) { av1Payload[2:14], av1Payload[14:], }, - }) { - t.Fatal("AV1 Unmarshal didn't store the expected results in the packet") } + assert.Equal(t, expect, av1Pkt, "AV1 Unmarshal didn't store the expected results in the packet") } diff --git a/codecs/common_test.go b/codecs/common_test.go index c472d2ab..bb009ba3 100644 --- a/codecs/common_test.go +++ b/codecs/common_test.go @@ -5,23 +5,19 @@ package codecs import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestCommon_Min(t *testing.T) { res := minInt(1, -1) - if res != -1 { - t.Fatal("Error: -1 < 1") - } + assert.Equal(t, -1, res) res = minInt(1, 2) - if res != 1 { - t.Fatal("Error: 1 < 2") - } + assert.Equal(t, 1, res) res = minInt(3, 3) - if res != 3 { - t.Fatal("Error: 3 == 3") - } + assert.Equal(t, 3, res) } func TestZeroAllocations(t *testing.T) { //nolint:maintidx @@ -240,13 +236,9 @@ func TestZeroAllocations(t *testing.T) { //nolint:maintidx d.SetZeroAllocation(true) } _, err := test.packet.Unmarshal(test.data) - if err != nil { - t.Errorf("Unmarshal failed: %v", err) - } + assert.NoError(t, err) }) - if allocs != 0 { - t.Errorf("%T: %v allocs", test.packet, allocs) - } + assert.Equal(t, 0.0, allocs) } } diff --git a/codecs/g711_packet_test.go b/codecs/g711_packet_test.go index 20563274..3fee2c32 100644 --- a/codecs/g711_packet_test.go +++ b/codecs/g711_packet_test.go @@ -8,6 +8,8 @@ import ( "crypto/rand" "math" "testing" + + "github.com/stretchr/testify/assert" ) func TestG711Payloader(t *testing.T) { @@ -21,9 +23,7 @@ func TestG711Payloader(t *testing.T) { // generate random 8-bit g722 samples samples := make([]byte, testlen) _, err := rand.Read(samples) - if err != nil { - t.Fatal("RNG Error: ", err) - } + assert.NoError(t, err) // make a copy, for payloader input samplesIn := make([]byte, testlen) @@ -33,43 +33,27 @@ func TestG711Payloader(t *testing.T) { payloads := payloader.Payload(testmtu, samplesIn) outcnt := int(math.Ceil(float64(testlen) / testmtu)) - if len(payloads) != outcnt { - t.Fatalf("Generated %d payloads instead of %d", len(payloads), outcnt) - } - - if !bytes.Equal(samplesIn, samples) { - t.Fatal("Modified input samples") - } + assert.Len(t, payloads, outcnt) + assert.Equal(t, samplesIn, samples, "Modified input samples") samplesOut := bytes.Join(payloads, []byte{}) - - if !bytes.Equal(samplesIn, samplesOut) { - t.Fatal("Output samples don't match") - } + assert.Equal(t, samplesIn, samplesOut) payload := []byte{0x90, 0x90, 0x90} // 0 MTU, small payload res := payloader.Payload(0, payload) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") // Positive MTU, small payload res = payloader.Payload(1, payload) - if len(res) != len(payload) { - t.Fatal("Generated payload should be the same size as original payload size") - } + assert.Len(t, res, len(payload), "Generated payload should be the same size as original payload size") // Positive MTU, small payload res = payloader.Payload(uint16(len(payload)-1), payload) // nolint: gosec // G115 - if len(res) != len(payload)-1 { - t.Fatal("Generated payload should be the same smaller than original payload size") - } + assert.Len(t, res, len(payload)-1, "Generated payload should be the same smaller than original payload size") // Positive MTU, small payload res = payloader.Payload(10, payload) - if len(res) != 1 { - t.Fatal("Generated payload should be 1") - } + assert.Len(t, res, 1, "Generated payload should be the 1") } diff --git a/codecs/g722_packet_test.go b/codecs/g722_packet_test.go index d098e73a..8bcf791a 100644 --- a/codecs/g722_packet_test.go +++ b/codecs/g722_packet_test.go @@ -8,6 +8,8 @@ import ( "crypto/rand" "math" "testing" + + "github.com/stretchr/testify/assert" ) func TestG722Payloader(t *testing.T) { @@ -21,9 +23,7 @@ func TestG722Payloader(t *testing.T) { // generate random 8-bit g722 samples samples := make([]byte, testlen) _, err := rand.Read(samples) - if err != nil { - t.Fatal("RNG Error: ", err) - } + assert.NoError(t, err) // make a copy, for payloader input samplesIn := make([]byte, testlen) @@ -33,43 +33,27 @@ func TestG722Payloader(t *testing.T) { payloads := payloader.Payload(testmtu, samplesIn) outcnt := int(math.Ceil(float64(testlen) / testmtu)) - if len(payloads) != outcnt { - t.Fatalf("Generated %d payloads instead of %d", len(payloads), outcnt) - } - - if !bytes.Equal(samplesIn, samples) { - t.Fatal("Modified input samples") - } + assert.Len(t, payloads, outcnt) + assert.Equal(t, samplesIn, samples, "Modified input samples") samplesOut := bytes.Join(payloads, []byte{}) - - if !bytes.Equal(samplesIn, samplesOut) { - t.Fatal("Output samples don't match") - } + assert.Equal(t, samplesIn, samplesOut, "Output samples don't match") payload := []byte{0x90, 0x90, 0x90} // 0 MTU, small payload res := payloader.Payload(0, payload) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") // Positive MTU, small payload res = payloader.Payload(1, payload) - if len(res) != len(payload) { - t.Fatal("Generated payload should be the same size as original payload size") - } + assert.Len(t, res, len(payload), "Generated payload should be the same size as original payload size") // Positive MTU, small payload res = payloader.Payload(uint16(len(payload)-1), payload) // nolint: gosec // G115 - if len(res) != len(payload)-1 { - t.Fatal("Generated payload should be the same smaller than original payload size") - } + assert.Len(t, res, len(payload)-1, "Generated payload should be the same smaller than original payload size") // Positive MTU, small payload res = payloader.Payload(10, payload) - if len(res) != 1 { - t.Fatal("Generated payload should be 1") - } + assert.Len(t, res, 1, "Generated payload should be the 1") } diff --git a/codecs/h264_packet_test.go b/codecs/h264_packet_test.go index e83de9ef..64b09ae5 100644 --- a/codecs/h264_packet_test.go +++ b/codecs/h264_packet_test.go @@ -4,11 +4,12 @@ package codecs import ( - "reflect" "testing" + + "github.com/stretchr/testify/assert" ) -func TestH264Payloader_Payload(t *testing.T) { //nolint:cyclop +func TestH264Payloader_Payload(t *testing.T) { pck := H264Payloader{} smallpayload := []byte{0x90, 0x90, 0x90} multiplepayload := []byte{0x00, 0x00, 0x01, 0x90, 0x00, 0x00, 0x01, 0x90} @@ -33,85 +34,57 @@ func TestH264Payloader_Payload(t *testing.T) { //nolint:cyclop // Positive MTU, nil payload res := pck.Payload(1, nil) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") // Positive MTU, empty payload res = pck.Payload(1, []byte{}) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") // Positive MTU, empty NAL res = pck.Payload(1, []byte{0x00, 0x00, 0x01}) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") // Negative MTU, small payload res = pck.Payload(0, smallpayload) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") // 0 MTU, small payload res = pck.Payload(0, smallpayload) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") // Positive MTU, small payload res = pck.Payload(1, smallpayload) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") // Positive MTU, small payload res = pck.Payload(5, smallpayload) - if len(res) != 1 { - t.Fatal("Generated payload shouldn't be empty") - } - if len(res[0]) != len(smallpayload) { - t.Fatal("Generated payload should be the same size as original payload size") - } + assert.Len(t, res, 1, "Generated payload should be the 1") + assert.Len(t, smallpayload, len(res[0]), "Generated payload should be the same size as original payload size") // Multiple NALU in a single payload res = pck.Payload(5, multiplepayload) - if len(res) != 2 { - t.Fatal("2 nal units should be broken out") - } + assert.Len(t, res, 2, "2 nal units should be broken out") for i := 0; i < 2; i++ { - if len(res[i]) != 1 { - t.Fatalf("Payload %d of 2 is packed incorrectly", i+1) - } + assert.Lenf(t, res[i], 1, "Payload %d of 2 is packed incorrectly", i+1) } // Multiple NALU in a single payload with 3-byte and 4-byte start sequences res = pck.Payload(5, mixednalupayload) - if len(res) != 4 { - t.Fatal("4 nal units should be broken out", len(res), res) - } + assert.Len(t, res, 4, "4 nal units should be broken out") for i := 0; i < 4; i++ { - if len(res[i]) != 1 { - t.Fatalf("Payload %d of 4 is packed incorrectly: %v", i+1, res[i]) - } + assert.Lenf(t, res[i], 1, "Payload %d of 4 is packed incorrectly", i+1) } // Large Payload split across multiple RTP Packets res = pck.Payload(5, largepayload) - if !reflect.DeepEqual(res, largePayloadPacketized) { - t.Fatal("FU-A packetization failed") - } + assert.Equal(t, largePayloadPacketized, res, "FU-A packetization failed") // Nalu type 9 or 12 res = pck.Payload(5, []byte{0x09, 0x00, 0x00}) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") } -func TestH264Packet_Unmarshal(t *testing.T) { //nolint:cyclop +func TestH264Packet_Unmarshal(t *testing.T) { singlePayload := []byte{0x90, 0x90, 0x90} singlePayloadUnmarshaled := []byte{0x00, 0x00, 0x00, 0x01, 0x90, 0x90, 0x90} singlePayloadUnmarshaledAVC := []byte{0x00, 0x00, 0x00, 0x03, 0x90, 0x90, 0x90} @@ -164,138 +137,88 @@ func TestH264Packet_Unmarshal(t *testing.T) { //nolint:cyclop pkt := H264Packet{} avcPkt := H264Packet{IsAVC: true} - if _, err := pkt.Unmarshal(nil); err == nil { - t.Fatal("Unmarshal did not fail on nil payload") - } + _, err := pkt.Unmarshal(nil) + assert.Error(t, err, "Unmarshal did not fail on nil payload") - if _, err := pkt.Unmarshal([]byte{}); err == nil { - t.Fatal("Unmarshal did not fail on []byte{}") - } + _, err = pkt.Unmarshal([]byte{}) + assert.Error(t, err, "Unmarshal did not fail on []byte{}") - if _, err := pkt.Unmarshal([]byte{0xFC}); err == nil { - t.Fatal("Unmarshal accepted a FU-A packet that is too small for a payload and header") - } + _, err = pkt.Unmarshal([]byte{0xFC}) + assert.Error(t, err, "Unmarshal accepted a FU-A packet that is too small for a payload and header") - if _, err := pkt.Unmarshal([]byte{0x0A}); err != nil { - t.Fatal("Unmarshaling end of sequence(NALU Type : 10) should succeed") - } + _, err = pkt.Unmarshal([]byte{0x0A}) + assert.NoError(t, err, "Unmarshaling end of sequence(NALU Type : 10) should succeed") - if _, err := pkt.Unmarshal([]byte{0xFF, 0x00, 0x00}); err == nil { - t.Fatal("Unmarshal accepted a packet with a NALU Type we don't handle") - } + _, err = pkt.Unmarshal([]byte{0xFF, 0x00, 0x00}) + assert.Error(t, err, "Unmarshal accepted a packet with a NALU Type we don't handle") - if _, err := pkt.Unmarshal(incompleteSinglePayloadMultiNALU); err == nil { - t.Fatal("Unmarshal accepted a STAP-A packet with insufficient data") - } + _, err = pkt.Unmarshal(incompleteSinglePayloadMultiNALU) + assert.Error(t, err, "Unmarshal accepted a STAP-A packet with insufficient data") res, err := pkt.Unmarshal(singlePayload) - if err != nil { - t.Fatal(err) - } else if !reflect.DeepEqual(res, singlePayloadUnmarshaled) { - t.Fatal("Unmarshaling a single payload shouldn't modify the payload") - } + assert.NoError(t, err) + assert.Equal(t, singlePayloadUnmarshaled, res) res, err = avcPkt.Unmarshal(singlePayload) - if err != nil { - t.Fatal(err) - } else if !reflect.DeepEqual(res, singlePayloadUnmarshaledAVC) { - t.Fatal("Unmarshaling a single payload into avc stream shouldn't modify the payload") - } + assert.NoError(t, err) + assert.Equal(t, singlePayloadUnmarshaledAVC, res) largePayloadResult := []byte{} for i := range largePayloadPacketized { res, err = pkt.Unmarshal(largePayloadPacketized[i]) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) largePayloadResult = append(largePayloadResult, res...) } - if !reflect.DeepEqual(largePayloadResult, largepayload) { - t.Fatal("Failed to unmarshal a large payload") - } + assert.Equal(t, largepayload, largePayloadResult) largePayloadResultAVC := []byte{} for i := range largePayloadPacketized { res, err = avcPkt.Unmarshal(largePayloadPacketized[i]) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) largePayloadResultAVC = append(largePayloadResultAVC, res...) } - if !reflect.DeepEqual(largePayloadResultAVC, largepayloadAVC) { - t.Fatal("Failed to unmarshal a large payload into avc stream") - } + assert.Equal(t, largepayloadAVC, largePayloadResultAVC) res, err = pkt.Unmarshal(singlePayloadMultiNALU) - if err != nil { - t.Fatal(err) - } else if !reflect.DeepEqual(res, singlePayloadMultiNALUUnmarshaled) { - t.Fatal("Failed to unmarshal a single packet with multiple NALUs") - } + assert.NoError(t, err) + assert.Equal(t, singlePayloadMultiNALUUnmarshaled, res) res, err = avcPkt.Unmarshal(singlePayloadMultiNALU) - if err != nil { - t.Fatal(err) - } else if !reflect.DeepEqual(res, singlePayloadMultiNALUUnmarshaledAVC) { - t.Fatal("Failed to unmarshal a single packet with multiple NALUs into avc stream") - } + assert.NoError(t, err) + assert.Equal(t, singlePayloadMultiNALUUnmarshaledAVC, res) res, err = pkt.Unmarshal(singlePayloadWithBrokenSecondNALU) - if err != nil { - t.Fatal(err) - } else if !reflect.DeepEqual(res, singlePayloadWithBrokenSecondNALUUnmarshaled) { - t.Fatal("Failed to unmarshal a single packet with broken second NALUs") - } + assert.NoError(t, err) + assert.Equal(t, singlePayloadWithBrokenSecondNALUUnmarshaled, res) res, err = avcPkt.Unmarshal(singlePayloadWithBrokenSecondNALU) - if err != nil { - t.Fatal(err) - } else if !reflect.DeepEqual(res, singlePayloadWithBrokenSecondUnmarshaledAVC) { - t.Fatal("Failed to unmarshal a single packet with broken second NALUs into avc stream") - } + assert.NoError(t, err) + assert.Equal(t, singlePayloadWithBrokenSecondUnmarshaledAVC, res) } func TestH264IsPartitionHead(t *testing.T) { h264 := H264Packet{} - if h264.IsPartitionHead(nil) { - t.Fatal("nil must not be a partition head") - } - - emptyNalu := []byte{} - if h264.IsPartitionHead(emptyNalu) { - t.Fatal("empty nalu must not be a partition head") - } + assert.False(t, h264.IsPartitionHead(nil), "nil must not be a partition head") + assert.False(t, h264.IsPartitionHead([]byte{}), "empty nalu must not be a partition head") singleNalu := []byte{1, 0} - if h264.IsPartitionHead(singleNalu) == false { - t.Fatal("single nalu must be a partition head") - } + assert.True(t, h264.IsPartitionHead(singleNalu), "single nalu must be a partition head") stapaNalu := []byte{stapaNALUType, 0} - if h264.IsPartitionHead(stapaNalu) == false { - t.Fatal("stapa nalu must be a partition head") - } + assert.True(t, h264.IsPartitionHead(stapaNalu), "stapa nalu must be a partition head") fuaStartNalu := []byte{fuaNALUType, fuStartBitmask} - if h264.IsPartitionHead(fuaStartNalu) == false { - t.Fatal("fua start nalu must be a partition head") - } + assert.True(t, h264.IsPartitionHead(fuaStartNalu), "fua start nalu must be a partition head") fuaEndNalu := []byte{fuaNALUType, fuEndBitmask} - if h264.IsPartitionHead(fuaEndNalu) { - t.Fatal("fua end nalu must not be a partition head") - } + assert.False(t, h264.IsPartitionHead(fuaEndNalu), "fua end nalu must not be a partition head") fubStartNalu := []byte{fubNALUType, fuStartBitmask} - if h264.IsPartitionHead(fubStartNalu) == false { - t.Fatal("fub start nalu must be a partition head") - } + assert.True(t, h264.IsPartitionHead(fubStartNalu), "fub start nalu must be a partition head") fubEndNalu := []byte{fubNALUType, fuEndBitmask} - if h264.IsPartitionHead(fubEndNalu) { - t.Fatal("fub end nalu must not be a partition head") - } + assert.False(t, h264.IsPartitionHead(fubEndNalu), "fub end nalu must not be a partition head") } func TestH264Payloader_Payload_SPS_and_PPS_handling(t *testing.T) { @@ -307,18 +230,11 @@ func TestH264Payloader_Payload_SPS_and_PPS_handling(t *testing.T) { // When packetizing SPS and PPS are emitted with following NALU res := pck.Payload(1500, []byte{0x07, 0x00, 0x01}) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") res = pck.Payload(1500, []byte{0x08, 0x02, 0x03}) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } - - if !reflect.DeepEqual(pck.Payload(1500, []byte{0x05, 0x04, 0x05}), expected) { - t.Fatal("SPS and PPS aren't packed together") - } + assert.Len(t, res, 0, "Generated payload should be empty") + assert.Equal(t, expected, pck.Payload(1500, []byte{0x05, 0x04, 0x05}), "SPS and PPS aren't packed together") } func TestH264Payloader_Payload_SPS_and_PPS_handling_no_stapA(t *testing.T) { @@ -328,20 +244,11 @@ func TestH264Payloader_Payload_SPS_and_PPS_handling_no_stapA(t *testing.T) { expectedSps := []byte{0x07, 0x00, 0x01} // The SPS is packed as a single NALU res := pck.Payload(1500, expectedSps) - if len(res) != 1 { - t.Fatal("Generated payload should not be empty") - } - if !reflect.DeepEqual(res[0], expectedSps) { - t.Fatal("SPS has not been packed correctly") - } + assert.Len(t, res, 1, "Generated payload should not be empty") + assert.Equal(t, expectedSps, res[0], "SPS has not been packed correctly") // The PPS is packed as a single NALU expectedPps := []byte{0x08, 0x02, 0x03} res = pck.Payload(1500, expectedPps) - if len(res) != 1 { - t.Fatal("Generated payload should not be empty") - } - - if !reflect.DeepEqual(res[0], expectedPps) { - t.Fatal("PPS has not been packed correctly") - } + assert.Len(t, res, 1, "Generated payload should not be empty") + assert.Equal(t, expectedPps, res[0], "PPS has not been packed correctly") } diff --git a/codecs/h265_packet_test.go b/codecs/h265_packet_test.go index dff6a84f..cadf7f6a 100644 --- a/codecs/h265_packet_test.go +++ b/codecs/h265_packet_test.go @@ -4,8 +4,9 @@ package codecs import ( - "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestH265_NALU_Header(t *testing.T) { @@ -70,38 +71,15 @@ func TestH265_NALU_Header(t *testing.T) { for _, cur := range tt { header := newH265NALUHeader(cur.RawHeader[0], cur.RawHeader[1]) - if header.F() != cur.FBit { - t.Fatal("invalid F bit") - } - - if header.Type() != cur.Type { - t.Fatal("invalid Type") - } - + assert.Equal(t, cur.FBit, header.F()) + assert.Equal(t, cur.Type, header.Type()) // For any type < 32, NAL is a VLC NAL unit. - if header.IsTypeVCLUnit() != (header.Type() < 32) { - t.Fatal("invalid IsTypeVCLUnit") - } - - if header.IsAggregationPacket() != cur.IsAP { - t.Fatal("invalid Type (aggregation packet)") - } - - if header.IsFragmentationUnit() != cur.IsFU { - t.Fatal("invalid Type (fragmentation unit)") - } - - if header.IsPACIPacket() != cur.IsPACI { - t.Fatal("invalid Type (PACI)") - } - - if header.LayerID() != cur.LayerID { - t.Fatal("invalid LayerID") - } - - if header.TID() != cur.TID { - t.Fatal("invalid TID") - } + assert.Equal(t, header.IsTypeVCLUnit(), header.Type() < 32) + assert.Equal(t, cur.IsAP, header.IsAggregationPacket()) + assert.Equal(t, cur.IsFU, header.IsFragmentationUnit()) + assert.Equal(t, cur.IsPACI, header.IsPACIPacket()) + assert.Equal(t, cur.LayerID, header.LayerID()) + assert.Equal(t, cur.TID, header.TID()) } } @@ -158,21 +136,13 @@ func TestH265_FU_Header(t *testing.T) { } for _, cur := range tt { - if cur.header.S() != cur.S { - t.Fatal("invalid S field") - } - - if cur.header.E() != cur.E { - t.Fatal("invalid E field") - } - - if cur.header.FuType() != cur.Type { - t.Fatal("invalid FuType field") - } + assert.Equal(t, cur.S, cur.header.S()) + assert.Equal(t, cur.E, cur.header.E()) + assert.Equal(t, cur.header.FuType(), cur.Type) } } -func TestH265_SingleNALUnitPacket(t *testing.T) { //nolint:cyclop +func TestH265_SingleNALUnitPacket(t *testing.T) { tt := [...]struct { Raw []byte WithDONL bool @@ -193,7 +163,7 @@ func TestH265_SingleNALUnitPacket(t *testing.T) { //nolint:cyclop }, { Raw: []byte{0x62, 0x01, 0x93}, - ExpectedErr: errShortPacket, + ExpectedErr: errInvalidH265PacketType, }, // FBit enabled in H265NALUHeader { @@ -245,31 +215,29 @@ func TestH265_SingleNALUnitPacket(t *testing.T) { //nolint:cyclop _, err := parsed.Unmarshal(cur.Raw) - if cur.ExpectedErr != nil && err == nil { - t.Fatal("should error") - } else if cur.ExpectedErr == nil && err != nil { - t.Fatal("should not error") + if cur.ExpectedErr == nil { + assert.NoError(t, err) + } else { + assert.ErrorIs(t, err, cur.ExpectedErr) } if cur.ExpectedPacket == nil { continue } - if cur.ExpectedPacket.PayloadHeader() != parsed.PayloadHeader() { - t.Fatal("invalid payload header") - } + assert.Equal(t, cur.ExpectedPacket.PayloadHeader(), parsed.PayloadHeader()) - if cur.ExpectedPacket.DONL() != nil && (*parsed.DONL() != *cur.ExpectedPacket.DONL()) { - t.Fatal("invalid DONL") + if cur.ExpectedPacket.DONL() != nil { + assert.Equal(t, *cur.ExpectedPacket.DONL(), *parsed.DONL()) + } else { + assert.Nil(t, parsed.DONL()) } - if !reflect.DeepEqual(cur.ExpectedPacket.Payload(), parsed.Payload()) { - t.Fatal("invalid payload") - } + assert.Equal(t, cur.ExpectedPacket.Payload(), parsed.Payload()) } } -func TestH265_AggregationPacket(t *testing.T) { //nolint:cyclop +func TestH265_AggregationPacket(t *testing.T) { tt := [...]struct { Raw []byte WithDONL bool @@ -290,7 +258,7 @@ func TestH265_AggregationPacket(t *testing.T) { //nolint:cyclop }, { Raw: []byte{0x62, 0x01, 0x93}, - ExpectedErr: errShortPacket, + ExpectedErr: errInvalidH265PacketType, }, // FBit enabled in H265NALUHeader { @@ -300,7 +268,7 @@ func TestH265_AggregationPacket(t *testing.T) { //nolint:cyclop // Type '48' in H265NALUHeader { Raw: []byte{0xE0, 0x01, 0x93, 0xaf, 0xaf, 0xaf, 0xaf}, - ExpectedErr: errInvalidH265PacketType, + ExpectedErr: errH265CorruptedPacket, }, // Small payload { @@ -387,10 +355,10 @@ func TestH265_AggregationPacket(t *testing.T) { //nolint:cyclop _, err := parsed.Unmarshal(cur.Raw) - if cur.ExpectedErr != nil && err == nil { - t.Fatal("should error") - } else if cur.ExpectedErr == nil && err != nil { - t.Fatal("should not error") + if cur.ExpectedErr == nil { + assert.NoError(t, err) + } else { + assert.ErrorIs(t, err, cur.ExpectedErr) } if cur.ExpectedPacket == nil { @@ -398,45 +366,38 @@ func TestH265_AggregationPacket(t *testing.T) { //nolint:cyclop } if cur.ExpectedPacket.FirstUnit() != nil { - if parsed.FirstUnit().NALUSize() != cur.ExpectedPacket.FirstUnit().NALUSize() { - t.Fatal("invalid first unit NALUSize") - } + assert.Equal(t, cur.ExpectedPacket.FirstUnit().NALUSize(), parsed.FirstUnit().NALUSize()) - if cur.ExpectedPacket.FirstUnit().DONL() != nil && - *cur.ExpectedPacket.FirstUnit().DONL() != *parsed.FirstUnit().DONL() { - t.Fatal("invalid first unit DONL") + if cur.ExpectedPacket.FirstUnit().DONL() != nil { + assert.Equal(t, *cur.ExpectedPacket.FirstUnit().DONL(), *parsed.FirstUnit().DONL()) + } else { + assert.Nil(t, parsed.FirstUnit().DONL()) } - if !reflect.DeepEqual(cur.ExpectedPacket.FirstUnit().NalUnit(), parsed.FirstUnit().NalUnit()) { - t.Fatal("invalid first unit NalUnit") - } + assert.Equal( + t, cur.ExpectedPacket.FirstUnit().NalUnit(), parsed.FirstUnit().NalUnit(), + ) } - if len(cur.ExpectedPacket.OtherUnits()) != len(parsed.OtherUnits()) { - t.Fatal("number of other units mismatch") - } + assert.Len(t, cur.ExpectedPacket.OtherUnits(), len(parsed.OtherUnits())) for ndx, unit := range cur.ExpectedPacket.OtherUnits() { - if parsed.OtherUnits()[ndx].NALUSize() != unit.NALUSize() { - t.Fatal("invalid unit NALUSize") - } + assert.Equal(t, unit.NALUSize(), parsed.OtherUnits()[ndx].NALUSize()) - if unit.DOND() != nil && *unit.DOND() != *parsed.OtherUnits()[ndx].DOND() { - t.Fatal("invalid unit DOND") + if unit.DOND() != nil { + assert.Equal(t, *unit.DOND(), *parsed.OtherUnits()[ndx].DOND()) + } else { + assert.Nil(t, parsed.OtherUnits()[ndx].DOND()) } - if !reflect.DeepEqual(unit.NalUnit(), parsed.OtherUnits()[ndx].NalUnit()) { - t.Fatal("invalid first unit NalUnit") - } + assert.Equal(t, unit.NalUnit(), parsed.OtherUnits()[ndx].NalUnit()) } - if !reflect.DeepEqual(cur.ExpectedPacket.OtherUnits(), parsed.OtherUnits()) { - t.Fatal("invalid payload") - } + assert.Equal(t, cur.ExpectedPacket.OtherUnits(), parsed.OtherUnits()) } } -func TestH265_FragmentationUnitPacket(t *testing.T) { //nolint:cyclop +func TestH265_FragmentationUnitPacket(t *testing.T) { tt := [...]struct { Raw []byte WithDONL bool @@ -509,32 +470,26 @@ func TestH265_FragmentationUnitPacket(t *testing.T) { //nolint:cyclop parsed.isH265Packet() _, err := parsed.Unmarshal(cur.Raw) - - if cur.ExpectedErr != nil && err == nil { - t.Fatal("should error") - } else if cur.ExpectedErr == nil && err != nil { - t.Fatal("should not error") + if cur.ExpectedErr != nil { + assert.ErrorIs(t, err, cur.ExpectedErr) + } else { + assert.NoError(t, err) } if cur.ExpectedFU == nil { continue } - if parsed.PayloadHeader() != cur.ExpectedFU.PayloadHeader() { - t.Fatal("invalid payload header") - } + assert.Equal(t, cur.ExpectedFU.PayloadHeader(), parsed.PayloadHeader()) + assert.Equal(t, cur.ExpectedFU.FuHeader(), parsed.FuHeader()) - if parsed.FuHeader() != cur.ExpectedFU.FuHeader() { - t.Fatal("invalid FU header") + if cur.ExpectedFU.DONL() != nil { + assert.Equal(t, *cur.ExpectedFU.DONL(), *parsed.DONL()) + } else { + assert.Nil(t, parsed.DONL()) } - if cur.ExpectedFU.DONL() != nil && (*parsed.DONL() != *cur.ExpectedFU.DONL()) { - t.Fatal("invalid DONL") - } - - if !reflect.DeepEqual(parsed.Payload(), cur.ExpectedFU.Payload()) { - t.Fatal("invalid Payload") - } + assert.Equal(t, cur.ExpectedFU.Payload(), parsed.Payload()) } } @@ -573,29 +528,15 @@ func TestH265_TemporalScalabilityControlInformation(t *testing.T) { } for _, cur := range tt { - if cur.Value.TL0PICIDX() != cur.ExpectedTL0PICIDX { - t.Fatal("invalid TL0PICIDX") - } - - if cur.Value.IrapPicID() != cur.ExpectedIrapPicID { - t.Fatal("invalid IrapPicID") - } - - if cur.Value.S() != cur.ExpectedS { - t.Fatal("invalid S") - } - - if cur.Value.E() != cur.ExpectedE { - t.Fatal("invalid E") - } - - if cur.Value.RES() != cur.ExpectedRES { - t.Fatal("invalid RES") - } + assert.Equal(t, cur.ExpectedTL0PICIDX, cur.Value.TL0PICIDX()) + assert.Equal(t, cur.ExpectedIrapPicID, cur.Value.IrapPicID()) + assert.Equal(t, cur.ExpectedS, cur.Value.S()) + assert.Equal(t, cur.ExpectedE, cur.Value.E()) + assert.Equal(t, cur.ExpectedRES, cur.Value.RES()) } } -func TestH265_PACI_Packet(t *testing.T) { //nolint:cyclop +func TestH265_PACI_Packet(t *testing.T) { tt := [...]struct { Raw []byte ExpectedFU *H265PACIPacket @@ -626,7 +567,7 @@ func TestH265_PACI_Packet(t *testing.T) { //nolint:cyclop // Invalid header extension size { Raw: []byte{0x64, 0x01, 0x93, 0xaf, 0xaf, 0xaf, 0xaf}, - ExpectedErr: errInvalidH265PacketType, + ExpectedErr: errShortPacket, }, // No Header Extension { @@ -667,58 +608,30 @@ func TestH265_PACI_Packet(t *testing.T) { //nolint:cyclop // Just for code coverage sake parsed.isH265Packet() - if cur.ExpectedErr != nil && err == nil { - t.Fatal("should error") - } else if cur.ExpectedErr == nil && err != nil { - t.Fatal("should not error") + if cur.ExpectedErr != nil { + assert.ErrorIs(t, err, cur.ExpectedErr) + } else { + assert.NoError(t, err) } if cur.ExpectedFU == nil { continue } - if cur.ExpectedFU.PayloadHeader() != parsed.PayloadHeader() { - t.Fatal("invalid PayloadHeader") - } - - if cur.ExpectedFU.A() != parsed.A() { - t.Fatal("invalid A") - } - - if cur.ExpectedFU.CType() != parsed.CType() { - t.Fatal("invalid CType") - } - - if cur.ExpectedFU.PHSsize() != parsed.PHSsize() { - t.Fatal("invalid PHSsize") - } - - if cur.ExpectedFU.F0() != parsed.F0() { - t.Fatal("invalid F0") - } - - if cur.ExpectedFU.F1() != parsed.F1() { - t.Fatal("invalid F1") - } - - if cur.ExpectedFU.F2() != parsed.F2() { - t.Fatal("invalid F2") - } - - if cur.ExpectedFU.Y() != parsed.Y() { - t.Fatal("invalid Y") - } - - if !reflect.DeepEqual(cur.ExpectedFU.PHES(), parsed.PHES()) { - t.Fatal("invalid PHES") - } - - if !reflect.DeepEqual(cur.ExpectedFU.Payload(), parsed.Payload()) { - t.Fatal("invalid Payload") - } - - if cur.ExpectedFU.TSCI() != nil && (*cur.ExpectedFU.TSCI() != *parsed.TSCI()) { - t.Fatal("invalid TSCI") + assert.Equal(t, cur.ExpectedFU.PayloadHeader(), parsed.PayloadHeader()) + assert.Equal(t, cur.ExpectedFU.A(), parsed.A()) + assert.Equal(t, cur.ExpectedFU.CType(), parsed.CType()) + assert.Equal(t, cur.ExpectedFU.PHSsize(), parsed.PHSsize()) + assert.Equal(t, cur.ExpectedFU.F0(), parsed.F0()) + assert.Equal(t, cur.ExpectedFU.F1(), parsed.F1()) + assert.Equal(t, cur.ExpectedFU.F2(), parsed.F2()) + assert.Equal(t, cur.ExpectedFU.Y(), parsed.Y()) + assert.Equal(t, cur.ExpectedFU.PHES(), parsed.PHES()) + assert.Equal(t, cur.ExpectedFU.Payload(), parsed.Payload()) + if cur.ExpectedFU.TSCI() != nil { + assert.Equal(t, cur.ExpectedFU.TSCI(), parsed.TSCI()) + } else { + assert.Nil(t, parsed.TSCI()) } } } @@ -727,7 +640,7 @@ func TestH265_Packet(t *testing.T) { tt := [...]struct { Raw []byte WithDONL bool - ExpectedPacketType reflect.Type + ExpectedPacketType interface{} ExpectedErr error }{ { @@ -759,7 +672,7 @@ func TestH265_Packet(t *testing.T) { // Valid H265SingleNALUnitPacket { Raw: []byte{0x01, 0x01, 0xab, 0xcd, 0xef}, - ExpectedPacketType: reflect.TypeOf((*H265SingleNALUnitPacket)(nil)), + ExpectedPacketType: &H265SingleNALUnitPacket{}, }, // Invalid H265SingleNALUnitPacket { @@ -770,18 +683,18 @@ func TestH265_Packet(t *testing.T) { // Valid H265PACIPacket { Raw: []byte{0x64, 0x01, 0x64, 0b00111000, 0xaa, 0xbb, 0x80, 0xab, 0xcd, 0xef}, - ExpectedPacketType: reflect.TypeOf((*H265PACIPacket)(nil)), + ExpectedPacketType: &H265PACIPacket{}, }, // Valid H265FragmentationUnitPacket { Raw: []byte{0x62, 0x01, 0x93, 0xcc, 0xdd, 0xaf, 0x0d, 0x5a}, - ExpectedPacketType: reflect.TypeOf((*H265FragmentationUnitPacket)(nil)), + ExpectedPacketType: &H265FragmentationUnitPacket{}, WithDONL: true, }, // Valid H265AggregationPacket { Raw: []byte{0x60, 0x01, 0xcc, 0xdd, 0x00, 0x02, 0xff, 0xee, 0x77, 0x00, 0x01, 0xaa}, - ExpectedPacketType: reflect.TypeOf((*H265AggregationPacket)(nil)), + ExpectedPacketType: &H265AggregationPacket{}, WithDONL: true, }, // Invalid H265AggregationPacket @@ -799,54 +712,37 @@ func TestH265_Packet(t *testing.T) { } _, err := pck.Unmarshal(cur.Raw) - - if cur.ExpectedErr != nil && err == nil { - t.Fatal("should error") - } else if cur.ExpectedErr == nil && err != nil { - t.Fatal("should not error") + if cur.ExpectedErr == nil { + assert.NoError(t, err) + } else { + assert.ErrorIs(t, err, cur.ExpectedErr) } if cur.ExpectedErr != nil { continue } - if reflect.TypeOf(pck.Packet()) != cur.ExpectedPacketType { - t.Fatal("invalid packet type") - } + assert.IsType(t, cur.ExpectedPacketType, pck.Packet()) } } func TestH265IsPartitionHead(t *testing.T) { h265 := H265Packet{} - if h265.IsPartitionHead(nil) { - t.Fatal("nil must not be a partition head") - } - - emptyNalu := []byte{} - if h265.IsPartitionHead(emptyNalu) { - t.Fatal("empty nalu must not be a partition head") - } + assert.False(t, h265.IsPartitionHead(nil), "nil must not be a partition head") + assert.False(t, h265.IsPartitionHead([]byte{}), "empty nalu must not be a partition head") singleNalu := []byte{0x01, 0x01, 0xab, 0xcd, 0xef} - if h265.IsPartitionHead(singleNalu) == false { - t.Fatal("single nalu must be a partition head") - } + assert.True(t, h265.IsPartitionHead(singleNalu), "single nalu must be a partition head") fbitNalu := []byte{0x80, 0x00, 0x00} - if h265.IsPartitionHead(fbitNalu) == false { - t.Fatal("fbit nalu must be a partition head") - } + assert.True(t, h265.IsPartitionHead(fbitNalu), "fbit nalu must be a partition head") fuStartNalu := []byte{0x62, 0x01, 0x93} - if h265.IsPartitionHead(fuStartNalu) == false { - t.Fatal("fu start nalu must be a partition head") - } + assert.True(t, h265.IsPartitionHead(fuStartNalu), "fu start nalu must be a partition head") fuEndNalu := []byte{0x62, 0x01, 0x53} - if h265.IsPartitionHead(fuEndNalu) { - t.Fatal("fu end nalu must not be a partition head") - } + assert.False(t, h265.IsPartitionHead(fuEndNalu), "fu end nalu must not be a partition head") } func TestH265_Packet_Real(t *testing.T) { @@ -867,9 +763,7 @@ func TestH265_Packet_Real(t *testing.T) { for _, cur := range tt { pck := &H265Packet{} _, err := pck.Unmarshal([]byte(cur)) - if err != nil { - t.Fatal("invalid packet type") - } + assert.NoError(t, err) } } @@ -1125,13 +1019,9 @@ func TestH265Payloader_Payload(t *testing.T) { pck := H265Payloader{AddDONL: cur.AddDONL, SkipAggregation: cur.SkipAggregation} res := pck.Payload(cur.MTU, cur.Data) if cur.ExpectedData != nil { - if !reflect.DeepEqual(res, *cur.ExpectedData) { - t.Fatal(cur.Msg) - } + assert.Equal(t, *cur.ExpectedData, res) } else { - if len(res) != cur.ExpectedLen { - t.Fatal(cur.Msg) - } + assert.Len(t, res, cur.ExpectedLen) } }) } @@ -1151,12 +1041,10 @@ func TestH265Payloader_Real(t *testing.T) { } pck := H265Payloader{} res := pck.Payload(1400, payload) - if len(res) != 3 { - // 1. Aggregating three NALUs into a single payload - // 2. Fragmented packets divided by MTU=1400 - // 3. Remaining fragment packets split by MTU - t.Fatal("Generated payload should be 3") - } + // 1. Aggregating three NALUs into a single payload + // 2. Fragmented packets divided by MTU=1400 + // 3. Remaining fragment packets split by MTU + assert.Len(t, res, 3, "Generated payload should be 3") } func uint8ptr(v uint8) *uint8 { diff --git a/codecs/opus_packet_test.go b/codecs/opus_packet_test.go index 665d17b5..45dd473a 100644 --- a/codecs/opus_packet_test.go +++ b/codecs/opus_packet_test.go @@ -4,8 +4,9 @@ package codecs import ( - "errors" "testing" + + "github.com/stretchr/testify/assert" ) func TestOpusPacket_Unmarshal(t *testing.T) { @@ -13,30 +14,18 @@ func TestOpusPacket_Unmarshal(t *testing.T) { // Nil packet raw, err := pck.Unmarshal(nil) - if raw != nil { - t.Fatal("Result should be nil in case of error") - } - if err == nil || err.Error() != errNilPacket.Error() { - t.Fatal("Error should be:", errNilPacket) - } + assert.ErrorIs(t, err, errNilPacket) + assert.Nil(t, raw, "Result should be nil in case of error") // Empty packet raw, err = pck.Unmarshal([]byte{}) - if raw != nil { - t.Fatal("Result should be nil in case of error") - } - if !errors.Is(err, errShortPacket) { - t.Fatal("Error should be:", errShortPacket) - } + assert.ErrorIs(t, err, errShortPacket) + assert.Nil(t, raw, "Result should be nil in case of error") // Normal packet raw, err = pck.Unmarshal([]byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x90}) - if raw == nil { - t.Fatal("Result shouldn't be nil in case of success") - } - if err != nil { - t.Fatal("Error should be nil in case of success") - } + assert.NoError(t, err) + assert.NotNil(t, raw, "Result shouldn't be nil in case of success") } func TestOpusPayloader_Payload(t *testing.T) { @@ -45,28 +34,23 @@ func TestOpusPayloader_Payload(t *testing.T) { // Positive MTU, nil payload res := pck.Payload(1, nil) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") // Positive MTU, small payload res = pck.Payload(1, payload) - if len(res) != 1 { - t.Fatal("Generated payload should be the 1") - } + assert.Len(t, res, 1, "Generated payload should be the 1") // Positive MTU, small payload res = pck.Payload(2, payload) - if len(res) != 1 { - t.Fatal("Generated payload should be the 1") - } + assert.Len(t, res, 1, "Generated payload should be the 1") } func TestOpusIsPartitionHead(t *testing.T) { opus := &OpusPacket{} t.Run("NormalPacket", func(t *testing.T) { - if !opus.IsPartitionHead([]byte{0x00, 0x00}) { - t.Fatal("All OPUS RTP packet should be the head of a new partition") - } + assert.True( + t, opus.IsPartitionHead([]byte{0x00, 0x00}), + "All OPUS RTP packet should be the head of a new partition", + ) }) } diff --git a/codecs/vp8_packet_test.go b/codecs/vp8_packet_test.go index 446bbe1f..0f94a4c8 100644 --- a/codecs/vp8_packet_test.go +++ b/codecs/vp8_packet_test.go @@ -4,103 +4,63 @@ package codecs import ( - "errors" - "reflect" "testing" + + "github.com/stretchr/testify/assert" ) -func TestVP8Packet_Unmarshal(t *testing.T) { //nolint:cyclop +func TestVP8Packet_Unmarshal(t *testing.T) { pck := VP8Packet{} // Nil packet raw, err := pck.Unmarshal(nil) - if raw != nil { - t.Fatal("Result should be nil in case of error") - } - if !errors.Is(err, errNilPacket) { - t.Fatal("Error should be:", errNilPacket) - } + assert.ErrorIs(t, err, errNilPacket) + assert.Nil(t, raw, "Result should be nil in case of error") // Nil payload raw, err = pck.Unmarshal([]byte{}) - if raw != nil { - t.Fatal("Result should be nil in case of error") - } - if !errors.Is(err, errShortPacket) { - t.Fatal("Error should be:", errShortPacket) - } + assert.ErrorIs(t, err, errShortPacket) + assert.Nil(t, raw, "Result should be nil in case of error") // Normal payload raw, err = pck.Unmarshal([]byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x90}) - if raw == nil { - t.Fatal("Result shouldn't be nil in case of success") - } - if err != nil { - t.Fatal("Error should be nil in case of success") - } + assert.NoError(t, err) + assert.NotNil(t, raw, "Result shouldn't be nil in case of success") // Header size, only X raw, err = pck.Unmarshal([]byte{0x80, 0x00, 0x00, 0x00}) - if raw == nil { - t.Fatal("Result shouldn't be nil in case of success") - } - if err != nil { - t.Fatal("Error should be nil in case of success") - } + assert.NoError(t, err) + assert.NotNil(t, raw, "Result shouldn't be nil in case of success") // Header size, X and I raw, err = pck.Unmarshal([]byte{0x80, 0x80, 0x00, 0x00}) - if raw == nil { - t.Fatal("Result shouldn't be nil in case of success") - } - if err != nil { - t.Fatal("Error should be nil in case of success") - } + assert.NoError(t, err) + assert.NotNil(t, raw, "Result shouldn't be nil in case of success") // Header size, X and I, PID 16bits raw, err = pck.Unmarshal([]byte{0x80, 0x80, 0x81, 0x00}) - if raw == nil { - t.Fatal("Result shouldn't be nil in case of success") - } - if err != nil { - t.Fatal("Error should be nil in case of success") - } + assert.NoError(t, err) + assert.NotNil(t, raw, "Result shouldn't be nil in case of success") // Header size, X and L raw, err = pck.Unmarshal([]byte{0x80, 0x40, 0x00, 0x00}) - if raw == nil { - t.Fatal("Result shouldn't be nil in case of success") - } - if err != nil { - t.Fatal("Error should be nil in case of success") - } + assert.NoError(t, err) + assert.NotNil(t, raw, "Result shouldn't be nil in case of success") // Header size, X and T raw, err = pck.Unmarshal([]byte{0x80, 0x20, 0x00, 0x00}) - if raw == nil { - t.Fatal("Result shouldn't be nil in case of success") - } - if err != nil { - t.Fatal("Error should be nil in case of success") - } + assert.NoError(t, err) + assert.NotNil(t, raw, "Result shouldn't be nil in case of success") // Header size, X and K raw, err = pck.Unmarshal([]byte{0x80, 0x10, 0x00, 0x00}) - if raw == nil { - t.Fatal("Result shouldn't be nil in case of success") - } - if err != nil { - t.Fatal("Error should be nil in case of success") - } + assert.NoError(t, err) + assert.NotNil(t, raw, "Result shouldn't be nil in case of success") // Header size, all flags raw, err = pck.Unmarshal([]byte{0xff, 0xff, 0x00, 0x00}) - if raw != nil { - t.Fatal("Result should be nil in case of error") - } - if !errors.Is(err, errShortPacket) { - t.Fatal("Error should be:", errShortPacket) - } + assert.ErrorIs(t, err, errShortPacket) + assert.Nil(t, raw, "Result should be nil in case of error") // According to RFC 7741 Section 4.4, the packetizer need not pay // attention to partition boundaries. In that case, it may @@ -108,30 +68,20 @@ func TestVP8Packet_Unmarshal(t *testing.T) { //nolint:cyclop // The next three have been witnessed in nature. _, err = pck.Unmarshal([]byte{0x00}) - if err != nil { - t.Errorf("Empty packet with trivial header: %v", err) - } + assert.NoError(t, err, "Empty packet with trivial header") _, err = pck.Unmarshal([]byte{0x00, 0x2a, 0x94}) - if err != nil { - t.Errorf("Non-empty packet with trivial header: %v", err) - } + assert.NoError(t, err, "Non-empty packet with trivial header") + raw, err = pck.Unmarshal([]byte{0x81, 0x81, 0x94}) - if raw != nil { - t.Fatal("Result should be nil in case of error") - } - if !errors.Is(err, errShortPacket) { - t.Fatal("Error should be:", errShortPacket) - } + assert.ErrorIs(t, err, errShortPacket) + assert.Nil(t, raw, "Result should be nil in case of error") // The following two were invented. _, err = pck.Unmarshal([]byte{0x80, 0x00}) - if err != nil { - t.Errorf("Empty packet with trivial extension: %v", err) - } + assert.NoError(t, err, "Empty packet with trivial extension") + _, err = pck.Unmarshal([]byte{0x80, 0x80, 42}) - if err != nil { - t.Errorf("Header with PictureID: %v", err) - } + assert.NoError(t, err, "Header with PictureID") } func TestVP8Payloader_Payload(t *testing.T) { @@ -201,9 +151,7 @@ func TestVP8Payloader_Payload(t *testing.T) { for i := range testCase.payload { res := pck.Payload(testCase.mtu, testCase.payload[i]) - if !reflect.DeepEqual(testCase.expected[i], res) { - t.Fatalf("Generated packet[%d] differs, expected: %v, got: %v", i, testCase.expected[i], res) - } + assert.Equal(t, testCase.expected[i], res, "Generated packet differs") } }) } @@ -214,34 +162,30 @@ func TestVP8Payloader_Payload(t *testing.T) { // Positive MTU, nil payload res := pck.Payload(1, nil) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") // Positive MTU, small payload // MTU of 1 results in fragment size of 0 res = pck.Payload(1, payload) - if len(res) != 0 { - t.Fatal("Generated payload should be empty") - } + assert.Len(t, res, 0, "Generated payload should be empty") }) } func TestVP8IsPartitionHead(t *testing.T) { vp8 := &VP8Packet{} t.Run("SmallPacket", func(t *testing.T) { - if vp8.IsPartitionHead([]byte{0x00}) { - t.Fatal("Small packet should not be the head of a new partition") - } + assert.False(t, vp8.IsPartitionHead([]byte{0x00}), "Small packet should not be the head of a new partition") }) t.Run("SFlagON", func(t *testing.T) { - if !vp8.IsPartitionHead([]byte{0x10, 0x00, 0x00, 0x00}) { - t.Fatal("Packet with S flag should be the head of a new partition") - } + assert.True( + t, vp8.IsPartitionHead([]byte{0x10, 0x00, 0x00, 0x00}), + "Packet with S flag should be the head of a new partition", + ) }) t.Run("SFlagOFF", func(t *testing.T) { - if vp8.IsPartitionHead([]byte{0x00, 0x00, 0x00, 0x00}) { - t.Fatal("Packet without S flag should not be the head of a new partition") - } + assert.False( + t, vp8.IsPartitionHead([]byte{0x00, 0x00, 0x00, 0x00}), + "Packet without S flag should not be the head of a new partition", + ) }) } diff --git a/codecs/vp9/header_test.go b/codecs/vp9/header_test.go index 41fb46a9..d7041230 100644 --- a/codecs/vp9/header_test.go +++ b/codecs/vp9/header_test.go @@ -4,8 +4,9 @@ package vp9 import ( - "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestHeaderUnmarshal(t *testing.T) { @@ -66,20 +67,10 @@ func TestHeaderUnmarshal(t *testing.T) { for _, ca := range cases { t.Run(ca.name, func(t *testing.T) { var sh Header - err := sh.Unmarshal(ca.byts) - if err != nil { - t.Fatal("unexpected error") - } - - if !reflect.DeepEqual(ca.sh, sh) { - t.Fatalf("expected %#+v, got %#+v", ca.sh, sh) - } - if ca.width != sh.Width() { - t.Fatalf("unexpected width") - } - if ca.height != sh.Height() { - t.Fatalf("unexpected height") - } + assert.NoError(t, sh.Unmarshal(ca.byts)) + assert.Equal(t, ca.sh, sh) + assert.Equal(t, ca.width, sh.Width()) + assert.Equal(t, ca.height, sh.Height()) }) } } diff --git a/codecs/vp9_packet_test.go b/codecs/vp9_packet_test.go index 7ad6ee4b..7e3f0acb 100644 --- a/codecs/vp9_packet_test.go +++ b/codecs/vp9_packet_test.go @@ -4,10 +4,10 @@ package codecs import ( - "errors" "math/rand" - "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestVP9Packet_Unmarshal(t *testing.T) { @@ -211,29 +211,19 @@ func TestVP9Packet_Unmarshal(t *testing.T) { t.Run(name, func(t *testing.T) { p := VP9Packet{} raw, err := p.Unmarshal(testCase.b) - if testCase.err == nil { // nolint: nestif - if raw == nil { - t.Error("Result shouldn't be nil in case of success") - } - if err != nil { - t.Error("Error should be nil in case of success") - } - if !reflect.DeepEqual(testCase.pkt, p) { - t.Errorf("Unmarshalled packet expected to be:\n %v\ngot:\n %v", testCase.pkt, p) - } + if testCase.err == nil { + assert.NoError(t, err) + assert.NotNil(t, raw) + assert.Equal(t, testCase.pkt, p) } else { - if raw != nil { - t.Error("Result should be nil in case of error") - } - if !errors.Is(err, testCase.err) { - t.Errorf("Error should be '%v', got '%v'", testCase.err, err) - } + assert.Nil(t, raw, "Result should be nil in case of error") + assert.ErrorIs(t, err, testCase.err) } }) } } -func TestVP9Payloader_Payload(t *testing.T) { //nolint:cyclop +func TestVP9Payloader_Payload(t *testing.T) { r0 := int(rand.New(rand.NewSource(0)).Int31n(0x7FFF)) //nolint:gosec var rands [][2]byte for i := 0; i < 10; i++ { @@ -374,9 +364,7 @@ func TestVP9Payloader_Payload(t *testing.T) { //nolint:cyclop for _, b := range testCase.b { res = append(res, pck.Payload(testCase.mtu, b)...) } - if !reflect.DeepEqual(testCase.res, res) { - t.Errorf("Payloaded packet expected to be:\n %v\ngot:\n %v", testCase.res, res) - } + assert.Equal(t, testCase.res, res) }) } @@ -392,17 +380,16 @@ func TestVP9Payloader_Payload(t *testing.T) { //nolint:cyclop res := pck.Payload(4, []byte{0x01}) packet := VP9Packet{} _, err := packet.Unmarshal(res[0]) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + assert.NoError(t, err) if i > 0 { if pPrev.PictureID == 0x7FFF { - if packet.PictureID != 0 { - t.Errorf("Picture ID next to 0x7FFF must be 0, got %d", packet.PictureID) - } - } else if pPrev.PictureID+1 != packet.PictureID { - t.Errorf("Picture ID next must be incremented by 1: %d -> %d", pPrev.PictureID, packet.PictureID) + assert.Equal( + t, uint16(0), packet.PictureID, + "Picture ID next to 0x7FFF must be 0", + ) + } else { + assert.Equal(t, pPrev.PictureID+1, packet.PictureID, "Picture ID next must be incremented by 1") } } @@ -414,16 +401,16 @@ func TestVP9Payloader_Payload(t *testing.T) { //nolint:cyclop func TestVP9IsPartitionHead(t *testing.T) { vp9 := &VP9Packet{} t.Run("SmallPacket", func(t *testing.T) { - if vp9.IsPartitionHead([]byte{}) { - t.Fatal("Small packet should not be the head of a new partition") - } + assert.False(t, vp9.IsPartitionHead(nil), "Small packet should not be the head of a new partition") }) t.Run("NormalPacket", func(t *testing.T) { - if !vp9.IsPartitionHead([]byte{0x18, 0x00, 0x00}) { - t.Error("VP9 RTP packet with B flag should be head of a new partition") - } - if vp9.IsPartitionHead([]byte{0x10, 0x00, 0x00}) { - t.Error("VP9 RTP packet without B flag should not be head of a new partition") - } + assert.True( + t, vp9.IsPartitionHead([]byte{0x18, 0x00, 0x00}), + "VP9 RTP packet with B flag should be head of a new partition", + ) + assert.False( + t, vp9.IsPartitionHead([]byte{0x10, 0x00, 0x00}), + "VP9 RTP packet without B flag should not be head of a new partition", + ) }) } diff --git a/go.mod b/go.mod index 3b37fe0f..db07d1d5 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,13 @@ module github.com/pion/rtp go 1.20 -require github.com/pion/randutil v0.1.0 +require ( + github.com/pion/randutil v0.1.0 + github.com/stretchr/testify v1.10.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 401b903b..7269fd1d 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/header_extension_test.go b/header_extension_test.go index f12362e7..31ed7bcd 100644 --- a/header_extension_test.go +++ b/header_extension_test.go @@ -4,9 +4,9 @@ package rtp import ( - "bytes" - "encoding/hex" "testing" + + "github.com/stretchr/testify/assert" ) func TestHeaderExtension_RFC8285OneByteExtension(t *testing.T) { @@ -16,14 +16,11 @@ func TestHeaderExtension_RFC8285OneByteExtension(t *testing.T) { 0xBE, 0xDE, 0x00, 0x01, 0x50, 0xAA, 0x00, 0x00, 0x98, 0x36, 0xbe, 0x88, 0x9e, } - if _, err := p.Unmarshal(rawPkt); err != nil { - t.Fatal("Unmarshal err for valid extension") - } + _, err := p.Unmarshal(rawPkt) + assert.NoError(t, err, "Unmarshal err for valid extension") dstData, _ := p.Marshal() - if !bytes.Equal(dstData, rawPkt) { - t.Errorf("Marshal failed raw \nMarshaled:\n%s\nrawPkt:\n%s", hex.Dump(dstData), hex.Dump(rawPkt)) - } + assert.Equal(t, rawPkt, dstData) } func TestHeaderExtension_RFC8285OneByteTwoExtensionOfTwoBytes(t *testing.T) { @@ -39,26 +36,19 @@ func TestHeaderExtension_RFC8285OneByteTwoExtensionOfTwoBytes(t *testing.T) { rawPkt := []byte{ 0xBE, 0xDE, 0x00, 0x01, 0x10, 0xAA, 0x20, 0xBB, } - if _, err := ext.Unmarshal(rawPkt); err != nil { - t.Fatal("Unmarshal err for valid extension") - } + _, err := ext.Unmarshal(rawPkt) + assert.NoError(t, err, "Unmarshal err for valid extension") ext1 := ext.Get(1) ext1Expect := []byte{0xAA} - if !bytes.Equal(ext1, ext1Expect) { - t.Errorf("Extension has incorrect data. Got: %+v, Expected: %+v", ext1, ext1Expect) - } + assert.Equal(t, ext1Expect, ext1, "Extension has incorrect data") ext2 := ext.Get(2) ext2Expect := []byte{0xBB} - if !bytes.Equal(ext2, ext2Expect) { - t.Errorf("Extension has incorrect data. Got: %+v, Expected: %+v", ext2, ext2Expect) - } + assert.Equal(t, ext2Expect, ext2, "Extension has incorrect data") dstData, _ := ext.Marshal() - if !bytes.Equal(dstData, rawPkt) { - t.Errorf("Marshal failed raw \nMarshaled:\n%s\nrawPkt:\n%s", hex.Dump(dstData), hex.Dump(rawPkt)) - } + assert.Equal(t, rawPkt, dstData) } func TestHeaderExtension_RFC8285OneByteMultipleExtensionsWithPadding(t *testing.T) { @@ -79,27 +69,20 @@ func TestHeaderExtension_RFC8285OneByteMultipleExtensionsWithPadding(t *testing. 0xBE, 0xDE, 0x00, 0x03, 0x10, 0xAA, 0x21, 0xBB, 0xBB, 0x00, 0x00, 0x33, 0xCC, 0xCC, 0xCC, 0xCC, } - if _, err := ext.Unmarshal(rawPkt); err != nil { - t.Fatal("Unmarshal err for valid extension") - } + _, err := ext.Unmarshal(rawPkt) + assert.NoError(t, err, "Unmarshal err for valid extension") ext1 := ext.Get(1) ext1Expect := []byte{0xAA} - if !bytes.Equal(ext1, ext1Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext1, ext1Expect) - } + assert.Equal(t, ext1Expect, ext1, "Extension has incorrect data") ext2 := ext.Get(2) ext2Expect := []byte{0xBB, 0xBB} - if !bytes.Equal(ext2, ext2Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext2, ext2Expect) - } + assert.Equal(t, ext2Expect, ext2, "Extension has incorrect data") ext3 := ext.Get(3) ext3Expect := []byte{0xCC, 0xCC, 0xCC, 0xCC} - if !bytes.Equal(ext3, ext3Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext3, ext3Expect) - } + assert.Equal(t, ext3Expect, ext3, "Extension has incorrect data") dstBuf := map[string][]byte{ "CleanBuffer": make([]byte, 1000), @@ -112,12 +95,9 @@ func TestHeaderExtension_RFC8285OneByteMultipleExtensionsWithPadding(t *testing. buf := buf t.Run(name, func(t *testing.T) { n, err := ext.MarshalTo(buf) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(buf[:n], rawPkt) { - t.Errorf("Marshal failed raw \nMarshaled:\n%s\nrawPkt:\n%s", hex.Dump(buf[:n]), hex.Dump(rawPkt)) - } + assert.NoError(t, err) + + assert.Equal(t, rawPkt, buf[:n]) }) } } @@ -131,14 +111,11 @@ func TestHeaderExtension_RFC8285TwoByteExtension(t *testing.T) { 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0x00, 0x00, } - if _, err := ext.Unmarshal(rawPkt); err != nil { - t.Fatal("Unmarshal err for valid extension") - } + _, err := ext.Unmarshal(rawPkt) + assert.NoError(t, err, "Unmarshal err for valid extension") dstData, _ := ext.Marshal() - if !bytes.Equal(dstData, rawPkt) { - t.Errorf("Marshal failed raw \nMarshaled:\n%s\nrawPkt:\n%s", hex.Dump(dstData), hex.Dump(rawPkt)) - } + assert.Equal(t, rawPkt, dstData) } func TestHeaderExtension_RFC8285TwoByteMultipleExtensionsWithPadding(t *testing.T) { @@ -160,27 +137,20 @@ func TestHeaderExtension_RFC8285TwoByteMultipleExtensionsWithPadding(t *testing. 0xBB, 0x00, 0x03, 0x04, 0xCC, 0xCC, 0xCC, 0xCC, } - if _, err := ext.Unmarshal(rawPkt); err != nil { - t.Fatal("Unmarshal err for valid extension") - } + _, err := ext.Unmarshal(rawPkt) + assert.NoError(t, err, "Unmarshal err for valid extension") ext1 := ext.Get(1) ext1Expect := []byte{} - if !bytes.Equal(ext1, ext1Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext1, ext1Expect) - } + assert.Equal(t, ext1Expect, ext1, "Extension has incorrect data") ext2 := ext.Get(2) ext2Expect := []byte{0xBB} - if !bytes.Equal(ext2, ext2Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext2, ext2Expect) - } + assert.Equal(t, ext2Expect, ext2, "Extension has incorrect data") ext3 := ext.Get(3) ext3Expect := []byte{0xCC, 0xCC, 0xCC, 0xCC} - if !bytes.Equal(ext3, ext3Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext3, ext3Expect) - } + assert.Equal(t, ext3Expect, ext3, "Extension has incorrect data") } func TestHeaderExtension_RFC8285TwoByteMultipleExtensionsWithLargeExtension(t *testing.T) { @@ -209,97 +179,54 @@ func TestHeaderExtension_RFC8285TwoByteMultipleExtensionsWithLargeExtension(t *t 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, } - if _, err := ext.Unmarshal(rawPkt); err != nil { - t.Fatal("Unmarshal err for valid extension") - } + _, err := ext.Unmarshal(rawPkt) + assert.NoError(t, err, "Unmarshal err for valid extension") ext1 := ext.Get(1) ext1Expect := []byte{} - if !bytes.Equal(ext1, ext1Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext1, ext1Expect) - } + assert.Equal(t, ext1Expect, ext1, "Extension has incorrect data") ext2 := ext.Get(2) ext2Expect := []byte{0xBB} - if !bytes.Equal(ext2, ext2Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext2, ext2Expect) - } + assert.Equal(t, ext2Expect, ext2, "Extension has incorrect data") ext3 := ext.Get(3) ext3Expect := []byte{ 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, 0xCC, } - if !bytes.Equal(ext3, ext3Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext3, ext3Expect) - } + assert.Equal(t, ext3Expect, ext3, "Extension has incorrect data") dstData, _ := ext.Marshal() - if !bytes.Equal(dstData, rawPkt) { - t.Errorf("Marshal failed raw \nMarshaled: %+v,\nrawPkt: %+v", dstData, rawPkt) - } + assert.Equal(t, rawPkt, dstData) } func TestHeaderExtension_RFC8285OneByteDelExtension(t *testing.T) { ext := &OneByteHeaderExtension{} - if _, err := ext.Unmarshal([]byte{0xBE, 0xDE, 0x00, 0x00}); err != nil { - t.Fatal("Unmarshal err for valid extension") - } - - if err := ext.Set(1, []byte{0xBB}); err != nil { - t.Fatal("Set err for valid extension") - } - - extExtension := ext.Get(1) - if extExtension == nil { - t.Error("Extension should exist") - } - - err := ext.Del(1) - if err != nil { - t.Error("Should successfully delete extension") - } - - extExtension = ext.Get(1) - if extExtension != nil { - t.Error("Extension should not exist") - } - - err = ext.Del(1) - if err == nil { - t.Error("Should return error when deleting extension that doesnt exist") - } + _, err := ext.Unmarshal([]byte{0xBE, 0xDE, 0x00, 0x00}) + assert.NoError(t, err, "Unmarshal err for valid extension") + assert.NoError(t, ext.Set(1, []byte{0xBB}), "Set err for valid extension") + assert.NotNil(t, ext.Get(1), "Extension should exist") + assert.NoError(t, ext.Del(1), "Should successfully delete extension") + assert.Nil(t, ext.Get(1), "Extension should not") + assert.Error(t, ext.Del(1), "Should return error when deleting extension that doesnt exist") } func TestHeaderExtension_RFC8285TwoByteDelExtension(t *testing.T) { ext := &TwoByteHeaderExtension{} - if _, err := ext.Unmarshal([]byte{0x10, 0x00, 0x00, 0x00}); err != nil { - t.Fatal("Unmarshal err for valid extension") - } + _, err := ext.Unmarshal([]byte{0x10, 0x00, 0x00, 0x00}) + assert.NoError(t, err, "Unmarshal err for valid extension") - if err := ext.Set(1, []byte{0xBB}); err != nil { - t.Fatal("Set err for valid extension") - } + assert.NoError(t, ext.Set(1, []byte{0xBB}), "Set err for valid extension") extExtension := ext.Get(1) - if extExtension == nil { - t.Error("Extension should exist") - } + assert.NotNil(t, extExtension, "Extension should exist") - err := ext.Del(1) - if err != nil { - t.Error("Should successfully delete extension") - } + assert.NoError(t, ext.Del(1), "Should successfully delete extension") extExtension = ext.Get(1) - if extExtension != nil { - t.Error("Extension should not exist") - } - - err = ext.Del(1) - if err == nil { - t.Error("Should return error when deleting extension that doesnt exist") - } + assert.Nil(t, extExtension, "Extension should exist") + assert.Error(t, ext.Del(1), "Should return error when deleting extension that doesnt exist") } diff --git a/packet_test.go b/packet_test.go index fd58088b..bfe75672 100644 --- a/packet_test.go +++ b/packet_test.go @@ -4,20 +4,17 @@ package rtp import ( - "bytes" - "encoding/hex" - "errors" "fmt" - "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestBasic(t *testing.T) { // nolint:maintidx,cyclop packet := &Packet{} - if err := packet.Unmarshal([]byte{}); err == nil { - t.Fatal("Unmarshal did not error on zero length packet") - } + assert.Error(t, packet.Unmarshal([]byte{}), "Unmarshal did not error on zero length packet") + assert.ErrorIs(t, packet.Unmarshal([]byte{}), errHeaderSizeInsufficient) rawPkt := []byte{ 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, @@ -48,24 +45,15 @@ func TestBasic(t *testing.T) { // nolint:maintidx,cyclop // Unmarshal to the used Packet should work as well. for i := 0; i < 2; i++ { t.Run(fmt.Sprintf("Run%d", i+1), func(t *testing.T) { - if err := packet.Unmarshal(rawPkt); err != nil { - t.Error(err) - } else if !reflect.DeepEqual(packet, parsedPacket) { - t.Errorf("TestBasic unmarshal: got %#v, want %#v", packet, parsedPacket) - } + assert.NoError(t, packet.Unmarshal(rawPkt)) + assert.Equal(t, packet, parsedPacket) - if parsedPacket.Header.MarshalSize() != 20 { - t.Errorf("wrong computed header marshal size") - } else if parsedPacket.MarshalSize() != len(rawPkt) { - t.Errorf("wrong computed marshal size") - } + assert.Equal(t, packet.Header.MarshalSize(), 20, "wrong computed header marshal size") + assert.Equal(t, packet.MarshalSize(), len(rawPkt), "wrong computed marshal size") raw, err := packet.Marshal() - if err != nil { - t.Error(err) - } else if !reflect.DeepEqual(raw, rawPkt) { - t.Errorf("TestBasic marshal: got %#v, want %#v", raw, rawPkt) - } + assert.NoError(t, err) + assert.Equal(t, rawPkt, raw) }) } @@ -95,11 +83,8 @@ func TestBasic(t *testing.T) { // nolint:maintidx,cyclop Payload: rawPkt[20:21], PaddingSize: 4, } - if err := packet.Unmarshal(rawPkt); err != nil { - t.Error(err) - } else if !reflect.DeepEqual(packet, parsedPacket) { - t.Errorf("TestBasic padding unmarshal: got %#v, want %#v", packet, parsedPacket) - } + assert.NoError(t, packet.Unmarshal(rawPkt)) + assert.Equal(t, packet, parsedPacket) // packet with zero padding following packet with non-zero padding rawPkt = []byte{ @@ -127,11 +112,8 @@ func TestBasic(t *testing.T) { // nolint:maintidx,cyclop Payload: rawPkt[20:], PaddingSize: 0, } - if err := packet.Unmarshal(rawPkt); err != nil { - t.Error(err) - } else if !reflect.DeepEqual(packet, parsedPacket) { - t.Errorf("TestBasic zero padding unmarshal: got %#v, want %#v", packet, parsedPacket) - } + assert.NoError(t, packet.Unmarshal(rawPkt)) + assert.Equal(t, packet, parsedPacket) // packet with only padding rawPkt = []byte{ @@ -159,14 +141,9 @@ func TestBasic(t *testing.T) { // nolint:maintidx,cyclop Payload: []byte{}, PaddingSize: 5, } - if err := packet.Unmarshal(rawPkt); err != nil { - t.Error(err) - } else if !reflect.DeepEqual(packet, parsedPacket) { - t.Errorf("TestBasic padding only unmarshal: got %#v, want %#v", packet, parsedPacket) - } - if len(packet.Payload) != 0 { - t.Errorf("Unmarshal of padding only packet has payload of non-zero length: %d", len(packet.Payload)) - } + assert.NoError(t, packet.Unmarshal(rawPkt)) + assert.Equal(t, packet, parsedPacket) + assert.Len(t, packet.Payload, 0, "Unmarshal of padding only packet has payload of non-zero length") // packet with excessive padding rawPkt = []byte{ @@ -195,12 +172,8 @@ func TestBasic(t *testing.T) { // nolint:maintidx,cyclop PaddingSize: 0, } err := packet.Unmarshal(rawPkt) - if err == nil { - t.Fatal("Unmarshal did not error on packet with excessive padding") - } - if !errors.Is(err, errTooSmall) { - t.Errorf("Expected error: %v, got: %v", errTooSmall, err) - } + assert.Error(t, err, "Unmarshal did not error on packet with excessive padding") + assert.ErrorIs(t, err, errTooSmall) // marshal packet with padding rawPkt = []byte{ @@ -229,12 +202,8 @@ func TestBasic(t *testing.T) { // nolint:maintidx,cyclop PaddingSize: 4, } buf, err := parsedPacket.Marshal() - if err != nil { - t.Error(err) - } - if !reflect.DeepEqual(buf, rawPkt) { - t.Errorf("TestBasic padding marshal: got %#v, want %#v", buf, rawPkt) - } + assert.NoError(t, err) + assert.Equal(t, rawPkt, buf) // marshal packet with padding only rawPkt = []byte{ @@ -263,12 +232,8 @@ func TestBasic(t *testing.T) { // nolint:maintidx,cyclop PaddingSize: 5, } buf, err = parsedPacket.Marshal() - if err != nil { - t.Error(err) - } - if !reflect.DeepEqual(buf, rawPkt) { - t.Errorf("TestBasic padding marshal: got %#v, want %#v", buf, rawPkt) - } + assert.NoError(t, err) + assert.Equal(t, rawPkt, buf) // marshal packet with padding only without setting Padding explicitly in Header rawPkt = []byte{ @@ -297,12 +262,8 @@ func TestBasic(t *testing.T) { // nolint:maintidx,cyclop PaddingSize: 5, } buf, err = parsedPacket.Marshal() - if err != nil { - t.Error(err) - } - if !reflect.DeepEqual(buf, rawPkt) { - t.Errorf("TestBasic padding marshal: got %#v, want %#v", buf, rawPkt) - } + assert.NoError(t, err) + assert.Equal(t, rawPkt, buf) } func TestExtension(t *testing.T) { @@ -312,17 +273,19 @@ func TestExtension(t *testing.T) { 0x90, 0x60, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, } - if err := packet.Unmarshal(missingExtensionPkt); err == nil { - t.Fatal("Unmarshal did not error on packet with missing extension data") - } + assert.Error( + t, packet.Unmarshal(missingExtensionPkt), + "Unmarshal did not error on packet with missing extension data", + ) invalidExtensionLengthPkt := []byte{ 0x90, 0x60, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0x99, 0x99, 0x99, 0x99, } - if err := packet.Unmarshal(invalidExtensionLengthPkt); err == nil { - t.Fatal("Unmarshal did not error on packet with invalid extension length") - } + assert.Error( + t, packet.Unmarshal(invalidExtensionLengthPkt), + "Unmarshal did not error on packet with invalid extension length", + ) packet = &Packet{ Header: Header{ @@ -336,9 +299,8 @@ func TestExtension(t *testing.T) { }, Payload: []byte{}, } - if _, err := packet.Marshal(); err == nil { - t.Fatal("Marshal did not error on packet with invalid extension length") - } + _, err := packet.Marshal() + assert.Error(t, err, "Marshal did not error on packet with invalid extension length") } func TestRFC8285OneByteExtension(t *testing.T) { @@ -349,9 +311,7 @@ func TestRFC8285OneByteExtension(t *testing.T) { 0x27, 0x82, 0xBE, 0xDE, 0x00, 0x01, 0x50, 0xAA, 0x00, 0x00, 0x98, 0x36, 0xbe, 0x88, 0x9e, } - if err := packet.Unmarshal(rawPkt); err != nil { - t.Fatal("Unmarshal err for valid extension") - } + assert.NoError(t, packet.Unmarshal(rawPkt)) packet = &Packet{ Header: Header{ @@ -374,9 +334,7 @@ func TestRFC8285OneByteExtension(t *testing.T) { } dstData, _ := packet.Marshal() - if !bytes.Equal(dstData, rawPkt) { - t.Errorf("Marshal failed raw \nMarshaled:\n%s\nrawPkt:\n%s", hex.Dump(dstData), hex.Dump(rawPkt)) - } + assert.Equal(t, rawPkt, dstData) } func TestRFC8285OneByteTwoExtensionOfTwoBytes(t *testing.T) { @@ -395,21 +353,15 @@ func TestRFC8285OneByteTwoExtensionOfTwoBytes(t *testing.T) { // Payload 0x98, 0x36, 0xbe, 0x88, 0x9e, } - if err := packet.Unmarshal(rawPkt); err != nil { - t.Fatal("Unmarshal err for valid extension") - } + assert.NoError(t, packet.Unmarshal(rawPkt)) ext1 := packet.GetExtension(1) ext1Expect := []byte{0xAA} - if !bytes.Equal(ext1, ext1Expect) { - t.Errorf("Extension has incorrect data. Got: %+v, Expected: %+v", ext1, ext1Expect) - } + assert.Equal(t, ext1Expect, ext1, "Extension has incorrect data") ext2 := packet.GetExtension(2) ext2Expect := []byte{0xBB} - if !bytes.Equal(ext2, ext2Expect) { - t.Errorf("Extension has incorrect data. Got: %+v, Expected: %+v", ext2, ext2Expect) - } + assert.Equal(t, ext2Expect, ext2, "Extension has incorrect data") // Test Marshal packet = &Packet{ @@ -436,9 +388,7 @@ func TestRFC8285OneByteTwoExtensionOfTwoBytes(t *testing.T) { } dstData, _ := packet.Marshal() - if !bytes.Equal(dstData, rawPkt) { - t.Errorf("Marshal failed raw \nMarshaled:\n%s\nrawPkt:\n%s", hex.Dump(dstData), hex.Dump(rawPkt)) - } + assert.Equal(t, rawPkt, dstData) } func TestRFC8285OneByteMultipleExtensionsWithPadding(t *testing.T) { @@ -462,27 +412,19 @@ func TestRFC8285OneByteMultipleExtensionsWithPadding(t *testing.T) { // Payload 0x98, 0x36, 0xbe, 0x88, 0x9e, } - if err := packet.Unmarshal(rawPkt); err != nil { - t.Fatal("Unmarshal err for valid extension") - } + assert.NoError(t, packet.Unmarshal(rawPkt)) ext1 := packet.GetExtension(1) ext1Expect := []byte{0xAA} - if !bytes.Equal(ext1, ext1Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext1, ext1Expect) - } + assert.Equal(t, ext1Expect, ext1, "Extension has incorrect data") ext2 := packet.GetExtension(2) ext2Expect := []byte{0xBB, 0xBB} - if !bytes.Equal(ext2, ext2Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext2, ext2Expect) - } + assert.Equal(t, ext2Expect, ext2, "Extension has incorrect data") ext3 := packet.GetExtension(3) ext3Expect := []byte{0xCC, 0xCC, 0xCC, 0xCC} - if !bytes.Equal(ext3, ext3Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext3, ext3Expect) - } + assert.Equal(t, ext3Expect, ext3, "Extension has incorrect data") rawPktReMarshal := []byte{ 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, @@ -502,12 +444,8 @@ func TestRFC8285OneByteMultipleExtensionsWithPadding(t *testing.T) { buf := buf t.Run(name, func(t *testing.T) { n, err := packet.MarshalTo(buf) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(buf[:n], rawPktReMarshal) { - t.Errorf("Marshal failed raw \nMarshaled:\n%s\nrawPkt:\n%s", hex.Dump(buf[:n]), hex.Dump(rawPktReMarshal)) - } + assert.NoError(t, err) + assert.Equal(t, rawPktReMarshal, buf[:n]) }) } } @@ -559,9 +497,7 @@ func TestRFC8285OneByteMultipleExtensions(t *testing.T) { } dstData, _ := packet.Marshal() - if !bytes.Equal(dstData, rawPkt) { - t.Errorf("Marshal failed raw \nMarshaled:\n%s\nrawPkt:\n%s", hex.Dump(dstData), hex.Dump(rawPkt)) - } + assert.Equal(t, rawPkt, dstData) } func TestRFC8285TwoByteExtension(t *testing.T) { @@ -574,9 +510,7 @@ func TestRFC8285TwoByteExtension(t *testing.T) { 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0x00, 0x00, 0x98, 0x36, 0xbe, 0x88, 0x9e, } - if err := packet.Unmarshal(rawPkt); err != nil { - t.Fatal("Unmarshal err for valid extension") - } + assert.NoError(t, packet.Unmarshal(rawPkt)) packet = &Packet{ Header: Header{ @@ -601,9 +535,7 @@ func TestRFC8285TwoByteExtension(t *testing.T) { } dstData, _ := packet.Marshal() - if !bytes.Equal(dstData, rawPkt) { - t.Errorf("Marshal failed raw \nMarshaled:\n%s\nrawPkt:\n%s", hex.Dump(dstData), hex.Dump(rawPkt)) - } + assert.Equal(t, rawPkt, dstData) } func TestRFC8285TwoByteMultipleExtensionsWithPadding(t *testing.T) { @@ -626,27 +558,19 @@ func TestRFC8285TwoByteMultipleExtensionsWithPadding(t *testing.T) { 0xBB, 0x00, 0x03, 0x04, 0xCC, 0xCC, 0xCC, 0xCC, 0x98, 0x36, 0xbe, 0x88, 0x9e, } - if err := packet.Unmarshal(rawPkt); err != nil { - t.Fatal("Unmarshal err for valid extension") - } + assert.NoError(t, packet.Unmarshal(rawPkt)) ext1 := packet.GetExtension(1) ext1Expect := []byte{} - if !bytes.Equal(ext1, ext1Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext1, ext1Expect) - } + assert.Equal(t, ext1Expect, ext1, "Extension has incorrect data") ext2 := packet.GetExtension(2) ext2Expect := []byte{0xBB} - if !bytes.Equal(ext2, ext2Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext2, ext2Expect) - } + assert.Equal(t, ext2Expect, ext2, "Extension has incorrect data") ext3 := packet.GetExtension(3) ext3Expect := []byte{0xCC, 0xCC, 0xCC, 0xCC} - if !bytes.Equal(ext3, ext3Expect) { - t.Errorf("Extension has incorrect data. Got: %v+, Expected: %v+", ext3, ext3Expect) - } + assert.Equal(t, ext3Expect, ext3, "Extension has incorrect data") } func TestRFC8285TwoByteMultipleExtensionsWithLargeExtension(t *testing.T) { @@ -700,11 +624,8 @@ func TestRFC8285TwoByteMultipleExtensionsWithLargeExtension(t *testing.T) { }, Payload: rawPkt[40:], } - dstData, _ := packet.Marshal() - if !bytes.Equal(dstData, rawPkt) { - t.Errorf("Marshal failed raw \nMarshaled: %+v,\nrawPkt: %+v", dstData, rawPkt) - } + assert.Equal(t, rawPkt, dstData) } func TestRFC8285GetExtensionReturnsNilWhenExtensionsDisabled(t *testing.T) { @@ -725,11 +646,7 @@ func TestRFC8285GetExtensionReturnsNilWhenExtensionsDisabled(t *testing.T) { }, Payload: payload, } - - err := packet.GetExtension(1) - if err != nil { - t.Error("Should return nil on GetExtension when h.Extension: false") - } + assert.Nil(t, packet.GetExtension(1), "Should return nil on GetExtension when h.Extension: false") } func TestRFC8285DelExtension(t *testing.T) { @@ -756,26 +673,10 @@ func TestRFC8285DelExtension(t *testing.T) { }, Payload: payload, } - - ext := packet.GetExtension(1) - if ext == nil { - t.Error("Extension should exist") - } - - err := packet.DelExtension(1) - if err != nil { - t.Error("Should successfully delete extension") - } - - ext = packet.GetExtension(1) - if ext != nil { - t.Error("Extension should not exist") - } - - err = packet.DelExtension(1) - if err == nil { - t.Error("Should return error when deleting extension that doesnt exist") - } + assert.NotNil(t, packet.GetExtension(1), "Extension should exist") + assert.NoError(t, packet.DelExtension(1), "Should successfully delete extension") + assert.Nil(t, packet.GetExtension(1), "Extension should not exist") + assert.Error(t, packet.DelExtension(1), "Should return error when deleting extension that doesnt exist") } func TestRFC8285GetExtensionIDs(t *testing.T) { @@ -805,24 +706,13 @@ func TestRFC8285GetExtensionIDs(t *testing.T) { }, Payload: payload, } - ids := packet.GetExtensionIDs() - if ids == nil { - t.Error("Extension should exist") - } - if len(ids) != len(packet.Extensions) { - t.Errorf( - "The number of IDs should be equal to the number of extensions,want=%d,have=%d", - len(packet.Extensions), - len(ids), - ) - } + assert.NotNil(t, ids, "Extension should exist") + assert.Len(t, ids, len(packet.Extensions), "The number of IDs should be equal to the number of extensions") for _, id := range ids { ext := packet.GetExtension(id) - if ext == nil { - t.Error("Extension should exist") - } + assert.NotNil(t, ext, "Extension should exist") } } @@ -844,11 +734,7 @@ func TestRFC8285GetExtensionIDsReturnsErrorWhenExtensionsDisabled(t *testing.T) }, Payload: payload, } - - ids := packet.GetExtensionIDs() - if ids != nil { - t.Error("Should return nil on GetExtensionIDs when h.Extensions is nil") - } + assert.Nil(t, packet.GetExtensionIDs(), "Should return nil on GetExtensionIDs when h.Extensions is nil") } func TestRFC8285DelExtensionReturnsErrorWhenExtensionsDisabled(t *testing.T) { @@ -869,11 +755,9 @@ func TestRFC8285DelExtensionReturnsErrorWhenExtensionsDisabled(t *testing.T) { }, Payload: payload, } - - err := packet.DelExtension(1) - if err == nil { - t.Error("Should return error on DelExtension when h.Extension: false") - } + assert.Error( + t, packet.DelExtension(1), "DelExtension did not error on h.Extension: false", + ) } func TestRFC8285OneByteSetExtensionShouldEnableExensionsWhenAdding(t *testing.T) { @@ -894,28 +778,12 @@ func TestRFC8285OneByteSetExtensionShouldEnableExensionsWhenAdding(t *testing.T) }, Payload: payload, } - extension := []byte{0xAA, 0xAA} - err := packet.SetExtension(1, extension) - if err != nil { - t.Error("Error setting extension") - } - - if packet.Extension != true { - t.Error("Extension should be set to true") - } - - if packet.ExtensionProfile != 0xBEDE { - t.Error("Extension profile should be set to 0xBEDE") - } - - if len(packet.Extensions) != 1 { - t.Error("Extensions should be set to 1") - } - - if !bytes.Equal(packet.GetExtension(1), extension) { - t.Error("Extension value is not set") - } + assert.NoError(t, packet.SetExtension(1, extension)) + assert.True(t, packet.Extension) + assert.Equal(t, uint16(0xBEDE), packet.ExtensionProfile) + assert.Len(t, packet.Extensions, 1) + assert.Equal(t, extension, packet.GetExtension(1)) } func TestRFC8285OneByteSetExtensionShouldSetCorrectExtensionProfileFor16ByteExtension(t *testing.T) { @@ -944,13 +812,8 @@ func TestRFC8285OneByteSetExtensionShouldSetCorrectExtensionProfileFor16ByteExte 0xAA, 0xAA, 0xAA, 0xAA, } err := packet.SetExtension(1, extension) - if err != nil { - t.Error("Error setting extension") - } - - if packet.ExtensionProfile != 0xBEDE { - t.Error("Extension profile should be set to 0xBEDE") - } + assert.NoError(t, err, "Error setting extension") + assert.Equal(t, uint16(0xBEDE), packet.ExtensionProfile) } func TestRFC8285OneByteSetExtensionShouldUpdateExistingExension(t *testing.T) { @@ -977,20 +840,12 @@ func TestRFC8285OneByteSetExtensionShouldUpdateExistingExension(t *testing.T) { }, Payload: payload, } - - if !bytes.Equal(packet.GetExtension(1), []byte{0xAA}) { - t.Error("Extension value not initialize properly") - } + assert.Equal(t, []byte{0xAA}, packet.GetExtension(1)) extension := []byte{0xBB} err := packet.SetExtension(1, extension) - if err != nil { - t.Error("Error setting extension") - } - - if !bytes.Equal(packet.GetExtension(1), extension) { - t.Error("Extension value was not set") - } + assert.NoError(t, err, "Error setting extension") + assert.Equal(t, extension, packet.GetExtension(1)) } func TestRFC8285OneByteSetExtensionShouldErrorWhenInvalidIDProvided(t *testing.T) { @@ -1017,14 +872,14 @@ func TestRFC8285OneByteSetExtensionShouldErrorWhenInvalidIDProvided(t *testing.T }, Payload: payload, } - - if packet.SetExtension(0, []byte{0xBB}) == nil { - t.Error("SetExtension did not error on invalid id") - } - - if packet.SetExtension(15, []byte{0xBB}) == nil { - t.Error("SetExtension did not error on invalid id") - } + assert.Error( + t, packet.SetExtension(0, []byte{0xBB}), + "SetExtension did not error on invalid id", + ) + assert.Error( + t, packet.SetExtension(15, []byte{0xBB}), + "SetExtension did not error on invalid id", + ) } func TestRFC8285OneByteExtensionTermianteProcessingWhenReservedIDEncountered(t *testing.T) { @@ -1034,20 +889,14 @@ func TestRFC8285OneByteExtensionTermianteProcessingWhenReservedIDEncountered(t * 0x90, 0xe0, 0x69, 0x8f, 0xd9, 0xc2, 0x93, 0xda, 0x1c, 0x64, 0x27, 0x82, 0xBE, 0xDE, 0x00, 0x01, 0xF0, 0xAA, 0x98, 0x36, 0xbe, 0x88, 0x9e, } - if err := packet.Unmarshal(reservedIDPkt); err != nil { - t.Error("Unmarshal error on packet with reserved extension id") - } - - if len(packet.Extensions) != 0 { - t.Error("Extensions should be empty for invalid id") - } + assert.NoError( + t, packet.Unmarshal(reservedIDPkt), + "Unmarshal error on packet with reserved extension id", + ) + assert.Len(t, packet.Extensions, 0, "Extensions should be empty for invalid id") payload := reservedIDPkt[17:] - if !bytes.Equal(packet.Payload, payload) { - t.Errorf("p.Payload must be same as payload.\n p.Payload: %+v,\n payload: %+v", - packet.Payload, payload, - ) - } + assert.Equal(t, payload, packet.Payload) } func TestRFC8285OneByteSetExtensionShouldErrorWhenPayloadTooLarge(t *testing.T) { @@ -1074,13 +923,14 @@ func TestRFC8285OneByteSetExtensionShouldErrorWhenPayloadTooLarge(t *testing.T) }, Payload: payload, } - - if packet.SetExtension(1, []byte{ - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - }) == nil { - t.Error("SetExtension did not error on too large payload") - } + assert.Error( + t, + packet.SetExtension(1, []byte{ + 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, + 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, + }), + "SetExtension did not error on too large payload", + ) } func TestRFC8285TwoByteSetExtensionShouldEnableExensionsWhenAdding(t *testing.T) { @@ -1106,26 +956,11 @@ func TestRFC8285TwoByteSetExtensionShouldEnableExensionsWhenAdding(t *testing.T) 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, } - err := packet.SetExtension(1, extension) - if err != nil { - t.Error("Error setting extension") - } - - if packet.Extension != true { - t.Error("Extension should be set to true") - } - - if packet.ExtensionProfile != 0x1000 { - t.Error("Extension profile should be set to 0xBEDE") - } - - if len(packet.Extensions) != 1 { - t.Error("Extensions should be set to 1") - } - - if !bytes.Equal(packet.GetExtension(1), extension) { - t.Error("Extension value is not set") - } + assert.NoError(t, packet.SetExtension(1, extension)) + assert.True(t, packet.Extension) + assert.Equal(t, uint16(0x1000), packet.ExtensionProfile) + assert.Len(t, packet.Extensions, 1) + assert.Equal(t, extension, packet.GetExtension(1)) } func TestRFC8285TwoByteSetExtensionShouldUpdateExistingExension(t *testing.T) { @@ -1152,23 +987,15 @@ func TestRFC8285TwoByteSetExtensionShouldUpdateExistingExension(t *testing.T) { }, Payload: payload, } - - if !bytes.Equal(packet.GetExtension(1), []byte{0xAA}) { - t.Error("Extension value not initialize properly") - } + assert.Equal(t, []byte{0xAA}, packet.GetExtension(1), "Extension value not initialize properly") extension := []byte{ 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, } err := packet.SetExtension(1, extension) - if err != nil { - t.Error("Error setting extension") - } - - if !bytes.Equal(packet.GetExtension(1), extension) { - t.Error("Extension value was not set") - } + assert.NoError(t, err) + assert.Equal(t, packet.GetExtension(1), extension) } func TestRFC8285TwoByteSetExtensionShouldErrorWhenPayloadTooLarge(t *testing.T) { @@ -1196,7 +1023,7 @@ func TestRFC8285TwoByteSetExtensionShouldErrorWhenPayloadTooLarge(t *testing.T) Payload: payload, } - if packet.SetExtension(1, []byte{ + err := packet.SetExtension(1, []byte{ 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, @@ -1223,9 +1050,8 @@ func TestRFC8285TwoByteSetExtensionShouldErrorWhenPayloadTooLarge(t *testing.T) 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB, - }) == nil { - t.Error("SetExtension did not error on too large payload") - } + }) + assert.Error(t, err, "SetExtension did not error on too large payload") } func TestRFC8285Padding(t *testing.T) { @@ -1250,9 +1076,7 @@ func TestRFC8285Padding(t *testing.T) { }, } { _, err := header.Unmarshal(payload) - if !errors.Is(err, errHeaderSizeInsufficientForExtension) { - t.Fatal("Expected errHeaderSizeInsufficientForExtension") - } + assert.ErrorIs(t, err, errHeaderSizeInsufficientForExtension) } } @@ -1282,14 +1106,10 @@ func TestRFC3550SetExtensionShouldErrorWhenNonZero(t *testing.T) { } expect := []byte{0xBB} - if packet.SetExtension(0, expect) != nil { - t.Error("SetExtension should not error on valid id") - } + assert.NoError(t, packet.SetExtension(0, expect), "SetExtension should not error on valid id") actual := packet.GetExtension(0) - if !bytes.Equal(actual, expect) { - t.Error("p.GetExtension returned incorrect value.") - } + assert.Equal(t, expect, actual) } func TestRFC3550SetExtensionShouldRaiseErrorWhenSettingNonzeroID(t *testing.T) { @@ -1312,9 +1132,7 @@ func TestRFC3550SetExtensionShouldRaiseErrorWhenSettingNonzeroID(t *testing.T) { Payload: payload, } - if packet.SetExtension(1, []byte{0xBB}) == nil { - t.Error("SetExtension did not error on invalid id") - } + assert.Error(t, packet.SetExtension(1, []byte{0xBB}), "SetExtension should error on invalid id") } func TestUnmarshal_ErrorHandling(t *testing.T) { @@ -1372,9 +1190,7 @@ func TestUnmarshal_ErrorHandling(t *testing.T) { t.Run(name, func(t *testing.T) { h := &Header{} _, err := h.Unmarshal(testCase.input) - if !errors.Is(err, testCase.err) { - t.Errorf("Expected error: %v, got: %v", testCase.err, err) - } + assert.ErrorIs(t, err, testCase.err) }) } } @@ -1387,27 +1203,13 @@ func TestRoundtrip(t *testing.T) { payload := rawPkt[12:] packet := &Packet{} - if err := packet.Unmarshal(rawPkt); err != nil { - t.Fatal(err) - } - if !bytes.Equal(payload, packet.Payload) { - t.Errorf("p.Payload must be same as payload.\n payload: %+v,\np.Payload: %+v", - payload, packet.Payload, - ) - } + assert.NoError(t, packet.Unmarshal(rawPkt)) + assert.Equal(t, payload, packet.Payload) buf, err := packet.Marshal() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(rawPkt, buf) { - t.Errorf("buf must be same as rawPkt.\n buf: %+v,\nrawPkt: %+v", buf, rawPkt) - } - if !bytes.Equal(payload, packet.Payload) { - t.Errorf("p.Payload must be same as payload.\n payload: %+v,\np.Payload: %+v", - payload, packet.Payload, - ) - } + assert.NoError(t, err) + assert.Equal(t, rawPkt, buf) + assert.Equal(t, payload, packet.Payload) } func TestCloneHeader(t *testing.T) { @@ -1428,18 +1230,12 @@ func TestCloneHeader(t *testing.T) { CSRC: []uint32{}, } clone := header.Clone() - if !reflect.DeepEqual(header, clone) { - t.Errorf("Cloned clone does not match the original") - } + assert.Equal(t, header, clone) header.CSRC = append(header.CSRC, 1) - if len(clone.CSRC) == len(header.CSRC) { - t.Errorf("Expected CSRC to be unchanged") - } + assert.NotEqual(t, len(clone.CSRC), len(header.CSRC), "Expected CSRC to be unchanged") header.Extensions[0].payload[0] = 0x1F - if clone.Extensions[0].payload[0] == 0x1F { - t.Errorf("Expected Extensions to be unchanged") - } + assert.NotEqual(t, clone.Extensions[0].payload[0], byte(0x1F), "Expected extension to be unchanged") } func TestClonePacket(t *testing.T) { @@ -1453,14 +1249,10 @@ func TestClonePacket(t *testing.T) { } clone := packet.Clone() - if !reflect.DeepEqual(packet, clone) { - t.Errorf("Cloned Packet does not match the original") - } + assert.Equal(t, packet, clone) packet.Payload[0] = 0x1F - if clone.Payload[0] == 0x1F { - t.Errorf("Expected Payload to be unchanged") - } + assert.NotEqual(t, clone.Payload[0], 0x1F, "Expected payload to be unchanged") } func BenchmarkMarshal(b *testing.B) { diff --git a/packetizer_test.go b/packetizer_test.go index 765d44e1..d0a33b68 100644 --- a/packetizer_test.go +++ b/packetizer_test.go @@ -5,11 +5,11 @@ package rtp import ( "fmt" - "reflect" "testing" "time" "github.com/pion/rtp/codecs" + "github.com/stretchr/testify/assert" ) func TestPacketizer(t *testing.T) { @@ -18,12 +18,16 @@ func TestPacketizer(t *testing.T) { packetizer := NewPacketizer(100, 98, 0x1234ABCD, &codecs.G722Payloader{}, NewRandomSequencer(), 90000) packets := packetizer.Packetize(multiplepayload, 2000) - if len(packets) != 2 { + expectedLen := 2 + if len(packets) != expectedLen { packetlengths := "" for i := 0; i < len(packets); i++ { packetlengths += fmt.Sprintf("Packet %d length %d\n", i, len(packets[i].Payload)) } - t.Fatalf("Generated %d packets instead of 2\n%s", len(packets), packetlengths) + assert.Failf( + t, "Packetize failed", "Generated %d packets instead of %d\n%s", + len(packets), expectedLen, packetlengths, + ) } } @@ -31,9 +35,7 @@ func TestPacketizer_AbsSendTime(t *testing.T) { // use the G722 payloader here, because it's very simple and all 0s is valid G722 data. pktizer := NewPacketizer(100, 98, 0x1234ABCD, &codecs.G722Payloader{}, NewFixedSequencer(1234), 90000) p, ok := pktizer.(*packetizer) - if !ok { - t.Fatal("Failed to access packetizer") - } + assert.True(t, ok, "Failed to cast to *packetizer") p.Timestamp = 45678 p.timegen = func() time.Time { @@ -67,15 +69,11 @@ func TestPacketizer_AbsSendTime(t *testing.T) { Payload: []byte{0x11, 0x12, 0x13, 0x14}, } - if len(packets) != 1 { - t.Fatalf("Generated %d packets instead of 1", len(packets)) - } - if !reflect.DeepEqual(expected, packets[0]) { - t.Errorf("Packetize failed\nexpected: %v\n got: %v", expected, packets[0]) - } + assert.Lenf(t, packets, 1, "Generated %d packets instead of 1", len(packets)) + assert.Equal(t, expected, packets[0], "Packetize failed") } -func TestPacketizer_Roundtrip(t *testing.T) { //nolint:cyclop +func TestPacketizer_Roundtrip(t *testing.T) { multiplepayload := make([]byte, 128) packetizer := NewPacketizer(100, 98, 0x1234ABCD, &codecs.G722Payloader{}, NewRandomSequencer(), 90000) packets := packetizer.Packetize(multiplepayload, 1000) @@ -83,9 +81,7 @@ func TestPacketizer_Roundtrip(t *testing.T) { //nolint:cyclop rawPkts := make([][]byte, 0, 1400) for _, pkt := range packets { raw, err := pkt.Marshal() - if err != nil { - t.Errorf("Packet Marshal failed: %v", err) - } + assert.NoError(t, err) rawPkts = append(rawPkts, raw) } @@ -94,59 +90,23 @@ func TestPacketizer_Roundtrip(t *testing.T) { //nolint:cyclop expectedPkt := packets[ndx] pkt := &Packet{} - err := pkt.Unmarshal(raw) - if err != nil { - t.Errorf("Packet Unmarshal failed: %v", err) - } - - if len(raw) != pkt.MarshalSize() { - t.Errorf("Packet sizes don't match, expected %d but got %d", len(raw), pkt.MarshalSize()) - } - if expectedPkt.MarshalSize() != pkt.MarshalSize() { - t.Errorf("Packet marshal sizes don't match, expected %d but got %d", expectedPkt.MarshalSize(), pkt.MarshalSize()) - } - - if expectedPkt.Version != pkt.Version { - t.Errorf("Packet versions don't match, expected %d but got %d", expectedPkt.Version, pkt.Version) - } - if expectedPkt.Padding != pkt.Padding { - t.Errorf("Packet versions don't match, expected %t but got %t", expectedPkt.Padding, pkt.Padding) - } - if expectedPkt.Extension != pkt.Extension { - t.Errorf("Packet versions don't match, expected %v but got %v", expectedPkt.Extension, pkt.Extension) - } - if expectedPkt.Marker != pkt.Marker { - t.Errorf("Packet versions don't match, expected %v but got %v", expectedPkt.Marker, pkt.Marker) - } - if expectedPkt.PayloadType != pkt.PayloadType { - t.Errorf("Packet versions don't match, expected %d but got %d", expectedPkt.PayloadType, pkt.PayloadType) - } - if expectedPkt.SequenceNumber != pkt.SequenceNumber { - t.Errorf("Packet versions don't match, expected %d but got %d", expectedPkt.SequenceNumber, pkt.SequenceNumber) - } - if expectedPkt.Timestamp != pkt.Timestamp { - t.Errorf("Packet versions don't match, expected %d but got %d", expectedPkt.Timestamp, pkt.Timestamp) - } - if expectedPkt.SSRC != pkt.SSRC { - t.Errorf("Packet versions don't match, expected %d but got %d", expectedPkt.SSRC, pkt.SSRC) - } - if !reflect.DeepEqual(expectedPkt.CSRC, pkt.CSRC) { - t.Errorf("Packet versions don't match, expected %v but got %v", expectedPkt.CSRC, pkt.CSRC) - } - if expectedPkt.ExtensionProfile != pkt.ExtensionProfile { - t.Errorf("Packet versions don't match, expected %d but got %d", expectedPkt.ExtensionProfile, pkt.ExtensionProfile) - } - if !reflect.DeepEqual(expectedPkt.Extensions, pkt.Extensions) { - t.Errorf("Packet versions don't match, expected %v but got %v", expectedPkt.Extensions, pkt.Extensions) - } - if !reflect.DeepEqual(expectedPkt.Payload, pkt.Payload) { - t.Errorf("Packet versions don't match, expected %v but got %v", expectedPkt.Payload, pkt.Payload) - } + assert.NoError(t, pkt.Unmarshal(raw)) + assert.Equal(t, len(raw), pkt.MarshalSize()) + assert.Equal(t, expectedPkt.MarshalSize(), pkt.MarshalSize()) + assert.Equal(t, expectedPkt.Version, pkt.Version) + assert.Equal(t, expectedPkt.Padding, pkt.Padding) + assert.Equal(t, expectedPkt.Extension, pkt.Extension) + assert.Equal(t, expectedPkt.Marker, pkt.Marker) + assert.Equal(t, expectedPkt.PayloadType, pkt.PayloadType) + assert.Equal(t, expectedPkt.SequenceNumber, pkt.SequenceNumber) + assert.Equal(t, expectedPkt.Timestamp, pkt.Timestamp) + assert.Equal(t, expectedPkt.SSRC, pkt.SSRC) + assert.Equal(t, expectedPkt.CSRC, pkt.CSRC) + assert.Equal(t, expectedPkt.ExtensionProfile, pkt.ExtensionProfile) + assert.Equal(t, expectedPkt.Extensions, pkt.Extensions) + assert.Equal(t, expectedPkt.Payload, pkt.Payload) pkt.PaddingSize = 0 - - if !reflect.DeepEqual(expectedPkt, pkt) { - t.Errorf("Packets don't match, expected %v but got %v", expectedPkt, pkt) - } + assert.Equal(t, expectedPkt, pkt) } } diff --git a/playoutdelayextension_test.go b/playoutdelayextension_test.go index b08cbdbf..d67b6e7d 100644 --- a/playoutdelayextension_test.go +++ b/playoutdelayextension_test.go @@ -4,9 +4,9 @@ package rtp import ( - "bytes" - "errors" "testing" + + "github.com/stretchr/testify/assert" ) func TestPlayoutDelayExtensionTooSmall(t *testing.T) { @@ -14,17 +14,15 @@ func TestPlayoutDelayExtensionTooSmall(t *testing.T) { var rawData []byte - if err := t1.Unmarshal(rawData); !errors.Is(err, errTooSmall) { - t.Fatal("err != errTooSmall") - } + err := t1.Unmarshal(rawData) + assert.ErrorIs(t, err, errTooSmall) } func TestPlayoutDelayExtensionTooLarge(t *testing.T) { t1 := PlayoutDelayExtension{MinDelay: 1 << 12, MaxDelay: 1 << 12} - if _, err := t1.Marshal(); !errors.Is(err, errPlayoutDelayInvalidValue) { - t.Fatal("err != errPlayoutDelayInvalidValue") - } + _, err := t1.Marshal() + assert.ErrorIs(t, err, errPlayoutDelayInvalidValue) } func TestPlayoutDelayExtension(t *testing.T) { @@ -34,22 +32,17 @@ func TestPlayoutDelayExtension(t *testing.T) { 0x01, 0x01, 0x00, } - if err := t1.Unmarshal(rawData); err != nil { - t.Fatal("Unmarshal error on extension data") - } + err := t1.Unmarshal(rawData) + assert.NoError(t, err) t2 := PlayoutDelayExtension{ MinDelay: 1 << 4, MaxDelay: 1 << 8, } - if t1 != t2 { - t.Error("Unmarshal failed") - } + assert.Equal(t, t1, t2) dstData, _ := t2.Marshal() - if !bytes.Equal(dstData, rawData) { - t.Error("Marshal failed") - } + assert.Equal(t, dstData, rawData) } func TestPlayoutDelayExtensionExtraBytes(t *testing.T) { @@ -59,15 +52,12 @@ func TestPlayoutDelayExtensionExtraBytes(t *testing.T) { 0x01, 0x01, 0x00, 0xff, 0xff, } - if err := t1.Unmarshal(rawData); err != nil { - t.Fatal("Unmarshal error on extension data") - } + err := t1.Unmarshal(rawData) + assert.NoError(t, err) t2 := PlayoutDelayExtension{ MinDelay: 1 << 4, MaxDelay: 1 << 8, } - if t1 != t2 { - t.Error("Unmarshal failed") - } + assert.Equal(t, t1, t2) } diff --git a/transportccextension_test.go b/transportccextension_test.go index 5a06f2cd..b03db2a4 100644 --- a/transportccextension_test.go +++ b/transportccextension_test.go @@ -4,9 +4,9 @@ package rtp import ( - "bytes" - "errors" "testing" + + "github.com/stretchr/testify/assert" ) func TestTransportCCExtensionTooSmall(t *testing.T) { @@ -14,9 +14,8 @@ func TestTransportCCExtensionTooSmall(t *testing.T) { rawData := []byte{} - if err := t1.Unmarshal(rawData); !errors.Is(err, errTooSmall) { - t.Fatal("err != errTooSmall") - } + err := t1.Unmarshal(rawData) + assert.ErrorIs(t, err, errTooSmall) } func TestTransportCCExtension(t *testing.T) { @@ -26,22 +25,17 @@ func TestTransportCCExtension(t *testing.T) { 0x00, 0x02, } - if err := t1.Unmarshal(rawData); err != nil { - t.Fatal("Unmarshal error on extension data") - } + err := t1.Unmarshal(rawData) + assert.NoError(t, err) t2 := TransportCCExtension{ TransportSequence: 2, } - if t1 != t2 { - t.Error("Unmarshal failed") - } + assert.Equal(t, t1, t2) dstData, _ := t2.Marshal() - if !bytes.Equal(dstData, rawData) { - t.Error("Marshal failed") - } + assert.Equal(t, dstData, rawData) } func TestTransportCCExtensionExtraBytes(t *testing.T) { @@ -51,15 +45,12 @@ func TestTransportCCExtensionExtraBytes(t *testing.T) { 0x00, 0x02, 0x00, 0xff, 0xff, } - if err := t1.Unmarshal(rawData); err != nil { - t.Fatal("Unmarshal error on extension data") - } + err := t1.Unmarshal(rawData) + assert.NoError(t, err) t2 := TransportCCExtension{ TransportSequence: 2, } - if t1 != t2 { - t.Error("Unmarshal failed") - } + assert.Equal(t, t1, t2) } diff --git a/vlaextension_test.go b/vlaextension_test.go index e25d87af..a9894ed0 100644 --- a/vlaextension_test.go +++ b/vlaextension_test.go @@ -4,22 +4,13 @@ package rtp import ( - "bytes" "encoding/hex" - "errors" - "reflect" "testing" -) - -func TestVLAMarshal(t *testing.T) { //nolint:cyclop - requireNoError := func(t *testing.T, err error) { - t.Helper() - if err != nil { - t.Fatal(err) - } - } + "github.com/stretchr/testify/assert" +) +func TestVLAMarshal(t *testing.T) { t.Run("3 streams no resolution and framerate", func(t *testing.T) { vla := &VLA{ RTPStreamID: 0, @@ -44,12 +35,13 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop } bytesActual, err := vla.Marshal() - requireNoError(t, err) + assert.NoError(t, err) bytesExpected, err := hex.DecodeString("21149601f0019003d005b009") - requireNoError(t, err) - if !bytes.Equal(bytesExpected, bytesActual) { - t.Fatalf("expected %s, actual %s", hex.EncodeToString(bytesExpected), hex.EncodeToString(bytesActual)) - } + assert.NoError(t, err) + assert.Equal( + t, bytesExpected, bytesActual, + "expected %s, actual %s", hex.EncodeToString(bytesExpected), hex.EncodeToString(bytesActual), + ) }) t.Run("3 streams with resolution and framerate", func(t *testing.T) { @@ -86,12 +78,13 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop } bytesActual, err := vla.Marshal() - requireNoError(t, err) + assert.NoError(t, err) bytesExpected, err := hex.DecodeString("a1149601f0019003d005b009013f00b31e027f01671e04ff02cf1e") - requireNoError(t, err) - if !bytes.Equal(bytesExpected, bytesActual) { - t.Fatalf("expected %s, actual %s", hex.EncodeToString(bytesExpected), hex.EncodeToString(bytesActual)) - } + assert.NoError(t, err) + assert.Equal( + t, bytesExpected, bytesActual, + "expected %s, actual %s", hex.EncodeToString(bytesExpected), hex.EncodeToString(bytesActual), + ) }) t.Run("Negative RTPStreamCount", func(t *testing.T) { @@ -101,9 +94,7 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop ActiveSpatialLayer: []SpatialLayer{}, } _, err := vla.Marshal() - if !errors.Is(err, ErrVLAInvalidStreamCount) { - t.Fatal("expected ErrVLAInvalidRTPStreamCount") - } + assert.ErrorIs(t, err, ErrVLAInvalidStreamCount) }) t.Run("RTPStreamCount too large", func(t *testing.T) { @@ -113,9 +104,7 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop ActiveSpatialLayer: []SpatialLayer{{}, {}, {}, {}, {}}, } _, err := vla.Marshal() - if !errors.Is(err, ErrVLAInvalidStreamCount) { - t.Fatal("expected ErrVLAInvalidRTPStreamCount") - } + assert.ErrorIs(t, err, ErrVLAInvalidStreamCount) }) t.Run("Negative RTPStreamID", func(t *testing.T) { @@ -125,9 +114,7 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop ActiveSpatialLayer: []SpatialLayer{{}}, } _, err := vla.Marshal() - if !errors.Is(err, ErrVLAInvalidStreamID) { - t.Fatalf("expected ErrVLAInvalidRTPStreamID, actual %v", err) - } + assert.ErrorIs(t, err, ErrVLAInvalidStreamID) }) t.Run("RTPStreamID to large", func(t *testing.T) { @@ -137,9 +124,7 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop ActiveSpatialLayer: []SpatialLayer{{}}, } _, err := vla.Marshal() - if !errors.Is(err, ErrVLAInvalidStreamID) { - t.Fatalf("expected ErrVLAInvalidRTPStreamID: %v", err) - } + assert.ErrorIs(t, err, ErrVLAInvalidStreamID) }) t.Run("Invalid stream ID in the spatial layer", func(t *testing.T) { @@ -151,9 +136,7 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop }}, } _, err := vla.Marshal() - if !errors.Is(err, ErrVLAInvalidStreamID) { - t.Fatalf("expected ErrVLAInvalidStreamID: %v", err) - } + assert.ErrorIs(t, err, ErrVLAInvalidStreamID) vla = &VLA{ RTPStreamID: 0, RTPStreamCount: 1, @@ -162,9 +145,7 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop }}, } _, err = vla.Marshal() - if !errors.Is(err, ErrVLAInvalidStreamID) { - t.Fatalf("expected ErrVLAInvalidStreamID: %v", err) - } + assert.ErrorIs(t, err, ErrVLAInvalidStreamID) }) t.Run("Invalid spatial ID in the spatial layer", func(t *testing.T) { @@ -177,9 +158,7 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop }}, } _, err := vla.Marshal() - if !errors.Is(err, ErrVLAInvalidSpatialID) { - t.Fatalf("expected ErrVLAInvalidSpatialID: %v", err) - } + assert.ErrorIs(t, err, ErrVLAInvalidSpatialID) vla = &VLA{ RTPStreamID: 0, RTPStreamCount: 1, @@ -189,9 +168,7 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop }}, } _, err = vla.Marshal() - if !errors.Is(err, ErrVLAInvalidSpatialID) { - t.Fatalf("expected ErrVLAInvalidSpatialID: %v", err) - } + assert.ErrorIs(t, err, ErrVLAInvalidSpatialID) }) t.Run("Invalid temporal layer in the spatial layer", func(t *testing.T) { @@ -205,9 +182,7 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop }}, } _, err := vla.Marshal() - if !errors.Is(err, ErrVLAInvalidTemporalLayer) { - t.Fatalf("expected ErrVLAInvalidTemporalLayer: %v", err) - } + assert.ErrorIs(t, err, ErrVLAInvalidTemporalLayer) vla = &VLA{ RTPStreamID: 0, RTPStreamCount: 1, @@ -218,9 +193,7 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop }}, } _, err = vla.Marshal() - if !errors.Is(err, ErrVLAInvalidTemporalLayer) { - t.Fatalf("expected ErrVLAInvalidTemporalLayer: %v", err) - } + assert.ErrorIs(t, err, ErrVLAInvalidTemporalLayer) }) t.Run("Duplicate spatial ID in the spatial layer", func(t *testing.T) { @@ -238,219 +211,170 @@ func TestVLAMarshal(t *testing.T) { //nolint:cyclop }}, } _, err := vla.Marshal() - if !errors.Is(err, ErrVLADuplicateSpatialID) { - t.Fatalf("expected ErrVLADuplicateSpatialID: %v", err) - } + assert.ErrorIs(t, err, ErrVLADuplicateSpatialID) }) } func TestVLAUnmarshal(t *testing.T) { - requireEqualInt := func(t *testing.T, expected, actual int) { - t.Helper() - - if expected != actual { - t.Fatalf("expected %d, actual %d", expected, actual) - } - } - requireNoError := func(t *testing.T, err error) { - t.Helper() - - if err != nil { - t.Fatal(err) - } - } - requireTrue := func(t *testing.T, val bool) { - t.Helper() - - if !val { - t.Fatal("expected true") - } - } - requireFalse := func(t *testing.T, val bool) { - t.Helper() - - if val { - t.Fatal("expected false") - } - } - t.Run("3 streams no resolution and framerate", func(t *testing.T) { // two layer ("low", "high") b, err := hex.DecodeString("21149601f0019003d005b009") - requireNoError(t, err) - if err != nil { - t.Fatal("failed to decode input data") - } + assert.NoError(t, err) vla := &VLA{} n, err := vla.Unmarshal(b) - requireNoError(t, err) - requireEqualInt(t, len(b), n) - - requireEqualInt(t, 0, vla.RTPStreamID) - requireEqualInt(t, 3, vla.RTPStreamCount) - requireEqualInt(t, 3, len(vla.ActiveSpatialLayer)) - - requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID) - requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].SpatialID) - requireEqualInt(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates)) - requireEqualInt(t, 150, vla.ActiveSpatialLayer[0].TargetBitrates[0]) - - requireEqualInt(t, 1, vla.ActiveSpatialLayer[1].RTPStreamID) - requireEqualInt(t, 0, vla.ActiveSpatialLayer[1].SpatialID) - requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates)) - requireEqualInt(t, 240, vla.ActiveSpatialLayer[1].TargetBitrates[0]) - requireEqualInt(t, 400, vla.ActiveSpatialLayer[1].TargetBitrates[1]) - - requireFalse(t, vla.HasResolutionAndFramerate) - - requireEqualInt(t, 2, vla.ActiveSpatialLayer[2].RTPStreamID) - requireEqualInt(t, 0, vla.ActiveSpatialLayer[2].SpatialID) - requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[2].TargetBitrates)) - requireEqualInt(t, 720, vla.ActiveSpatialLayer[2].TargetBitrates[0]) - requireEqualInt(t, 1200, vla.ActiveSpatialLayer[2].TargetBitrates[1]) + assert.NoError(t, err) + assert.Equal(t, len(b), n) + + assert.Equal(t, 0, vla.RTPStreamID) + assert.Equal(t, 3, vla.RTPStreamCount) + assert.Equal(t, 3, len(vla.ActiveSpatialLayer)) + + assert.Equal(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID) + assert.Equal(t, 0, vla.ActiveSpatialLayer[0].SpatialID) + assert.Equal(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates)) + assert.Equal(t, 150, vla.ActiveSpatialLayer[0].TargetBitrates[0]) + + assert.Equal(t, 1, vla.ActiveSpatialLayer[1].RTPStreamID) + assert.Equal(t, 0, vla.ActiveSpatialLayer[1].SpatialID) + assert.Equal(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates)) + assert.Equal(t, 240, vla.ActiveSpatialLayer[1].TargetBitrates[0]) + assert.Equal(t, 400, vla.ActiveSpatialLayer[1].TargetBitrates[1]) + + assert.False(t, vla.HasResolutionAndFramerate) + + assert.Equal(t, 2, vla.ActiveSpatialLayer[2].RTPStreamID) + assert.Equal(t, 0, vla.ActiveSpatialLayer[2].SpatialID) + assert.Equal(t, 2, len(vla.ActiveSpatialLayer[2].TargetBitrates)) + assert.Equal(t, 720, vla.ActiveSpatialLayer[2].TargetBitrates[0]) + assert.Equal(t, 1200, vla.ActiveSpatialLayer[2].TargetBitrates[1]) }) t.Run("3 streams with resolution and framerate", func(t *testing.T) { b, err := hex.DecodeString("a1149601f0019003d005b009013f00b31e027f01671e04ff02cf1e") - requireNoError(t, err) + assert.NoError(t, err) vla := &VLA{} n, err := vla.Unmarshal(b) - requireNoError(t, err) - requireEqualInt(t, len(b), n) - - requireEqualInt(t, 2, vla.RTPStreamID) - requireEqualInt(t, 3, vla.RTPStreamCount) - - requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID) - requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].SpatialID) - requireEqualInt(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates)) - requireEqualInt(t, 150, vla.ActiveSpatialLayer[0].TargetBitrates[0]) - - requireEqualInt(t, 1, vla.ActiveSpatialLayer[1].RTPStreamID) - requireEqualInt(t, 0, vla.ActiveSpatialLayer[1].SpatialID) - requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates)) - requireEqualInt(t, 240, vla.ActiveSpatialLayer[1].TargetBitrates[0]) - requireEqualInt(t, 400, vla.ActiveSpatialLayer[1].TargetBitrates[1]) - - requireEqualInt(t, 2, vla.ActiveSpatialLayer[2].RTPStreamID) - requireEqualInt(t, 0, vla.ActiveSpatialLayer[2].SpatialID) - requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[2].TargetBitrates)) - requireEqualInt(t, 720, vla.ActiveSpatialLayer[2].TargetBitrates[0]) - requireEqualInt(t, 1200, vla.ActiveSpatialLayer[2].TargetBitrates[1]) - - requireTrue(t, vla.HasResolutionAndFramerate) - - requireEqualInt(t, 320, vla.ActiveSpatialLayer[0].Width) - requireEqualInt(t, 180, vla.ActiveSpatialLayer[0].Height) - requireEqualInt(t, 30, vla.ActiveSpatialLayer[0].Framerate) - requireEqualInt(t, 640, vla.ActiveSpatialLayer[1].Width) - requireEqualInt(t, 360, vla.ActiveSpatialLayer[1].Height) - requireEqualInt(t, 30, vla.ActiveSpatialLayer[1].Framerate) - requireEqualInt(t, 1280, vla.ActiveSpatialLayer[2].Width) - requireEqualInt(t, 720, vla.ActiveSpatialLayer[2].Height) - requireEqualInt(t, 30, vla.ActiveSpatialLayer[2].Framerate) + assert.NoError(t, err) + assert.Equal(t, len(b), n) + + assert.Equal(t, 2, vla.RTPStreamID) + assert.Equal(t, 3, vla.RTPStreamCount) + + assert.Equal(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID) + assert.Equal(t, 0, vla.ActiveSpatialLayer[0].SpatialID) + assert.Equal(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates)) + assert.Equal(t, 150, vla.ActiveSpatialLayer[0].TargetBitrates[0]) + + assert.Equal(t, 1, vla.ActiveSpatialLayer[1].RTPStreamID) + assert.Equal(t, 0, vla.ActiveSpatialLayer[1].SpatialID) + assert.Equal(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates)) + assert.Equal(t, 240, vla.ActiveSpatialLayer[1].TargetBitrates[0]) + assert.Equal(t, 400, vla.ActiveSpatialLayer[1].TargetBitrates[1]) + + assert.Equal(t, 2, vla.ActiveSpatialLayer[2].RTPStreamID) + assert.Equal(t, 0, vla.ActiveSpatialLayer[2].SpatialID) + assert.Equal(t, 2, len(vla.ActiveSpatialLayer[2].TargetBitrates)) + assert.Equal(t, 720, vla.ActiveSpatialLayer[2].TargetBitrates[0]) + assert.Equal(t, 1200, vla.ActiveSpatialLayer[2].TargetBitrates[1]) + + assert.True(t, vla.HasResolutionAndFramerate) + + assert.Equal(t, 320, vla.ActiveSpatialLayer[0].Width) + assert.Equal(t, 180, vla.ActiveSpatialLayer[0].Height) + assert.Equal(t, 30, vla.ActiveSpatialLayer[0].Framerate) + assert.Equal(t, 640, vla.ActiveSpatialLayer[1].Width) + assert.Equal(t, 360, vla.ActiveSpatialLayer[1].Height) + assert.Equal(t, 30, vla.ActiveSpatialLayer[1].Framerate) + assert.Equal(t, 1280, vla.ActiveSpatialLayer[2].Width) + assert.Equal(t, 720, vla.ActiveSpatialLayer[2].Height) + assert.Equal(t, 30, vla.ActiveSpatialLayer[2].Framerate) }) t.Run("2 streams", func(t *testing.T) { // two layer ("low", "high") b, err := hex.DecodeString("1110c801d005b009") - requireNoError(t, err) + assert.NoError(t, err) vla := &VLA{} n, err := vla.Unmarshal(b) - requireNoError(t, err) - requireEqualInt(t, len(b), n) + assert.NoError(t, err) + assert.Equal(t, len(b), n) - requireEqualInt(t, 0, vla.RTPStreamID) - requireEqualInt(t, 2, vla.RTPStreamCount) - requireEqualInt(t, 2, len(vla.ActiveSpatialLayer)) + assert.Equal(t, 0, vla.RTPStreamID) + assert.Equal(t, 2, vla.RTPStreamCount) + assert.Equal(t, 2, len(vla.ActiveSpatialLayer)) - requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID) - requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].SpatialID) - requireEqualInt(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates)) - requireEqualInt(t, 200, vla.ActiveSpatialLayer[0].TargetBitrates[0]) + assert.Equal(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID) + assert.Equal(t, 0, vla.ActiveSpatialLayer[0].SpatialID) + assert.Equal(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates)) + assert.Equal(t, 200, vla.ActiveSpatialLayer[0].TargetBitrates[0]) - requireEqualInt(t, 1, vla.ActiveSpatialLayer[1].RTPStreamID) - requireEqualInt(t, 0, vla.ActiveSpatialLayer[1].SpatialID) - requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates)) - requireEqualInt(t, 720, vla.ActiveSpatialLayer[1].TargetBitrates[0]) - requireEqualInt(t, 1200, vla.ActiveSpatialLayer[1].TargetBitrates[1]) + assert.Equal(t, 1, vla.ActiveSpatialLayer[1].RTPStreamID) + assert.Equal(t, 0, vla.ActiveSpatialLayer[1].SpatialID) + assert.Equal(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates)) + assert.Equal(t, 720, vla.ActiveSpatialLayer[1].TargetBitrates[0]) + assert.Equal(t, 1200, vla.ActiveSpatialLayer[1].TargetBitrates[1]) - requireFalse(t, vla.HasResolutionAndFramerate) + assert.False(t, vla.HasResolutionAndFramerate) }) t.Run("3 streams mid paused with resolution and framerate", func(t *testing.T) { b, err := hex.DecodeString("601010109601d005b009013f00b31e04ff02cf1e") - requireNoError(t, err) + assert.NoError(t, err) vla := &VLA{} n, err := vla.Unmarshal(b) - requireNoError(t, err) - requireEqualInt(t, len(b), n) - - requireEqualInt(t, 1, vla.RTPStreamID) - requireEqualInt(t, 3, vla.RTPStreamCount) - - requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID) - requireEqualInt(t, 0, vla.ActiveSpatialLayer[0].SpatialID) - requireEqualInt(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates)) - requireEqualInt(t, 150, vla.ActiveSpatialLayer[0].TargetBitrates[0]) - - requireEqualInt(t, 2, vla.ActiveSpatialLayer[1].RTPStreamID) - requireEqualInt(t, 0, vla.ActiveSpatialLayer[1].SpatialID) - requireEqualInt(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates)) - requireEqualInt(t, 720, vla.ActiveSpatialLayer[1].TargetBitrates[0]) - requireEqualInt(t, 1200, vla.ActiveSpatialLayer[1].TargetBitrates[1]) - - requireTrue(t, vla.HasResolutionAndFramerate) - - requireEqualInt(t, 320, vla.ActiveSpatialLayer[0].Width) - requireEqualInt(t, 180, vla.ActiveSpatialLayer[0].Height) - requireEqualInt(t, 30, vla.ActiveSpatialLayer[0].Framerate) - requireEqualInt(t, 1280, vla.ActiveSpatialLayer[1].Width) - requireEqualInt(t, 720, vla.ActiveSpatialLayer[1].Height) - requireEqualInt(t, 30, vla.ActiveSpatialLayer[1].Framerate) + assert.NoError(t, err) + assert.Equal(t, len(b), n) + + assert.Equal(t, 1, vla.RTPStreamID) + assert.Equal(t, 3, vla.RTPStreamCount) + + assert.Equal(t, 0, vla.ActiveSpatialLayer[0].RTPStreamID) + assert.Equal(t, 0, vla.ActiveSpatialLayer[0].SpatialID) + assert.Equal(t, 1, len(vla.ActiveSpatialLayer[0].TargetBitrates)) + assert.Equal(t, 150, vla.ActiveSpatialLayer[0].TargetBitrates[0]) + + assert.Equal(t, 2, vla.ActiveSpatialLayer[1].RTPStreamID) + assert.Equal(t, 0, vla.ActiveSpatialLayer[1].SpatialID) + assert.Equal(t, 2, len(vla.ActiveSpatialLayer[1].TargetBitrates)) + assert.Equal(t, 720, vla.ActiveSpatialLayer[1].TargetBitrates[0]) + assert.Equal(t, 1200, vla.ActiveSpatialLayer[1].TargetBitrates[1]) + + assert.True(t, vla.HasResolutionAndFramerate) + + assert.Equal(t, 320, vla.ActiveSpatialLayer[0].Width) + assert.Equal(t, 180, vla.ActiveSpatialLayer[0].Height) + assert.Equal(t, 30, vla.ActiveSpatialLayer[0].Framerate) + assert.Equal(t, 1280, vla.ActiveSpatialLayer[1].Width) + assert.Equal(t, 720, vla.ActiveSpatialLayer[1].Height) + assert.Equal(t, 30, vla.ActiveSpatialLayer[1].Framerate) }) t.Run("extra 1", func(t *testing.T) { b, err := hex.DecodeString("a0001040ac02f403") - requireNoError(t, err) + assert.NoError(t, err) vla := &VLA{} n, err := vla.Unmarshal(b) - requireNoError(t, err) - requireEqualInt(t, len(b), n) + assert.NoError(t, err) + assert.Equal(t, len(b), n) }) t.Run("extra 2", func(t *testing.T) { b, err := hex.DecodeString("a00010409405cc08") - requireNoError(t, err) + assert.NoError(t, err) vla := &VLA{} n, err := vla.Unmarshal(b) - requireNoError(t, err) - requireEqualInt(t, len(b), n) + assert.NoError(t, err) + assert.Equal(t, len(b), n) }) } -func TestVLAMarshalThenUnmarshal(t *testing.T) { //nolint:cyclop - requireEqualInt := func(t *testing.T, expected, actual int) { - t.Helper() - - if expected != actual { - t.Fatalf("expected %d, actual %d", expected, actual) - } - } - requireNoError := func(t *testing.T, err error) { - t.Helper() - - if err != nil { - t.Fatal(err) - } - } - +func TestVLAMarshalThenUnmarshal(t *testing.T) { t.Run("multiple spatial layers", func(t *testing.T) { var spatialLayers []SpatialLayer for streamID := 0; streamID < 3; streamID++ { @@ -474,16 +398,13 @@ func TestVLAMarshalThenUnmarshal(t *testing.T) { //nolint:cyclop } b, err := vla0.Marshal() - requireNoError(t, err) + assert.NoError(t, err) vla1 := &VLA{} n, err := vla1.Unmarshal(b) - requireNoError(t, err) - requireEqualInt(t, len(b), n) - - if !reflect.DeepEqual(vla0, vla1) { - t.Fatalf("expected %v, actual %v", vla0, vla1) - } + assert.NoError(t, err) + assert.Equal(t, len(b), n) + assert.Equal(t, vla0, vla1) }) t.Run("different spatial layer bitmasks", func(t *testing.T) { @@ -509,26 +430,18 @@ func TestVLAMarshalThenUnmarshal(t *testing.T) { //nolint:cyclop } b, err := vla0.Marshal() - requireNoError(t, err) - if b[0]&0x0f != 0 { - t.Error("expects sl_bm to be 0") - } - if b[1] != 0x13 { - t.Error("expects sl0_bm,sl1_bm to be b0001,b0011") - } - if b[2] != 0x7f { - t.Error("expects sl1_bm,sl2_bm to be b0111,b1111") - } + assert.NoError(t, err) + assert.Equal(t, byte(0x00), b[0]&0x0f, "expects sl_bm to be 0") + assert.Equal(t, byte(0x13), b[1], "expects sl0_bm,sl1_bm to be b0001,b0011") + assert.Equal(t, byte(0x7f), b[2], "expects sl1_bm,sl2_bm to be b0111,b1111") t.Logf("b: %s", hex.EncodeToString(b)) vla1 := &VLA{} n, err := vla1.Unmarshal(b) - requireNoError(t, err) - requireEqualInt(t, len(b), n) + assert.NoError(t, err) + assert.Equal(t, len(b), n) - if !reflect.DeepEqual(vla0, vla1) { - t.Fatalf("expected %v, actual %v", vla0, vla1) - } + assert.Equal(t, vla0, vla1) }) }