diff --git a/icegatherer.go b/icegatherer.go index 0fe7e876d41..100ac32252d 100644 --- a/icegatherer.go +++ b/icegatherer.go @@ -114,6 +114,40 @@ func (api *API) NewICEGatherer(opts ICEGatherOptions) (*ICEGatherer, error) { }, nil } +// updateServers updates the ICE servers and gather policy. +// If called before gathering starts, the new servers will be used for initial gathering. +// If called after gathering has started, the new servers will be used on the next ICE restart. +func (g *ICEGatherer) updateServers(servers []ICEServer, policy ICETransportPolicy) error { + g.lock.Lock() + defer g.lock.Unlock() + + var validatedServers []*stun.URI + for _, server := range servers { + urls, err := server.urls() + if err != nil { + return err + } + validatedServers = append(validatedServers, urls...) + } + + g.validatedServers = validatedServers + g.gatherPolicy = policy + + if g.agent != nil { + return g.agent.UpdateOptions(ice.WithUrls(validatedServers)) + } + + return nil +} + +// validatedServersCount returns the number of validated ICE server URLs. +func (g *ICEGatherer) validatedServersCount() int { + g.lock.RLock() + defer g.lock.RUnlock() + + return len(g.validatedServers) +} + func (g *ICEGatherer) createAgent() error { g.lock.Lock() defer g.lock.Unlock() diff --git a/icegatherer_test.go b/icegatherer_test.go index e50f79e3a36..3506f1b941e 100644 --- a/icegatherer_test.go +++ b/icegatherer_test.go @@ -111,6 +111,26 @@ func TestICEGatherer_InvalidMDNSHostName(t *testing.T) { assert.ErrorIs(t, err, ice.ErrInvalidMulticastDNSHostName) } +func TestICEGatherer_updateServers(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + gatherer, err := NewAPI().NewICEGatherer(ICEGatherOptions{}) + require.NoError(t, err) + + assert.Equal(t, 0, gatherer.validatedServersCount()) + + newServers := []ICEServer{{URLs: []string{"stun:stun.l.google.com:19302"}}} + err = gatherer.updateServers(newServers, ICETransportPolicyAll) + assert.NoError(t, err) + assert.Equal(t, 1, gatherer.validatedServersCount()) + + assert.NoError(t, gatherer.Close()) +} + func TestLegacyNAT1To1AddressRewriteRules(t *testing.T) { t.Run("empty", func(t *testing.T) { assert.Empty(t, legacyNAT1To1AddressRewriteRules(nil, ice.CandidateTypeHost)) diff --git a/peerconnection.go b/peerconnection.go index f331d428d5a..2cd6215effe 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -519,13 +519,14 @@ func (pc *PeerConnection) onConnectionStateChange(cs PeerConnectionState) { } // SetConfiguration updates the configuration of this PeerConnection object. +// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-setconfiguration func (pc *PeerConnection) SetConfiguration(configuration Configuration) error { //nolint:gocognit,cyclop // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-setconfiguration (step #2) if pc.isClosed.Load() { return &rtcerr.InvalidStateError{Err: ErrConnectionClosed} } - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #3) + // Not in W3C spec, but we validate PeerIdentity cannot be modified. if configuration.PeerIdentity != "" { if configuration.PeerIdentity != pc.configuration.PeerIdentity { return &rtcerr.InvalidModificationError{Err: ErrModifyingPeerIdentity} @@ -533,7 +534,7 @@ func (pc *PeerConnection) SetConfiguration(configuration Configuration) error { pc.configuration.PeerIdentity = configuration.PeerIdentity } - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #4) + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #3.1 - #3.3) if len(configuration.Certificates) > 0 { if len(configuration.Certificates) != len(pc.configuration.Certificates) { return &rtcerr.InvalidModificationError{Err: ErrModifyingCertificates} @@ -547,7 +548,7 @@ func (pc *PeerConnection) SetConfiguration(configuration Configuration) error { pc.configuration.Certificates = configuration.Certificates } - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #5) + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #3.4) if configuration.BundlePolicy != BundlePolicyUnknown { if configuration.BundlePolicy != pc.configuration.BundlePolicy { return &rtcerr.InvalidModificationError{Err: ErrModifyingBundlePolicy} @@ -555,7 +556,7 @@ func (pc *PeerConnection) SetConfiguration(configuration Configuration) error { pc.configuration.BundlePolicy = configuration.BundlePolicy } - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #6) + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #3.5) if configuration.RTCPMuxPolicy != RTCPMuxPolicyUnknown { if configuration.RTCPMuxPolicy != pc.configuration.RTCPMuxPolicy { return &rtcerr.InvalidModificationError{Err: ErrModifyingRTCPMuxPolicy} @@ -563,7 +564,7 @@ func (pc *PeerConnection) SetConfiguration(configuration Configuration) error { pc.configuration.RTCPMuxPolicy = configuration.RTCPMuxPolicy } - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #7) + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #3.6) if configuration.ICECandidatePoolSize != 0 { if pc.configuration.ICECandidatePoolSize != configuration.ICECandidatePoolSize && pc.LocalDescription() != nil { @@ -572,20 +573,29 @@ func (pc *PeerConnection) SetConfiguration(configuration Configuration) error { pc.configuration.ICECandidatePoolSize = configuration.ICECandidatePoolSize } - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #8) + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #4-6) + for _, server := range configuration.ICEServers { + if err := server.validate(); err != nil { + return err + } + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #7) pc.configuration.ICETransportPolicy = configuration.ICETransportPolicy - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11) - if len(configuration.ICEServers) > 0 { - // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3) - for _, server := range configuration.ICEServers { - if err := server.validate(); err != nil { - return err - } + // Step #8: ICE candidate pool size is not implemented in pion/webrtc. + // The value is stored in configuration but candidate pooling is not supported. + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #9) + // Update the ICE gatherer so new servers take effect at the next gathering phase. + if pc.iceGatherer != nil { + if err := pc.iceGatherer.updateServers(configuration.ICEServers, pc.configuration.ICETransportPolicy); err != nil { + pc.log.Debugf("Could not update ICE gatherer servers: %v", err) } - pc.configuration.ICEServers = configuration.ICEServers } + pc.configuration.ICEServers = configuration.ICEServers + return nil } diff --git a/peerconnection_go_test.go b/peerconnection_go_test.go index 83d2fb81fcc..50702540887 100644 --- a/peerconnection_go_test.go +++ b/peerconnection_go_test.go @@ -30,6 +30,7 @@ import ( "github.com/pion/rtp" "github.com/pion/transport/v4/test" "github.com/pion/transport/v4/vnet" + "github.com/pion/turn/v4" "github.com/pion/webrtc/v4/internal/util" "github.com/pion/webrtc/v4/pkg/rtcerr" "github.com/stretchr/testify/assert" @@ -204,6 +205,10 @@ func TestPeerConnection_SetConfiguration_Go(t *testing.T) { return pc, err } + if pc.iceGatherer.validatedServersCount() != 0 { + return pc, fmt.Errorf("%w: expected 0 validated servers", ErrUnknownType) + } + err = pc.SetConfiguration(Configuration{ ICEServers: []ICEServer{ { @@ -230,6 +235,11 @@ func TestPeerConnection_SetConfiguration_Go(t *testing.T) { return pc, err } + // Verify ICE gatherer received the new servers. + if pc.iceGatherer.validatedServersCount() != 2 { + return pc, fmt.Errorf("%w: expected 2 validated servers", ErrUnknownType) + } + return pc, nil }, config: Configuration{}, @@ -1040,6 +1050,140 @@ func TestICERestart_Error_Handling(t *testing.T) { closePairNow(t, offerPeerConnection, answerPeerConnection) } +func TestPeerConnection_ICERestart_SetConfiguration_NewServers(t *testing.T) { + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + // Set up vnet with a STUN server. + const ( + offerIP = "1.2.3.4" + answerIP = "1.2.3.5" + stunIP = "1.2.3.100" + stunPort = 3478 + ) + + loggerFactory := logging.NewDefaultLoggerFactory() + + wan, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "1.2.3.0/24", + LoggerFactory: loggerFactory, + }) + assert.NoError(t, err) + + offerNet, err := vnet.NewNet(&vnet.NetConfig{StaticIPs: []string{offerIP}}) + assert.NoError(t, err) + answerNet, err := vnet.NewNet(&vnet.NetConfig{StaticIPs: []string{answerIP}}) + assert.NoError(t, err) + stunNet, err := vnet.NewNet(&vnet.NetConfig{StaticIPs: []string{stunIP}}) + assert.NoError(t, err) + + assert.NoError(t, wan.AddNet(offerNet)) + assert.NoError(t, wan.AddNet(answerNet)) + assert.NoError(t, wan.AddNet(stunNet)) + assert.NoError(t, wan.Start()) + + // Create STUN server. + stunListener, err := stunNet.ListenPacket("udp4", fmt.Sprintf("%s:%d", stunIP, stunPort)) + assert.NoError(t, err) + stunServer, err := turn.NewServer(turn.ServerConfig{ + Realm: "pion.ly", + LoggerFactory: loggerFactory, + PacketConnConfigs: []turn.PacketConnConfig{ + { + PacketConn: stunListener, + RelayAddressGenerator: &turn.RelayAddressGeneratorStatic{ + RelayAddress: net.ParseIP(stunIP), + Address: "0.0.0.0", + Net: stunNet, + }, + }, + }, + }) + assert.NoError(t, err) + + // Create peer connections. + offerSettingEngine := SettingEngine{} + offerSettingEngine.SetNet(offerNet) + offerSettingEngine.SetICETimeouts(time.Second, time.Second, time.Millisecond*200) + + answerSettingEngine := SettingEngine{} + answerSettingEngine.SetNet(answerNet) + answerSettingEngine.SetICETimeouts(time.Second, time.Second, time.Millisecond*200) + + offerPC, err := NewAPI(WithSettingEngine(offerSettingEngine)).NewPeerConnection(Configuration{}) + assert.NoError(t, err) + answerPC, err := NewAPI(WithSettingEngine(answerSettingEngine)).NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + _, err = offerPC.CreateDataChannel("test", nil) + assert.NoError(t, err) + + // Initial negotiation without STUN servers. + offer, err := offerPC.CreateOffer(nil) + assert.NoError(t, err) + + offerGatherComplete := GatheringCompletePromise(offerPC) + assert.NoError(t, offerPC.SetLocalDescription(offer)) + <-offerGatherComplete + + // Verify initial offer has no srflx candidates. + assert.NotContains(t, offerPC.LocalDescription().SDP, "srflx", + "should not have srflx candidates without STUN servers") + + assert.NoError(t, answerPC.SetRemoteDescription(*offerPC.LocalDescription())) + + answer, err := answerPC.CreateAnswer(nil) + assert.NoError(t, err) + + answerGatherComplete := GatheringCompletePromise(answerPC) + assert.NoError(t, answerPC.SetLocalDescription(answer)) + <-answerGatherComplete + + assert.NoError(t, offerPC.SetRemoteDescription(*answerPC.LocalDescription())) + + // Update configuration with local STUN server. + stunURL := fmt.Sprintf("stun:%s:%d", stunIP, stunPort) + newConfig := Configuration{ + ICEServers: []ICEServer{ + {URLs: []string{stunURL}}, + }, + } + assert.Equal(t, 0, offerPC.iceGatherer.validatedServersCount()) + err = offerPC.SetConfiguration(newConfig) + assert.NoError(t, err) + assert.Equal(t, 1, offerPC.iceGatherer.validatedServersCount()) + + // Trigger ICE restart. + offer, err = offerPC.CreateOffer(&OfferOptions{ICERestart: true}) + assert.NoError(t, err) + + offerGatherComplete = GatheringCompletePromise(offerPC) + assert.NoError(t, offerPC.SetLocalDescription(offer)) + <-offerGatherComplete + + // Verify the offer now has srflx candidates from the STUN server. + assert.Contains(t, offerPC.LocalDescription().SDP, "srflx", + "should have srflx candidates after restart with STUN servers") + + assert.NoError(t, answerPC.SetRemoteDescription(*offerPC.LocalDescription())) + + answer, err = answerPC.CreateAnswer(nil) + assert.NoError(t, err) + + answerGatherComplete = GatheringCompletePromise(answerPC) + assert.NoError(t, answerPC.SetLocalDescription(answer)) + <-answerGatherComplete + + assert.NoError(t, offerPC.SetRemoteDescription(*answerPC.LocalDescription())) + + assert.NoError(t, stunServer.Close()) + assert.NoError(t, wan.Stop()) + closePairNow(t, offerPC, answerPC) +} + type trackRecords struct { mu sync.Mutex trackIDs map[string]struct{}