diff --git a/examples/date/publisher.go b/examples/date/publisher.go index 6505fcd8..44c9942b 100644 --- a/examples/date/publisher.go +++ b/examples/date/publisher.go @@ -16,10 +16,10 @@ func (p *publisher) SendDatagram(o moqtransport.Object) error { return p.p.SendDatagram(o) } -func (p *publisher) OpenSubgroup(groupID, subgroupID uint64, priority uint8) (*moqtransport.Subgroup, error) { +func (p *publisher) OpenSubgroup(groupID, subgroupID uint64, priority uint8, opts ...moqtransport.SubgroupOption) (*moqtransport.Subgroup, error) { log.Printf("sessionNr: %d, requestID: %d, groupID: %d, subgroupID: %v", p.sessionID, p.requestID, groupID, subgroupID) - return p.p.OpenSubgroup(groupID, subgroupID, priority) + return p.p.OpenSubgroup(groupID, subgroupID, priority, opts...) } func (p *publisher) CloseWithError(code uint64, reason string) error { diff --git a/handler.go b/handler.go index 7d1a40a4..80427843 100644 --- a/handler.go +++ b/handler.go @@ -50,13 +50,29 @@ type ResponseWriter interface { Reject(code uint64, reason string) error } +// SubgroupOption configures optional subgroup properties. +type SubgroupOption func(*subgroupOptions) + +type subgroupOptions struct { + endOfGroup bool +} + +// WithEndOfGroup signals that this subgroup stream will contain the last +// object in the group. Per draft-14, this sets the "Contains End of Group" +// bit in the SUBGROUP_HEADER stream type. +func WithEndOfGroup() SubgroupOption { + return func(o *subgroupOptions) { + o.endOfGroup = true + } +} + // Publisher is the interface implemented by SubscribeResponseWriters type Publisher interface { // SendDatagram sends an object in a datagram. SendDatagram(Object) error // OpenSubgroup opens and returns a new subgroup. - OpenSubgroup(groupID, subgroupID uint64, priority uint8) (*Subgroup, error) + OpenSubgroup(groupID, subgroupID uint64, priority uint8, opts ...SubgroupOption) (*Subgroup, error) // CloseWithError closes the track and sends SUBSCRIBE_DONE with code and // reason. diff --git a/internal/wire/announce_ok_message.go b/internal/wire/announce_ok_message.go deleted file mode 100644 index 67683a21..00000000 --- a/internal/wire/announce_ok_message.go +++ /dev/null @@ -1,30 +0,0 @@ -package wire - -import ( - "log/slog" - - "github.com/quic-go/quic-go/quicvarint" -) - -type AnnounceOkMessage struct { - RequestID uint64 -} - -func (m *AnnounceOkMessage) LogValue() slog.Value { - return slog.GroupValue( - slog.String("type", "announce_ok"), - ) -} - -func (m AnnounceOkMessage) Type() controlMessageType { - return messageTypeAnnounceOk -} - -func (m *AnnounceOkMessage) Append(buf []byte) []byte { - return quicvarint.Append(buf, m.RequestID) -} - -func (m *AnnounceOkMessage) parse(_ Version, data []byte) (err error) { - m.RequestID, _, err = quicvarint.Parse(data) - return err -} diff --git a/internal/wire/client_setup_message_test.go b/internal/wire/client_setup_message_test.go index 776ccac0..d9dc17a7 100644 --- a/internal/wire/client_setup_message_test.go +++ b/internal/wire/client_setup_message_test.go @@ -26,17 +26,17 @@ func TestClientSetupMessageAppend(t *testing.T) { }, { csm: ClientSetupMessage{ - SupportedVersions: []Version{Draft_ietf_moq_transport_00}, + SupportedVersions: []Version{CurrentVersion}, SetupParameters: KVPList{}, }, buf: []byte{}, expect: []byte{ - 0x01, 0xc0, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, + 0x01, 0xc0, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x0e, 0x00, }, }, { csm: ClientSetupMessage{ - SupportedVersions: []Version{Draft_ietf_moq_transport_00}, + SupportedVersions: []Version{CurrentVersion}, SetupParameters: KVPList{ KeyValuePair{ Type: PathParameterKey, @@ -46,7 +46,7 @@ func TestClientSetupMessageAppend(t *testing.T) { }, buf: []byte{}, expect: []byte{ - 0x01, 0xc0, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 'A', + 0x01, 0xc0, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x0e, 0x01, 0x01, 0x01, 'A', }, }, } diff --git a/internal/wire/control_message_parser.go b/internal/wire/control_message_parser.go index 15342da6..cc346923 100644 --- a/internal/wire/control_message_parser.go +++ b/internal/wire/control_message_parser.go @@ -67,8 +67,8 @@ func (p *ControlMessageParser) Parse() (ControlMessage, error) { m = &UnsubscribeMessage{} case messageTypeSubscribeUpdate: m = &SubscribeUpdateMessage{} - case messageTypeSubscribeDone: - m = &SubscribeDoneMessage{} + case messageTypePublishDone: + m = &PublishDoneMessage{} case messageTypeFetch: m = &FetchMessage{} @@ -80,29 +80,29 @@ func (p *ControlMessageParser) Parse() (ControlMessage, error) { m = &FetchCancelMessage{} case messageTypeTrackStatus: - m = &TrackStatusRequestMessage{} - case messageTypeTrackStatusOk: m = &TrackStatusMessage{} + case messageTypeTrackStatusOk: + m = &TrackStatusOkMessage{} - case messageTypeAnnounce: - m = &AnnounceMessage{} - case messageTypeAnnounceOk: - m = &AnnounceOkMessage{} - case messageTypeAnnounceError: - m = &AnnounceErrorMessage{} - case messageTypeUnannounce: - m = &UnannounceMessage{} - case messageTypeAnnounceCancel: - m = &AnnounceCancelMessage{} + case messageTypePublishNamespace: + m = &PublishNamespaceMessage{} + case messageTypePublishNamespaceOk: + m = &PublishNamespaceOkMessage{} + case messageTypePublishNamespaceError: + m = &PublishNamespaceErrorMessage{} + case messageTypePublishNamespaceDone: + m = &PublishNamespaceDoneMessage{} + case messageTypePublishNamespaceCancel: + m = &PublishNamespaceCancelMessage{} case messageTypeSubscribeNamespace: - m = &SubscribeAnnouncesMessage{} + m = &SubscribeNamespaceMessage{} case messageTypeSubscribeNamespaceOk: - m = &SubscribeAnnouncesOkMessage{} + m = &SubscribeNamespaceOkMessage{} case messageTypeSubscribeNamespaceError: - m = &SubscribeAnnouncesErrorMessage{} + m = &SubscribeNamespaceErrorMessage{} case messageTypeUnsubscribeNamespace: - m = &UnsubscribeAnnouncesMessage{} + m = &UnsubscribeNamespaceMessage{} default: return nil, fmt.Errorf("%w: %v", errInvalidMessageType, mt) } diff --git a/internal/wire/control_message_type.go b/internal/wire/control_message_type.go index a6bc1931..a1cd585e 100644 --- a/internal/wire/control_message_type.go +++ b/internal/wire/control_message_type.go @@ -22,8 +22,8 @@ const ( messageTypeSubscribeError controlMessageType = 0x05 messageTypeSubscribeUpdate controlMessageType = 0x02 messageTypeUnsubscribe controlMessageType = 0x0a - messageTypeSubscribeDone controlMessageType = 0x0b + messageTypePublishDone controlMessageType = 0x0b messageTypePublish controlMessageType = 0x1d messageTypePublishOk controlMessageType = 0x1e messageTypePublishError controlMessageType = 0x1f @@ -37,11 +37,11 @@ const ( messageTypeTrackStatusOk controlMessageType = 0x0e messageTypeTrackStatusError controlMessageType = 0x0f - messageTypeAnnounce controlMessageType = 0x06 - messageTypeAnnounceOk controlMessageType = 0x07 - messageTypeAnnounceError controlMessageType = 0x08 - messageTypeUnannounce controlMessageType = 0x09 - messageTypeAnnounceCancel controlMessageType = 0x0c + messageTypePublishNamespace controlMessageType = 0x06 + messageTypePublishNamespaceOk controlMessageType = 0x07 + messageTypePublishNamespaceError controlMessageType = 0x08 + messageTypePublishNamespaceDone controlMessageType = 0x09 + messageTypePublishNamespaceCancel controlMessageType = 0x0c messageTypeSubscribeNamespace controlMessageType = 0x11 messageTypeSubscribeNamespaceOk controlMessageType = 0x12 @@ -74,9 +74,9 @@ func (mt controlMessageType) String() string { return "Unsubscribe" case messageTypeSubscribeUpdate: return "SubscribeUpdate" - case messageTypeSubscribeDone: - return "SubscribeDone" + case messageTypePublishDone: + return "PublishDone" case messageTypePublish: return "Publish" case messageTypePublishOk: @@ -100,16 +100,16 @@ func (mt controlMessageType) String() string { case messageTypeTrackStatusError: return "TrackStatusError" - case messageTypeAnnounce: - return "Announce" - case messageTypeAnnounceOk: - return "AnnounceOk" - case messageTypeAnnounceError: - return "AnnounceError" - case messageTypeUnannounce: - return "Unannounce" - case messageTypeAnnounceCancel: - return "AnnounceCancel" + case messageTypePublishNamespace: + return "PublishNamespace" + case messageTypePublishNamespaceOk: + return "PublishNamespaceOk" + case messageTypePublishNamespaceError: + return "PublishNamespaceError" + case messageTypePublishNamespaceDone: + return "PublishNamespaceDone" + case messageTypePublishNamespaceCancel: + return "PublishNamespaceCancel" case messageTypeSubscribeNamespace: return "SubscribeNamespace" diff --git a/internal/wire/object_stream_parser.go b/internal/wire/object_stream_parser.go index aba4b187..41487894 100644 --- a/internal/wire/object_stream_parser.go +++ b/internal/wire/object_stream_parser.go @@ -16,19 +16,47 @@ import ( type StreamType uint64 const ( - StreamTypeFetch StreamType = 0x05 - StreamTypeSubgroupZeroSIDNoExt StreamType = 0x08 - StreamTypeSubgroupZeroSIDExt StreamType = 0x09 - StreamTypeSubgroupNoSIDNoExt StreamType = 0x0a - StreamTypeSubgroupNoSIDExt StreamType = 0x0b - StreamTypeSubgroupSIDNoExt StreamType = 0x0c - StreamTypeSubgroupSIDExt StreamType = 0x0d + StreamTypeFetch StreamType = 0x05 + + // Subgroup header types without End of Group (0x10-0x15) + StreamTypeSubgroupZeroSIDNoExt StreamType = 0x10 + StreamTypeSubgroupZeroSIDExt StreamType = 0x11 + StreamTypeSubgroupNoSIDNoExt StreamType = 0x12 + StreamTypeSubgroupNoSIDExt StreamType = 0x13 + StreamTypeSubgroupSIDNoExt StreamType = 0x14 + StreamTypeSubgroupSIDExt StreamType = 0x15 + + // Subgroup header types with End of Group (0x18-0x1D) + StreamTypeSubgroupZeroSIDNoExtEOG StreamType = 0x18 + StreamTypeSubgroupZeroSIDExtEOG StreamType = 0x19 + StreamTypeSubgroupNoSIDNoExtEOG StreamType = 0x1A + StreamTypeSubgroupNoSIDExtEOG StreamType = 0x1B + StreamTypeSubgroupSIDNoExtEOG StreamType = 0x1C + StreamTypeSubgroupSIDExtEOG StreamType = 0x1D ) var ( errInvalidStreamType = errors.New("invalid stream type") ) +func isSubgroupStreamType(st StreamType) bool { + return (st >= 0x10 && st <= 0x15) || (st >= 0x18 && st <= 0x1D) +} + +func subgroupHasExplicitSID(st StreamType) bool { + return st == StreamTypeSubgroupSIDNoExt || st == StreamTypeSubgroupSIDExt || + st == StreamTypeSubgroupSIDNoExtEOG || st == StreamTypeSubgroupSIDExtEOG +} + +func subgroupSIDIsFirstObjectID(st StreamType) bool { + return st == StreamTypeSubgroupNoSIDNoExt || st == StreamTypeSubgroupNoSIDExt || + st == StreamTypeSubgroupNoSIDNoExtEOG || st == StreamTypeSubgroupNoSIDExtEOG +} + +func subgroupContainsEndOfGroup(st StreamType) bool { + return st >= 0x18 && st <= 0x1D +} + type ObjectStreamParser struct { qlogger *qlog.Logger streamID uint64 @@ -42,6 +70,7 @@ type ObjectStreamParser struct { PublisherPriority uint8 GroupID uint64 SubgroupID uint64 + EndOfGroup bool } func (p *ObjectStreamParser) Type() StreamType { @@ -83,7 +112,7 @@ func NewObjectStreamParser(r io.Reader, streamID uint64, qlogger *qlog.Logger) ( SubgroupID: 0, }, nil } - if streamType >= 0x08 && streamType <= 0x0d { + if isSubgroupStreamType(streamType) { if qlogger != nil { qlogger.Log(moqt.StreamTypeSetEvent{ Owner: moqt.GetOwner(moqt.OwnerRemote), @@ -95,9 +124,9 @@ func NewObjectStreamParser(r io.Reader, streamID uint64, qlogger *qlog.Logger) ( // objects ext := streamType&0x01 > 0 - // Only read subgroup ID from header if type is 0x0c or 0x0d. In all - // other cases, it is either zero or will be read from the first object. - sid := streamType == 0x0c || streamType == 0x0d + // Only read subgroup ID from header if type has explicit SID field. + // In all other cases, it is either zero or will be read from the first object. + sid := subgroupHasExplicitSID(streamType) var shsm SubgroupHeaderMessage if err := shsm.parse(br, sid); err != nil { @@ -109,13 +138,14 @@ func NewObjectStreamParser(r io.Reader, streamID uint64, qlogger *qlog.Logger) ( reader: br, typ: streamType, identifier: shsm.TrackAlias, - // if stream type is 0x0a or 0x0b, we don't yet know the subgroup ID + // if subgroup ID comes from first object ID, we don't yet know it // because it will only be read when the first object is parsed. - hasSubgroupID: streamType != 0x0a && streamType != 0x0b, + hasSubgroupID: !subgroupSIDIsFirstObjectID(streamType), hasExtensions: ext, PublisherPriority: shsm.PublisherPriority, GroupID: shsm.GroupID, SubgroupID: shsm.SubgroupID, + EndOfGroup: subgroupContainsEndOfGroup(streamType), }, nil } return nil, fmt.Errorf("%w: %v", errInvalidStreamType, st) @@ -239,7 +269,7 @@ func (p *ObjectStreamParser) Parse() (*ObjectMessage, error) { if p.typ == StreamTypeFetch { return p.parseFetchObject() } - if p.typ >= 0x08 && p.typ <= 0x0d { + if isSubgroupStreamType(p.typ) { return p.parseSubgroupObject() } return nil, errInvalidStreamType diff --git a/internal/wire/subscribe_done_message.go b/internal/wire/publish_done_message.go similarity index 73% rename from internal/wire/subscribe_done_message.go rename to internal/wire/publish_done_message.go index 7ff58e55..f2f54b7c 100644 --- a/internal/wire/subscribe_done_message.go +++ b/internal/wire/publish_done_message.go @@ -6,16 +6,16 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) -type SubscribeDoneMessage struct { +type PublishDoneMessage struct { RequestID uint64 StatusCode uint64 StreamCount uint64 ReasonPhrase string } -func (m *SubscribeDoneMessage) LogValue() slog.Value { +func (m *PublishDoneMessage) LogValue() slog.Value { return slog.GroupValue( - slog.String("type", "subscribe_done"), + slog.String("type", "publish_done"), slog.Uint64("request_id", m.RequestID), slog.Uint64("status_code", m.StatusCode), slog.Uint64("stream_count", m.StreamCount), @@ -23,11 +23,11 @@ func (m *SubscribeDoneMessage) LogValue() slog.Value { ) } -func (m SubscribeDoneMessage) Type() controlMessageType { - return messageTypeSubscribeDone +func (m PublishDoneMessage) Type() controlMessageType { + return messageTypePublishDone } -func (m *SubscribeDoneMessage) Append(buf []byte) []byte { +func (m *PublishDoneMessage) Append(buf []byte) []byte { buf = quicvarint.Append(buf, m.RequestID) buf = quicvarint.Append(buf, m.StatusCode) buf = quicvarint.Append(buf, m.StreamCount) @@ -35,7 +35,7 @@ func (m *SubscribeDoneMessage) Append(buf []byte) []byte { return buf } -func (m *SubscribeDoneMessage) parse(_ Version, data []byte) (err error) { +func (m *PublishDoneMessage) parse(_ Version, data []byte) (err error) { var n int m.RequestID, n, err = quicvarint.Parse(data) if err != nil { diff --git a/internal/wire/subscribe_done_message_test.go b/internal/wire/publish_done_message_test.go similarity index 82% rename from internal/wire/subscribe_done_message_test.go rename to internal/wire/publish_done_message_test.go index 03c8541f..d768bdd8 100644 --- a/internal/wire/subscribe_done_message_test.go +++ b/internal/wire/publish_done_message_test.go @@ -8,14 +8,14 @@ import ( "github.com/stretchr/testify/assert" ) -func TestSubscribeDoneMessageAppend(t *testing.T) { +func TestPublishDoneMessageAppend(t *testing.T) { cases := []struct { - srm SubscribeDoneMessage + srm PublishDoneMessage buf []byte expect []byte }{ { - srm: SubscribeDoneMessage{ + srm: PublishDoneMessage{ RequestID: 0, StatusCode: 0, StreamCount: 0, @@ -25,7 +25,7 @@ func TestSubscribeDoneMessageAppend(t *testing.T) { expect: []byte{0x00, 0x00, 0x00, 0x00}, }, { - srm: SubscribeDoneMessage{ + srm: PublishDoneMessage{ RequestID: 0, StatusCode: 1, StreamCount: 2, @@ -40,7 +40,7 @@ func TestSubscribeDoneMessageAppend(t *testing.T) { }, }, { - srm: SubscribeDoneMessage{ + srm: PublishDoneMessage{ RequestID: 17, StatusCode: 1, StreamCount: 4, @@ -56,7 +56,7 @@ func TestSubscribeDoneMessageAppend(t *testing.T) { }, }, { - srm: SubscribeDoneMessage{ + srm: PublishDoneMessage{ RequestID: 0, StatusCode: 0, StreamCount: 0, @@ -66,7 +66,7 @@ func TestSubscribeDoneMessageAppend(t *testing.T) { expect: []byte{0x00, 0x00, 0x00, 0x00}, }, { - srm: SubscribeDoneMessage{ + srm: PublishDoneMessage{ RequestID: 0, StatusCode: 1, StreamCount: 2, @@ -81,7 +81,7 @@ func TestSubscribeDoneMessageAppend(t *testing.T) { }, }, { - srm: SubscribeDoneMessage{ + srm: PublishDoneMessage{ RequestID: 17, StatusCode: 1, StreamCount: 2, @@ -105,27 +105,27 @@ func TestSubscribeDoneMessageAppend(t *testing.T) { } } -func TestParseSubscribeDoneMessage(t *testing.T) { +func TestParsePublishDoneMessage(t *testing.T) { cases := []struct { data []byte - expect *SubscribeDoneMessage + expect *PublishDoneMessage err error }{ { data: nil, - expect: &SubscribeDoneMessage{}, + expect: &PublishDoneMessage{}, err: io.EOF, }, { data: []byte{}, - expect: &SubscribeDoneMessage{}, + expect: &PublishDoneMessage{}, err: io.EOF, }, { data: []byte{ 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, }, - expect: &SubscribeDoneMessage{ + expect: &PublishDoneMessage{ RequestID: 0, StatusCode: 0, StreamCount: 0, @@ -143,7 +143,7 @@ func TestParseSubscribeDoneMessage(t *testing.T) { 0x02, 0x03, }, - expect: &SubscribeDoneMessage{ + expect: &PublishDoneMessage{ RequestID: 0, StatusCode: 1, StreamCount: 2, @@ -159,7 +159,7 @@ func TestParseSubscribeDoneMessage(t *testing.T) { 0x06, 'r', 'e', 'a', 's', 'o', 'n', 0x00, }, - expect: &SubscribeDoneMessage{ + expect: &PublishDoneMessage{ RequestID: 0, StatusCode: 1, StreamCount: 2, @@ -171,7 +171,7 @@ func TestParseSubscribeDoneMessage(t *testing.T) { data: []byte{ 0x00, 0x00, 0x00, 0x00, }, - expect: &SubscribeDoneMessage{ + expect: &PublishDoneMessage{ RequestID: 0, StatusCode: 0, StreamCount: 0, @@ -182,7 +182,7 @@ func TestParseSubscribeDoneMessage(t *testing.T) { } for i, tc := range cases { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - res := &SubscribeDoneMessage{} + res := &PublishDoneMessage{} err := res.parse(CurrentVersion, tc.data) assert.Equal(t, tc.expect, res) if tc.err != nil { diff --git a/internal/wire/announce_cancel_message.go b/internal/wire/publish_namespace_cancel_message.go similarity index 62% rename from internal/wire/announce_cancel_message.go rename to internal/wire/publish_namespace_cancel_message.go index 323a23e6..41abddb7 100644 --- a/internal/wire/announce_cancel_message.go +++ b/internal/wire/publish_namespace_cancel_message.go @@ -6,37 +6,37 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) -type AnnounceCancelMessage struct { +type PublishNamespaceCancelMessage struct { TrackNamespace Tuple ErrorCode uint64 ReasonPhrase string } -func (m *AnnounceCancelMessage) LogValue() slog.Value { +func (m *PublishNamespaceCancelMessage) LogValue() slog.Value { return slog.GroupValue( - slog.String("type", "announce_cancel"), + slog.String("type", "publish_namespace_cancel"), slog.Any("track_namespace", m.TrackNamespace), slog.Uint64("error_code", m.ErrorCode), slog.String("reason", m.ReasonPhrase), ) } -func (m AnnounceCancelMessage) GetTrackNamespace() string { +func (m PublishNamespaceCancelMessage) GetTrackNamespace() string { return m.TrackNamespace.String() } -func (m AnnounceCancelMessage) Type() controlMessageType { - return messageTypeAnnounce +func (m PublishNamespaceCancelMessage) Type() controlMessageType { + return messageTypePublishNamespaceCancel } -func (m *AnnounceCancelMessage) Append(buf []byte) []byte { +func (m *PublishNamespaceCancelMessage) Append(buf []byte) []byte { buf = m.TrackNamespace.append(buf) buf = quicvarint.Append(buf, m.ErrorCode) buf = appendVarIntBytes(buf, []byte(m.ReasonPhrase)) return buf } -func (m *AnnounceCancelMessage) parse(_ Version, data []byte) (err error) { +func (m *PublishNamespaceCancelMessage) parse(_ Version, data []byte) (err error) { var n int m.TrackNamespace, n, err = parseTuple(data) if err != nil { diff --git a/internal/wire/announce_cancel_message_test.go b/internal/wire/publish_namespace_cancel_message_test.go similarity index 78% rename from internal/wire/announce_cancel_message_test.go rename to internal/wire/publish_namespace_cancel_message_test.go index a61e6467..730c2aaf 100644 --- a/internal/wire/announce_cancel_message_test.go +++ b/internal/wire/publish_namespace_cancel_message_test.go @@ -8,14 +8,14 @@ import ( "github.com/stretchr/testify/assert" ) -func TestAnnounceCancelMessageAppend(t *testing.T) { +func TestPublishNamespaceCancelMessageAppend(t *testing.T) { cases := []struct { - aom AnnounceCancelMessage + aom PublishNamespaceCancelMessage buf []byte expect []byte }{ { - aom: AnnounceCancelMessage{ + aom: PublishNamespaceCancelMessage{ TrackNamespace: []string{""}, ErrorCode: 1, ReasonPhrase: "reason", @@ -26,7 +26,7 @@ func TestAnnounceCancelMessageAppend(t *testing.T) { }, }, { - aom: AnnounceCancelMessage{ + aom: PublishNamespaceCancelMessage{ TrackNamespace: []string{"tracknamespace"}, ErrorCode: 1, ReasonPhrase: "reason", @@ -48,22 +48,22 @@ func TestAnnounceCancelMessageAppend(t *testing.T) { } } -func TestParseAnnounceCancelMessage(t *testing.T) { +func TestParsePublishNamespaceCancelMessage(t *testing.T) { cases := []struct { data []byte - expect *AnnounceCancelMessage + expect *PublishNamespaceCancelMessage err error }{ { data: nil, - expect: &AnnounceCancelMessage{}, + expect: &PublishNamespaceCancelMessage{}, err: io.EOF, }, { data: append( []byte{0x01, 0x0E}, append([]byte("tracknamespace"), 0x00, 0x00)..., ), - expect: &AnnounceCancelMessage{ + expect: &PublishNamespaceCancelMessage{ TrackNamespace: []string{"tracknamespace"}, ErrorCode: 0, ReasonPhrase: "", @@ -72,7 +72,7 @@ func TestParseAnnounceCancelMessage(t *testing.T) { }, { data: append([]byte{0x01, 0x05}, append([]byte("track"), []byte{0x01, 0x06, 'r', 'e', 'a', 's', 'o', 'n', 'p', 'h', 'r', 'a', 's', 'e'}...)...), - expect: &AnnounceCancelMessage{ + expect: &PublishNamespaceCancelMessage{ TrackNamespace: []string{"track"}, ErrorCode: 1, ReasonPhrase: "reason", @@ -81,7 +81,7 @@ func TestParseAnnounceCancelMessage(t *testing.T) { }, { data: append([]byte{0x01, 0x0F}, "tracknamespace"...), - expect: &AnnounceCancelMessage{ + expect: &PublishNamespaceCancelMessage{ TrackNamespace: []string{}, }, err: errLengthMismatch, @@ -89,7 +89,7 @@ func TestParseAnnounceCancelMessage(t *testing.T) { } for i, tc := range cases { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - res := &AnnounceCancelMessage{} + res := &PublishNamespaceCancelMessage{} err := res.parse(CurrentVersion, tc.data) assert.Equal(t, tc.expect, res) if tc.err != nil { diff --git a/internal/wire/publish_namespace_done_message.go b/internal/wire/publish_namespace_done_message.go new file mode 100644 index 00000000..f404cea5 --- /dev/null +++ b/internal/wire/publish_namespace_done_message.go @@ -0,0 +1,30 @@ +package wire + +import ( + "log/slog" +) + +type PublishNamespaceDoneMessage struct { + TrackNamespace Tuple +} + +func (m *PublishNamespaceDoneMessage) LogValue() slog.Value { + return slog.GroupValue( + slog.String("type", "publish_namespace_done"), + slog.Any("track_namespace", m.TrackNamespace), + ) +} + +func (m PublishNamespaceDoneMessage) Type() controlMessageType { + return messageTypePublishNamespaceDone +} + +func (m *PublishNamespaceDoneMessage) Append(buf []byte) []byte { + buf = m.TrackNamespace.append(buf) + return buf +} + +func (p *PublishNamespaceDoneMessage) parse(_ Version, data []byte) (err error) { + p.TrackNamespace, _, err = parseTuple(data) + return err +} diff --git a/internal/wire/unannounce_message_test.go b/internal/wire/publish_namespace_done_message_test.go similarity index 74% rename from internal/wire/unannounce_message_test.go rename to internal/wire/publish_namespace_done_message_test.go index 47d984d2..dc4d1f49 100644 --- a/internal/wire/unannounce_message_test.go +++ b/internal/wire/publish_namespace_done_message_test.go @@ -8,14 +8,14 @@ import ( "github.com/stretchr/testify/assert" ) -func TestUnannounceMessageAppend(t *testing.T) { +func TestPublishNamespaceDoneMessageAppend(t *testing.T) { cases := []struct { - uam UnannounceMessage + uam PublishNamespaceDoneMessage buf []byte expect []byte }{ { - uam: UnannounceMessage{ + uam: PublishNamespaceDoneMessage{ TrackNamespace: []string{""}, }, buf: []byte{}, @@ -24,7 +24,7 @@ func TestUnannounceMessageAppend(t *testing.T) { }, }, { - uam: UnannounceMessage{ + uam: PublishNamespaceDoneMessage{ TrackNamespace: []string{"tracknamespace"}, }, buf: []byte{0x0a, 0x0b}, @@ -39,34 +39,34 @@ func TestUnannounceMessageAppend(t *testing.T) { } } -func TestParseUnannounceMessage(t *testing.T) { +func TestParsePublishNamespaceDoneMessage(t *testing.T) { cases := []struct { data []byte - expect *UnannounceMessage + expect *PublishNamespaceDoneMessage err error }{ { data: nil, - expect: &UnannounceMessage{}, + expect: &PublishNamespaceDoneMessage{}, err: io.EOF, }, { data: append([]byte{0x01, 0x0E}, "tracknamespace"...), - expect: &UnannounceMessage{ + expect: &PublishNamespaceDoneMessage{ TrackNamespace: []string{"tracknamespace"}, }, err: nil, }, { data: append([]byte{0x01, 0x05}, "tracknamespace"...), - expect: &UnannounceMessage{ + expect: &PublishNamespaceDoneMessage{ TrackNamespace: []string{"track"}, }, err: nil, }, { data: append([]byte{0x01, 0x0F}, "tracknamespace"...), - expect: &UnannounceMessage{ + expect: &PublishNamespaceDoneMessage{ TrackNamespace: []string{}, }, err: errLengthMismatch, @@ -74,7 +74,7 @@ func TestParseUnannounceMessage(t *testing.T) { } for i, tc := range cases { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - res := &UnannounceMessage{} + res := &PublishNamespaceDoneMessage{} err := res.parse(CurrentVersion, tc.data) if tc.err != nil { assert.Equal(t, tc.err, err) diff --git a/internal/wire/announce_error_message.go b/internal/wire/publish_namespace_error_message.go similarity index 63% rename from internal/wire/announce_error_message.go rename to internal/wire/publish_namespace_error_message.go index bbf21955..71a3146f 100644 --- a/internal/wire/announce_error_message.go +++ b/internal/wire/publish_namespace_error_message.go @@ -6,32 +6,32 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) -type AnnounceErrorMessage struct { +type PublishNamespaceErrorMessage struct { RequestID uint64 ErrorCode uint64 ReasonPhrase string } -func (m *AnnounceErrorMessage) LogValue() slog.Value { +func (m *PublishNamespaceErrorMessage) LogValue() slog.Value { return slog.GroupValue( - slog.String("type", "announce_error"), + slog.String("type", "publish_namespace_error"), slog.Uint64("error_code", m.ErrorCode), slog.String("reason", m.ReasonPhrase), ) } -func (m AnnounceErrorMessage) Type() controlMessageType { - return messageTypeAnnounceError +func (m PublishNamespaceErrorMessage) Type() controlMessageType { + return messageTypePublishNamespaceError } -func (m *AnnounceErrorMessage) Append(buf []byte) []byte { +func (m *PublishNamespaceErrorMessage) Append(buf []byte) []byte { buf = quicvarint.Append(buf, m.RequestID) buf = quicvarint.Append(buf, m.ErrorCode) buf = appendVarIntBytes(buf, []byte(m.ReasonPhrase)) return buf } -func (m *AnnounceErrorMessage) parse(_ Version, data []byte) (err error) { +func (m *PublishNamespaceErrorMessage) parse(_ Version, data []byte) (err error) { var n int m.RequestID, n, err = quicvarint.Parse(data) if err != nil { diff --git a/internal/wire/announce_error_message_test.go b/internal/wire/publish_namespace_error_message_test.go similarity index 76% rename from internal/wire/announce_error_message_test.go rename to internal/wire/publish_namespace_error_message_test.go index 685e7f8d..ec4391ef 100644 --- a/internal/wire/announce_error_message_test.go +++ b/internal/wire/publish_namespace_error_message_test.go @@ -8,14 +8,14 @@ import ( "github.com/stretchr/testify/assert" ) -func TestAnnounceErrorMessageAppend(t *testing.T) { +func TestPublishNamespaceErrorMessageAppend(t *testing.T) { cases := []struct { - aem AnnounceErrorMessage + aem PublishNamespaceErrorMessage buf []byte expect []byte }{ { - aem: AnnounceErrorMessage{ + aem: PublishNamespaceErrorMessage{ RequestID: 0, ErrorCode: 0, ReasonPhrase: "", @@ -26,7 +26,7 @@ func TestAnnounceErrorMessageAppend(t *testing.T) { }, }, { - aem: AnnounceErrorMessage{ + aem: PublishNamespaceErrorMessage{ RequestID: 1, ErrorCode: 1, ReasonPhrase: "reason", @@ -35,7 +35,7 @@ func TestAnnounceErrorMessageAppend(t *testing.T) { expect: append([]byte{0x01, 0x01, 0x06}, "reason"...), }, { - aem: AnnounceErrorMessage{ + aem: PublishNamespaceErrorMessage{ RequestID: 1, ErrorCode: 1, ReasonPhrase: "reason", @@ -52,20 +52,20 @@ func TestAnnounceErrorMessageAppend(t *testing.T) { } } -func TestParseAnnounceErrorMessage(t *testing.T) { +func TestParsePublishNamespaceErrorMessage(t *testing.T) { cases := []struct { data []byte - expect *AnnounceErrorMessage + expect *PublishNamespaceErrorMessage err error }{ { data: nil, - expect: &AnnounceErrorMessage{}, + expect: &PublishNamespaceErrorMessage{}, err: io.EOF, }, { data: []byte{0x01, 0x03, 0x03, 'e', 'r'}, - expect: &AnnounceErrorMessage{ + expect: &PublishNamespaceErrorMessage{ RequestID: 1, ErrorCode: 3, ReasonPhrase: "", @@ -74,7 +74,7 @@ func TestParseAnnounceErrorMessage(t *testing.T) { }, { data: append([]byte{0x00, 0x01, 0x0d}, "reason phrase"...), - expect: &AnnounceErrorMessage{ + expect: &PublishNamespaceErrorMessage{ RequestID: 0, ErrorCode: 1, ReasonPhrase: "reason phrase", @@ -84,7 +84,7 @@ func TestParseAnnounceErrorMessage(t *testing.T) { } for i, tc := range cases { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - res := &AnnounceErrorMessage{} + res := &PublishNamespaceErrorMessage{} err := res.parse(CurrentVersion, tc.data) if tc.err != nil { assert.Equal(t, tc.err, err) diff --git a/internal/wire/announce_message.go b/internal/wire/publish_namespace_message.go similarity index 66% rename from internal/wire/announce_message.go rename to internal/wire/publish_namespace_message.go index 029edae8..92069c28 100644 --- a/internal/wire/announce_message.go +++ b/internal/wire/publish_namespace_message.go @@ -6,15 +6,15 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) -type AnnounceMessage struct { +type PublishNamespaceMessage struct { RequestID uint64 TrackNamespace Tuple Parameters KVPList } -func (m *AnnounceMessage) LogValue() slog.Value { +func (m *PublishNamespaceMessage) LogValue() slog.Value { attrs := []slog.Attr{ - slog.String("type", "announce"), + slog.String("type", "publish_namespace"), slog.Any("track_namespace", m.TrackNamespace), slog.Uint64("number_of_parameters", uint64(len(m.Parameters))), } @@ -26,21 +26,21 @@ func (m *AnnounceMessage) LogValue() slog.Value { return slog.GroupValue(attrs...) } -func (m AnnounceMessage) GetTrackNamespace() string { +func (m PublishNamespaceMessage) GetTrackNamespace() string { return m.TrackNamespace.String() } -func (m AnnounceMessage) Type() controlMessageType { - return messageTypeAnnounce +func (m PublishNamespaceMessage) Type() controlMessageType { + return messageTypePublishNamespace } -func (m *AnnounceMessage) Append(buf []byte) []byte { +func (m *PublishNamespaceMessage) Append(buf []byte) []byte { buf = quicvarint.Append(buf, m.RequestID) buf = m.TrackNamespace.append(buf) return m.Parameters.appendNum(buf) } -func (m *AnnounceMessage) parse(_ Version, data []byte) (err error) { +func (m *PublishNamespaceMessage) parse(_ Version, data []byte) (err error) { var n int m.RequestID, n, err = quicvarint.Parse(data) if err != nil { diff --git a/internal/wire/announce_message_test.go b/internal/wire/publish_namespace_message_test.go similarity index 77% rename from internal/wire/announce_message_test.go rename to internal/wire/publish_namespace_message_test.go index adebd50e..544d85b9 100644 --- a/internal/wire/announce_message_test.go +++ b/internal/wire/publish_namespace_message_test.go @@ -8,14 +8,14 @@ import ( "github.com/stretchr/testify/assert" ) -func TestAnnounceMessageAppend(t *testing.T) { +func TestPublishNamespaceMessageAppend(t *testing.T) { cases := []struct { - am AnnounceMessage + am PublishNamespaceMessage buf []byte expect []byte }{ { - am: AnnounceMessage{ + am: PublishNamespaceMessage{ RequestID: 0, TrackNamespace: []string{""}, Parameters: KVPList{}, @@ -26,7 +26,7 @@ func TestAnnounceMessageAppend(t *testing.T) { }, }, { - am: AnnounceMessage{ + am: PublishNamespaceMessage{ RequestID: 1, TrackNamespace: []string{"tracknamespace"}, Parameters: KVPList{}, @@ -43,25 +43,25 @@ func TestAnnounceMessageAppend(t *testing.T) { } } -func TestParseAnnounceMessage(t *testing.T) { +func TestParsePublishNamespaceMessage(t *testing.T) { cases := []struct { data []byte - expect *AnnounceMessage + expect *PublishNamespaceMessage err error }{ { data: nil, - expect: &AnnounceMessage{}, + expect: &PublishNamespaceMessage{}, err: io.EOF, }, { data: []byte{}, - expect: &AnnounceMessage{}, + expect: &PublishNamespaceMessage{}, err: io.EOF, }, { data: append(append([]byte{0x00, 0x01, 0x09}, "trackname"...), 0x00), - expect: &AnnounceMessage{ + expect: &PublishNamespaceMessage{ RequestID: 0, TrackNamespace: []string{"trackname"}, Parameters: KVPList{}, @@ -71,7 +71,7 @@ func TestParseAnnounceMessage(t *testing.T) { } for i, tc := range cases { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - res := &AnnounceMessage{} + res := &PublishNamespaceMessage{} err := res.parse(CurrentVersion, tc.data) assert.Equal(t, tc.expect, res) if tc.err != nil { diff --git a/internal/wire/publish_namespace_ok_message.go b/internal/wire/publish_namespace_ok_message.go new file mode 100644 index 00000000..831189d0 --- /dev/null +++ b/internal/wire/publish_namespace_ok_message.go @@ -0,0 +1,30 @@ +package wire + +import ( + "log/slog" + + "github.com/quic-go/quic-go/quicvarint" +) + +type PublishNamespaceOkMessage struct { + RequestID uint64 +} + +func (m *PublishNamespaceOkMessage) LogValue() slog.Value { + return slog.GroupValue( + slog.String("type", "publish_namespace_ok"), + ) +} + +func (m PublishNamespaceOkMessage) Type() controlMessageType { + return messageTypePublishNamespaceOk +} + +func (m *PublishNamespaceOkMessage) Append(buf []byte) []byte { + return quicvarint.Append(buf, m.RequestID) +} + +func (m *PublishNamespaceOkMessage) parse(_ Version, data []byte) (err error) { + m.RequestID, _, err = quicvarint.Parse(data) + return err +} diff --git a/internal/wire/announce_ok_message_test.go b/internal/wire/publish_namespace_ok_message_test.go similarity index 70% rename from internal/wire/announce_ok_message_test.go rename to internal/wire/publish_namespace_ok_message_test.go index baadab6f..865c3476 100644 --- a/internal/wire/announce_ok_message_test.go +++ b/internal/wire/publish_namespace_ok_message_test.go @@ -8,14 +8,14 @@ import ( "github.com/stretchr/testify/assert" ) -func TestAnnounceOkMessageAppend(t *testing.T) { +func TestPublishNamespaceOkMessageAppend(t *testing.T) { cases := []struct { - aom AnnounceOkMessage + aom PublishNamespaceOkMessage buf []byte expect []byte }{ { - aom: AnnounceOkMessage{ + aom: PublishNamespaceOkMessage{ RequestID: 1, }, buf: []byte{}, @@ -24,7 +24,7 @@ func TestAnnounceOkMessageAppend(t *testing.T) { }, }, { - aom: AnnounceOkMessage{ + aom: PublishNamespaceOkMessage{ RequestID: 1, }, buf: []byte{0x0a, 0x0b}, @@ -39,34 +39,34 @@ func TestAnnounceOkMessageAppend(t *testing.T) { } } -func TestParseAnnounceOkMessage(t *testing.T) { +func TestParsePublishNamespaceOkMessage(t *testing.T) { cases := []struct { data []byte - expect *AnnounceOkMessage + expect *PublishNamespaceOkMessage err error }{ { data: nil, - expect: &AnnounceOkMessage{}, + expect: &PublishNamespaceOkMessage{}, err: io.EOF, }, { data: []byte{0x01}, - expect: &AnnounceOkMessage{ + expect: &PublishNamespaceOkMessage{ RequestID: 1, }, err: nil, }, { data: []byte{0x01}, - expect: &AnnounceOkMessage{ + expect: &PublishNamespaceOkMessage{ RequestID: 1, }, err: nil, }, { data: []byte{}, - expect: &AnnounceOkMessage{ + expect: &PublishNamespaceOkMessage{ RequestID: 0, }, err: io.EOF, @@ -74,7 +74,7 @@ func TestParseAnnounceOkMessage(t *testing.T) { } for i, tc := range cases { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - res := &AnnounceOkMessage{} + res := &PublishNamespaceOkMessage{} err := res.parse(CurrentVersion, tc.data) assert.Equal(t, tc.expect, res) if tc.err != nil { diff --git a/internal/wire/server_setup_message_test.go b/internal/wire/server_setup_message_test.go index 99544121..ef6d80ac 100644 --- a/internal/wire/server_setup_message_test.go +++ b/internal/wire/server_setup_message_test.go @@ -99,10 +99,10 @@ func TestParseServerSetupMessage(t *testing.T) { }, { data: []byte{ - 0xc0, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, + 0xc0, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x0e, 0x00, }, expect: &ServerSetupMessage{ - SelectedVersion: Draft_ietf_moq_transport_00, + SelectedVersion: CurrentVersion, SetupParameters: KVPList{}, }, err: nil, diff --git a/internal/wire/stream_header_subgroup_message.go b/internal/wire/stream_header_subgroup_message.go index c251d86e..0eedbbf0 100644 --- a/internal/wire/stream_header_subgroup_message.go +++ b/internal/wire/stream_header_subgroup_message.go @@ -9,10 +9,15 @@ type SubgroupHeaderMessage struct { GroupID uint64 SubgroupID uint64 PublisherPriority uint8 + EndOfGroup bool } func (m *SubgroupHeaderMessage) Append(buf []byte) []byte { - buf = quicvarint.Append(buf, uint64(StreamTypeSubgroupSIDExt)) + st := StreamTypeSubgroupSIDExt + if m.EndOfGroup { + st = StreamTypeSubgroupSIDExtEOG + } + buf = quicvarint.Append(buf, uint64(st)) buf = quicvarint.Append(buf, m.TrackAlias) buf = quicvarint.Append(buf, m.GroupID) buf = quicvarint.Append(buf, m.SubgroupID) diff --git a/internal/wire/stream_header_subgroup_message_test.go b/internal/wire/stream_header_subgroup_message_test.go index 2cde153d..071dd729 100644 --- a/internal/wire/stream_header_subgroup_message_test.go +++ b/internal/wire/stream_header_subgroup_message_test.go @@ -46,6 +46,28 @@ func TestStreamHeaderSubgroupMessageAppend(t *testing.T) { buf: []byte{0x0a, 0x0b}, expect: []byte{0x0a, 0x0b, byte(StreamTypeSubgroupSIDExt), 0x01, 0x02, 0x03, 0x04}, }, + { + shgm: SubgroupHeaderMessage{ + TrackAlias: 1, + GroupID: 2, + SubgroupID: 3, + PublisherPriority: 4, + EndOfGroup: true, + }, + buf: []byte{}, + expect: []byte{byte(StreamTypeSubgroupSIDExtEOG), 0x01, 0x02, 0x03, 0x04}, + }, + { + shgm: SubgroupHeaderMessage{ + TrackAlias: 0, + GroupID: 0, + SubgroupID: 0, + PublisherPriority: 0, + EndOfGroup: true, + }, + buf: []byte{}, + expect: []byte{byte(StreamTypeSubgroupSIDExtEOG), 0x00, 0x00, 0x00, 0x00}, + }, } for i, tc := range cases { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { @@ -55,6 +77,61 @@ func TestStreamHeaderSubgroupMessageAppend(t *testing.T) { } } +func TestSubgroupStreamTypeHelpers(t *testing.T) { + // All non-EOG subgroup types + for _, st := range []StreamType{0x10, 0x11, 0x12, 0x13, 0x14, 0x15} { + assert.True(t, isSubgroupStreamType(st), "expected 0x%x to be subgroup", st) + assert.False(t, subgroupContainsEndOfGroup(st), "expected 0x%x to not be EOG", st) + } + // All EOG subgroup types + for _, st := range []StreamType{0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D} { + assert.True(t, isSubgroupStreamType(st), "expected 0x%x to be subgroup", st) + assert.True(t, subgroupContainsEndOfGroup(st), "expected 0x%x to be EOG", st) + } + // Gap between non-EOG and EOG (0x16, 0x17) should not be valid + for _, st := range []StreamType{0x16, 0x17} { + assert.False(t, isSubgroupStreamType(st), "expected 0x%x to not be subgroup", st) + } + // Values outside the range + for _, st := range []StreamType{0x00, 0x05, 0x08, 0x0d, 0x0f, 0x1E, 0xFF} { + assert.False(t, isSubgroupStreamType(st), "expected 0x%x to not be subgroup", st) + } + // Explicit SID types + assert.True(t, subgroupHasExplicitSID(StreamTypeSubgroupSIDNoExt)) + assert.True(t, subgroupHasExplicitSID(StreamTypeSubgroupSIDExt)) + assert.True(t, subgroupHasExplicitSID(StreamTypeSubgroupSIDNoExtEOG)) + assert.True(t, subgroupHasExplicitSID(StreamTypeSubgroupSIDExtEOG)) + assert.False(t, subgroupHasExplicitSID(StreamTypeSubgroupZeroSIDNoExt)) + assert.False(t, subgroupHasExplicitSID(StreamTypeSubgroupNoSIDNoExt)) + + // SID from first object ID types + assert.True(t, subgroupSIDIsFirstObjectID(StreamTypeSubgroupNoSIDNoExt)) + assert.True(t, subgroupSIDIsFirstObjectID(StreamTypeSubgroupNoSIDExt)) + assert.True(t, subgroupSIDIsFirstObjectID(StreamTypeSubgroupNoSIDNoExtEOG)) + assert.True(t, subgroupSIDIsFirstObjectID(StreamTypeSubgroupNoSIDExtEOG)) + assert.False(t, subgroupSIDIsFirstObjectID(StreamTypeSubgroupSIDExt)) + assert.False(t, subgroupSIDIsFirstObjectID(StreamTypeSubgroupZeroSIDNoExt)) +} + +func TestNewObjectStreamParserEndOfGroup(t *testing.T) { + // Build a stream with EOG subgroup header: type 0x1D (SID+Ext+EOG), + // followed by TrackAlias=1, GroupID=2, SubgroupID=3, Priority=4 + data := []byte{byte(StreamTypeSubgroupSIDExtEOG), 0x01, 0x02, 0x03, 0x04} + parser, err := NewObjectStreamParser(bytes.NewReader(data), 0, nil) + assert.NoError(t, err) + assert.True(t, parser.EndOfGroup) + assert.Equal(t, uint64(1), parser.Identifier()) + assert.Equal(t, uint64(2), parser.GroupID) + assert.Equal(t, uint64(3), parser.SubgroupID) + assert.Equal(t, uint8(4), parser.PublisherPriority) + + // Non-EOG variant + data2 := []byte{byte(StreamTypeSubgroupSIDExt), 0x01, 0x02, 0x03, 0x04} + parser2, err := NewObjectStreamParser(bytes.NewReader(data2), 0, nil) + assert.NoError(t, err) + assert.False(t, parser2.EndOfGroup) +} + func TestParseStreamHeaderSubgroupMessage(t *testing.T) { cases := []struct { data []byte diff --git a/internal/wire/subscribe_announces_ok_message.go b/internal/wire/subscribe_announces_ok_message.go deleted file mode 100644 index 0a43aba3..00000000 --- a/internal/wire/subscribe_announces_ok_message.go +++ /dev/null @@ -1,31 +0,0 @@ -package wire - -import ( - "log/slog" - - "github.com/quic-go/quic-go/quicvarint" -) - -// TODO: Add tests -type SubscribeAnnouncesOkMessage struct { - RequestID uint64 -} - -func (m *SubscribeAnnouncesOkMessage) LogValue() slog.Value { - return slog.GroupValue( - slog.String("type", "subscribe_announces_ok"), - ) -} - -func (m SubscribeAnnouncesOkMessage) Type() controlMessageType { - return messageTypeSubscribeNamespaceOk -} - -func (m *SubscribeAnnouncesOkMessage) Append(buf []byte) []byte { - return quicvarint.Append(buf, m.RequestID) -} - -func (m *SubscribeAnnouncesOkMessage) parse(_ Version, data []byte) (err error) { - m.RequestID, _, err = quicvarint.Parse(data) - return err -} diff --git a/internal/wire/subscribe_announces_error_message.go b/internal/wire/subscribe_namespace_error_message.go similarity index 68% rename from internal/wire/subscribe_announces_error_message.go rename to internal/wire/subscribe_namespace_error_message.go index 1be04c11..3fda20d2 100644 --- a/internal/wire/subscribe_announces_error_message.go +++ b/internal/wire/subscribe_namespace_error_message.go @@ -6,32 +6,31 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) -// TODO: Add tests -type SubscribeAnnouncesErrorMessage struct { +type SubscribeNamespaceErrorMessage struct { RequestID uint64 ErrorCode uint64 ReasonPhrase string } -func (m *SubscribeAnnouncesErrorMessage) LogValue() slog.Value { +func (m *SubscribeNamespaceErrorMessage) LogValue() slog.Value { return slog.GroupValue( - slog.String("type", "subscribe_announces_error"), + slog.String("type", "subscribe_namespace_error"), slog.Uint64("error_code", m.ErrorCode), slog.String("reason", m.ReasonPhrase), ) } -func (m SubscribeAnnouncesErrorMessage) Type() controlMessageType { +func (m SubscribeNamespaceErrorMessage) Type() controlMessageType { return messageTypeSubscribeNamespaceError } -func (m *SubscribeAnnouncesErrorMessage) Append(buf []byte) []byte { +func (m *SubscribeNamespaceErrorMessage) Append(buf []byte) []byte { buf = quicvarint.Append(buf, m.RequestID) buf = quicvarint.Append(buf, m.ErrorCode) return appendVarIntBytes(buf, []byte(m.ReasonPhrase)) } -func (m *SubscribeAnnouncesErrorMessage) parse(_ Version, data []byte) (err error) { +func (m *SubscribeNamespaceErrorMessage) parse(_ Version, data []byte) (err error) { var n int m.RequestID, n, err = quicvarint.Parse(data) if err != nil { diff --git a/internal/wire/subscribe_announces_message.go b/internal/wire/subscribe_namespace_message.go similarity index 72% rename from internal/wire/subscribe_announces_message.go rename to internal/wire/subscribe_namespace_message.go index 23883812..7df6971b 100644 --- a/internal/wire/subscribe_announces_message.go +++ b/internal/wire/subscribe_namespace_message.go @@ -6,16 +6,15 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) -// TODO: Add tests -type SubscribeAnnouncesMessage struct { +type SubscribeNamespaceMessage struct { RequestID uint64 TrackNamespacePrefix Tuple Parameters KVPList } -func (m *SubscribeAnnouncesMessage) LogValue() slog.Value { +func (m *SubscribeNamespaceMessage) LogValue() slog.Value { attrs := []slog.Attr{ - slog.String("type", "subscribe_announces"), + slog.String("type", "subscribe_namespace"), slog.Any("track_namespace_prefix", m.TrackNamespacePrefix), slog.Uint64("number_of_parameters", uint64(len(m.Parameters))), } @@ -27,17 +26,17 @@ func (m *SubscribeAnnouncesMessage) LogValue() slog.Value { return slog.GroupValue(attrs...) } -func (m SubscribeAnnouncesMessage) Type() controlMessageType { +func (m SubscribeNamespaceMessage) Type() controlMessageType { return messageTypeSubscribeNamespace } -func (m *SubscribeAnnouncesMessage) Append(buf []byte) []byte { +func (m *SubscribeNamespaceMessage) Append(buf []byte) []byte { buf = quicvarint.Append(buf, m.RequestID) buf = m.TrackNamespacePrefix.append(buf) return m.Parameters.appendNum(buf) } -func (m *SubscribeAnnouncesMessage) parse(_ Version, data []byte) (err error) { +func (m *SubscribeNamespaceMessage) parse(_ Version, data []byte) (err error) { var n int m.RequestID, n, err = quicvarint.Parse(data) if err != nil { diff --git a/internal/wire/subscribe_namespace_ok_message.go b/internal/wire/subscribe_namespace_ok_message.go new file mode 100644 index 00000000..568cf5c8 --- /dev/null +++ b/internal/wire/subscribe_namespace_ok_message.go @@ -0,0 +1,31 @@ +package wire + +import ( + "log/slog" + + "github.com/quic-go/quic-go/quicvarint" +) + +// TODO: Add tests +type SubscribeNamespaceOkMessage struct { + RequestID uint64 +} + +func (m *SubscribeNamespaceOkMessage) LogValue() slog.Value { + return slog.GroupValue( + slog.String("type", "subscribe_namespace_ok"), + ) +} + +func (m SubscribeNamespaceOkMessage) Type() controlMessageType { + return messageTypeSubscribeNamespaceOk +} + +func (m *SubscribeNamespaceOkMessage) Append(buf []byte) []byte { + return quicvarint.Append(buf, m.RequestID) +} + +func (m *SubscribeNamespaceOkMessage) parse(_ Version, data []byte) (err error) { + m.RequestID, _, err = quicvarint.Parse(data) + return err +} diff --git a/internal/wire/subscribe_update_message.go b/internal/wire/subscribe_update_message.go index 341085f2..e23c17f6 100644 --- a/internal/wire/subscribe_update_message.go +++ b/internal/wire/subscribe_update_message.go @@ -7,22 +7,25 @@ import ( ) type SubscribeUpdateMessage struct { - RequestID uint64 - StartLocation Location - EndGroup uint64 - SubscriberPriority uint8 - Forward uint8 - Parameters KVPList + RequestID uint64 + SubscriptionRequestID uint64 + StartLocation Location + EndGroup uint64 + SubscriberPriority uint8 + Forward uint8 + Parameters KVPList } func (m *SubscribeUpdateMessage) LogValue() slog.Value { attrs := []slog.Attr{ slog.String("type", "subscribe_update"), slog.Uint64("request_id", m.RequestID), + slog.Uint64("subscription_request_id", m.SubscriptionRequestID), slog.Uint64("start_group", m.StartLocation.Group), slog.Uint64("start_object", m.StartLocation.Object), slog.Uint64("end_group", m.EndGroup), - slog.Any("subscriber_priority", m.SubscriberPriority), + slog.Uint64("subscriber_priority", uint64(m.SubscriberPriority)), + slog.Uint64("forward", uint64(m.Forward)), slog.Uint64("number_of_parameters", uint64(len(m.Parameters))), } if len(m.Parameters) > 0 { @@ -39,6 +42,7 @@ func (m SubscribeUpdateMessage) Type() controlMessageType { func (m *SubscribeUpdateMessage) Append(buf []byte) []byte { buf = quicvarint.Append(buf, m.RequestID) + buf = quicvarint.Append(buf, m.SubscriptionRequestID) buf = m.StartLocation.append(buf) buf = quicvarint.Append(buf, m.EndGroup) buf = append(buf, m.SubscriberPriority) @@ -55,6 +59,12 @@ func (m *SubscribeUpdateMessage) parse(v Version, data []byte) (err error) { } data = data[n:] + m.SubscriptionRequestID, n, err = quicvarint.Parse(data) + if err != nil { + return err + } + data = data[n:] + n, err = m.StartLocation.parse(v, data) if err != nil { return err diff --git a/internal/wire/subscribe_update_message_test.go b/internal/wire/subscribe_update_message_test.go index 674ce856..747e60fc 100644 --- a/internal/wire/subscribe_update_message_test.go +++ b/internal/wire/subscribe_update_message_test.go @@ -16,7 +16,8 @@ func TestSubscribeUpdateMessageAppend(t *testing.T) { }{ { sum: SubscribeUpdateMessage{ - RequestID: 0, + RequestID: 2, + SubscriptionRequestID: 1, StartLocation: Location{ Group: 0, Object: 0, @@ -28,38 +29,40 @@ func TestSubscribeUpdateMessageAppend(t *testing.T) { }, buf: []byte{}, expect: []byte{ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, }, { sum: SubscribeUpdateMessage{ - RequestID: 1, + RequestID: 2, + SubscriptionRequestID: 1, StartLocation: Location{ - Group: 2, - Object: 3, + Group: 3, + Object: 4, }, - EndGroup: 4, - SubscriberPriority: 5, + EndGroup: 5, + SubscriberPriority: 6, Forward: 1, Parameters: KVPList{KeyValuePair{Type: PathParameterKey, ValueBytes: []byte("A")}}, }, buf: []byte{}, - expect: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x01, 0x01, 0x01, 0x01, 'A'}, + expect: []byte{0x02, 0x01, 0x03, 0x04, 0x05, 0x06, 0x01, 0x01, 0x01, 0x01, 'A'}, }, { sum: SubscribeUpdateMessage{ - RequestID: 1, + RequestID: 2, + SubscriptionRequestID: 1, StartLocation: Location{ - Group: 2, - Object: 3, + Group: 3, + Object: 4, }, - EndGroup: 4, - SubscriberPriority: 5, + EndGroup: 5, + SubscriberPriority: 6, Forward: 1, Parameters: KVPList{KeyValuePair{Type: PathParameterKey, ValueBytes: []byte("A")}}, }, buf: []byte{0x0a, 0x0b}, - expect: []byte{0x0a, 0x0b, 0x01, 0x02, 0x03, 0x04, 0x05, 0x01, 0x01, 0x01, 0x01, 'A'}, + expect: []byte{0x0a, 0x0b, 0x02, 0x01, 0x03, 0x04, 0x05, 0x06, 0x01, 0x01, 0x01, 0x01, 'A'}, }, } for i, tc := range cases { @@ -87,12 +90,13 @@ func TestParseSubscribeUpdateMessage(t *testing.T) { err: io.EOF, }, { - data: []byte{0x00, 0x01, 0x02}, + data: []byte{0x02, 0x01, 0x02, 0x03}, expect: &SubscribeUpdateMessage{ - RequestID: 0, + RequestID: 2, + SubscriptionRequestID: 1, StartLocation: Location{ - Group: 1, - Object: 2, + Group: 2, + Object: 3, }, EndGroup: 0, SubscriberPriority: 0, @@ -102,15 +106,16 @@ func TestParseSubscribeUpdateMessage(t *testing.T) { err: io.EOF, }, { - data: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x01, 0x01, 0x01, 0x01, 'P'}, + data: []byte{0x02, 0x01, 0x03, 0x04, 0x05, 0x06, 0x01, 0x01, 0x01, 0x01, 'P'}, expect: &SubscribeUpdateMessage{ - RequestID: 1, + RequestID: 2, + SubscriptionRequestID: 1, StartLocation: Location{ - Group: 2, - Object: 3, + Group: 3, + Object: 4, }, - EndGroup: 4, - SubscriberPriority: 5, + EndGroup: 5, + SubscriberPriority: 6, Forward: 1, Parameters: KVPList{KeyValuePair{Type: PathParameterKey, ValueBytes: []byte("P")}}, }, diff --git a/internal/wire/track_status_message.go b/internal/wire/track_status_message.go index 489d5c7c..19c66e46 100644 --- a/internal/wire/track_status_message.go +++ b/internal/wire/track_status_message.go @@ -3,37 +3,41 @@ package wire import ( "log/slog" + "github.com/mengelbart/qlog" "github.com/quic-go/quic-go/quicvarint" ) type TrackStatusMessage struct { - RequestID uint64 - StatusCode uint64 - LargestLocation Location - Parameters KVPList + RequestID uint64 + TrackNamespace Tuple + TrackName []byte + Parameters KVPList } func (m *TrackStatusMessage) LogValue() slog.Value { return slog.GroupValue( slog.String("type", "track_status"), - slog.Uint64("status_code", m.StatusCode), - slog.Uint64("last_group_id", m.LargestLocation.Group), - slog.Uint64("last_object_id", m.LargestLocation.Object), + slog.Any("track_namespace", m.TrackNamespace), + slog.Any("track_name", qlog.RawInfo{ + Length: uint64(len(m.TrackName)), + PayloadLength: uint64(len(m.TrackName)), + Data: []byte(m.TrackName), + }), ) } func (m TrackStatusMessage) Type() controlMessageType { - return messageTypeTrackStatusOk + return messageTypeTrackStatus } func (m *TrackStatusMessage) Append(buf []byte) []byte { buf = quicvarint.Append(buf, m.RequestID) - buf = quicvarint.Append(buf, m.StatusCode) - buf = m.LargestLocation.append(buf) + buf = m.TrackNamespace.append(buf) + buf = appendVarIntBytes(buf, []byte(m.TrackName)) return m.Parameters.appendNum(buf) } -func (m *TrackStatusMessage) parse(v Version, data []byte) (err error) { +func (m *TrackStatusMessage) parse(_ Version, data []byte) (err error) { var n int m.RequestID, n, err = quicvarint.Parse(data) if err != nil { @@ -41,15 +45,15 @@ func (m *TrackStatusMessage) parse(v Version, data []byte) (err error) { } data = data[n:] - m.StatusCode, n, err = quicvarint.Parse(data) + m.TrackNamespace, n, err = parseTuple(data) if err != nil { return } data = data[n:] - n, err = m.LargestLocation.parse(v, data) + m.TrackName, n, err = parseVarIntBytes(data) if err != nil { - return + return err } data = data[n:] diff --git a/internal/wire/track_status_message_test.go b/internal/wire/track_status_message_test.go index 4cf6a578..e28b5186 100644 --- a/internal/wire/track_status_message_test.go +++ b/internal/wire/track_status_message_test.go @@ -10,39 +10,36 @@ import ( func TestTrackStatusMessageAppend(t *testing.T) { cases := []struct { - tsm TrackStatusMessage + aom TrackStatusMessage buf []byte expect []byte }{ { - tsm: TrackStatusMessage{ - RequestID: 0, - StatusCode: 0, - LargestLocation: Location{ - Group: 0, - Object: 0, - }, + aom: TrackStatusMessage{ + RequestID: 0, + TrackNamespace: []string{""}, + TrackName: []byte(""), + Parameters: KVPList{}, + }, + buf: []byte{}, + expect: []byte{ + 0x00, 0x01, 0x00, 0x00, 0x00, }, - buf: []byte{}, - expect: []byte{0x00, 0x00, 0x00, 0x00, 0x00}, }, { - tsm: TrackStatusMessage{ - RequestID: 1, - StatusCode: 2, - LargestLocation: Location{ - Group: 1, - Object: 2, - }, - Parameters: KVPList{}, + aom: TrackStatusMessage{ + RequestID: 0, + TrackNamespace: []string{"tracknamespace"}, + TrackName: []byte("track"), + Parameters: KVPList{}, }, buf: []byte{0x0a, 0x0b}, - expect: []byte{0x0a, 0x0b, 0x01, 0x02, 0x01, 0x02, 0x00}, + expect: []byte{0x0a, 0x0b, 0x00, 0x01, 0x0e, 't', 'r', 'a', 'c', 'k', 'n', 'a', 'm', 'e', 's', 'p', 'a', 'c', 'e', 0x05, 't', 'r', 'a', 'c', 'k', 0x00}, }, } for i, tc := range cases { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - res := tc.tsm.Append(tc.buf) + res := tc.aom.Append(tc.buf) assert.Equal(t, tc.expect, res) }) } @@ -60,22 +57,24 @@ func TestParseTrackStatusMessage(t *testing.T) { err: io.EOF, }, { - data: []byte{}, - expect: &TrackStatusMessage{}, - err: io.EOF, + data: []byte{0x00, 0x01, 0x0e, 't', 'r', 'a', 'c', 'k', 'n', 'a', 'm', 'e', 's', 'p', 'a', 'c', 'e', 0x05, 't', 'r', 'a', 'c', 'k', 0x00}, + expect: &TrackStatusMessage{ + RequestID: 0, + TrackNamespace: []string{"tracknamespace"}, + TrackName: []byte("track"), + Parameters: KVPList{}, + }, + err: nil, }, { - data: []byte{0x01, 0x02, 0x03, 0x04, 0x00}, + data: append([]byte{0x00, 0x10}, append([]byte("tracknamespace"), 0x00)...), expect: &TrackStatusMessage{ - RequestID: 1, - StatusCode: 2, - LargestLocation: Location{ - Group: 3, - Object: 4, - }, - Parameters: KVPList{}, + RequestID: 0, + TrackNamespace: []string{}, + TrackName: nil, + Parameters: nil, }, - err: nil, + err: errLengthMismatch, }, } for i, tc := range cases { diff --git a/internal/wire/track_status_ok_message.go b/internal/wire/track_status_ok_message.go new file mode 100644 index 00000000..51a62f21 --- /dev/null +++ b/internal/wire/track_status_ok_message.go @@ -0,0 +1,58 @@ +package wire + +import ( + "log/slog" + + "github.com/quic-go/quic-go/quicvarint" +) + +type TrackStatusOkMessage struct { + RequestID uint64 + StatusCode uint64 + LargestLocation Location + Parameters KVPList +} + +func (m *TrackStatusOkMessage) LogValue() slog.Value { + return slog.GroupValue( + slog.String("type", "track_status_ok"), + slog.Uint64("status_code", m.StatusCode), + slog.Uint64("last_group_id", m.LargestLocation.Group), + slog.Uint64("last_object_id", m.LargestLocation.Object), + ) +} + +func (m TrackStatusOkMessage) Type() controlMessageType { + return messageTypeTrackStatusOk +} + +func (m *TrackStatusOkMessage) Append(buf []byte) []byte { + buf = quicvarint.Append(buf, m.RequestID) + buf = quicvarint.Append(buf, m.StatusCode) + buf = m.LargestLocation.append(buf) + return m.Parameters.appendNum(buf) +} + +func (m *TrackStatusOkMessage) parse(v Version, data []byte) (err error) { + var n int + m.RequestID, n, err = quicvarint.Parse(data) + if err != nil { + return + } + data = data[n:] + + m.StatusCode, n, err = quicvarint.Parse(data) + if err != nil { + return + } + data = data[n:] + + n, err = m.LargestLocation.parse(v, data) + if err != nil { + return + } + data = data[n:] + + m.Parameters = KVPList{} + return m.Parameters.parseNum(data) +} diff --git a/internal/wire/track_status_ok_message_test.go b/internal/wire/track_status_ok_message_test.go new file mode 100644 index 00000000..da945b1e --- /dev/null +++ b/internal/wire/track_status_ok_message_test.go @@ -0,0 +1,93 @@ +package wire + +import ( + "fmt" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTrackStatusOkMessageAppend(t *testing.T) { + cases := []struct { + tsm TrackStatusOkMessage + buf []byte + expect []byte + }{ + { + tsm: TrackStatusOkMessage{ + RequestID: 0, + StatusCode: 0, + LargestLocation: Location{ + Group: 0, + Object: 0, + }, + }, + buf: []byte{}, + expect: []byte{0x00, 0x00, 0x00, 0x00, 0x00}, + }, + { + tsm: TrackStatusOkMessage{ + RequestID: 1, + StatusCode: 2, + LargestLocation: Location{ + Group: 1, + Object: 2, + }, + Parameters: KVPList{}, + }, + buf: []byte{0x0a, 0x0b}, + expect: []byte{0x0a, 0x0b, 0x01, 0x02, 0x01, 0x02, 0x00}, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + res := tc.tsm.Append(tc.buf) + assert.Equal(t, tc.expect, res) + }) + } +} + +func TestParseTrackStatusOkMessage(t *testing.T) { + cases := []struct { + data []byte + expect *TrackStatusOkMessage + err error + }{ + { + data: nil, + expect: &TrackStatusOkMessage{}, + err: io.EOF, + }, + { + data: []byte{}, + expect: &TrackStatusOkMessage{}, + err: io.EOF, + }, + { + data: []byte{0x01, 0x02, 0x03, 0x04, 0x00}, + expect: &TrackStatusOkMessage{ + RequestID: 1, + StatusCode: 2, + LargestLocation: Location{ + Group: 3, + Object: 4, + }, + Parameters: KVPList{}, + }, + err: nil, + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + res := &TrackStatusOkMessage{} + err := res.parse(CurrentVersion, tc.data) + assert.Equal(t, tc.expect, res) + if tc.err != nil { + assert.Equal(t, tc.err, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/wire/track_status_request_message.go b/internal/wire/track_status_request_message.go deleted file mode 100644 index 08d8bfc5..00000000 --- a/internal/wire/track_status_request_message.go +++ /dev/null @@ -1,62 +0,0 @@ -package wire - -import ( - "log/slog" - - "github.com/mengelbart/qlog" - "github.com/quic-go/quic-go/quicvarint" -) - -type TrackStatusRequestMessage struct { - RequestID uint64 - TrackNamespace Tuple - TrackName []byte - Parameters KVPList -} - -func (m *TrackStatusRequestMessage) LogValue() slog.Value { - return slog.GroupValue( - slog.String("type", "track_status_request"), - slog.Any("track_namespace", m.TrackNamespace), - slog.Any("track_name", qlog.RawInfo{ - Length: uint64(len(m.TrackName)), - PayloadLength: uint64(len(m.TrackName)), - Data: []byte(m.TrackName), - }), - ) -} - -func (m TrackStatusRequestMessage) Type() controlMessageType { - return messageTypeTrackStatus -} - -func (m *TrackStatusRequestMessage) Append(buf []byte) []byte { - buf = quicvarint.Append(buf, m.RequestID) - buf = m.TrackNamespace.append(buf) - buf = appendVarIntBytes(buf, []byte(m.TrackName)) - return m.Parameters.appendNum(buf) -} - -func (m *TrackStatusRequestMessage) parse(_ Version, data []byte) (err error) { - var n int - m.RequestID, n, err = quicvarint.Parse(data) - if err != nil { - return - } - data = data[n:] - - m.TrackNamespace, n, err = parseTuple(data) - if err != nil { - return - } - data = data[n:] - - m.TrackName, n, err = parseVarIntBytes(data) - if err != nil { - return err - } - data = data[n:] - - m.Parameters = KVPList{} - return m.Parameters.parseNum(data) -} diff --git a/internal/wire/track_status_request_message_test.go b/internal/wire/track_status_request_message_test.go deleted file mode 100644 index 34cedab7..00000000 --- a/internal/wire/track_status_request_message_test.go +++ /dev/null @@ -1,92 +0,0 @@ -package wire - -import ( - "fmt" - "io" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestTrackStatusRequestMessageAppend(t *testing.T) { - cases := []struct { - aom TrackStatusRequestMessage - buf []byte - expect []byte - }{ - { - aom: TrackStatusRequestMessage{ - RequestID: 0, - TrackNamespace: []string{""}, - TrackName: []byte(""), - Parameters: KVPList{}, - }, - buf: []byte{}, - expect: []byte{ - 0x00, 0x01, 0x00, 0x00, 0x00, - }, - }, - { - aom: TrackStatusRequestMessage{ - RequestID: 0, - TrackNamespace: []string{"tracknamespace"}, - TrackName: []byte("track"), - Parameters: KVPList{}, - }, - buf: []byte{0x0a, 0x0b}, - expect: []byte{0x0a, 0x0b, 0x00, 0x01, 0x0e, 't', 'r', 'a', 'c', 'k', 'n', 'a', 'm', 'e', 's', 'p', 'a', 'c', 'e', 0x05, 't', 'r', 'a', 'c', 'k', 0x00}, - }, - } - for i, tc := range cases { - t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - res := tc.aom.Append(tc.buf) - assert.Equal(t, tc.expect, res) - }) - } -} - -func TestParseTrackStatusRequestMessage(t *testing.T) { - cases := []struct { - data []byte - expect *TrackStatusRequestMessage - err error - }{ - { - data: nil, - expect: &TrackStatusRequestMessage{}, - err: io.EOF, - }, - { - data: []byte{0x00, 0x01, 0x0e, 't', 'r', 'a', 'c', 'k', 'n', 'a', 'm', 'e', 's', 'p', 'a', 'c', 'e', 0x05, 't', 'r', 'a', 'c', 'k', 0x00}, - expect: &TrackStatusRequestMessage{ - RequestID: 0, - TrackNamespace: []string{"tracknamespace"}, - TrackName: []byte("track"), - Parameters: KVPList{}, - }, - err: nil, - }, - { - data: append([]byte{0x00, 0x10}, append([]byte("tracknamespace"), 0x00)...), - expect: &TrackStatusRequestMessage{ - RequestID: 0, - TrackNamespace: []string{}, - TrackName: nil, - Parameters: nil, - }, - err: errLengthMismatch, - }, - } - for i, tc := range cases { - t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { - res := &TrackStatusRequestMessage{} - err := res.parse(CurrentVersion, tc.data) - assert.Equal(t, tc.expect, res) - if tc.err != nil { - assert.Equal(t, tc.err, err) - } else { - assert.NoError(t, err) - } - }) - } -} diff --git a/internal/wire/unannounce_message.go b/internal/wire/unannounce_message.go deleted file mode 100644 index 232b7ea8..00000000 --- a/internal/wire/unannounce_message.go +++ /dev/null @@ -1,30 +0,0 @@ -package wire - -import ( - "log/slog" -) - -type UnannounceMessage struct { - TrackNamespace Tuple -} - -func (m *UnannounceMessage) LogValue() slog.Value { - return slog.GroupValue( - slog.String("type", "unannounce"), - slog.Any("track_namespace", m.TrackNamespace), - ) -} - -func (m UnannounceMessage) Type() controlMessageType { - return messageTypeUnannounce -} - -func (m *UnannounceMessage) Append(buf []byte) []byte { - buf = m.TrackNamespace.append(buf) - return buf -} - -func (p *UnannounceMessage) parse(_ Version, data []byte) (err error) { - p.TrackNamespace, _, err = parseTuple(data) - return err -} diff --git a/internal/wire/unsubscribe_announces_message.go b/internal/wire/unsubscribe_namespace_message.go similarity index 50% rename from internal/wire/unsubscribe_announces_message.go rename to internal/wire/unsubscribe_namespace_message.go index 634224db..7fddefa6 100644 --- a/internal/wire/unsubscribe_announces_message.go +++ b/internal/wire/unsubscribe_namespace_message.go @@ -4,27 +4,27 @@ import ( "log/slog" ) -// TODO: Add tests -type UnsubscribeAnnouncesMessage struct { +type UnsubscribeNamespaceMessage struct { TrackNamespacePrefix Tuple } -func (m *UnsubscribeAnnouncesMessage) LogValue() slog.Value { +// TODO: Add tests +func (m *UnsubscribeNamespaceMessage) LogValue() slog.Value { return slog.GroupValue( - slog.String("type", "unsubscribe_announces"), + slog.String("type", "unsubscribe_namespace"), slog.Any("track_namespace_prefix", m.TrackNamespacePrefix), ) } -func (m UnsubscribeAnnouncesMessage) Type() controlMessageType { +func (m UnsubscribeNamespaceMessage) Type() controlMessageType { return messageTypeUnsubscribeNamespace } -func (m *UnsubscribeAnnouncesMessage) Append(buf []byte) []byte { +func (m *UnsubscribeNamespaceMessage) Append(buf []byte) []byte { return m.TrackNamespacePrefix.append(buf) } -func (m *UnsubscribeAnnouncesMessage) parse(_ Version, data []byte) (err error) { +func (m *UnsubscribeNamespaceMessage) parse(_ Version, data []byte) (err error) { m.TrackNamespacePrefix, _, err = parseTuple(data) return err } diff --git a/internal/wire/version.go b/internal/wire/version.go index 975ed807..c3fb158c 100644 --- a/internal/wire/version.go +++ b/internal/wire/version.go @@ -9,19 +9,7 @@ import ( type Version uint64 const ( - Draft_ietf_moq_transport_00 Version = 0xff000000 - Draft_ietf_moq_transport_01 Version = 0xff000001 - Draft_ietf_moq_transport_02 Version = 0xff000002 - Draft_ietf_moq_transport_03 Version = 0xff000003 - Draft_ietf_moq_transport_04 Version = 0xff000004 - Draft_ietf_moq_transport_05 Version = 0xff000005 - Draft_ietf_moq_transport_06 Version = 0xff000006 - Draft_ietf_moq_transport_07 Version = 0xff000007 - Draft_ietf_moq_transport_08 Version = 0xff000008 - Draft_ietf_moq_transport_10 Version = 0xff00000a - Draft_ietf_moq_transport_11 Version = 0xff00000b - - CurrentVersion = Draft_ietf_moq_transport_11 + CurrentVersion Version = 0xff00000e // draft-ietf-moq-transport-14 ) var SupportedVersions = []Version{CurrentVersion} diff --git a/local_track.go b/local_track.go index 656483bd..c7d2e026 100644 --- a/local_track.go +++ b/local_track.go @@ -120,16 +120,20 @@ func (p *localTrack) sendDatagram(o Object) error { return p.conn.SendDatagram(buf) } -func (p *localTrack) openSubgroup(groupID, subgroupID uint64, priority uint8) (*Subgroup, error) { +func (p *localTrack) openSubgroup(groupID, subgroupID uint64, priority uint8, opts ...SubgroupOption) (*Subgroup, error) { if err := p.closed(); err != nil { return nil, err } + var o subgroupOptions + for _, opt := range opts { + opt(&o) + } stream, err := p.conn.OpenUniStream() if err != nil { return nil, err } p.subgroupCount++ - return newSubgroup(stream, p.trackAlias, groupID, subgroupID, priority, p.qlogger) + return newSubgroup(stream, p.trackAlias, groupID, subgroupID, priority, o.endOfGroup, p.qlogger) } func (s *localTrack) close(code uint64, reason string) error { diff --git a/messages.go b/messages.go index 438cd2c7..a381b23c 100644 --- a/messages.go +++ b/messages.go @@ -214,7 +214,8 @@ type SubscribeMessage struct { // SubscribeUpdateMessage represents a SUBSCRIBE_UPDATE message from the peer. type SubscribeUpdateMessage struct { - RequestID uint64 + RequestID uint64 + SubscriptionRequestID uint64 // The Request ID of the original SUBSCRIBE message being updated // Subscribe update specific fields StartLocation Location // New start position for the subscription diff --git a/session.go b/session.go index 0895bfe6..acfc459a 100644 --- a/session.go +++ b/session.go @@ -556,12 +556,18 @@ func (s *Session) Subscribe( // - SubscriberPriority: 128 (medium priority) // - Forward: true (forward preference) // - Parameters: empty -func (s *Session) UpdateSubscription(ctx context.Context, requestID uint64, options ...SubscribeUpdateOption) error { +func (s *Session) UpdateSubscription(ctx context.Context, subscriptionRequestID uint64, options ...SubscribeUpdateOption) error { // Validate that the subscription exists - if _, exists := s.remoteTracks.findByRequestID(requestID); !exists { + if _, exists := s.remoteTracks.findByRequestID(subscriptionRequestID); !exists { return errUnknownRequestID } + // Get a new request ID for this update message + newRequestID, err := s.getRequestID() + if err != nil { + return err + } + // Set default values opts := &SubscribeUpdateOptions{ StartLocation: Location{ @@ -581,12 +587,13 @@ func (s *Session) UpdateSubscription(ctx context.Context, requestID uint64, opti // Create and send SUBSCRIBE_UPDATE message cm := &wire.SubscribeUpdateMessage{ - RequestID: requestID, - StartLocation: opts.StartLocation, - EndGroup: opts.EndGroup, - SubscriberPriority: opts.SubscriberPriority, - Forward: boolToUint8(opts.Forward), - Parameters: opts.Parameters.ToWire(), + RequestID: newRequestID, + SubscriptionRequestID: subscriptionRequestID, + StartLocation: opts.StartLocation, + EndGroup: opts.EndGroup, + SubscriberPriority: opts.SubscriberPriority, + Forward: boolToUint8(opts.Forward), + Parameters: opts.Parameters.ToWire(), } return s.controlStream.write(cm) @@ -594,7 +601,7 @@ func (s *Session) UpdateSubscription(ctx context.Context, requestID uint64, opti // acceptSubscriptionWithOptions accepts a subscription with relevant options. func (s *Session) acceptSubscriptionWithOptions(id uint64, opts *SubscribeOkOptions) error { - _, ok := s.localTracks.confirm(id) + lt, ok := s.localTracks.confirm(id) if !ok { return errUnknownRequestID } @@ -612,6 +619,7 @@ func (s *Session) acceptSubscriptionWithOptions(id uint64, opts *SubscribeOkOpti msg := &wire.SubscribeOkMessage{ RequestID: id, + TrackAlias: lt.trackAlias, Expires: opts.Expires, GroupOrder: uint8(opts.GroupOrder), ContentExists: opts.ContentExists, @@ -649,7 +657,7 @@ func (s *Session) subscriptionDone(id, code, count uint64, reason string) error if !ok { return errUnknownRequestID } - return s.controlStream.write(&wire.SubscribeDoneMessage{ + return s.controlStream.write(&wire.PublishDoneMessage{ RequestID: lt.requestID, StatusCode: code, StreamCount: count, @@ -768,7 +776,7 @@ func (s *Session) RequestTrackStatus(ctx context.Context, namespace []string, tr } s.outgoingTrackStatusRequests.add(tsr) - tsrm := &wire.TrackStatusRequestMessage{ + tsrm := &wire.TrackStatusMessage{ TrackNamespace: namespace, TrackName: []byte(track), } @@ -785,7 +793,7 @@ func (s *Session) RequestTrackStatus(ctx context.Context, namespace []string, tr } func (s *Session) sendTrackStatus(ts TrackStatus) error { - return s.controlStream.write(&wire.TrackStatusMessage{ + return s.controlStream.write(&wire.TrackStatusOkMessage{ StatusCode: ts.StatusCode, RequestID: 0, LargestLocation: wire.Location{}, @@ -808,7 +816,7 @@ func (s *Session) Announce(ctx context.Context, namespace []string) error { response: make(chan error, 1), } s.outgoingAnnouncements.add(a) - am := &wire.AnnounceMessage{ + am := &wire.PublishNamespaceMessage{ RequestID: a.requestID, TrackNamespace: a.namespace, Parameters: a.parameters, @@ -829,7 +837,7 @@ func (s *Session) acceptAnnouncement(requestID uint64) error { if _, err := s.incomingAnnouncements.confirmAndGet(requestID); err != nil { return err } - if err := s.controlStream.write(&wire.AnnounceOkMessage{ + if err := s.controlStream.write(&wire.PublishNamespaceOkMessage{ RequestID: requestID, }); err != nil { return err @@ -838,7 +846,7 @@ func (s *Session) acceptAnnouncement(requestID uint64) error { } func (s *Session) rejectAnnouncement(requestID uint64, c uint64, r string) error { - return s.controlStream.write(&wire.AnnounceErrorMessage{ + return s.controlStream.write(&wire.PublishNamespaceErrorMessage{ RequestID: requestID, ErrorCode: c, ReasonPhrase: r, @@ -849,7 +857,7 @@ func (s *Session) Unannounce(ctx context.Context, namespace []string) error { if ok := s.outgoingAnnouncements.delete(namespace); ok { return errUnknownAnnouncementNamespace } - u := &wire.UnannounceMessage{ + u := &wire.PublishNamespaceDoneMessage{ TrackNamespace: namespace, } return s.controlStream.write(u) @@ -859,7 +867,7 @@ func (s *Session) AnnounceCancel(ctx context.Context, namespace []string, errorC if !s.incomingAnnouncements.delete(namespace) { return errUnknownAnnouncementNamespace } - acm := &wire.AnnounceCancelMessage{ + acm := &wire.PublishNamespaceCancelMessage{ TrackNamespace: namespace, ErrorCode: errorCode, ReasonPhrase: reason, @@ -880,7 +888,7 @@ func (s *Session) SubscribeAnnouncements(ctx context.Context, prefix []string) e response: make(chan announcementSubscriptionResponse, 1), } s.pendingOutgointAnnouncementSubscriptions.add(as) - sam := &wire.SubscribeAnnouncesMessage{ + sam := &wire.SubscribeNamespaceMessage{ RequestID: as.requestID, TrackNamespacePrefix: as.namespace, Parameters: wire.KVPList{}, @@ -898,13 +906,13 @@ func (s *Session) SubscribeAnnouncements(ctx context.Context, prefix []string) e } func (s *Session) acceptAnnouncementSubscription(requestID uint64) error { - return s.controlStream.write(&wire.SubscribeAnnouncesOkMessage{ + return s.controlStream.write(&wire.SubscribeNamespaceOkMessage{ RequestID: requestID, }) } func (s *Session) rejectAnnouncementSubscription(requestID uint64, c uint64, r string) error { - return s.controlStream.write(&wire.SubscribeAnnouncesErrorMessage{ + return s.controlStream.write(&wire.SubscribeNamespaceErrorMessage{ RequestID: requestID, ErrorCode: c, ReasonPhrase: r, @@ -913,7 +921,7 @@ func (s *Session) rejectAnnouncementSubscription(requestID uint64, c uint64, r s func (s *Session) UnsubscribeAnnouncements(ctx context.Context, namespace []string) error { s.pendingOutgointAnnouncementSubscriptions.delete(namespace) - uam := &wire.UnsubscribeAnnouncesMessage{ + uam := &wire.UnsubscribeNamespaceMessage{ TrackNamespacePrefix: namespace, } return s.controlStream.write(uam) @@ -952,8 +960,8 @@ func (s *Session) receive(msg wire.ControlMessage) error { err = s.onSubscribeUpdate(m) case *wire.UnsubscribeMessage: err = s.onUnsubscribe(m) - case *wire.SubscribeDoneMessage: - err = s.onSubscribeDone(m) + case *wire.PublishDoneMessage: + err = s.onPublishDone(m) case *wire.FetchMessage: err = s.onFetch(m) case *wire.FetchOkMessage: @@ -962,28 +970,28 @@ func (s *Session) receive(msg wire.ControlMessage) error { err = s.onFetchError(m) case *wire.FetchCancelMessage: err = s.onFetchCancel(m) - case *wire.TrackStatusRequestMessage: - err = s.onTrackStatusRequest(m) case *wire.TrackStatusMessage: err = s.onTrackStatus(m) - case *wire.AnnounceMessage: - err = s.onAnnounce(m) - case *wire.AnnounceOkMessage: - err = s.onAnnounceOk(m) - case *wire.AnnounceErrorMessage: - err = s.onAnnounceError(m) - case *wire.UnannounceMessage: - err = s.onUnannounce(m) - case *wire.AnnounceCancelMessage: - err = s.onAnnounceCancel(m) - case *wire.SubscribeAnnouncesMessage: - err = s.onSubscribeAnnounces(m) - case *wire.SubscribeAnnouncesOkMessage: - err = s.onSubscribeAnnouncesOk(m) - case *wire.SubscribeAnnouncesErrorMessage: - err = s.onSubscribeAnnouncesError(m) - case *wire.UnsubscribeAnnouncesMessage: - s.onUnsubscribeAnnounces(m) + case *wire.TrackStatusOkMessage: + err = s.onTrackStatusOk(m) + case *wire.PublishNamespaceMessage: + err = s.onPublishNamespace(m) + case *wire.PublishNamespaceOkMessage: + err = s.onPublishNamespaceOk(m) + case *wire.PublishNamespaceErrorMessage: + err = s.onPublishNamespaceError(m) + case *wire.PublishNamespaceDoneMessage: + err = s.onPublishNamespaceDone(m) + case *wire.PublishNamespaceCancelMessage: + err = s.onPublishNamespaceCancel(m) + case *wire.SubscribeNamespaceMessage: + err = s.onSubscribeNamespace(m) + case *wire.SubscribeNamespaceOkMessage: + err = s.onSubscribeNamespaceOk(m) + case *wire.SubscribeNamespaceErrorMessage: + err = s.onSubscribeNamespaceError(m) + case *wire.UnsubscribeNamespaceMessage: + s.onUnsubscribeNamespace(m) default: err = errUnexpectedMessageType } @@ -1179,22 +1187,29 @@ func (s *Session) onSubscribeError(msg *wire.SubscribeErrorMessage) error { } func (s *Session) onSubscribeUpdate(msg *wire.SubscribeUpdateMessage) error { - // Find the local track for this request ID to validate it exists - _, ok := s.localTracks.findByID(msg.RequestID) + // Validate Request ID against flow control limit (per draft-14, SUBSCRIBE_UPDATE + // increments the request ID counter and is subject to flow control) + if msg.RequestID >= s.localMaxRequestID.Load() { + return errMaxRequestIDViolated + } + + // Find the local track by SubscriptionRequestID to validate the subscription exists + _, ok := s.localTracks.findByID(msg.SubscriptionRequestID) if !ok { - // According to draft-11, should close session with Protocol Violation - // if Request ID doesn't exist + // According to draft-14, should close session with Protocol Violation + // if Subscription Request ID doesn't exist return errUnknownRequestID } // Convert wire message to public message struct publicMsg := &SubscribeUpdateMessage{ - RequestID: msg.RequestID, - StartLocation: msg.StartLocation, - EndGroup: msg.EndGroup, - SubscriberPriority: msg.SubscriberPriority, - Forward: msg.Forward, - Parameters: FromWire(msg.Parameters), + RequestID: msg.RequestID, + SubscriptionRequestID: msg.SubscriptionRequestID, + StartLocation: msg.StartLocation, + EndGroup: msg.EndGroup, + SubscriberPriority: msg.SubscriberPriority, + Forward: msg.Forward, + Parameters: FromWire(msg.Parameters), } // Propagate to application handler if available @@ -1203,7 +1218,7 @@ func (s *Session) onSubscribeUpdate(msg *wire.SubscribeUpdateMessage) error { } // For now, accept the update without enforcing constraints - // A full implementation would validate narrowing constraints per draft-11 + // A full implementation would validate narrowing constraints per draft-14 return nil } @@ -1218,7 +1233,7 @@ func (s *Session) onUnsubscribe(msg *wire.UnsubscribeMessage) error { return nil } -func (s *Session) onSubscribeDone(msg *wire.SubscribeDoneMessage) error { +func (s *Session) onPublishDone(msg *wire.PublishDoneMessage) error { sub, ok := s.remoteTracks.findByRequestID(msg.RequestID) if !ok { return errUnknownRequestID @@ -1304,7 +1319,7 @@ func (s *Session) onFetchCancel(msg *wire.FetchCancelMessage) error { return nil } -func (s *Session) onTrackStatusRequest(msg *wire.TrackStatusRequestMessage) error { +func (s *Session) onTrackStatus(msg *wire.TrackStatusMessage) error { if len(msg.TrackNamespace) == 0 || len(msg.TrackNamespace) > 32 { return errInvalidNamespaceLength } @@ -1330,7 +1345,7 @@ func (s *Session) onTrackStatusRequest(msg *wire.TrackStatusRequestMessage) erro return nil } -func (s *Session) onTrackStatus(msg *wire.TrackStatusMessage) error { +func (s *Session) onTrackStatusOk(msg *wire.TrackStatusOkMessage) error { tsr, ok := s.outgoingTrackStatusRequests.delete(msg.RequestID) if !ok { return errUnknownTrackStatusRequest @@ -1349,7 +1364,7 @@ func (s *Session) onTrackStatus(msg *wire.TrackStatusMessage) error { return nil } -func (s *Session) onAnnounce(msg *wire.AnnounceMessage) error { +func (s *Session) onPublishNamespace(msg *wire.PublishNamespaceMessage) error { if len(msg.TrackNamespace) == 0 || len(msg.TrackNamespace) > 32 { return errInvalidNamespaceLength } @@ -1377,7 +1392,7 @@ func (s *Session) onAnnounce(msg *wire.AnnounceMessage) error { return nil } -func (s *Session) onAnnounceOk(msg *wire.AnnounceOkMessage) error { +func (s *Session) onPublishNamespaceOk(msg *wire.PublishNamespaceOkMessage) error { announcement, err := s.outgoingAnnouncements.confirmAndGet(msg.RequestID) if err != nil { return errUnknownAnnouncement @@ -1390,7 +1405,7 @@ func (s *Session) onAnnounceOk(msg *wire.AnnounceOkMessage) error { return nil } -func (s *Session) onAnnounceError(msg *wire.AnnounceErrorMessage) error { +func (s *Session) onPublishNamespaceError(msg *wire.PublishNamespaceErrorMessage) error { announcement, ok := s.outgoingAnnouncements.reject(msg.RequestID) if !ok { return errUnknownAnnouncement @@ -1406,7 +1421,7 @@ func (s *Session) onAnnounceError(msg *wire.AnnounceErrorMessage) error { return nil } -func (s *Session) onUnannounce(msg *wire.UnannounceMessage) error { +func (s *Session) onPublishNamespaceDone(msg *wire.PublishNamespaceDoneMessage) error { if len(msg.TrackNamespace) == 0 || len(msg.TrackNamespace) > 32 { return errInvalidNamespaceLength } @@ -1420,7 +1435,7 @@ func (s *Session) onUnannounce(msg *wire.UnannounceMessage) error { return nil } -func (s *Session) onAnnounceCancel(msg *wire.AnnounceCancelMessage) error { +func (s *Session) onPublishNamespaceCancel(msg *wire.PublishNamespaceCancelMessage) error { if len(msg.TrackNamespace) == 0 || len(msg.TrackNamespace) > 32 { return errInvalidNamespaceLength } @@ -1433,7 +1448,7 @@ func (s *Session) onAnnounceCancel(msg *wire.AnnounceCancelMessage) error { return nil } -func (s *Session) onSubscribeAnnounces(msg *wire.SubscribeAnnouncesMessage) error { +func (s *Session) onSubscribeNamespace(msg *wire.SubscribeNamespaceMessage) error { s.pendingIncomingAnnouncementSubscriptions.add(&announcementSubscription{ requestID: msg.RequestID, namespace: msg.TrackNamespacePrefix, @@ -1455,7 +1470,7 @@ func (s *Session) onSubscribeAnnounces(msg *wire.SubscribeAnnouncesMessage) erro return nil } -func (s *Session) onSubscribeAnnouncesOk(msg *wire.SubscribeAnnouncesOkMessage) error { +func (s *Session) onSubscribeNamespaceOk(msg *wire.SubscribeNamespaceOkMessage) error { as, ok := s.pendingOutgointAnnouncementSubscriptions.deleteByID(msg.RequestID) if !ok { return errUnknownSubscribeAnnouncesPrefix @@ -1465,12 +1480,12 @@ func (s *Session) onSubscribeAnnouncesOk(msg *wire.SubscribeAnnouncesOkMessage) err: nil, }: default: - s.logger.Info("dropping unhandled SubscribeAnnounces response") + s.logger.Info("dropping unhandled SubscribeNamespace response") } return nil } -func (s *Session) onSubscribeAnnouncesError(msg *wire.SubscribeAnnouncesErrorMessage) error { +func (s *Session) onSubscribeNamespaceError(msg *wire.SubscribeNamespaceErrorMessage) error { as, ok := s.pendingOutgointAnnouncementSubscriptions.deleteByID(msg.RequestID) if !ok { return errUnknownSubscribeAnnouncesPrefix @@ -1483,12 +1498,12 @@ func (s *Session) onSubscribeAnnouncesError(msg *wire.SubscribeAnnouncesErrorMes }, }: default: - s.logger.Info("dropping unhandled SubscribeAnnounces response") + s.logger.Info("dropping unhandled SubscribeNamespace response") } return nil } -func (s *Session) onUnsubscribeAnnounces(msg *wire.UnsubscribeAnnouncesMessage) { +func (s *Session) onUnsubscribeNamespace(msg *wire.UnsubscribeNamespaceMessage) { s.Handler.Handle(nil, &Message{ Method: MessageUnsubscribeAnnounces, Namespace: msg.TrackNamespacePrefix, diff --git a/session_test.go b/session_test.go index 07196d4e..531e2f24 100644 --- a/session_test.go +++ b/session_test.go @@ -371,12 +371,12 @@ func TestSession(t *testing.T) { s := newSession(conn, cs, nil) s.handshakeDone.Store(true) - cs.EXPECT().write(&wire.AnnounceMessage{ + cs.EXPECT().write(&wire.PublishNamespaceMessage{ RequestID: 0, TrackNamespace: []string{"namespace"}, Parameters: wire.KVPList{}, }).DoAndReturn(func(_ wire.ControlMessage) error { - err := s.receive(&wire.AnnounceOkMessage{ + err := s.receive(&wire.PublishNamespaceOkMessage{ RequestID: 0, }) assert.NoError(t, err) @@ -404,10 +404,10 @@ func TestSession(t *testing.T) { }).DoAndReturn(func(rw ResponseWriter, req *Message) { assert.NoError(t, rw.Accept()) }) - cs.EXPECT().write(&wire.AnnounceOkMessage{ + cs.EXPECT().write(&wire.PublishNamespaceOkMessage{ RequestID: 2, }) - err := s.receive(&wire.AnnounceMessage{ + err := s.receive(&wire.PublishNamespaceMessage{ RequestID: 2, TrackNamespace: []string{"namespace"}, Parameters: wire.KVPList{}, @@ -503,15 +503,21 @@ func TestSession_UpdateSubscription(t *testing.T) { _, err := s.remoteTracks.confirm(123) assert.NoError(t, err) - // Expect SUBSCRIBE_UPDATE message to be written - cs.EXPECT().write(&wire.SubscribeUpdateMessage{ - RequestID: 123, - StartLocation: wire.Location{Group: 100, Object: 5}, - EndGroup: 200, - SubscriberPriority: 64, - Forward: 1, - Parameters: wire.KVPList{}, - }).Return(nil) + // Set max request ID to allow getting a new request ID + _ = s.requestIDs.setMax(100) + + // Expect SUBSCRIBE_UPDATE message to be written with a new RequestID + // The SubscriptionRequestID is 123 (the original subscription) + cs.EXPECT().write(gomock.Any()).DoAndReturn(func(msg wire.ControlMessage) error { + sum, ok := msg.(*wire.SubscribeUpdateMessage) + assert.True(t, ok) + assert.Equal(t, uint64(123), sum.SubscriptionRequestID) + assert.Equal(t, wire.Location{Group: 100, Object: 5}, sum.StartLocation) + assert.Equal(t, uint64(200), sum.EndGroup) + assert.Equal(t, uint8(64), sum.SubscriberPriority) + assert.Equal(t, uint8(1), sum.Forward) + return nil + }) // Test UpdateSubscription err = s.UpdateSubscription(context.Background(), 123, diff --git a/subgroup.go b/subgroup.go index 8d0266a7..98ee39e1 100644 --- a/subgroup.go +++ b/subgroup.go @@ -14,12 +14,13 @@ type Subgroup struct { subgroupID uint64 } -func newSubgroup(stream SendStream, trackAlias, groupID, subgroupID uint64, publisherPriority uint8, qlogger *qlog.Logger) (*Subgroup, error) { +func newSubgroup(stream SendStream, trackAlias, groupID, subgroupID uint64, publisherPriority uint8, endOfGroup bool, qlogger *qlog.Logger) (*Subgroup, error) { shgm := &wire.SubgroupHeaderMessage{ TrackAlias: trackAlias, GroupID: groupID, SubgroupID: subgroupID, PublisherPriority: publisherPriority, + EndOfGroup: endOfGroup, } buf := make([]byte, 0, 40) buf = shgm.Append(buf) diff --git a/subscribe_response_writer.go b/subscribe_response_writer.go index 43cd19b6..c59d07f7 100644 --- a/subscribe_response_writer.go +++ b/subscribe_response_writer.go @@ -89,8 +89,8 @@ func (w *SubscribeResponseWriter) SendDatagram(o Object) error { return w.localTrack.sendDatagram(o) } -func (w *SubscribeResponseWriter) OpenSubgroup(groupID, subgroupID uint64, priority uint8) (*Subgroup, error) { - return w.localTrack.openSubgroup(groupID, subgroupID, priority) +func (w *SubscribeResponseWriter) OpenSubgroup(groupID, subgroupID uint64, priority uint8, opts ...SubgroupOption) (*Subgroup, error) { + return w.localTrack.openSubgroup(groupID, subgroupID, priority, opts...) } func (w *SubscribeResponseWriter) CloseWithError(code uint64, reason string) error {