diff --git a/; b/; new file mode 100644 index 000000000..c509105da --- /dev/null +++ b/; @@ -0,0 +1,25 @@ +use tokio::sync::Mutex; +use arc_swap::ArcSwapOption; + +pub struct EventHandler { + inner: ArcSwapOption>>, +} + +impl EventHandler { + fn empty() -> Self { + Self { inner: ArcSwapOption::empty() } + } + + async fn load(&self) -> Option<&T> { + + let guard = self.inner.load(); + let handler = guard.as_ref()?.lock().await; + //Some(&*handler) + } +} + +impl Default for EventHandler { + fn default() -> Self { + Self::empty() + } +} diff --git a/Cargo.lock b/Cargo.lock index 0efe328bd..9474e0ec4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3004,6 +3004,7 @@ dependencies = [ name = "webrtc-util" version = "0.8.0" dependencies = [ + "arc-swap", "async-global-executor", "async-trait", "bitflags 1.3.2", diff --git a/examples/examples/broadcast/broadcast.rs b/examples/examples/broadcast/broadcast.rs index 2626863d4..082792396 100644 --- a/examples/examples/broadcast/broadcast.rs +++ b/examples/examples/broadcast/broadcast.rs @@ -1,8 +1,9 @@ use std::io::Write; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use anyhow::Result; use clap::{AppSettings, Arg, Command}; +use std::future::Future; use tokio::time::Duration; use webrtc::api::interceptor_registry::register_default_interceptors; use webrtc::api::media_engine::MediaEngine; @@ -12,12 +13,94 @@ use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::{PeerConnectionEventHandler, RTCPeerConnection}; use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; use webrtc::rtp_transceiver::rtp_codec::RTPCodecType; use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; use webrtc::Error; +use webrtc::rtp_transceiver::rtp_receiver::RTCRtpReceiver; +use webrtc::rtp_transceiver::RTCRtpTransceiver; +use webrtc::track::track_remote::TrackRemote; + +struct ConnectionHandler { + connection: Weak, + local_track_chan_tx: Arc>>, +} + +impl PeerConnectionEventHandler for ConnectionHandler { + // Set a handler for when a new remote track starts, this handler copies inbound RTP packets, + // replaces the SSRC and sends them back + fn on_track( + &mut self, + track: Arc, + _: Arc, + _: Arc, + ) -> impl Future + Send { + async move { + // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval + // This is a temporary fix until we implement incoming RTCP events, then we would push a PLI only when a viewer requests it + let connection = self.connection.clone(); + let media_ssrc = track.ssrc(); + tokio::spawn(async move { + let mut result = Result::::Ok(0); + while result.is_ok() { + let timeout = tokio::time::sleep(Duration::from_secs(3)); + tokio::pin!(timeout); + + tokio::select! { + _ = timeout.as_mut() =>{ + if let Some(pc) = connection.upgrade(){ + result = pc.write_rtcp(&[Box::new(PictureLossIndication{ + sender_ssrc: 0, + media_ssrc, + })]).await.map_err(Into::into); + }else{ + break; + } + } + }; + } + }); + + let local_track_chan_tx = self.local_track_chan_tx.clone(); + tokio::spawn(async move { + // Create Track that we send video back to browser on + let local_track = Arc::new(TrackLocalStaticRTP::new( + track.codec().capability, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + let _ = local_track_chan_tx.send(local_track.clone()).await; + + // Read RTP packets being sent to webrtc-rs + while let Ok((rtp, _)) = track.read_rtp().await { + if let Err(err) = local_track.write_rtp(&rtp).await { + if Error::ErrClosedPipe != err { + print!("output track write_rtp got error: {err} and break"); + break; + } else { + print!("output track write_rtp got error: {err}"); + } + } + } + }); + } + } + + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + println!("Peer Connection State has changed: {state}"); + } + } +} + #[tokio::main] async fn main() -> Result<()> { let mut app = Command::new("broadcast") @@ -126,64 +209,11 @@ async fn main() -> Result<()> { // Set a handler for when a new remote track starts, this handler copies inbound RTP packets, // replaces the SSRC and sends them back let pc = Arc::downgrade(&peer_connection); - peer_connection.on_track(Box::new(move |track, _, _| { - // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval - // This is a temporary fix until we implement incoming RTCP events, then we would push a PLI only when a viewer requests it - let media_ssrc = track.ssrc(); - let pc2 = pc.clone(); - tokio::spawn(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(3)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - if let Some(pc) = pc2.upgrade(){ - result = pc.write_rtcp(&[Box::new(PictureLossIndication{ - sender_ssrc: 0, - media_ssrc, - })]).await.map_err(Into::into); - }else{ - break; - } - } - }; - } - }); - - let local_track_chan_tx2 = Arc::clone(&local_track_chan_tx); - tokio::spawn(async move { - // Create Track that we send video back to browser on - let local_track = Arc::new(TrackLocalStaticRTP::new( - track.codec().capability, - "video".to_owned(), - "webrtc-rs".to_owned(), - )); - let _ = local_track_chan_tx2.send(Arc::clone(&local_track)).await; - - // Read RTP packets being sent to webrtc-rs - while let Ok((rtp, _)) = track.read_rtp().await { - if let Err(err) = local_track.write_rtp(&rtp).await { - if Error::ErrClosedPipe != err { - print!("output track write_rtp got error: {err} and break"); - break; - } else { - print!("output track write_rtp got error: {err}"); - } - } - } - }); - Box::pin(async {}) - })); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - Box::pin(async {}) - })); + peer_connection.with_event_handler(ConnectionHandler { + connection: pc, + local_track_chan_tx, + }); // Set the remote SessionDescription peer_connection.set_remote_description(offer).await?; @@ -264,14 +294,21 @@ async fn main() -> Result<()> { Result::<()>::Ok(()) }); - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new( - move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - Box::pin(async {}) - }, - )); + struct StateNotifier; + + impl PeerConnectionEventHandler for StateNotifier { + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + println!("Peer Connection State has changed: {state}"); + } + } + } + peer_connection.with_event_handler(StateNotifier); // Set the remote SessionDescription peer_connection diff --git a/examples/examples/data-channels-close/data-channels-close.rs b/examples/examples/data-channels-close/data-channels-close.rs index ce76b605f..3fcfd5da8 100644 --- a/examples/examples/data-channels-close/data-channels-close.rs +++ b/examples/examples/data-channels-close/data-channels-close.rs @@ -1,3 +1,4 @@ +use std::future::Future; use std::io::Write; use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::Arc; @@ -10,13 +11,119 @@ use webrtc::api::interceptor_registry::register_default_interceptors; use webrtc::api::media_engine::MediaEngine; use webrtc::api::APIBuilder; use webrtc::data_channel::data_channel_message::DataChannelMessage; -use webrtc::data_channel::RTCDataChannel; +use webrtc::data_channel::{RTCDataChannel, RTCDataChannelEventHandler}; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::math_rand_alpha; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::PeerConnectionEventHandler; + +struct ConnectionHandler { + done_tx: Arc>, + close_after: Arc, +} + +impl PeerConnectionEventHandler for ConnectionHandler { + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); + } + } + } + + // Register data channel creation handling + fn on_data_channel(&mut self, channel: Arc) -> impl Future + Send { + async move { + let d_label = channel.label().to_owned(); + let d_id = channel.id(); + println!("New DataChannel {d_label} {d_id}"); + + let (done_tx, done_rx) = tokio::sync::mpsc::channel::<()>(1); + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + + let dc_handle = DataChannelHandler { + label: d_label, + id: d_id, + done_rx, + done_tx, + channel: channel.clone(), + close_after: self.close_after.clone(), + }; + channel.with_event_handler(dc_handle); + } + } +} + +struct DataChannelHandler { + label: String, + id: u16, + done_rx: tokio::sync::mpsc::Receiver<()>, + done_tx: Arc>>>, + channel: Arc, + close_after: Arc, +} + +impl RTCDataChannelEventHandler for DataChannelHandler { + // Register text message handling + fn on_message(&mut self, msg: DataChannelMessage) -> impl Future + Send { + async move { + let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); + println!("Message from DataChannel '{}': '{msg_str}'", self.label); + } + } + + // Register channel opening handling + fn on_open(&mut self) -> impl Future + Send { + async move { + println!("Data channel '{}'-'{}' open. Random messages will now be sent to any connected DataChannels every 5 seconds", self.label, self.id); + + let mut result = Result::::Ok(0); + while result.is_ok() { + let timeout = tokio::time::sleep(Duration::from_secs(5)); + tokio::pin!(timeout); + + tokio::select! { + _ = self.done_rx.recv() => { + break; + } + _ = timeout.as_mut() =>{ + let message = math_rand_alpha(15); + println!("Sending '{message}'"); + result = self.channel.send_text(message).await.map_err(Into::into); + + let cnt = self.close_after.fetch_sub(1, Ordering::SeqCst); + if cnt <= 0 { + println!("Sent times out. Closing data channel '{}'-'{}'.", self.label, self.id); + let _ = self.channel.close().await; + break; + } + } + }; + } + } + } + + // Register channel closing handling + fn on_close(&mut self) -> impl Future + Send { + async move { + println!("Data channel '{}'-'{}' closed.", self.label, self.id); + let mut done = self.done_tx.lock().await; + done.take(); + } + } +} #[tokio::main] async fn main() -> Result<()> { @@ -113,85 +220,12 @@ async fn main() -> Result<()> { let peer_connection = Arc::new(api.new_peer_connection(config).await?); let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); + let done_tx = Arc::new(done_tx); - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Register data channel creation handling - peer_connection - .on_data_channel(Box::new(move |d: Arc| { - let d_label = d.label().to_owned(); - let d_id = d.id(); - println!("New DataChannel {d_label} {d_id}"); - - let close_after2 = Arc::clone(&close_after); - - // Register channel opening handling - Box::pin(async move { - let d2 = Arc::clone(&d); - let d_label2 = d_label.clone(); - let d_id2 = d_id; - d.on_open(Box::new(move || { - println!("Data channel '{d_label2}'-'{d_id2}' open. Random messages will now be sent to any connected DataChannels every 5 seconds"); - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - Box::pin(async move { - d2.on_close(Box::new(move || { - println!("Data channel '{d_label2}'-'{d_id2}' closed."); - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move{ - let mut done = done_tx2.lock().await; - done.take(); - }) - })); - - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = done_rx.recv() => { - break; - } - _ = timeout.as_mut() =>{ - let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = d2.send_text(message).await.map_err(Into::into); - - let cnt = close_after2.fetch_sub(1, Ordering::SeqCst); - if cnt <= 0 { - println!("Sent times out. Closing data channel '{}'-'{}'.", d2.label(), d2.id()); - let _ = d2.close().await; - break; - } - } - }; - } - }) - })); - - // Register text message handling - d.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async {}) - })); - }) - })); + peer_connection.with_event_handler(ConnectionHandler { + done_tx, + close_after: close_after.clone(), + }); // Wait for the offer to be pasted let line = signal::must_read_stdin()?; diff --git a/examples/examples/data-channels-create/data-channels-create.rs b/examples/examples/data-channels-create/data-channels-create.rs index 87d3c0789..e75af1514 100644 --- a/examples/examples/data-channels-create/data-channels-create.rs +++ b/examples/examples/data-channels-create/data-channels-create.rs @@ -1,3 +1,4 @@ +use std::future::Future; use std::io::Write; use std::sync::Arc; @@ -8,12 +9,73 @@ use webrtc::api::interceptor_registry::register_default_interceptors; use webrtc::api::media_engine::MediaEngine; use webrtc::api::APIBuilder; use webrtc::data_channel::data_channel_message::DataChannelMessage; +use webrtc::data_channel::{RTCDataChannel, RTCDataChannelEventHandler}; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::math_rand_alpha; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::PeerConnectionEventHandler; + +struct ConnectionHandler { + done_tx: Arc>, +} + +impl PeerConnectionEventHandler for ConnectionHandler { + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); + } + } + } +} + +struct DataChannelHandler { + channel: Arc, + label: String, + id: u16, +} + +impl RTCDataChannelEventHandler for DataChannelHandler { + // Register channel opening handling + fn on_open(&mut self) -> impl Future + Send { + async move { + println!("Data channel '{}'-'{}' open. Random messages will now be sent to any connected DataChannels every 5 seconds", self.label, self.id); + + let mut result = Result::::Ok(0); + while result.is_ok() { + let timeout = tokio::time::sleep(Duration::from_secs(5)); + tokio::pin!(timeout); + + tokio::select! { + _ = timeout.as_mut() =>{ + let message = math_rand_alpha(15); + println!("Sending '{message}'"); + result = self.channel.send_text(message).await.map_err(Into::into); + } + }; + } + } + } + + // Register text message handling + fn on_message(&mut self, msg: DataChannelMessage) -> impl Future + Send { + let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); + println!("Message from DataChannel '{}': '{msg_str}'", self.label); + async {} + } +} #[tokio::main] async fn main() -> Result<()> { @@ -98,54 +160,16 @@ async fn main() -> Result<()> { // Create a datachannel with label 'data' let data_channel = peer_connection.create_data_channel("data", None).await?; - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Register channel opening handling - let d1 = Arc::clone(&data_channel); - data_channel.on_open(Box::new(move || { - println!("Data channel '{}'-'{}' open. Random messages will now be sent to any connected DataChannels every 5 seconds", d1.label(), d1.id()); - - let d2 = Arc::clone(&d1); - Box::pin(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); + data_channel.with_event_handler(DataChannelHandler { + channel: data_channel.clone(), + label: data_channel.label().to_owned(), + id: data_channel.id(), + }); - tokio::select! { - _ = timeout.as_mut() =>{ - let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = d2.send_text(message).await.map_err(Into::into); - } - }; - } - }) - })); + let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); + let done_tx = Arc::new(done_tx); - // Register text message handling - let d_label = data_channel.label().to_owned(); - data_channel.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async {}) - })); + peer_connection.with_event_handler(ConnectionHandler { done_tx }); // Create an offer to send to the browser let offer = peer_connection.create_offer(None).await?; diff --git a/examples/examples/data-channels-detach-create/data-channels-detach-create.rs b/examples/examples/data-channels-detach-create/data-channels-detach-create.rs index 2c6301790..65286a343 100644 --- a/examples/examples/data-channels-detach-create/data-channels-detach-create.rs +++ b/examples/examples/data-channels-detach-create/data-channels-detach-create.rs @@ -1,3 +1,4 @@ +use std::future::Future; use std::io::Write; use std::sync::Arc; @@ -9,12 +10,14 @@ use webrtc::api::interceptor_registry::register_default_interceptors; use webrtc::api::media_engine::MediaEngine; use webrtc::api::setting_engine::SettingEngine; use webrtc::api::APIBuilder; +use webrtc::data_channel::{RTCDataChannel, RTCDataChannelEventHandler}; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::math_rand_alpha; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::PeerConnectionEventHandler; const MESSAGE_SIZE: usize = 1500; @@ -114,47 +117,71 @@ async fn main() -> Result<()> { // Set the handler for Peer connection state // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); + + struct ConnectionHandler { + done_tx: tokio::sync::mpsc::Sender<()>, + } + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + println!("Peer Connection State has changed: {state}"); + + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); + } + } } + } - Box::pin(async {}) - })); + peer_connection.with_event_handler(ConnectionHandler { done_tx }); // Register channel opening handling let d = Arc::clone(&data_channel); - data_channel.on_open(Box::new(move || { - println!("Data channel '{}'-'{}' open.", d.label(), d.id()); - - let d2 = Arc::clone(&d); - Box::pin(async move { - let raw = match d2.detach().await { - Ok(raw) => raw, - Err(err) => { - println!("data channel detach got err: {err}"); - return; - } - }; - - // Handle reading from the data channel - let r = Arc::clone(&raw); - tokio::spawn(async move { - let _ = read_loop(r).await; - }); - - // Handle writing to the data channel - tokio::spawn(async move { - let _ = write_loop(raw).await; - }); - }) - })); + + struct ChannelHandler { + data_channel: Arc, + }; + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + async move { + println!( + "Data channel '{}'-'{}' open.", + self.data_channel.label(), + self.data_channel.id() + ); + let raw = match self.data_channel.detach().await { + Ok(raw) => raw, + Err(err) => { + println!("data channel detach got err: {err}"); + return; + } + }; + + // Handle reading from the data channel + let r = Arc::clone(&raw); + tokio::spawn(async move { + let _ = read_loop(r).await; + }); + + // Handle writing to the data channel + tokio::spawn(async move { + let _ = write_loop(raw).await; + }); + } + } + } + + data_channel.with_event_handler(ChannelHandler { + data_channel: data_channel.clone(), + }); // Create an offer to send to the browser let offer = peer_connection.create_offer(None).await?; diff --git a/examples/examples/data-channels-detach/data-channels-detach.rs b/examples/examples/data-channels-detach/data-channels-detach.rs index 7c90a294a..2c2b6dbbf 100644 --- a/examples/examples/data-channels-detach/data-channels-detach.rs +++ b/examples/examples/data-channels-detach/data-channels-detach.rs @@ -1,3 +1,4 @@ +use std::future::Future; use std::io::Write; use std::sync::Arc; @@ -10,12 +11,14 @@ use webrtc::api::media_engine::MediaEngine; use webrtc::api::setting_engine::SettingEngine; use webrtc::api::APIBuilder; use webrtc::data_channel::RTCDataChannel; +use webrtc::data_channel::RTCDataChannelEventHandler; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::math_rand_alpha; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::PeerConnectionEventHandler; const MESSAGE_SIZE: usize = 1500; @@ -110,59 +113,79 @@ async fn main() -> Result<()> { let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); + struct ConnectionHandler { + done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); + } + async {} } - Box::pin(async {}) - })); - - // Register data channel creation handling - peer_connection.on_data_channel(Box::new(move |d: Arc| { - let d_label = d.label().to_owned(); - let d_id = d.id(); - println!("New DataChannel {d_label} {d_id}"); - - // Register channel opening handling - Box::pin(async move { - let d2 = Arc::clone(&d); - let d_label2 = d_label.clone(); - let d_id2 = d_id; - d.on_open(Box::new(move || { - println!("Data channel '{d_label2}'-'{d_id2}' open."); - - Box::pin(async move { - let raw = match d2.detach().await { - Ok(raw) => raw, - Err(err) => { - println!("data channel detach got err: {err}"); - return; - } - }; - - // Handle reading from the data channel - let r = Arc::clone(&raw); - tokio::spawn(async move { - let _ = read_loop(r).await; - }); - - // Handle writing to the data channel - tokio::spawn(async move { - let _ = write_loop(raw).await; - }); - }) - })); - }) - })); + // Register data channel creation handling + fn on_data_channel( + &mut self, + data_channel: Arc, + ) -> impl Future + Send { + async move { + let d_label = data_channel.label().to_owned(); + let d_id = data_channel.id(); + println!("New DataChannel {d_label} {d_id}"); + data_channel.with_event_handler(ChannelHandler { + label: d_label, + id: d_id, + data_channel: data_channel.clone(), + }); + } + } + } + struct ChannelHandler { + label: String, + id: u16, + data_channel: Arc, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + println!("Data channel '{}'-'{}' open.", self.label, self.id); + + Box::pin(async move { + let raw = match self.data_channel.detach().await { + Ok(raw) => raw, + Err(err) => { + println!("data channel detach got err: {err}"); + return; + } + }; + + // Handle reading from the data channel + let r = Arc::clone(&raw); + tokio::spawn(async move { + let _ = read_loop(r).await; + }); + + // Handle writing to the data channel + tokio::spawn(async move { + let _ = write_loop(raw).await; + }); + }) + } + } + peer_connection.with_event_handler(ConnectionHandler { done_tx }); // Wait for the offer to be pasted let line = signal::must_read_stdin()?; diff --git a/examples/examples/data-channels-flow-control/data-channels-flow-control.rs b/examples/examples/data-channels-flow-control/data-channels-flow-control.rs index ba1999d79..0a7e361d8 100644 --- a/examples/examples/data-channels-flow-control/data-channels-flow-control.rs +++ b/examples/examples/data-channels-flow-control/data-channels-flow-control.rs @@ -1,3 +1,4 @@ +use std::future::Future; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime}; @@ -7,12 +8,14 @@ use webrtc::api::interceptor_registry::register_default_interceptors; use webrtc::api::media_engine::MediaEngine; use webrtc::api::APIBuilder; use webrtc::data_channel::data_channel_init::RTCDataChannelInit; +use webrtc::data_channel::data_channel_message::DataChannelMessage; +use webrtc::data_channel::{RTCDataChannel, RTCDataChannelEventHandler}; use webrtc::ice_transport::ice_candidate::RTCIceCandidate; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::RTCPeerConnection; +use webrtc::peer_connection::{PeerConnectionEventHandler, RTCPeerConnection}; const BUFFERED_AMOUNT_LOW_THRESHOLD: usize = 512 * 1024; // 512 KB const MAX_BUFFERED_AMOUNT: usize = 1024 * 1024; // 1 MB @@ -61,61 +64,109 @@ async fn create_requester() -> anyhow::Result { let dc = pc.create_data_channel("data", options).await?; // Use mpsc channel to send and receive a signal when more data can be sent - let (more_can_be_sent, mut maybe_more_can_be_sent) = tokio::sync::mpsc::channel(1); + let (more_can_be_sent, maybe_more_can_be_sent) = tokio::sync::mpsc::channel::<()>(1); - // Get a shared pointer to the data channel - let shared_dc = dc.clone(); - dc.on_open(Box::new(|| { - Box::pin(async move { + struct ChannelHandler { + data_channel: Arc, + maybe_more_can_be_sent: Arc>>, + more_can_be_sent: tokio::sync::mpsc::Sender<()>, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { // This callback shouldn't be blocked for a long time, so we spawn our handler + + let maybe_more_can_be_sent = self.maybe_more_can_be_sent.clone(); + let data_channel = self.data_channel.clone(); tokio::spawn(async move { let buf = Bytes::from_static(&[0u8; 1024]); loop { - if shared_dc.send(&buf).await.is_err() { + if data_channel.send(&buf).await.is_err() { break; } - let buffered_amount = shared_dc.buffered_amount().await; + let buffered_amount = data_channel.buffered_amount().await; if buffered_amount + buf.len() > MAX_BUFFERED_AMOUNT { // Wait for the signal that more can be sent - let _ = maybe_more_can_be_sent.recv().await; + let mut wait_for_more = maybe_more_can_be_sent.lock().await; + let _ = wait_for_more.recv().await; } } }); - }) - })); + async {} + } + + fn on_buffered_amount_low(&mut self, _: ()) -> impl Future + Send { + async move { + // Send a signal that more can be sent + self.more_can_be_sent.send(()).await.unwrap(); + } + } + } + dc.with_event_handler(ChannelHandler { + data_channel: dc.clone(), + more_can_be_sent, + maybe_more_can_be_sent: Arc::new(tokio::sync::Mutex::new(maybe_more_can_be_sent)), + }); dc.set_buffered_amount_low_threshold(BUFFERED_AMOUNT_LOW_THRESHOLD) .await; - dc.on_buffered_amount_low(Box::new(move || { - let more_can_be_sent = more_can_be_sent.clone(); - - Box::pin(async move { - // Send a signal that more can be sent - more_can_be_sent.send(()).await.unwrap(); - }) - })) - .await; - Ok(pc) } -async fn create_responder() -> anyhow::Result { - // Create a peer connection first - let pc = create_peer_connection().await?; +#[tokio::main] +async fn main() -> anyhow::Result<()> { + env_logger::init(); + + let requester = Arc::new(create_peer_connection().await?); + let responder = Arc::new(create_peer_connection().await?); - // Set a data channel handler so that we can receive data - pc.on_data_channel(Box::new(move |dc| { - Box::pin(async move { - let total_bytes_received = Arc::new(AtomicUsize::new(0)); + struct ResponderHandler { + maybe_responder: std::sync::Weak, + fault: tokio::sync::mpsc::Sender<()>, + } - let shared_total_bytes_received = total_bytes_received.clone(); - dc.on_open(Box::new(move || { - Box::pin(async { - // This callback shouldn't be blocked for a long time, so we spawn our handler + impl PeerConnectionEventHandler for ResponderHandler { + fn on_ice_candidate( + &mut self, + candidate: Option, + ) -> impl Future + Send { + async move { + if let (Some(requester), Some(Ok(candidate))) = ( + self.maybe_responder.upgrade(), + candidate.map(|c| c.to_json()), + ) { + if let Err(err) = requester.add_ice_candidate(candidate).await { + log::warn!("{err}"); + } + } + } + } + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + if state == RTCPeerConnectionState::Failed { + self.fault.send(()).await.unwrap(); + } + } + } + + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + struct ChannelHandler { + bytes_received: Arc, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + let bytes_received = self.bytes_received.clone(); tokio::spawn(async move { let start = SystemTime::now(); @@ -123,8 +174,7 @@ async fn create_responder() -> anyhow::Result { println!(); loop { - let total_bytes_received = - shared_total_bytes_received.load(Ordering::Relaxed); + let total_bytes_received = bytes_received.load(Ordering::Relaxed); let elapsed = SystemTime::now().duration_since(start); let bps = @@ -137,84 +187,71 @@ async fn create_responder() -> anyhow::Result { tokio::time::sleep(Duration::from_secs(1)).await; } }); - }) - })); - - dc.on_message(Box::new(move |msg| { - let total_bytes_received = total_bytes_received.clone(); - - Box::pin(async move { - total_bytes_received.fetch_add(msg.data.len(), Ordering::Relaxed); - }) - })); - }) - })); - - Ok(pc) -} + async {} + } + fn on_message( + &mut self, + msg: DataChannelMessage, + ) -> impl Future + Send { + self.bytes_received + .fetch_add(msg.data.len(), Ordering::Relaxed); + async {} + } + } -#[tokio::main] -async fn main() -> anyhow::Result<()> { - env_logger::init(); + channel.with_event_handler(ChannelHandler { + bytes_received: Arc::new(AtomicUsize::new(0)), + }); + async {} + } + } - let requester = Arc::new(create_requester().await?); - let responder = Arc::new(create_responder().await?); + struct RequesterHandler { + maybe_requester: std::sync::Weak, + fault: tokio::sync::mpsc::Sender<()>, + } - let maybe_requester = Arc::downgrade(&requester); - responder.on_ice_candidate(Box::new(move |candidate: Option| { - let maybe_requester = maybe_requester.clone(); - - Box::pin(async move { - if let Some(candidate) = candidate { - if let Ok(candidate) = candidate.to_json() { - if let Some(requester) = maybe_requester.upgrade() { - if let Err(err) = requester.add_ice_candidate(candidate).await { - log::warn!("{}", err); - } + impl PeerConnectionEventHandler for RequesterHandler { + fn on_ice_candidate( + &mut self, + candidate: Option, + ) -> impl Future + Send { + async move { + if let (Some(requester), Some(Ok(candidate))) = ( + self.maybe_requester.upgrade(), + candidate.map(|c| c.to_json()), + ) { + if let Err(err) = requester.add_ice_candidate(candidate).await { + log::warn!("{err}"); } } } - }) - })); - - let maybe_responder = Arc::downgrade(&responder); - requester.on_ice_candidate(Box::new(move |candidate: Option| { - let maybe_responder = maybe_responder.clone(); - - Box::pin(async move { - if let Some(candidate) = candidate { - if let Ok(candidate) = candidate.to_json() { - if let Some(responder) = maybe_responder.upgrade() { - if let Err(err) = responder.add_ice_candidate(candidate).await { - log::warn!("{}", err); - } - } + } + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + if state == RTCPeerConnectionState::Failed { + self.fault.send(()).await.unwrap(); } } - }) - })); + } + } + let maybe_requester = Arc::downgrade(&requester); let (fault, mut reqs_fault) = tokio::sync::mpsc::channel(1); - requester.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - let fault = fault.clone(); - - Box::pin(async move { - if s == RTCPeerConnectionState::Failed { - fault.send(()).await.unwrap(); - } - }) - })); + requester.with_event_handler(RequesterHandler { + maybe_requester, + fault, + }); + let maybe_responder = Arc::downgrade(&responder); let (fault, mut resp_fault) = tokio::sync::mpsc::channel(1); - responder.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - let fault = fault.clone(); - - Box::pin(async move { - if s == RTCPeerConnectionState::Failed { - fault.send(()).await.unwrap(); - } - }) - })); + responder.with_event_handler(ResponderHandler { + maybe_responder, + fault, + }); let reqs = requester.create_offer(None).await?; diff --git a/examples/examples/data-channels/data-channels.rs b/examples/examples/data-channels/data-channels.rs index 708f20342..79364ceec 100644 --- a/examples/examples/data-channels/data-channels.rs +++ b/examples/examples/data-channels/data-channels.rs @@ -1,3 +1,4 @@ +use std::future::Future; use std::io::Write; use std::sync::Arc; @@ -9,12 +10,14 @@ use webrtc::api::media_engine::MediaEngine; use webrtc::api::APIBuilder; use webrtc::data_channel::data_channel_message::DataChannelMessage; use webrtc::data_channel::RTCDataChannel; +use webrtc::data_channel::RTCDataChannelEventHandler; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::math_rand_alpha; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::PeerConnectionEventHandler; #[tokio::main] async fn main() -> Result<()> { @@ -98,67 +101,84 @@ async fn main() -> Result<()> { let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } + struct ConnectionHandler { + done_tx: tokio::sync::mpsc::Sender<()>, + } - Box::pin(async {}) - })); + impl PeerConnectionEventHandler for ConnectionHandler { + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + async move { + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); + } + } + } - // Register data channel creation handling - peer_connection - .on_data_channel(Box::new(move |d: Arc| { - let d_label = d.label().to_owned(); - let d_id = d.id(); + // Register data channel creation handling + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + let d_label = channel.label().to_owned(); + let d_id = channel.id(); println!("New DataChannel {d_label} {d_id}"); - // Register channel opening handling - Box::pin(async move { - let d2 = Arc::clone(&d); - let d_label2 = d_label.clone(); - let d_id2 = d_id; - d.on_close(Box::new(move || { - println!("Data channel closed"); - Box::pin(async {}) - })); - - d.on_open(Box::new(move || { - println!("Data channel '{d_label2}'-'{d_id2}' open. Random messages will now be sent to any connected DataChannels every 5 seconds"); - - Box::pin(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ + channel.with_event_handler(ChannelHandler { + label: d_label, + id: d_id, + channel: channel.clone(), + }); + async {} + } + } + + struct ChannelHandler { + label: String, + id: u16, + channel: Arc, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_close(&mut self) -> impl Future + Send { + println!("Data channel closed"); + async {} + } + + fn on_open(&mut self) -> impl Future + Send { + println!("Data channel '{}'-'{}' open. Random messages will now be sent to any connected DataChannels every 5 seconds", self.label, self.id); + async move { + let mut result = Result::::Ok(0); + while result.is_ok() { + let timeout = tokio::time::sleep(Duration::from_secs(5)); + tokio::pin!(timeout); + tokio::select! { + _ = timeout.as_mut() => { let message = math_rand_alpha(15); println!("Sending '{message}'"); - result = d2.send_text(message).await.map_err(Into::into); - } - }; + result = self.channel.send_text(message).await.map_err(Into::into); } - }) - })); - - // Register text message handling - d.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async {}) - })); - }) - })); + } + } + } + } + + // Register text message handling + fn on_message(&mut self, msg: DataChannelMessage) -> impl Future + Send { + let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); + println!("Message from DataChannel '{}': '{msg_str}'", self.label); + async {} + } + } // Wait for the offer to be pasted let line = signal::must_read_stdin()?; diff --git a/examples/examples/ice-restart/ice-restart.rs b/examples/examples/ice-restart/ice-restart.rs index 11fe51d28..45a1a28c6 100644 --- a/examples/examples/ice-restart/ice-restart.rs +++ b/examples/examples/ice-restart/ice-restart.rs @@ -1,3 +1,4 @@ +use std::future::Future; use std::io::Write; use std::net::SocketAddr; use std::str::FromStr; @@ -13,12 +14,12 @@ use tokio_util::codec::{BytesCodec, FramedRead}; use webrtc::api::interceptor_registry::register_default_interceptors; use webrtc::api::media_engine::MediaEngine; use webrtc::api::APIBuilder; -use webrtc::data_channel::RTCDataChannel; +use webrtc::data_channel::{RTCDataChannel, RTCDataChannelEventHandler}; use webrtc::ice_transport::ice_connection_state::RTCIceConnectionState; use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::peer_connection::RTCPeerConnection; +use webrtc::peer_connection::{PeerConnectionEventHandler, RTCPeerConnection}; #[macro_use] extern crate lazy_static; @@ -108,32 +109,49 @@ async fn do_signaling(req: Request) -> Result, hyper::Error }; let pc = Arc::new(pc); - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - pc.on_ice_connection_state_change(Box::new( - |connection_state: RTCIceConnectionState| { - println!("ICE Connection State has changed: {connection_state}"); - Box::pin(async {}) - }, - )); - - // Send the current time via a DataChannel to the remote peer every 3 seconds - pc.on_data_channel(Box::new(|d: Arc| { - Box::pin(async move { - let d2 = Arc::clone(&d); - d.on_open(Box::new(move || { - Box::pin(async move { - while d2 - .send_text(format!("{:?}", tokio::time::Instant::now())) - .await - .is_ok() - { - tokio::time::sleep(Duration::from_secs(3)).await; - } - }) - })); - }) - })); + struct ConnectionHandler; + + struct ChannelHandler { + channel: Arc, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + async move { + while self + .channel + .send_text(format!("{:?}", tokio::time::Instant::now())) + .await + .is_ok() + { + tokio::time::sleep(Duration::from_secs(3)).await; + } + } + } + } + + impl PeerConnectionEventHandler for ConnectionHandler { + // Set the handler for ICE connection state + // This will notify you when the peer has connected/disconnected + fn on_ice_connection_state_change( + &mut self, + state: RTCIceConnectionState, + ) -> impl Future + Send { + println!("ICE Connection State has changed: {state}"); + async {} + } + + // Send the current time via a DataChannel to the remote peer every 3 seconds + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + channel.with_event_handler(ChannelHandler { + channel: channel.clone(), + }); + async {} + } + } *peer_connection = Some(Arc::clone(&pc)); pc diff --git a/examples/examples/insertable-streams/insertable-streams.rs b/examples/examples/insertable-streams/insertable-streams.rs index 4e45499fa..e259f4d51 100644 --- a/examples/examples/insertable-streams/insertable-streams.rs +++ b/examples/examples/insertable-streams/insertable-streams.rs @@ -1,4 +1,5 @@ use std::fs::File; +use std::future::Future; use std::io::{BufReader, Write}; use std::path::Path; use std::sync::Arc; @@ -18,6 +19,7 @@ use webrtc::media::Sample; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::PeerConnectionEventHandler; use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; use webrtc::track::track_local::TrackLocal; @@ -194,33 +196,44 @@ async fn main() -> Result<()> { Result::<()>::Ok(()) }); - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); - if connection_state == RTCIceConnectionState::Connected { - notify_tx.notify_waiters(); - } - Box::pin(async {}) - }, - )); + struct ConnectionHandler { + notify_tx: Arc, + done_tx: tokio::sync::mpsc::Sender<()>, + } - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); + impl PeerConnectionEventHandler for ConnectionHandler { + // Set the handler for ICE connection state + // This will notify you when the peer has connected/disconnected + fn on_ice_connection_state_change( + &mut self, + state: RTCIceConnectionState, + ) -> impl Future + Send { + println!("Connection State has changed {state}"); + if state == RTCIceConnectionState::Connected { + self.notify_tx.notify_waiters(); + } + async {} } - Box::pin(async {}) - })); + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); + } + async {} + } + } + peer_connection.with_event_handler(ConnectionHandler { notify_tx, done_tx }); // Wait for the offer to be pasted let line = signal::must_read_stdin()?; diff --git a/examples/examples/offer-answer/answer.rs b/examples/examples/offer-answer/answer.rs index df81857f7..4acd091f5 100644 --- a/examples/examples/offer-answer/answer.rs +++ b/examples/examples/offer-answer/answer.rs @@ -1,7 +1,8 @@ +use std::future::Future; use std::io::Write; use std::net::SocketAddr; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use anyhow::Result; use clap::{AppSettings, Arg, Command}; @@ -13,14 +14,14 @@ use webrtc::api::interceptor_registry::register_default_interceptors; use webrtc::api::media_engine::MediaEngine; use webrtc::api::APIBuilder; use webrtc::data_channel::data_channel_message::DataChannelMessage; -use webrtc::data_channel::RTCDataChannel; +use webrtc::data_channel::{RTCDataChannel, RTCDataChannelEventHandler}; use webrtc::ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit}; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::peer_connection::{math_rand_alpha, RTCPeerConnection}; +use webrtc::peer_connection::{math_rand_alpha, PeerConnectionEventHandler, RTCPeerConnection}; #[macro_use] extern crate lazy_static; @@ -279,27 +280,105 @@ async fn main() -> Result<()> { // the other Pion instance will add this candidate by calling AddICECandidate let pc = Arc::downgrade(&peer_connection); let pending_candidates2 = Arc::clone(&PENDING_CANDIDATES); - let addr2 = offer_addr.clone(); - peer_connection.on_ice_candidate(Box::new(move |c: Option| { - //println!("on_ice_candidate {:?}", c); - - let pc2 = pc.clone(); - let pending_candidates3 = Arc::clone(&pending_candidates2); - let addr3 = addr2.clone(); - Box::pin(async move { - if let Some(c) = c { - if let Some(pc) = pc2.upgrade() { - let desc = pc.remote_description().await; + + struct ConnectionHandler { + connection: Weak, + addr: String, + done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_ice_candidate( + &mut self, + candidate: Option, + ) -> impl Future + Send { + async move { + if let (Some(candidate), Some(connection)) = (candidate, self.connection.upgrade()) + { + let desc = connection.remote_description().await; if desc.is_none() { - let mut cs = pending_candidates3.lock().await; - cs.push(c); - } else if let Err(err) = signal_candidate(&addr3, &c).await { + } else if let Err(err) = signal_candidate(&self.addr, &candidate).await { panic!("{}", err); } } } - }) - })); + } + + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); + } + + async {} + } + + // Register data channel creation handling + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + let d_label = channel.label().to_owned(); + let d_id = channel.id(); + println!("New DataChannel {d_label} {d_id}"); + + struct ChannelHandler { + label: String, + id: u16, + channel: Arc, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + // Register channel opening handling + fn on_open(&mut self) -> impl Future + Send { + println!("Data channel '{}'-'{}' open. Random messages will now be sent to any connected DataChannels every 5 seconds", self.label, self.id); + async move { + let mut result = Result::::Ok(0); + while result.is_ok() { + let timeout = tokio::time::sleep(Duration::from_secs(5)); + tokio::pin!(timeout); + + tokio::select! { + _ = timeout.as_mut() =>{ + let message = math_rand_alpha(15); + println!("Sending '{message}'"); + result = self.channel.send_text(message).await.map_err(Into::into); + } + }; + } + } + } + + // Register text message handling + fn on_message( + &mut self, + msg: DataChannelMessage, + ) -> impl Future + Send { + let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); + println!("Message from DataChannel '{}': '{msg_str}'", self.label); + async {} + } + } + + channel.with_event_handler(ChannelHandler { + label: d_label, + id: d_id, + channel: channel.clone(), + }); + + async {} + } + } println!("Listening on http://{answer_addr}"); { @@ -319,61 +398,11 @@ async fn main() -> Result<()> { }); let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - - Box::pin(async {}) - })); - - // Register data channel creation handling - peer_connection.on_data_channel(Box::new(move |d: Arc| { - let d_label = d.label().to_owned(); - let d_id = d.id(); - println!("New DataChannel {d_label} {d_id}"); - - Box::pin(async move{ - // Register channel opening handling - let d2 = Arc::clone(&d); - let d_label2 = d_label.clone(); - let d_id2 = d_id; - d.on_open(Box::new(move || { - println!("Data channel '{d_label2}'-'{d_id2}' open. Random messages will now be sent to any connected DataChannels every 5 seconds"); - Box::pin(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = d2.send_text(message).await.map_err(Into::into); - } - }; - } - }) - })); - - // Register text message handling - d.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async{}) - })); - }) - })); + peer_connection.with_event_handler(ConnectionHandler { + connection: pc, + done_tx, + addr: offer_addr, + }); println!("Press ctrl-c to stop"); tokio::select! { diff --git a/examples/examples/offer-answer/offer.rs b/examples/examples/offer-answer/offer.rs index f492ad9ca..2de8d0890 100644 --- a/examples/examples/offer-answer/offer.rs +++ b/examples/examples/offer-answer/offer.rs @@ -1,7 +1,8 @@ +use std::future::Future; use std::io::Write; use std::net::SocketAddr; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use anyhow::Result; use clap::{AppSettings, Arg, Command}; @@ -13,13 +14,14 @@ use webrtc::api::interceptor_registry::register_default_interceptors; use webrtc::api::media_engine::MediaEngine; use webrtc::api::APIBuilder; use webrtc::data_channel::data_channel_message::DataChannelMessage; +use webrtc::data_channel::{RTCDataChannel, RTCDataChannelEventHandler}; use webrtc::ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit}; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::peer_connection::{math_rand_alpha, RTCPeerConnection}; +use webrtc::peer_connection::{math_rand_alpha, PeerConnectionEventHandler, RTCPeerConnection}; #[macro_use] extern crate lazy_static; @@ -237,27 +239,51 @@ async fn main() -> Result<()> { // the other Pion instance will add this candidate by calling AddICECandidate let pc = Arc::downgrade(&peer_connection); let pending_candidates2 = Arc::clone(&PENDING_CANDIDATES); - let addr2 = answer_addr.clone(); - peer_connection.on_ice_candidate(Box::new(move |c: Option| { - //println!("on_ice_candidate {:?}", c); - - let pc2 = pc.clone(); - let pending_candidates3 = Arc::clone(&pending_candidates2); - let addr3 = addr2.clone(); - Box::pin(async move { - if let Some(c) = c { - if let Some(pc) = pc2.upgrade() { - let desc = pc.remote_description().await; + + struct ConnectionHandler { + connection: Weak, + addr: String, + pending_candidates: Arc>>, + done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_ice_candidate( + &mut self, + candidate: Option, + ) -> impl Future + Send { + async move { + if let (Some(candidate), Some(connection)) = (candidate, self.connection.upgrade()) + { + let desc = connection.remote_description().await; if desc.is_none() { - let mut cs = pending_candidates3.lock().await; - cs.push(c); - } else if let Err(err) = signal_candidate(&addr3, &c).await { + let mut candidates = self.pending_candidates.lock().await; + candidates.push(candidate); + } else if let Err(err) = signal_candidate(&self.addr, &candidate).await { panic!("{}", err); } } } - }) - })); + } + + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); + } + async {} + } + } println!("Listening on http://{offer_addr}"); { @@ -281,52 +307,53 @@ async fn main() -> Result<()> { let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } + peer_connection.with_event_handler(ConnectionHandler { + connection: pc, + addr: answer_addr.clone(), + done_tx, + pending_candidates: pending_candidates2, + }); - Box::pin(async {}) - })); - - // Register channel opening handling - let d1 = Arc::clone(&data_channel); - data_channel.on_open(Box::new(move || { - println!("Data channel '{}'-'{}' open. Random messages will now be sent to any connected DataChannels every 5 seconds", d1.label(), d1.id()); - - let d2 = Arc::clone(&d1); - Box::pin(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(5)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - let message = math_rand_alpha(15); - println!("Sending '{message}'"); - result = d2.send_text(message).await.map_err(Into::into); - } - }; + struct ChannelHandler { + channel: Arc, + label: String, + id: u16, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + // Register channel opening handling + fn on_open(&mut self) -> impl Future + Send { + println!("Data channel '{}'-'{}' open. Random messages will now be sent to any connected DataChannels every 5 seconds", self.label, self.id); + async move { + let mut result = Result::::Ok(0); + while result.is_ok() { + let timeout = tokio::time::sleep(Duration::from_secs(5)); + tokio::pin!(timeout); + + tokio::select! { + _ = timeout.as_mut() =>{ + let message = math_rand_alpha(15); + println!("Sending '{message}'"); + result = self.channel.send_text(message).await.map_err(Into::into); + } + }; + } } - }) - })); + } + + fn on_message(&mut self, msg: DataChannelMessage) -> impl Future + Send { + let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); + println!("Message from DataChannel '{}': '{msg_str}'", self.label); + async {} + } + } - // Register text message handling let d_label = data_channel.label().to_owned(); - data_channel.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async {}) - })); + data_channel.with_event_handler(ChannelHandler { + channel: data_channel.clone(), + label: d_label, + id: data_channel.id(), + }); // Create an offer to send to the other process let offer = peer_connection.create_offer(None).await?; diff --git a/examples/examples/ortc/ortc.rs b/examples/examples/ortc/ortc.rs index f4d2fa439..d55003e01 100644 --- a/examples/examples/ortc/ortc.rs +++ b/examples/examples/ortc/ortc.rs @@ -1,3 +1,4 @@ +use std::future::Future; use std::io::Write; use std::sync::Arc; @@ -9,15 +10,17 @@ use tokio::time::Duration; use webrtc::api::APIBuilder; use webrtc::data_channel::data_channel_message::DataChannelMessage; use webrtc::data_channel::data_channel_parameters::DataChannelParameters; -use webrtc::data_channel::RTCDataChannel; +use webrtc::data_channel::{RTCDataChannel, RTCDataChannelEventHandler}; use webrtc::dtls_transport::dtls_parameters::DTLSParameters; use webrtc::ice_transport::ice_candidate::RTCIceCandidate; +use webrtc::ice_transport::ice_gatherer::IceGathererEventHandler; use webrtc::ice_transport::ice_gatherer::RTCIceGatherOptions; use webrtc::ice_transport::ice_parameters::RTCIceParameters; use webrtc::ice_transport::ice_role::RTCIceRole; use webrtc::ice_transport::ice_server::RTCIceServer; use webrtc::peer_connection::math_rand_alpha; use webrtc::sctp_transport::sctp_transport_capabilities::SCTPTransportCapabilities; +use webrtc::sctp_transport::SctpTransportEventHandler; #[tokio::main] async fn main() -> Result<()> { @@ -100,48 +103,81 @@ async fn main() -> Result<()> { let done_answer = done.clone(); let done_offer = done.clone(); - // Handle incoming data channels - sctp.on_data_channel(Box::new(move |d: Arc| { - let d_label = d.label().to_owned(); - let d_id = d.id(); - println!("New DataChannel {d_label} {d_id}"); + struct TransportHandler { + done_answer: Arc, + } - let done_answer1 = done_answer.clone(); - // Register the handlers - Box::pin(async move { - // no need to downgrade this to Weak, since on_open is FnOnce callback - let d2 = Arc::clone(&d); - let done_answer2 = done_answer1.clone(); - d.on_open(Box::new(move || { - Box::pin(async move { - tokio::select! { - _ = done_answer2.notified() => { - println!("received done_answer signal!"); - } - _ = handle_on_open(d2) => {} - }; - - println!("exit data answer"); - }) - })); - - // Register text message handling - d.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async {}) - })); - }) - })); + impl SctpTransportEventHandler for TransportHandler { + // Handle incoming data channels + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + let d_label = channel.label().to_owned(); + let d_id = channel.id(); + println!("New DataChannel {d_label} {d_id}"); + + channel.with_event_handler(ChannelHandler { + done_answer: self.done_answer.clone(), + id: d_id, + label: d_label, + channel: channel.clone(), + }); + async {} + } + } + + struct ChannelHandler { + done_answer: Arc, + channel: Arc, + label: String, + id: u16, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + async move { + tokio::select! { + _ = self.done_answer.notified() => { + println!("received done_answer signal!"); + } + _ = handle_on_open(self.channel.clone()) => {} + }; + println!("exit data answer"); + } + } + + // Register text message handling + fn on_message(&mut self, msg: DataChannelMessage) -> impl Future + Send { + let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); + println!("Message from DataChannel '{}': '{msg_str}'", self.label); + async {} + } + } + + sctp.with_event_handler(TransportHandler { done_answer }); + + struct GathererHandler { + finished: Option>, + } + + impl IceGathererEventHandler for GathererHandler { + fn on_local_candidate( + &mut self, + candidate: Option, + ) -> impl Future + Send { + if candidate.is_none() { + self.finished.take(); + } + async {} + } + } let (gather_finished_tx, mut gather_finished_rx) = tokio::sync::mpsc::channel::<()>(1); let mut gather_finished_tx = Some(gather_finished_tx); - gatherer.on_local_candidate(Box::new(move |c: Option| { - if c.is_none() { - gather_finished_tx.take(); - } - Box::pin(async {}) - })); + gatherer.with_event_handler(GathererHandler { + finished: gather_finished_tx, + }); // Gather candidates gatherer.gather().await?; @@ -220,12 +256,20 @@ async fn main() -> Result<()> { println!("exit data offer"); }); - let d_label = d.label().to_owned(); - d.on_message(Box::new(move |msg: DataChannelMessage| { - let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); - println!("Message from DataChannel '{d_label}': '{msg_str}'"); - Box::pin(async {}) - })); + struct MessageHandler { + label: String, + } + + impl RTCDataChannelEventHandler for MessageHandler { + fn on_message(&mut self, msg: DataChannelMessage) -> impl Future + Send { + let msg_str = String::from_utf8(msg.data.to_vec()).unwrap(); + println!("Message from DataChannel '{}': '{msg_str}'", self.label); + async {} + } + } + d.with_event_handler(MessageHandler { + label: d.label().to_owned(), + }); } println!("Press ctrl-c to stop"); diff --git a/examples/examples/play-from-disk-h264/play-from-disk-h264.rs b/examples/examples/play-from-disk-h264/play-from-disk-h264.rs index 94e770185..4dfbfa65f 100644 --- a/examples/examples/play-from-disk-h264/play-from-disk-h264.rs +++ b/examples/examples/play-from-disk-h264/play-from-disk-h264.rs @@ -1,4 +1,5 @@ use std::fs::File; +use std::future::Future; use std::io::{BufReader, Write}; use std::path::Path; use std::sync::Arc; @@ -19,6 +20,7 @@ use webrtc::media::Sample; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::PeerConnectionEventHandler; use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; use webrtc::track::track_local::TrackLocal; @@ -286,33 +288,45 @@ async fn main() -> Result<()> { }); } - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); - if connection_state == RTCIceConnectionState::Connected { - notify_tx.notify_waiters(); + struct ConnectionHandler { + notify_tx: Arc, + done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + // Set the handler for ICE connection state + // This will notify you when the peer has connected/disconnected + fn on_ice_connection_state_change( + &mut self, + state: RTCIceConnectionState, + ) -> impl Future + Send { + println!("Connection State has changed {state}"); + if state == RTCIceConnectionState::Connected { + self.notify_tx.notify_waiters(); + } + async {} + } + + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); } - Box::pin(async {}) - }, - )); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); + async {} } + } - Box::pin(async {}) - })); + peer_connection.with_event_handler(ConnectionHandler { notify_tx, done_tx }); // Wait for the offer to be pasted let line = signal::must_read_stdin()?; diff --git a/examples/examples/play-from-disk-renegotiation/play-from-disk-renegotiation.rs b/examples/examples/play-from-disk-renegotiation/play-from-disk-renegotiation.rs index 503cd863d..372f767eb 100644 --- a/examples/examples/play-from-disk-renegotiation/play-from-disk-renegotiation.rs +++ b/examples/examples/play-from-disk-renegotiation/play-from-disk-renegotiation.rs @@ -1,4 +1,5 @@ use std::fs::File; +use std::future::Future; use std::io::{BufReader, Write}; use std::net::SocketAddr; use std::path::Path; @@ -22,6 +23,7 @@ use webrtc::media::Sample; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::PeerConnectionEventHandler; use webrtc::peer_connection::RTCPeerConnection; use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; @@ -323,19 +325,29 @@ async fn main() -> Result<()> { // Set the handler for Peer connection state // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); - } - Box::pin(async {}) - })); + struct ConnectionHandler { + done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); + } + async {} + } + } + peer_connection.with_event_handler(ConnectionHandler { done_tx }); { let mut pcm = PEER_CONNECTION_MUTEX.lock().await; diff --git a/examples/examples/play-from-disk-vpx/play-from-disk-vpx.rs b/examples/examples/play-from-disk-vpx/play-from-disk-vpx.rs index d3a578dc8..91881cdb0 100644 --- a/examples/examples/play-from-disk-vpx/play-from-disk-vpx.rs +++ b/examples/examples/play-from-disk-vpx/play-from-disk-vpx.rs @@ -1,4 +1,5 @@ use std::fs::File; +use std::future::Future; use std::io::{BufReader, Write}; use std::path::Path; use std::sync::Arc; @@ -19,6 +20,7 @@ use webrtc::media::Sample; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::PeerConnectionEventHandler; use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; use webrtc::track::track_local::TrackLocal; @@ -297,33 +299,42 @@ async fn main() -> Result<()> { }); } - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); - if connection_state == RTCIceConnectionState::Connected { - notify_tx.notify_waiters(); + struct ConnectionHandler { + notify_tx: Arc, + done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + // Set the handler for ICE connection state + // This will notify you when the peer has connected/disconnected + fn on_ice_connection_state_change( + &mut self, + state: RTCIceConnectionState, + ) -> impl Future + Send { + println!("Connection State has changed {state}"); + if state == RTCIceConnectionState::Connected { + self.notify_tx.notify_waiters(); } - Box::pin(async {}) - }, - )); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); + async {} } - - Box::pin(async {}) - })); + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); + } + async {} + } + } // Wait for the offer to be pasted let line = signal::must_read_stdin()?; diff --git a/examples/examples/reflect/reflect.rs b/examples/examples/reflect/reflect.rs index c01875e42..117548471 100644 --- a/examples/examples/reflect/reflect.rs +++ b/examples/examples/reflect/reflect.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; +use std::future::Future; use std::io::Write; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use anyhow::Result; use clap::{AppSettings, Arg, Command}; @@ -13,12 +14,16 @@ use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::{PeerConnectionEventHandler, RTCPeerConnection}; use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; use webrtc::rtp_transceiver::rtp_codec::{ RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType, }; +use webrtc::rtp_transceiver::rtp_receiver::RTCRtpReceiver; +use webrtc::rtp_transceiver::RTCRtpTransceiver; use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; +use webrtc::track::track_remote::TrackRemote; #[tokio::main] async fn main() -> Result<()> { @@ -197,89 +202,106 @@ async fn main() -> Result<()> { // Set a handler for when a new remote track starts, this handler copies inbound RTP packets, // replaces the SSRC and sends them back let pc = Arc::downgrade(&peer_connection); - peer_connection.on_track(Box::new(move |track, _, _| { - // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval - // This is a temporary fix until we implement incoming RTCP events, then we would push a PLI only when a viewer requests it - let media_ssrc = track.ssrc(); - - if track.kind() == RTPCodecType::Video { - let pc2 = pc.clone(); - tokio::spawn(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(3)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - if let Some(pc) = pc2.upgrade(){ - result = pc.write_rtcp(&[Box::new(PictureLossIndication{ - sender_ssrc: 0, - media_ssrc, - })]).await.map_err(Into::into); - }else{ - break; + struct ConnectionHandler { + connection: Weak, + output_tracks: HashMap>, + done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_track( + &mut self, + track: Arc, + _: Arc, + _: Arc, + ) -> impl Future + Send { + // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval + // This is a temporary fix until we implement incoming RTCP events, then we would push a PLI only when a viewer requests it + if track.kind() == RTPCodecType::Video { + let media_ssrc = track.ssrc(); + let connection = self.connection.clone(); + tokio::spawn(async move { + let mut result = Result::::Ok(0); + while result.is_ok() { + let timeout = tokio::time::sleep(Duration::from_secs(3)); + tokio::pin!(timeout); + + tokio::select! { + _ = timeout.as_mut() =>{ + if let Some(pc) = connection.upgrade(){ + result = pc.write_rtcp(&[Box::new(PictureLossIndication{ + sender_ssrc: 0, + media_ssrc, + })]).await.map_err(Into::into); + }else{ + break; + } } + }; + } + }); + } + let kind = if track.kind() == RTPCodecType::Audio { + "audio" + } else { + "video" + }; + async move { + let output_track = if let Some(output_track) = self.output_tracks.get(kind) { + Arc::clone(output_track) + } else { + println!("output_track not found for type = {kind}"); + return; + }; + + let output_track2 = Arc::clone(&output_track); + tokio::spawn(async move { + println!( + "Track has started, of type {}: {}", + track.payload_type(), + track.codec().capability.mime_type + ); + // Read RTP packets being sent to webrtc-rs + while let Ok((rtp, _)) = track.read_rtp().await { + if let Err(err) = output_track2.write_rtp(&rtp).await { + println!("output track write_rtp got error: {err}"); + break; } - }; - } - }); + } + + println!( + "on_track finished, of type {}: {}", + track.payload_type(), + track.codec().capability.mime_type + ); + }); + } } - let kind = if track.kind() == RTPCodecType::Audio { - "audio" - } else { - "video" - }; - let output_track = if let Some(output_track) = output_tracks.get(kind) { - Arc::clone(output_track) - } else { - println!("output_track not found for type = {kind}"); - return Box::pin(async {}); - }; - - let output_track2 = Arc::clone(&output_track); - tokio::spawn(async move { - println!( - "Track has started, of type {}: {}", - track.payload_type(), - track.codec().capability.mime_type - ); - // Read RTP packets being sent to webrtc-rs - while let Ok((rtp, _)) = track.read_rtp().await { - if let Err(err) = output_track2.write_rtp(&rtp).await { - println!("output track write_rtp got error: {err}"); - break; - } + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); } - - println!( - "on_track finished, of type {}: {}", - track.payload_type(), - track.codec().capability.mime_type - ); - }); - - Box::pin(async {}) - })); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); + async {} } - - Box::pin(async {}) - })); + } + let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); + peer_connection.with_event_handler(ConnectionHandler { + connection: pc, + done_tx, + output_tracks, + }); // Create an answer let answer = peer_connection.create_answer(None).await?; diff --git a/examples/examples/rtp-forwarder/rtp-forwarder.rs b/examples/examples/rtp-forwarder/rtp-forwarder.rs index c6e548ae3..baae5ab9f 100644 --- a/examples/examples/rtp-forwarder/rtp-forwarder.rs +++ b/examples/examples/rtp-forwarder/rtp-forwarder.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; +use std::future::Future; use std::io::Write; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use anyhow::Result; use clap::{AppSettings, Arg, Command}; @@ -15,10 +16,14 @@ use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::{PeerConnectionEventHandler, RTCPeerConnection}; use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; use webrtc::rtp_transceiver::rtp_codec::{ RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType, }; +use webrtc::rtp_transceiver::rtp_receiver::RTCRtpReceiver; +use webrtc::rtp_transceiver::RTCRtpTransceiver; +use webrtc::track::track_remote::TrackRemote; use webrtc::util::{Conn, Marshal}; #[derive(Clone)] @@ -171,108 +176,128 @@ async fn main() -> Result<()> { }, ); - // Set a handler for when a new remote track starts, this handler will forward data to - // our UDP listeners. - // In your application this is where you would handle/process audio/video - let pc = Arc::downgrade(&peer_connection); - peer_connection.on_track(Box::new(move |track, _, _| { - // Retrieve udp connection - let c = if let Some(c) = udp_conns.get(&track.kind().to_string()) { - c.clone() - } else { - return Box::pin(async {}); - }; - - // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval - let media_ssrc = track.ssrc(); - let pc2 = pc.clone(); - tokio::spawn(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(3)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - if let Some(pc) = pc2.upgrade(){ - result = pc.write_rtcp(&[Box::new(PictureLossIndication{ - sender_ssrc: 0, - media_ssrc, - })]).await.map_err(Into::into); - }else{ - break; - } - } + struct ConnectionHandler { + udp_conns: HashMap, + connection: Weak, + done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + // Set a handler for when a new remote track starts, this handler will forward data to + // our UDP listeners. + // In your application this is where you would handle/process audio/video + fn on_track( + &mut self, + track: Arc, + _: Arc, + _: Arc, + ) -> impl Future + Send { + async move { + // Retrieve udp connection + let c = if let Some(c) = self.udp_conns.get(&track.kind().to_string()) { + c.clone() + } else { + return; }; - } - }); - - tokio::spawn(async move { - let mut b = vec![0u8; 1500]; - while let Ok((mut rtp_packet, _)) = track.read(&mut b).await { - // Update the PayloadType - rtp_packet.header.payload_type = c.payload_type; - - // Marshal into original buffer with updated PayloadType - - let n = rtp_packet.marshal_to(&mut b)?; - - // Write - if let Err(err) = c.conn.send(&b[..n]).await { - // For this particular example, third party applications usually timeout after a short - // amount of time during which the user doesn't have enough time to provide the answer - // to the browser. - // That's why, for this particular example, the user first needs to provide the answer - // to the browser then open the third party application. Therefore we must not kill - // the forward on "connection refused" errors - //if opError, ok := err.(*net.OpError); ok && opError.Err.Error() == "write: connection refused" { - // continue - //} - //panic(err) - if err.to_string().contains("Connection refused") { - continue; - } else { - println!("conn send err: {err}"); - break; - } - } - } - Result::<()>::Ok(()) - }); + // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval + let media_ssrc = track.ssrc(); + + let connection = self.connection.clone(); + tokio::spawn(async move { + let mut result = Result::::Ok(0); + while result.is_ok() { + let timeout = tokio::time::sleep(Duration::from_secs(3)); + tokio::pin!(timeout); + + tokio::select! { + _ = timeout.as_mut() =>{ + if let Some(pc) = connection.upgrade(){ + result = pc.write_rtcp(&[Box::new(PictureLossIndication{ + sender_ssrc: 0, + media_ssrc, + })]).await.map_err(Into::into); + }else{ + break; + } + } + }; + } + }); + + tokio::spawn(async move { + let mut b = vec![0u8; 1500]; + while let Ok((mut rtp_packet, _)) = track.read(&mut b).await { + // Update the PayloadType + rtp_packet.header.payload_type = c.payload_type; + + // Marshal into original buffer with updated PayloadType + + let n = rtp_packet.marshal_to(&mut b)?; + + // Write + if let Err(err) = c.conn.send(&b[..n]).await { + // For this particular example, third party applications usually timeout after a short + // amount of time during which the user doesn't have enough time to provide the answer + // to the browser. + // That's why, for this particular example, the user first needs to provide the answer + // to the browser then open the third party application. Therefore we must not kill + // the forward on "connection refused" errors + //if opError, ok := err.(*net.OpError); ok && opError.Err.Error() == "write: connection refused" { + // continue + //} + //panic(err) + if err.to_string().contains("Connection refused") { + continue; + } else { + println!("conn send err: {err}"); + break; + } + } + } - Box::pin(async {}) - })); + Result::<()>::Ok(()) + }); + } + } - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); - if connection_state == RTCIceConnectionState::Connected { + // Set the handler for ICE connection state + // This will notify you when the peer has connected/disconnected + fn on_ice_connection_state_change( + &mut self, + state: RTCIceConnectionState, + ) -> impl Future + Send { + println!("Connection State has changed {state}"); + if state == RTCIceConnectionState::Connected { println!("Ctrl+C the remote client to stop the demo"); } - Box::pin(async {}) - }, - )); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting: Done forwarding"); - let _ = done_tx.try_send(()); + async {} } - Box::pin(async {}) - })); + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting: Done forwarding"); + let _ = self.done_tx.try_send(()); + } + async {} + } + } + let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); + peer_connection.with_event_handler(ConnectionHandler { + done_tx, + connection: Arc::downgrade(&peer_connection), + udp_conns, + }); // Wait for the offer to be pasted let line = signal::must_read_stdin()?; diff --git a/examples/examples/rtp-to-webrtc/rtp-to-webrtc.rs b/examples/examples/rtp-to-webrtc/rtp-to-webrtc.rs index ddcb0555f..1a4730267 100644 --- a/examples/examples/rtp-to-webrtc/rtp-to-webrtc.rs +++ b/examples/examples/rtp-to-webrtc/rtp-to-webrtc.rs @@ -1,3 +1,4 @@ +use std::future::Future; use std::io::Write; use std::sync::Arc; @@ -13,6 +14,7 @@ use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::PeerConnectionEventHandler; use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; @@ -123,35 +125,43 @@ async fn main() -> Result<()> { let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - let done_tx1 = done_tx.clone(); - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); - if connection_state == RTCIceConnectionState::Failed { - let _ = done_tx1.try_send(()); - } - Box::pin(async {}) - }, - )); + struct ConnectionHandler { + done_tx: tokio::sync::mpsc::Sender<()>, + } - let done_tx2 = done_tx.clone(); - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting: Done forwarding"); - let _ = done_tx2.try_send(()); + impl PeerConnectionEventHandler for ConnectionHandler { + // Set the handler for ICE connection state + // This will notify you when the peer has connected/disconnected + fn on_ice_connection_state_change( + &mut self, + state: RTCIceConnectionState, + ) -> impl Future + Send { + println!("Connection State has changed {state}"); + if state == RTCIceConnectionState::Failed { + let _ = self.done_tx.try_send(()); + } + async {} } - Box::pin(async {}) - })); + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting: Done forwarding"); + let _ = self.done_tx.try_send(()); + } + async {} + } + } + peer_connection.with_event_handler(ConnectionHandler { + done_tx: done_tx.clone(), + }); // Wait for the offer to be pasted let line = signal::must_read_stdin()?; diff --git a/examples/examples/save-to-disk-h264/save-to-disk-h264.rs b/examples/examples/save-to-disk-h264/save-to-disk-h264.rs index 238320b37..d7f3db642 100644 --- a/examples/examples/save-to-disk-h264/save-to-disk-h264.rs +++ b/examples/examples/save-to-disk-h264/save-to-disk-h264.rs @@ -1,6 +1,7 @@ use std::fs::File; +use std::future::Future; use std::io::Write; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use anyhow::Result; use clap::{AppSettings, Arg, Command}; @@ -16,10 +17,13 @@ use webrtc::media::io::h264_writer::H264Writer; use webrtc::media::io::ogg_writer::OggWriter; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::{PeerConnectionEventHandler, RTCPeerConnection}; use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; use webrtc::rtp_transceiver::rtp_codec::{ RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType, }; +use webrtc::rtp_transceiver::rtp_receiver::RTCRtpReceiver; +use webrtc::rtp_transceiver::RTCRtpTransceiver; use webrtc::track::track_remote::TrackRemote; async fn save_to_disk( @@ -201,75 +205,100 @@ async fn main() -> Result<()> { let notify_tx = Arc::new(Notify::new()); let notify_rx = notify_tx.clone(); - // Set a handler for when a new remote track starts, this handler saves buffers to disk as - // an ivf file, since we could have multiple video tracks we provide a counter. - // In your application this is where you would handle/process video - let pc = Arc::downgrade(&peer_connection); - peer_connection.on_track(Box::new(move |track, _, _| { - // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval - let media_ssrc = track.ssrc(); - let pc2 = pc.clone(); - tokio::spawn(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(3)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - if let Some(pc) = pc2.upgrade(){ - result = pc.write_rtcp(&[Box::new(PictureLossIndication{ - sender_ssrc: 0, - media_ssrc, - })]).await.map_err(Into::into); - }else { - break; + struct ConnectionHandler { + connection: Weak, + notify_rx: Arc, + h264_writer: Arc>, + ogg_writer: Arc>, + notify_tx: Arc, + done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + // Set a handler for when a new remote track starts, this handler saves buffers to disk as + // an ivf file, since we could have multiple video tracks we provide a counter. + // In your application this is where you would handle/process video + fn on_track( + &mut self, + track: Arc, + _: Arc, + _: Arc, + ) -> impl Future + Send { + // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval + let media_ssrc = track.ssrc(); + let connection = self.connection.clone(); + tokio::spawn(async move { + let mut result = Result::::Ok(0); + while result.is_ok() { + let timeout = tokio::time::sleep(Duration::from_secs(3)); + tokio::pin!(timeout); + + tokio::select! { + _ = timeout.as_mut() =>{ + if let Some(pc) = connection.upgrade(){ + result = pc.write_rtcp(&[Box::new(PictureLossIndication{ + sender_ssrc: 0, + media_ssrc, + })]).await.map_err(Into::into); + }else { + break; + } } - } - }; - } - }); - - let notify_rx2 = Arc::clone(¬ify_rx); - let h264_writer2 = Arc::clone(&h264_writer); - let ogg_writer2 = Arc::clone(&ogg_writer); - Box::pin(async move { - let codec = track.codec(); - let mime_type = codec.capability.mime_type.to_lowercase(); - if mime_type == MIME_TYPE_OPUS.to_lowercase() { - println!("Got Opus track, saving to disk as output.opus (48 kHz, 2 channels)"); - tokio::spawn(async move { - let _ = save_to_disk(ogg_writer2, track, notify_rx2).await; - }); - } else if mime_type == MIME_TYPE_H264.to_lowercase() { - println!("Got h264 track, saving to disk as output.h264"); - tokio::spawn(async move { - let _ = save_to_disk(h264_writer2, track, notify_rx2).await; - }); + }; + } + }); + + let ogg_writer = self.ogg_writer.clone(); + let h264_writer = self.h264_writer.clone(); + let notify_rx = self.notify_rx.clone(); + async move { + let codec = track.codec(); + let mime_type = codec.capability.mime_type.to_lowercase(); + if mime_type == MIME_TYPE_OPUS.to_lowercase() { + println!("Got Opus track, saving to disk as output.opus (48 kHz, 2 channels)"); + tokio::spawn(async move { + let _ = save_to_disk(ogg_writer, track, notify_rx).await; + }); + } else if mime_type == MIME_TYPE_H264.to_lowercase() { + println!("Got h264 track, saving to disk as output.h264"); + tokio::spawn(async move { + let _ = save_to_disk(h264_writer, track, notify_rx).await; + }); + } } - }) - })); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); + } - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); + // Set the handler for ICE connection state + // This will notify you when the peer has connected/disconnected + fn on_ice_connection_state_change( + &mut self, + state: RTCIceConnectionState, + ) -> impl Future + Send { + println!("Connection State has changed {state}"); - if connection_state == RTCIceConnectionState::Connected { + if state == RTCIceConnectionState::Connected { println!("Ctrl+C the remote client to stop the demo"); - } else if connection_state == RTCIceConnectionState::Failed { - notify_tx.notify_waiters(); + } else if state == RTCIceConnectionState::Failed { + self.notify_tx.notify_waiters(); println!("Done writing media files"); - let _ = done_tx.try_send(()); + let _ = self.done_tx.try_send(()); } - Box::pin(async {}) - }, - )); + async {} + } + } + + let pc = Arc::downgrade(&peer_connection); + let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); + peer_connection.with_event_handler(ConnectionHandler { + connection: pc, + notify_rx, + done_tx, + h264_writer, + notify_tx, + ogg_writer, + }); // Wait for the offer to be pasted let line = signal::must_read_stdin()?; diff --git a/examples/examples/save-to-disk-vpx/save-to-disk-vpx.rs b/examples/examples/save-to-disk-vpx/save-to-disk-vpx.rs index 2b02986b6..c8b88cf4d 100644 --- a/examples/examples/save-to-disk-vpx/save-to-disk-vpx.rs +++ b/examples/examples/save-to-disk-vpx/save-to-disk-vpx.rs @@ -1,6 +1,7 @@ use std::fs::File; +use std::future::Future; use std::io::Write; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use anyhow::Result; use clap::{AppSettings, Arg, Command}; @@ -17,10 +18,13 @@ use webrtc::media::io::ivf_writer::IVFWriter; use webrtc::media::io::ogg_writer::OggWriter; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::{PeerConnectionEventHandler, RTCPeerConnection}; use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; use webrtc::rtp_transceiver::rtp_codec::{ RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType, }; +use webrtc::rtp_transceiver::rtp_receiver::RTCRtpReceiver; +use webrtc::rtp_transceiver::RTCRtpTransceiver; use webrtc::track::track_remote::TrackRemote; async fn save_to_disk( @@ -230,76 +234,104 @@ async fn main() -> Result<()> { // an ivf file, since we could have multiple video tracks we provide a counter. // In your application this is where you would handle/process video let pc = Arc::downgrade(&peer_connection); - peer_connection.on_track(Box::new(move |track, _, _| { - // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval - let media_ssrc = track.ssrc(); - let pc2 = pc.clone(); - tokio::spawn(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - let timeout = tokio::time::sleep(Duration::from_secs(3)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - if let Some(pc) = pc2.upgrade(){ - result = pc.write_rtcp(&[Box::new(PictureLossIndication{ - sender_ssrc: 0, - media_ssrc, - })]).await.map_err(Into::into); - }else{ - break; + + struct ConnectionHandler { + connection: Weak, + ivf_writer: Arc>, + ogg_writer: Arc>, + is_vp9: bool, + notify_rx: Arc, + notify_tx: Arc, + done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_track( + &mut self, + track: Arc, + _: Arc, + _: Arc, + ) -> impl Future + Send { + // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval + let media_ssrc = track.ssrc(); + let connection = self.connection.clone(); + tokio::spawn(async move { + let mut result = Result::::Ok(0); + while result.is_ok() { + let timeout = tokio::time::sleep(Duration::from_secs(3)); + tokio::pin!(timeout); + + tokio::select! { + _ = timeout.as_mut() =>{ + if let Some(pc) = connection.upgrade(){ + result = pc.write_rtcp(&[Box::new(PictureLossIndication{ + sender_ssrc: 0, + media_ssrc, + })]).await.map_err(Into::into); + }else{ + break; + } } - } - }; - } - }); - - let notify_rx2 = Arc::clone(¬ify_rx); - let ivf_writer2 = Arc::clone(&ivf_writer); - let ogg_writer2 = Arc::clone(&ogg_writer); - Box::pin(async move { - let codec = track.codec(); - let mime_type = codec.capability.mime_type.to_lowercase(); - if mime_type == MIME_TYPE_OPUS.to_lowercase() { - println!("Got Opus track, saving to disk as output.opus (48 kHz, 2 channels)"); - tokio::spawn(async move { - let _ = save_to_disk(ogg_writer2, track, notify_rx2).await; - }); - } else if mime_type == MIME_TYPE_VP8.to_lowercase() - || mime_type == MIME_TYPE_VP9.to_lowercase() - { - println!( - "Got {} track, saving to disk as output.ivf", - if is_vp9 { "VP9" } else { "VP8" } - ); - tokio::spawn(async move { - let _ = save_to_disk(ivf_writer2, track, notify_rx2).await; - }); + }; + } + }); + + let ogg_writer = self.ogg_writer.clone(); + let ivf_writer = self.ivf_writer.clone(); + let notify_rx = self.notify_rx.clone(); + async move { + let codec = track.codec(); + let mime_type = codec.capability.mime_type.to_lowercase(); + if mime_type == MIME_TYPE_OPUS.to_lowercase() { + println!("Got Opus track, saving to disk as output.opus (48 kHz, 2 channels)"); + tokio::spawn(async move { + let _ = save_to_disk(ogg_writer, track, notify_rx).await; + }); + } else if mime_type == MIME_TYPE_VP8.to_lowercase() + || mime_type == MIME_TYPE_VP9.to_lowercase() + { + println!( + "Got {} track, saving to disk as output.ivf", + if self.is_vp9 { "VP9" } else { "VP8" } + ); + tokio::spawn(async move { + let _ = save_to_disk(ivf_writer, track, notify_rx).await; + }); + } } - }) - })); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); + } - // Set the handler for ICE connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_ice_connection_state_change(Box::new( - move |connection_state: RTCIceConnectionState| { - println!("Connection State has changed {connection_state}"); + // Set the handler for ICE connection state + // This will notify you when the peer has connected/disconnected + fn on_ice_connection_state_change( + &mut self, + state: RTCIceConnectionState, + ) -> impl Future + Send { + println!("Connection State has changed {state}"); - if connection_state == RTCIceConnectionState::Connected { + if state == RTCIceConnectionState::Connected { println!("Ctrl+C the remote client to stop the demo"); - } else if connection_state == RTCIceConnectionState::Failed { - notify_tx.notify_waiters(); + } else if state == RTCIceConnectionState::Failed { + self.notify_tx.notify_waiters(); println!("Done writing media files"); - let _ = done_tx.try_send(()); + let _ = self.done_tx.try_send(()); } - Box::pin(async {}) - }, - )); + async {} + } + } + + let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); + peer_connection.with_event_handler(ConnectionHandler { + connection: pc, + ivf_writer, + ogg_writer, + is_vp9, + notify_rx, + notify_tx, + done_tx, + }); // Wait for the offer to be pasted let line = signal::must_read_stdin()?; diff --git a/examples/examples/simulcast/simulcast.rs b/examples/examples/simulcast/simulcast.rs index f47c7b620..77d2ef16a 100644 --- a/examples/examples/simulcast/simulcast.rs +++ b/examples/examples/simulcast/simulcast.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; +use std::future::Future; use std::io::Write; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use anyhow::Result; use clap::{AppSettings, Arg, Command}; @@ -13,12 +14,16 @@ use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::{PeerConnectionEventHandler, RTCPeerConnection}; use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; use webrtc::rtp_transceiver::rtp_codec::{ RTCRtpCodecCapability, RTCRtpHeaderExtensionCapability, RTPCodecType, }; +use webrtc::rtp_transceiver::rtp_receiver::RTCRtpReceiver; +use webrtc::rtp_transceiver::RTCRtpTransceiver; use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; +use webrtc::track::track_remote::TrackRemote; use webrtc::Error; #[tokio::main] @@ -153,79 +158,97 @@ async fn main() -> Result<()> { // Set a handler for when a new remote track starts let pc = Arc::downgrade(&peer_connection); - peer_connection.on_track(Box::new(move |track, _, _| { - println!("Track has started"); - - let rid = track.rid().to_owned(); - let output_track = if let Some(output_track) = output_tracks.get(&rid) { - Arc::clone(output_track) - } else { - println!("output_track not found for rid = {rid}"); - return Box::pin(async {}); - }; - - // Start reading from all the streams and sending them to the related output track - let media_ssrc = track.ssrc(); - let pc2 = pc.clone(); - tokio::spawn(async move { - let mut result = Result::::Ok(0); - while result.is_ok() { - println!("Sending pli for stream with rid: {rid}, ssrc: {media_ssrc}"); - - let timeout = tokio::time::sleep(Duration::from_secs(3)); - tokio::pin!(timeout); - - tokio::select! { - _ = timeout.as_mut() =>{ - if let Some(pc) = pc2.upgrade(){ - result = pc.write_rtcp(&[Box::new(PictureLossIndication{ - sender_ssrc: 0, - media_ssrc, - })]).await.map_err(Into::into); - }else{ - break; - } - } + + struct ConnectionHandler { + connection: Weak, + output_tracks: HashMap>, + done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_track( + &mut self, + track: Arc, + _: Arc, + _: Arc, + ) -> impl Future + Send { + async move { + println!("Track has started"); + + let rid = track.rid().to_owned(); + let output_track = if let Some(output_track) = self.output_tracks.get(&rid) { + Arc::clone(output_track) + } else { + println!("output_track not found for rid = {rid}"); + return; }; - } - }); - tokio::spawn(async move { - // Read RTP packets being sent to webrtc-rs - println!("enter track loop {}", track.rid()); - while let Ok((rtp, _)) = track.read_rtp().await { - if let Err(err) = output_track.write_rtp(&rtp).await { - if Error::ErrClosedPipe != err { - println!("output track write_rtp got error: {err} and break"); - break; - } else { - println!("output track write_rtp got error: {err}"); + // Start reading from all the streams and sending them to the related output track + let media_ssrc = track.ssrc(); + let connection = self.connection.clone(); + tokio::spawn(async move { + let mut result = Result::::Ok(0); + while result.is_ok() { + println!("Sending pli for stream with rid: {rid}, ssrc: {media_ssrc}"); + + let timeout = tokio::time::sleep(Duration::from_secs(3)); + tokio::pin!(timeout); + + tokio::select! { + _ = timeout.as_mut() =>{ + if let Some(pc) = connection.upgrade(){ + result = pc.write_rtcp(&[Box::new(PictureLossIndication{ + sender_ssrc: 0, + media_ssrc, + })]).await.map_err(Into::into); + }else{ + break; + } + } + }; + } + }); + + tokio::spawn(async move { + // Read RTP packets being sent to webrtc-rs + println!("enter track loop {}", track.rid()); + while let Ok((rtp, _)) = track.read_rtp().await { + if let Err(err) = output_track.write_rtp(&rtp).await { + if Error::ErrClosedPipe != err { + println!("output track write_rtp got error: {err} and break"); + break; + } else { + println!("output track write_rtp got error: {err}"); + } + } } - } + println!("exit track loop {}", track.rid()); + }); } - println!("exit track loop {}", track.rid()); - }); - - Box::pin(async {}) - })); - - let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); - - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - - if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - println!("Peer Connection has gone to failed exiting"); - let _ = done_tx.try_send(()); } - - Box::pin(async {}) - })); + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + println!("Peer Connection has gone to failed exiting"); + let _ = self.done_tx.try_send(()); + } + async {} + } + } + let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::<()>(1); + peer_connection.with_event_handler(ConnectionHandler { + connection: pc, + output_tracks, + done_tx, + }); // Create an answer let answer = peer_connection.create_answer(None).await?; diff --git a/examples/examples/swap-tracks/swap-tracks.rs b/examples/examples/swap-tracks/swap-tracks.rs index db8cd25d7..a50736b98 100644 --- a/examples/examples/swap-tracks/swap-tracks.rs +++ b/examples/examples/swap-tracks/swap-tracks.rs @@ -1,6 +1,7 @@ +use std::future::Future; use std::io::Write; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use anyhow::Result; use clap::{AppSettings, Arg, Command}; @@ -13,10 +14,15 @@ use webrtc::interceptor::registry::Registry; use webrtc::peer_connection::configuration::RTCConfiguration; use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::peer_connection::{PeerConnectionEventHandler, RTCPeerConnection}; use webrtc::rtcp::payload_feedbacks::picture_loss_indication::PictureLossIndication; +use webrtc::rtp::packet::Packet; use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; +use webrtc::rtp_transceiver::rtp_receiver::RTCRtpReceiver; +use webrtc::rtp_transceiver::RTCRtpTransceiver; use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; +use webrtc::track::track_remote::TrackRemote; use webrtc::Error; #[tokio::main] @@ -144,83 +150,112 @@ async fn main() -> Result<()> { let pc = Arc::downgrade(&peer_connection); let curr_track1 = Arc::clone(&curr_track); let track_count1 = Arc::clone(&track_count); - peer_connection.on_track(Box::new(move |track, _, _| { - let track_num = track_count1.fetch_add(1, Ordering::SeqCst); - - let curr_track2 = Arc::clone(&curr_track1); - let pc2 = pc.clone(); - let packets_tx2 = Arc::clone(&packets_tx); - tokio::spawn(async move { - println!( - "Track has started, of type {}: {}", - track.payload_type(), - track.codec().capability.mime_type - ); - - let mut last_timestamp = 0; - let mut is_curr_track = false; - while let Ok((mut rtp, _)) = track.read_rtp().await { - // Change the timestamp to only be the delta - let old_timestamp = rtp.header.timestamp; - if last_timestamp == 0 { - rtp.header.timestamp = 0 - } else { - rtp.header.timestamp -= last_timestamp; - } - last_timestamp = old_timestamp; - - // Check if this is the current track - if curr_track2.load(Ordering::SeqCst) == track_num { - // If just switched to this track, send PLI to get picture refresh - if !is_curr_track { - is_curr_track = true; - if let Some(pc) = pc2.upgrade() { - if let Err(err) = pc - .write_rtcp(&[Box::new(PictureLossIndication { - sender_ssrc: 0, - media_ssrc: track.ssrc(), - })]) - .await - { - println!("write_rtcp err: {err}"); + struct ConnectionHandler { + track_count: Arc, + current_track: Arc, + connection: Weak, + packets_tx: Arc>, + connected_tx: tokio::sync::mpsc::Sender<()>, + done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_track( + &mut self, + track: Arc, + _: Arc, + _: Arc, + ) -> impl Future + Send { + let track_num = self.track_count.fetch_add(1, Ordering::SeqCst); + + let current_track = self.current_track.clone(); + let connection = self.connection.clone(); + let packets_tx = self.packets_tx.clone(); + + tokio::spawn(async move { + println!( + "Track has started, of type {}: {}", + track.payload_type(), + track.codec().capability.mime_type + ); + + let mut last_timestamp = 0; + let mut is_curr_track = false; + while let Ok((mut rtp, _)) = track.read_rtp().await { + // Change the timestamp to only be the delta + let old_timestamp = rtp.header.timestamp; + if last_timestamp == 0 { + rtp.header.timestamp = 0 + } else { + rtp.header.timestamp -= last_timestamp; + } + last_timestamp = old_timestamp; + + // Check if this is the current track + if current_track.load(Ordering::SeqCst) == track_num { + // If just switched to this track, send PLI to get picture refresh + if !is_curr_track { + is_curr_track = true; + if let Some(pc) = connection.upgrade() { + if let Err(err) = pc + .write_rtcp(&[Box::new(PictureLossIndication { + sender_ssrc: 0, + media_ssrc: track.ssrc(), + })]) + .await + { + println!("write_rtcp err: {err}"); + } + } else { + break; } - } else { - break; } + let _ = packets_tx.send(rtp).await; + } else { + is_curr_track = false; } - let _ = packets_tx2.send(rtp).await; - } else { - is_curr_track = false; } - } - println!( - "Track has ended, of type {}: {}", - track.payload_type(), - track.codec().capability.mime_type - ); - }); + println!( + "Track has ended, of type {}: {}", + track.payload_type(), + track.codec().capability.mime_type + ); + }); - Box::pin(async {}) - })); + async {} + } + + // Set the handler for Peer connection state + // This will notify you when the peer has connected/disconnected + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + println!("Peer Connection State has changed: {state}"); + if state == RTCPeerConnectionState::Connected { + let _ = self.connected_tx.try_send(()); + } else if state == RTCPeerConnectionState::Failed { + // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. + // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. + // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. + let _ = self.done_tx.try_send(()); + } + async {} + } + } let (connected_tx, mut connected_rx) = tokio::sync::mpsc::channel(1); let (done_tx, mut done_rx) = tokio::sync::mpsc::channel(1); - // Set the handler for Peer connection state - // This will notify you when the peer has connected/disconnected - peer_connection.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - println!("Peer Connection State has changed: {s}"); - if s == RTCPeerConnectionState::Connected { - let _ = connected_tx.try_send(()); - } else if s == RTCPeerConnectionState::Failed { - // Wait until PeerConnection has had no network activity for 30 seconds or another failure. It may be reconnected using an ICE Restart. - // Use webrtc.PeerConnectionStateDisconnected if you are interested in detecting faster timeout. - // Note that the PeerConnection may come back from PeerConnectionStateDisconnected. - let _ = done_tx.try_send(()); - } - Box::pin(async move {}) - })); + peer_connection.with_event_handler(ConnectionHandler { + track_count: track_count.clone(), + current_track: curr_track.clone(), + connection: pc, + packets_tx, + connected_tx, + done_tx, + }); // Create an answer let answer = peer_connection.create_answer(None).await?; diff --git a/ice/examples/ping_pong.rs b/ice/examples/ping_pong.rs index 9620383b1..2d1dab819 100644 --- a/ice/examples/ping_pong.rs +++ b/ice/examples/ping_pong.rs @@ -1,12 +1,14 @@ +use std::future::Future; use std::io; use std::sync::Arc; use std::time::Duration; use clap::{App, AppSettings, Arg}; +use hyper::client::HttpConnector; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Client, Method, Request, Response, Server, StatusCode}; use ice::agent::agent_config::AgentConfig; -use ice::agent::Agent; +use ice::agent::{Agent, AgentEventHandler}; use ice::candidate::candidate_base::*; use ice::candidate::*; use ice::network_type::*; @@ -200,19 +202,28 @@ async fn main() -> Result<(), Error> { let client = Arc::new(Client::new()); - // When we have gathered a new ICE Candidate send it to the remote peer let client2 = Arc::clone(&client); - ice_agent.on_candidate(Box::new( - move |c: Option>| { - let client3 = Arc::clone(&client2); - Box::pin(async move { - if let Some(c) = c { + struct AgentHandler { + client: Arc>, + remote_port: u16, + ice_done_tx: tokio::sync::mpsc::Sender<()>, + } + + impl AgentEventHandler for AgentHandler { + // When we have gathered a new ICE Candidate send it to the remote peer + fn on_candidate( + &mut self, + candidate: Option>, + ) -> impl Future + Send { + async move { + if let Some(c) = candidate { println!("posting remoteCandidate with {}", c.marshal()); let req = match Request::builder() .method(Method::POST) .uri(format!( - "http://localhost:{remote_http_port}/remoteCandidate" + "http://localhost:{}/remoteCandidate", + self.remote_port )) .body(Body::from(c.marshal())) { @@ -222,7 +233,7 @@ async fn main() -> Result<(), Error> { return; } }; - let resp = match client3.request(req).await { + let resp = match self.client.request(req).await { Ok(resp) => resp, Err(err) => { println!("{err}"); @@ -231,19 +242,27 @@ async fn main() -> Result<(), Error> { }; println!("Response from remoteCandidate: {}", resp.status()); } - }) - }, - )); + } + } - let (ice_done_tx, mut ice_done_rx) = mpsc::channel::<()>(1); - // When ICE Connection state has change print to stdout - ice_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - println!("ICE Connection State has changed: {c}"); - if c == ConnectionState::Failed { - let _ = ice_done_tx.try_send(()); + // When ICE Connection state has change print to stdout + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + println!("ICE Connection State has changed: {state}"); + if state == ConnectionState::Failed { + let _ = self.ice_done_tx.try_send(()); + } + async {} } - Box::pin(async move {}) - })); + } + let (ice_done_tx, mut ice_done_rx) = mpsc::channel::<()>(1); + ice_agent.with_event_handler(AgentHandler { + client: client.clone(), + remote_port: remote_http_port, + ice_done_tx, + }); // Get the local auth details and send to remote peer let (local_ufrag, local_pwd) = ice_agent.get_local_user_credentials().await; diff --git a/ice/src/agent/agent_gather_test.rs b/ice/src/agent/agent_gather_test.rs index ed90f097a..53c5b5f1b 100644 --- a/ice/src/agent/agent_gather_test.rs +++ b/ice/src/agent/agent_gather_test.rs @@ -111,6 +111,24 @@ async fn test_vnet_gather_listen_udp() -> Result<()> { Ok(()) } +struct CandidateHandler { + done_tx: Arc>>>, +} + +impl AgentEventHandler for CandidateHandler { + fn on_candidate( + &mut self, + candidate: Option>, + ) -> impl Future + Send { + async move { + if candidate.is_some() { + let mut tx = self.done_tx.lock().await; + tx.take(); + } + } + } +} + #[tokio::test] async fn test_vnet_gather_with_nat_1to1_as_host_candidates() -> Result<()> { let external_ip0 = "1.2.3.4"; @@ -154,17 +172,8 @@ async fn test_vnet_gather_with_nat_1to1_as_host_candidates() -> Result<()> { let (done_tx, mut done_rx) = mpsc::channel::<()>(1); let done_tx = Arc::new(Mutex::new(Some(done_tx))); - a.on_candidate(Box::new( - move |c: Option>| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { - if c.is_none() { - let mut tx = done_tx_clone.lock().await; - tx.take(); - } - }) - }, - )); + + a.with_event_handler(CandidateHandler { done_tx }); a.gather_candidates()?; @@ -271,17 +280,8 @@ async fn test_vnet_gather_with_nat_1to1_as_srflx_candidates() -> Result<()> { let (done_tx, mut done_rx) = mpsc::channel::<()>(1); let done_tx = Arc::new(Mutex::new(Some(done_tx))); - a.on_candidate(Box::new( - move |c: Option>| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { - if c.is_none() { - let mut tx = done_tx_clone.lock().await; - tx.take(); - } - }) - }, - )); + + a.with_event_handler(CandidateHandler { done_tx }); a.gather_candidates()?; @@ -455,17 +455,8 @@ async fn test_vnet_gather_muxed_udp() -> Result<()> { let (done_tx, mut done_rx) = mpsc::channel::<()>(1); let done_tx = Arc::new(Mutex::new(Some(done_tx))); - a.on_candidate(Box::new( - move |c: Option>| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { - if c.is_none() { - let mut tx = done_tx_clone.lock().await; - tx.take(); - } - }) - }, - )); + + a.with_event_handler(CandidateHandler { done_tx }); a.gather_candidates()?; diff --git a/ice/src/agent/agent_internal.rs b/ice/src/agent/agent_internal.rs index 8fd2b22b1..dd86f3188 100644 --- a/ice/src/agent/agent_internal.rs +++ b/ice/src/agent/agent_internal.rs @@ -2,6 +2,7 @@ use std::sync::atomic::{AtomicBool, AtomicU64}; use arc_swap::ArcSwapOption; use util::sync::Mutex as SyncMutex; +use util::EventHandler; use super::agent_transport::*; use super::*; @@ -36,10 +37,7 @@ pub struct AgentInternal { pub(crate) chan_candidate_pair_tx: Mutex>>, pub(crate) chan_state_tx: Mutex>>, - pub(crate) on_connection_state_change_hdlr: ArcSwapOption>, - pub(crate) on_selected_candidate_pair_change_hdlr: - ArcSwapOption>, - pub(crate) on_candidate_hdlr: ArcSwapOption>, + pub(crate) events_handler: EventHandler, pub(crate) tie_breaker: AtomicU64, pub(crate) is_controlling: AtomicBool, @@ -108,9 +106,7 @@ impl AgentInternal { chan_candidate_pair_tx: Mutex::new(Some(chan_candidate_pair_tx)), chan_state_tx: Mutex::new(Some(chan_state_tx)), - on_connection_state_change_hdlr: ArcSwapOption::empty(), - on_selected_candidate_pair_change_hdlr: ArcSwapOption::empty(), - on_candidate_hdlr: ArcSwapOption::empty(), + events_handler: EventHandler::empty(), tie_breaker: AtomicU64::new(rand::random::()), is_controlling: AtomicBool::new(config.is_controlling), @@ -1063,11 +1059,13 @@ impl AgentInternal { // Blocking one by the other one causes deadlock. while chan_candidate_pair_rx.recv().await.is_some() { if let (Some(cb), Some(p)) = ( - &*ai.on_selected_candidate_pair_change_hdlr.load(), + &*ai.events_handler.load(), &*ai.agent_conn.selected_pair.load(), ) { - let mut f = cb.lock().await; - f(&p.local, &p.remote).await; + cb.lock() + .await + .inline_on_selected_candidate_pair_change(p.local.clone(), p.remote.clone()) + .await; } } }); @@ -1077,32 +1075,28 @@ impl AgentInternal { loop { tokio::select! { opt_state = chan_state_rx.recv() => { - if let Some(s) = opt_state { - if let Some(handler) = &*ai.on_connection_state_change_hdlr.load() { - let mut f = handler.lock().await; - f(s).await; + if let Some(state) = opt_state { + if let Some(handler) = &*ai.events_handler.load() { + handler.lock().await.inline_on_connection_state_change(state).await; } } else { - while let Some(c) = chan_candidate_rx.recv().await { - if let Some(handler) = &*ai.on_candidate_hdlr.load() { - let mut f = handler.lock().await; - f(c).await; + while let Some(candidate) = chan_candidate_rx.recv().await { + if let Some(handler) = &*ai.events_handler.load() { + handler.lock().await.inline_on_candidate(candidate).await; } } break; } }, opt_cand = chan_candidate_rx.recv() => { - if let Some(c) = opt_cand { - if let Some(handler) = &*ai.on_candidate_hdlr.load() { - let mut f = handler.lock().await; - f(c).await; + if let Some(candidate) = opt_cand { + if let Some(handler) = &*ai.events_handler.load() { + handler.lock().await.inline_on_candidate(candidate).await; } } else { - while let Some(s) = chan_state_rx.recv().await { - if let Some(handler) = &*ai.on_connection_state_change_hdlr.load() { - let mut f = handler.lock().await; - f(s).await; + while let Some(state) = chan_state_rx.recv().await { + if let Some(handler) = &*ai.events_handler.load() { + handler.lock().await.inline_on_connection_state_change(state).await; } } break; diff --git a/ice/src/agent/agent_test.rs b/ice/src/agent/agent_test.rs index 31b7babd3..c0a5900e9 100644 --- a/ice/src/agent/agent_test.rs +++ b/ice/src/agent/agent_test.rs @@ -191,14 +191,23 @@ async fn test_on_selected_candidate_pair_change() -> Result<()> { let a = Agent::new(AgentConfig::default()).await?; let (callback_called_tx, mut callback_called_rx) = mpsc::channel::<()>(1); let callback_called_tx = Arc::new(Mutex::new(Some(callback_called_tx))); - let cb: OnSelectedCandidatePairChangeHdlrFn = Box::new(move |_, _| { - let callback_called_tx_clone = Arc::clone(&callback_called_tx); - Box::pin(async move { - let mut tx = callback_called_tx_clone.lock().await; - tx.take(); - }) - }); - a.on_selected_candidate_pair_change(cb); + + struct AgentHandler { + callback_called_tx: Arc>>>, + } + impl AgentEventHandler for AgentHandler { + fn on_selected_candidate_pair_change( + &mut self, + _: Arc, + _: Arc, + ) -> impl Future + Send { + async move { + let mut tx = self.callback_called_tx.lock().await; + tx.take(); + } + } + } + a.with_event_handler(AgentHandler { callback_called_tx }); let host_config = CandidateHostConfig { base_config: CandidateBaseConfig { @@ -437,7 +446,7 @@ async fn test_connectivity_on_startup() -> Result<()> { }; let a_agent = Arc::new(Agent::new(cfg0).await?); - a_agent.on_connection_state_change(a_notifier); + a_agent.with_event_handler(a_notifier); let cfg1 = AgentConfig { network_types: supported_network_types(), @@ -450,7 +459,7 @@ async fn test_connectivity_on_startup() -> Result<()> { }; let b_agent = Arc::new(Agent::new(cfg1).await?); - b_agent.on_connection_state_change(b_notifier); + b_agent.with_event_handler(b_notifier); // Manual signaling let (a_ufrag, a_pwd) = a_agent.get_local_user_credentials().await; @@ -464,15 +473,29 @@ async fn test_connectivity_on_startup() -> Result<()> { let (_b_cancel_tx, b_cancel_rx) = mpsc::channel(1); let accepting_tx = Arc::new(Mutex::new(Some(accepting_tx))); - a_agent.on_connection_state_change(Box::new(move |s: ConnectionState| { - let accepted_tx_clone = Arc::clone(&accepting_tx); - Box::pin(async move { - if s == ConnectionState::Checking { - let mut tx = accepted_tx_clone.lock().await; - tx.take(); + + struct AgentAHandler { + accepted_tx: Arc>>>, + } + + impl AgentEventHandler for AgentAHandler { + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + async move { + if state == ConnectionState::Checking { + let mut tx = self.accepted_tx.lock().await; + tx.take(); + } } - }) - })); + } + } + + let accepted_tx = Arc::new(Mutex::new(Some(accepted_tx))); + a_agent.with_event_handler(AgentAHandler { + accepted_tx: accepted_tx.clone(), + }); tokio::spawn(async move { let result = a_agent.accept(a_cancel_rx, b_ufrag, b_pwd).await; @@ -545,7 +568,7 @@ async fn test_connectivity_lite() -> Result<()> { }; let a_agent = Arc::new(Agent::new(cfg0).await?); - a_agent.on_connection_state_change(a_notifier); + a_agent.with_event_handler(a_notifier); let cfg1 = AgentConfig { urls: vec![], @@ -558,7 +581,7 @@ async fn test_connectivity_lite() -> Result<()> { }; let b_agent = Arc::new(Agent::new(cfg1).await?); - b_agent.on_connection_state_change(b_notifier); + b_agent.with_event_handler(b_notifier); let _ = connect_with_vnet(&a_agent, &b_agent).await?; @@ -969,43 +992,58 @@ async fn test_connection_state_callback() -> Result<()> { let is_failed_tx = Arc::new(Mutex::new(Some(is_failed_tx))); let is_closed_tx = Arc::new(Mutex::new(Some(is_closed_tx))); - a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let is_checking_tx_clone = Arc::clone(&is_checking_tx); - let is_connected_tx_clone = Arc::clone(&is_connected_tx); - let is_disconnected_tx_clone = Arc::clone(&is_disconnected_tx); - let is_failed_tx_clone = Arc::clone(&is_failed_tx); - let is_closed_tx_clone = Arc::clone(&is_closed_tx); - Box::pin(async move { - match c { - ConnectionState::Checking => { - log::debug!("drop is_checking_tx"); - let mut tx = is_checking_tx_clone.lock().await; - tx.take(); - } - ConnectionState::Connected => { - log::debug!("drop is_connected_tx"); - let mut tx = is_connected_tx_clone.lock().await; - tx.take(); - } - ConnectionState::Disconnected => { - log::debug!("drop is_disconnected_tx"); - let mut tx = is_disconnected_tx_clone.lock().await; - tx.take(); - } - ConnectionState::Failed => { - log::debug!("drop is_failed_tx"); - let mut tx = is_failed_tx_clone.lock().await; - tx.take(); - } - ConnectionState::Closed => { - log::debug!("drop is_closed_tx"); - let mut tx = is_closed_tx_clone.lock().await; - tx.take(); + struct AgentHandler { + is_checking_tx: Arc>>>, + is_connected_tx: Arc>>>, + is_disconnected_tx: Arc>>>, + is_failed_tx: Arc>>>, + is_closed_tx: Arc>>>, + } + + impl AgentEventHandler for AgentHandler { + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + async move { + match state { + ConnectionState::Checking => { + log::debug!("drop is_checking_tx"); + let mut tx = self.is_checking_tx.lock().await; + tx.take(); + } + ConnectionState::Connected => { + log::debug!("drop is_connected_tx"); + let mut tx = self.is_connected_tx.lock().await; + tx.take(); + } + ConnectionState::Disconnected => { + log::debug!("drop is_disconnected_tx"); + let mut tx = self.is_disconnected_tx.lock().await; + tx.take(); + } + ConnectionState::Failed => { + log::debug!("drop is_failed_tx"); + let mut tx = self.is_failed_tx.lock().await; + tx.take(); + } + ConnectionState::Closed => { + log::debug!("drop is_closed_tx"); + let mut tx = self.is_closed_tx.lock().await; + tx.take(); + } + _ => {} } - _ => {} - }; - }) - })); + } + } + } + a_agent.with_event_handler(AgentHandler { + is_checking_tx, + is_connected_tx, + is_disconnected_tx, + is_failed_tx, + is_closed_tx, + }); connect_with_vnet(&a_agent, &b_agent).await?; @@ -1657,15 +1695,24 @@ async fn test_connection_state_failed_delete_all_candidates() -> Result<()> { let (is_failed_tx, mut is_failed_rx) = mpsc::channel::<()>(1); let is_failed_tx = Arc::new(Mutex::new(Some(is_failed_tx))); - a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let is_failed_tx_clone = Arc::clone(&is_failed_tx); - Box::pin(async move { - if c == ConnectionState::Failed { - let mut tx = is_failed_tx_clone.lock().await; - tx.take(); + struct AgentHandler { + is_failed_tx: Arc>>>, + } + + impl AgentEventHandler for AgentHandler { + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + async move { + if state == ConnectionState::Failed { + let mut tx = self.is_failed_tx.lock().await; + tx.take(); + } } - }) - })); + } + } + a_agent.with_event_handler(AgentHandler { is_failed_tx }); connect_with_vnet(&a_agent, &b_agent).await?; let _ = is_failed_rx.recv().await; @@ -1712,32 +1759,42 @@ async fn test_connection_state_connecting_to_failed() -> Result<()> { let is_failed = WaitGroup::new(); let is_checking = WaitGroup::new(); - let connection_state_check = move |wf: Worker, wc: Worker| { - let wf = Arc::new(Mutex::new(Some(wf))); - let wc = Arc::new(Mutex::new(Some(wc))); - let hdlr_fn: OnConnectionStateChangeHdlrFn = Box::new(move |c: ConnectionState| { - let wf_clone = Arc::clone(&wf); - let wc_clone = Arc::clone(&wc); - Box::pin(async move { - if c == ConnectionState::Failed { - let mut f = wf_clone.lock().await; + struct AgentHandler { + wf: Arc>>, + wc: Arc>>, + } + + impl AgentEventHandler for AgentHandler { + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + async move { + if state == ConnectionState::Failed { + let mut f = self.wf.lock().await; f.take(); - } else if c == ConnectionState::Checking { - let mut c = wc_clone.lock().await; + } else if state == ConnectionState::Checking { + let mut c = self.wc.lock().await; c.take(); - } else if c == ConnectionState::Connected || c == ConnectionState::Completed { - panic!("Unexpected ConnectionState: {c}"); + } else if state == ConnectionState::Connected || state == ConnectionState::Completed + { + panic!("Unexpected ConnectionState: {state}"); } - }) - }); - hdlr_fn - }; + } + } + } let (wf1, wc1) = (is_failed.worker(), is_checking.worker()); - a_agent.on_connection_state_change(connection_state_check(wf1, wc1)); + a_agent.with_event_handler(AgentHandler { + wf: Arc::new(Mutex::new(Some(wf1))), + wc: Arc::new(Mutex::new(Some(wc1))), + }); let (wf2, wc2) = (is_failed.worker(), is_checking.worker()); - b_agent.on_connection_state_change(connection_state_check(wf2, wc2)); + b_agent.with_event_handler(AgentHandler { + wf: Arc::new(Mutex::new(Some(wf2))), + wc: Arc::new(Mutex::new(Some(wc2))), + }); let agent_a = Arc::clone(&a_agent); tokio::spawn(async move { @@ -1824,15 +1881,24 @@ async fn test_agent_restart_one_side() -> Result<()> { let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); let cancel_tx = Arc::new(Mutex::new(Some(cancel_tx))); - agent_b.on_connection_state_change(Box::new(move |c: ConnectionState| { - let cancel_tx_clone = Arc::clone(&cancel_tx); - Box::pin(async move { - if c == ConnectionState::Failed || c == ConnectionState::Disconnected { - let mut tx = cancel_tx_clone.lock().await; - tx.take(); + + struct AgentHandler { + cancel_tx: Arc>>>, + } + + impl AgentEventHandler for AgentHandler { + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + async move { + if state == ConnectionState::Failed || state == ConnectionState::Disconnected { + let mut tx = self.cancel_tx.lock().await; + tx.take(); + } } - }) - })); + } + } agent_a.restart("".to_owned(), "".to_owned()).await?; @@ -1886,10 +1952,10 @@ async fn test_agent_restart_both_side() -> Result<()> { generate_candidate_address_strings(agent_b.get_local_candidates().await); let (a_notifier, mut a_connected) = on_connected(); - agent_a.on_connection_state_change(a_notifier); + agent_a.with_event_handler(a_notifier); let (b_notifier, mut b_connected) = on_connected(); - agent_b.on_connection_state_change(b_notifier); + agent_b.with_event_handler(b_notifier); // Restart and Re-Signal agent_a.restart("".to_owned(), "".to_owned()).await?; @@ -1981,20 +2047,33 @@ async fn test_close_in_connection_state_callback() -> Result<()> { let (is_connected_tx, mut is_connected_rx) = mpsc::channel::<()>(1); let is_closed_tx = Arc::new(Mutex::new(Some(is_closed_tx))); let is_connected_tx = Arc::new(Mutex::new(Some(is_connected_tx))); - a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let is_closed_tx_clone = Arc::clone(&is_closed_tx); - let is_connected_tx_clone = Arc::clone(&is_connected_tx); - Box::pin(async move { - if c == ConnectionState::Connected { - let mut tx = is_connected_tx_clone.lock().await; - tx.take(); - } else if c == ConnectionState::Closed { - let mut tx = is_closed_tx_clone.lock().await; - tx.take(); + + struct AgentHandler { + is_closed_tx: Arc>>>, + is_connected_tx: Arc>>>, + } + + impl AgentEventHandler for AgentHandler { + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + async move { + if state == ConnectionState::Connected { + let mut tx = self.is_connected_tx.lock().await; + tx.take(); + } else if state == ConnectionState::Closed { + let mut tx = self.is_closed_tx.lock().await; + tx.take(); + } } - }) - })); + } + } + a_agent.with_event_handler(AgentHandler { + is_closed_tx, + is_connected_tx, + }); connect_with_vnet(&a_agent, &b_agent).await?; let _ = is_connected_rx.recv().await; @@ -2036,15 +2115,24 @@ async fn test_run_task_in_connection_state_callback() -> Result<()> { let (is_complete_tx, mut is_complete_rx) = mpsc::channel::<()>(1); let is_complete_tx = Arc::new(Mutex::new(Some(is_complete_tx))); - a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let is_complete_tx_clone = Arc::clone(&is_complete_tx); - Box::pin(async move { - if c == ConnectionState::Connected { - let mut tx = is_complete_tx_clone.lock().await; - tx.take(); + struct AgentHandler { + is_complete_tx: Arc>>>, + } + + impl AgentEventHandler for AgentHandler { + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + async move { + if state == ConnectionState::Connected { + let mut tx = self.is_complete_tx.lock().await; + tx.take(); + } } - }) - })); + } + } + a_agent.with_event_handler(AgentHandler { is_complete_tx }); connect_with_vnet(&a_agent, &b_agent).await?; @@ -2088,27 +2176,43 @@ async fn test_run_task_in_selected_candidate_pair_change_callback() -> Result<() let (is_tested_tx, mut is_tested_rx) = mpsc::channel::<()>(1); let is_tested_tx = Arc::new(Mutex::new(Some(is_tested_tx))); - a_agent.on_selected_candidate_pair_change(Box::new( - move |_: &Arc, _: &Arc| { - let is_tested_tx_clone = Arc::clone(&is_tested_tx); - Box::pin(async move { - let mut tx = is_tested_tx_clone.lock().await; - tx.take(); - }) - }, - )); - let (is_complete_tx, mut is_complete_rx) = mpsc::channel::<()>(1); let is_complete_tx = Arc::new(Mutex::new(Some(is_complete_tx))); - a_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let is_complete_tx_clone = Arc::clone(&is_complete_tx); - Box::pin(async move { - if c == ConnectionState::Connected { - let mut tx = is_complete_tx_clone.lock().await; + + struct AgentHandler { + is_complete_tx: Arc>>>, + is_tested_tx: Arc>>>, + } + + impl AgentEventHandler for AgentHandler { + fn on_selected_candidate_pair_change( + &mut self, + _: Arc, + _: Arc, + ) -> impl Future + Send { + async move { + let mut tx = self.is_tested_tx.lock().await; tx.take(); } - }) - })); + } + + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + async move { + if state == ConnectionState::Connected { + let mut tx = self.is_complete_tx.lock().await; + tx.take(); + } + } + } + } + + a_agent.with_event_handler(AgentHandler { + is_complete_tx, + is_tested_tx, + }); connect_with_vnet(&a_agent, &b_agent).await?; @@ -2137,7 +2241,7 @@ async fn test_lite_lifecycle() -> Result<()> { .await?, ); - a_agent.on_connection_state_change(a_notifier); + a_agent.with_event_handler(a_notifier); let disconnected_duration = Duration::from_secs(1); let failed_duration = Duration::from_secs(1); @@ -2165,24 +2269,31 @@ async fn test_lite_lifecycle() -> Result<()> { let b_disconnected_tx = Arc::new(Mutex::new(Some(b_disconnected_tx))); let b_failed_tx = Arc::new(Mutex::new(Some(b_failed_tx))); - b_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let b_connected_tx_clone = Arc::clone(&b_connected_tx); - let b_disconnected_tx_clone = Arc::clone(&b_disconnected_tx); - let b_failed_tx_clone = Arc::clone(&b_failed_tx); + struct AgentHandler { + connected_tx: Arc>>>, + disconnected_tx: Arc>>>, + failed_tx: Arc>>>, + } - Box::pin(async move { - if c == ConnectionState::Connected { - let mut tx = b_connected_tx_clone.lock().await; - tx.take(); - } else if c == ConnectionState::Disconnected { - let mut tx = b_disconnected_tx_clone.lock().await; - tx.take(); - } else if c == ConnectionState::Failed { - let mut tx = b_failed_tx_clone.lock().await; - tx.take(); + impl AgentEventHandler for AgentHandler { + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + async move { + if state == ConnectionState::Connected { + let mut tx = self.connected_tx.lock().await; + tx.take(); + } else if state == ConnectionState::Disconnected { + let mut tx = self.disconnected_tx.lock().await; + tx.take(); + } else if state == ConnectionState::Failed { + let mut tx = self.failed_tx.lock().await; + tx.take(); + } } - }) - })); + } + } connect_with_vnet(&b_agent, &a_agent).await?; diff --git a/ice/src/agent/agent_transport_test.rs b/ice/src/agent/agent_transport_test.rs index 8d4a8016f..bf5d5b59f 100644 --- a/ice/src/agent/agent_transport_test.rs +++ b/ice/src/agent/agent_transport_test.rs @@ -22,7 +22,7 @@ pub(crate) async fn pipe( cfg0.network_types = supported_network_types(); let a_agent = Arc::new(Agent::new(cfg0).await?); - a_agent.on_connection_state_change(a_notifier); + a_agent.with_event_handler(a_notifier); let mut cfg1 = if let Some(cfg) = default_config1 { cfg @@ -33,7 +33,7 @@ pub(crate) async fn pipe( cfg1.network_types = supported_network_types(); let b_agent = Arc::new(Agent::new(cfg1).await?); - b_agent.on_connection_state_change(b_notifier); + b_agent.with_event_handler(b_notifier); let (a_conn, b_conn) = connect_with_vnet(&a_agent, &b_agent).await?; diff --git a/ice/src/agent/agent_vnet_test.rs b/ice/src/agent/agent_vnet_test.rs index c461d83a9..7d716f7a5 100644 --- a/ice/src/agent/agent_vnet_test.rs +++ b/ice/src/agent/agent_vnet_test.rs @@ -8,7 +8,7 @@ use util::vnet::chunk::Chunk; use util::vnet::router::Nic; use util::vnet::*; use util::Conn; -use waitgroup::WaitGroup; +use waitgroup::{WaitGroup, Worker}; use super::*; use crate::candidate::candidate_base::unmarshal_candidate; @@ -311,7 +311,7 @@ pub(crate) async fn pipe_with_vnet( }; let a_agent = Arc::new(Agent::new(cfg0).await?); - a_agent.on_connection_state_change(a_notifier); + a_agent.with_event_handler(a_notifier); let nat_1to1_ips = if a1test_config.nat_1to1_ip_candidate_type != CandidateType::Unspecified { vec![VNET_GLOBAL_IPB.to_owned()] @@ -329,7 +329,7 @@ pub(crate) async fn pipe_with_vnet( }; let b_agent = Arc::new(Agent::new(cfg1).await?); - b_agent.on_connection_state_change(b_notifier); + b_agent.with_event_handler(b_notifier); let (a_conn, b_conn) = connect_with_vnet(&a_agent, &b_agent).await?; @@ -341,19 +341,29 @@ pub(crate) async fn pipe_with_vnet( Ok((a_conn, b_conn)) } -pub(crate) fn on_connected() -> (OnConnectionStateChangeHdlrFn, mpsc::Receiver<()>) { - let (done_tx, done_rx) = mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); - let hdlr_fn: OnConnectionStateChangeHdlrFn = Box::new(move |state: ConnectionState| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { +pub(crate) struct ConnectionStateNotifier { + done_tx: Arc>>>, +} + +impl AgentEventHandler for ConnectionStateNotifier { + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + async move { if state == ConnectionState::Connected { - let mut tx = done_tx_clone.lock().await; + let mut tx = self.done_tx.lock().await; tx.take(); } - }) - }); - (hdlr_fn, done_rx) + } + } +} + +pub(crate) fn on_connected() -> (ConnectionStateNotifier, mpsc::Receiver<()>) { + let (done_tx, done_rx) = mpsc::channel::<()>(1); + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + let handler = ConnectionStateNotifier { done_tx }; + (handler, done_rx) } pub(crate) async fn gather_and_exchange_candidates( @@ -362,32 +372,34 @@ pub(crate) async fn gather_and_exchange_candidates( ) -> Result<(), Error> { let wg = WaitGroup::new(); - let w1 = Arc::new(Mutex::new(Some(wg.worker()))); - a_agent.on_candidate(Box::new( - move |candidate: Option>| { - let w3 = Arc::clone(&w1); - Box::pin(async move { + struct CandidateHandler { + worker: Arc>>, + } + + impl AgentEventHandler for CandidateHandler { + fn on_candidate( + &mut self, + candidate: Option>, + ) -> impl Future + Send { + async move { if candidate.is_none() { - let mut w = w3.lock().await; - w.take(); + let mut worker = self.worker.lock().await; + worker.take(); } - }) - }, - )); + } + } + } + + let candidate_handler_1 = CandidateHandler { + worker: Arc::new(Mutex::new(Some(wg.worker()))), + }; + a_agent.with_event_handler(candidate_handler_1); a_agent.gather_candidates()?; - let w2 = Arc::new(Mutex::new(Some(wg.worker()))); - b_agent.on_candidate(Box::new( - move |candidate: Option>| { - let w3 = Arc::clone(&w2); - Box::pin(async move { - if candidate.is_none() { - let mut w = w3.lock().await; - w.take(); - } - }) - }, - )); + let candidate_handler_2 = CandidateHandler { + worker: Arc::new(Mutex::new(Some(wg.worker()))), + }; + b_agent.with_event_handler(candidate_handler_2); b_agent.gather_candidates()?; wg.wait().await; @@ -821,22 +833,34 @@ async fn test_disconnected_to_connected() -> Result<(), Error> { let (controlling_state_changes_tx, mut controlling_state_changes_rx) = mpsc::channel::(100); let controlling_state_changes_tx = Arc::new(controlling_state_changes_tx); - controlling_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let controlling_state_changes_tx_clone = Arc::clone(&controlling_state_changes_tx); - Box::pin(async move { - let _ = controlling_state_changes_tx_clone.try_send(c); - }) - })); let (controlled_state_changes_tx, mut controlled_state_changes_rx) = mpsc::channel::(100); let controlled_state_changes_tx = Arc::new(controlled_state_changes_tx); - controlled_agent.on_connection_state_change(Box::new(move |c: ConnectionState| { - let controlled_state_changes_tx_clone = Arc::clone(&controlled_state_changes_tx); - Box::pin(async move { - let _ = controlled_state_changes_tx_clone.try_send(c); - }) - })); + struct AgentStateHandler { + changes_tx: Arc>, + } + + impl AgentEventHandler for AgentStateHandler { + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + async move { + let _ = self.changes_tx.try_send(state); + } + } + } + + let controlling_state_handler = AgentStateHandler { + changes_tx: controlling_state_changes_tx, + }; + let controlled_state_handler = AgentStateHandler { + changes_tx: controlled_state_changes_tx, + }; + + controlling_agent.with_event_handler(controlling_state_handler); + controlled_agent.with_event_handler(controlled_state_handler); connect_with_vnet(&controlling_agent, &controlled_agent).await?; diff --git a/ice/src/agent/mod.rs b/ice/src/agent/mod.rs index d61b392e4..72c3662f9 100644 --- a/ice/src/agent/mod.rs +++ b/ice/src/agent/mod.rs @@ -35,7 +35,7 @@ use stun::xoraddr::*; use tokio::sync::{broadcast, mpsc, Mutex}; use tokio::time::{Duration, Instant}; use util::vnet::net::*; -use util::Buffer; +use util::{Buffer, EventHandler, FutureUnit}; use crate::agent::agent_gather::GatherCandidatesInternalParams; use crate::candidate::*; @@ -119,6 +119,84 @@ pub struct Agent { pub(crate) gather_candidate_cancel: Option, } +pub trait AgentEventHandler: Send { + /// Sets a handler that is fired when new candidates gathered. When the gathering process + /// complete the last candidate is None. + fn on_candidate( + &mut self, + candidate: Option>, + ) -> impl Future + Send { + async {} + } + + /// Sets a handler that is fired when the connection state changes. + fn on_connection_state_change( + &mut self, + connection_state: ConnectionState, + ) -> impl Future + Send { + async {} + } + + /// Sets a handler that is fired when the final candidate pair is selected. + fn on_selected_candidate_pair_change( + &mut self, + local_candidate: Arc, + remote_candidate: Arc, + ) -> impl Future + Send { + async {} + } +} + +pub trait InlineAgentEventHandler: Send { + fn inline_on_candidate( + &mut self, + candidate: Option>, + ) -> FutureUnit<'_>; + + fn inline_on_connection_state_change( + &mut self, + connection_state: ConnectionState, + ) -> FutureUnit<'_>; + + fn inline_on_selected_candidate_pair_change( + &mut self, + local_candidate: Arc, + remote_candidate: Arc, + ) -> FutureUnit<'_>; +} + +impl InlineAgentEventHandler for T +where + T: AgentEventHandler, +{ + fn inline_on_candidate( + &mut self, + candidate: Option>, + ) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_candidate(candidate).await }) + } + + fn inline_on_connection_state_change( + &mut self, + connection_state: ConnectionState, + ) -> FutureUnit<'_> { + FutureUnit::from_async( + async move { self.on_connection_state_change(connection_state).await }, + ) + } + + fn inline_on_selected_candidate_pair_change( + &mut self, + local_candidate: Arc, + remote_candidate: Arc, + ) -> FutureUnit<'_> { + FutureUnit::from_async(async move { + self.on_selected_candidate_pair_change(local_candidate, remote_candidate) + .await + }) + } +} + impl Agent { /// Creates a new Agent. pub async fn new(config: AgentConfig) -> Result { @@ -195,7 +273,7 @@ impl Agent { Arc::new(Net::new(None)) }; - let agent = Self { + let mut agent = Self { udp_network: config.udp_network, internal: Arc::new(ai), interface_filter: Arc::clone(&config.interface_filter), @@ -237,26 +315,8 @@ impl Agent { self.internal.agent_conn.bytes_sent() } - /// Sets a handler that is fired when the connection state changes. - pub fn on_connection_state_change(&self, f: OnConnectionStateChangeHdlrFn) { - self.internal - .on_connection_state_change_hdlr - .store(Some(Arc::new(Mutex::new(f)))) - } - - /// Sets a handler that is fired when the final candidate pair is selected. - pub fn on_selected_candidate_pair_change(&self, f: OnSelectedCandidatePairChangeHdlrFn) { - self.internal - .on_selected_candidate_pair_change_hdlr - .store(Some(Arc::new(Mutex::new(f)))) - } - - /// Sets a handler that is fired when new candidates gathered. When the gathering process - /// complete the last candidate is nil. - pub fn on_candidate(&self, f: OnCandidateHdlrFn) { - self.internal - .on_candidate_hdlr - .store(Some(Arc::new(Mutex::new(f)))); + pub fn with_event_handler(&self, handler: impl AgentEventHandler + Send + Sync + 'static) { + self.internal.events_handler.store(Box::new(handler)); } /// Adds a new remote candidate. @@ -440,7 +500,7 @@ impl Agent { return Err(Error::ErrMultipleGatherAttempted); } - if self.internal.on_candidate_hdlr.load().is_none() { + if self.internal.events_handler.load().is_none() { return Err(Error::ErrNoOnCandidateHandler); } diff --git a/ice/src/candidate/candidate_relay_test.rs b/ice/src/candidate/candidate_relay_test.rs index c1fd4bfd0..7cb4afc44 100644 --- a/ice/src/candidate/candidate_relay_test.rs +++ b/ice/src/candidate/candidate_relay_test.rs @@ -81,7 +81,7 @@ async fn test_relay_only_connection() -> Result<(), Error> { let a_agent = Arc::new(Agent::new(cfg0).await?); let (a_notifier, mut a_connected) = on_connected(); - a_agent.on_connection_state_change(a_notifier); + a_agent.with_event_handler(a_notifier); let cfg1 = AgentConfig { network_types: supported_network_types(), @@ -99,7 +99,7 @@ async fn test_relay_only_connection() -> Result<(), Error> { let b_agent = Arc::new(Agent::new(cfg1).await?); let (b_notifier, mut b_connected) = on_connected(); - b_agent.on_connection_state_change(b_notifier); + b_agent.with_event_handler(b_notifier); connect_with_vnet(&a_agent, &b_agent).await?; diff --git a/ice/src/candidate/candidate_server_reflexive_test.rs b/ice/src/candidate/candidate_server_reflexive_test.rs index ca40de1a7..083b6fd45 100644 --- a/ice/src/candidate/candidate_server_reflexive_test.rs +++ b/ice/src/candidate/candidate_server_reflexive_test.rs @@ -60,7 +60,7 @@ async fn test_server_reflexive_only_connection() -> Result<()> { let a_agent = Arc::new(Agent::new(cfg0).await?); let (a_notifier, mut a_connected) = on_connected(); - a_agent.on_connection_state_change(a_notifier); + a_agent.with_event_handler(a_notifier); let cfg1 = AgentConfig { network_types: vec![NetworkType::Udp4], @@ -76,7 +76,7 @@ async fn test_server_reflexive_only_connection() -> Result<()> { let b_agent = Arc::new(Agent::new(cfg1).await?); let (b_notifier, mut b_connected) = on_connected(); - b_agent.on_connection_state_change(b_notifier); + b_agent.with_event_handler(b_notifier); connect_with_vnet(&a_agent, &b_agent).await?; diff --git a/ice/src/mdns/mdns_test.rs b/ice/src/mdns/mdns_test.rs index 604010390..e8071505f 100644 --- a/ice/src/mdns/mdns_test.rs +++ b/ice/src/mdns/mdns_test.rs @@ -1,4 +1,5 @@ use regex::Regex; +use std::future::Future; use tokio::sync::{mpsc, Mutex}; use super::*; @@ -25,7 +26,7 @@ async fn test_multicast_dns_only_connection() -> Result<()> { let a_agent = Arc::new(Agent::new(cfg0).await?); let (a_notifier, mut a_connected) = on_connected(); - a_agent.on_connection_state_change(a_notifier); + a_agent.with_event_handler(a_notifier); let cfg1 = AgentConfig { network_types: vec![NetworkType::Udp4], @@ -36,7 +37,7 @@ async fn test_multicast_dns_only_connection() -> Result<()> { let b_agent = Arc::new(Agent::new(cfg1).await?); let (b_notifier, mut b_connected) = on_connected(); - b_agent.on_connection_state_change(b_notifier); + b_agent.with_event_handler(b_notifier); connect_with_vnet(&a_agent, &b_agent).await?; let _ = a_connected.recv().await; @@ -59,7 +60,7 @@ async fn test_multicast_dns_mixed_connection() -> Result<()> { let a_agent = Arc::new(Agent::new(cfg0).await?); let (a_notifier, mut a_connected) = on_connected(); - a_agent.on_connection_state_change(a_notifier); + a_agent.with_event_handler(a_notifier); let cfg1 = AgentConfig { network_types: vec![NetworkType::Udp4], @@ -70,7 +71,7 @@ async fn test_multicast_dns_mixed_connection() -> Result<()> { let b_agent = Arc::new(Agent::new(cfg1).await?); let (b_notifier, mut b_connected) = on_connected(); - b_agent.on_connection_state_change(b_notifier); + b_agent.with_event_handler(b_notifier); connect_with_vnet(&a_agent, &b_agent).await?; let _ = a_connected.recv().await; @@ -109,17 +110,26 @@ async fn test_multicast_dns_static_host_name() -> Result<()> { let (done_tx, mut done_rx) = mpsc::channel::<()>(1); let done_tx = Arc::new(Mutex::new(Some(done_tx))); - a.on_candidate(Box::new( - move |c: Option>| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { - if c.is_none() { - let mut tx = done_tx_clone.lock().await; + + struct CandidateHandler { + done_tx: Arc>>>, + } + + impl AgentEventHandler for CandidateHandler { + fn on_candidate( + &mut self, + candidate: Option>, + ) -> impl Future + Send { + async move { + if candidate.is_none() { + let mut tx = self.done_tx.lock().await; tx.take(); } - }) - }, - )); + } + } + } + + a.with_event_handler(CandidateHandler { done_tx }); a.gather_candidates()?; diff --git a/util/Cargo.toml b/util/Cargo.toml index 5054363c3..9b3207e04 100644 --- a/util/Cargo.toml +++ b/util/Cargo.toml @@ -27,6 +27,7 @@ log = "0.4" rand = "0.8" bytes = "1" thiserror = "1" +arc-swap = "1" [target.'cfg(not(windows))'.dependencies] nix = "0.26.2" diff --git a/util/src/event_handler.rs b/util/src/event_handler.rs new file mode 100644 index 000000000..717a61099 --- /dev/null +++ b/util/src/event_handler.rs @@ -0,0 +1,62 @@ +use arc_swap::{ArcSwapOption, Guard}; +use std::sync::Arc; +use tokio::sync::Mutex; + +pub struct EventHandler { + // FIXME: it would be preferred if we didnt have to double allocate here + // (type is ArcSwapAny>>>) but since ArcSwaps implementation uses an + // AtomicPtr (which does not support unsized types as there is not language support for atomic + // operations larger than a word), it has to be sized for now. + // + // I(jumbeldliam) am also unsure if it would be necessary to include the Arc as all of the implemented + // fields are Arc'd anyway, and I dont think we need both + inner: ArcSwapOption>>, +} + +impl EventHandler { + pub fn empty() -> Self { + Self { + inner: ArcSwapOption::empty(), + } + } + + pub fn with_handler(handler: Box) -> Self { + Self { + inner: Some(Arc::new(Mutex::new(handler))).into(), + } + } + + pub fn load(&self) -> Guard>>>> { + //FIXME: if there was a way to get a MutexGuard<'_, T> instead of + //having what we have now that would be great + self.inner.load() + } + + pub fn store(&self, handler: Box) { + self.inner.store(Some(Arc::new(Mutex::new(handler)))) + } + + pub fn swap(&mut self, handle: Box) -> Option>>> { + self.inner.swap(Some(Arc::new(Mutex::new(handle)))) + } +} + +impl Default for EventHandler { + fn default() -> Self { + Self::empty() + } +} + +mod test { + use super::*; + struct T { + a: EventHandler, + } + + impl T { + fn new(val: impl Send + Sync + 'static) -> Self { + let a: EventHandler = EventHandler::with_handler(Box::new(val)); + Self { a } + } + } +} diff --git a/util/src/future.rs b/util/src/future.rs new file mode 100644 index 000000000..8080932d3 --- /dev/null +++ b/util/src/future.rs @@ -0,0 +1,50 @@ +use core::future::Future; +use core::pin::Pin; +use core::ptr::NonNull; +use core::task::{Context, Poll}; + +#[repr(transparent)] +pub struct FutureUnit<'a> { + inner: NonNull + Send + 'a>, +} + +unsafe impl Send for FutureUnit<'_> {} + +impl<'a> FutureUnit<'a> { + pub fn from_async(async_fn: impl Future + Send + 'a) -> Self { + //FIXME: optimistically should be non-heap allocated as they should each contain only + //one byte. I(jumbeldliam) would prefer to have a more ergonomic api upfront which can be + //changed later (and as to why I made inner NonNull rather than Box) + let boxed: Box + Send + 'a> = Box::new(async_fn); + let boxed = Box::into_raw(boxed); + + // SAFETY: Box::into_raw always returns a valid ptr + let inner = unsafe { NonNull::new_unchecked(boxed) }; + FutureUnit { inner } + } +} + +impl Future for FutureUnit<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let mut inner = Pin::into_inner(self).inner; + // SAFETY: the pin has not been moved so we have not moved out of the ptr + let inner_pin = unsafe { Pin::new_unchecked(inner.as_mut()) }; + inner_pin.poll(cx) + } +} + +impl Unpin for FutureUnit<'_> {} + +impl Drop for FutureUnit<'_> { + fn drop(&mut self) { + unsafe { + // SAFETY: the pointer is still valid + // so it is okay to construct a box out of it + drop(Box::from_raw(self.inner.as_ptr())) + } + } +} + +//TODO: proc macro for inline trait variants diff --git a/util/src/lib.rs b/util/src/lib.rs index b149bbb50..2be722b60 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -14,9 +14,14 @@ extern crate lazy_static; #[macro_use] extern crate bitflags; +pub mod event_handler; pub mod fixed_big_int; +pub mod future; pub mod replay_detector; +pub use event_handler::EventHandler; +pub use future::FutureUnit; + /// KeyingMaterialExporter to extract keying material. /// /// This trait sits here to avoid getting a direct dependency between diff --git a/webrtc/src/data_channel/data_channel_test.rs b/webrtc/src/data_channel/data_channel_test.rs index 65e8eb9f1..5290121e0 100644 --- a/webrtc/src/data_channel/data_channel_test.rs +++ b/webrtc/src/data_channel/data_channel_test.rs @@ -9,6 +9,7 @@ use waitgroup::WaitGroup; use super::*; use crate::api::media_engine::MediaEngine; use crate::api::{APIBuilder, API}; +use crate::data::data_channel::DataChannel; use crate::data_channel::data_channel_init::RTCDataChannelInit; //use log::LevelFilter; //use std::io::Write; @@ -17,14 +18,17 @@ use crate::dtls_transport::RTCDtlsTransport; use crate::error::flatten_errs; use crate::ice_transport::ice_candidate::RTCIceCandidate; use crate::ice_transport::ice_connection_state::RTCIceConnectionState; -use crate::ice_transport::ice_gatherer::{RTCIceGatherOptions, RTCIceGatherer}; +use crate::ice_transport::ice_gatherer::{ + IceGathererEventHandler, RTCIceGatherOptions, RTCIceGatherer, +}; use crate::ice_transport::ice_parameters::RTCIceParameters; use crate::ice_transport::ice_role::RTCIceRole; use crate::ice_transport::RTCIceTransport; use crate::peer_connection::configuration::RTCConfiguration; use crate::peer_connection::peer_connection_test::*; -use crate::peer_connection::RTCPeerConnection; +use crate::peer_connection::{PeerConnectionEventHandler, RTCPeerConnection}; use crate::sctp_transport::sctp_transport_capabilities::SCTPTransportCapabilities; +use crate::sctp_transport::SctpTransportEventHandler; // EXPECTED_LABEL represents the label of the data channel we are trying to test. // Some other channels may have been created during initialization (in the Wasm @@ -140,41 +144,75 @@ async fn test_data_channel_open() -> Result<()> { let open_calls_tx = Arc::new(open_calls_tx); let done_tx = Arc::new(done_tx); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - if d.label() == EXPECTED_LABEL { - let open_calls_tx2 = Arc::clone(&open_calls_tx); - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move { - d.on_open(Box::new(move || { - Box::pin(async move { - let _ = open_calls_tx2.send(()).await; - }) - })); - d.on_message(Box::new(move |_: DataChannelMessage| { - let done_tx3 = Arc::clone(&done_tx2); + + struct ChannelHandler { + open_calls_tx: Arc>, + done_tx: Arc>, + } + + impl PeerConnectionEventHandler for ChannelHandler { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + struct DataHandler { + open_calls_tx: Arc>, + done_tx: Arc>, + } + + impl RTCDataChannelEventHandler for DataHandler { + fn on_open(&mut self) -> impl Future + Send { + async move { + let _ = self.open_calls_tx.send(()).await; + } + } + + fn on_message( + &mut self, + _: DataChannelMessage, + ) -> impl Future + Send { + let done_tx = self.done_tx.clone(); tokio::spawn(async move { // Wait a little bit to ensure all messages are processed. tokio::time::sleep(Duration::from_millis(100)).await; - let _ = done_tx3.send(()).await; + let _ = done_tx.send(()).await; }); - Box::pin(async {}) - })); - }) - } else { - Box::pin(async {}) + async {} + } + } + + async move { + if channel.label() != EXPECTED_LABEL { + return; + } + channel.with_event_handler(DataHandler { + open_calls_tx: self.open_calls_tx.clone(), + done_tx: self.done_tx.clone(), + }) + } } - })); + } + answer_pc.with_event_handler(ChannelHandler { + open_calls_tx, + done_tx, + }); let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; let dc2 = Arc::clone(&dc); - dc.on_open(Box::new(move || { - Box::pin(async move { - let result = dc2.send_text("Ping".to_owned()).await; - assert!(result.is_ok(), "Failed to send string on data channel"); - }) - })); + struct DataChannelHandler { + channel: Arc, + } + + impl RTCDataChannelEventHandler for DataChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + async move { + let result = self.channel.send_text("Ping".to_owned()).await; + assert!(result.is_ok(), "Failed to send string on data channel"); + } + } + } signal_pair(&mut offer_pc, &mut answer_pc).await?; close_pair(&offer_pc, &answer_pc, done_rx).await; @@ -195,47 +233,74 @@ async fn test_data_channel_send_before_signaling() -> Result<()> { let (mut offer_pc, mut answer_pc) = new_pair(&api).await?; - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); + impl PeerConnectionEventHandler for ChannelHandler { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async move { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if channel.label() != EXPECTED_LABEL { + return; + } + + struct MessageHandler { + channel: Arc, + } + + impl RTCDataChannelEventHandler for MessageHandler { + fn on_message( + &mut self, + _: DataChannelMessage, + ) -> impl Future + Send { + async move { + let result = self.channel.send(&Bytes::from(b"Pong".to_vec())).await; + assert!(result.is_ok(), "Failed to send string on data channel"); + } + } + } + + channel.with_event_handler(MessageHandler { + channel: channel.clone(), + }); + + assert!(channel.ordered(), "Ordered should be set to true"); + } } - Box::pin(async move { - let d2 = Arc::clone(&d); - d.on_message(Box::new(move |_: DataChannelMessage| { - let d3 = Arc::clone(&d2); - Box::pin(async move { - let result = d3.send(&Bytes::from(b"Pong".to_vec())).await; - assert!(result.is_ok(), "Failed to send string on data channel"); - }) - })); - assert!(d.ordered(), "Ordered should be set to true"); - }) - })); + } let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; assert!(dc.ordered(), "Ordered should be set to true"); - let dc2 = Arc::clone(&dc); - dc.on_open(Box::new(move || { - let dc3 = Arc::clone(&dc2); - Box::pin(async move { - let result = dc3.send_text("Ping".to_owned()).await; - assert!(result.is_ok(), "Failed to send string on data channel"); - }) - })); + struct ChannelHandler { + channel: Arc, + done_tx: Arc>>>, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + async move { + let result = self.channel.send_text("Ping".to_owned()).await; + assert!(result.is_ok(), "Failed to send string on data channel"); + } + } + + fn on_message(&mut self, _: DataChannelMessage) -> impl Future + Send { + async move { + let mut done = self.done_tx.lock().await; + done.take(); + } + } + } let (done_tx, done_rx) = mpsc::channel::<()>(1); let done_tx = Arc::new(Mutex::new(Some(done_tx))); - dc.on_message(Box::new(move |_: DataChannelMessage| { - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move { - let mut done = done_tx2.lock().await; - done.take(); - }) - })); + dc.with_event_handler(ChannelHandler { + channel: dc.clone(), + done_tx, + }); signal_pair(&mut offer_pc, &mut answer_pc).await?; @@ -251,70 +316,133 @@ async fn test_data_channel_send_after_connected() -> Result<()> { let (mut offer_pc, mut answer_pc) = new_pair(&api).await?; - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - Box::pin(async move { - let d2 = Arc::clone(&d); - d.on_message(Box::new(move |_: DataChannelMessage| { - let d3 = Arc::clone(&d2); + struct ConnectionHandler { + channel: Arc, + done_tx: Arc>>>, + } - Box::pin(async move { - let result = d3.send(&Bytes::from(b"Pong".to_vec())).await; - assert!(result.is_ok(), "Failed to send string on data channel"); - }) - })); - assert!(d.ordered(), "Ordered should be set to true"); - }) - })); + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async move { + if channel.label() != EXPECTED_LABEL { + return; + } - let dc = offer_pc - .create_data_channel(EXPECTED_LABEL, None) - .await - .expect("Failed to create a PC pair for testing"); + struct MessageHandler { + channel: Arc, + } - let (done_tx, done_rx) = mpsc::channel::<()>(1); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); + impl RTCDataChannelEventHandler for MessageHandler { + fn on_message( + &mut self, + _: DataChannelMessage, + ) -> impl Future + Send { + async move { + let result = self.channel.send(&Bytes::from(b"Pong".to_vec())).await; + assert!(result.is_ok(), "Failed to send string on data channel"); + } + } + } + channel.with_event_handler(MessageHandler { + channel: channel.clone(), + }); + assert!(channel.ordered(), "Ordered should be set to true"); + } + } - //once := &sync.Once{} - offer_pc.on_ice_connection_state_change(Box::new(move |state: RTCIceConnectionState| { - let done_tx1 = Arc::clone(&done_tx); - let dc1 = Arc::clone(&dc); - Box::pin(async move { - if state == RTCIceConnectionState::Connected - || state == RTCIceConnectionState::Completed - { + fn on_ice_connection_state_change( + &mut self, + state: RTCIceConnectionState, + ) -> impl Future + Send { + async move { + if state != RTCIceConnectionState::Connected + && state != RTCIceConnectionState::Completed + { + return; + } // wasm fires completed state multiple times /*once.Do(func()*/ { - assert!(dc1.ordered(), "Ordered should be set to true"); + assert!(self.channel.ordered(), "Ordered should be set to true"); - dc1.on_message(Box::new(move |_: DataChannelMessage| { - let done_tx2 = Arc::clone(&done_tx1); - Box::pin(async move { - let mut done = done_tx2.lock().await; + struct MessageHandler { + done_tx: Arc>>>, + } + + fn take_tx<'a>( + tx: &'a Arc>>>, + ) -> impl Future + Send + 'a { + async move { + let mut done = tx.lock().await; done.take(); - }) - })); + } + } + + impl RTCDataChannelEventHandler for MessageHandler { + fn on_message( + &mut self, + _: DataChannelMessage, + ) -> impl Future + Send { + take_tx(&self.done_tx) + } + } + + self.channel.with_event_handler(MessageHandler { + done_tx: self.done_tx.clone(), + }); - if dc1.send_text("Ping".to_owned()).await.is_err() { + if self.channel.send_text("Ping".to_owned()).await.is_err() { // wasm binding doesn't fire OnOpen (we probably already missed it) - let dc2 = Arc::clone(&dc1); - dc1.on_open(Box::new(move || { - let dc3 = Arc::clone(&dc2); - Box::pin(async move { - let result = dc3.send_text("Ping".to_owned()).await; - assert!(result.is_ok(), "Failed to send string on data channel"); - }) - })); + struct LateMessageHandler { + channel: Arc, + done_tx: Arc>>>, + } + + impl RTCDataChannelEventHandler for LateMessageHandler { + fn on_message( + &mut self, + _: DataChannelMessage, + ) -> impl Future + Send { + take_tx(&self.done_tx) + } + + fn on_open(&mut self) -> impl Future + Send { + async move { + let result = self.channel.send_text("Ping".to_owned()).await; + assert!( + result.is_ok(), + "Failed to send string on data channel" + ); + } + } + } + + self.channel.with_event_handler(LateMessageHandler { + channel: self.channel.clone(), + done_tx: self.done_tx.clone(), + }); } } } - }) - })); + } + } + + let dc = offer_pc + .create_data_channel(EXPECTED_LABEL, None) + .await + .expect("Failed to create a PC pair for testing"); + + let (done_tx, done_rx) = mpsc::channel::<()>(1); + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + + //once := &sync.Once{} + offer_pc.with_event_handler(ConnectionHandler { + channel: dc.clone(), + done_tx, + }); signal_pair(&mut offer_pc, &mut answer_pc).await?; @@ -382,28 +510,40 @@ async fn test_data_channel_parameters_max_packet_life_time_exchange() -> Result< ); let done_tx = Arc::new(Mutex::new(Some(done_tx))); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - // Check if parameters are correctly set - assert_eq!( - d.ordered(), - ordered, - "Ordered should be same value as set in DataChannelInit" - ); - assert_eq!( - d.max_packet_lifetime(), - max_packet_life_time, - "should match" - ); - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move { - let mut done = done_tx2.lock().await; - done.take(); - }) - })); + struct ConnectionHandler { + expected_ordered: bool, + expected_max_packet_lifetime: u16, + done_tx: Arc>>>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async move { + if channel.label() != EXPECTED_LABEL { + return; + } + + // Check if parameters are correctly set + assert_eq!( + channel.ordered(), + self.expected_ordered, + "Ordered should be same value as set in DataChannelInit" + ); + assert_eq!( + channel.max_packet_lifetime(), + self.expected_max_packet_lifetime, + "should match" + ); + + let mut done = self.done_tx.lock().await; + done.take(); + } + } + } close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; Ok(()) @@ -431,22 +571,39 @@ async fn test_data_channel_parameters_max_retransmits_exchange() -> Result<()> { assert_eq!(dc.max_retransmits(), max_retransmits, "should match"); let done_tx = Arc::new(Mutex::new(Some(done_tx))); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); + + struct ConnectionHandler { + done_tx: Arc>>>, + expected_max_retransmits: u16, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async move { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if channel.label() != EXPECTED_LABEL { + return; + } + assert!(!channel.ordered(), "Ordered should be set to false"); + assert_eq!( + self.expected_max_retransmits, + channel.max_retransmits(), + "should match" + ); + let mut done = self.done_tx.lock().await; + done.take(); + } } + } - // Check if parameters are correctly set - assert!(!d.ordered(), "Ordered should be set to false"); - assert_eq!(max_retransmits, d.max_retransmits(), "should match"); - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move { - let mut done = done_tx2.lock().await; - done.take(); - }) - })); + answer_pc.with_event_handler(ConnectionHandler { + done_tx, + expected_max_retransmits: max_retransmits, + }); close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; @@ -476,25 +633,36 @@ async fn test_data_channel_parameters_protocol_exchange() -> Result<()> { ); let done_tx = Arc::new(Mutex::new(Some(done_tx))); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - // Check if parameters are correctly set - assert_eq!( - protocol, - d.protocol(), - "Protocol should match what channel creator declared" - ); - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move { - let mut done = done_tx2.lock().await; - done.take(); - }) - })); + struct ConnectionHandler { + expected_protocol: String, + done_tx: Arc>>>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async move { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if channel.label() != EXPECTED_LABEL { + return; + } + + // Check if parameters are correctly set + assert_eq!( + self.expected_protocol, + channel.protocol(), + "Protocol should match what channel creator declared" + ); + + let mut done = self.done_tx.lock().await; + done.take(); + } + } + } close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; @@ -522,37 +690,58 @@ async fn test_data_channel_parameters_negotiated_exchange() -> Result<()> { .create_data_channel(EXPECTED_LABEL, Some(options)) .await?; - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Ignore our default channel, exists to force ICE candidates. See signalPair for more info - if d.label() == "initial_data_channel" { - return Box::pin(async {}); + struct AnswerChannelHandler; + + impl PeerConnectionEventHandler for AnswerChannelHandler { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async move { + // Ignore our default channel, exists to force ICE candidates. See signalPair for more info + if channel.label() == "initial_data_channel" { + return; + } + panic!("OnDataChannel must not be fired when negotiated == true"); + } } - panic!("OnDataChannel must not be fired when negotiated == true"); - })); + } - offer_pc.on_data_channel(Box::new(move |_d: Arc| { - panic!("OnDataChannel must not be fired when negotiated == true"); - })); + struct OfferChannelHandler; + + impl PeerConnectionEventHandler for OfferChannelHandler { + fn on_data_channel(&mut self, _: Arc) -> impl Future + Send { + async { + panic!("OnDataChannel must not be fired when negotiated == true"); + } + } + } + + answer_pc.with_event_handler(AnswerChannelHandler); + offer_pc.with_event_handler(OfferChannelHandler); let seen_answer_message = Arc::new(AtomicBool::new(false)); let seen_offer_message = Arc::new(AtomicBool::new(false)); - let seen_answer_message2 = Arc::clone(&seen_answer_message); - answer_datachannel.on_message(Box::new(move |msg: DataChannelMessage| { - if msg.is_string && msg.data == EXPECTED_MESSAGE { - seen_answer_message2.store(true, Ordering::SeqCst); - } - - Box::pin(async {}) - })); + struct DataHandler { + seen_message: Arc, + } - let seen_offer_message2 = Arc::clone(&seen_offer_message); - offer_datachannel.on_message(Box::new(move |msg: DataChannelMessage| { - if msg.is_string && msg.data == EXPECTED_MESSAGE { - seen_offer_message2.store(true, Ordering::SeqCst); + impl RTCDataChannelEventHandler for DataHandler { + fn on_message(&mut self, msg: DataChannelMessage) -> impl Future + Send { + if msg.is_string && msg.data == EXPECTED_MESSAGE { + self.seen_message.store(true, Ordering::SeqCst); + } + async {} } - Box::pin(async {}) - })); + } + + answer_datachannel.with_event_handler(DataHandler { + seen_message: seen_answer_message.clone(), + }); + offer_datachannel.with_event_handler(DataHandler { + seen_message: seen_offer_message.clone(), + }); let done_tx = Arc::new(Mutex::new(Some(done_tx))); tokio::spawn(async move { @@ -604,26 +793,36 @@ async fn test_data_channel_event_handlers() -> Result<()> { dc.do_open(); let on_open_called_tx = Arc::new(Mutex::new(Some(on_open_called_tx))); - dc.on_open(Box::new(move || { - let on_open_called_tx2 = Arc::clone(&on_open_called_tx); - Box::pin(async move { - let mut done = on_open_called_tx2.lock().await; - done.take(); - }) - })); + struct ChannelHandler { + done_tx: Arc>>>, + on_message_called_tx: Arc>>>, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + async move { + let mut done = self.done_tx.lock().await; + done.take(); + } + } + + fn on_message(&mut self, _: DataChannelMessage) -> impl Future + Send { + async move { + let mut called = self.on_message_called_tx.lock().await; + called.take(); + } + } + } let on_message_called_tx = Arc::new(Mutex::new(Some(on_message_called_tx))); - dc.on_message(Box::new(move |_: DataChannelMessage| { - let on_message_called_tx2 = Arc::clone(&on_message_called_tx); - Box::pin(async move { - let mut done = on_message_called_tx2.lock().await; - done.take(); - }) - })); + dc.with_event_handler(ChannelHandler { + done_tx: on_open_called_tx, + on_message_called_tx, + }); // Verify that the set handlers are called dc.do_open(); - dc.do_message(DataChannelMessage { + dc.loop_back_message(DataChannelMessage { is_string: false, data: Bytes::from_static(b"o hai"), }) @@ -650,55 +849,52 @@ async fn test_data_channel_messages_are_ordered() -> Result<()> { let out_tx = Arc::new(out_tx); - let out_tx1 = Arc::clone(&out_tx); - dc.on_message(Box::new(move |msg: DataChannelMessage| { - let out_tx2 = Arc::clone(&out_tx1); + struct MessageHandler { + out_tx: Arc>, + iter: u64, + } - Box::pin(async move { - // randomly sleep - let r = rand::random::() % m; - tokio::time::sleep(Duration::from_millis(r)).await; + impl RTCDataChannelEventHandler for MessageHandler { + fn on_message(&mut self, msg: DataChannelMessage) -> impl Future + Send { + async move { + // randomly sleep + let r = rand::random::() % self.iter; + tokio::time::sleep(Duration::from_millis(r)).await; - let mut buf = [0u8; 8]; - for i in 0..8 { - buf[i] = msg.data[i]; + let mut buf = [0u8; 8]; + for i in 0..8 { + buf[i] = msg.data[i]; + } + let s = u64::from_be_bytes(buf); + + let _ = self.out_tx.send(s).await; } - let s = u64::from_be_bytes(buf); + } + } - let _ = out_tx2.send(s).await; - }) - })); + dc.with_event_handler(MessageHandler { + out_tx: out_tx.clone(), + iter: 0, + }); tokio::spawn(async move { for j in 1..=m { let buf = j.to_be_bytes().to_vec(); - dc.do_message(DataChannelMessage { + dc.loop_back_message(DataChannelMessage { is_string: false, data: Bytes::from(buf), }) .await; + // Change the registered handler a couple of times to make sure // that everything continues to work, we don't lose messages, etc. if j % 2 == 0 { let out_tx1 = Arc::clone(&out_tx); - dc.on_message(Box::new(move |msg: DataChannelMessage| { - let out_tx2 = Arc::clone(&out_tx1); - - Box::pin(async move { - // randomly sleep - let r = rand::random::() % m; - tokio::time::sleep(Duration::from_millis(r)).await; - - let mut buf = [0u8; 8]; - for i in 0..8 { - buf[i] = msg.data[i]; - } - let s = u64::from_be_bytes(buf); - - let _ = out_tx2.send(s).await; - }) - })); + dc.with_event_handler(MessageHandler { + out_tx: out_tx.clone(), + iter: j, + }); } } }); @@ -749,27 +945,40 @@ async fn test_data_channel_parameters_go() -> Result<()> { ); let done_tx = Arc::new(Mutex::new(Some(done_tx))); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } + struct ConnectionHandler { + max_packet_lifetime: u16, + done_tx: Arc>>>, + } - // Check if parameters are correctly set - assert!(d.ordered, "Ordered should be set to true"); - assert_eq!( - max_packet_life_time, - d.max_packet_lifetime(), - "should match" - ); - - let done_tx2 = Arc::clone(&done_tx); - Box::pin(async move { - let mut done = done_tx2.lock().await; - done.take(); - }) - })); + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async move { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if channel.label() != EXPECTED_LABEL { + return; + } + + // Check if parameters are correctly set + assert!(channel.ordered, "Ordered should be set to true"); + assert_eq!( + self.max_packet_lifetime, + channel.max_packet_lifetime(), + "should match" + ); + + let mut done = self.done_tx.lock().await; + done.take(); + } + } + } + answer_pc.with_event_handler(ConnectionHandler { + max_packet_lifetime: max_packet_life_time, + done_tx, + }); close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; } @@ -834,69 +1043,103 @@ async fn test_data_channel_buffered_amount_set_before_open() -> Result<()> { let done_tx = Arc::new(Mutex::new(Some(done_tx))); let n_packets_received = Arc::new(AtomicU16::new(0)); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - let done_tx2 = Arc::clone(&done_tx); - let n_packets_received2 = Arc::clone(&n_packets_received); - Box::pin(async move { - d.on_message(Box::new(move |_msg: DataChannelMessage| { - let n = n_packets_received2.fetch_add(1, Ordering::SeqCst); - if n == 9 { - let done_tx3 = Arc::clone(&done_tx2); - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(10)).await; - let mut done = done_tx3.lock().await; - done.take(); - }); - } + struct ConnectionHandler { + packets_recieved: Arc, + done_tx: Arc>>>, + } - Box::pin(async {}) - })); + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async move { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if channel.label() != EXPECTED_LABEL { + return; + } - assert!(d.ordered(), "Ordered should be set to true"); - }) - })); + channel.with_event_handler(MessageHandler { + packets_recieved: self.packets_recieved.clone(), + done_tx: self.done_tx.clone(), + }); + struct MessageHandler { + packets_recieved: Arc, + done_tx: Arc>>>, + } + impl RTCDataChannelEventHandler for MessageHandler { + fn on_message( + &mut self, + _: DataChannelMessage, + ) -> impl Future + Send { + let packets_recieved = self.packets_recieved.fetch_add(1, Ordering::SeqCst); + if packets_recieved == 9 { + let done_tx = self.done_tx.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + let mut done = done_tx.lock().await; + done.take(); + }); + } + async {} + } + } + assert!(channel.ordered(), "Ordered should be set to true"); + } + } + } let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; assert!(dc.ordered(), "Ordered should be set to true"); let dc2 = Arc::clone(&dc); - dc.on_open(Box::new(move || { - let dc3 = Arc::clone(&dc2); - Box::pin(async move { - for _ in 0..10 { - assert!( - dc3.send(&buf).await.is_ok(), - "Failed to send string on data channel" - ); - assert_eq!( - 1500, - dc3.buffered_amount_low_threshold().await, - "value mismatch" - ); + + struct ChannelHandler { + channel: Arc, + buf: Bytes, + callbacks: Arc, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + async move { + for _ in 0..10 { + assert!( + self.channel.send(&self.buf).await.is_ok(), + "Failed to send string on data channel" + ); + assert_eq!( + 1500, + self.channel.buffered_amount_low_threshold().await, + "value mismatch" + ); + } } - }) - })); + } - dc.on_message(Box::new(|_msg: DataChannelMessage| Box::pin(async {}))); + fn on_message(&mut self, _: DataChannelMessage) -> impl Future + Send { + async {} + } + + fn on_buffered_amount_low(&mut self, _: ()) -> impl Future + Send { + self.callbacks.fetch_add(1, Ordering::SeqCst); + async {} + } + } // The value is temporarily stored in the dc object // until the dc gets opened dc.set_buffered_amount_low_threshold(1500).await; // The callback function is temporarily stored in the dc object // until the dc gets opened - let n_cbs2 = Arc::clone(&n_cbs); - dc.on_buffered_amount_low(Box::new(move || { - n_cbs2.fetch_add(1, Ordering::SeqCst); - Box::pin(async {}) - })) - .await; + dc.with_event_handler(ChannelHandler { + channel: dc.clone(), + buf, + callbacks: n_cbs.clone(), + }); signal_pair(&mut offer_pc, &mut answer_pc).await?; @@ -925,67 +1168,102 @@ async fn test_data_channel_buffered_amount_set_after_open() -> Result<()> { let done_tx = Arc::new(Mutex::new(Some(done_tx))); let n_packets_received = Arc::new(AtomicU16::new(0)); - answer_pc.on_data_channel(Box::new(move |d: Arc| { - // Make sure this is the data channel we were looking for. (Not the one - // created in signalPair). - if d.label() != EXPECTED_LABEL { - return Box::pin(async {}); - } - let done_tx2 = Arc::clone(&done_tx); - let n_packets_received2 = Arc::clone(&n_packets_received); - Box::pin(async move { - d.on_message(Box::new(move |_msg: DataChannelMessage| { - let n = n_packets_received2.fetch_add(1, Ordering::SeqCst); - if n == 9 { - let done_tx3 = Arc::clone(&done_tx2); - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(10)).await; - let mut done = done_tx3.lock().await; - done.take(); - }); + struct ConnectionHandler { + done_tx: Arc>>>, + packets_recieved: Arc, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async move { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if channel.label() != EXPECTED_LABEL {} + + struct MessageHandler { + done_tx: Arc>>>, + packets_recieved: Arc, } - Box::pin(async {}) - })); + impl RTCDataChannelEventHandler for MessageHandler { + fn on_message( + &mut self, + _: DataChannelMessage, + ) -> impl Future + Send { + let packets_recieved = self.packets_recieved.fetch_add(1, Ordering::SeqCst); + if packets_recieved == 9 { + let done_tx = self.done_tx.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + let mut done = done_tx.lock().await; + done.take(); + }); + } + async {} + } + } - assert!(d.ordered(), "Ordered should be set to true"); - }) - })); + channel.with_event_handler(MessageHandler { + done_tx: self.done_tx.clone(), + packets_recieved: self.packets_recieved.clone(), + }); + } + } + } + answer_pc.with_event_handler(ConnectionHandler { + done_tx, + packets_recieved: n_packets_received, + }); let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; - assert!(dc.ordered(), "Ordered should be set to true"); - let dc2 = Arc::clone(&dc); - let n_cbs2 = Arc::clone(&n_cbs); - dc.on_open(Box::new(move || { - let dc3 = Arc::clone(&dc2); - Box::pin(async move { - // The value should directly be passed to sctp - dc3.set_buffered_amount_low_threshold(1500).await; - // The callback function should directly be passed to sctp - dc3.on_buffered_amount_low(Box::new(move || { - n_cbs2.fetch_add(1, Ordering::SeqCst); - Box::pin(async {}) - })) - .await; + struct ChannelHandler { + channel: Arc, + callbacks: Arc, + buf: Bytes, + } - for _ in 0..10 { - assert!( - dc3.send(&buf).await.is_ok(), - "Failed to send string on data channel" - ); - assert_eq!( - 1500, - dc3.buffered_amount_low_threshold().await, - "value mismatch" - ); + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + async move { + // The value should directly be passed to sctp + self.channel.set_buffered_amount_low_threshold(1500).await; + // The callback function should directly be passed to sctp + + for _ in 0..10 { + assert!( + self.channel.send(&self.buf).await.is_ok(), + "Failed to send string on data channel" + ); + assert_eq!( + 1500, + self.channel.buffered_amount_low_threshold().await, + "value mismatch" + ); + } } - }) - })); + } + fn on_buffered_amount_low(&mut self, _: ()) -> impl Future + Send { + self.callbacks.fetch_add(1, Ordering::SeqCst); + async {} + } + fn on_message(&mut self, _: DataChannelMessage) -> impl Future + Send { + async {} + } + } - dc.on_message(Box::new(|_msg: DataChannelMessage| Box::pin(async {}))); + let dc2 = Arc::clone(&dc); + let n_cbs2 = Arc::clone(&n_cbs); + dc.with_event_handler(ChannelHandler { + channel: dc.clone(), + callbacks: n_cbs.clone(), + buf, + }); signal_pair(&mut offer_pc, &mut answer_pc).await?; @@ -1017,32 +1295,47 @@ async fn test_eof_detach() -> Result<()> { let (dc_chan_tx, mut dc_chan_rx) = mpsc::channel(1); let dc_chan_tx = Arc::new(dc_chan_tx); - pcb.on_data_channel(Box::new(move |dc: Arc| { - if dc.label() != label { - return Box::pin(async {}); - } - log::debug!("OnDataChannel was called"); - let dc_chan_tx2 = Arc::clone(&dc_chan_tx); - let dc2 = Arc::clone(&dc); - Box::pin(async move { - let dc3 = Arc::clone(&dc2); - dc2.on_open(Box::new(move || { - let dc_chan_tx3 = Arc::clone(&dc_chan_tx2); - let dc4 = Arc::clone(&dc3); - Box::pin(async move { - let detached = match dc4.detach().await { - Ok(detached) => detached, - Err(err) => { - log::debug!("Detach failed: {}", err); - panic!(); - } - }; - let _ = dc_chan_tx3.send(detached).await; - }) - })); - }) - })); + struct ConnectionHandler<'a> { + label: &'a str, + detached_chan_tx: Arc>>, + } + impl PeerConnectionEventHandler for ConnectionHandler<'_> { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async move { + if channel.label() != self.label { + return; + } + + struct ChannelHandler { + channel: Arc, + detached_chan_tx: Arc>>, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + async move { + let detached = match self.channel.detach().await { + Ok(detached) => detached, + Err(err) => { + log::debug!("Detach failed: {err}"); + panic!(); + } + }; + let _ = self.detached_chan_tx.send(detached).await; + } + } + } + } + } + } + pcb.with_event_handler(ConnectionHandler { + label, + detached_chan_tx: dc_chan_tx, + }); let w = wg.worker(); tokio::spawn(async move { @@ -1072,12 +1365,19 @@ async fn test_eof_detach() -> Result<()> { log::debug!("Waiting for data channel to open"); let (open_tx, mut open_rx) = mpsc::channel::<()>(1); let open_tx = Arc::new(open_tx); - attached.on_open(Box::new(move || { - let open_tx2 = Arc::clone(&open_tx); - Box::pin(async move { - let _ = open_tx2.send(()).await; - }) - })); + + struct ChannelHandler { + open_tx: Arc>, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + async move { + self.open_tx.send(()).await; + } + } + } + attached.with_event_handler(ChannelHandler { open_tx }); let _ = open_rx.recv().await; log::debug!("data channel opened"); @@ -1123,71 +1423,94 @@ async fn test_eof_no_detach() -> Result<()> { let (dcb_closed_ch_tx, mut dcb_closed_ch_rx) = mpsc::channel::<()>(1); let dcb_closed_ch_tx = Arc::new(dcb_closed_ch_tx); - pcb.on_data_channel(Box::new(move |dc: Arc| { - if dc.label() != label { - return Box::pin(async {}); - } - log::debug!("pcb: new datachannel: {}", dc.label()); - - let dcb_closed_ch_tx2 = Arc::clone(&dcb_closed_ch_tx); - Box::pin(async move { - // Register channel opening handling - dc.on_open(Box::new(move || { - log::debug!("pcb: datachannel opened"); - Box::pin(async {}) - })); - - dc.on_close(Box::new(move || { - // (2) - log::debug!("pcb: data channel closed"); - let dcb_closed_ch_tx3 = Arc::clone(&dcb_closed_ch_tx2); - Box::pin(async move { - let _ = dcb_closed_ch_tx3.send(()).await; + struct ConnectionHandler<'a> { + label: &'a str, + channel_closed: Arc>, + } + + impl PeerConnectionEventHandler for ConnectionHandler<'_> { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async move { + if channel.label() != self.label { + return; + } + + struct ChannelHandler { + channel_closed: Arc>, + } + + impl RTCDataChannelEventHandler for ChannelHandler { + fn on_open(&mut self) -> impl Future + Send { + log::debug!("pcb: datachannel opened"); + async {} + } + + fn on_close(&mut self) -> impl Future + Send { + async move { + log::debug!("pcb: data channel closed"); + self.channel_closed.send(()).await; + } + } + + fn on_message( + &mut self, + msg: DataChannelMessage, + ) -> impl Future + Send { + let test_data: &'static [u8] = b"this is some test data"; + log::debug!("pcb: received ping: {:?}", msg.data); + assert_eq!(&msg.data[..], test_data, "data mismatch"); + async {} + } + } + channel.with_event_handler(ChannelHandler { + channel_closed: self.channel_closed.clone(), }) - })); - - // Register the OnMessage to handle incoming messages - log::debug!("pcb: registering onMessage callback"); - dc.on_message(Box::new(|dc_msg: DataChannelMessage| { - let test_data: &'static [u8] = b"this is some test data"; - log::debug!("pcb: received ping: {:?}", dc_msg.data); - assert_eq!(&dc_msg.data[..], test_data, "data mismatch"); - Box::pin(async {}) - })); - }) - })); + } + } + } let dca = pca.create_data_channel(label, None).await?; let dca2 = Arc::clone(&dca); - dca.on_open(Box::new(move || { - log::debug!("pca: data channel opened"); - log::debug!("pca: sending {:?}", test_data); - let dca3 = Arc::clone(&dca2); - Box::pin(async move { - let _ = dca3.send(&Bytes::from_static(test_data)).await; - log::debug!("pca: sent ping"); - assert!(dca3.close().await.is_ok(), "should succeed"); // <-- dca closes - }) - })); - - let dca_closed_ch_tx = Arc::new(dca_closed_ch_tx); - dca.on_close(Box::new(move || { - // (1) - log::debug!("pca: data channel closed"); - let dca_closed_ch_tx2 = Arc::clone(&dca_closed_ch_tx); - Box::pin(async move { - let _ = dca_closed_ch_tx2.send(()).await; - }) - })); - // Register the OnMessage to handle incoming messages - log::debug!("pca: registering onMessage callback"); - dca.on_message(Box::new(move |dc_msg: DataChannelMessage| { - log::debug!("pca: received pong: {:?}", &dc_msg.data[..]); - assert_eq!(&dc_msg.data[..], test_data, "data mismatch"); - Box::pin(async {}) - })); + struct DataChannelAccept { + channel: Arc, + test_data: &'static [u8], + closed_channel_tx: Arc>, + } + + impl RTCDataChannelEventHandler for DataChannelAccept { + fn on_open(&mut self) -> impl Future + Send { + async move { + log::debug!("pca: data channel opened"); + log::debug!("pca: sending {:?}", self.test_data); + let _ = self.channel.send(&Bytes::from_static(self.test_data)).await; + log::debug!("pca: sent ping"); + assert!(self.channel.close().await.is_ok(), "should succeed"); // <-- dca closes + } + } + + fn on_close(&mut self) -> impl Future + Send { + async move { + log::debug!("pca: data channel closed"); + self.closed_channel_tx.send(()).await; + } + } + + fn on_message(&mut self, msg: DataChannelMessage) -> impl Future + Send { + log::debug!("pca: received pong: {:?}", &msg.data[..]); + assert_eq!(&msg.data[..], self.test_data, "data mismatch"); + async {} + } + } + dca.with_event_handler(DataChannelAccept { + channel: dca.clone(), + test_data, + closed_channel_tx: Arc::new(dca_closed_ch_tx), + }); signal_pair(&mut pca, &mut pcb).await?; @@ -1216,12 +1539,20 @@ async fn test_data_channel_non_standard_session_description() -> Result<()> { let (on_data_channel_called_tx, mut on_data_channel_called_rx) = mpsc::channel::<()>(1); let on_data_channel_called_tx = Arc::new(on_data_channel_called_tx); - answer_pc.on_data_channel(Box::new(move |_: Arc| { - let on_data_channel_called_tx2 = Arc::clone(&on_data_channel_called_tx); - Box::pin(async move { - let _ = on_data_channel_called_tx2.send(()).await; - }) - })); + struct ConnectionHandler { + data_channel_called: Arc>, + } + + impl PeerConnectionEventHandler for ConnectionHandler { + fn on_data_channel(&mut self, _: Arc) -> impl Future + Send { + async move { + self.data_channel_called.send(()).await; + } + } + } + answer_pc.with_event_handler(ConnectionHandler { + data_channel_called: on_data_channel_called_tx, + }); let offer = offer_pc.create_offer(None).await?; @@ -1327,16 +1658,24 @@ impl TestOrtcStack { async fn get_signal(&self) -> Result { let (gather_finished_tx, mut gather_finished_rx) = mpsc::channel::<()>(1); let gather_finished_tx = Arc::new(gather_finished_tx); - self.gatherer - .on_local_candidate(Box::new(move |i: Option| { - let gather_finished_tx2 = Arc::clone(&gather_finished_tx); - Box::pin(async move { - if i.is_none() { - let _ = gather_finished_tx2.send(()).await; - } - }) - })); + struct GathererHandler { + gather_finished_tx: Arc>, + } + impl IceGathererEventHandler for GathererHandler { + fn on_local_candidate( + &mut self, + candidate: Option, + ) -> impl Future + Send { + async move { + if candidate.is_none() { + self.gather_finished_tx.send(()).await; + } + } + } + } + self.gatherer + .with_event_handler(GathererHandler { gather_finished_tx }); self.gatherer.gather().await?; let _ = gather_finished_rx.recv().await; @@ -1420,28 +1759,53 @@ async fn test_data_channel_ortc_e2e() -> Result<()> { let await_setup_tx = Arc::new(await_setup_tx); let await_string_tx = Arc::new(await_string_tx); let await_binary_tx = Arc::new(await_binary_tx); - stack_b - .sctp - .on_data_channel(Box::new(move |d: Arc| { - let await_setup_tx2 = Arc::clone(&await_setup_tx); - let await_string_tx2 = Arc::clone(&await_string_tx); - let await_binary_tx2 = Arc::clone(&await_binary_tx); - Box::pin(async move { - let _ = await_setup_tx2.send(()).await; - - d.on_message(Box::new(move |msg: DataChannelMessage| { - let await_string_tx3 = Arc::clone(&await_string_tx2); - let await_binary_tx3 = Arc::clone(&await_binary_tx2); - Box::pin(async move { - if msg.is_string { - let _ = await_string_tx3.send(()).await; - } else { - let _ = await_binary_tx3.send(()).await; + + struct TransportHandler { + await_setup_tx: Arc>, + await_string_tx: Arc>, + await_binary_tx: Arc>, + } + + impl SctpTransportEventHandler for TransportHandler { + fn on_data_channel( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async move { + self.await_setup_tx.send(()).await; + + struct MessageHandler { + await_string_tx: Arc>, + await_binary_tx: Arc>, + } + + impl RTCDataChannelEventHandler for MessageHandler { + fn on_message( + &mut self, + msg: DataChannelMessage, + ) -> impl Future + Send { + async move { + if msg.is_string { + let _ = self.await_string_tx.send(()).await; + } else { + let _ = self.await_binary_tx.send(()).await; + } } - }) - })); - }) - })); + } + } + channel.with_event_handler(MessageHandler { + await_string_tx: self.await_string_tx.clone(), + await_binary_tx: self.await_binary_tx.clone(), + }); + } + } + } + + stack_b.sctp.with_event_handler(TransportHandler { + await_setup_tx, + await_string_tx, + await_binary_tx, + }); signal_ortc_pair(Arc::clone(&stack_a), Arc::clone(&stack_b)).await?; diff --git a/webrtc/src/data_channel/mod.rs b/webrtc/src/data_channel/mod.rs index d25db3d35..a7ae8d411 100644 --- a/webrtc/src/data_channel/mod.rs +++ b/webrtc/src/data_channel/mod.rs @@ -21,6 +21,7 @@ use data_channel_state::RTCDataChannelState; use sctp::stream::OnBufferedAmountLowFn; use tokio::sync::{Mutex, Notify}; use util::sync::Mutex as SyncMutex; +use util::{EventHandler, FutureUnit}; use crate::api::setting_engine::SettingEngine; use crate::error::{Error, OnErrorHdlrFn, Result}; @@ -67,10 +68,7 @@ pub struct RTCDataChannel { // is created, the binaryType attribute MUST be initialized to the string // "blob". This attribute controls how binary data is exposed to scripts. // binaryType string - pub(crate) on_message_handler: Arc>>, - pub(crate) on_open_handler: SyncMutex>, - pub(crate) on_close_handler: Arc>>, - pub(crate) on_error_handler: Arc>>, + pub(crate) events_handler: Arc>, pub(crate) on_buffered_amount_low: Mutex>, @@ -83,6 +81,68 @@ pub struct RTCDataChannel { pub(crate) setting_engine: Arc, } +pub trait RTCDataChannelEventHandler: Send { + /// on_message sets an event handler which is invoked on a binary + /// message arrival over the sctp transport from a remote peer. + /// OnMessage can currently receive messages up to 16384 bytes + /// in size. Check out the detach API if you want to use larger + /// message sizes. Note that browser support for larger messages + /// is also limited. + fn on_message(&mut self, message: DataChannelMessage) -> impl Future + Send { + async {} + } + /// on_error sets an event handler which is invoked when + /// the underlying data transport cannot be read. + fn on_error(&mut self, err: crate::error::Error) -> impl Future + Send { + async {} + } + /// on_open sets an event handler which is invoked when + /// the underlying data transport has been established (or re-established). + fn on_open(&mut self) -> impl Future + Send { + async {} + } + /// on_close sets an event handler which is invoked when + /// the underlying data transport has been closed. + fn on_close(&mut self) -> impl Future + Send { + async {} + } + /// on_buffered_amount_low sets an event handler which is invoked when + /// the number of bytes of outgoing data becomes lower than the + /// buffered_amount_low_threshold. + fn on_buffered_amount_low(&mut self, amt: ()) -> impl Future + Send { + async {} + } +} + +trait InlineRTCDataChannelEventHandler: Send { + fn inline_on_message(&mut self, message: DataChannelMessage) -> FutureUnit<'_>; + fn inline_on_error(&mut self, err: crate::error::Error) -> FutureUnit<'_>; + fn inline_on_open(&mut self) -> FutureUnit<'_>; + fn inline_on_close(&mut self) -> FutureUnit<'_>; + fn inline_on_buffered_amount_low(&mut self, amt: ()) -> FutureUnit<'_>; +} + +impl InlineRTCDataChannelEventHandler for T +where + T: RTCDataChannelEventHandler, +{ + fn inline_on_message(&mut self, message: DataChannelMessage) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_message(message).await }) + } + fn inline_on_error(&mut self, err: crate::error::Error) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_error(err).await }) + } + fn inline_on_open(&mut self) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_open().await }) + } + fn inline_on_close(&mut self) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_close().await }) + } + fn inline_on_buffered_amount_low(&mut self, amt: ()) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_buffered_amount_low(amt).await }) + } +} + impl RTCDataChannel { // create the DataChannel object before the networking is set up. pub(crate) fn new(params: DataChannelParameters, setting_engine: Arc) -> Self { @@ -198,63 +258,51 @@ impl RTCDataChannel { sctp_transport.clone() } - /// on_open sets an event handler which is invoked when - /// the underlying data transport has been established (or re-established). - pub fn on_open(&self, f: OnOpenHdlrFn) { - let _ = self.on_open_handler.lock().replace(f); + pub fn with_event_handler( + &self, + handler: impl RTCDataChannelEventHandler + Send + Sync + 'static, + ) { + self.events_handler.store(Box::new(handler)); if self.ready_state() == RTCDataChannelState::Open { self.do_open(); } + + //TODO: on_buffered_amount low attatches to the inner datachannel if it exists, + //which aquiring it is async... } - fn do_open(&self) { - let on_open_handler = self.on_open_handler.lock().take(); - if on_open_handler.is_none() { - return; + // calls [on_message] on itself + async fn loop_back_message(&self, msg: DataChannelMessage) { + if let Some(handler) = &*self.events_handler.load() { + let mut handle = handler.lock().await; + handle.inline_on_message(msg).await; } + } + fn do_open(&self) { + let Some(handler) = &*self.events_handler.load() else { + return; + }; + let handler = handler.clone(); let detach_data_channels = self.setting_engine.detach.data_channels; let detach_called = Arc::clone(&self.detach_called); + tokio::spawn(async move { - if let Some(f) = on_open_handler { - f().await; - - // self.check_detach_after_open(); - // After onOpen is complete check that the user called detach - // and provide an error message if the call was missed - if detach_data_channels && !detach_called.load(Ordering::SeqCst) { - log::warn!( - "webrtc.DetachDataChannels() enabled but didn't Detach, call Detach from OnOpen" - ); - } + let mut handle = handler.lock().await; + handle.inline_on_open().await; + + // self.check_detach_after_open(); + // After onOpen is complete check that the user called detach + // and provide an error message if the call was missed + if detach_data_channels && !detach_called.load(Ordering::SeqCst) { + log::warn!( + "webrtc.DetachDataChannels() enabled but didn't Detach, call Detach from OnOpen" + ); } }); } - /// on_close sets an event handler which is invoked when - /// the underlying data transport has been closed. - pub fn on_close(&self, f: OnCloseHdlrFn) { - self.on_close_handler.store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_message sets an event handler which is invoked on a binary - /// message arrival over the sctp transport from a remote peer. - /// OnMessage can currently receive messages up to 16384 bytes - /// in size. Check out the detach API if you want to use larger - /// message sizes. Note that browser support for larger messages - /// is also limited. - pub fn on_message(&self, f: OnMessageHdlrFn) { - self.on_message_handler.store(Some(Arc::new(Mutex::new(f)))); - } - - async fn do_message(&self, msg: DataChannelMessage) { - if let Some(handler) = &*self.on_message_handler.load() { - let mut f = handler.lock().await; - f(msg).await; - } - } - pub(crate) async fn handle_open(&self, dc: Arc) { { let mut data_channel = self.data_channel.lock().await; @@ -265,40 +313,23 @@ impl RTCDataChannel { self.do_open(); if !self.setting_engine.detach.data_channels { - let ready_state = Arc::clone(&self.ready_state); - let on_message_handler = Arc::clone(&self.on_message_handler); - let on_close_handler = Arc::clone(&self.on_close_handler); - let on_error_handler = Arc::clone(&self.on_error_handler); + let ready_state = self.ready_state.clone(); let notify_rx = self.notify_tx.clone(); + let events_handler = self.events_handler.clone(); tokio::spawn(async move { - RTCDataChannel::read_loop( - notify_rx, - dc, - ready_state, - on_message_handler, - on_close_handler, - on_error_handler, - ) - .await; + RTCDataChannel::read_loop(notify_rx, dc, ready_state, events_handler).await; }); } } - /// on_error sets an event handler which is invoked when - /// the underlying data transport cannot be read. - pub fn on_error(&self, f: OnErrorHdlrFn) { - self.on_error_handler.store(Some(Arc::new(Mutex::new(f)))); - } - async fn read_loop( notify_rx: Arc, data_channel: Arc, ready_state: Arc, - on_message_handler: Arc>>, - on_close_handler: Arc>>, - on_error_handler: Arc>>, + events_handler: Arc>, ) { let mut buffer = vec![0u8; DATA_CHANNEL_BUFFER_SIZE as usize]; + loop { let (n, is_string) = tokio::select! { _ = notify_rx.notified() => break, @@ -310,11 +341,11 @@ impl RTCDataChannel { { ready_state.store(RTCDataChannelState::Closed as u8, Ordering::SeqCst); - let on_close_handler2 = Arc::clone(&on_close_handler); + let events_handler = events_handler.clone(); tokio::spawn(async move { - if let Some(handler) = &*on_close_handler2.load() { - let mut f = handler.lock().await; - f().await; + if let Some(handler) = &*events_handler.load() { + let mut handle = handler.lock().await; + handle.inline_on_close().await; } }); @@ -324,19 +355,19 @@ impl RTCDataChannel { Err(err) => { ready_state.store(RTCDataChannelState::Closed as u8, Ordering::SeqCst); - let on_error_handler2 = Arc::clone(&on_error_handler); + let error_handler = events_handler.clone(); tokio::spawn(async move { - if let Some(handler) = &*on_error_handler2.load() { - let mut f = handler.lock().await; - f(err.into()).await; + if let Some(handler) = &*error_handler.load() { + let mut handle = handler.lock().await; + handle.inline_on_error(err.into()).await; } }); - let on_close_handler2 = Arc::clone(&on_close_handler); + let close_handler = events_handler.clone(); tokio::spawn(async move { - if let Some(handler) = &*on_close_handler2.load() { - let mut f = handler.lock().await; - f().await; + if let Some(handler) = &*close_handler.load() { + let mut handle = handler.lock().await; + handle.inline_on_close().await; } }); @@ -346,13 +377,13 @@ impl RTCDataChannel { } }; - if let Some(handler) = &*on_message_handler.load() { - let mut f = handler.lock().await; - f(DataChannelMessage { + if let Some(handler) = &*events_handler.load() { + let mut handle = handler.lock().await; + let msg = DataChannelMessage { is_string, data: Bytes::from(buffer[..n].to_vec()), - }) - .await; + }; + handle.inline_on_message(msg).await } } } @@ -531,7 +562,10 @@ impl RTCDataChannel { /// the number of bytes of outgoing data becomes lower than the /// buffered_amount_low_threshold. pub async fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) { + //RTCDataChannel and DataChannel both have this callback let data_channel = self.data_channel.lock().await; + //should be able to clone the arc and put the handler in both + //the data channel inner and this currnent? i dont see why not. if let Some(dc) = &*data_channel { dc.on_buffered_amount_low(f); } else { diff --git a/webrtc/src/dtls_transport/dtls_transport_test.rs b/webrtc/src/dtls_transport/dtls_transport_test.rs index a9e71aca0..ff1c202db 100644 --- a/webrtc/src/dtls_transport/dtls_transport_test.rs +++ b/webrtc/src/dtls_transport/dtls_transport_test.rs @@ -2,7 +2,7 @@ use ice::mdns::MulticastDnsMode; use ice::network_type::NetworkType; use regex::Regex; use tokio::time::Duration; -use waitgroup::WaitGroup; +use waitgroup::{WaitGroup, Worker}; use super::*; use crate::api::media_engine::MediaEngine; @@ -12,8 +12,9 @@ use crate::ice_transport::ice_candidate::RTCIceCandidate; use crate::peer_connection::configuration::RTCConfiguration; use crate::peer_connection::peer_connection_state::RTCPeerConnectionState; use crate::peer_connection::peer_connection_test::{ - close_pair_now, new_pair, signal_pair, until_connection_state, + close_pair_now, new_pair, signal_pair, StateHandler, }; +use crate::peer_connection::PeerConnectionEventHandler; //use log::LevelFilter; //use std::io::Write; @@ -42,36 +43,72 @@ async fn test_invalid_fingerprint_causes_failed() -> Result<()> { let (mut pc_offer, mut pc_answer) = new_pair(&api).await?; - pc_answer.on_data_channel(Box::new(|_: Arc| { - panic!("A DataChannel must not be created when Fingerprint verification fails"); - })); + struct AnswerHandler { + worker: Arc>>, + } - let (offer_chan_tx, mut offer_chan_rx) = mpsc::channel::<()>(1); + impl PeerConnectionEventHandler for AnswerHandler { + fn on_data_channel(&mut self, _: Arc) -> impl Future + Send { + async move { + panic!("A DataChannel must not be created when Fingerprint verification fails"); + } + } - let offer_chan_tx = Arc::new(offer_chan_tx); - pc_offer.on_ice_candidate(Box::new(move |candidate: Option| { - let offer_chan_tx2 = Arc::clone(&offer_chan_tx); - Box::pin(async move { - if candidate.is_none() { - let _ = offer_chan_tx2.send(()).await; + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + if state == RTCPeerConnectionState::Failed { + let mut worker = self.worker.lock().await; + worker.take(); + } } - }) - })); + } + } + + struct OfferHandler { + offer_chan_tx: Arc>, + worker: Arc>>, + } + + impl PeerConnectionEventHandler for OfferHandler { + fn on_ice_candidate( + &mut self, + candidate: Option, + ) -> impl Future + Send { + async move { + if candidate.is_none() { + let _ = self.offer_chan_tx.try_send(()); + } + } + } + + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + if state == RTCPeerConnectionState::Failed { + let mut worker = self.worker.lock().await; + worker.take(); + } + } + } + } + + let (offer_chan_tx, mut offer_chan_rx) = mpsc::channel::<()>(1); + let offer_chan_tx = Arc::new(offer_chan_tx); let offer_connection_has_failed = WaitGroup::new(); - until_connection_state( - &mut pc_offer, - &offer_connection_has_failed, - RTCPeerConnectionState::Failed, - ) - .await; let answer_connection_has_failed = WaitGroup::new(); - until_connection_state( - &mut pc_answer, - &answer_connection_has_failed, - RTCPeerConnectionState::Failed, - ) - .await; + pc_offer.with_event_handler(OfferHandler { + offer_chan_tx, + worker: Arc::new(Mutex::new(Some(offer_connection_has_failed.worker()))), + }); + pc_answer.with_event_handler(AnswerHandler { + worker: Arc::new(Mutex::new(Some(offer_connection_has_failed.worker()))), + }); let _ = pc_offer .create_data_channel("unusedDataChannel", None) @@ -155,7 +192,9 @@ async fn run_test(r: DTLSRole) -> Result<()> { signal_pair(&mut offer_pc, &mut answer_pc).await?; let wg = WaitGroup::new(); - until_connection_state(&mut answer_pc, &wg, RTCPeerConnectionState::Connected).await; + answer_pc.with_event_handler(StateHandler { + worker: Arc::new(Mutex::new(Some(wg.worker()))), + }); wg.wait().await; close_pair_now(&offer_pc, &answer_pc).await; diff --git a/webrtc/src/dtls_transport/mod.rs b/webrtc/src/dtls_transport/mod.rs index 8d6bcf356..07239da38 100644 --- a/webrtc/src/dtls_transport/mod.rs +++ b/webrtc/src/dtls_transport/mod.rs @@ -17,7 +17,7 @@ use srtp::protection_profile::ProtectionProfile; use srtp::session::Session; use srtp::stream::Stream; use tokio::sync::{mpsc, Mutex}; -use util::Conn; +use util::{Conn, EventHandler, FutureUnit}; use crate::api::setting_engine::SettingEngine; use crate::dtls_transport::dtls_parameters::DTLSParameters; @@ -67,7 +67,7 @@ pub struct RTCDtlsTransport { pub(crate) remote_certificate: Mutex, pub(crate) state: AtomicU8, //DTLSTransportState, pub(crate) srtp_protection_profile: Mutex, - pub(crate) on_state_change_handler: ArcSwapOption>, + pub(crate) events_handler: EventHandler, pub(crate) conn: Mutex>>, pub(crate) srtp_session: Mutex>>, @@ -84,6 +84,27 @@ pub struct RTCDtlsTransport { pub(crate) dtls_matcher: Option, } +pub trait DtlsTransportEventHandler: Send { + /// on_state_change sets a handler that is fired when the DTLS + /// connection state changes. + fn on_state_change(&mut self, state: RTCDtlsTransportState) -> impl Future + Send { + async {} + } +} + +trait InlineDtlsTransportEventHandler: Send { + fn inline_on_state_change(&mut self, state: RTCDtlsTransportState) -> FutureUnit<'_>; +} + +impl InlineDtlsTransportEventHandler for T +where + T: DtlsTransportEventHandler, +{ + fn inline_on_state_change(&mut self, state: RTCDtlsTransportState) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_state_change(state).await }) + } +} + impl RTCDtlsTransport { pub(crate) fn new( ice_transport: Arc, @@ -118,17 +139,17 @@ impl RTCDtlsTransport { /// state_change requires the caller holds the lock async fn state_change(&self, state: RTCDtlsTransportState) { self.state.store(state as u8, Ordering::SeqCst); - if let Some(handler) = &*self.on_state_change_handler.load() { - let mut f = handler.lock().await; - f(state).await; + if let Some(handler) = &*self.events_handler.load() { + let mut handle = handler.lock().await; + handle.inline_on_state_change(state).await; } } - /// on_state_change sets a handler that is fired when the DTLS - /// connection state changes. - pub fn on_state_change(&self, f: OnDTLSTransportStateChangeHdlrFn) { - self.on_state_change_handler - .store(Some(Arc::new(Mutex::new(f)))); + pub fn with_event_handler( + &self, + handler: impl DtlsTransportEventHandler + Send + Sync + 'static, + ) { + self.events_handler.store(Box::new(handler)); } /// state returns the current dtls_transport transport state. diff --git a/webrtc/src/ice_transport/ice_gatherer.rs b/webrtc/src/ice_transport/ice_gatherer.rs index eda930bf5..2b9fab9ad 100644 --- a/webrtc/src/ice_transport/ice_gatherer.rs +++ b/webrtc/src/ice_transport/ice_gatherer.rs @@ -21,6 +21,7 @@ use crate::peer_connection::policy::ice_transport_policy::RTCIceTransportPolicy; use crate::stats::stats_collector::StatsCollector; use crate::stats::SourceStatsType::*; use crate::stats::{ICECandidatePairStats, StatsReportType}; +use util::{EventHandler, FutureUnit}; /// ICEGatherOptions provides options relating to the gathering of ICE candidates. #[derive(Default, Debug, Clone)] @@ -57,11 +58,95 @@ pub struct RTCIceGatherer { pub(crate) state: Arc, //ICEGathererState, pub(crate) agent: Mutex>>, - pub(crate) on_local_candidate_handler: Arc>>, - pub(crate) on_state_change_handler: Arc>>, + gatherer_state: GathererState, +} + +#[derive(Default, Clone)] +#[repr(transparent)] +struct GathererState { + inner: Arc, +} + +#[derive(Default)] +struct GathererStateInner { + state: Arc, + event_handler: EventHandler, +} + +impl crate::ice::agent::AgentEventHandler for GathererState { + fn on_candidate( + &mut self, + candidate: Option>, + ) -> impl Future + Send { + async move { + match (candidate, &*self.inner.event_handler.load()) { + (Some(candidate), Some(handler)) => { + let cand = RTCIceCandidate::from(&candidate); + handler + .lock() + .await + .inline_on_local_candidate(Some(RTCIceCandidate::from(&candidate.clone()))) + .await; + } + (_, maybe_handler) => { + self.inner + .state + .store(RTCIceGathererState::Complete as u8, Ordering::SeqCst); + if let Some(handler) = maybe_handler { + let mut handler = handler.lock().await; + handler + .inline_on_state_change(RTCIceGathererState::Complete) + .await; + handler.inline_on_gathering_complete().await; + handler.inline_on_local_candidate(None).await; + } + } + } + } + } +} + +pub trait IceGathererEventHandler: Send { + /// on_local_candidate sets an event handler which fires when a new local ICE candidate is available + /// Take note that the handler is gonna be called with a nil pointer when gathering is finished. + fn on_local_candidate( + &mut self, + candidate: Option, + ) -> impl Future + Send { + async {} + } + + /// on_state_change sets an event handler which fires any time the ICEGatherer changes + fn on_state_change(&mut self, state: RTCIceGathererState) -> impl Future + Send { + async {} + } + + /// on_gathering_complete sets an event handler which fires any time the ICEGatherer changes + fn on_gathering_complete(&mut self) -> impl Future + Send { + async {} + } +} + +trait InlineIceGathererEventHandler: Send { + fn inline_on_local_candidate(&mut self, candidate: Option) -> FutureUnit<'_>; + fn inline_on_state_change(&mut self, state: RTCIceGathererState) -> FutureUnit<'_>; + fn inline_on_gathering_complete(&mut self) -> FutureUnit<'_>; +} + +impl InlineIceGathererEventHandler for T +where + T: IceGathererEventHandler, +{ + fn inline_on_local_candidate(&mut self, candidate: Option) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_local_candidate(candidate).await }) + } + fn inline_on_state_change(&mut self, state: RTCIceGathererState) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_state_change(state).await }) + } - // Used for gathering_complete_promise - pub(crate) on_gathering_complete_handler: Arc>>, + fn inline_on_gathering_complete(&mut self) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_gathering_complete().await }) + } } impl RTCIceGatherer { @@ -155,48 +240,7 @@ impl RTCIceGatherer { self.set_state(RTCIceGathererState::Gathering).await; if let Some(agent) = self.get_agent().await { - let state = Arc::clone(&self.state); - let on_local_candidate_handler = Arc::clone(&self.on_local_candidate_handler); - let on_state_change_handler = Arc::clone(&self.on_state_change_handler); - let on_gathering_complete_handler = Arc::clone(&self.on_gathering_complete_handler); - - agent.on_candidate(Box::new( - move |candidate: Option>| { - let state_clone = Arc::clone(&state); - let on_local_candidate_handler_clone = Arc::clone(&on_local_candidate_handler); - let on_state_change_handler_clone = Arc::clone(&on_state_change_handler); - let on_gathering_complete_handler_clone = - Arc::clone(&on_gathering_complete_handler); - - Box::pin(async move { - if let Some(cand) = candidate { - if let Some(handler) = &*on_local_candidate_handler_clone.load() { - let mut f = handler.lock().await; - f(Some(RTCIceCandidate::from(&cand))).await; - } - } else { - state_clone - .store(RTCIceGathererState::Complete as u8, Ordering::SeqCst); - - if let Some(handler) = &*on_state_change_handler_clone.load() { - let mut f = handler.lock().await; - f(RTCIceGathererState::Complete).await; - } - - if let Some(handler) = &*on_gathering_complete_handler_clone.load() { - let mut f = handler.lock().await; - f().await; - } - - if let Some(handler) = &*on_local_candidate_handler_clone.load() { - let mut f = handler.lock().await; - f(None).await; - } - } - }) - }, - )); - + agent.with_event_handler(self.gatherer_state.clone()); agent.gather_candidates()?; } @@ -249,23 +293,14 @@ impl RTCIceGatherer { Ok(rtc_ice_candidates_from_ice_candidates(&ice_candidates)) } - /// on_local_candidate sets an event handler which fires when a new local ICE candidate is available - /// Take note that the handler is gonna be called with a nil pointer when gathering is finished. - pub fn on_local_candidate(&self, f: OnLocalCandidateHdlrFn) { - self.on_local_candidate_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_state_change sets an event handler which fires any time the ICEGatherer changes - pub fn on_state_change(&self, f: OnICEGathererStateChangeHdlrFn) { - self.on_state_change_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_gathering_complete sets an event handler which fires any time the ICEGatherer changes - pub fn on_gathering_complete(&self, f: OnGatheringCompleteHdlrFn) { - self.on_gathering_complete_handler - .store(Some(Arc::new(Mutex::new(f)))); + pub fn with_event_handler( + &self, + handler: impl InlineIceGathererEventHandler + Send + Sync + 'static, + ) { + self.gatherer_state + .inner + .event_handler + .store(Box::new(handler)) } /// State indicates the current state of the ICE gatherer. @@ -276,9 +311,9 @@ impl RTCIceGatherer { pub async fn set_state(&self, s: RTCIceGathererState) { self.state.store(s as u8, Ordering::SeqCst); - if let Some(handler) = &*self.on_state_change_handler.load() { - let mut f = handler.lock().await; - f(s).await; + if let Some(handler) = &*self.gatherer_state.inner.event_handler.load() { + let mut handler = handler.lock().await; + handler.inline_on_state_change(s).await; } } @@ -322,6 +357,7 @@ mod test { use super::*; use crate::api::APIBuilder; use crate::ice_transport::ice_gatherer::RTCIceGatherOptions; + use crate::ice_transport::ice_gatherer::IceGathererEventHandler; use crate::ice_transport::ice_server::RTCIceServer; #[tokio::test] @@ -344,15 +380,22 @@ mod test { let (gather_finished_tx, mut gather_finished_rx) = mpsc::channel::<()>(1); let gather_finished_tx = Arc::new(Mutex::new(Some(gather_finished_tx))); - gatherer.on_local_candidate(Box::new(move |c: Option| { - let gather_finished_tx_clone = Arc::clone(&gather_finished_tx); - Box::pin(async move { - if c.is_none() { - let mut tx = gather_finished_tx_clone.lock().await; - tx.take(); + + struct GatherHandler { + gather_finished_tx: Arc>>>, + } + + impl IceGathererEventHandler for GatherHandler { + fn on_local_candidate(&mut self, candidate: Option) -> impl Future + Send { + async move { + if candidate.is_none() { + let mut finished = self.gather_finished_tx.lock().await; + finished.take(); + } } - }) - })); + } + } + gatherer.with_event_handler(GatherHandler {gather_finished_tx}); gatherer.gather().await?; @@ -386,17 +429,24 @@ mod test { let (done_tx, mut done_rx) = mpsc::channel::<()>(1); let done_tx = Arc::new(Mutex::new(Some(done_tx))); - gatherer.on_local_candidate(Box::new(move |c: Option| { - let done_tx_clone = Arc::clone(&done_tx); - Box::pin(async move { - if let Some(c) = c { + + struct GatherHandler { + done_tx: Arc>>>, + } + + impl IceGathererEventHandler for GatherHandler { + fn on_local_candidate(&mut self, candidate: Option) -> impl Future + Send { + async move { + if let Some(c) = candidate { if c.address.ends_with(".local") { - let mut tx = done_tx_clone.lock().await; + let mut tx = self.done_tx.lock().await; tx.take(); } } - }) - })); + } + } + } + gatherer.with_event_handler(GatherHandler{done_tx}); gatherer.gather().await?; diff --git a/webrtc/src/ice_transport/ice_transport_test.rs b/webrtc/src/ice_transport/ice_transport_test.rs index 866f78afd..48998367e 100644 --- a/webrtc/src/ice_transport/ice_transport_test.rs +++ b/webrtc/src/ice_transport/ice_transport_test.rs @@ -1,7 +1,7 @@ use std::sync::atomic::AtomicU32; use tokio::time::Duration; -use waitgroup::WaitGroup; +use waitgroup::{WaitGroup, Worker}; use super::*; use crate::api::media_engine::MediaEngine; @@ -9,9 +9,8 @@ use crate::api::APIBuilder; use crate::error::Result; use crate::ice_transport::ice_connection_state::RTCIceConnectionState; use crate::peer_connection::peer_connection_state::RTCPeerConnectionState; -use crate::peer_connection::peer_connection_test::{ - close_pair_now, new_pair, signal_pair, until_connection_state, -}; +use crate::peer_connection::peer_connection_test::{close_pair_now, new_pair, signal_pair}; +use crate::peer_connection::PeerConnectionEventHandler; #[tokio::test] async fn test_ice_transport_on_selected_candidate_pair_change() -> Result<()> { @@ -23,27 +22,43 @@ async fn test_ice_transport_on_selected_candidate_pair_change() -> Result<()> { let (ice_complete_tx, mut ice_complete_rx) = mpsc::channel::<()>(1); let ice_complete_tx = Arc::new(Mutex::new(Some(ice_complete_tx))); - pc_answer.on_ice_connection_state_change(Box::new(move |ice_state: RTCIceConnectionState| { - let ice_complete_tx2 = Arc::clone(&ice_complete_tx); - Box::pin(async move { - if ice_state == RTCIceConnectionState::Connected { + + struct AnswerHandler { + ice_complete_tx: Arc>>>, + } + + impl PeerConnectionEventHandler for AnswerHandler { + fn on_ice_connection_state_change(&mut self, state: RTCIceConnectionState) -> impl Future + Send { + async move { + + if state == RTCIceConnectionState::Connected { tokio::time::sleep(Duration::from_secs(1)).await; - let mut done = ice_complete_tx2.lock().await; + let mut done = self.ice_complete_tx.lock().await; done.take(); } - }) - })); + } + } + } + + pc_answer.with_event_handler(AnswerHandler{ ice_complete_tx }); + + struct OfferHandler { + candidate_changes: Arc, + } + + impl IceTransportEventHandler for OfferHandler { + fn on_selected_candidate_pair_change(&mut self, _: RTCIceCandidatePair) -> impl Future + Send { + self.candidate_changes.fetch_add(1, Ordering::SeqCst); + async {} + } + } let sender_called_candidate_change = Arc::new(AtomicU32::new(0)); - let sender_called_candidate_change2 = Arc::clone(&sender_called_candidate_change); pc_offer .sctp() .transport() .ice_transport() - .on_selected_candidate_pair_change(Box::new(move |_: RTCIceCandidatePair| { - sender_called_candidate_change2.store(1, Ordering::SeqCst); - Box::pin(async {}) - })); + .with_event_handler(OfferHandler{candidate_changes: sender_called_candidate_change.clone()}); signal_pair(&mut pc_offer, &mut pc_answer).await?; @@ -68,18 +83,30 @@ async fn test_ice_transport_get_selected_candidate_pair() -> Result<()> { let (mut offerer, mut answerer) = new_pair(&api).await?; let peer_connection_connected = WaitGroup::new(); - until_connection_state( - &mut offerer, - &peer_connection_connected, - RTCPeerConnectionState::Connected, - ) - .await; - until_connection_state( - &mut answerer, - &peer_connection_connected, - RTCPeerConnectionState::Connected, - ) - .await; + + struct ConnectionStateHandler { + worker: Arc>>, + } + + impl PeerConnectionEventHandler for ConnectionStateHandler { + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + if state == RTCPeerConnectionState::Connected { + let mut worker = self.worker.lock().await; + worker.take(); + } + } + } + } + offerer.with_event_handler(ConnectionStateHandler { + worker: Arc::new(Mutex::new(Some(peer_connection_connected.worker()))), + }); + answerer.with_event_handler(ConnectionStateHandler { + worker: Arc::new(Mutex::new(Some(peer_connection_connected.worker()))), + }); let offerer_selected_pair = offerer .sctp() diff --git a/webrtc/src/ice_transport/mod.rs b/webrtc/src/ice_transport/mod.rs index 6bf09c06d..8b7353103 100644 --- a/webrtc/src/ice_transport/mod.rs +++ b/webrtc/src/ice_transport/mod.rs @@ -11,7 +11,7 @@ use ice_candidate_pair::RTCIceCandidatePair; use ice_gatherer::RTCIceGatherer; use ice_role::RTCIceRole; use tokio::sync::{mpsc, Mutex}; -use util::Conn; +use util::{Conn, EventHandler, FutureUnit}; use crate::error::{flatten_errs, Error, Result}; use crate::ice_transport::ice_parameters::RTCIceParameters; @@ -68,10 +68,107 @@ pub struct RTCIceTransport { on_connection_state_change_handler: Arc>>, on_selected_candidate_pair_change_handler: Arc>>, + transport_state: TransportState, state: Arc, // ICETransportState internal: Mutex, } +#[derive(Default, Clone)] +#[repr(transparent)] +struct TransportState { + inner: Arc, +} + +#[derive(Default)] +struct TransportStateInner { + state: Arc, + events_handler: Arc>, +} + +pub trait IceTransportEventHandler: Send { + /// on_connection_state_change sets a handler that is fired when the ICE + /// connection state changes. + fn on_connection_state_change( + &mut self, + state: RTCIceTransportState, + ) -> impl Future + Send { + async {} + } + /// on_selected_candidate_pair_change sets a handler that is invoked when a new + /// ICE candidate pair is selected + fn on_selected_candidate_pair_change( + &mut self, + candidate_pair: RTCIceCandidatePair, + ) -> impl Future + Send { + async {} + } +} + +trait InlineIceTransportEventHandler: Send { + fn inline_on_connection_state_change(&mut self, state: RTCIceTransportState) -> FutureUnit<'_>; + fn inline_on_selected_candidate_pair_change( + &mut self, + candidate_pair: RTCIceCandidatePair, + ) -> FutureUnit<'_>; +} + +impl InlineIceTransportEventHandler for T +where + T: IceTransportEventHandler, +{ + fn inline_on_connection_state_change(&mut self, state: RTCIceTransportState) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_connection_state_change(state).await }) + } + fn inline_on_selected_candidate_pair_change( + &mut self, + candidate_pair: RTCIceCandidatePair, + ) -> FutureUnit<'_> { + FutureUnit::from_async(async move { + self.on_selected_candidate_pair_change(candidate_pair).await + }) + } +} + +impl ice::agent::AgentEventHandler for TransportState { + fn on_connection_state_change( + &mut self, + state: ConnectionState, + ) -> impl Future + Send { + async move { + let ice_state = RTCIceTransportState::from(state); + self.inner.state.store(ice_state as u8, Ordering::SeqCst); + + if let Some(handler) = &*self.inner.events_handler.load() { + handler + .lock() + .await + .inline_on_connection_state_change(ice_state) + .await + } + } + } + + fn on_selected_candidate_pair_change( + &mut self, + local_candidate: Arc, + remote_candidate: Arc, + ) -> impl Future + Send { + async move { + if let Some(handler) = &*self.inner.events_handler.load() { + let local = RTCIceCandidate::from(&local_candidate); + let remote = RTCIceCandidate::from(&remote_candidate); + handler + .lock() + .await + .inline_on_selected_candidate_pair_change(RTCIceCandidatePair::new( + local, remote, + )) + .await + } + } + } +} + impl RTCIceTransport { /// creates a new new_icetransport. pub(crate) fn new(gatherer: Arc) -> Self { @@ -104,43 +201,6 @@ impl RTCIceTransport { self.ensure_gatherer().await?; if let Some(agent) = self.gatherer.get_agent().await { - let state = Arc::clone(&self.state); - - let on_connection_state_change_handler = - Arc::clone(&self.on_connection_state_change_handler); - agent.on_connection_state_change(Box::new(move |ice_state: ConnectionState| { - let s = RTCIceTransportState::from(ice_state); - let on_connection_state_change_handler_clone = - Arc::clone(&on_connection_state_change_handler); - state.store(s as u8, Ordering::SeqCst); - Box::pin(async move { - if let Some(handler) = &*on_connection_state_change_handler_clone.load() { - let mut f = handler.lock().await; - f(s).await; - } - }) - })); - - let on_selected_candidate_pair_change_handler = - Arc::clone(&self.on_selected_candidate_pair_change_handler); - agent.on_selected_candidate_pair_change(Box::new( - move |local: &Arc, - remote: &Arc| { - let on_selected_candidate_pair_change_handler_clone = - Arc::clone(&on_selected_candidate_pair_change_handler); - let local = RTCIceCandidate::from(local); - let remote = RTCIceCandidate::from(remote); - Box::pin(async move { - if let Some(handler) = - &*on_selected_candidate_pair_change_handler_clone.load() - { - let mut f = handler.lock().await; - f(RTCIceCandidatePair::new(local, remote)).await; - } - }) - }, - )); - let role = if let Some(role) = role { role } else { @@ -240,18 +300,14 @@ impl RTCIceTransport { flatten_errs(errs) } - /// on_selected_candidate_pair_change sets a handler that is invoked when a new - /// ICE candidate pair is selected - pub fn on_selected_candidate_pair_change(&self, f: OnSelectedCandidatePairChangeHdlrFn) { - self.on_selected_candidate_pair_change_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_connection_state_change sets a handler that is fired when the ICE - /// connection state changes. - pub fn on_connection_state_change(&self, f: OnConnectionStateChangeHdlrFn) { - self.on_connection_state_change_handler - .store(Some(Arc::new(Mutex::new(f)))); + pub fn with_event_handler( + &self, + handler: impl IceTransportEventHandler + Send + Sync + 'static, + ) { + self.transport_state + .inner + .events_handler + .store(Box::new(handler)) } /// Role indicates the current role of the ICE transport. @@ -265,6 +321,8 @@ impl RTCIceTransport { self.ensure_gatherer().await?; if let Some(agent) = self.gatherer.get_agent().await { + agent.with_event_handler(self.transport_state.clone()); + //agent.with_event_handler(self.) for rc in remote_candidates { let c: Arc = Arc::new(rc.to_ice()?); agent.add_remote_candidate(&c)?; diff --git a/webrtc/src/peer_connection/mod.rs b/webrtc/src/peer_connection/mod.rs index b4989cef4..82251df15 100644 --- a/webrtc/src/peer_connection/mod.rs +++ b/webrtc/src/peer_connection/mod.rs @@ -49,8 +49,8 @@ use crate::error::{flatten_errs, Error, Result}; use crate::ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit}; use crate::ice_transport::ice_connection_state::RTCIceConnectionState; use crate::ice_transport::ice_gatherer::{ - OnGatheringCompleteHdlrFn, OnICEGathererStateChangeHdlrFn, OnLocalCandidateHdlrFn, - RTCIceGatherOptions, RTCIceGatherer, + IceGathererEventHandler, OnGatheringCompleteHdlrFn, OnICEGathererStateChangeHdlrFn, + OnLocalCandidateHdlrFn, RTCIceGatherOptions, RTCIceGatherer, }; use crate::ice_transport::ice_gatherer_state::RTCIceGathererState; use crate::ice_transport::ice_gathering_state::RTCIceGatheringState; @@ -85,6 +85,7 @@ use crate::sctp_transport::RTCSctpTransport; use crate::stats::StatsReport; use crate::track::track_local::TrackLocal; use crate::track::track_remote::TrackRemote; +use util::{EventHandler, FutureUnit}; /// SIMULCAST_PROBE_COUNT is the amount of RTP Packets /// that handleUndeclaredSSRC will read and try to dispatch from @@ -170,7 +171,7 @@ struct CheckNegotiationNeededParams { #[derive(Clone)] struct NegotiationNeededParams { - on_negotiation_needed_handler: Arc>>, + events_handler: Arc>, is_closed: Arc, ops: Arc, negotiation_needed_state: Arc, @@ -179,6 +180,8 @@ struct NegotiationNeededParams { check_negotiation_needed_params: CheckNegotiationNeededParams, } +//TODO: move the rest of the impl stuff over to the new FutureUnit + /// PeerConnection represents a WebRTC connection that establishes a /// peer-to-peer communications with another PeerConnection instance in a /// browser, or to another endpoint implementing the required protocols. @@ -212,6 +215,169 @@ impl std::fmt::Display for RTCPeerConnection { } } +// NOTE: on_ice_candidate & on_ice_gathering_state_change are both from `IceGatherer` +// i also wouldnt doubt that other handlers are coming from other places as well. +pub trait PeerConnectionEventHandler: Send { + /// on_ice_candidate sets an event handler which is invoked when a new ICE + /// candidate is found. + /// Take note that the handler is gonna be called with a nil pointer when + /// gathering is finished. + fn on_ice_candidate( + &mut self, + ice_candidate: Option, + ) -> impl Future + Send { + async {} + } + + /// on_ice_gathering_state_change sets an event handler which is invoked when the + /// ICE candidate gathering state has changed. + fn on_ice_gathering_state_change( + &mut self, + state: RTCIceGathererState, + ) -> impl Future + Send { + async {} + } + + /// on_track sets an event handler which is called when remote track + /// arrives from a remote peer. + fn on_track( + &mut self, + track_remote: Arc, + receiver: Arc, + transceiver: Arc, + ) -> impl Future + Send { + async {} + } + + /// on_ice_connection_state_change sets an event handler which is called + /// when an ICE connection state is changed. + fn on_ice_connection_state_change( + &mut self, + state: RTCIceConnectionState, + ) -> impl Future + Send { + async {} + } + + /// on_peer_connection_state_change sets an event handler which is called + /// when the PeerConnectionState has changed + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async {} + } + + /// on_signaling_state_change sets an event handler which is invoked when the + /// peer connection's signaling state changes + fn on_signaling_state_change( + &mut self, + state: RTCSignalingState, + ) -> impl Future + Send { + async {} + } + + /// on_data_channel sets an event handler which is invoked when a data + /// channel message arrives from a remote peer. + fn on_data_channel(&mut self, channel: Arc) -> impl Future + Send { + async {} + } + + /// on_negotiation_needed sets an event handler which is invoked when + /// a change has occurred which requires session negotiation + fn on_negotiation_needed(&mut self) -> impl Future + Send { + async {} + } +} + +impl crate::sctp_transport::SctpTransportEventHandler + for Arc> +{ + fn on_data_channel( + &mut self, + data_channel: Arc, + ) -> impl Future + Send { + async move { + if let Some(handle) = &*self.load() { + let mut handle = handle.lock().await; + handle.inline_on_data_channel(data_channel).await + } + } + } +} + +trait InlinePeerConnectionEventHandler: Send { + fn inline_on_track( + &mut self, + track_remote: Arc, + receiver: Arc, + transceiver: Arc, + ) -> FutureUnit<'_>; + fn inline_on_ice_connection_state_change( + &mut self, + state: RTCIceConnectionState, + ) -> FutureUnit<'_>; + fn inline_on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> FutureUnit<'_>; + fn inline_on_signaling_state_change(&mut self, state: RTCSignalingState) -> FutureUnit<'_>; + fn inline_on_data_channel(&mut self, channel: Arc) -> FutureUnit<'_>; + fn inline_on_negotiation_needed(&mut self) -> FutureUnit<'_>; +} + +impl InlinePeerConnectionEventHandler for T +where + T: PeerConnectionEventHandler, +{ + fn inline_on_track( + &mut self, + track_remote: Arc, + receiver: Arc, + transceiver: Arc, + ) -> FutureUnit<'_> { + FutureUnit::from_async( + async move { self.on_track(track_remote, receiver, transceiver).await }, + ) + } + fn inline_on_ice_connection_state_change( + &mut self, + state: RTCIceConnectionState, + ) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_ice_connection_state_change(state).await }) + } + fn inline_on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_peer_connection_state_change(state).await }) + } + fn inline_on_signaling_state_change(&mut self, state: RTCSignalingState) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_signaling_state_change(state).await }) + } + fn inline_on_data_channel(&mut self, channel: Arc) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_data_channel(channel).await }) + } + fn inline_on_negotiation_needed(&mut self) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_negotiation_needed().await }) + } +} + +impl IceGathererEventHandler for T +where + T: PeerConnectionEventHandler, +{ + fn on_state_change(&mut self, state: RTCIceGathererState) -> impl Future + Send { + self.on_ice_gathering_state_change(state) + } + + fn on_local_candidate( + &mut self, + candidate: Option, + ) -> impl Future + Send { + self.on_ice_candidate(candidate) + } +} + impl RTCPeerConnection { /// creates a PeerConnection with the default codecs and /// interceptors. See register_default_codecs and register_default_interceptors. @@ -286,38 +452,21 @@ impl RTCPeerConnection { Ok(()) } - /// on_signaling_state_change sets an event handler which is invoked when the - /// peer connection's signaling state changes - pub fn on_signaling_state_change(&self, f: OnSignalingStateChangeHdlrFn) { - self.internal - .on_signaling_state_change_handler - .store(Some(Arc::new(Mutex::new(f)))) + pub fn with_event_handler( + &self, + handler: impl PeerConnectionEventHandler + Send + Sync + 'static, + ) { + self.internal.events_handler.store(Box::new(handler)) } async fn do_signaling_state_change(&self, new_state: RTCSignalingState) { log::info!("signaling state changed to {}", new_state); - if let Some(handler) = &*self.internal.on_signaling_state_change_handler.load() { - let mut f = handler.lock().await; - f(new_state).await; + if let Some(handler) = &*self.internal.events_handler.load() { + let mut handler = handler.lock().await; + handler.inline_on_signaling_state_change(new_state).await; } } - /// on_data_channel sets an event handler which is invoked when a data - /// channel message arrives from a remote peer. - pub fn on_data_channel(&self, f: OnDataChannelHdlrFn) { - self.internal - .on_data_channel_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_negotiation_needed sets an event handler which is invoked when - /// a change has occurred which requires session negotiation - pub fn on_negotiation_needed(&self, f: OnNegotiationNeededHdlrFn) { - self.internal - .on_negotiation_needed_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - fn do_negotiation_needed_inner(params: &NegotiationNeededParams) -> bool { // https://w3c.github.io/webrtc-pc/#updating-the-negotiation-needed-flag // non-canon step 1 @@ -375,10 +524,9 @@ impl RTCPeerConnection { async fn negotiation_needed_op(params: NegotiationNeededParams) -> bool { // Don't run NegotiatedNeeded checks if on_negotiation_needed is not set - let handler = &*params.on_negotiation_needed_handler.load(); - if handler.is_none() { + let Some(handler) = &*params.events_handler.load() else { return false; - } + }; // https://www.w3.org/TR/webrtc/#updating-the-negotiation-needed-flag // Step 2.1 @@ -416,10 +564,8 @@ impl RTCPeerConnection { params.is_negotiation_needed.store(true, Ordering::SeqCst); // Step 2.7 - if let Some(handler) = handler { - let mut f = handler.lock().await; - f().await; - } + let mut handle = handler.lock().await; + handle.inline_on_negotiation_needed().await; RTCPeerConnection::after_negotiation_needed_op(params).await } @@ -562,28 +708,6 @@ impl RTCPeerConnection { } } - /// on_ice_candidate sets an event handler which is invoked when a new ICE - /// candidate is found. - /// Take note that the handler is gonna be called with a nil pointer when - /// gathering is finished. - pub fn on_ice_candidate(&self, f: OnLocalCandidateHdlrFn) { - self.internal.ice_gatherer.on_local_candidate(f) - } - - /// on_ice_gathering_state_change sets an event handler which is invoked when the - /// ICE candidate gathering state has changed. - pub fn on_ice_gathering_state_change(&self, f: OnICEGathererStateChangeHdlrFn) { - self.internal.ice_gatherer.on_state_change(f) - } - - /// on_track sets an event handler which is called when remote track - /// arrives from a remote peer. - pub fn on_track(&self, f: OnTrackHdlrFn) { - self.internal - .on_track_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - fn do_track( on_track_handler: Arc>>, track: Arc, @@ -602,19 +726,33 @@ impl RTCPeerConnection { }); } - /// on_ice_connection_state_change sets an event handler which is called - /// when an ICE connection state is changed. - pub fn on_ice_connection_state_change(&self, f: OnICEConnectionStateChangeHdlrFn) { - self.internal - .on_ice_connection_state_change_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - async fn do_ice_connection_state_change( handler: &Arc>>, ice_connection_state: &Arc, cs: RTCIceConnectionState, ) { + //FIXME: there should be some wrapper for a state + //eg: + //struct State where T: Into + From { + // inner: std::sync::AtomicU8, + // _: core::marker::PhantomData, + //} + //impl State where T: Into + From { + // fn new(state: T) -> Self { + // inner: state, + // _: core::marker::PhantomData, + // } + // + // fn load(&self) -> T { + // self.inner.load(Ordering::SeqCst).into(); + // } + // fn store(&mut self, state: T) { + // self.inner.store(state.into(), Ordering::SeqCst); + // } + //} + // for better type system and so that you dont need to cast + // to/from u8 each time + ice_connection_state.store(cs as u8, Ordering::SeqCst); log::info!("ICE connection state changed: {}", cs); @@ -624,14 +762,6 @@ impl RTCPeerConnection { } } - /// on_peer_connection_state_change sets an event handler which is called - /// when the PeerConnectionState has changed - pub fn on_peer_connection_state_change(&self, f: OnPeerConnectionStateChangeHdlrFn) { - self.internal - .on_peer_connection_state_change_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - async fn do_peer_connection_state_change( handler: &Arc>>, cs: RTCPeerConnectionState, @@ -907,6 +1037,56 @@ impl RTCPeerConnection { .await; } + fn updated_connection_state( + is_closed: &Arc, + peer_connection_state: &Arc, + ice_connection_state: RTCIceConnectionState, + dtls_transport_state: RTCDtlsTransportState, + ) -> Option { + let connection_state = + // The RTCPeerConnection object's [[IsClosed]] slot is true. + if is_closed.load(Ordering::SeqCst) { + RTCPeerConnectionState::Closed + } else if ice_connection_state == RTCIceConnectionState::Failed || dtls_transport_state == RTCDtlsTransportState::Failed { + // Any of the RTCIceTransports or RTCDtlsTransports are in a "failed" state. + RTCPeerConnectionState::Failed + } else if ice_connection_state == RTCIceConnectionState::Disconnected { + // Any of the RTCIceTransports or RTCDtlsTransports are in the "disconnected" + // state and none of them are in the "failed" or "connecting" or "checking" state. + RTCPeerConnectionState::Disconnected + } else if ice_connection_state == RTCIceConnectionState::Connected && dtls_transport_state == RTCDtlsTransportState::Connected { + // All RTCIceTransports and RTCDtlsTransports are in the "connected", "completed" or "closed" + // state and at least one of them is in the "connected" or "completed" state. + RTCPeerConnectionState::Connected + } else if ice_connection_state == RTCIceConnectionState::Checking && dtls_transport_state == RTCDtlsTransportState::Connecting { + // Any of the RTCIceTransports or RTCDtlsTransports are in the "connecting" or + // "checking" state and none of them is in the "failed" state. + RTCPeerConnectionState::Connecting + } else { + RTCPeerConnectionState::New + }; + /* TODO: move this to a seperate fn + use RTCPeerConnectionState as PCS; + use RTCIceConnectionState as ICS; + use RTCDtlsTransportState as DCS; + let connection_state = if is_closed.load(Ordering::SeqCst) { + PCS::Closed + }else { + match (ice_connection_state, dtls_transport_state) { + (ICS::Failed, _) | (_, DCS::Failed) => PCS::Failed, + (ICS::Disconnected, _) => PCS::Disconnected, + (ICS::Connected, DCS::Connected) => PCS::Connected, + (ICS::Checking, DCS::Connecting) => PCS::Connecting, + _ => PCS::New + } + }; + */ + if peer_connection_state.load(Ordering::SeqCst) == connection_state as u8 { + return None; + } + Some(connection_state) + } + /// create_answer starts the PeerConnection and generates the localDescription pub async fn create_answer( &self, @@ -1964,14 +2144,23 @@ impl RTCPeerConnection { } // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #11) - RTCPeerConnection::update_connection_state( - &self.internal.on_peer_connection_state_change_handler, - &self.internal.is_closed, - &self.internal.peer_connection_state, - self.ice_connection_state(), - self.internal.dtls_transport.state(), - ) - .await; + match ( + &*self.internal.events_handler.load(), + RTCPeerConnection::updated_connection_state( + &self.internal.is_closed, + &self.internal.peer_connection_state, + self.ice_connection_state(), + self.internal.dtls_transport.state(), + ), + ) { + (Some(handler), Some(changed_state)) => { + let mut handler = handler.lock().await; + handler + .inline_on_peer_connection_state_change(changed_state) + .await; + } + _ => (), + } if let Err(err) = self.internal.ops.close().await { close_errs.push(Error::new(format!("ops: {err}"))); diff --git a/webrtc/src/peer_connection/peer_connection_internal.rs b/webrtc/src/peer_connection/peer_connection_internal.rs index a02bd20c6..08230e91a 100644 --- a/webrtc/src/peer_connection/peer_connection_internal.rs +++ b/webrtc/src/peer_connection/peer_connection_internal.rs @@ -5,7 +5,7 @@ use std::sync::Weak; use arc_swap::ArcSwapOption; use smol_str::SmolStr; use tokio::time::Instant; -use util::Unmarshal; +use util::{EventHandler, Unmarshal}; use super::*; use crate::rtp_transceiver::create_stream_info; @@ -27,7 +27,6 @@ pub(crate) struct PeerConnectionInternal { pub(super) last_offer: Mutex, pub(super) last_answer: Mutex, - pub(super) on_negotiation_needed_handler: Arc>>, pub(super) is_closed: Arc, /// ops is an operations queue which will ensure the enqueued actions are @@ -40,20 +39,14 @@ pub(crate) struct PeerConnectionInternal { pub(super) ice_transport: Arc, pub(super) dtls_transport: Arc, - pub(super) on_peer_connection_state_change_handler: - Arc>>, pub(super) peer_connection_state: Arc, pub(super) ice_connection_state: Arc, pub(super) sctp_transport: Arc, pub(super) rtp_transceivers: Arc>>>, - pub(super) on_track_handler: Arc>>, - pub(super) on_signaling_state_change_handler: - ArcSwapOption>, - pub(super) on_ice_connection_state_change_handler: - Arc>>, - pub(super) on_data_channel_handler: Arc>>, + pub(super) events_handler: + Arc>, pub(super) ice_gatherer: Arc, @@ -69,6 +62,61 @@ pub(crate) struct PeerConnectionInternal { stats_interceptor: Arc, } +use crate::ice_transport::IceTransportEventHandler; + +struct IceTransportHandler { + events_handler: Arc>, + is_closed: Arc, + peer_connection_state: Arc, + dtls_transport: Arc, + ice_connection_state: Arc, +} + +impl IceTransportEventHandler for IceTransportHandler { + fn on_connection_state_change( + &mut self, + state: RTCIceTransportState, + ) -> impl Future + Send { + async move { + //FIXME: this should be moved into a seperate fn + let conn_state = match state { + RTCIceTransportState::New => RTCIceConnectionState::New, + RTCIceTransportState::Checking => RTCIceConnectionState::Checking, + RTCIceTransportState::Connected => RTCIceConnectionState::Connected, + RTCIceTransportState::Completed => RTCIceConnectionState::Completed, + RTCIceTransportState::Failed => RTCIceConnectionState::Failed, + RTCIceTransportState::Disconnected => RTCIceConnectionState::Disconnected, + RTCIceTransportState::Closed => RTCIceConnectionState::Closed, + _ => { + log::warn!("on_connection_state_change: unhandled ICE state: {}", state); + return; + } + }; + + if let Some(handler) = &*self.events_handler.load() { + let mut handler = handler.lock().await; + if self.ice_connection_state.load(Ordering::SeqCst) != conn_state as u8 { + self.ice_connection_state + .store(conn_state as u8, Ordering::SeqCst); + handler + .inline_on_ice_connection_state_change(conn_state) + .await + } + if let Some(changed_state) = RTCPeerConnection::updated_connection_state( + &self.is_closed, + &self.peer_connection_state, + conn_state, + self.dtls_transport.state(), + ) { + handler + .inline_on_peer_connection_state_change(changed_state) + .await; + } + } + } + } +} + impl PeerConnectionInternal { pub(super) async fn new( api: &API, @@ -82,7 +130,6 @@ impl PeerConnectionInternal { last_offer: Mutex::new("".to_owned()), last_answer: Mutex::new("".to_owned()), - on_negotiation_needed_handler: Arc::new(ArcSwapOption::empty()), ops: Arc::new(Operations::new()), is_closed: Arc::new(AtomicBool::new(false)), is_negotiation_needed: Arc::new(AtomicBool::new(false)), @@ -93,10 +140,6 @@ impl PeerConnectionInternal { ice_connection_state: Arc::new(AtomicU8::new(RTCIceConnectionState::New as u8)), sctp_transport: Arc::new(Default::default()), rtp_transceivers: Arc::new(Default::default()), - on_track_handler: Arc::new(ArcSwapOption::empty()), - on_signaling_state_change_handler: ArcSwapOption::empty(), - on_ice_connection_state_change_handler: Arc::new(ArcSwapOption::empty()), - on_data_channel_handler: Arc::new(Default::default()), ice_gatherer: Arc::new(Default::default()), current_local_description: Arc::new(Default::default()), current_remote_description: Arc::new(Default::default()), @@ -111,7 +154,7 @@ impl PeerConnectionInternal { }, interceptor, stats_interceptor, - on_peer_connection_state_change_handler: Arc::new(ArcSwapOption::empty()), + events_handler: Arc::new(EventHandler::empty()), pending_remote_description: Arc::new(Default::default()), }; @@ -133,17 +176,8 @@ impl PeerConnectionInternal { pc.sctp_transport = Arc::new(api.new_sctp_transport(Arc::clone(&pc.dtls_transport))?); // Wire up the on datachannel handler - let on_data_channel_handler = Arc::clone(&pc.on_data_channel_handler); pc.sctp_transport - .on_data_channel(Box::new(move |d: Arc| { - let on_data_channel_handler2 = Arc::clone(&on_data_channel_handler); - Box::pin(async move { - if let Some(handler) = &*on_data_channel_handler2.load() { - let mut f = handler.lock().await; - f(d).await; - } - }) - })); + .with_event_handler(pc.events_handler.clone()); Ok((Arc::new(pc), configuration)) } @@ -373,8 +407,8 @@ impl PeerConnectionInternal { self.setting_engine.get_receive_mtu(), incoming_track, receiver, - Arc::clone(t), - Arc::clone(&self.on_track_handler), + t.clone(), + self.events_handler.clone(), ) .await; track_handled = true; @@ -551,7 +585,8 @@ impl PeerConnectionInternal { /// Creates the parameters needed to trigger a negotiation needed. fn create_negotiation_needed_params(&self) -> NegotiationNeededParams { NegotiationNeededParams { - on_negotiation_needed_handler: Arc::clone(&self.on_negotiation_needed_handler), + //on_negotiation_needed_handler: Arc::clone(&self.on_negotiation_needed_handler), + events_handler: self.events_handler.clone(), is_closed: Arc::clone(&self.is_closed), ops: Arc::clone(&self.ops), negotiation_needed_state: Arc::clone(&self.negotiation_needed_state), @@ -590,7 +625,8 @@ impl PeerConnectionInternal { } pub(super) fn set_gather_complete_handler(&self, f: OnGatheringCompleteHdlrFn) { - self.ice_gatherer.on_gathering_complete(f); + //TODO: + //self.ice_gatherer.on_gathering_complete(f); } /// Start all transports. PeerConnection now has enough state @@ -631,14 +667,25 @@ impl PeerConnectionInternal { }], }) .await; - RTCPeerConnection::update_connection_state( - &self.on_peer_connection_state_change_handler, - &self.is_closed, - &self.peer_connection_state, - self.ice_connection_state.load(Ordering::SeqCst).into(), - self.dtls_transport.state(), - ) - .await; + + match ( + &*self.events_handler.load(), + RTCPeerConnection::updated_connection_state( + &self.is_closed, + &self.peer_connection_state, + self.ice_connection_state.load(Ordering::SeqCst).into(), + self.dtls_transport.state(), + ), + ) { + (Some(handler), Some(changed_state)) => { + let mut handler = handler.lock().await; + handler + .inline_on_peer_connection_state_change(changed_state) + .await; + } + _ => (), + } + if let Err(err) = result { log::warn!("Failed to start manager dtls: {}", err); } @@ -894,7 +941,7 @@ impl PeerConnectionInternal { &incoming, receiver, t, - Arc::clone(&self.on_track_handler), + self.events_handler.clone(), ) .await; Ok(true) @@ -1042,12 +1089,10 @@ impl PeerConnectionInternal { .await?; track.prepopulate_peeked_data(buffered_packets).await; - RTCPeerConnection::do_track( - Arc::clone(&self.on_track_handler), - track, - receiver, - Arc::clone(t), - ); + if let Some(handler) = &*self.events_handler.load() { + let mut handler = handler.lock().await; + handler.inline_on_track(track, receiver, t.clone()).await + } return Ok(()); } } @@ -1065,7 +1110,7 @@ impl PeerConnectionInternal { incoming: &TrackDetails, receiver: Arc, transceiver: Arc, - on_track_handler: Arc>>, + events_handler: Arc>, ) { receiver.start(incoming).await; for t in receiver.tracks().await { @@ -1075,7 +1120,7 @@ impl PeerConnectionInternal { let receiver = Arc::clone(&receiver); let transceiver = Arc::clone(&transceiver); - let on_track_handler = Arc::clone(&on_track_handler); + let track_handler = events_handler.clone(); tokio::spawn(async move { if let Some(track) = receiver.track().await { let mut b = vec![0u8; receive_mtu]; @@ -1100,66 +1145,28 @@ impl PeerConnectionInternal { return; } - RTCPeerConnection::do_track(on_track_handler, track, receiver, transceiver); + if let Some(handler) = &*track_handler.load() { + let mut handle = handler.lock().await; + handle.inline_on_track(track, receiver, transceiver).await; + } } }); } } + fn ice_transport_handler(&self) -> IceTransportHandler { + IceTransportHandler { + events_handler: self.events_handler.clone(), + is_closed: self.is_closed.clone(), + dtls_transport: self.dtls_transport.clone(), + peer_connection_state: self.peer_connection_state.clone(), + ice_connection_state: self.ice_connection_state.clone(), + } + } + pub(super) async fn create_ice_transport(&self, api: &API) -> Arc { let ice_transport = Arc::new(api.new_ice_transport(Arc::clone(&self.ice_gatherer))); - - let ice_connection_state = Arc::clone(&self.ice_connection_state); - let peer_connection_state = Arc::clone(&self.peer_connection_state); - let is_closed = Arc::clone(&self.is_closed); - let dtls_transport = Arc::clone(&self.dtls_transport); - let on_ice_connection_state_change_handler = - Arc::clone(&self.on_ice_connection_state_change_handler); - let on_peer_connection_state_change_handler = - Arc::clone(&self.on_peer_connection_state_change_handler); - - ice_transport.on_connection_state_change(Box::new(move |state: RTCIceTransportState| { - let cs = match state { - RTCIceTransportState::New => RTCIceConnectionState::New, - RTCIceTransportState::Checking => RTCIceConnectionState::Checking, - RTCIceTransportState::Connected => RTCIceConnectionState::Connected, - RTCIceTransportState::Completed => RTCIceConnectionState::Completed, - RTCIceTransportState::Failed => RTCIceConnectionState::Failed, - RTCIceTransportState::Disconnected => RTCIceConnectionState::Disconnected, - RTCIceTransportState::Closed => RTCIceConnectionState::Closed, - _ => { - log::warn!("on_connection_state_change: unhandled ICE state: {}", state); - return Box::pin(async {}); - } - }; - - let ice_connection_state2 = Arc::clone(&ice_connection_state); - let on_ice_connection_state_change_handler2 = - Arc::clone(&on_ice_connection_state_change_handler); - let on_peer_connection_state_change_handler2 = - Arc::clone(&on_peer_connection_state_change_handler); - let is_closed2 = Arc::clone(&is_closed); - let dtls_transport_state = dtls_transport.state(); - let peer_connection_state2 = Arc::clone(&peer_connection_state); - Box::pin(async move { - RTCPeerConnection::do_ice_connection_state_change( - &on_ice_connection_state_change_handler2, - &ice_connection_state2, - cs, - ) - .await; - - RTCPeerConnection::update_connection_state( - &on_peer_connection_state_change_handler2, - &is_closed2, - &peer_connection_state2, - cs, - dtls_transport_state, - ) - .await; - }) - })); - + ice_transport.with_event_handler(self.ice_transport_handler()); ice_transport } diff --git a/webrtc/src/peer_connection/peer_connection_test.rs b/webrtc/src/peer_connection/peer_connection_test.rs index 6073054de..320a83339 100644 --- a/webrtc/src/peer_connection/peer_connection_test.rs +++ b/webrtc/src/peer_connection/peer_connection_test.rs @@ -7,7 +7,7 @@ use media::Sample; use tokio::time::Duration; use util::vnet::net::{Net, NetConfig}; use util::vnet::router::{Router, RouterConfig}; -use waitgroup::WaitGroup; +use waitgroup::{WaitGroup, Worker}; use super::*; use crate::api::interceptor_registry::register_default_interceptors; @@ -15,6 +15,7 @@ use crate::api::media_engine::{MediaEngine, MIME_TYPE_VP8}; use crate::api::APIBuilder; use crate::ice_transport::ice_candidate_pair::RTCIceCandidatePair; use crate::ice_transport::ice_server::RTCIceServer; +use crate::ice_transport::IceTransportEventHandler; use crate::peer_connection::configuration::RTCConfiguration; use crate::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; use crate::stats::StatsReportType; @@ -249,12 +250,16 @@ pub(crate) async fn send_video_until_done( } } +/* pub(crate) async fn until_connection_state( pc: &mut RTCPeerConnection, wg: &WaitGroup, state: RTCPeerConnectionState, ) { let w = Arc::new(Mutex::new(Some(wg.worker()))); + + + pc.on_peer_connection_state_change(Box::new(move |pcs: RTCPeerConnectionState| { let w2 = Arc::clone(&w); Box::pin(async move { @@ -265,6 +270,25 @@ pub(crate) async fn until_connection_state( }) })); } +*/ + +pub struct StateHandler { + pub worker: Arc>>, +} + +impl PeerConnectionEventHandler for StateHandler { + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + if state == RTCPeerConnectionState::Connected { + let mut worker = self.worker.lock().await; + worker.take(); + } + } + } +} #[tokio::test] async fn test_get_stats() -> Result<()> { @@ -276,27 +300,74 @@ async fn test_get_stats() -> Result<()> { let (ice_complete_tx, mut ice_complete_rx) = mpsc::channel::<()>(1); let ice_complete_tx = Arc::new(Mutex::new(Some(ice_complete_tx))); - pc_answer.on_ice_connection_state_change(Box::new(move |ice_state: RTCIceConnectionState| { - let ice_complete_tx2 = Arc::clone(&ice_complete_tx); - Box::pin(async move { - if ice_state == RTCIceConnectionState::Connected { + + struct AnswerHandler { + ice_complete_tx: Arc>>>, + packet_tx: mpsc::Sender<()>, + } + + impl PeerConnectionEventHandler for AnswerHandler { + fn on_ice_connection_state_change( + &mut self, + _: RTCIceConnectionState, + ) -> impl Future + Send { + async move { tokio::time::sleep(Duration::from_secs(1)).await; - let mut done = ice_complete_tx2.lock().await; + let mut done = self.ice_complete_tx.lock().await; done.take(); } - }) - })); + } + + fn on_track( + &mut self, + track: Arc, + _: Arc, + _: Arc, + ) -> impl Future + Send { + let packet_tx = self.packet_tx.clone(); + tokio::spawn(async move { + while let Ok((pkt, _)) = track.read_rtp().await { + dbg!(&pkt); + let last = pkt.payload[pkt.payload.len() - 1]; + + if last == 0xAA { + let _ = packet_tx.send(()).await; + break; + } + } + }); + async {} + } + } + let (packet_tx, packet_rx) = mpsc::channel(1); + pc_answer.with_event_handler(AnswerHandler { + ice_complete_tx, + packet_tx, + }); + + struct OfferHandler { + sender_called_candidate_change: Arc, + } + + impl IceTransportEventHandler for OfferHandler { + fn on_selected_candidate_pair_change( + &mut self, + _: RTCIceCandidatePair, + ) -> impl Future + Send { + self.sender_called_candidate_change + .store(1, Ordering::SeqCst); + async {} + } + } let sender_called_candidate_change = Arc::new(AtomicU32::new(0)); - let sender_called_candidate_change2 = Arc::clone(&sender_called_candidate_change); pc_offer .sctp() .transport() .ice_transport() - .on_selected_candidate_pair_change(Box::new(move |_: RTCIceCandidatePair| { - sender_called_candidate_change2.store(1, Ordering::SeqCst); - Box::pin(async {}) - })); + .with_event_handler(OfferHandler { + sender_called_candidate_change, + }); let track = Arc::new(TrackLocalStaticSample::new( RTCRtpCodecCapability { mime_type: MIME_TYPE_VP8.to_owned(), @@ -309,24 +380,6 @@ async fn test_get_stats() -> Result<()> { .add_track(track.clone()) .await .expect("Failed to add track"); - let (packet_tx, packet_rx) = mpsc::channel(1); - - pc_answer.on_track(Box::new(move |track, _, _| { - let packet_tx = packet_tx.clone(); - tokio::spawn(async move { - while let Ok((pkt, _)) = track.read_rtp().await { - dbg!(&pkt); - let last = pkt.payload[pkt.payload.len() - 1]; - - if last == 0xAA { - let _ = packet_tx.send(()).await; - break; - } - } - }); - - Box::pin(async move {}) - })); signal_pair(&mut pc_offer, &mut pc_answer).await?; diff --git a/webrtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs b/webrtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs index 304bff8f7..cf981e07d 100644 --- a/webrtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs +++ b/webrtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs @@ -1,18 +1,21 @@ use bytes::Bytes; use media::Sample; +use std::future::Future; use tokio::sync::mpsc; use tokio::time::Duration; -use waitgroup::WaitGroup; +use waitgroup::{WaitGroup, Worker}; use super::*; use crate::api::media_engine::{MIME_TYPE_OPUS, MIME_TYPE_VP8}; use crate::error::Result; use crate::peer_connection::peer_connection_state::RTCPeerConnectionState; use crate::peer_connection::peer_connection_test::{ - close_pair_now, create_vnet_pair, signal_pair, until_connection_state, + close_pair_now, create_vnet_pair, signal_pair, StateHandler, }; +use crate::peer_connection::PeerConnectionEventHandler; use crate::rtp_transceiver::rtp_codec::RTCRtpHeaderExtensionParameters; use crate::rtp_transceiver::RTCPFeedback; +use crate::rtp_transceiver::RTCRtpTransceiver; use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; use crate::track::track_local::TrackLocal; @@ -87,49 +90,73 @@ async fn test_set_rtp_parameters() -> Result<()> { let (seen_packet_tx, mut seen_packet_rx) = mpsc::channel::<()>(1); let seen_packet_tx = Arc::new(Mutex::new(Some(seen_packet_tx))); - receiver.on_track(Box::new(move |_, receiver, _| { - let seen_packet_tx2 = Arc::clone(&seen_packet_tx); - Box::pin(async move { - receiver.set_rtp_parameters(P.clone()).await; - - if let Some(t) = receiver.track().await { - let incoming_track_codecs = t.codec(); - - assert_eq!(P.header_extensions, t.params().header_extensions); - assert_eq!( - P.codecs[0].capability.mime_type, - incoming_track_codecs.capability.mime_type - ); - assert_eq!( - P.codecs[0].capability.clock_rate, - incoming_track_codecs.capability.clock_rate - ); - assert_eq!( - P.codecs[0].capability.channels, - incoming_track_codecs.capability.channels - ); - assert_eq!( - P.codecs[0].capability.sdp_fmtp_line, - incoming_track_codecs.capability.sdp_fmtp_line - ); - assert_eq!( - P.codecs[0].capability.rtcp_feedback, - incoming_track_codecs.capability.rtcp_feedback - ); - assert_eq!(P.codecs[0].payload_type, incoming_track_codecs.payload_type); - { - let mut done = seen_packet_tx2.lock().await; + struct TrackHandler { + seen_packet_tx: Arc>>>, + worker: Arc>>, + } + + impl PeerConnectionEventHandler for TrackHandler { + fn on_track( + &mut self, + _: Arc, + receiver: Arc, + _: Arc, + ) -> impl Future + Send { + async move { + receiver.set_rtp_parameters(P.clone()).await; + if let Some(track) = receiver.track().await { + let incoming_track_codecs = track.codec(); + assert_eq!(P.header_extensions, track.params().header_extensions); + assert_eq!( + P.codecs[0].capability.mime_type, + incoming_track_codecs.capability.mime_type + ); + assert_eq!( + P.codecs[0].capability.clock_rate, + incoming_track_codecs.capability.clock_rate + ); + assert_eq!( + P.codecs[0].capability.channels, + incoming_track_codecs.capability.channels + ); + assert_eq!( + P.codecs[0].capability.sdp_fmtp_line, + incoming_track_codecs.capability.sdp_fmtp_line + ); + assert_eq!( + P.codecs[0].capability.rtcp_feedback, + incoming_track_codecs.capability.rtcp_feedback + ); + assert_eq!(P.codecs[0].payload_type, incoming_track_codecs.payload_type); + + let mut done = self.seen_packet_tx.lock().await; done.take(); } } - }) - })); + } + + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + if state == RTCPeerConnectionState::Connected { + let mut worker = self.worker.lock().await; + worker.take(); + } + } + } + } let wg = WaitGroup::new(); - - until_connection_state(&mut sender, &wg, RTCPeerConnectionState::Connected).await; - until_connection_state(&mut receiver, &wg, RTCPeerConnectionState::Connected).await; + receiver.with_event_handler(TrackHandler { + seen_packet_tx: seen_packet_tx.clone(), + worker: Arc::new(Mutex::new(Some(wg.worker()))), + }); + sender.with_event_handler(StateHandler { + worker: Arc::new(Mutex::new(Some(wg.worker()))), + }); signal_pair(&mut sender, &mut receiver).await?; @@ -178,32 +205,61 @@ async fn test_rtp_receiver_set_read_deadline() -> Result<()> { let (seen_packet_tx, mut seen_packet_rx) = mpsc::channel::<()>(1); let seen_packet_tx = Arc::new(Mutex::new(Some(seen_packet_tx))); - receiver.on_track(Box::new(move |track, receiver, _| { - let seen_packet_tx2 = Arc::clone(&seen_packet_tx); - Box::pin(async move { - // First call will not error because we cache for probing - let result = tokio::time::timeout(Duration::from_secs(1), track.read_rtp()).await; - assert!( - result.is_ok(), - " First call will not error because we cache for probing" - ); - - let result = tokio::time::timeout(Duration::from_secs(1), track.read_rtp()).await; - assert!(result.is_err()); - - let result = tokio::time::timeout(Duration::from_secs(1), receiver.read_rtcp()).await; - assert!(result.is_err()); - - { - let mut done = seen_packet_tx2.lock().await; + + struct TrackHandler { + seen_packet_tx: Arc>>>, + worker: Arc>>, + } + + impl PeerConnectionEventHandler for TrackHandler { + fn on_track( + &mut self, + track: Arc, + receiver: Arc, + _: Arc, + ) -> impl Future + Send { + async move { + // First call will not error because we cache for probing + let result = tokio::time::timeout(Duration::from_secs(1), track.read_rtp()).await; + assert!( + result.is_ok(), + " First call will not error because we cache for probing" + ); + + let result = tokio::time::timeout(Duration::from_secs(1), track.read_rtp()).await; + assert!(result.is_err()); + + let result = + tokio::time::timeout(Duration::from_secs(1), receiver.read_rtcp()).await; + assert!(result.is_err()); + + let mut done = self.seen_packet_tx.lock().await; done.take(); } - }) - })); + } + + fn on_peer_connection_state_change( + &mut self, + state: RTCPeerConnectionState, + ) -> impl Future + Send { + async move { + if state == RTCPeerConnectionState::Connected { + let mut worker = self.worker.lock().await; + worker.take(); + } + } + } + } let wg = WaitGroup::new(); - until_connection_state(&mut sender, &wg, RTCPeerConnectionState::Connected).await; - until_connection_state(&mut receiver, &wg, RTCPeerConnectionState::Connected).await; + receiver.with_event_handler(TrackHandler { + seen_packet_tx: seen_packet_tx.clone(), + worker: Arc::new(Mutex::new(Some(wg.worker()))), + }); + + sender.with_event_handler(StateHandler { + worker: Arc::new(Mutex::new(Some(wg.worker()))), + }); signal_pair(&mut sender, &mut receiver).await?; diff --git a/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs b/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs index d06c6b30c..07601c8c8 100644 --- a/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs +++ b/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs @@ -1,6 +1,7 @@ use std::sync::atomic::AtomicU64; use bytes::Bytes; +use std::future::Future; use tokio::time::Duration; use waitgroup::WaitGroup; @@ -11,10 +12,12 @@ use crate::api::APIBuilder; use crate::error::Result; use crate::peer_connection::peer_connection_state::RTCPeerConnectionState; use crate::peer_connection::peer_connection_test::{ - close_pair_now, create_vnet_pair, new_pair, send_video_until_done, signal_pair, - until_connection_state, + close_pair_now, create_vnet_pair, new_pair, send_video_until_done, signal_pair, StateHandler, }; +use crate::peer_connection::PeerConnectionEventHandler; use crate::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; +use crate::rtp_transceiver::{RTCRtpReceiver, RTCRtpTransceiver}; + use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; #[tokio::test] @@ -60,32 +63,49 @@ async fn test_rtp_sender_replace_track() -> Result<()> { let seen_packet_a_tx = Arc::new(seen_packet_a_tx); let seen_packet_b_tx = Arc::new(seen_packet_b_tx); let on_track_count = Arc::new(AtomicU64::new(0)); - receiver.on_track(Box::new(move |track, _, _| { - assert_eq!(on_track_count.fetch_add(1, Ordering::SeqCst), 0); - let seen_packet_a_tx2 = Arc::clone(&seen_packet_a_tx); - let seen_packet_b_tx2 = Arc::clone(&seen_packet_b_tx); - Box::pin(async move { - let pkt = match track.read_rtp().await { - Ok((pkt, _)) => pkt, - Err(err) => { - //assert!(errors.Is(io.EOF, err)) - log::debug!("{}", err); - return; + + struct TrackHandler { + seen_packet_a: Arc>, + seen_packet_b: Arc>, + track_count: Arc, + } + + impl PeerConnectionEventHandler for TrackHandler { + fn on_track( + &mut self, + track: Arc, + _: Arc, + _: Arc, + ) -> impl Future + Send { + async move { + assert_eq!(self.track_count.fetch_add(1, Ordering::SeqCst), 0); + let pkt = match track.read_rtp().await { + Ok((pkt, _)) => pkt, + Err(err) => { + //assert!(errors.Is(io.EOF, err)) + log::debug!("{}", err); + return; + } + }; + + let last = pkt.payload[pkt.payload.len() - 1]; + if last == 0xAA { + assert_eq!(track.codec().capability.mime_type, MIME_TYPE_VP8); + let _ = self.seen_packet_a.send(()).await; + } else if last == 0xBB { + assert_eq!(track.codec().capability.mime_type, MIME_TYPE_H264); + let _ = self.seen_packet_b.send(()).await; + } else { + panic!("Unexpected RTP Data {last:02x}"); } - }; - - let last = pkt.payload[pkt.payload.len() - 1]; - if last == 0xAA { - assert_eq!(track.codec().capability.mime_type, MIME_TYPE_VP8); - let _ = seen_packet_a_tx2.send(()).await; - } else if last == 0xBB { - assert_eq!(track.codec().capability.mime_type, MIME_TYPE_H264); - let _ = seen_packet_b_tx2.send(()).await; - } else { - panic!("Unexpected RTP Data {last:02x}"); } - }) - })); + } + } + receiver.with_event_handler(TrackHandler { + seen_packet_a: seen_packet_a_tx, + seen_packet_b: seen_packet_b_tx, + track_count: on_track_count.clone(), + }); signal_pair(&mut sender, &mut receiver).await?; @@ -163,18 +183,12 @@ async fn test_rtp_sender_set_read_deadline() -> Result<()> { .await?; let peer_connections_connected = WaitGroup::new(); - until_connection_state( - &mut sender, - &peer_connections_connected, - RTCPeerConnectionState::Connected, - ) - .await; - until_connection_state( - &mut receiver, - &peer_connections_connected, - RTCPeerConnectionState::Connected, - ) - .await; + sender.with_event_handler(StateHandler { + worker: Arc::new(Mutex::new(Some(peer_connections_connected.worker()))), + }); + receiver.with_event_handler(StateHandler { + worker: Arc::new(Mutex::new(Some(peer_connections_connected.worker()))), + }); signal_pair(&mut sender, &mut receiver).await?; @@ -192,6 +206,23 @@ async fn test_rtp_sender_set_read_deadline() -> Result<()> { Ok(()) } +struct TrackPacketHandler { + seen_packet_tx: Arc>, +} + +impl PeerConnectionEventHandler for TrackPacketHandler { + fn on_track( + &mut self, + _: Arc, + _: Arc, + _: Arc, + ) -> impl Future + Send { + async move { + let _ = self.seen_packet_tx.send(()).await; + } + } +} + #[tokio::test] async fn test_rtp_sender_replace_track_invalid_track_kind_change() -> Result<()> { let mut m = MediaEngine::default(); @@ -226,12 +257,8 @@ async fn test_rtp_sender_replace_track_invalid_track_kind_change() -> Result<()> let (seen_packet_tx, seen_packet_rx) = mpsc::channel::<()>(1); let seen_packet_tx = Arc::new(seen_packet_tx); - receiver.on_track(Box::new(move |_, _, _| { - let seen_packet_tx2 = Arc::clone(&seen_packet_tx); - Box::pin(async move { - let _ = seen_packet_tx2.send(()).await; - }) - })); + + receiver.with_event_handler(TrackPacketHandler { seen_packet_tx }); tokio::spawn(async move { send_video_until_done( @@ -308,12 +335,8 @@ async fn test_rtp_sender_replace_track_invalid_codec_change() -> Result<()> { let (seen_packet_tx, seen_packet_rx) = mpsc::channel::<()>(1); let seen_packet_tx = Arc::new(seen_packet_tx); - receiver.on_track(Box::new(move |_, _, _| { - let seen_packet_tx2 = Arc::clone(&seen_packet_tx); - Box::pin(async move { - let _ = seen_packet_tx2.send(()).await; - }) - })); + + receiver.with_event_handler(TrackPacketHandler { seen_packet_tx }); tokio::spawn(async move { send_video_until_done( diff --git a/webrtc/src/rtp_transceiver/rtp_transceiver_test.rs b/webrtc/src/rtp_transceiver/rtp_transceiver_test.rs index 3cecc542c..e378cc178 100644 --- a/webrtc/src/rtp_transceiver/rtp_transceiver_test.rs +++ b/webrtc/src/rtp_transceiver/rtp_transceiver_test.rs @@ -6,6 +6,7 @@ use crate::api::APIBuilder; use crate::dtls_transport::RTCDtlsTransport; use crate::peer_connection::configuration::RTCConfiguration; use crate::peer_connection::peer_connection_test::{close_pair_now, create_vnet_pair}; +use crate::peer_connection::PeerConnectionEventHandler; #[tokio::test] async fn test_rtp_transceiver_set_codec_preferences() -> Result<()> { @@ -259,16 +260,21 @@ async fn test_rtp_transceiver_set_direction_causing_negotiation() -> Result<()> let count = Arc::new(AtomicUsize::new(0)); - { - let count = count.clone(); - offer_pc.on_negotiation_needed(Box::new(move || { - let count = count.clone(); - Box::pin(async move { - count.fetch_add(1, Ordering::SeqCst); - }) - })); + struct NegotiationCounter { + count: Arc, } + impl PeerConnectionEventHandler for NegotiationCounter { + fn on_negotiation_needed(&mut self) -> impl Future + Send { + self.count.fetch_add(1, Ordering::SeqCst); + async {} + } + } + + offer_pc.with_event_handler(NegotiationCounter { + count: count.clone(), + }); + let offer_transceiver = offer_pc .add_transceiver_from_kind(RTPCodecType::Video, None) .await?; diff --git a/webrtc/src/sctp_transport/mod.rs b/webrtc/src/sctp_transport/mod.rs index 8927070ae..87911c581 100644 --- a/webrtc/src/sctp_transport/mod.rs +++ b/webrtc/src/sctp_transport/mod.rs @@ -10,13 +10,12 @@ use std::pin::Pin; use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU8, Ordering}; use std::sync::Arc; -use arc_swap::ArcSwapOption; use data::data_channel::DataChannel; use data::message::message_channel_open::ChannelType; use sctp::association::Association; use sctp_transport_state::RTCSctpTransportState; use tokio::sync::{Mutex, Notify}; -use util::Conn; +use util::{Conn, EventHandler, FutureUnit}; use crate::api::setting_engine::SettingEngine; use crate::data_channel::data_channel_parameters::DataChannelParameters; @@ -48,9 +47,7 @@ struct AcceptDataChannelParams { notify_rx: Arc, sctp_association: Arc, data_channels: Arc>>>, - on_error_handler: Arc>>, - on_data_channel_handler: Arc>>, - on_data_channel_opened_handler: Arc>>, + events_handler: Arc>, data_channels_opened: Arc, data_channels_accepted: Arc, setting_engine: Arc, @@ -78,9 +75,7 @@ pub struct RTCSctpTransport { sctp_association: Mutex>>, - on_error_handler: Arc>>, - on_data_channel_handler: Arc>>, - on_data_channel_opened_handler: Arc>>, + pub(crate) events_handler: Arc>, // DataChannels pub(crate) data_channels: Arc>>>, @@ -93,6 +88,48 @@ pub struct RTCSctpTransport { setting_engine: Arc, } +pub trait SctpTransportEventHandler: Send { + /// on_error sets an event handler which is invoked when + /// the SCTP connection error occurs. + fn on_error(&mut self, err: crate::error::Error) -> impl Future + Send { + async {} + } + /// on_data_channel sets an event handler which is invoked when a data + /// channel message arrives from a remote peer. + fn on_data_channel(&mut self, channel: Arc) -> impl Future + Send { + async {} + } + /// on_data_channel_opened sets an event handler which is invoked when a data + /// channel is opened + fn on_data_channel_opened( + &mut self, + channel: Arc, + ) -> impl Future + Send { + async {} + } +} + +trait InlineSctpTransportEventHandler: Send { + fn inline_on_error(&mut self, err: crate::error::Error) -> FutureUnit<'_>; + fn inline_on_data_channel(&mut self, channel: Arc) -> FutureUnit<'_>; + fn inline_on_data_channel_opened(&mut self, channel: Arc) -> FutureUnit<'_>; +} + +impl InlineSctpTransportEventHandler for T +where + T: SctpTransportEventHandler, +{ + fn inline_on_error(&mut self, err: crate::error::Error) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_error(err).await }) + } + fn inline_on_data_channel(&mut self, channel: Arc) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_data_channel(channel).await }) + } + fn inline_on_data_channel_opened(&mut self, channel: Arc) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_data_channel_opened(channel).await }) + } +} + impl RTCSctpTransport { pub(crate) fn new( dtls_transport: Arc, @@ -105,9 +142,7 @@ impl RTCSctpTransport { max_message_size: RTCSctpTransport::calc_message_size(65536, 65536), max_channels: SCTP_MAX_CHANNELS, sctp_association: Mutex::new(None), - on_error_handler: Arc::new(ArcSwapOption::empty()), - on_data_channel_handler: Arc::new(ArcSwapOption::empty()), - on_data_channel_opened_handler: Arc::new(ArcSwapOption::empty()), + events_handler: Arc::new(EventHandler::empty()), data_channels: Arc::new(Mutex::new(vec![])), data_channels_opened: Arc::new(AtomicU32::new(0)), @@ -175,9 +210,7 @@ impl RTCSctpTransport { notify_rx: self.notify_tx.clone(), sctp_association, data_channels: Arc::clone(&self.data_channels), - on_error_handler: Arc::clone(&self.on_error_handler), - on_data_channel_handler: Arc::clone(&self.on_data_channel_handler), - on_data_channel_opened_handler: Arc::clone(&self.on_data_channel_opened_handler), + events_handler: Arc::clone(&self.events_handler), data_channels_opened: Arc::clone(&self.data_channels_opened), data_channels_accepted: Arc::clone(&self.data_channels_accepted), setting_engine: Arc::clone(&self.setting_engine), @@ -210,14 +243,15 @@ impl RTCSctpTransport { } async fn accept_data_channels(param: AcceptDataChannelParams) { - let dcs = param.data_channels.lock().await; let mut existing_data_channels = Vec::new(); - for dc in dcs.iter() { - if let Some(dc) = dc.data_channel.lock().await.clone() { - existing_data_channels.push(dc); + { + let dcs = param.data_channels.lock().await; + for dc in dcs.iter() { + if let Some(dc) = dc.data_channel.lock().await.clone() { + existing_data_channels.push(dc); + } } - } - drop(dcs); + } //we want to drop `dcs` here to free the mutex lock before looping loop { let dc = tokio::select! { @@ -232,9 +266,9 @@ impl RTCSctpTransport { Err(err) => { if data::Error::ErrStreamClosed == err { log::error!("Failed to accept data channel: {}", err); - if let Some(handler) = &*param.on_error_handler.load() { - let mut f = handler.lock().await; - f(err.into()).await; + if let Some(handle) = &*param.events_handler.load() { + let mut handle = handle.lock().await; + handle.inline_on_error(err.into()).await; } } break; @@ -248,6 +282,7 @@ impl RTCSctpTransport { let val = dc.config.reliability_parameter as u16; let ordered; + //FIXME: this should be moved into its own fn. match dc.config.channel_type { ChannelType::Reliable => { ordered = true; @@ -290,10 +325,9 @@ impl RTCSctpTransport { Arc::clone(¶m.setting_engine), )); - if let Some(handler) = &*param.on_data_channel_handler.load() { - let mut f = handler.lock().await; - f(Arc::clone(&rtc_dc)).await; - + if let Some(handle) = &*param.events_handler.load() { + let mut handle = handle.lock().await; + handle.inline_on_data_channel(rtc_dc.clone()).await; param.data_channels_accepted.fetch_add(1, Ordering::SeqCst); let mut dcs = param.data_channels.lock().await; @@ -302,32 +336,19 @@ impl RTCSctpTransport { rtc_dc.handle_open(Arc::new(dc)).await; - if let Some(handler) = &*param.on_data_channel_opened_handler.load() { - let mut f = handler.lock().await; - f(rtc_dc).await; + if let Some(handle) = &*param.events_handler.load() { + let mut handle = handle.lock().await; + handle.inline_on_data_channel_opened(rtc_dc).await; param.data_channels_opened.fetch_add(1, Ordering::SeqCst); } } } - /// on_error sets an event handler which is invoked when - /// the SCTP connection error occurs. - pub fn on_error(&self, f: OnErrorHdlrFn) { - self.on_error_handler.store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_data_channel sets an event handler which is invoked when a data - /// channel message arrives from a remote peer. - pub fn on_data_channel(&self, f: OnDataChannelHdlrFn) { - self.on_data_channel_handler - .store(Some(Arc::new(Mutex::new(f)))); - } - - /// on_data_channel_opened sets an event handler which is invoked when a data - /// channel is opened - pub fn on_data_channel_opened(&self, f: OnDataChannelOpenedHdlrFn) { - self.on_data_channel_opened_handler - .store(Some(Arc::new(Mutex::new(f)))); + pub fn with_event_handler( + &self, + handler: impl SctpTransportEventHandler + Send + Sync + 'static, + ) { + self.events_handler.store(Box::new(handler)) } fn calc_message_size(remote_max_message_size: usize, can_send_size: usize) -> usize { diff --git a/webrtc/src/track/track_local/track_local_static_test.rs b/webrtc/src/track/track_local/track_local_static_test.rs index b385d98ac..e7eda6a22 100644 --- a/webrtc/src/track/track_local/track_local_static_test.rs +++ b/webrtc/src/track/track_local/track_local_static_test.rs @@ -10,6 +10,8 @@ use crate::api::media_engine::{MediaEngine, MIME_TYPE_VP8}; use crate::api::APIBuilder; use crate::peer_connection::configuration::RTCConfiguration; use crate::peer_connection::peer_connection_test::*; +use crate::peer_connection::PeerConnectionEventHandler; +use crate::rtp_transceiver::rtp_receiver::RTCRtpReceiver; // If a remote doesn't support a Codec used by a `TrackLocalStatic` // an error should be returned to the user @@ -258,18 +260,31 @@ async fn test_track_local_static_payload_type() -> Result<()> { let (on_track_fired_tx, on_track_fired_rx) = mpsc::channel::<()>(1); let on_track_fired_tx = Arc::new(Mutex::new(Some(on_track_fired_tx))); - offerer.on_track(Box::new(move |track, _, _| { - let on_track_fired_tx2 = Arc::clone(&on_track_fired_tx); - Box::pin(async move { - assert_eq!(track.payload_type(), 100); - assert_eq!(track.codec().capability.mime_type, MIME_TYPE_VP8); - { + + struct OfferHandler { + fired_tx: Arc>>>, + } + + impl PeerConnectionEventHandler for OfferHandler { + fn on_track( + &mut self, + track: Arc, + _: Arc, + _: Arc, + ) -> impl std::future::Future + Send { + async move { + assert_eq!(track.payload_type(), 100); + assert_eq!(track.codec().capability.mime_type, MIME_TYPE_VP8); log::debug!("onTrackFiredFunc!!!"); - let mut done = on_track_fired_tx2.lock().await; + let mut done = self.fired_tx.lock().await; done.take(); } - }) - })); + } + } + + offerer.with_event_handler(OfferHandler { + fired_tx: on_track_fired_tx, + }); signal_pair(&mut offerer, &mut answerer).await?; diff --git a/webrtc/src/track/track_remote/mod.rs b/webrtc/src/track/track_remote/mod.rs index b5cd4f58d..50963a17b 100644 --- a/webrtc/src/track/track_remote/mod.rs +++ b/webrtc/src/track/track_remote/mod.rs @@ -9,6 +9,7 @@ use interceptor::{Attributes, Interceptor}; use smol_str::SmolStr; use tokio::sync::Mutex; use util::sync::Mutex as SyncMutex; +use util::{EventHandler, FutureUnit}; use crate::api::media_engine::MediaEngine; use crate::error::{Error, Result}; @@ -29,6 +30,32 @@ struct Handlers { on_unmute: ArcSwapOption>, } +pub trait TrackRemoteEventHandler: Send { + fn on_mute(&mut self) -> impl Future + Send { + async {} + } + fn on_unmute(&mut self) -> impl Future + Send { + async {} + } +} + +trait InlineTrackRemoteEventHandler: Send { + fn inline_on_mute(&mut self) -> FutureUnit<'_>; + fn inline_on_unmute(&mut self) -> FutureUnit<'_>; +} + +impl InlineTrackRemoteEventHandler for T +where + T: TrackRemoteEventHandler, +{ + fn inline_on_mute(&mut self) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_mute().await }) + } + fn inline_on_unmute(&mut self) -> FutureUnit<'_> { + FutureUnit::from_async(async move { self.on_unmute().await }) + } +} + #[derive(Default)] struct TrackRemoteInternal { peeked: VecDeque<(rtp::packet::Packet, Attributes)>, @@ -52,7 +79,7 @@ pub struct TrackRemote { media_engine: Arc, interceptor: Arc, - handlers: Arc, + events_handler: Arc>, receiver: Option>, internal: Mutex, @@ -97,8 +124,7 @@ impl TrackRemote { receiver: Some(receiver), media_engine, interceptor, - handlers: Default::default(), - + events_handler: Default::default(), internal: Default::default(), } } @@ -191,22 +217,11 @@ impl TrackRemote { *p = params; } - pub fn onmute(&self, handler: F) - where - F: FnMut() -> Pin + Send + 'static>> + Send + 'static + Sync, - { - self.handlers - .on_mute - .store(Some(Arc::new(Mutex::new(Box::new(handler))))); - } - - pub fn onunmute(&self, handler: F) - where - F: FnMut() -> Pin + Send + 'static>> + Send + 'static + Sync, - { - self.handlers - .on_unmute - .store(Some(Arc::new(Mutex::new(Box::new(handler))))); + pub fn with_event_handler( + &self, + handler: impl TrackRemoteEventHandler + Send + Sync + 'static, + ) { + self.events_handler.store(Box::new(handler)); } /// Reads data from the track. @@ -303,18 +318,16 @@ impl TrackRemote { } pub(crate) async fn fire_onmute(&self) { - let on_mute = self.handlers.on_mute.load(); - - if let Some(f) = on_mute.as_ref() { - (f.lock().await)().await - }; + if let Some(handler) = &*self.events_handler.load() { + let mut handle = handler.lock().await; + handle.inline_on_mute().await; + } } pub(crate) async fn fire_onunmute(&self) { - let on_unmute = self.handlers.on_unmute.load(); - - if let Some(f) = on_unmute.as_ref() { - (f.lock().await)().await - }; + if let Some(handler) = &*self.events_handler.load() { + let mut handle = handler.lock().await; + handle.inline_on_unmute().await; + } } }