diff --git a/Cargo.lock b/Cargo.lock index ca9bbc9..d8e0621 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -286,9 +286,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" [[package]] name = "cc" @@ -732,6 +732,7 @@ name = "dumbpipe" version = "0.23.0" dependencies = [ "anyhow", + "bytes", "clap", "duct", "hex", diff --git a/Cargo.toml b/Cargo.toml index 8f77204..260e0da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ rust-version = "1.81" [dependencies] anyhow = "1.0.75" +bytes = "1.10.0" clap = { version = "4.4.10", features = ["derive"] } hex = "0.4.3" iroh = { version = "0.31", default-features = false } diff --git a/src/main.rs b/src/main.rs index 68a2cf9..28a4041 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,15 +7,13 @@ use iroh::{ Endpoint, NodeAddr, SecretKey, }; use std::{ - io, - net::{SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}, - str::FromStr, + io, net::{SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}, str::FromStr }; use tokio::{ - io::{AsyncRead, AsyncWrite, AsyncWriteExt}, - select, + io::{AsyncRead, AsyncWrite, AsyncWriteExt}, select }; use tokio_util::sync::CancellationToken; +mod udp; /// Create a dumb pipe between two machines, using an iroh magicsocket. /// @@ -54,6 +52,15 @@ pub enum Commands { /// connecting to a TCP socket for which you have to specify the host and port. ListenTcp(ListenTcpArgs), + /// Listen on a magicsocket and forward incoming connections to the specified + /// host and port. Every incoming bidi stream is forwarded to a new connection. + /// + /// Will print a node ticket on stderr that can be used to connect. + /// + /// As far as the magic socket is concerned, this is listening. But it is + /// connecting to a UDP socket for which you have to specify the host and port. + ListenUdp(ListenUdpArgs), + /// Connect to a magicsocket, open a bidi stream, and forward stdin/stdout. /// /// A node ticket is required to connect. @@ -67,6 +74,15 @@ pub enum Commands { /// As far as the magic socket is concerned, this is connecting. But it is /// listening on a TCP socket for which you have to specify the interface and port. ConnectTcp(ConnectTcpArgs), + + /// Connect to a magicsocket, open a bidi stream, and forward stdin/stdout + /// to it. + /// + /// A node ticket is required to connect. + /// + /// As far as the magic socket is concerned, this is connecting. But it is + /// listening on a UDP socket for which you have to specify the interface and port. + ConnectUdp(ConnectUdpArgs), } #[derive(Parser, Debug)] @@ -140,6 +156,15 @@ pub struct ListenTcpArgs { pub common: CommonArgs, } +#[derive(Parser, Debug)] +pub struct ListenUdpArgs { + #[clap(long)] + pub host: String, + + #[clap(flatten)] + pub common: CommonArgs, +} + #[derive(Parser, Debug)] pub struct ConnectTcpArgs { /// The addresses to listen on for incoming tcp connections. @@ -155,6 +180,21 @@ pub struct ConnectTcpArgs { pub common: CommonArgs, } +#[derive(Parser, Debug)] +pub struct ConnectUdpArgs { + /// The addresses to listen on for incoming udp connections. + /// + /// To listen on all network interfaces, use 0.0.0.0:12345 + #[clap(long)] + pub addr: String, + + /// The node to connect to + pub ticket: NodeTicket, + + #[clap(flatten)] + pub common: CommonArgs, +} + #[derive(Parser, Debug)] pub struct ConnectArgs { /// The node to connect to @@ -540,8 +580,10 @@ async fn main() -> anyhow::Result<()> { let res = match args.command { Commands::Listen(args) => listen_stdio(args).await, Commands::ListenTcp(args) => listen_tcp(args).await, + Commands::ListenUdp(args) => udp::listen_udp(args).await, Commands::Connect(args) => connect_stdio(args).await, Commands::ConnectTcp(args) => connect_tcp(args).await, + Commands::ConnectUdp(args) => udp::connect_udp(args).await, }; match res { Ok(()) => std::process::exit(0), diff --git a/src/udp.rs b/src/udp.rs new file mode 100644 index 0000000..9c55072 --- /dev/null +++ b/src/udp.rs @@ -0,0 +1,377 @@ +use anyhow::{Context, Result}; +use bytes::Bytes; +use dumbpipe::NodeTicket; +use iroh::{ + endpoint::{get_remote_node_id, Connecting}, + Endpoint, +}; +use quinn::Connection; +use std::{ + collections::HashMap, + net::{SocketAddr, ToSocketAddrs}, +}; +use tokio::{net::UdpSocket, select, signal}; +use tokio_util::sync::CancellationToken; + +use std::sync::Arc; + +use crate::{get_or_create_secret, ConnectUdpArgs, ListenUdpArgs}; + +// 1- Receives request message from socket +// 2- Forwards it to the connection datagram +// 3- Receives response message back from connection datagram +// 4- Forwards it back to the socket +pub async fn connect_udp(args: ConnectUdpArgs) -> anyhow::Result<()> { + let addrs = args + .addr + .to_socket_addrs() + .context(format!("invalid host string {}", args.addr))?; + let secret_key = get_or_create_secret()?; + let mut builder = Endpoint::builder().secret_key(secret_key).alpns(vec![]); + if let Some(addr) = args.common.magic_ipv4_addr { + builder = builder.bind_addr_v4(addr); + } + if let Some(addr) = args.common.magic_ipv6_addr { + builder = builder.bind_addr_v6(addr); + } + let endpoint = builder.bind().await.context("unable to bind magicsock")?; + tracing::info!("udp listening on {:?}", addrs); + let socket = Arc::new(UdpSocket::bind(addrs.as_slice()).await?); + + let node_addr = args.ticket.node_addr(); + let mut buf: Vec = vec![0u8; 65535]; + let conns = Arc::new(tokio::sync::Mutex::new( + HashMap::::new(), + )); + loop { + tokio::select! { + _ = signal::ctrl_c() => { + eprintln!("Received CTRL-C, shutting down..."); + break; + } + result = socket.recv_from(&mut buf) => { + match result { + Ok((size, sock_addr)) => { + // Check if we already have a connection for this socket address + let mut cnns = conns.lock().await; + let connection = match cnns.get_mut(&sock_addr) { + Some(conn) => conn, + None => { + // If we don't have a connection, drop the previous lock to create a new one later on + drop(cnns); + + // Create a new connection since this address is not in the hashmap + let endpoint = endpoint.clone(); + let addr = node_addr.clone(); + let handshake = !args.common.is_custom_alpn(); + let alpn = args.common.alpn()?; + + let remote_node_id = addr.node_id; + tracing::info!("creating a connection to be forwarding UDP to {}", remote_node_id); + + // connect to the node, try only once + let connection = endpoint + .connect(addr.clone(), &alpn) + .await + .context(format!("error connecting to {}", remote_node_id))?; + tracing::info!("connected to {}", remote_node_id); + + // send the handshake unless we are using a custom alpn + if handshake { + connection.send_datagram(Bytes::from_static(&dumbpipe::HANDSHAKE))?; + } + + let sock_send = socket.clone(); + let conn_clone = connection.clone(); + let conns_clone = conns.clone(); + // Spawn a task for listening the connection datagram, and forward the data to the UDP socket + tokio::spawn(async move { + // 3- Receives response message back from connection datagram + // 4- Forwards it back to the socket + if let Err(cause) = handle_udp_accept(sock_addr, sock_send, conn_clone).await { + // log error at warn level + // + // we should know about it, but it's not fatal + tracing::warn!("error handling connection: {}", cause); + } + // Cleanup resources for this connection since it's `Connection` is closed or errored out + let mut cn = conns_clone.lock().await; + cn.remove(&sock_addr); + }); + + // Store the connection and return + let mut cn = conns.lock().await; + cn.insert(sock_addr, connection.clone()); + &mut connection.clone() + } + }; + + // 1- Receives request message from socket + // 2- Forwards it to the connection datagram + if let Err(e) = connection.send_datagram(Bytes::copy_from_slice(&buf[..size])) { // Is Bytes::copy_from_slice most efficient way to do this?. Investigate. + tracing::error!("Error writing to connection datagram: {}", e); + return Err(e.into()); + } + } + Err(e) => { + tracing::warn!("error receiving from UDP socket: {}", e); + break; + } + } + } + } + } + Ok(()) +} + +/// Listen on a magicsocket and forward incoming connections to a udp socket. +pub async fn listen_udp(args: ListenUdpArgs) -> anyhow::Result<()> { + let addrs = match args.host.to_socket_addrs() { + Ok(addrs) => addrs.collect::>(), + Err(e) => anyhow::bail!("invalid host string {}: {}", args.host, e), + }; + let secret_key = get_or_create_secret()?; + let mut builder = Endpoint::builder() + .alpns(vec![args.common.alpn()?]) + .secret_key(secret_key); + if let Some(addr) = args.common.magic_ipv4_addr { + builder = builder.bind_addr_v4(addr); + } + if let Some(addr) = args.common.magic_ipv6_addr { + builder = builder.bind_addr_v6(addr); + } + let endpoint = builder.bind().await?; + // wait for the endpoint to figure out its address before making a ticket + endpoint.home_relay().initialized().await?; + let node_addr = endpoint.node_addr().await?; + let mut short = node_addr.clone(); + let ticket = NodeTicket::new(node_addr); + short.direct_addresses.clear(); + let short = NodeTicket::new(short); + + // print the ticket on stderr so it doesn't interfere with the data itself + // + // note that the tests rely on the ticket being the last thing printed + eprintln!("Forwarding incoming requests to '{}'.", args.host); + eprintln!("To connect, use e.g.:"); + eprintln!("dumbpipe connect-udp {ticket}"); + if args.common.verbose > 0 { + eprintln!("or:\ndumbpipe connect-udp {}", short); + } + tracing::info!("node id is {}", ticket.node_addr().node_id); + tracing::info!("derp url is {:?}", ticket.node_addr().relay_url); + + // handle a new incoming connection on the magic endpoint + async fn handle_magic_accept( + connecting: Connecting, + addrs: Vec, + handshake: bool, + ) -> anyhow::Result<()> { + let connection = connecting.await.context("error accepting connection")?; + let remote_node_id = get_remote_node_id(&connection)?; + tracing::info!("got connection from {}", remote_node_id); + if handshake { + // read the handshake and verify it + let bytes = connection.read_datagram().await?; + anyhow::ensure!(*bytes == dumbpipe::HANDSHAKE, "invalid handshake"); + } + + // 1- Receives request message from connection datagram + // 2- Forwards it to the (addrs) via UDP socket + // 3- Receives response message back from UDP socket + // 4- Forwards it back to the connection datagram + handle_udp_listen(addrs.as_slice(), connection).await?; + Ok(()) + } + + loop { + let incoming = select! { + incoming = endpoint.accept() => incoming, + _ = tokio::signal::ctrl_c() => { + eprintln!("got ctrl-c, exiting"); + break; + } + }; + let Some(incoming) = incoming else { + break; + }; + let Ok(connecting) = incoming.accept() else { + break; + }; + let addrs = addrs.clone(); + let handshake = !args.common.is_custom_alpn(); + tokio::spawn(async move { + if let Err(cause) = handle_magic_accept(connecting, addrs, handshake).await { + // log error at warn level + // + // we should know about it, but it's not fatal + tracing::warn!("error handling connection: {}", cause); + } + }); + } + Ok(()) +} + +async fn handle_udp_accept( + client_addr: SocketAddr, + udp_socket: Arc, + connection: Connection, +) -> Result<()> { + // Create a cancellation token to coordinate shutdown + let token = CancellationToken::new(); + let token_conn = token.clone(); + let token_ctrl_c = token.clone(); + + // Create buffer for receiving data + let connection_to_udp = { + let socket = udp_socket.clone(); + tokio::spawn(async move { + loop { + // Check if we should stop + if token_conn.is_cancelled() { + break; + } + + // Read from connection datagram + match connection.read_datagram().await { + Ok(bytes) => { + // Forward to UDP peer + if let Err(e) = socket.send_to(&bytes, client_addr).await { + tracing::error!("Error sending to UDP: {}", e); + token_conn.cancel(); + break; + } + } + Err(e) => { + tracing::error!("Connection read_datagram error: {}", e); + token_conn.cancel(); + break; + } + } + } + }) + }; + + // Handle Ctrl+C signal + let ctrl_c = tokio::spawn(async move { + if let Ok(()) = tokio::signal::ctrl_c().await { + token_ctrl_c.cancel(); + } + }); + + // Wait for any task to complete (or Ctrl+C) + tokio::select! { + _ = connection_to_udp => {}, + _ = ctrl_c => {}, + } + + Ok(()) +} + +// Every new connection is a new socket to the `connect udp` command +async fn handle_udp_listen(peer_addrs: &[SocketAddr], connection: Connection) -> Result<()> { + // Create a cancellation token to coordinate shutdown + let token = CancellationToken::new(); + let token_udp = token.clone(); + let token_conn = token.clone(); + let token_ctrl_c = token.clone(); + + // Create a new socket for this connection, representing the client connected to UDP server at the other side. + // This socket will be used to send data to the actual server, receive response back and forward it to the conn. + let socket = Arc::new(UdpSocket::bind("0.0.0:0").await?); + + let udp_buf_size = 65535; // Maximum UDP packet size + let conn_to_udp = { + let socket_send = socket.clone(); + let p_addr = peer_addrs.to_vec(); + let conn_clone = connection.clone(); + tokio::spawn(async move { + loop { + // Check if we should stop + if token_conn.is_cancelled() { + tracing::info!("Token cancellation was requested. Ending QUIC to UDP task."); + break; + } + + // Read from connection datagram + match conn_clone.read_datagram().await { + Ok(bytes) => { + // Forward to UDP peer + for addr in p_addr.iter() { + if let Err(e) = socket_send.send_to(&bytes, addr).await { + tracing::error!("Error sending to UDP: {}", e); + token_conn.cancel(); + break; + } + } + } + Err(e) => { + tracing::error!("Connection read_datagram error: {}", e); + token_conn.cancel(); + break; + } + } + } + tracing::info!("Token cancellation was requested or error received. connection datagram task ended."); + }) + }; + + let udp_to_conn = { + // Task for listening to the response to the UDP server + let socket_listen = socket.clone(); + let conn_clone = connection.clone(); + tokio::spawn(async move { + let mut buf = vec![0u8; udp_buf_size]; + loop { + // Check if we should stop + if token_udp.is_cancelled() { + tracing::info!("Token cancellation was requested. Ending UDP to QUIC task."); + break; + } + + // Use timeout to periodically check cancellation + match tokio::time::timeout( + tokio::time::Duration::from_millis(100), + socket_listen.recv_from(&mut buf), + ) + .await + { + Ok(Ok((n, _addr))) => { + // Forward the buf back to the connection datagram + if let Err(e) = conn_clone.send_datagram(Bytes::copy_from_slice(&buf[..n])) + { + tracing::error!("Error on connection send_datagram: {}", e); + token_udp.cancel(); + break; + } + } + Ok(Err(e)) => { + tracing::error!("UDP receive error: {}", e); + token_udp.cancel(); + break; + } + Err(_) => continue, // Timeout, check cancellation + } + } + tracing::info!( + "Token cancellation was requested or error received. UDP socket task ended." + ); + }) + }; + + // Handle Ctrl+C signal + let ctrl_c = tokio::spawn(async move { + if let Ok(()) = tokio::signal::ctrl_c().await { + token_ctrl_c.cancel(); + } + }); + + // Wait for any task to complete (or Ctrl+C) + tokio::select! { + _ = conn_to_udp => {}, + _ = udp_to_conn => {}, + _ = ctrl_c => {}, + } + + Ok(()) +}