Skip to content

Commit f05ec8c

Browse files
feat(provider): emit events about outgoing transfers
1 parent 2bc899a commit f05ec8c

File tree

2 files changed

+108
-22
lines changed

2 files changed

+108
-22
lines changed

src/lib.rs

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ mod tests {
2828
use anyhow::{anyhow, Context, Result};
2929
use rand::RngCore;
3030
use testdir::testdir;
31-
use tokio::fs;
3231
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
32+
use tokio::{fs, sync::broadcast};
3333
use tracing_subscriber::{prelude::*, EnvFilter};
3434

3535
use crate::protocol::AuthToken;
@@ -224,20 +224,31 @@ mod tests {
224224
let mut provider_events = provider.subscribe();
225225
let events_task = tokio::task::spawn(async move {
226226
let mut events = Vec::new();
227-
while let Ok(event) = provider_events.recv().await {
228-
match event {
229-
Event::TransferCompleted { .. } | Event::TransferAborted { .. } => {
230-
events.push(event);
231-
break;
232-
}
233-
_ => events.push(event),
227+
loop {
228+
match provider_events.recv().await {
229+
Ok(event) => match event {
230+
Event::TransferCollectionCompleted { .. }
231+
| Event::TransferAborted { .. } => {
232+
events.push(event);
233+
break;
234+
}
235+
_ => events.push(event),
236+
},
237+
Err(e) => match e {
238+
broadcast::error::RecvError::Closed => {
239+
break;
240+
}
241+
broadcast::error::RecvError::Lagged(num) => {
242+
panic!("unable to keep up, skipped {num} messages");
243+
}
244+
},
234245
}
235246
}
236247
events
237248
});
238249

239250
let opts = get::Options {
240-
addr: provider.listen_addr(),
251+
addr: dbg!(provider.listen_addr()),
241252
peer_id: Some(provider.peer_id()),
242253
keylog: true,
243254
};
@@ -281,16 +292,35 @@ mod tests {
281292
provider.shutdown();
282293
provider.await?;
283294

284-
assert_events(events);
295+
assert_events(events, num_blobs);
285296

286297
Ok(())
287298
}
288299

289-
fn assert_events(events: Vec<Event>) {
290-
assert_eq!(events.len(), 3);
300+
fn assert_events(events: Vec<Event>, num_blobs: usize) {
301+
let num_basic_events = 4;
302+
let num_total_events = num_basic_events + num_blobs;
303+
assert_eq!(
304+
events.len(),
305+
num_total_events,
306+
"missing events, only got {:#?}",
307+
events
308+
);
291309
assert!(matches!(events[0], Event::ClientConnected { .. }));
292310
assert!(matches!(events[1], Event::RequestReceived { .. }));
293-
assert!(matches!(events[2], Event::TransferCompleted { .. }));
311+
assert!(matches!(events[2], Event::TransferCollectionStarted { .. }));
312+
for (i, event) in events[3..num_total_events - 1].iter().enumerate() {
313+
match event {
314+
Event::TransferBlobCompleted { index, .. } => {
315+
assert_eq!(*index, i as u64);
316+
}
317+
_ => panic!("unexpected event {:?}", event),
318+
}
319+
}
320+
assert!(matches!(
321+
events.last().unwrap(),
322+
Event::TransferCollectionCompleted { .. }
323+
));
294324
}
295325

296326
fn setup_logging() {
@@ -328,7 +358,7 @@ mod tests {
328358
match maybe_event {
329359
Ok(event) => {
330360
match event {
331-
Event::TransferCompleted { .. } => provider.shutdown(),
361+
Event::TransferCollectionCompleted { .. } => provider.shutdown(),
332362
Event::TransferAborted { .. } => {
333363
break Err(anyhow!("transfer aborted"));
334364
}

src/provider/mod.rs

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -327,12 +327,36 @@ pub enum Event {
327327
/// The hash for which the client wants to receive data.
328328
hash: Hash,
329329
},
330-
/// A request was completed and the data was sent to the client.
331-
TransferCompleted {
330+
/// A collection has been found and is being transferred.
331+
TransferCollectionStarted {
332332
/// An unique connection id.
333333
connection_id: u64,
334334
/// An identifier uniquely identifying this transfer request.
335335
request_id: u64,
336+
/// The number of blobs in the collection.
337+
num_blobs: u64,
338+
/// The total blob size of the data.
339+
total_blobs_size: u64,
340+
},
341+
/// A collection request was completed and the data was sent to the client.
342+
TransferCollectionCompleted {
343+
/// An unique connection id.
344+
connection_id: u64,
345+
/// An identifier uniquely identifying this transfer request.
346+
request_id: u64,
347+
},
348+
/// A blob in a collection was transferred.
349+
TransferBlobCompleted {
350+
/// An unique connection id.
351+
connection_id: u64,
352+
/// An identifier uniquely identifying this transfer request.
353+
request_id: u64,
354+
/// The hash of the blob
355+
hash: Hash,
356+
/// The index of the blob in the collection.
357+
index: u64,
358+
/// The size of the blob transferred.
359+
size: u64,
336360
},
337361
/// A request was aborted because the client disconnected.
338362
TransferAborted {
@@ -625,6 +649,7 @@ async fn read_request(mut reader: quinn::RecvStream, buffer: &mut BytesMut) -> R
625649
/// close the writer, and return with `Ok(SentStatus::NotFound)`.
626650
///
627651
/// If the transfer does _not_ end in error, the buffer will be empty and the writer is gracefully closed.
652+
#[allow(clippy::too_many_arguments)]
628653
async fn transfer_collection(
629654
// Database from which to fetch blobs.
630655
db: &Database,
@@ -636,6 +661,9 @@ async fn transfer_collection(
636661
outboard: &Bytes,
637662
// The actual blob data.
638663
data: &Bytes,
664+
events: broadcast::Sender<Event>,
665+
connection_id: u64,
666+
request_id: u64,
639667
) -> Result<SentStatus> {
640668
// We only respond to requests for collections, not individual blobs
641669
let mut extractor = SliceExtractor::new_outboard(
@@ -652,6 +680,13 @@ async fn transfer_collection(
652680

653681
let c: Collection = postcard::from_bytes(data)?;
654682

683+
let _ = events.send(Event::TransferCollectionStarted {
684+
connection_id,
685+
request_id,
686+
num_blobs: c.blobs.len() as u64,
687+
total_blobs_size: c.total_blobs_size,
688+
});
689+
655690
// TODO: we should check if the blobs referenced in this container
656691
// actually exist in this provider before returning `FoundCollection`
657692
write_response(
@@ -667,12 +702,21 @@ async fn transfer_collection(
667702
writer.write_buf(&mut data).await?;
668703
for (i, blob) in c.blobs.iter().enumerate() {
669704
debug!("writing blob {}/{}", i, c.blobs.len());
670-
let (status, writer1) = send_blob(db.clone(), blob.hash, writer, buffer).await?;
705+
tokio::task::yield_now().await;
706+
let (status, writer1, size) = send_blob(db.clone(), blob.hash, writer, buffer).await?;
671707
writer = writer1;
672708
if SentStatus::NotFound == status {
673709
writer.finish().await?;
674710
return Ok(status);
675711
}
712+
713+
let _ = events.send(Event::TransferBlobCompleted {
714+
connection_id,
715+
request_id,
716+
hash: blob.hash,
717+
index: i as u64,
718+
size,
719+
});
676720
}
677721

678722
writer.finish().await?;
@@ -740,9 +784,20 @@ async fn handle_stream(
740784
};
741785

742786
// 5. Transfer data!
743-
match transfer_collection(&db, writer, &mut out_buffer, &outboard, &data).await {
787+
match transfer_collection(
788+
&db,
789+
writer,
790+
&mut out_buffer,
791+
&outboard,
792+
&data,
793+
events.clone(),
794+
connection_id,
795+
request_id,
796+
)
797+
.await
798+
{
744799
Ok(SentStatus::Sent) => {
745-
let _ = events.send(Event::TransferCompleted {
800+
let _ = events.send(Event::TransferCollectionCompleted {
746801
connection_id,
747802
request_id,
748803
});
@@ -771,7 +826,7 @@ async fn send_blob<W: AsyncWrite + Unpin + Send + 'static>(
771826
name: Hash,
772827
mut writer: W,
773828
buffer: &mut BytesMut,
774-
) -> Result<(SentStatus, W)> {
829+
) -> Result<(SentStatus, W, u64)> {
775830
match db.get(&name) {
776831
Some(BlobOrCollection::Blob(Data {
777832
outboard,
@@ -796,11 +851,12 @@ async fn send_blob<W: AsyncWrite + Unpin + Send + 'static>(
796851
std::io::Result::Ok(writer)
797852
})
798853
.await??;
799-
Ok((SentStatus::Sent, writer))
854+
855+
Ok((SentStatus::Sent, writer, size))
800856
}
801857
_ => {
802858
write_response(&mut writer, buffer, Res::NotFound).await?;
803-
Ok((SentStatus::NotFound, writer))
859+
Ok((SentStatus::NotFound, writer, 0))
804860
}
805861
}
806862
}

0 commit comments

Comments
 (0)