diff --git a/Cargo.lock b/Cargo.lock index a020626738a..83e58f1b814 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3693,6 +3693,7 @@ dependencies = [ "prettytable-rs", "rcgen", "reqwest 0.12.5", + "rstest", "rustls 0.23.12", "rustls-pemfile 2.1.3", "semver 1.0.23", @@ -3703,6 +3704,7 @@ dependencies = [ "thiserror", "tokio", "tokio-rustls 0.26.0", + "tokio-stream", "tokio-util", "tracing", "tracing-subscriber", diff --git a/changelog.d/567.added.md b/changelog.d/567.added.md new file mode 100644 index 00000000000..077b3051802 --- /dev/null +++ b/changelog.d/567.added.md @@ -0,0 +1,5 @@ +Add port forwarding feature which can be used to proxy data from a local port to a remote one - +if the local port is not specified, it will default to the same as the remote +``` +mirrord port-forward [options] -L [local_port:]remote_ip:remote_port +``` \ No newline at end of file diff --git a/mirrord/analytics/src/lib.rs b/mirrord/analytics/src/lib.rs index c4c44d3ed0e..4ca01f2b6c6 100644 --- a/mirrord/analytics/src/lib.rs +++ b/mirrord/analytics/src/lib.rs @@ -35,6 +35,7 @@ pub enum ExecutionKind { Container = 1, #[default] Exec = 2, + PortForward = 3, Other = 0, } @@ -43,6 +44,7 @@ impl From for ExecutionKind { match kind { 1 => ExecutionKind::Container, 2 => ExecutionKind::Exec, + 3 => ExecutionKind::PortForward, _ => ExecutionKind::Other, } } diff --git a/mirrord/cli/Cargo.toml b/mirrord/cli/Cargo.toml index eedc6345768..1eab5ebd1c4 100644 --- a/mirrord/cli/Cargo.toml +++ b/mirrord/cli/Cargo.toml @@ -19,7 +19,11 @@ workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -mirrord-operator = { path = "../operator", features = ["client", "license-fetch", "setup"] } +mirrord-operator = { path = "../operator", features = [ + "client", + "license-fetch", + "setup", +] } mirrord-progress = { path = "../progress" } mirrord-kube = { path = "../kube" } mirrord-config = { path = "../config" } @@ -40,13 +44,13 @@ semver.workspace = true exec.workspace = true reqwest.workspace = true const-random = "0.1.15" -tokio = { workspace = true, features = ["rt", "net", "macros", "process"]} +tokio = { workspace = true, features = ["rt", "net", "macros", "process"] } kube.workspace = true k8s-openapi.workspace = true miette = { version = "7", features = ["fancy"] } thiserror.workspace = true humantime = "2" -nix = {workspace = true, features = ["process", "resource"]} +nix = { workspace = true, features = ["process", "resource"] } tokio-util.workspace = true socket2.workspace = true drain.workspace = true @@ -58,6 +62,7 @@ tempfile = "3" rcgen = "0.13" rustls-pemfile = "2" tokio-rustls = "0.26" +tokio-stream = { workspace = true, features = ["net"] } [target.'cfg(target_os = "macos")'.dependencies] @@ -65,4 +70,7 @@ mirrord-sip = { path = "../sip" } [build-dependencies] -mirrord-layer = { artifact = "cdylib", path="../layer" } +mirrord-layer = { artifact = "cdylib", path = "../layer" } + +[dev-dependencies] +rstest = "0.21" diff --git a/mirrord/cli/src/config.rs b/mirrord/cli/src/config.rs index 19fc1a981d2..799b6de2f62 100644 --- a/mirrord/cli/src/config.rs +++ b/mirrord/cli/src/config.rs @@ -1,11 +1,19 @@ #![deny(missing_docs)] -use std::{collections::HashMap, ffi::OsString, fmt::Display, path::PathBuf}; +use std::{ + collections::HashMap, + ffi::OsString, + fmt::Display, + net::{IpAddr, Ipv4Addr, SocketAddr}, + path::PathBuf, + str::FromStr, +}; use clap::{ArgGroup, Args, Parser, Subcommand, ValueEnum, ValueHint}; use clap_complete::Shell; use mirrord_config::MIRRORD_CONFIG_FILE_ENV; use mirrord_operator::setup::OperatorNamespace; +use thiserror::Error; use crate::error::CliError; @@ -58,6 +66,10 @@ pub(super) enum Commands { #[command(hide = true, name = "intproxy")] InternalProxy, + /// Port forwarding - UNSTABLE FEATURE + #[command(name = "port-forward")] + PortForward(Box), + /// Verify config file without starting mirrord. #[command(hide = true)] VerifyConfig(VerifyConfigArgs), @@ -96,15 +108,9 @@ impl Display for FsMode { #[derive(Args, Debug)] /// Parameters to override any values from mirrord-config as part of `exec` or `container` commands. pub(super) struct ExecParams { - /// Target name to mirror. - /// Target can either be a deployment or a pod. - /// Valid formats: deployment/name, pod/name, pod/name/container/name - #[arg(short = 't', long)] - pub target: Option, - - /// Namespace of the pod to mirror. Defaults to "default". - #[arg(short = 'n', long)] - pub target_namespace: Option, + /// Parameters for the target + #[clap(flatten)] + pub target: TargetParams, /// Namespace to place agent in. #[arg(short = 'a', long)] @@ -191,7 +197,7 @@ impl ExecParams { pub fn as_env_vars(&self) -> Result, CliError> { let mut envs: HashMap = HashMap::new(); - if let Some(target) = &self.target { + if let Some(target) = &self.target.target { envs.insert("MIRRORD_IMPERSONATED_TARGET".into(), target.into()); } @@ -203,7 +209,7 @@ impl ExecParams { envs.insert("MIRRORD_SKIP_PROCESSES".into(), skip_processes.into()); } - if let Some(namespace) = &self.target_namespace { + if let Some(namespace) = &self.target.target_namespace { envs.insert("MIRRORD_TARGET_NAMESPACE".into(), namespace.into()); } @@ -303,6 +309,200 @@ pub(super) struct ExecArgs { pub(super) binary_args: Vec, } +#[derive(Args, Debug)] +pub(super) struct TargetParams { + /// Target name to mirror. + /// Target can either be a deployment or a pod. + /// Valid formats: deployment/name, pod/name, pod/name/container/name + #[arg(short = 't', long)] + pub target: Option, + + /// Namespace of the pod to mirror. Defaults to "default". + #[arg(short = 'n', long)] + pub target_namespace: Option, +} + +impl TargetParams { + pub fn as_env_vars(&self) -> Result, CliError> { + let mut envs: HashMap = HashMap::new(); + + if let Some(target) = &self.target { + envs.insert("MIRRORD_IMPERSONATED_TARGET".into(), target.into()); + } + if let Some(namespace) = &self.target_namespace { + envs.insert("MIRRORD_TARGET_NAMESPACE".into(), namespace.into()); + } + + Ok(envs) + } +} + +#[derive(Args, Debug)] +#[command(group(ArgGroup::new("port-forward")))] +pub(super) struct PortForwardArgs { + /// Parameters for the target + #[clap(flatten)] + pub target: TargetParams, + + /// Namespace to place agent in + #[arg(short = 'a', long)] + pub agent_namespace: Option, + + /// Agent log level + #[arg(short = 'l', long)] + pub agent_log_level: Option, + + /// Agent image + #[arg(short = 'i', long)] + pub agent_image: Option, + + /// Agent TTL + #[arg(long)] + pub agent_ttl: Option, + + /// Agent Startup Timeout seconds + #[arg(long)] + pub agent_startup_timeout: Option, + + /// Accept/reject invalid certificates + #[arg(short = 'c', long)] + pub accept_invalid_certificates: bool, + + /// Use an Ephemeral Container to mirror traffic + #[arg(short, long)] + pub ephemeral_container: bool, + + /// Disable telemetry - see + #[arg(long)] + pub no_telemetry: bool, + + #[arg(long)] + /// Disable version check on startup + pub disable_version_check: bool, + + /// Load config from config file + #[arg(short = 'f', long, value_hint = ValueHint::FilePath)] + pub config_file: Option, + + /// Kube context to use from Kubeconfig + #[arg(long)] + pub context: Option, + + /// Mappings for port forwarding + #[arg(short = 'L', long)] + pub port_mappings: Vec, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct PortMapping { + pub local: SocketAddr, + pub remote: SocketAddr, +} + +impl FromStr for PortMapping { + type Err = PortMappingParseErr; + + fn from_str(string: &str) -> Result { + fn parse_port(string: &str, original: &str) -> Result { + match string.parse::() { + Ok(0) => Err(PortMappingParseErr::PortZeroInvalid(string.to_string())), + Ok(port) => Ok(port), + Err(_error) => Err(PortMappingParseErr::PortParseErr( + string.to_string(), + original.to_string(), + )), + } + } + + fn parse_ip(string: &str, original: &str) -> Result { + match string.parse::() { + Ok(ip) => Ok(ip), + Err(_error) => Err(PortMappingParseErr::IpParseErr( + string.to_string(), + original.to_string(), + )), + } + } + + // expected format = local_port:dest_server:remote_port + // alternatively, = dest_server:remote_port + let vec: Vec<&str> = string.split(':').collect(); + let (local_port, remote_ip, remote_port) = match vec.as_slice() { + [local_port, remote_ip, remote_port] => { + let local_port = parse_port(local_port, string)?; + let remote_port = parse_port(remote_port, string)?; + let remote_ip = parse_ip(remote_ip, string)?; + (local_port, remote_ip, remote_port) + } + [remote_ip, remote_port] => { + let remote_port = parse_port(remote_port, string)?; + let remote_ip = parse_ip(remote_ip, string)?; + (remote_port, remote_ip, remote_port) + } + _ => { + return Err(PortMappingParseErr::InvalidFormat(string.to_string())); + } + }; + + Ok(Self { + local: SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST), local_port), + remote: SocketAddr::new(IpAddr::V4(remote_ip), remote_port), + }) + } +} + +#[derive(Error, Debug, PartialEq)] +pub enum PortMappingParseErr { + #[error("Invalid format of argument `{0}`, expected `[local-port:]remote-ipv4:remote-port`")] + InvalidFormat(String), + + #[error("Failed to parse port `{0}` in argument `{1}`")] + PortParseErr(String, String), + + #[error("Failed to parse IPv4 address `{0}` in argument `{1}`")] + IpParseErr(String, String), + + #[error("Port `0` is not allowed in argument `{0}`")] + PortZeroInvalid(String), +} + +#[cfg(test)] +mod test { + use std::str::FromStr; + + use rstest::rstest; + + use super::PortMapping; + + #[rstest] + #[case("3030:152.37.110.132:3038", "127.0.0.1:3030", "152.37.110.132:3038")] + #[case("152.37.110.132:3038", "127.0.0.1:3038", "152.37.110.132:3038")] + fn parse_valid_mapping( + #[case] input: &str, + #[case] expected_local: &str, + #[case] expected_remote: &str, + ) { + let expected = PortMapping { + local: expected_local.parse().unwrap(), + remote: expected_remote.parse().unwrap(), + }; + assert_eq!(PortMapping::from_str(input).unwrap(), expected); + } + + #[rstest] + #[case("3030:152.37.110.132:3038:2027")] + #[case("152.37.110.132:3030:3038")] + #[case("3030:152.37.110.132:0")] + #[case("3o3o:152.37.11o.132:3o38")] + #[case("3030:152110.132:3038")] + #[case("30303030:152.37.110.132:3038")] + #[case("")] + #[should_panic] + fn parse_invalid_mapping(#[case] input: &str) { + PortMapping::from_str(input).unwrap(); + } +} + #[derive(Args, Debug)] pub(super) struct OperatorArgs { #[command(subcommand)] diff --git a/mirrord/cli/src/error.rs b/mirrord/cli/src/error.rs index e1c6f57dd96..0958785d712 100644 --- a/mirrord/cli/src/error.rs +++ b/mirrord/cli/src/error.rs @@ -10,6 +10,8 @@ use mirrord_operator::client::error::{HttpError, OperatorApiError, OperatorOpera use reqwest::StatusCode; use thiserror::Error; +use crate::port_forward::PortForwardError; + pub(crate) type Result = core::result::Result; const GENERAL_HELP: &str = r#" @@ -327,6 +329,10 @@ pub(crate) enum CliError { #[error("mirrord returned a target resource of unknown type: {0}")] #[diagnostic(help("{GENERAL_BUG}"))] OperatorReturnedUnknownTargetType(String), + + #[error("An error occurred in the port-forwarding process: {0}")] + #[diagnostic(help("{GENERAL_BUG}"))] + PortForwardingError(#[from] PortForwardError), } impl From for CliError { diff --git a/mirrord/cli/src/main.rs b/mirrord/cli/src/main.rs index 4de5ea80333..35163fdcfcc 100644 --- a/mirrord/cli/src/main.rs +++ b/mirrord/cli/src/main.rs @@ -8,6 +8,7 @@ use std::time::Duration; use clap::{CommandFactory, Parser}; use clap_complete::generate; use config::*; +use connection::create_and_connect; use container::container_command; use diagnose::diagnose_command; use exec::execvp; @@ -18,7 +19,7 @@ use kube::Client; use kube_resource::KubeResourceSeeker; use miette::JSONReportHandler; use mirrord_analytics::{ - AnalyticsError, AnalyticsReporter, CollectAnalytics, NullReporter, Reporter, + AnalyticsError, AnalyticsReporter, CollectAnalytics, ExecutionKind, NullReporter, Reporter, }; use mirrord_config::{ config::{ConfigContext, MirrordConfig}, @@ -36,6 +37,7 @@ use mirrord_kube::api::{container::SKIP_NAMES, kubernetes::create_kube_config}; use mirrord_operator::client::OperatorApi; use mirrord_progress::{Progress, ProgressTracker}; use operator::operator_command; +use port_forward::PortForwarder; use semver::Version; use serde_json::json; use tracing::{error, info, warn}; @@ -54,6 +56,7 @@ mod extract; mod internal_proxy; mod kube_resource; mod operator; +pub mod port_forward; mod teams; mod util; mod verify_config; @@ -454,6 +457,77 @@ async fn print_targets(args: &ListTargetArgs) -> Result<()> { Ok(()) } +async fn port_forward(args: &PortForwardArgs, watch: drain::Watch) -> Result<()> { + let mut progress = ProgressTracker::from_env("mirrord port-forward"); + progress.warning("Port forwarding is currently an unstable feature and subject to change. See https://github.com/metalbear-co/mirrord/issues/2640 for more info."); + if !args.disable_version_check { + prompt_outdated_version(&progress).await; + } + + for (name, value) in args.target.as_env_vars()? { + std::env::set_var(name, value); + } + + if args.no_telemetry { + std::env::set_var("MIRRORD_TELEMETRY", "false"); + } + + if let Some(namespace) = &args.agent_namespace { + std::env::set_var("MIRRORD_AGENT_NAMESPACE", namespace.clone()); + } + + if let Some(log_level) = &args.agent_log_level { + std::env::set_var("MIRRORD_AGENT_RUST_LOG", log_level.clone()); + } + + if let Some(image) = &args.agent_image { + std::env::set_var("MIRRORD_AGENT_IMAGE", image.clone()); + } + + if let Some(agent_ttl) = &args.agent_ttl { + std::env::set_var("MIRRORD_AGENT_TTL", agent_ttl.to_string()); + } + if let Some(agent_startup_timeout) = &args.agent_startup_timeout { + std::env::set_var( + "MIRRORD_AGENT_STARTUP_TIMEOUT", + agent_startup_timeout.to_string(), + ); + } + + if args.accept_invalid_certificates { + std::env::set_var("MIRRORD_ACCEPT_INVALID_CERTIFICATES", "true"); + warn!("Accepting invalid certificates"); + } + + if args.ephemeral_container { + std::env::set_var("MIRRORD_EPHEMERAL_CONTAINER", "true"); + }; + + if let Some(context) = &args.context { + std::env::set_var("MIRRORD_KUBE_CONTEXT", context); + } + + if let Some(config_file) = &args.config_file { + std::env::set_var("MIRRORD_CONFIG_FILE", config_file); + } + + let (config, mut context) = LayerConfig::from_env_with_warnings()?; + + let mut analytics = AnalyticsReporter::new(config.telemetry, ExecutionKind::PortForward, watch); + (&config).collect_analytics(analytics.get_mut()); + + config.verify(&mut context)?; + for warning in context.get_warnings() { + progress.warning(warning); + } + + let (_connection_info, connection) = + create_and_connect(&config, &mut progress, &mut analytics).await?; + let mut port_forward = PortForwarder::new(connection, args.port_mappings.clone()).await?; + port_forward.run().await?; + Ok(()) +} + const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION"); fn main() -> miette::Result<()> { @@ -506,6 +580,7 @@ fn main() -> miette::Result<()> { container_command(runtime_args, exec_params, watch).await? } Commands::ExternalProxy => external_proxy::proxy(watch).await?, + Commands::PortForward(args) => port_forward(&args, watch).await?, }; Ok(()) diff --git a/mirrord/cli/src/port_forward.rs b/mirrord/cli/src/port_forward.rs new file mode 100644 index 00000000000..99ebcecae03 --- /dev/null +++ b/mirrord/cli/src/port_forward.rs @@ -0,0 +1,815 @@ +use std::{ + collections::{HashMap, VecDeque}, + net::SocketAddr, + time::{Duration, Instant}, +}; + +use futures::StreamExt; +use mirrord_protocol::{ + outgoing::{ + tcp::{DaemonTcpOutgoing, LayerTcpOutgoing}, + LayerClose, LayerConnect, LayerWrite, SocketAddress, + }, + ClientMessage, ConnectionId, DaemonMessage, LogLevel, CLIENT_READY_FOR_LOGS, +}; +use thiserror::Error; +use tokio::{ + io::AsyncWriteExt, + net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpListener, TcpStream, + }, + select, + sync::{ + mpsc::{self, Receiver, Sender}, + oneshot, + }, +}; +use tokio_stream::{wrappers::TcpListenerStream, StreamMap}; +use tokio_util::io::ReaderStream; +use tracing::Level; + +use crate::{connection::AgentConnection, PortMapping}; + +pub struct PortForwarder { + /// communicates with the agent (only TCP supported) + agent_connection: AgentConnection, + /// associates local ports with destination ports + mappings: HashMap, + /// accepts connections from the user app in the form of a stream + listeners: StreamMap, + /// oneshot channels for sending connection IDs to tasks and the associated local address + oneshots: VecDeque<(SocketAddr, oneshot::Sender)>, + /// identifies a pair of mapped socket addresses by their corresponding connection ID + sockets: HashMap, + /// identifies task senders by their corresponding local socket address + /// for sending data from the remote socket to the local address + task_txs: HashMap>>, + + /// transmit internal messages from tasks to [`PortForwarder`]'s main loop. + internal_msg_tx: Sender, + internal_msg_rx: Receiver, + + /// true if Ping has been sent to agent + waiting_for_pong: bool, + ping_pong_timeout: Instant, +} + +/// Used by tasks for individual forwarding connections to send instructions to [`PortForwarder`]'s +/// main loop. +#[derive(Debug)] +enum PortForwardMessage { + /// A request to make outgoing connection to the remote peer. + /// Sent by the task only after receiving first batch of data from the user. + /// The task waits for [`ConnectionId`] on the other end of the [`oneshot`] channel. + Connect(PortMapping, oneshot::Sender), + + /// Data received from the user in the connection with the given id. + Send(ConnectionId, Vec), + + /// A request to close the remote connection with the given id, if it exists. + Close(PortMapping, Option), +} + +impl PortForwarder { + pub(crate) async fn new( + agent_connection: AgentConnection, + parsed_mappings: Vec, + ) -> Result { + // open tcp listener for local addrs + let mut listeners = StreamMap::with_capacity(parsed_mappings.len()); + let mut mappings = HashMap::with_capacity(parsed_mappings.len()); + + if parsed_mappings.is_empty() { + return Err(PortForwardError::NoMappingsError()); + } + for mapping in parsed_mappings { + if listeners.contains_key(&mapping.local) { + // two mappings shared a key thus keys were not unique + return Err(PortForwardError::PortMapSetupError(mapping.local)); + } + let listener = TcpListener::bind(mapping.local).await; + match listener { + Ok(listener) => { + listeners.insert(mapping.local, TcpListenerStream::new(listener)); + mappings.insert(mapping.local, mapping.remote); + } + Err(error) => return Err(PortForwardError::TcpListenerError(error)), + } + } + + let (internal_msg_tx, internal_msg_rx) = mpsc::channel(1024); + + Ok(Self { + agent_connection, + mappings, + listeners, + oneshots: VecDeque::new(), + sockets: HashMap::new(), + task_txs: HashMap::new(), + internal_msg_tx, + internal_msg_rx, + waiting_for_pong: false, + ping_pong_timeout: Instant::now(), + }) + } + + pub(crate) async fn run(&mut self) -> Result<(), PortForwardError> { + // setup agent connection + self.agent_connection + .sender + .send(ClientMessage::SwitchProtocolVersion( + mirrord_protocol::VERSION.clone(), + )) + .await?; + match self.agent_connection.receiver.recv().await { + Some(DaemonMessage::SwitchProtocolVersionResponse(version)) + if CLIENT_READY_FOR_LOGS.matches(&version) => + { + self.agent_connection + .sender + .send(ClientMessage::ReadyForLogs) + .await?; + } + _ => return Err(PortForwardError::AgentConnectionFailed), + } + tracing::trace!("port forwarding setup complete"); + + loop { + select! { + _ = tokio::time::sleep_until(self.ping_pong_timeout.into()) => { + if self.waiting_for_pong { + // no pong received before timeout + break Err(PortForwardError::AgentError("agent failed to respond to Ping".into())); + } + self.agent_connection.sender.send(ClientMessage::Ping).await?; + self.waiting_for_pong = true; + self.ping_pong_timeout = Instant::now() + Duration::from_secs(30); + }, + + message = self.agent_connection.receiver.recv() => match message { + Some(message) => self.handle_msg_from_agent(message).await?, + None => { + break Err(PortForwardError::AgentError("unexpected end of connection with agent".into())); + }, + }, + + // stream coming from the user app + message = self.listeners.next() => match message { + Some(message) => self.handle_listener_stream(message).await?, + None => unreachable!("created listener sockets are never closed"), + }, + + message = self.internal_msg_rx.recv() => { + self.handle_msg_from_task(message.expect("this channel is never closed")).await?; + }, + } + } + } + + #[tracing::instrument(level = Level::TRACE, skip(self), err)] + async fn handle_msg_from_agent( + &mut self, + message: DaemonMessage, + ) -> Result<(), PortForwardError> { + match message { + DaemonMessage::TcpOutgoing(message) => match message { + DaemonTcpOutgoing::Connect(res) => match res { + Ok(res) => { + let connection_id = res.connection_id; + let SocketAddress::Ip(remote_socket) = res.remote_address else { + return Err(PortForwardError::ConnectionError( + "unexpectedly received Unix address for socket during setup".into(), + )); + }; + let Some((local_socket, channel)) = self.oneshots.pop_front() else { + return Err(PortForwardError::ReadyTaskNotFound( + remote_socket, + connection_id, + )); + }; + let port_map = PortMapping { + local: local_socket, + remote: remote_socket, + }; + self.sockets.insert(connection_id, port_map); + match channel.send(connection_id) { + Ok(_) => (), + Err(_) => { + self.agent_connection + .sender + .send(ClientMessage::TcpOutgoing(LayerTcpOutgoing::Close( + LayerClose { connection_id }, + ))) + .await?; + self.task_txs.remove(&local_socket); + self.sockets.remove(&connection_id); + tracing::warn!("failed to send connection ID {connection_id} to task on oneshot channel"); + } + }; + tracing::trace!("successful connection to remote address {remote_socket}, connection ID is {}", connection_id); + } + Err(error) => { + tracing::error!("failed to connect to a remote address: {error}"); + // LocalConnectionTask will fail when oneshot is dropped and handle cleanup + let _ = self.oneshots.pop_front(); + } + }, + DaemonTcpOutgoing::Read(res) => match res { + Ok(res) => { + let Some(&PortMapping { + local: local_socket, + remote: _, + }) = self.sockets.get(&res.connection_id) + else { + // ignore unknown connection IDs + return Ok(()); + }; + let Some(sender) = self.task_txs.get(&local_socket) else { + unreachable!("sender is always created before this point") + }; + match sender.send(res.bytes).await { + Ok(_) => (), + Err(_) => { + self.task_txs.remove(&local_socket); + self.sockets.remove(&res.connection_id); + self.agent_connection + .sender + .send(ClientMessage::TcpOutgoing(LayerTcpOutgoing::Close( + LayerClose { + connection_id: res.connection_id, + }, + ))) + .await?; + tracing::error!( + "failed to send response from remote to local port" + ); + } + } + } + Err(error) => { + return Err(PortForwardError::AgentError(format!( + "problem receiving DaemonTcpOutgoing::Read {error}" + ))) + } + }, + DaemonTcpOutgoing::Close(connection_id) => { + let Some(PortMapping { + local: local_socket, + remote: remote_socket, + }) = self.sockets.remove(&connection_id) + else { + // ignore unknown connection IDs + return Ok(()); + }; + self.task_txs.remove(&local_socket); + tracing::trace!( + "connection closed for port mapping {local_socket}:{remote_socket}, connection {connection_id}" + ); + } + }, + DaemonMessage::LogMessage(log_message) => match log_message.level { + LogLevel::Warn => tracing::warn!("agent log: {}", log_message.message), + LogLevel::Error => tracing::error!("agent log: {}", log_message.message), + }, + DaemonMessage::Close(error) => { + return Err(PortForwardError::AgentError(error)); + } + DaemonMessage::Pong if self.waiting_for_pong => { + self.waiting_for_pong = false; + } + other => { + // includes unexepcted DaemonMessage::Pong + return Err(PortForwardError::AgentError(format!( + "unexpected message from agent: {other:?}" + ))); + } + } + + Ok(()) + } + + #[tracing::instrument(level = Level::TRACE, skip(self), err)] + async fn handle_listener_stream( + &mut self, + message: (SocketAddr, Result), + ) -> Result<(), PortForwardError> { + let local_socket = message.0; + let stream = match message.1 { + Ok(stream) => stream, + Err(error) => { + // error from TcpStream + tracing::error!( + "error occured while listening to local socket {local_socket}: {error}" + ); + self.listeners.remove(&local_socket); + return Ok(()); + } + }; + + let task_internal_tx = self.internal_msg_tx.clone(); + let Some(remote_socket) = self.mappings.get(&local_socket).cloned() else { + unreachable!("mappings are always created before this point") + }; + let (response_tx, response_rx) = mpsc::channel(256); + self.task_txs.insert(local_socket, response_tx); + + tokio::spawn(async move { + let mut task = LocalConnectionTask::new( + stream, + local_socket, + remote_socket, + task_internal_tx, + response_rx, + ); + task.run().await + }); + + Ok(()) + } + + #[tracing::instrument(level = Level::TRACE, skip(self), err)] + async fn handle_msg_from_task( + &mut self, + message: PortForwardMessage, + ) -> Result<(), PortForwardError> { + match message { + PortForwardMessage::Connect(port_mapping, oneshot) => { + let remote_address = SocketAddress::Ip(port_mapping.remote); + self.oneshots.push_back((port_mapping.local, oneshot)); + self.agent_connection + .sender + .send(ClientMessage::TcpOutgoing(LayerTcpOutgoing::Connect( + LayerConnect { remote_address }, + ))) + .await?; + } + PortForwardMessage::Send(connection_id, bytes) => { + self.agent_connection + .sender + .send(ClientMessage::TcpOutgoing(LayerTcpOutgoing::Write( + LayerWrite { + connection_id, + bytes, + }, + ))) + .await?; + } + PortForwardMessage::Close(port_mapping, connection_id) => { + self.task_txs.remove(&port_mapping.local); + if let Some(connection_id) = connection_id { + self.agent_connection + .sender + .send(ClientMessage::TcpOutgoing(LayerTcpOutgoing::Close( + LayerClose { connection_id }, + ))) + .await?; + self.sockets.remove(&connection_id); + } + } + } + Ok(()) + } +} + +struct LocalConnectionTask { + read_stream: ReaderStream, + write: OwnedWriteHalf, + port_mapping: PortMapping, + task_internal_tx: Sender, + response_rx: Receiver>, +} + +impl LocalConnectionTask { + pub fn new( + stream: TcpStream, + local_socket: SocketAddr, + remote_socket: SocketAddr, + task_internal_tx: Sender, + response_rx: Receiver>, + ) -> Self { + let (read, write) = stream.into_split(); + let read_stream = ReaderStream::with_capacity(read, 64 * 1024); + let port_mapping = PortMapping { + local: local_socket, + remote: remote_socket, + }; + Self { + read_stream, + write, + port_mapping, + task_internal_tx, + response_rx, + } + } + + pub async fn run(&mut self) -> Result<(), PortForwardError> { + let (oneshot_tx, oneshot_rx) = oneshot::channel::(); + + // lazy connection: wait until data starts + let first = match self.read_stream.next().await { + Some(Ok(data)) => data, + Some(Err(error)) => return Err(PortForwardError::TcpListenerError(error)), + None => { + // stream ended without sending data + let _ = self + .task_internal_tx + .send(PortForwardMessage::Close(self.port_mapping.clone(), None)) + .await; + return Ok(()); + } + }; + + match self + .task_internal_tx + .send(PortForwardMessage::Connect( + self.port_mapping.clone(), + oneshot_tx, + )) + .await + { + Ok(_) => (), + Err(error) => { + tracing::warn!( + "failed to send connection request to PortForwarder on internal channel: {error}" + ); + } + }; + let connection_id = match oneshot_rx.await { + Ok(connection_id) => connection_id, + Err(error) => { + tracing::warn!( + "failed to receive connection ID from PortForwarder on internal channel: {error}" + ); + let _ = self + .task_internal_tx + .send(PortForwardMessage::Close(self.port_mapping.clone(), None)) + .await; + return Ok(()); + } + }; + match self + .task_internal_tx + .send(PortForwardMessage::Send(connection_id, first.into())) + .await + { + Ok(_) => (), + Err(error) => { + tracing::warn!("failed to send data to main loop: {error}"); + } + }; + + let result: Result<(), PortForwardError> = loop { + select! { + message = self.read_stream.next() => match message { + Some(Ok(message)) => { + match self.task_internal_tx + .send(PortForwardMessage::Send(connection_id, message.into())) + .await + { + Ok(_) => (), + Err(error) => { + tracing::warn!("failed to send data to main loop: {error}"); + } + }; + }, + Some(Err(error)) => { + tracing::warn!( + %error, + port_mapping = ?self.port_mapping, + "local connection failed", + ); + break Ok(()); + }, + None => { + break Ok(()); + }, + }, + + message = self.response_rx.recv() => match message { + Some(message) => { + match self.write.write_all(message.as_ref()).await { + Ok(_) => continue, + Err(error) => { + tracing::error!( + %error, + port_mapping = ?self.port_mapping, + "local connection failed", + ); + break Ok(()); + }, + } + }, + None => break Ok(()), + } + } + }; + + let _ = self + .task_internal_tx + .send(PortForwardMessage::Close( + self.port_mapping.clone(), + Some(connection_id), + )) + .await; + result + } +} + +#[derive(Debug, Error)] +pub enum PortForwardError { + // setup errors + #[error("multiple port forwarding mappings found for local address `{0}`")] + PortMapSetupError(SocketAddr), + + #[error("no port forwarding mappings were provided")] + NoMappingsError(), + + // running errors + #[error("agent closed connection with error: `{0}`")] + AgentError(String), + + #[error("connection with the agent failed")] + AgentConnectionFailed, + + #[error("failed to send Ping to agent: `{0}`")] + PingError(String), + + #[error("TcpListener operation failed with error: `{0}`")] + TcpListenerError(std::io::Error), + + #[error("no destination address found for local address `{0}`")] + SocketMappingNotFound(SocketAddr), + + #[error("no task for socket {0} ready to receive connection ID: `{1}`")] + ReadyTaskNotFound(SocketAddr, ConnectionId), + + #[error("failed to establish connection with remote process: `{0}`")] + ConnectionError(String), +} + +impl From> for PortForwardError { + fn from(_: mpsc::error::SendError) -> Self { + Self::AgentConnectionFailed + } +} + +#[cfg(test)] +mod test { + use mirrord_protocol::{ + outgoing::{ + tcp::{DaemonTcpOutgoing, LayerTcpOutgoing}, + DaemonConnect, DaemonRead, LayerConnect, LayerWrite, SocketAddress, + }, + ClientMessage, DaemonMessage, + }; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream}, + sync::mpsc, + }; + + use crate::{connection::AgentConnection, port_forward::PortForwarder, PortMapping}; + + #[tokio::test] + async fn test_port_forwarding() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_destination = listener.local_addr().unwrap(); + drop(listener); + + let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); + let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); + + let agent_connection = AgentConnection { + sender: client_msg_tx, + receiver: daemon_msg_rx, + }; + let parsed_mappings = vec![PortMapping { + local: local_destination, + remote: "152.37.40.40:3038".parse().unwrap(), + }]; + + let mut port_forwarder = PortForwarder::new(agent_connection, parsed_mappings) + .await + .unwrap(); + tokio::spawn(async move { port_forwarder.run().await.unwrap() }); + + // send data to socket + let mut stream = TcpStream::connect(local_destination).await.unwrap(); + stream.write_all(b"data-my-beloved").await.unwrap(); + + // expect handshake procedure + let expected = Some(ClientMessage::SwitchProtocolVersion( + mirrord_protocol::VERSION.clone(), + )); + assert_eq!(client_msg_rx.recv().await, expected); + daemon_msg_tx + .send(DaemonMessage::SwitchProtocolVersionResponse( + mirrord_protocol::VERSION.clone(), + )) + .await + .unwrap(); + let expected = Some(ClientMessage::ReadyForLogs); + assert_eq!(client_msg_rx.recv().await, expected); + + // expect Connect on client_msg_rx + let remote_address = SocketAddress::Ip("152.37.40.40:3038".parse().unwrap()); + let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Connect(LayerConnect { + remote_address: remote_address.clone(), + })); + let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { + ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), + other => other, + }; + assert_eq!(message, expected); + + // reply with successful on daemon_msg_tx + daemon_msg_tx + .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Connect(Ok( + DaemonConnect { + connection_id: 1, + remote_address: remote_address.clone(), + local_address: remote_address, + }, + )))) + .await + .unwrap(); + + let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Write(LayerWrite { + connection_id: 1, + bytes: b"data-my-beloved".to_vec(), + })); + let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { + ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), + other => other, + }; + assert_eq!(message, expected); + + // send response data from agent on daemon_msg_tx + daemon_msg_tx + .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Read(Ok( + DaemonRead { + connection_id: 1, + bytes: b"reply-my-beloved".to_vec(), + }, + )))) + .await + .unwrap(); + + // check data arrives at local + let mut buf = [0; 16]; + stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, b"reply-my-beloved".as_ref()); + } + + #[tokio::test] + async fn test_multiple_mappings_forwarding() { + let remote_destination_1 = "152.37.40.40:1018".parse().unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_destination_1 = listener.local_addr().unwrap(); + drop(listener); + + let remote_destination_2 = "152.37.40.40:2028".parse().unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_destination_2 = listener.local_addr().unwrap(); + drop(listener); + + let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); + let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); + + let agent_connection = AgentConnection { + sender: client_msg_tx, + receiver: daemon_msg_rx, + }; + let parsed_mappings = vec![ + PortMapping { + local: local_destination_1, + remote: remote_destination_1, + }, + PortMapping { + local: local_destination_2, + remote: remote_destination_2, + }, + ]; + + let mut port_forwarder = PortForwarder::new(agent_connection, parsed_mappings) + .await + .unwrap(); + tokio::spawn(async move { port_forwarder.run().await.unwrap() }); + + // send data to first socket + let mut stream_1 = TcpStream::connect(local_destination_1).await.unwrap(); + + // expect handshake procedure + let expected = Some(ClientMessage::SwitchProtocolVersion( + mirrord_protocol::VERSION.clone(), + )); + assert_eq!(client_msg_rx.recv().await, expected); + daemon_msg_tx + .send(DaemonMessage::SwitchProtocolVersionResponse( + mirrord_protocol::VERSION.clone(), + )) + .await + .unwrap(); + let expected = Some(ClientMessage::ReadyForLogs); + assert_eq!(client_msg_rx.recv().await, expected); + + // expect each Connect on client_msg_rx with correct mappings when data has been written + // (lazy) + stream_1.write_all(b"data-from-1").await.unwrap(); + let remote_address_1 = SocketAddress::Ip(remote_destination_1); + let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Connect(LayerConnect { + remote_address: remote_address_1.clone(), + })); + let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { + ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), + other => other, + }; + assert_eq!(message, expected); + + // send data to second socket + let mut stream_2 = TcpStream::connect(local_destination_2).await.unwrap(); + let remote_address_2 = SocketAddress::Ip(remote_destination_2); + stream_2.write_all(b"data-from-2").await.unwrap(); + + let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Connect(LayerConnect { + remote_address: remote_address_2.clone(), + })); + let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { + ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), + other => other, + }; + assert_eq!(message, expected); + + // reply with successful on each daemon_msg_tx + daemon_msg_tx + .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Connect(Ok( + DaemonConnect { + connection_id: 1, + remote_address: remote_address_1.clone(), + local_address: remote_address_1, + }, + )))) + .await + .unwrap(); + daemon_msg_tx + .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Connect(Ok( + DaemonConnect { + connection_id: 2, + remote_address: remote_address_2.clone(), + local_address: remote_address_2, + }, + )))) + .await + .unwrap(); + + // expect data to be received + let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Write(LayerWrite { + connection_id: 1, + bytes: b"data-from-1".to_vec(), + })); + let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { + ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), + other => other, + }; + assert_eq!(message, expected); + + let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Write(LayerWrite { + connection_id: 2, + bytes: b"data-from-2".to_vec(), + })); + let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { + ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), + other => other, + }; + assert_eq!(message, expected); + + // send each data response from agent on daemon_msg_tx + daemon_msg_tx + .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Read(Ok( + DaemonRead { + connection_id: 1, + bytes: b"reply-to-1".to_vec(), + }, + )))) + .await + .unwrap(); + daemon_msg_tx + .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Read(Ok( + DaemonRead { + connection_id: 2, + bytes: b"reply-to-2".to_vec(), + }, + )))) + .await + .unwrap(); + + // check data arrives at each local addr + let mut buf = [0; 10]; + stream_1.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, b"reply-to-1".as_ref()); + let mut buf = [0; 10]; + stream_2.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, b"reply-to-2".as_ref()); + } +}