Skip to content

Commit

Permalink
tcp: add netsim test.
Browse files Browse the repository at this point in the history
This test simulates a network with a given latency and packet loss, and measures the
throughput between two virtual smoltcp instances.
  • Loading branch information
Dirbaio committed Dec 30, 2024
1 parent 7a248ae commit 3673322
Show file tree
Hide file tree
Showing 3 changed files with 381 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ getopts = "0.2"
rand = "0.8"
url = "2.0"
rstest = "0.17"
insta = "1.41.1"
rand_chacha = "0.3.1"

[features]
std = ["managed/std", "alloc"]
Expand Down
364 changes: 364 additions & 0 deletions tests/netsim.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,364 @@
use std::cell::RefCell;
use std::collections::BinaryHeap;
use std::fmt::Write as _;
use std::io::Write as _;
use std::sync::Mutex;

use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use smoltcp::iface::{Config, Interface, SocketSet};
use smoltcp::phy::Tracer;
use smoltcp::phy::{self, ChecksumCapabilities, Device, DeviceCapabilities, Medium};
use smoltcp::socket::tcp;
use smoltcp::time::{Duration, Instant};
use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress, IpCidr};

const MAC_A: HardwareAddress = HardwareAddress::Ethernet(EthernetAddress([2, 0, 0, 0, 0, 1]));
const MAC_B: HardwareAddress = HardwareAddress::Ethernet(EthernetAddress([2, 0, 0, 0, 0, 2]));
const IP_A: IpAddress = IpAddress::v4(10, 0, 0, 1);
const IP_B: IpAddress = IpAddress::v4(10, 0, 0, 2);

const BYTES: usize = 10 * 1024 * 1024;

static CLOCK: Mutex<(Instant, char)> = Mutex::new((Instant::ZERO, ' '));

#[test]
fn netsim() {
setup_logging();

let buffers = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768];
let losses = [0.0, 0.001, 0.01, 0.02, 0.05, 0.10, 0.20, 0.30];

let mut s = String::new();

write!(&mut s, "buf\\loss").unwrap();
for loss in losses {
write!(&mut s, "{loss:9.3} ").unwrap();
}
writeln!(&mut s).unwrap();

for buffer in buffers {
write!(&mut s, "{buffer:7}").unwrap();
for loss in losses {
let r = run_test(TestCase {
rtt: Duration::from_millis(100),
buffer,
loss,
});
write!(&mut s, " {r:9.2}").unwrap();
}
writeln!(&mut s).unwrap();
}

insta::assert_snapshot!(s);
}

struct TestCase {
rtt: Duration,
loss: f64,
buffer: usize,
}

fn run_test(case: TestCase) -> f64 {
let mut time = Instant::ZERO;

let params = QueueParams {
latency: case.rtt / 2,
loss: case.loss,
};
let queue_a_to_b = RefCell::new(PacketQueue::new(params.clone(), 0));
let queue_b_to_a = RefCell::new(PacketQueue::new(params.clone(), 1));
let device_a = QueueDevice::new(&queue_a_to_b, &queue_b_to_a, Medium::Ethernet);
let device_b = QueueDevice::new(&queue_b_to_a, &queue_a_to_b, Medium::Ethernet);

let mut device_a = Tracer::new(device_a, |_timestamp, _printer| log::trace!("{}", _printer));
let mut device_b = Tracer::new(device_b, |_timestamp, _printer| log::trace!("{}", _printer));

let mut iface_a = Interface::new(Config::new(MAC_A), &mut device_a, time);
iface_a.update_ip_addrs(|a| a.push(IpCidr::new(IP_A, 8)).unwrap());
let mut iface_b = Interface::new(Config::new(MAC_B), &mut device_b, time);
iface_b.update_ip_addrs(|a| a.push(IpCidr::new(IP_B, 8)).unwrap());

// Create sockets
let socket_a = {
let tcp_rx_buffer = tcp::SocketBuffer::new(vec![0; case.buffer]);
let tcp_tx_buffer = tcp::SocketBuffer::new(vec![0; case.buffer]);
tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer)
};

let socket_b = {
let tcp_rx_buffer = tcp::SocketBuffer::new(vec![0; case.buffer]);
let tcp_tx_buffer = tcp::SocketBuffer::new(vec![0; case.buffer]);
tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer)
};

let mut sockets_a: [_; 2] = Default::default();
let mut sockets_a = SocketSet::new(&mut sockets_a[..]);
let socket_a_handle = sockets_a.add(socket_a);

let mut sockets_b: [_; 2] = Default::default();
let mut sockets_b = SocketSet::new(&mut sockets_b[..]);
let socket_b_handle = sockets_b.add(socket_b);

let mut did_listen = false;
let mut did_connect = false;
let mut processed = 0;
while processed < BYTES {
*CLOCK.lock().unwrap() = (time, ' ');
log::info!("loop");
//println!("t = {}", time);

*CLOCK.lock().unwrap() = (time, 'A');

iface_a.poll(time, &mut device_a, &mut sockets_a);

let socket = sockets_a.get_mut::<tcp::Socket>(socket_a_handle);
if !socket.is_active() && !socket.is_listening() && !did_listen {
//println!("listening");
socket.listen(1234).unwrap();
did_listen = true;
}

while socket.can_recv() {
let received = socket.recv(|buffer| (buffer.len(), buffer.len())).unwrap();
//println!("got {:?}", received,);
processed += received;
}

*CLOCK.lock().unwrap() = (time, 'B');
iface_b.poll(time, &mut device_b, &mut sockets_b);
let socket = sockets_b.get_mut::<tcp::Socket>(socket_b_handle);
let cx = iface_b.context();
if !socket.is_open() && !did_connect {
//println!("connecting");
socket.connect(cx, (IP_A, 1234), 65000).unwrap();
did_connect = true;
}

while socket.can_send() {
//println!("sending");
socket.send(|buffer| (buffer.len(), ())).unwrap();
}

*CLOCK.lock().unwrap() = (time, ' ');

let mut next_time = queue_a_to_b.borrow_mut().next_expiration();
next_time = next_time.min(queue_b_to_a.borrow_mut().next_expiration());
if let Some(t) = iface_a.poll_at(time, &sockets_a) {
next_time = next_time.min(t);
}
if let Some(t) = iface_b.poll_at(time, &sockets_b) {
next_time = next_time.min(t);
}
assert!(next_time.total_micros() != i64::MAX);
time = time.max(next_time);
}

let duration = time - Instant::ZERO;
processed as f64 / duration.total_micros() as f64 * 1e6
}

struct Packet {
timestamp: Instant,
id: u64,
data: Vec<u8>,
}

impl PartialEq for Packet {
fn eq(&self, other: &Self) -> bool {
(other.timestamp, other.id) == (self.timestamp, self.id)
}
}

impl Eq for Packet {}

impl PartialOrd for Packet {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
(other.timestamp, other.id).partial_cmp(&(self.timestamp, self.id))
}
}

impl Ord for Packet {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(other.timestamp, other.id).cmp(&(self.timestamp, self.id))
}
}

#[derive(Clone)]
struct QueueParams {
latency: Duration,
loss: f64,
}

struct PacketQueue {
queue: BinaryHeap<Packet>,
next_id: u64,
params: QueueParams,
rng: ChaCha20Rng,
}

impl PacketQueue {
pub fn new(params: QueueParams, seed: u64) -> Self {
Self {
queue: BinaryHeap::new(),
next_id: 0,
params,
rng: ChaCha20Rng::seed_from_u64(seed),
}
}

pub fn next_expiration(&self) -> Instant {
self.queue
.peek()
.map(|p| p.timestamp)
.unwrap_or(Instant::from_micros(i64::MAX))
}

pub fn push(&mut self, data: Vec<u8>, timestamp: Instant) {
if self.rng.gen::<f64>() < self.params.loss {
log::info!("PACKET LOST!");
return;
}

self.queue.push(Packet {
data,
id: self.next_id,
timestamp: timestamp + self.params.latency,
});
self.next_id += 1;
}

pub fn pop(&mut self, timestamp: Instant) -> Option<Vec<u8>> {
let p = self.queue.peek()?;
if p.timestamp > timestamp {
return None;
}
Some(self.queue.pop().unwrap().data)
}
}

pub struct QueueDevice<'a> {
tx_queue: &'a RefCell<PacketQueue>,
rx_queue: &'a RefCell<PacketQueue>,
medium: Medium,
}

impl<'a> QueueDevice<'a> {
fn new(
tx_queue: &'a RefCell<PacketQueue>,
rx_queue: &'a RefCell<PacketQueue>,
medium: Medium,
) -> Self {
Self {
tx_queue,
rx_queue,
medium,
}
}
}

impl Device for QueueDevice<'_> {
type RxToken<'a>
= RxToken
where
Self: 'a;
type TxToken<'a>
= TxToken<'a>
where
Self: 'a;

fn capabilities(&self) -> DeviceCapabilities {
let mut caps = DeviceCapabilities::default();
caps.max_transmission_unit = 1514;
caps.medium = self.medium;
caps.checksum = ChecksumCapabilities::ignored();
caps
}

fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
self.rx_queue
.borrow_mut()
.pop(timestamp)
.map(move |buffer| {
let rx = RxToken { buffer };
let tx = TxToken {
queue: &mut self.tx_queue,
timestamp,
};
(rx, tx)
})
}

fn transmit(&mut self, timestamp: Instant) -> Option<Self::TxToken<'_>> {
Some(TxToken {
queue: &mut self.tx_queue,
timestamp,
})
}
}

pub struct RxToken {
buffer: Vec<u8>,
}

impl phy::RxToken for RxToken {
fn consume<R, F>(self, f: F) -> R
where
F: FnOnce(&[u8]) -> R,
{
f(&self.buffer)
}
}

pub struct TxToken<'a> {
queue: &'a RefCell<PacketQueue>,
timestamp: Instant,
}

impl<'a> phy::TxToken for TxToken<'a> {
fn consume<R, F>(self, len: usize, f: F) -> R
where
F: FnOnce(&mut [u8]) -> R,
{
let mut buffer = vec![0; len];
let result = f(&mut buffer);
self.queue.borrow_mut().push(buffer, self.timestamp);
result
}
}

pub fn setup_logging() {
env_logger::Builder::new()
.format(move |buf, record| {
let (elapsed, side) = *CLOCK.lock().unwrap();

let timestamp = format!("[{elapsed} {side}]");
if record.target().starts_with("smoltcp::") {
writeln!(
buf,
"{} ({}): {}",
timestamp,
record.target().replace("smoltcp::", ""),
record.args()
)
} else if record.level() == log::Level::Trace {
let message = format!("{}", record.args());
writeln!(
buf,
"{} {}",
timestamp,
message.replace('\n', "\n ")
)
} else {
writeln!(
buf,
"{} ({}): {}",
timestamp,
record.target(),
record.args()
)
}
})
.parse_env("RUST_LOG")
.init();
}
15 changes: 15 additions & 0 deletions tests/snapshots/netsim__netsim.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
---
source: tests/netsim.rs
expression: s
snapshot_kind: text
---
buf\loss 0.000 0.001 0.010 0.020 0.050 0.100 0.200 0.300
128 1279.98 1255.76 1054.15 886.36 538.66 227.84 33.99 7.18
256 2559.91 2507.27 2100.03 1770.30 1070.71 468.24 66.71 14.35
512 5119.63 5011.95 4172.36 3531.57 2098.73 942.38 144.73 29.45
1024 10238.50 10023.19 8340.90 7084.25 4003.34 1869.94 290.74 60.92
2048 17535.11 17171.82 14093.50 12063.90 7205.27 3379.12 824.76 131.54
4096 35062.41 33852.31 27011.08 22073.09 13680.70 7631.11 1617.81 302.65
8192 77374.28 72409.99 58428.68 48310.75 29123.30 14314.36 2880.39 551.60
16384 161842.28 159448.56 141467.31 127073.06 78239.08 38637.20 7565.64 1112.31
32768 322944.88 314313.90 266384.37 245985.29 138762.29 83162.99 10739.10 1951.95

0 comments on commit 3673322

Please sign in to comment.