diff --git a/common/src/header/codec.rs b/common/src/header/codec.rs index 5ffd4d5..b2a61c4 100644 --- a/common/src/header/codec.rs +++ b/common/src/header/codec.rs @@ -9,6 +9,8 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::{instrument, trace}; +use crate::header::timestamp::{ReplayValidator, TimestampMsg, TIME_FRAME}; + pub const MAX_HEADER_LEN: usize = 1024; pub trait Header {} @@ -27,45 +29,71 @@ where R: reader_bounds, H: for<'de> Deserialize<'de> + std::fmt::Debug + Header, { - let mut buf = [0; MAX_HEADER_LEN + tokio_chacha20::mac::BLOCK_BYTES]; + let mut buf = [0; MAX_HEADER_LEN * 2]; + // Read nonce + { + let size = cursor.remaining_nonce_size(); + let buf = &mut buf[..size]; + let res = reader.read_exact(buf); + add_await([res])?; + let i = cursor.decrypt(buf); + if i.is_some() { + return Err(CodecError::Integrity); + } + } // Decode header length let len = { - let n = 4 + cursor.remaining_nonce_size(); - let res = reader.read_exact(&mut buf[..n]); + let size = core::mem::size_of::(); + let buf = &mut buf[..size]; + let res = reader.read_exact(buf); add_await([res])?; - let i = cursor.decrypt(&mut buf[..n]).unwrap(); - u32::from_be_bytes(buf[i..n].try_into().unwrap()) as usize + let i = cursor.decrypt(buf).unwrap(); + if i != 0 { + return Err(CodecError::Integrity); + } + let len = u32::from_be_bytes(buf.try_into().unwrap()) as usize; + trace!(len, "Read header length"); + if len > MAX_HEADER_LEN { + return Err(CodecError::Io(io::Error::new( + io::ErrorKind::InvalidData, + "Header too long", + ))); + } + len }; - trace!(len, "Read header length"); - // Read header and tag - if len > MAX_HEADER_LEN { - return Err(CodecError::Io(io::Error::new( - io::ErrorKind::InvalidData, - "Header too long", - ))); - } - let hdr_tag = &mut buf[..len + tokio_chacha20::mac::BLOCK_BYTES]; - let res = reader.read_exact(hdr_tag); - add_await([res])?; - let (hdr, tag) = hdr_tag.split_at_mut(len); - let tag: &[u8] = tag; - + let (hdr_buf, tag) = { + let buf = &mut buf[..len + tokio_chacha20::mac::BLOCK_BYTES]; + let res = reader.read_exact(buf); + add_await([res])?; + let (hdr, tag) = buf.split_at_mut(len); + let tag: &[u8] = tag; + (hdr, tag) + }; // Check MAC - let key = cursor.poly1305_key().unwrap(); - let expected_tag = tokio_chacha20::mac::poly1305_mac(key, hdr); - if tag != expected_tag { - return Err(CodecError::Integrity); + { + let key = cursor.poly1305_key().unwrap(); + let expected_tag = tokio_chacha20::mac::poly1305_mac(key, hdr_buf); + if tag != expected_tag { + return Err(CodecError::Integrity); + } } - // Decode header let header = { - let i = cursor.decrypt(hdr).unwrap(); + let i = cursor.decrypt(hdr_buf).unwrap(); assert_eq!(i, 0); - bincode::deserialize(hdr)? + let mut rdr = io::Cursor::new(&hdr_buf[..]); + let mut timestamp_buf = [0; TimestampMsg::SIZE]; + Read::read_exact(&mut rdr, &mut timestamp_buf).unwrap(); + let timestamp = TimestampMsg::decode(timestamp_buf); + if !ReplayValidator::new(TIME_FRAME).validates(timestamp.timestamp()) { + return Err(CodecError::Integrity); + } + let header = bincode::deserialize_from(&mut rdr)?; + trace!(?timestamp, ?header, "Read header"); + header }; - trace!(?header, "Read header"); Ok(header) } @@ -85,34 +113,38 @@ where W: writer_bounds, H: Serialize + std::fmt::Debug + Header, { - let mut hdr_buf = [0; MAX_HEADER_LEN]; - let mut hdr_wtr = io::Cursor::new(&mut hdr_buf[..]); - let mut buf = [0; MAX_HEADER_LEN * 2]; - // Encode header - let hdr = { + let mut hdr_buf = [0; TimestampMsg::SIZE + MAX_HEADER_LEN]; + let hdr_buf: &[u8] = { + let mut hdr_wtr = io::Cursor::new(&mut hdr_buf[..]); + let timestamp = TimestampMsg::now(); + Write::write_all(&mut hdr_wtr, ×tamp.encode()).unwrap(); bincode::serialize_into(&mut hdr_wtr, header)?; let len = hdr_wtr.position(); let hdr = &mut hdr_buf[..len as usize]; - trace!(?header, ?len, "Encoded header"); + trace!(?timestamp, ?header, ?len, "Encoded header"); hdr }; + let mut buf = [0; MAX_HEADER_LEN * 2]; let mut n = 0; // Write header length - let len = hdr.len() as u32; - let len = len.to_be_bytes(); - let (f, t) = cursor.encrypt(&len, &mut buf[n..]); - assert_eq!(len.len(), f); - n += t; - + { + let len = hdr_buf.len() as u32; + let len = len.to_be_bytes(); + let (f, t) = cursor.encrypt(&len, &mut buf[n..]); + assert_eq!(len.len(), f); + n += t; + } // Write header - let (f, t) = cursor.encrypt(hdr, &mut buf[n..]); - assert_eq!(hdr.len(), f); - let encrypted_hdr = &buf[n..n + t]; - n += t; - + let encrypted_hdr = { + let (f, t) = cursor.encrypt(hdr_buf, &mut buf[n..]); + assert_eq!(hdr_buf.len(), f); + let encrypted_hdr = &buf[n..n + t]; + n += t; + encrypted_hdr + }; // Write tag let key = cursor.poly1305_key(); let tag = tokio_chacha20::mac::poly1305_mac(key, encrypted_hdr); diff --git a/common/src/header/mod.rs b/common/src/header/mod.rs index 708be35..da76303 100644 --- a/common/src/header/mod.rs +++ b/common/src/header/mod.rs @@ -1,3 +1,4 @@ pub mod codec; pub mod heartbeat; pub mod route; +pub mod timestamp; diff --git a/common/src/header/timestamp.rs b/common/src/header/timestamp.rs new file mode 100644 index 0000000..c32411e --- /dev/null +++ b/common/src/header/timestamp.rs @@ -0,0 +1,44 @@ +use core::time::Duration; + +pub const TIME_FRAME: Duration = Duration::from_secs(5); + +#[derive(Debug, Clone)] +pub struct TimestampMsg { + pub timestamp_sec: u64, +} +impl TimestampMsg { + pub fn now() -> Self { + let timestamp_sec = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + Self { timestamp_sec } + } + pub fn timestamp(&self) -> Duration { + Duration::from_secs(self.timestamp_sec) + } + pub const SIZE: usize = std::mem::size_of::(); + pub fn encode(&self) -> [u8; Self::SIZE] { + self.timestamp_sec.to_be_bytes() + } + pub fn decode(buf: [u8; Self::SIZE]) -> Self { + let timestamp_sec = u64::from_be_bytes(buf); + Self { timestamp_sec } + } +} + +#[derive(Debug, Clone)] +pub struct ReplayValidator { + time_frame: Duration, +} +impl ReplayValidator { + pub fn new(time_frame: Duration) -> Self { + Self { time_frame } + } + pub fn validates(&self, timestamp: Duration) -> bool { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap(); + now.abs_diff(timestamp) < self.time_frame + } +}