Skip to content

refine code #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 15 additions & 43 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
name: CI
on: [push, pull_request]

on:
push:
branches:
- '**'
pull_request:
branches:
- '**'

env:
CARGO_INCREMENTAL: 0
@@ -50,59 +57,24 @@ jobs:
shell: bash
- uses: Swatinem/rust-cache@v2
- name: Setup android environment
if: contains(matrix.build, 'android')
uses: ./.github/actions/ndk-dev-rs
with:
rust-target: ${{ matrix.target }}
if: contains(matrix.build, 'android')
- run: cargo test ${{ matrix.no_run }} --workspace --target ${{ matrix.target }}
- run: cargo test ${{ matrix.no_run }} --workspace --target ${{ matrix.target }} --release

msrv:
name: MSRV
msrv_n_clippy:
name: MSRV & Clippy & Rustfmt
runs-on: ${{ matrix.os }}
env:
MSRV: 1.65.0
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: |
rustup toolchain install $MSRV --no-self-update --profile minimal
rustup toolchain install nightly --no-self-update --profile minimal
rustup default $MSRV
shell: bash
- name: Create Cargo.lock with minimal version
run: cargo +nightly update -Zminimal-versions
- name: Cache downloaded crates since minimal version is really slow in fetching
uses: Swatinem/rust-cache@v2
- run: cargo check --lib -p netstack-smoltcp --locked
- run: cargo check --lib -p netstack-smoltcp --locked --all-features

clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: |
rustup toolchain install stable --no-self-update --profile minimal --component rustfmt
rustup default stable
shell: bash
- uses: Swatinem/rust-cache@v2
- run: cargo clippy

rustfmt:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install Rust
run: |
rustup toolchain install stable --no-self-update --profile minimal --component rustfmt
rustup default stable
shell: bash
- uses: Swatinem/rust-cache@v2
- uses: dtolnay/rust-toolchain@stable
- run: cargo fmt -- --check
- run: cargo clippy --all-features -- -D warnings
- run: cargo check --lib -p netstack-smoltcp
- run: cargo check --lib -p netstack-smoltcp --all-features
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
/Cargo.lock

.idea
.VSCodeCounter/
.vscode
.DS_Store
*.iml
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -13,12 +13,13 @@ description = """
A netstack for the special purpose of turning packets from/to a TUN interface
into TCP streams and UDP packets. It uses smoltcp-rs as the backend netstack.
"""
rust-version = "1.75.0"

[dependencies]
tracing = { version = "0.1", default-features = false, features = ["std"] }
tokio = { version = "1", features = ["sync", "time", "rt", "macros"] }
tokio-util = "0.7.10"
etherparse = "0.13"
etherparse = "0.16"
futures = "0.3"
rand = "0.8"
spin = "0.9"
@@ -34,7 +35,7 @@ smoltcp = { version = "0.11", default-features = false, features = [
] }

[dev-dependencies]
tun = { package = "tun2", version = "1.0", features = ["async"] }
tun2 = { version = "3", features = ["async"] }
tokio = { version = "1", features = [
"rt",
"macros",
20 changes: 10 additions & 10 deletions examples/forward.rs
Original file line number Diff line number Diff line change
@@ -86,24 +86,24 @@ async fn main_exec(opt: Opt) {
)
.unwrap();

let mut cfg = tun::Configuration::default();
cfg.layer(tun::Layer::L3);
let mut cfg = tun2::Configuration::default();
cfg.layer(tun2::Layer::L3);
let fd = -1;
if fd >= 0 {
cfg.raw_fd(fd);
} else {
cfg.tun_name("utun8")
cfg.tun_name(&opt.interface)
.address("10.10.10.2")
.destination("10.10.10.1")
.mtu(tun::DEFAULT_MTU);
.mtu(tun2::DEFAULT_MTU);
#[cfg(not(any(target_arch = "mips", target_arch = "mips64",)))]
{
cfg.netmask("255.255.255.0");
}
cfg.up();
}

let device = tun::create_as_async(&cfg).unwrap();
let device = tun2::create_as_async(&cfg).unwrap();
let mut builder = StackBuilder::default()
.enable_tcp(true)
.enable_udp(true)
@@ -274,20 +274,20 @@ async fn new_udp_packet(addr: SocketAddr, iface: &str) -> std::io::Result<tokio:
socket
}

fn get_device_broadcast(device: &tun::AsyncDevice) -> Option<std::net::Ipv4Addr> {
use tun::AbstractDevice;
fn get_device_broadcast(device: &tun2::AsyncDevice) -> Option<std::net::Ipv4Addr> {
use tun2::AbstractDevice;

let mtu = device.as_ref().mtu().unwrap_or(tun::DEFAULT_MTU);
let mtu = device.mtu().unwrap_or(tun2::DEFAULT_MTU);

let address = match device.as_ref().address() {
let address = match device.address() {
Ok(a) => match a {
IpAddr::V4(v4) => v4,
IpAddr::V6(_) => return None,
},
Err(_) => return None,
};

let netmask = match device.as_ref().netmask() {
let netmask = match device.netmask() {
Ok(n) => match n {
IpAddr::V4(v4) => v4,
IpAddr::V6(_) => return None,
2 changes: 1 addition & 1 deletion src/device.rs
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ use smoltcp::{
};
use tokio::sync::mpsc::{unbounded_channel, Permit, Sender, UnboundedReceiver, UnboundedSender};

use super::packet::AnyIpPktFrame;
use crate::packet::AnyIpPktFrame;

pub(super) struct VirtualDevice {
in_buf_avail: Arc<AtomicBool>,
2 changes: 1 addition & 1 deletion src/runner.rs
Original file line number Diff line number Diff line change
@@ -38,4 +38,4 @@ impl<T> Future for BoxFuture<'_, T> {
}
}

pub type Runner = BoxFuture<'static, ()>;
pub type Runner = BoxFuture<'static, std::io::Result<()>>;
42 changes: 12 additions & 30 deletions src/stack.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
io,
net::IpAddr,
pin::Pin,
task::{Context, Poll},
@@ -92,6 +91,7 @@ impl StackBuilder {
self
}

#[allow(clippy::type_complexity)]
pub fn build(
self,
) -> std::io::Result<(
@@ -119,29 +119,19 @@ impl StackBuilder {
// ICMP is handled by TCP's Interface.
// smoltcp's interface will always send replies to EchoRequest
if self.enable_icmp && !self.enable_tcp {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Enabling icmp requires enabling tcp",
));
use std::io::{Error, ErrorKind::InvalidInput};
return Err(Error::new(InvalidInput, "ICMP requires TCP"));
}
let icmp_tx = if self.enable_icmp {
if let Some(ref tcp_tx) = tcp_tx {
Some(tcp_tx.clone())
} else {
None
}
tcp_tx.clone()
} else {
None
};

let udp_socket = if let Some(udp_rx) = udp_rx {
Some(UdpSocket::new(udp_rx, stack_tx.clone()))
} else {
None
};
let udp_socket = udp_rx.map(|udp_rx| UdpSocket::new(udp_rx, stack_tx.clone()));

let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx {
let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx);
let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx)?;
(Some(tcp_runner), Some(tcp_listener))
} else {
(None, None)
@@ -207,7 +197,7 @@ impl Stack {

// Recv from stack.
impl Stream for Stack {
type Item = io::Result<AnyIpPktFrame>;
type Item = std::io::Result<AnyIpPktFrame>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.stack_rx.poll_recv(cx) {
@@ -220,7 +210,7 @@ impl Stream for Stack {

// Send to stack.
impl Sink<AnyIpPktFrame> for Stack {
type Error = io::Error;
type Error = std::io::Error;

fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.sink_buf.is_none() {
@@ -235,24 +225,16 @@ impl Sink<AnyIpPktFrame> for Stack {
return Ok(());
}

let packet = IpPacket::new_checked(item.as_slice()).map_err(|err| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid IP packet: {}", err),
)
})?;
use std::io::{Error, ErrorKind::InvalidInput};
let packet = IpPacket::new_checked(item.as_slice())
.map_err(|err| Error::new(InvalidInput, format!("invalid IP packet: {}", err)))?;

let src_ip = packet.src_addr();
let dst_ip = packet.dst_addr();

let addr_allowed = self.ip_filters.is_allowed(&src_ip, &dst_ip);
if !addr_allowed {
trace!(
"IP packet {} -> {} (allowed? {}) throwing away",
src_ip,
dst_ip,
addr_allowed,
);
trace!("IP packet {src_ip} -> {dst_ip} (allowed? {addr_allowed}) throwing away",);
return Ok(());
}

89 changes: 42 additions & 47 deletions src/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{
collections::HashMap,
io, mem,
net::SocketAddr,
pin::Pin,
sync::{
@@ -29,7 +28,7 @@ use tokio::{
};
use tracing::{error, trace};

use super::{
use crate::{
device::VirtualDevice,
packet::{AnyIpPktFrame, IpPacket},
Runner,
@@ -79,11 +78,13 @@ impl TcpListenerRunner {
Runner::new(async move {
let notify = Arc::new(Notify::new());
let (socket_tx, socket_rx) = unbounded_channel::<TcpSocketCreation>();
tokio::select! {
_ = Self::handle_packet(notify.clone(), iface_ingress_tx, iface_ingress_tx_avail.clone(), tcp_rx, stream_tx, socket_tx) => {}
_ = Self::handle_socket(notify, device, iface, iface_ingress_tx_avail, sockets, socket_rx) => {}
}
let res = tokio::select! {
v = Self::handle_packet(notify.clone(), iface_ingress_tx, iface_ingress_tx_avail.clone(), tcp_rx, stream_tx, socket_tx) => v,
v = Self::handle_socket(notify, device, iface, iface_ingress_tx_avail, sockets, socket_rx) => v,
};
res?;
trace!("VirtDevice::poll thread exited");
Ok(())
})
}

@@ -94,7 +95,7 @@ impl TcpListenerRunner {
mut tcp_rx: Receiver<AnyIpPktFrame>,
stream_tx: UnboundedSender<TcpStream>,
socket_tx: UnboundedSender<TcpSocketCreation>,
) {
) -> std::io::Result<()> {
while let Some(frame) = tcp_rx.recv().await {
let packet = match IpPacket::new_checked(frame.as_slice()) {
Ok(p) => p,
@@ -108,25 +109,20 @@ impl TcpListenerRunner {
if matches!(packet.protocol(), IpProtocol::Icmp | IpProtocol::Icmpv6) {
iface_ingress_tx
.send(frame)
.expect("channel already closed");
.map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
iface_ingress_tx_avail.store(true, Ordering::Release);
notify.notify_one();
continue;
}

let src_ip = packet.src_addr();
let dst_ip = packet.dst_addr();
let payload = packet.payload();

let packet = match TcpPacket::new_checked(packet.payload()) {
let packet = match TcpPacket::new_checked(payload) {
Ok(p) => p,
Err(err) => {
error!(
"invalid TCP err: {}, src_ip: {}, dst_ip: {}, payload: {:?}",
err,
packet.src_addr(),
packet.dst_addr(),
packet.payload(),
);
error!("invalid TCP err: {err}, src_ip: {src_ip}, dst_ip: {dst_ip}, payload: {payload:?}");
continue;
}
};
@@ -171,19 +167,20 @@ impl TcpListenerRunner {
notify: notify.clone(),
control: control.clone(),
})
.expect("channel already closed");
.map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
socket_tx
.send(TcpSocketCreation { control, socket })
.expect("channel already closed");
.map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
}

// Pipeline tcp stream packet
iface_ingress_tx
.send(frame)
.expect("channel already closed");
.map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
iface_ingress_tx_avail.store(true, Ordering::Release);
notify.notify_one();
}
Ok(())
}

async fn handle_socket(
@@ -193,7 +190,7 @@ impl TcpListenerRunner {
iface_ingress_tx_avail: Arc<AtomicBool>,
mut sockets: HashMap<SocketHandle, SharedControl>,
mut socket_rx: UnboundedReceiver<TcpSocketCreation>,
) {
) -> std::io::Result<()> {
let mut socket_set = SocketSet::new(vec![]);
loop {
while let Ok(TcpSocketCreation { control, socket }) = socket_rx.try_recv() {
@@ -253,9 +250,7 @@ impl TcpListenerRunner {
});

match result {
Ok(..) => {
wake_receiver = true;
}
Ok(..) => wake_receiver = true,
Err(err) => {
error!("socket recv error: {:?}, {:?}", err, socket.state());

@@ -275,16 +270,16 @@ impl TcpListenerRunner {

// If socket is not in ESTABLISH, FIN-WAIT-1, FIN-WAIT-2,
// the local client have closed our receiver.
let states = [
TcpState::Listen,
TcpState::SynReceived,
TcpState::Established,
TcpState::FinWait1,
TcpState::FinWait2,
];
if matches!(control.recv_state, TcpSocketState::Normal)
&& !socket.may_recv()
&& !matches!(
socket.state(),
TcpState::Listen
| TcpState::SynReceived
| TcpState::Established
| TcpState::FinWait1
| TcpState::FinWait2
)
&& !states.contains(&socket.state())
{
trace!("closed TCP Read Half, {:?}", socket.state());

@@ -308,9 +303,7 @@ impl TcpListenerRunner {
});

match result {
Ok(..) => {
wake_sender = true;
}
Ok(..) => wake_sender = true,
Err(err) => {
error!("socket send error: {:?}, {:?}", err, socket.state());

@@ -364,9 +357,9 @@ impl TcpListener {
pub(super) fn new(
tcp_rx: Receiver<AnyIpPktFrame>,
stack_tx: Sender<AnyIpPktFrame>,
) -> (Runner, Self) {
) -> std::io::Result<(Runner, Self)> {
let (mut device, iface_ingress_tx, iface_ingress_tx_avail) = VirtualDevice::new(stack_tx);
let iface = Self::create_interface(&mut device);
let iface = Self::create_interface(&mut device)?;

let (stream_tx, stream_rx) = unbounded_channel();

@@ -380,10 +373,10 @@ impl TcpListener {
HashMap::new(),
);

(runner, Self { stream_rx })
Ok((runner, Self { stream_rx }))
}

fn create_interface<D>(device: &mut D) -> Interface
fn create_interface<D>(device: &mut D) -> std::io::Result<Interface>
where
D: Device + ?Sized,
{
@@ -401,13 +394,13 @@ impl TcpListener {
iface
.routes_mut()
.add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 1))
.expect("IPv4 default route");
.map_err(|e| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, e))?;
iface
.routes_mut()
.add_default_ipv6_route(Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1))
.expect("IPv6 default route");
.map_err(|e| std::io::Error::new(std::io::ErrorKind::AddrNotAvailable, e))?;
iface.set_any_ip(true);
iface
Ok(iface)
}
}

@@ -466,7 +459,7 @@ impl AsyncRead for TcpStream {
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
) -> Poll<std::io::Result<()>> {
let mut control = self.control.lock();

// Read from buffer
@@ -486,7 +479,9 @@ impl AsyncRead for TcpStream {
return Poll::Pending;
}

let recv_buf = unsafe { mem::transmute::<_, &mut [u8]>(buf.unfilled_mut()) };
let recv_buf = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<u8>], &mut [u8]>(buf.unfilled_mut())
};
let n = control.recv_buffer.dequeue_slice(recv_buf);
buf.advance(n);

@@ -503,12 +498,12 @@ impl AsyncWrite for TcpStream {
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
) -> Poll<std::io::Result<usize>> {
let mut control = self.control.lock();

// If state == Close | Closing | Closed, the TCP stream WR half is closed.
if !matches!(control.send_state, TcpSocketState::Normal) {
return Err(io::ErrorKind::BrokenPipe.into()).into();
return Err(std::io::ErrorKind::BrokenPipe.into()).into();
}

// Write to buffer
@@ -532,11 +527,11 @@ impl AsyncWrite for TcpStream {
Ok(n).into()
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Ok(()).into()
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let mut control = self.control.lock();

if matches!(control.send_state, TcpSocketState::Closed) {
43 changes: 14 additions & 29 deletions src/udp.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
io,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
@@ -12,7 +11,7 @@ use tokio::sync::mpsc::{Receiver, Sender};
use tokio_util::sync::PollSender;
use tracing::{error, trace};

use super::packet::{AnyIpPktFrame, IpPacket};
use crate::packet::{AnyIpPktFrame, IpPacket};

pub type UdpMsg = (
Vec<u8>, /* payload */
@@ -69,17 +68,12 @@ impl Stream for ReadHalf {

let src_ip = packet.src_addr();
let dst_ip = packet.dst_addr();
let payload = packet.payload();

let packet = match UdpPacket::new_checked(packet.payload()) {
let packet = match UdpPacket::new_checked(payload) {
Ok(p) => p,
Err(err) => {
error!(
"invalid err: {}, src_ip: {}, dst_ip: {}, payload: {:?}",
err,
packet.src_addr(),
packet.dst_addr(),
packet.payload(),
);
error!("invalid err: {err}, src_ip: {src_ip}, dst_ip: {dst_ip}, payload: {payload:?}");
return None;
}
};
@@ -98,16 +92,17 @@ impl Stream for ReadHalf {
}

impl Sink<UdpMsg> for WriteHalf {
type Error = io::Error;
type Error = std::io::Error;

fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match ready!(self.stack_tx.poll_ready_unpin(cx)) {
Ok(()) => Poll::Ready(Ok(())),
Err(err) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err))),
Err(err) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, err))),
}
}

fn start_send(mut self: Pin<&mut Self>, item: UdpMsg) -> Result<(), Self::Error> {
use std::io::{Error, ErrorKind::InvalidData, ErrorKind::Other};
let (data, src_addr, dst_addr) = item;

if data.is_empty() {
@@ -124,44 +119,34 @@ impl Sink<UdpMsg> for WriteHalf {
.udp(src_addr.port(), dst_addr.port())
}
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"source and destination type unmatch",
));
return Err(Error::new(InvalidData, "src or destination type unmatch"));
}
};

let mut ip_packet_writer = Vec::with_capacity(builder.size(data.len()));
builder
.write(&mut ip_packet_writer, &data)
.expect("PacketBuilder::write");
.map_err(|err| Error::new(Other, format!("PacketBuilder::write: {}", err)))?;

match self.stack_tx.start_send_unpin(ip_packet_writer.clone()) {
Ok(()) => Ok(()),
Err(err) => Err(io::Error::new(
io::ErrorKind::Other,
format!("send error: {}", err),
)),
Err(err) => Err(Error::new(Other, format!("send error: {}", err))),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
use std::io::{Error, ErrorKind::Other};
match ready!(self.stack_tx.poll_flush_unpin(cx)) {
Ok(()) => Poll::Ready(Ok(())),
Err(err) => Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
format!("flush error: {}", err),
))),
Err(err) => Poll::Ready(Err(Error::new(Other, format!("flush error: {}", err)))),
}
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
use std::io::{Error, ErrorKind::Other};
match ready!(self.stack_tx.poll_close_unpin(cx)) {
Ok(()) => Poll::Ready(Ok(())),
Err(err) => Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
format!("close error: {}", err),
))),
Err(err) => Poll::Ready(Err(Error::new(Other, format!("close error: {}", err)))),
}
}
}