diff --git a/dtlstransport.go b/dtlstransport.go index 402518217a1..502b819f787 100644 --- a/dtlstransport.go +++ b/dtlstransport.go @@ -44,8 +44,9 @@ type DTLSTransport struct { state DTLSTransportState srtpProtectionProfile srtp.ProtectionProfile - onStateChangeHandler func(DTLSTransportState) - internalOnCloseHandler func() + onStateChangeHandler func(DTLSTransportState) + internalOnStateChangeHandler func(DTLSTransportState) + internalOnCloseHandler func() conn *dtls.Conn @@ -120,8 +121,10 @@ func (t *DTLSTransport) ICETransport() *ICETransport { // onStateChange requires the caller holds the lock. func (t *DTLSTransport) onStateChange(state DTLSTransportState) { t.state = state - handler := t.onStateChangeHandler - if handler != nil { + if handler := t.onStateChangeHandler; handler != nil { + handler(state) + } + if handler := t.internalOnStateChangeHandler; handler != nil { handler(state) } } @@ -386,10 +389,18 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint: parsedRemoteCert, parseErr := x509.ParseCertificate(t.remoteCertificate) if parseErr != nil { + t.onStateChange(DTLSTransportStateFailed) + return parseErr } - return t.validateFingerPrint(parsedRemoteCert) + if fpErr := t.validateFingerPrint(parsedRemoteCert); fpErr != nil { + t.onStateChange(DTLSTransportStateFailed) + + return fpErr + } + + return nil } // Connect as DTLS Client/Server, function is blocking and we diff --git a/dtlstransport_test.go b/dtlstransport_test.go index 8820df390c4..cdc42c5ff90 100644 --- a/dtlstransport_test.go +++ b/dtlstransport_test.go @@ -67,9 +67,44 @@ func TestInvalidFingerprintCausesFailed(t *testing.T) { //nolint:cyclop } }) - // Also wait for PeerConnection to close (may take longer due to cleanup) - offerConnectionHasClosed := untilConnectionState(PeerConnectionStateClosed, pcOffer) - answerConnectionHasClosed := untilConnectionState(PeerConnectionStateClosed, pcAnswer) + // Track PeerConnection state transitions - need single handler since OnConnectionStateChange replaces + offerConnectionHasFailed := make(chan struct{}) + offerConnectionHasClosed := make(chan struct{}) + pcOffer.OnConnectionStateChange(func(state PeerConnectionState) { + if state == PeerConnectionStateFailed { + select { + case <-offerConnectionHasFailed: + default: + close(offerConnectionHasFailed) + } + } + if state == PeerConnectionStateClosed { + select { + case <-offerConnectionHasClosed: + default: + close(offerConnectionHasClosed) + } + } + }) + + answerConnectionHasFailed := make(chan struct{}) + answerConnectionHasClosed := make(chan struct{}) + pcAnswer.OnConnectionStateChange(func(state PeerConnectionState) { + if state == PeerConnectionStateFailed { + select { + case <-answerConnectionHasFailed: + default: + close(answerConnectionHasFailed) + } + } + if state == PeerConnectionStateClosed { + select { + case <-answerConnectionHasClosed: + default: + close(answerConnectionHasClosed) + } + } + }) _, err = pcOffer.CreateDataChannel("unusedDataChannel", nil) assert.NoError(t, err) @@ -119,9 +154,35 @@ func TestInvalidFingerprintCausesFailed(t *testing.T) { //nolint:cyclop assert.Fail(t, "timed out waiting for answer DTLS to fail") } + // Verify PeerConnection state transitions to "failed" (per W3C WebRTC spec) + // "Any of the RTCIceTransports or RTCDtlsTransports are in a 'failed' state" + // should result in PeerConnectionStateFailed + select { + case <-offerConnectionHasFailed: + // Expected - offer PeerConnection reached failed state + case <-time.After(7 * time.Second): + assert.Fail(t, "timed out waiting for offer PeerConnection to reach failed state") + } + + select { + case <-answerConnectionHasFailed: + // Expected - answer PeerConnection reached failed state + case <-time.After(7 * time.Second): + assert.Fail(t, "timed out waiting for answer PeerConnection to reach failed state") + } + // Wait for PeerConnection to close (may take longer due to cleanup) - offerConnectionHasClosed.Wait() - answerConnectionHasClosed.Wait() + select { + case <-offerConnectionHasClosed: + case <-time.After(7 * time.Second): + assert.Fail(t, "timed out waiting for offer PeerConnection to close") + } + + select { + case <-answerConnectionHasClosed: + case <-time.After(7 * time.Second): + assert.Fail(t, "timed out waiting for answer PeerConnection to close") + } assert.Contains( t, []DTLSTransportState{DTLSTransportStateClosed, DTLSTransportStateFailed}, pcOffer.SCTP().Transport().State(), diff --git a/peerconnection.go b/peerconnection.go index d6ba18eb28f..1e99b620983 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -187,6 +187,11 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection, } pc.dtlsTransport = dtlsTransport + // Wire up DTLS state change handler to update PeerConnection state + pc.dtlsTransport.internalOnStateChangeHandler = func(state DTLSTransportState) { + pc.updateConnectionState(pc.ICEConnectionState(), state) + } + // Create the SCTP transport pc.sctpTransport = pc.api.NewSCTPTransport(pc.dtlsTransport) @@ -2425,6 +2430,18 @@ func (pc *PeerConnection) close(shouldGracefullyClose bool) error { //nolint:cyc // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #1) // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #2) + // Per W3C WebRTC spec, if DTLS is in "failed" state, the PeerConnection should + // transition to "failed" before "closed". Check and report failed state before + // setting isClosed, which would prevent the failed state from being reported. + // Also check for "connecting" state - if DTLS handshake started but never completed + // when close is called, treat it as a failure (e.g., remote peer closed during handshake). + if pc.dtlsTransport != nil { + dtlsState := pc.dtlsTransport.State() + if dtlsState == DTLSTransportStateFailed || dtlsState == DTLSTransportStateConnecting { + pc.updateConnectionState(pc.ICEConnectionState(), DTLSTransportStateFailed) + } + } + pc.mu.Lock() // A lock in this critical section is needed because pc.isClosed and // pc.isGracefullyClosingOrClosed are related to each other in that we