Skip to content

Commit

Permalink
feat: time frame validation
Browse files Browse the repository at this point in the history
  • Loading branch information
Banyc committed Dec 19, 2024
1 parent 575cb76 commit 5cc9a87
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 44 deletions.
120 changes: 76 additions & 44 deletions common/src/header/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::{instrument, trace};

use crate::header::timestamp::{ReplayValidator, TimestampMsg, TIME_FRAME};

pub const MAX_HEADER_LEN: usize = 1024;

pub trait Header {}
Expand All @@ -27,45 +29,71 @@ where
R: reader_bounds,
H: for<'de> Deserialize<'de> + std::fmt::Debug + Header,
{
let mut buf = [0; MAX_HEADER_LEN + tokio_chacha20::mac::BLOCK_BYTES];
let mut buf = [0; MAX_HEADER_LEN * 2];

// Read nonce
{
let size = cursor.remaining_nonce_size();
let buf = &mut buf[..size];
let res = reader.read_exact(buf);
add_await([res])?;
let i = cursor.decrypt(buf);
if i.is_some() {
return Err(CodecError::Integrity);
}
}
// Decode header length
let len = {
let n = 4 + cursor.remaining_nonce_size();
let res = reader.read_exact(&mut buf[..n]);
let size = core::mem::size_of::<u32>();
let buf = &mut buf[..size];
let res = reader.read_exact(buf);
add_await([res])?;
let i = cursor.decrypt(&mut buf[..n]).unwrap();
u32::from_be_bytes(buf[i..n].try_into().unwrap()) as usize
let i = cursor.decrypt(buf).unwrap();
if i != 0 {
return Err(CodecError::Integrity);
}
let len = u32::from_be_bytes(buf.try_into().unwrap()) as usize;
trace!(len, "Read header length");
if len > MAX_HEADER_LEN {
return Err(CodecError::Io(io::Error::new(
io::ErrorKind::InvalidData,
"Header too long",
)));
}
len
};
trace!(len, "Read header length");

// Read header and tag
if len > MAX_HEADER_LEN {
return Err(CodecError::Io(io::Error::new(
io::ErrorKind::InvalidData,
"Header too long",
)));
}
let hdr_tag = &mut buf[..len + tokio_chacha20::mac::BLOCK_BYTES];
let res = reader.read_exact(hdr_tag);
add_await([res])?;
let (hdr, tag) = hdr_tag.split_at_mut(len);
let tag: &[u8] = tag;

let (hdr_buf, tag) = {
let buf = &mut buf[..len + tokio_chacha20::mac::BLOCK_BYTES];
let res = reader.read_exact(buf);
add_await([res])?;
let (hdr, tag) = buf.split_at_mut(len);
let tag: &[u8] = tag;
(hdr, tag)
};
// Check MAC
let key = cursor.poly1305_key().unwrap();
let expected_tag = tokio_chacha20::mac::poly1305_mac(key, hdr);
if tag != expected_tag {
return Err(CodecError::Integrity);
{
let key = cursor.poly1305_key().unwrap();
let expected_tag = tokio_chacha20::mac::poly1305_mac(key, hdr_buf);
if tag != expected_tag {
return Err(CodecError::Integrity);
}
}

// Decode header
let header = {
let i = cursor.decrypt(hdr).unwrap();
let i = cursor.decrypt(hdr_buf).unwrap();
assert_eq!(i, 0);
bincode::deserialize(hdr)?
let mut rdr = io::Cursor::new(&hdr_buf[..]);
let mut timestamp_buf = [0; TimestampMsg::SIZE];
Read::read_exact(&mut rdr, &mut timestamp_buf).unwrap();
let timestamp = TimestampMsg::decode(timestamp_buf);
if !ReplayValidator::new(TIME_FRAME).validates(timestamp.timestamp()) {
return Err(CodecError::Integrity);
}
let header = bincode::deserialize_from(&mut rdr)?;
trace!(?timestamp, ?header, "Read header");
header
};
trace!(?header, "Read header");

Ok(header)
}
Expand All @@ -85,34 +113,38 @@ where
W: writer_bounds,
H: Serialize + std::fmt::Debug + Header,
{
let mut hdr_buf = [0; MAX_HEADER_LEN];
let mut hdr_wtr = io::Cursor::new(&mut hdr_buf[..]);
let mut buf = [0; MAX_HEADER_LEN * 2];

// Encode header
let hdr = {
let mut hdr_buf = [0; TimestampMsg::SIZE + MAX_HEADER_LEN];
let hdr_buf: &[u8] = {
let mut hdr_wtr = io::Cursor::new(&mut hdr_buf[..]);
let timestamp = TimestampMsg::now();
Write::write_all(&mut hdr_wtr, &timestamp.encode()).unwrap();
bincode::serialize_into(&mut hdr_wtr, header)?;
let len = hdr_wtr.position();
let hdr = &mut hdr_buf[..len as usize];
trace!(?header, ?len, "Encoded header");
trace!(?timestamp, ?header, ?len, "Encoded header");
hdr
};

let mut buf = [0; MAX_HEADER_LEN * 2];
let mut n = 0;

// Write header length
let len = hdr.len() as u32;
let len = len.to_be_bytes();
let (f, t) = cursor.encrypt(&len, &mut buf[n..]);
assert_eq!(len.len(), f);
n += t;

{
let len = hdr_buf.len() as u32;
let len = len.to_be_bytes();
let (f, t) = cursor.encrypt(&len, &mut buf[n..]);
assert_eq!(len.len(), f);
n += t;
}
// Write header
let (f, t) = cursor.encrypt(hdr, &mut buf[n..]);
assert_eq!(hdr.len(), f);
let encrypted_hdr = &buf[n..n + t];
n += t;

let encrypted_hdr = {
let (f, t) = cursor.encrypt(hdr_buf, &mut buf[n..]);
assert_eq!(hdr_buf.len(), f);
let encrypted_hdr = &buf[n..n + t];
n += t;
encrypted_hdr
};
// Write tag
let key = cursor.poly1305_key();
let tag = tokio_chacha20::mac::poly1305_mac(key, encrypted_hdr);
Expand Down
1 change: 1 addition & 0 deletions common/src/header/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod codec;
pub mod heartbeat;
pub mod route;
pub mod timestamp;
44 changes: 44 additions & 0 deletions common/src/header/timestamp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use core::time::Duration;

pub const TIME_FRAME: Duration = Duration::from_secs(5);

#[derive(Debug, Clone)]
pub struct TimestampMsg {
pub timestamp_sec: u64,
}
impl TimestampMsg {
pub fn now() -> Self {
let timestamp_sec = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
Self { timestamp_sec }
}
pub fn timestamp(&self) -> Duration {
Duration::from_secs(self.timestamp_sec)
}
pub const SIZE: usize = std::mem::size_of::<u64>();
pub fn encode(&self) -> [u8; Self::SIZE] {
self.timestamp_sec.to_be_bytes()
}
pub fn decode(buf: [u8; Self::SIZE]) -> Self {
let timestamp_sec = u64::from_be_bytes(buf);
Self { timestamp_sec }
}
}

#[derive(Debug, Clone)]
pub struct ReplayValidator {
time_frame: Duration,
}
impl ReplayValidator {
pub fn new(time_frame: Duration) -> Self {
Self { time_frame }
}
pub fn validates(&self, timestamp: Duration) -> bool {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap();
now.abs_diff(timestamp) < self.time_frame
}
}

0 comments on commit 5cc9a87

Please sign in to comment.