Skip to content

Commit b0d56c6

Browse files
committed
Implement DTLS restart
Fixes #1636
1 parent 7948437 commit b0d56c6

File tree

3 files changed

+245
-1
lines changed

3 files changed

+245
-1
lines changed

dtlstransport.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,31 @@ func (t *DTLSTransport) startSRTP() error {
213213
return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
214214
}
215215

216+
isAlreadyRunning := func() bool {
217+
select {
218+
case <-t.srtpReady:
219+
return true
220+
default:
221+
return false
222+
}
223+
}()
224+
225+
if isAlreadyRunning {
226+
if sess, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok {
227+
if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil {
228+
return updateErr
229+
}
230+
}
231+
232+
if sess, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok {
233+
if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil {
234+
return updateErr
235+
}
236+
}
237+
238+
return nil
239+
}
240+
216241
srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
217242
if err != nil {
218243
return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
@@ -283,7 +308,7 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
283308
return DTLSRole(0), nil, err
284309
}
285310

286-
if t.state != DTLSTransportStateNew {
311+
if t.state != DTLSTransportStateNew && t.state != DTLSTransportStateClosed {
287312
return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)}
288313
}
289314

peerconnection.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,41 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
11081108
pc.ops.Enqueue(func() {
11091109
pc.startRTP(true, &desc, currentTransceivers)
11101110
})
1111+
} else if pc.dtlsTransport.State() != DTLSTransportStateNew {
1112+
fingerprint, fingerprintHash, fErr := extractFingerprint(desc.parsed)
1113+
if fErr != nil {
1114+
return fErr
1115+
}
1116+
1117+
fingerPrintDidChange := true
1118+
1119+
for _, fp := range pc.dtlsTransport.remoteParameters.Fingerprints {
1120+
if fingerprint == fp.Value && fingerprintHash == fp.Algorithm {
1121+
fingerPrintDidChange = false
1122+
break
1123+
}
1124+
}
1125+
1126+
if fingerPrintDidChange {
1127+
pc.ops.Enqueue(func() {
1128+
if dErr := pc.dtlsTransport.Stop(); dErr != nil {
1129+
pc.log.Warnf("Failed to stop DTLS: %s", dErr)
1130+
}
1131+
1132+
// Restart the dtls transport with updated fingerprints
1133+
err = pc.dtlsTransport.Start(DTLSParameters{
1134+
Role: dtlsRoleFromRemoteSDP(desc.parsed),
1135+
Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}},
1136+
})
1137+
pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State())
1138+
if err != nil {
1139+
pc.log.Warnf("Failed to restart DTLS: %s", err)
1140+
return
1141+
}
1142+
})
1143+
}
11111144
}
1145+
11121146
return nil
11131147
}
11141148

peerconnection_media_test.go

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ import (
1414
"testing"
1515
"time"
1616

17+
"github.com/pion/logging"
1718
"github.com/pion/randutil"
1819
"github.com/pion/rtcp"
1920
"github.com/pion/rtp"
2021
"github.com/pion/transport/test"
22+
"github.com/pion/transport/vnet"
2123
"github.com/pion/webrtc/v3/pkg/media"
2224
"github.com/stretchr/testify/assert"
2325
"github.com/stretchr/testify/require"
@@ -1052,3 +1054,186 @@ func TestPeerConnection_RaceReplaceTrack(t *testing.T) {
10521054

10531055
assert.NoError(t, pc.Close())
10541056
}
1057+
1058+
// Issue #1636
1059+
func TestPeerConnection_DTLS_Restart(t *testing.T) {
1060+
lim := test.TimeOut(time.Second * 30)
1061+
defer lim.Stop()
1062+
1063+
// First prepare network configuration
1064+
1065+
router, err := vnet.NewRouter(&vnet.RouterConfig{
1066+
CIDR: "0.0.0.0/0",
1067+
LoggerFactory: logging.NewDefaultLoggerFactory(),
1068+
})
1069+
assert.NoError(t, err)
1070+
1071+
networkA1 := vnet.NewNet(&vnet.NetConfig{
1072+
NetworkConditioner: vnet.NewNetworkConditioner(vnet.NetworkConditionerPresetNone),
1073+
})
1074+
1075+
networkA2 := vnet.NewNet(&vnet.NetConfig{
1076+
NetworkConditioner: vnet.NewNetworkConditioner(vnet.NetworkConditionerPresetNone),
1077+
})
1078+
1079+
networkB := vnet.NewNet(&vnet.NetConfig{
1080+
NetworkConditioner: vnet.NewNetworkConditioner(vnet.NetworkConditionerPresetNone),
1081+
})
1082+
1083+
assert.NoError(t, router.AddNet(networkA1))
1084+
assert.NoError(t, router.AddNet(networkA2))
1085+
assert.NoError(t, router.AddNet(networkB))
1086+
1087+
assert.NoError(t, router.Start())
1088+
defer func() { _ = router.Stop() }()
1089+
1090+
// ... then the clients
1091+
1092+
makeClient := func(network *vnet.Net) (*PeerConnection, *TrackLocalStaticSample) {
1093+
m := &MediaEngine{}
1094+
assert.NoError(t, m.RegisterDefaultCodecs())
1095+
1096+
s := SettingEngine{}
1097+
s.SetVNet(network)
1098+
s.SetICETimeouts(2*time.Second, 5*time.Second, 1*time.Second)
1099+
1100+
api := NewAPI(WithSettingEngine(s), WithMediaEngine(m))
1101+
pc, cliErr := api.NewPeerConnection(Configuration{})
1102+
assert.NoError(t, cliErr)
1103+
1104+
track, cliErr := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeOpus}, "audio", "test-client")
1105+
assert.NoError(t, cliErr)
1106+
1107+
_, cliErr = pc.AddTrack(track)
1108+
assert.NoError(t, cliErr)
1109+
1110+
return pc, track
1111+
}
1112+
1113+
clientA1, _ := makeClient(networkA1)
1114+
defer func() { _ = clientA1.Close() }()
1115+
1116+
clientB, localClientBTrack := makeClient(networkB)
1117+
defer func() { _ = clientB.Close() }()
1118+
1119+
// ... clientB starts publishing media
1120+
publishClientBCtx, publishCancel := context.WithCancel(context.Background())
1121+
go func() {
1122+
ticker := time.NewTicker(20 * time.Millisecond)
1123+
defer ticker.Stop()
1124+
1125+
for {
1126+
select {
1127+
case <-publishClientBCtx.Done():
1128+
return
1129+
case <-ticker.C:
1130+
_ = localClientBTrack.WriteSample(media.Sample{
1131+
Data: []byte{0xbb},
1132+
Timestamp: time.Now(),
1133+
Duration: 20 * time.Millisecond,
1134+
})
1135+
}
1136+
}
1137+
}()
1138+
defer publishCancel()
1139+
1140+
clientA1Tracks := make(chan *TrackRemote, 1)
1141+
clientA1.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) {
1142+
clientA1Tracks <- remote
1143+
})
1144+
1145+
// ClientA1 connects to ClientB
1146+
1147+
gatherCompletePromiseA1 := GatheringCompletePromise(clientA1)
1148+
offerA1, err := clientA1.CreateOffer(nil)
1149+
assert.NoError(t, err)
1150+
assert.NoError(t, clientA1.SetLocalDescription(offerA1))
1151+
<-gatherCompletePromiseA1
1152+
1153+
assert.NoError(t, clientB.SetRemoteDescription(*clientA1.LocalDescription()))
1154+
1155+
gatherCompletePromiseB := GatheringCompletePromise(clientB)
1156+
answerB, err := clientB.CreateAnswer(nil)
1157+
assert.NoError(t, err)
1158+
assert.NoError(t, clientB.SetLocalDescription(answerB))
1159+
<-gatherCompletePromiseB
1160+
1161+
clientA1Connected := make(chan struct{}, 1)
1162+
clientA1Disconnected := make(chan struct{}, 1)
1163+
clientA1.OnICEConnectionStateChange(func(s ICEConnectionState) {
1164+
if s == ICEConnectionStateConnected {
1165+
clientA1Connected <- struct{}{}
1166+
} else if s == ICEConnectionStateDisconnected {
1167+
clientA1Disconnected <- struct{}{}
1168+
}
1169+
})
1170+
1171+
assert.NoError(t, clientA1.SetRemoteDescription(answerB))
1172+
1173+
// Wait for connection
1174+
<-clientA1Connected
1175+
1176+
// At this point, clientA1 should have received a track, and some media
1177+
clientA1RemoteTrack := <-clientA1Tracks
1178+
pkt, _, err := clientA1RemoteTrack.ReadRTP()
1179+
assert.NotNil(t, pkt)
1180+
assert.NoError(t, err)
1181+
1182+
networkA1.SetNetworkConditioner(vnet.NewNetworkConditioner(vnet.NetworkConditionerPresetFullLoss))
1183+
1184+
<-clientA1Disconnected
1185+
1186+
// ClientA1 has been disconnected – in a mobile app context, this could be a switch to the background
1187+
// or a killed app.
1188+
//
1189+
// In these scenarios, the client will reconnect with a different PeerConnection – here ClientA2.
1190+
1191+
clientA2, _ := makeClient(networkA2)
1192+
defer func() { _ = clientA2.Close() }()
1193+
1194+
clientA2Connected := make(chan struct{}, 1)
1195+
clientA2.OnICEConnectionStateChange(func(s ICEConnectionState) {
1196+
if s == ICEConnectionStateConnected {
1197+
clientA2Connected <- struct{}{}
1198+
} else if s == ICEConnectionStateFailed {
1199+
assert.FailNow(t, "should not fail")
1200+
}
1201+
})
1202+
1203+
clientA2Tracks := make(chan *TrackRemote, 1)
1204+
clientA2.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) {
1205+
clientA2Tracks <- remote
1206+
})
1207+
1208+
// ClientA2 connects to ClientB
1209+
1210+
gatherCompletePromiseA2 := GatheringCompletePromise(clientA2)
1211+
// We can't do an ICE Restart here, since it's a different PeerConnection
1212+
offerA2, err := clientA2.CreateOffer(nil)
1213+
assert.NoError(t, err)
1214+
assert.NoError(t, clientA2.SetLocalDescription(offerA2))
1215+
<-gatherCompletePromiseA2
1216+
1217+
assert.NoError(t, clientB.SetRemoteDescription(*clientA2.LocalDescription()))
1218+
1219+
gatherCompletePromiseB = GatheringCompletePromise(clientB)
1220+
answerB, err = clientB.CreateAnswer(nil)
1221+
assert.NoError(t, err)
1222+
assert.NoError(t, clientB.SetLocalDescription(answerB))
1223+
<-gatherCompletePromiseB
1224+
1225+
assert.NoError(t, clientA2.SetRemoteDescription(answerB))
1226+
1227+
// Wait for connection
1228+
<-clientA2Connected
1229+
1230+
// At this point, clientA2 should have received a track, and some media
1231+
clientA2RemoteTrack := <-clientA2Tracks
1232+
1233+
// Read a bunch of RTPs
1234+
for ndx := 0; ndx < 10; ndx++ {
1235+
pkt, _, err = clientA2RemoteTrack.ReadRTP()
1236+
assert.NotNil(t, pkt)
1237+
assert.NoError(t, err)
1238+
}
1239+
}

0 commit comments

Comments
 (0)