diff --git a/h3i/examples/content_length_mismatch.rs b/h3i/examples/content_length_mismatch.rs index cff1bfacb7..b3fd29bf1a 100644 --- a/h3i/examples/content_length_mismatch.rs +++ b/h3i/examples/content_length_mismatch.rs @@ -67,7 +67,7 @@ fn main() { ]; let summary = - sync_client::connect(config, &actions).expect("connection failed"); + sync_client::connect(config, &actions, None).expect("connection failed"); println!( "=== received connection summary! ===\n\n{}", diff --git a/h3i/src/client/connection_summary.rs b/h3i/src/client/connection_summary.rs index 9f72a2a73f..9945e3afa6 100644 --- a/h3i/src/client/connection_summary.rs +++ b/h3i/src/client/connection_summary.rs @@ -39,6 +39,7 @@ use std::collections::HashMap; use std::iter::FromIterator; use crate::frame::EnrichedHeaders; +use crate::frame::ExpectedFrame; use crate::frame::H3iFrame; /// Maximum length of any serialized element's unstructured data such as reason @@ -57,6 +58,8 @@ pub struct ConnectionSummary { pub path_stats: Vec, /// Details about why the connection closed. pub conn_close_details: ConnectionCloseDetails, + /// [`ExpectedFrame`]s that were not received. + pub missing_frames: Option>, } impl Serialize for ConnectionSummary { @@ -74,6 +77,7 @@ impl Serialize for ConnectionSummary { self.path_stats.iter().map(SerializablePathStats).collect(); state.serialize_field("path_stats", &p)?; state.serialize_field("error", &self.conn_close_details)?; + state.serialize_field("missed_expected_frames", &self.missing_frames)?; state.end() } } @@ -81,7 +85,10 @@ impl Serialize for ConnectionSummary { /// A read-only aggregation of frames received over a connection, mapped to the /// stream ID over which they were received. #[derive(Clone, Debug, Default, Serialize)] -pub struct StreamMap(HashMap>); +pub struct StreamMap { + map: HashMap>, + expected_frames: Option, +} impl From for StreamMap where @@ -89,7 +96,10 @@ where { fn from(value: T) -> Self { let map = HashMap::from_iter(value); - Self(map) + Self { + map, + expected_frames: None, + } } } @@ -113,7 +123,7 @@ impl StreamMap { /// assert_eq!(stream_map.all_frames(), vec![headers]); /// ``` pub fn all_frames(&self) -> Vec { - self.0 + self.map .values() .flatten() .map(Clone::clone) @@ -140,7 +150,7 @@ impl StreamMap { /// assert_eq!(stream_map.stream(0), vec![headers]); /// ``` pub fn stream(&self, stream_id: u64) -> Vec { - self.0.get(&stream_id).cloned().unwrap_or_default() + self.map.get(&stream_id).cloned().unwrap_or_default() } /// Check if a provided [`H3iFrame`] was received, regardless of what stream @@ -189,7 +199,7 @@ impl StreamMap { pub fn received_frame_on_stream( &self, stream: u64, frame: &H3iFrame, ) -> bool { - self.0.get(&stream).map(|v| v.contains(frame)).is_some() + self.map.get(&stream).map(|v| v.contains(frame)).is_some() } /// Check if the stream map is empty, e.g., no frames were received. @@ -213,7 +223,7 @@ impl StreamMap { /// assert!(!stream_map.is_empty()); /// ``` pub fn is_empty(&self) -> bool { - self.0.is_empty() + self.map.is_empty() } /// See all HEADERS received on a given stream. @@ -246,8 +256,57 @@ impl StreamMap { .collect() } + pub(crate) fn new(expected: Option>) -> Self { + Self { + expected_frames: expected.map(|e| ExpectedFrames::new(e)), + ..Default::default() + } + } + pub(crate) fn insert(&mut self, stream_id: u64, frame: H3iFrame) { - self.0.entry(stream_id).or_default().push(frame); + if let Some(expected) = self.expected_frames.as_mut() { + expected.receive_frame(stream_id, &frame); + } + + self.map.entry(stream_id).or_default().push(frame); + } + + pub(crate) fn saw_all_expected_frames(&self) -> bool { + self.expected_frames + .as_ref() + .is_some_and(|e| e.saw_all_frames()) + } + + pub(crate) fn missing_frames(&self) -> Option> { + self.expected_frames.as_ref().map(|e| e.missing_frames()) + } +} + +#[derive(Serialize, Clone, Debug)] +struct ExpectedFrames { + missing: Vec, +} + +impl ExpectedFrames { + fn new(frames: Vec) -> Self { + Self { missing: frames } + } + + fn receive_frame(&mut self, stream_id: u64, frame: &H3iFrame) { + for (i, ef) in self.missing.iter_mut().enumerate() { + if ef.is_equivalent(frame) && ef.stream_id() == stream_id { + self.missing.remove(i); + break; + } + } + } + + fn saw_all_frames(&self) -> bool { + self.missing.is_empty() + } + + fn missing_frames(&self) -> Vec { + self.missing.clone() } } @@ -422,3 +481,73 @@ impl Serialize for SerializableConnectionError<'_> { state.end() } } + +#[cfg(test)] +mod tests { + use super::*; + use quiche::h3::Header; + + fn h3i_frame() -> H3iFrame { + vec![Header::new(b"hello", b"world")].into() + } + + #[test] + fn expected_frame() { + let frame = h3i_frame(); + let mut expected = + ExpectedFrames::new(vec![ExpectedFrame::new(0, frame.clone())]); + + expected.receive_frame(0, &frame); + + assert!(expected.saw_all_frames()); + } + + #[test] + fn expected_frame_missing() { + let frame = h3i_frame(); + let expected_frames = vec![ + ExpectedFrame::new(0, frame.clone()), + ExpectedFrame::new(4, frame.clone()), + ExpectedFrame::new(8, vec![Header::new(b"go", b"jets")].into()), + ]; + let mut expected = ExpectedFrames::new(expected_frames.clone()); + + expected.receive_frame(0, &frame); + + assert!(!expected.saw_all_frames()); + assert_eq!(expected.missing_frames(), expected_frames[1..].to_vec()); + } + + fn stream_map_data() -> Vec { + let headers = + H3iFrame::Headers(EnrichedHeaders::from(vec![Header::new( + b"hello", b"world", + )])); + let data = H3iFrame::QuicheH3(quiche::h3::frame::Frame::Data { + payload: b"hello world".to_vec(), + }); + + vec![headers, data] + } + + #[test] + fn test_stream_map_expected_frames_with_none() { + let stream_map: StreamMap = vec![(0, stream_map_data())].into(); + assert!(!stream_map.saw_all_expected_frames()); + } + + #[test] + fn test_stream_map_expected_frames() { + let data = stream_map_data(); + let mut stream_map = StreamMap::new(Some(vec![ + ExpectedFrame::new(0, data[0].clone()), + ExpectedFrame::new(0, data[1].clone()), + ])); + + stream_map.insert(0, data[0].clone()); + assert!(!stream_map.saw_all_expected_frames()); + assert_eq!(stream_map.missing_frames().unwrap(), vec![ + ExpectedFrame::new(0, data[1].clone()) + ]); + } +} diff --git a/h3i/src/client/sync_client.rs b/h3i/src/client/sync_client.rs index a24b3cecfd..178afa246b 100644 --- a/h3i/src/client/sync_client.rs +++ b/h3i/src/client/sync_client.rs @@ -31,6 +31,7 @@ use std::slice::Iter; use std::time::Duration; use std::time::Instant; +use crate::frame::ExpectedFrame; use crate::frame::H3iFrame; use crate::quiche; @@ -57,6 +58,15 @@ struct SyncClient { stream_parsers: StreamParserMap, } +impl SyncClient { + fn new(expected_frames: Option>) -> Self { + Self { + streams: StreamMap::new(expected_frames), + ..Default::default() + } + } +} + impl Client for SyncClient { fn stream_parsers_mut(&mut self) -> &mut StreamParserMap { &mut self.stream_parsers @@ -74,7 +84,7 @@ impl Client for SyncClient { /// /// Returns a [ConnectionSummary] on success, [ClientError] on failure. pub fn connect( - args: Config, actions: &[Action], + args: Config, actions: &[Action], expected_frames: Option>, ) -> std::result::Result { let mut buf = [0; 65535]; let mut out = [0; MAX_DATAGRAM_SIZE]; @@ -142,8 +152,7 @@ pub fn connect( let mut wait_duration = None; let mut wait_instant = None; - let mut client = SyncClient::default(); - + let mut client = SyncClient::new(expected_frames); let mut waiting_for = WaitingFor::default(); loop { @@ -277,6 +286,14 @@ pub fn connect( wait_cleared = true; } + if client.streams.saw_all_expected_frames() { + let _ = conn.close( + true, + quiche::h3::WireErrorCode::NoError as u64, + b"saw all expected frames", + ); + } + if wait_cleared { check_duration_and_do_actions( &mut wait_duration, @@ -370,11 +387,13 @@ pub fn connect( } } + let missing_frames = client.streams.missing_frames(); Ok(ConnectionSummary { stream_map: client.streams, stats: Some(conn.stats()), path_stats: conn.path_stats().collect(), conn_close_details: ConnectionCloseDetails::new(&conn), + missing_frames, }) } diff --git a/h3i/src/config.rs b/h3i/src/config.rs index bb57984448..a057b0ed23 100644 --- a/h3i/src/config.rs +++ b/h3i/src/config.rs @@ -28,6 +28,7 @@ use std::io; /// Server details and QUIC connection properties. +#[derive(Clone)] pub struct Config { /// A string representing the host and port to connect to using the format /// `:`. diff --git a/h3i/src/frame.rs b/h3i/src/frame.rs index dd642e6518..304e13c58c 100644 --- a/h3i/src/frame.rs +++ b/h3i/src/frame.rs @@ -48,7 +48,7 @@ pub type BoxError = Box; /// An internal representation of a QUIC or HTTP/3 frame. This type exists so /// that we can extend types defined in Quiche. -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Eq, PartialEq, Clone)] pub enum H3iFrame { /// A wrapper around a quiche HTTP/3 frame. QuicheH3(QFrame), @@ -69,6 +69,49 @@ impl H3iFrame { None } } + + /// Check if this [`H3iFrame`] is equivalent to another. For + /// QuicheH3/ResetStream variants, equivalence is the same as equality. + /// For Headers variants, this [`H3iFrame`] is equivalent to another if + /// the other frame contains all [`Header`]s in _this_ frame. + /// + /// # Example + /// + /// ``` + /// use h3i::frame::H3iFrame; + /// use quiche::h3::Header; + /// + /// let this: H3iFrame = vec![Header::new(b"hello", b"world")].into(); + /// let other: H3iFrame = vec![ + /// Header::new(b"hello", b"world"), + /// Header::new(b"go", b"jets") + /// ].into(); + /// + /// assert!(this.is_equivalent(&other)); + /// // `this` does not contain the `go: jets` header, so `other` is not equivalent to `this`. + /// assert!(!other.is_equivalent(&this)); + /// ``` + pub fn is_equivalent(&self, other: &H3iFrame) -> bool { + match self { + Self::Headers(me) => { + let H3iFrame::Headers(other) = other else { + return false; + }; + + // We expect pretty small expected frame vectors, so complexity + // here isn't too bad + me.headers().iter().all(|m| other.headers().contains(m)) + }, + Self::QuicheH3(me) => match other { + H3iFrame::QuicheH3(other) => me == other, + _ => false, + }, + Self::ResetStream(me) => match other { + H3iFrame::ResetStream(rs) => me == rs, + _ => false, + }, + } + } } impl Serialize for H3iFrame { @@ -432,3 +475,30 @@ impl Serialize for SerializableQFrame<'_> { } } } + +/// A combination of stream ID and [`H3iFrame`] which is used to instruct h3i to +/// watch for specific frames. If h3i receives all the frames it expects, it +/// will send an application CONNECTION_CLOSE frame with an error code of 0x100. +/// This bypasses the idle timeout and vastly quickens test suites which depend +/// heavily on h3i. +#[derive(Debug, Eq, PartialEq, Serialize, Clone)] +pub struct ExpectedFrame { + stream_id: u64, + frame: H3iFrame, +} + +impl ExpectedFrame { + pub fn new(stream_id: u64, frame: H3iFrame) -> Self { + Self { stream_id, frame } + } + + pub(crate) fn stream_id(&self) -> u64 { + self.stream_id + } + + pub(crate) fn is_equivalent(&self, other: &H3iFrame) -> bool { + // TODO(evanrittenhouse): allow users to specify custom equivalence + // functions + self.frame.is_equivalent(other) + } +} diff --git a/h3i/src/lib.rs b/h3i/src/lib.rs index 10cfb8ef29..b70ec09996 100644 --- a/h3i/src/lib.rs +++ b/h3i/src/lib.rs @@ -112,7 +112,7 @@ //! ]; //! //! let summary = -//! sync_client::connect(config, &actions).expect("connection failed"); +//! sync_client::connect(config, &actions, None).expect("connection failed"); //! //! println!( //! "=== received connection summary! ===\n\n{}", diff --git a/h3i/src/main.rs b/h3i/src/main.rs index b27075b51f..5ddf3fdca8 100644 --- a/h3i/src/main.rs +++ b/h3i/src/main.rs @@ -298,7 +298,8 @@ fn config_from_clap() -> std::result::Result { fn sync_client( config: Config, actions: &[Action], ) -> Result { - h3i::client::sync_client::connect(config.library_config, actions) + // TODO: CLI doesn't support passing expected frames at the moment + h3i::client::sync_client::connect(config.library_config, actions, None) } fn read_qlog(filename: &str, host_override: Option<&str>) -> Vec {