diff --git a/examples/warp/main.go b/examples/warp/main.go index 240148162ac..8e11386bb0f 100644 --- a/examples/warp/main.go +++ b/examples/warp/main.go @@ -45,8 +45,13 @@ func setupOfferHandler(pc **webrtc.PeerConnection) { return } + // Enable SNAP. + s := webrtc.SettingEngine{} + s.EnableSctpSnap(true) + api := webrtc.NewAPI(webrtc.WithSettingEngine(s)) + var err error - *pc, err = webrtc.NewPeerConnection(webrtc.Configuration{ + *pc, err = api.NewPeerConnection(webrtc.Configuration{ BundlePolicy: webrtc.BundlePolicyMaxBundle, }) if err != nil { diff --git a/peerconnection.go b/peerconnection.go index 2cd6215effe..fa0d5e465ff 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -1615,10 +1615,11 @@ func (pc *PeerConnection) startRTPSenders(currentTransceivers []*RTPTransceiver) } // Start SCTP subsystem. -func (pc *PeerConnection) startSCTP(maxMessageSize uint32) { +func (pc *PeerConnection) startSCTP(maxMessageSize uint32, remoteSctpInit []byte) { // Start sctp if err := pc.sctpTransport.Start(SCTPCapabilities{ MaxMessageSize: maxMessageSize, + SctpInit: remoteSctpInit, }); err != nil { pc.log.Warnf("Failed to start SCTP: %s", err) if err = pc.sctpTransport.Stop(); err != nil { @@ -2791,7 +2792,8 @@ func (pc *PeerConnection) startRTP( pc.startRTPReceivers(remoteDesc, currentTransceivers) if d := haveDataChannel(remoteDesc); d != nil { - pc.startSCTP(getMaxMessageSize(d)) + remoteSctpInit, _ := getSctpInit(d) + pc.startSCTP(getMaxMessageSize(d), remoteSctpInit) } } @@ -2824,6 +2826,11 @@ func (pc *PeerConnection) generateUnmatchedSDP( // Needed for pc.sctpTransport.dataChannelsRequested pc.sctpTransport.lock.Lock() + + var localSctpInit []byte + if pc.sctpTransport.dataChannelsRequested != 0 && pc.api.settingEngine.sctp.enableSnap { + localSctpInit = pc.sctpTransport.GetSctpInit() + } defer pc.sctpTransport.lock.Unlock() if isPlanB { //nolint:nestif @@ -2860,7 +2867,7 @@ func (pc *PeerConnection) generateUnmatchedSDP( } if pc.sctpTransport.dataChannelsRequested != 0 { - mediaSections = append(mediaSections, mediaSection{id: strconv.Itoa(len(mediaSections)), data: true}) + mediaSections = append(mediaSections, mediaSection{id: strconv.Itoa(len(mediaSections)), data: true, sctpInit: localSctpInit}) } } @@ -2889,7 +2896,7 @@ func (pc *PeerConnection) generateUnmatchedSDP( } // generateMatchedSDP generates a SDP and takes the remote state into account -// this is used everytime we have a RemoteDescription +// This is used everytime we have a RemoteDescription // //nolint:gocognit,gocyclo,cyclop func (pc *PeerConnection) generateMatchedSDP( @@ -2929,6 +2936,7 @@ func (pc *PeerConnection) generateMatchedSDP( mediaSections := []mediaSection{} alreadyHaveApplicationMediaSection := false + var localSctpInit []byte for _, media := range remoteDescription.parsed.MediaDescriptions { midValue := getMidValue(media) if midValue == "" { @@ -2936,7 +2944,14 @@ func (pc *PeerConnection) generateMatchedSDP( } if media.MediaName.Media == mediaSectionApplication { - mediaSections = append(mediaSections, mediaSection{id: midValue, data: true}) + init, _ := getSctpInit(media) + if init != nil && pc.api.settingEngine.sctp.enableSnap { + pc.sctpTransport.lock.Lock() + localSctpInit = pc.sctpTransport.GetSctpInit() + pc.sctpTransport.lock.Unlock() + } + + mediaSections = append(mediaSections, mediaSection{id: midValue, data: true, sctpInit: localSctpInit}) alreadyHaveApplicationMediaSection = true continue @@ -3023,7 +3038,7 @@ func (pc *PeerConnection) generateMatchedSDP( if detectedPlanB { mediaSections = append(mediaSections, mediaSection{id: "data", data: true}) } else { - mediaSections = append(mediaSections, mediaSection{id: strconv.Itoa(len(mediaSections)), data: true}) + mediaSections = append(mediaSections, mediaSection{id: strconv.Itoa(len(mediaSections)), data: true, sctpInit: localSctpInit}) } } } else if remoteDescription != nil { diff --git a/peerconnection_test.go b/peerconnection_test.go index aa722ccce07..6fca7d90e23 100644 --- a/peerconnection_test.go +++ b/peerconnection_test.go @@ -807,6 +807,23 @@ func TestPeerConnection_SessionID(t *testing.T) { closePairNow(t, pcOffer, pcAnswer) } +func TestSctpSnap(t *testing.T) { + s := SettingEngine{} + s.EnableSnap(true) + api := NewAPI(WithSettingEngine(s)) + + offer, err := api.NewPeerConnection(Configuration{}) + assert.NoError(t, err) + answer, err := api.NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + peerConnectionsConnected := untilConnectionState(PeerConnectionStateConnected, offer, answer) + assert.NoError(t, signalPair(offer, answer)) + peerConnectionsConnected.Wait() + + closePairNow(t, offer, answer) +} + func TestICETrickleCapabilityString(t *testing.T) { tests := []struct { value ICETrickleCapability diff --git a/sctpcapabilities.go b/sctpcapabilities.go index c4c5ff5eb8e..e1b8deaae06 100644 --- a/sctpcapabilities.go +++ b/sctpcapabilities.go @@ -6,4 +6,5 @@ package webrtc // SCTPCapabilities indicates the capabilities of the SCTPTransport. type SCTPCapabilities struct { MaxMessageSize uint32 `json:"maxMessageSize"` + SctpInit []byte `json:"sctpInit"` } diff --git a/sctptransport.go b/sctptransport.go index 32617bbcc60..2edff9345d1 100644 --- a/sctptransport.go +++ b/sctptransport.go @@ -53,6 +53,8 @@ type SCTPTransport struct { dataChannelsRequested uint32 dataChannelsAccepted uint32 + localSctpInit []byte + api *API log logging.LeveledLogger } @@ -107,6 +109,7 @@ func (r *SCTPTransport) Start(capabilities SCTPCapabilities) error { if maxMessageSize == 0 { maxMessageSize = sctpMaxMessageSizeUnsetValue } + remoteSctpInit := capabilities.SctpInit dtlsTransport := r.Transport() if dtlsTransport == nil || dtlsTransport.conn == nil { @@ -119,11 +122,14 @@ func (r *SCTPTransport) Start(capabilities SCTPCapabilities) error { LoggerFactory: r.api.settingEngine.LoggerFactory, RTOMax: float64(r.api.settingEngine.sctp.rtoMax) / float64(time.Millisecond), BlockWrite: r.api.settingEngine.detach.DataChannels && r.api.settingEngine.dataChannelBlockWrite, - MaxMessageSize: maxMessageSize, MTU: outboundMTU, MinCwnd: r.api.settingEngine.sctp.minCwnd, FastRtxWnd: r.api.settingEngine.sctp.fastRtxWnd, CwndCAStep: r.api.settingEngine.sctp.cwndCAStep, + }, sctp.SctpParameters{ + MaxMessageSize: maxMessageSize, + LocalSctpInit: r.localSctpInit, + RemoteSctpInit: remoteSctpInit, }) if err != nil { return err @@ -456,3 +462,15 @@ func (r *SCTPTransport) BufferedAmount() int { return r.sctpAssociation.BufferedAmount() } + +// The caller should hold the lock. +func (r *SCTPTransport) GetSctpInit() []byte { + if len(r.localSctpInit) == 0 { + r.localSctpInit, _ = sctp.GenerateOutOfBandToken(sctp.Config{ + MaxReceiveBufferSize: r.api.settingEngine.sctp.maxReceiveBufferSize, + EnableZeroChecksum: r.api.settingEngine.sctp.enableZeroChecksum, + }) + } + + return r.localSctpInit +} diff --git a/sdp.go b/sdp.go index 3a710772fb6..baf01c643a5 100644 --- a/sdp.go +++ b/sdp.go @@ -7,6 +7,7 @@ package webrtc import ( + "encoding/base64" "errors" "fmt" "net/url" @@ -369,6 +370,7 @@ func addDataMediaSection( dtlsRole sdp.ConnectionRole, iceGatheringState ICEGatheringState, sctpMaxMessageSize uint32, + sctpInit []byte, ) error { media := (&sdp.MediaDescription{ MediaName: sdp.MediaName{ @@ -388,10 +390,14 @@ func addDataMediaSection( WithValueAttribute(sdp.AttrKeyConnectionSetup, dtlsRole.String()). WithValueAttribute(sdp.AttrKeyMID, midValue). WithPropertyAttribute(RTPTransceiverDirectionSendrecv.String()). + // TODO: do not hardcode this. WithPropertyAttribute("sctp-port:5000"). WithValueAttribute("max-message-size", fmt.Sprintf("%d", sctpMaxMessageSize)). WithICECredentials(iceParams.UsernameFragment, iceParams.Password) + if len(sctpInit) != 0 { + media = media.WithValueAttribute("sctp-init", base64.StdEncoding.EncodeToString(sctpInit)) + } for _, f := range dtlsFingerprints { media = media.WithFingerprint(f.Algorithm, strings.ToUpper(f.Value)) } @@ -669,6 +675,7 @@ type mediaSection struct { id string transceivers []*RTPTransceiver data bool + sctpInit []byte matchExtensions map[string]int rids []*simulcastRid } @@ -742,6 +749,7 @@ func populateSDP( connectionRole, iceGatheringState, sctpMaxMessageSize, + section.sctpInit, ); err != nil { return nil, err } @@ -1210,3 +1218,20 @@ func getMaxMessageSize(desc *sdp.MediaDescription) uint32 { return 0 } + +func getSctpInit(desc *sdp.MediaDescription) ([]byte, error) { + var err error + var decoded []byte + for _, a := range desc.Attributes { + if strings.TrimSpace(a.Key) == "sctp-init" { + decoded, err = base64.StdEncoding.DecodeString(a.Value) + if err == nil { + return decoded, nil + } + + return nil, err + } + } + + return nil, nil +} diff --git a/settingengine.go b/settingengine.go index ea665cdfbf1..aa6aa6bbfa5 100644 --- a/settingengine.go +++ b/settingengine.go @@ -91,6 +91,7 @@ type SettingEngine struct { minCwnd uint32 fastRtxWnd uint32 cwndCAStep uint32 + enableSnap bool } sdpMediaLevelFingerprints bool answeringDTLSRole DTLSRole @@ -590,6 +591,11 @@ func (e *SettingEngine) EnableSCTPZeroChecksum(isEnabled bool) { e.sctp.enableZeroChecksum = isEnabled } +// EnableSctpSnap enables the use of the SCTP SNAP connect optimization. +func (e *SettingEngine) EnableSctpSnap(isEnabled bool) { + e.sctp.enableSnap = isEnabled +} + // SetSCTPMaxMessageSize sets the largest message we are willing to accept. // Leave this 0 for the default max message size. func (e *SettingEngine) SetSCTPMaxMessageSize(maxMessageSize uint32) {