From a0c317aa389e1a35c7ab8789b793c4476b906fba Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 14 Jan 2025 10:27:12 +0100 Subject: [PATCH 01/60] Integration test --- changelog.d/3013.fixed.md | 1 + .../src/proxies/incoming/interceptor.rs | 8 +- mirrord/layer/tests/apps/app_chunked.py | 33 ++++ mirrord/layer/tests/common/mod.rs | 17 +- mirrord/layer/tests/issue3013.rs | 184 ++++++++++++++++++ tests/src/traffic/steal.rs | 28 ++- 6 files changed, 247 insertions(+), 24 deletions(-) create mode 100644 changelog.d/3013.fixed.md create mode 100644 mirrord/layer/tests/apps/app_chunked.py create mode 100644 mirrord/layer/tests/issue3013.rs diff --git a/changelog.d/3013.fixed.md b/changelog.d/3013.fixed.md new file mode 100644 index 00000000000..794d3f2c5f3 --- /dev/null +++ b/changelog.d/3013.fixed.md @@ -0,0 +1 @@ +Fixed an issue where HTTP requests stolen with a filter would hang with a single-threaded local HTTP server. diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs index 2d6486d709f..0cae486df0e 100644 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ b/mirrord/intproxy/src/proxies/incoming/interceptor.rs @@ -327,7 +327,7 @@ impl HttpConnection { HttpRequestFallback::Framed(..) => { HttpResponse::::from_hyper_response( res, - self.peer.port(), + request.port(), request.connection_id(), request.request_id(), ) @@ -337,7 +337,7 @@ impl HttpConnection { HttpRequestFallback::Fallback(..) => { HttpResponse::>::from_hyper_response( res, - self.peer.port(), + request.port(), request.connection_id(), request.request_id(), ) @@ -349,7 +349,7 @@ impl HttpConnection { { HttpResponse::::from_hyper_response( res, - self.peer.port(), + request.port(), request.connection_id(), request.request_id(), ) @@ -361,7 +361,7 @@ impl HttpConnection { HttpRequestFallback::Streamed { .. } => { HttpResponse::::from_hyper_response( res, - self.peer.port(), + request.port(), request.connection_id(), request.request_id(), ) diff --git a/mirrord/layer/tests/apps/app_chunked.py b/mirrord/layer/tests/apps/app_chunked.py new file mode 100644 index 00000000000..04ffde714b7 --- /dev/null +++ b/mirrord/layer/tests/apps/app_chunked.py @@ -0,0 +1,33 @@ +from http.server import HTTPServer, BaseHTTPRequestHandler +import time; + +class ChunkedHTTPHandler(BaseHTTPRequestHandler): + protocol_version = "HTTP/1.1" + + def do_GET(self): + # Send response headers + self.send_response(200) + self.send_header("Content-Type", "text/plain") + self.send_header("Transfer-Encoding", "chunked") + self.end_headers() + + # Send the response in chunks + chunks = [ + "This is the first chunk.\n"*8000, + "This is the second chunk.\n" + ] + for chunk in chunks: + time.sleep(3.0) + # Write the chunk size in hexadecimal followed by the chunk data + self.wfile.write(f"{len(chunk):X}\r\n".encode('utf-8')) + self.wfile.write(chunk.encode('utf-8')) + self.wfile.write(b"\r\n") + + # Signal the end of the response + self.wfile.write(b"0\r\n\r\n") + +if __name__ == "__main__": + port = 80 + print(f"Starting server on port {port}") + server = HTTPServer(("0.0.0.0", port), ChunkedHTTPHandler) + server.serve_forever() diff --git a/mirrord/layer/tests/common/mod.rs b/mirrord/layer/tests/common/mod.rs index 22945790e4b..dbcb92adc90 100644 --- a/mirrord/layer/tests/common/mod.rs +++ b/mirrord/layer/tests/common/mod.rs @@ -782,10 +782,13 @@ pub enum Application { /// Mode to use when opening the file, accepted as `-m` param. mode: u32, }, - // For running applications with the executable and arguments determined at runtime. + /// For running applications with the executable and arguments determined at runtime. DynamicApp(String, Vec), - // Go app that only checks whether Linux pidfd syscalls are supported. + /// Go app that only checks whether Linux pidfd syscalls are supported. Go23Issue2988, + /// Python HTTP server that returns large (200kb) chunked responses + /// and processes one request at a time. + PythonHTTPChunked, } impl Application { @@ -814,7 +817,8 @@ impl Application { Application::PythonFlaskHTTP | Application::PythonSelfConnect | Application::PythonDontLoad - | Application::PythonListen => Self::get_python3_executable().await, + | Application::PythonListen + | Application::PythonHTTPChunked => Self::get_python3_executable().await, Application::PythonFastApiHTTP | Application::PythonIssue864 => String::from("uvicorn"), Application::Fork => String::from("tests/apps/fork/out.c_test_app"), Application::ReadLink => String::from("tests/apps/readlink/out.c_test_app"), @@ -1105,6 +1109,10 @@ impl Application { ] } Application::DynamicApp(_, args) => args.to_owned(), + Application::PythonHTTPChunked => { + app_path.push("app_chunked.py"); + vec![String::from("-u"), app_path.to_string_lossy().to_string()] + } } } @@ -1118,7 +1126,8 @@ impl Application { | Application::Go23FileOps | Application::NodeHTTP | Application::RustIssue1054 - | Application::PythonFlaskHTTP => 80, + | Application::PythonFlaskHTTP + | Application::PythonHTTPChunked => 80, // mapped from 9999 in `configs/port_mapping.json` Application::PythonFastApiHTTP | Application::PythonIssue864 => 1234, Application::RustIssue1123 => 41222, diff --git a/mirrord/layer/tests/issue3013.rs b/mirrord/layer/tests/issue3013.rs new file mode 100644 index 00000000000..3e3870d3f3e --- /dev/null +++ b/mirrord/layer/tests/issue3013.rs @@ -0,0 +1,184 @@ +#![feature(assert_matches)] +#![warn(clippy::indexing_slicing)] +use std::{ + path::{Path, PathBuf}, + time::Duration, +}; + +use hyper::{ + header::{HeaderName, HeaderValue}, + Method, StatusCode, Version, +}; +use mirrord_protocol::{ + self, + file::{ + CloseFileRequest, OpenFileRequest, OpenFileResponse, ReadFileRequest, ReadFileResponse, + }, + tcp::{ + DaemonTcp, HttpRequest, HttpResponse, InternalHttpRequest, InternalHttpResponse, + LayerTcpSteal, StealType, + }, + ClientMessage, DaemonMessage, FileRequest, FileResponse, +}; +use rstest::rstest; + +mod common; + +pub use common::*; + +/// Verifies that [issue 3013](https://github.com/metalbear-co/mirrord/issues/3013) is resolved. +/// +/// The issue was that the first request was leaving behind a lingering HTTP connection, that was in +/// turn blocking the local application. The lingering connection was not a bug on our side, but +/// still we can handle this case more smoothly. +#[rstest] +#[tokio::test] +#[timeout(Duration::from_secs(60))] +async fn issue_3013(dylib_path: &Path) { + let config_file = tempfile::tempdir().unwrap(); + let config = serde_json::json!( + { + "feature": { + "network": { + "outgoing": false, + "dns": false, + "incoming": { + "mode": "steal", + "http_filter": { + "header_filter": "x-filter: yes", + }, + } + }, + "fs": { + "mode": "local", + } + }, + } + ); + let config_path = config_file.path().join("config.json"); + tokio::fs::write(&config_path, serde_json::to_string_pretty(&config).unwrap()) + .await + .unwrap(); + + let (test_process, mut test_intproxy) = Application::PythonHTTPChunked + .start_process_with_layer(dylib_path, vec![], Some(&config_path.to_string_lossy())) + .await; + + match test_intproxy.recv().await { + ClientMessage::FileRequest(FileRequest::Open(OpenFileRequest { path, .. })) => { + assert_eq!(path, PathBuf::from("/etc/hostname")); + } + other => panic!("unexpected message from intproxy: {other:?}"), + } + test_intproxy + .send(DaemonMessage::File(FileResponse::Open(Ok( + OpenFileResponse { fd: 2137 }, + )))) + .await; + match test_intproxy.recv().await { + ClientMessage::FileRequest(FileRequest::Read(ReadFileRequest { + remote_fd: 2137, .. + })) => {} + other => panic!("unexpected message from intproxy: {other:?}"), + } + test_intproxy + .send(DaemonMessage::File(FileResponse::Read(Ok( + ReadFileResponse { + bytes: "test-hostname".as_bytes().into(), + read_amount: 13, + }, + )))) + .await; + match test_intproxy.recv().await { + ClientMessage::FileRequest(FileRequest::Close(CloseFileRequest { fd: 2137 })) => {} + other => panic!("unexpected message from intproxy: {other:?}"), + } + + match test_intproxy.recv().await { + ClientMessage::TcpSteal(LayerTcpSteal::PortSubscribe(StealType::FilteredHttpEx(..))) => {} + other => panic!("unexpected message from intproxy: {other:?}"), + } + test_intproxy + .send(DaemonMessage::TcpSteal(DaemonTcp::SubscribeResult(Ok(80)))) + .await; + test_process + .wait_for_line(Duration::from_secs(40), "daemon subscribed") + .await; + println!("The application subscribed the port"); + + println!("Sending the first request to the intproxy"); + test_intproxy + .send(DaemonMessage::TcpSteal(DaemonTcp::HttpRequestFramed( + HttpRequest { + internal_request: InternalHttpRequest { + method: Method::GET, + uri: "/some/path".parse().unwrap(), + headers: [( + HeaderName::from_static("connection"), + HeaderValue::from_static("keep-alive"), + )] + .into_iter() + .collect(), + version: Version::HTTP_11, + body: Default::default(), + }, + connection_id: 0, + request_id: 0, + port: 80, + }, + ))) + .await; + match test_intproxy.recv().await { + ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(HttpResponse { + port: 80, + connection_id: 0, + request_id: 0, + internal_response: + InternalHttpResponse { + status: StatusCode::OK, + version: Version::HTTP_11, + .. + }, + })) => {} + other => panic!("unexpected message from intproxy: {other:?}"), + } + println!("Received the first response from the intproxy"); + + println!("Sending the second request to the intproxy, without closing the first connection"); + test_intproxy + .send(DaemonMessage::TcpSteal(DaemonTcp::HttpRequestFramed( + HttpRequest { + internal_request: InternalHttpRequest { + method: Method::GET, + uri: "/some/path".parse().unwrap(), + headers: [( + HeaderName::from_static("connection"), + HeaderValue::from_static("keep-alive"), + )] + .into_iter() + .collect(), + version: Version::HTTP_11, + body: Default::default(), + }, + connection_id: 1, + request_id: 0, + port: 80, + }, + ))) + .await; + match test_intproxy.recv().await { + ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(HttpResponse { + port: 80, + connection_id: 1, + request_id: 0, + internal_response: + InternalHttpResponse { + status: StatusCode::OK, + version: Version::HTTP_11, + .. + }, + })) => {} + other => panic!("unexpected message from intproxy: {other:?}"), + } + println!("Received the second response from the intproxy"); +} diff --git a/tests/src/traffic/steal.rs b/tests/src/traffic/steal.rs index 31b5f668a2b..bf417813cf0 100644 --- a/tests/src/traffic/steal.rs +++ b/tests/src/traffic/steal.rs @@ -1,21 +1,17 @@ -#![allow(dead_code, unused)] #[cfg(test)] mod steal_tests { - use std::{ - io::{BufRead, BufReader, Read, Write}, - net::{SocketAddr, TcpStream}, - path::Path, - time::Duration, - }; + use std::{net::SocketAddr, path::Path, time::Duration}; use futures_util::{SinkExt, StreamExt}; - use hyper::{body, client::conn, Request, StatusCode}; - use hyper_util::rt::TokioIo; use k8s_openapi::api::core::v1::Pod; use kube::{Api, Client}; use reqwest::{header::HeaderMap, Url}; use rstest::*; - use tokio::time::sleep; + use tokio::{ + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, + net::TcpStream, + time::sleep, + }; use tokio_tungstenite::{ connect_async, tungstenite::{client::IntoClientRequest, Message}, @@ -280,7 +276,7 @@ mod steal_tests { .wait_for_line(Duration::from_secs(40), "daemon subscribed") .await; - let mut tcp_stream = TcpStream::connect((addr, port as u16)).unwrap(); + let mut tcp_stream = TcpStream::connect((addr, port as u16)).await.unwrap(); // Wait for the test app to close the socket and tell us about it. process @@ -289,10 +285,10 @@ mod steal_tests { const DATA: &[u8; 16] = b"upper me please\n"; - tcp_stream.write_all(DATA).unwrap(); + tcp_stream.write_all(DATA).await.unwrap(); let mut response = [0u8; DATA.len()]; - tcp_stream.read_exact(&mut response).unwrap(); + tcp_stream.read_exact(&mut response).await.unwrap(); process .write_to_stdin(b"Hey test app, please stop running and just exit successfuly.\n") @@ -625,11 +621,11 @@ mod steal_tests { .await; let addr = SocketAddr::new(host.trim().parse().unwrap(), port as u16); - let mut stream = TcpStream::connect(addr).unwrap(); - stream.write_all(tcp_data.as_bytes()).unwrap(); + let mut stream = TcpStream::connect(addr).await.unwrap(); + stream.write_all(tcp_data.as_bytes()).await.unwrap(); let mut reader = BufReader::new(stream); let mut buf = String::new(); - reader.read_line(&mut buf).unwrap(); + reader.read_line(&mut buf).await.unwrap(); println!("Got response: {buf}"); // replace "remote: " with empty string, since the response can be split into frames // and we just need assert the final response From f7359bc1e89752ec65df8c80928b861d2be01e11 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 14 Jan 2025 10:39:10 +0100 Subject: [PATCH 02/60] Moved MetadataStore to a separate module --- mirrord/intproxy/src/proxies/incoming.rs | 36 ++------------ .../src/proxies/incoming/metadata_store.rs | 48 +++++++++++++++++++ 2 files changed, 52 insertions(+), 32 deletions(-) create mode 100644 mirrord/intproxy/src/proxies/incoming/metadata_store.rs diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 966d1175acb..ed4cf0bf23e 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -11,6 +11,7 @@ use futures::StreamExt; use http::RETRY_ON_RESET_ATTEMPTS; use http_body_util::StreamBody; use hyper::body::Frame; +use metadata_store::MetadataStore; use mirrord_intproxy_protocol::{ ConnMetadataRequest, ConnMetadataResponse, IncomingRequest, IncomingResponse, LayerId, MessageId, PortSubscribe, PortSubscription, PortUnsubscribe, ProxyToLayerMessage, @@ -46,6 +47,7 @@ use crate::{ mod http; mod interceptor; +mod metadata_store; pub mod port_subscription_ext; mod subscriptions; @@ -120,36 +122,6 @@ struct InterceptorHandle { subscription: PortSubscription, } -/// Store for mapping [`Interceptor`] socket addresses to addresses of the original peers. -#[derive(Default)] -struct MetadataStore { - prepared_responses: HashMap, - expected_requests: HashMap, -} - -impl MetadataStore { - fn get(&mut self, req: ConnMetadataRequest) -> ConnMetadataResponse { - self.prepared_responses - .remove(&req) - .unwrap_or_else(|| ConnMetadataResponse { - remote_source: req.peer_address, - local_address: req.listener_address.ip(), - }) - } - - fn expect(&mut self, req: ConnMetadataRequest, from: InterceptorId, res: ConnMetadataResponse) { - self.expected_requests.insert(from, req.clone()); - self.prepared_responses.insert(req, res); - } - - fn no_longer_expect(&mut self, from: InterceptorId) { - let Some(req) = self.expected_requests.remove(&from) else { - return; - }; - self.prepared_responses.remove(&req); - } -} - /// Handles logic and state of the `incoming` feature. /// Run as a [`BackgroundTask`]. /// @@ -390,7 +362,7 @@ impl IncomingProxy { listener_address: subscription.listening_on, peer_address: interceptor_socket.local_addr()?, }, - id, + id.0, ConnMetadataResponse { remote_source: SocketAddr::new(remote_address, source_port), local_address, @@ -527,7 +499,7 @@ impl BackgroundTask for IncomingProxy { (id, TaskUpdate::Finished(res)) => { tracing::trace!("{id} finished: {res:?}"); - self.metadata_store.no_longer_expect(id); + self.metadata_store.no_longer_expect(id.0); let msg = self.get_subscription(id).map(|s| s.wrap_agent_unsubscribe_connection(id.0)); if let Some(msg) = msg { diff --git a/mirrord/intproxy/src/proxies/incoming/metadata_store.rs b/mirrord/intproxy/src/proxies/incoming/metadata_store.rs new file mode 100644 index 00000000000..fc15d7e9d4b --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/metadata_store.rs @@ -0,0 +1,48 @@ +use std::collections::HashMap; + +use mirrord_intproxy_protocol::{ConnMetadataRequest, ConnMetadataResponse}; +use mirrord_protocol::ConnectionId; + +/// Maps local socket address pairs to remote. +/// +/// Allows for extracting the original socket addresses of peers of a remote connection. +#[derive(Default)] +pub struct MetadataStore { + prepared_responses: HashMap, + expected_requests: HashMap, +} + +impl MetadataStore { + /// Retrieves remote addresses for the given pair of local addresses. + /// + /// If the mapping is not found, returns the local addresses unchanged. + pub fn get(&mut self, req: ConnMetadataRequest) -> ConnMetadataResponse { + self.prepared_responses + .remove(&req) + .unwrap_or_else(|| ConnMetadataResponse { + remote_source: req.peer_address, + local_address: req.listener_address.ip(), + }) + } + + /// Adds a new `req`->`res` mapping to this struct. + /// + /// Marks that the mapping is related to the remote connection with the given id. + pub fn expect( + &mut self, + req: ConnMetadataRequest, + connection: ConnectionId, + res: ConnMetadataResponse, + ) { + self.expected_requests.insert(connection, req.clone()); + self.prepared_responses.insert(req, res); + } + + /// Clears mapping related to the remote connection with the given id. + pub fn no_longer_expect(&mut self, connection: ConnectionId) { + let Some(req) = self.expected_requests.remove(&connection) else { + return; + }; + self.prepared_responses.remove(&req); + } +} From 48bac2722994310307417e423f42eb1bdda11740 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 14 Jan 2025 10:53:47 +0100 Subject: [PATCH 03/60] bind_similar -> BoundTcpSocket --- mirrord/intproxy/src/proxies/incoming.rs | 47 ++++--------------- .../src/proxies/incoming/bound_socket.rs | 46 ++++++++++++++++++ .../src/proxies/incoming/interceptor.rs | 19 ++++---- 3 files changed, 64 insertions(+), 48 deletions(-) create mode 100644 mirrord/intproxy/src/proxies/incoming/bound_socket.rs diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index ed4cf0bf23e..cb3093e03a4 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -3,9 +3,10 @@ use std::{ collections::{hash_map::Entry, HashMap}, fmt, io, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + net::SocketAddr, }; +use bound_socket::BoundTcpSocket; use bytes::Bytes; use futures::StreamExt; use http::RETRY_ON_RESET_ATTEMPTS; @@ -27,10 +28,7 @@ use mirrord_protocol::{ ClientMessage, ConnectionId, RequestId, ResponseError, }; use thiserror::Error; -use tokio::{ - net::TcpSocket, - sync::mpsc::{self, Sender}, -}; +use tokio::sync::mpsc::{self, Sender}; use tokio_stream::{wrappers::ReceiverStream, StreamMap, StreamNotifyClose}; use tracing::{debug, Level}; @@ -45,44 +43,13 @@ use crate::{ ProxyMessage, }; +mod bound_socket; mod http; mod interceptor; mod metadata_store; pub mod port_subscription_ext; mod subscriptions; -/// Creates and binds a new [`TcpSocket`]. -/// The socket has the same IP version and address as the given `addr`. -/// -/// # Exception -/// -/// If the given `addr` is unspecified, this function binds to localhost. -#[tracing::instrument(level = Level::TRACE, ret, err)] -fn bind_similar(addr: SocketAddr) -> io::Result { - match addr.ip() { - IpAddr::V4(Ipv4Addr::UNSPECIFIED) => { - let socket = TcpSocket::new_v4()?; - socket.bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))?; - Ok(socket) - } - IpAddr::V6(Ipv6Addr::UNSPECIFIED) => { - let socket = TcpSocket::new_v6()?; - socket.bind(SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 0))?; - Ok(socket) - } - addr @ IpAddr::V4(..) => { - let socket = TcpSocket::new_v4()?; - socket.bind(SocketAddr::new(addr, 0))?; - Ok(socket) - } - addr @ IpAddr::V6(..) => { - let socket = TcpSocket::new_v6()?; - socket.bind(SocketAddr::new(addr, 0))?; - Ok(socket) - } - } -} - /// Id of a single [`Interceptor`] task. Used to manage interceptor tasks with the /// [`BackgroundTasks`] struct. #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] @@ -212,7 +179,8 @@ impl IncomingProxy { return Ok(None); }; - let interceptor_socket = bind_similar(subscription.listening_on)?; + let interceptor_socket = + BoundTcpSocket::bind_specified_or_localhost(subscription.listening_on.ip())?; let interceptor = self.background_tasks.register( Interceptor::new( @@ -353,7 +321,8 @@ impl IncomingProxy { return Ok(()); }; - let interceptor_socket = bind_similar(subscription.listening_on)?; + let interceptor_socket = + BoundTcpSocket::bind_specified_or_localhost(subscription.listening_on.ip())?; let id = InterceptorId(connection_id); diff --git a/mirrord/intproxy/src/proxies/incoming/bound_socket.rs b/mirrord/intproxy/src/proxies/incoming/bound_socket.rs new file mode 100644 index 00000000000..1c6cbef385a --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/bound_socket.rs @@ -0,0 +1,46 @@ +use std::{ + fmt, io, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, +}; + +use tokio::net::{TcpSocket, TcpStream}; +use tracing::Level; + +/// A TCP socket that is already bound. +/// +/// Provides a nicer [`fmt::Debug`] implementation than [`TcpSocket`]. +pub struct BoundTcpSocket(TcpSocket); + +impl BoundTcpSocket { + /// Opens a new TCP socket and binds it to the given IP address and a random port. + /// If the given IP address is not specified, binds the socket to localhost instead. + #[tracing::instrument(level = Level::TRACE, ret, err)] + pub fn bind_specified_or_localhost(ip: IpAddr) -> io::Result { + let (socket, ip) = match ip { + IpAddr::V4(Ipv4Addr::UNSPECIFIED) => (TcpSocket::new_v4()?, Ipv4Addr::LOCALHOST.into()), + IpAddr::V6(Ipv6Addr::UNSPECIFIED) => (TcpSocket::new_v6()?, Ipv6Addr::LOCALHOST.into()), + addr @ IpAddr::V4(..) => (TcpSocket::new_v4()?, addr), + addr @ IpAddr::V6(..) => (TcpSocket::new_v6()?, addr), + }; + + socket.bind(SocketAddr::new(ip, 0))?; + + Ok(Self(socket)) + } + + /// Returns the address to which this socket is bound. + pub fn local_addr(&self) -> io::Result { + self.0.local_addr() + } + + /// Makes a connection to the given peer. + pub async fn connect(self, peer: SocketAddr) -> io::Result { + self.0.connect(peer).await + } +} + +impl fmt::Debug for BoundTcpSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.local_addr().fmt(f) + } +} diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs index 0cae486df0e..bcf68aef2fe 100644 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ b/mirrord/intproxy/src/proxies/incoming/interceptor.rs @@ -19,7 +19,7 @@ use mirrord_protocol::tcp::{ use thiserror::Error; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, - net::{TcpSocket, TcpStream}, + net::TcpStream, time::{self, sleep}, }; use tracing::Level; @@ -27,7 +27,7 @@ use tracing::Level; use super::http::HttpSender; use crate::{ background_tasks::{BackgroundTask, MessageBus}, - proxies::incoming::http::RETRY_ON_RESET_ATTEMPTS, + proxies::incoming::{bound_socket::BoundTcpSocket, http::RETRY_ON_RESET_ATTEMPTS}, }; /// Messages consumed by the [`Interceptor`] when it runs as a [`BackgroundTask`]. @@ -126,7 +126,7 @@ pub type InterceptorResult = core::result::Result /// When it received [`MessageIn::Http`], it starts acting as an HTTP gateway. pub struct Interceptor { /// Socket that should be used to make the first connection (should already be bound). - socket: TcpSocket, + socket: BoundTcpSocket, /// Address of user app's listener. peer: SocketAddr, /// Version of [`mirrord_protocol`] negotiated with the agent. @@ -141,7 +141,7 @@ impl Interceptor { /// /// The socket can be replaced when retrying HTTP requests. pub fn new( - socket: TcpSocket, + socket: BoundTcpSocket, peer: SocketAddr, agent_protocol_version: Option, ) -> Self { @@ -418,7 +418,7 @@ impl HttpConnection { sleep(backoff).await; // Create a new connection for the next attempt. - let socket = super::bind_similar(self.peer)?; + let socket = BoundTcpSocket::bind_specified_or_localhost(self.peer.ip())?; let stream = socket.connect(self.peer).await?; let new_sender = super::http::handshake(request.version(), stream).await?; self.sender = new_sender; @@ -554,6 +554,7 @@ impl RawConnection { mod test { use std::{ convert::Infallible, + net::Ipv4Addr, sync::{Arc, Mutex}, }; @@ -681,8 +682,8 @@ mod test { let mut tasks: BackgroundTasks<(), MessageOut, InterceptorError> = Default::default(); let interceptor = { - let socket = TcpSocket::new_v4().unwrap(); - socket.bind("127.0.0.1:0".parse().unwrap()).unwrap(); + let socket = + BoundTcpSocket::bind_specified_or_localhost(Ipv4Addr::LOCALHOST.into()).unwrap(); tasks.register( Interceptor::new( socket, @@ -765,8 +766,8 @@ mod test { let local_destination = listener.local_addr().unwrap(); let mut tasks: BackgroundTasks<(), MessageOut, InterceptorError> = Default::default(); - let socket = TcpSocket::new_v4().unwrap(); - socket.bind("127.0.0.1:0".parse().unwrap()).unwrap(); + let socket = + BoundTcpSocket::bind_specified_or_localhost(Ipv4Addr::LOCALHOST.into()).unwrap(); let interceptor = Interceptor::new( socket, local_destination, From 5798219b8d1b27d8e1d482db4a3cfbee5d834692 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 14 Jan 2025 11:03:11 +0100 Subject: [PATCH 04/60] BodyExt -> BatchedBody, reworked trait --- Cargo.lock | 1 + mirrord/agent/src/steal/connection.rs | 6 +- mirrord/intproxy/src/proxies/incoming.rs | 5 +- mirrord/protocol/Cargo.toml | 1 + mirrord/protocol/src/batched_body.rs | 86 ++++++++++++++++++++++++ mirrord/protocol/src/body_chunks.rs | 71 ------------------- mirrord/protocol/src/lib.rs | 2 +- mirrord/protocol/src/tcp.rs | 4 +- 8 files changed, 96 insertions(+), 80 deletions(-) create mode 100644 mirrord/protocol/src/batched_body.rs delete mode 100644 mirrord/protocol/src/body_chunks.rs diff --git a/Cargo.lock b/Cargo.lock index 89528a03ca1..466053b1533 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4485,6 +4485,7 @@ dependencies = [ "bincode", "bytes", "fancy-regex", + "futures", "hickory-proto", "hickory-resolver", "http-body-util", diff --git a/mirrord/agent/src/steal/connection.rs b/mirrord/agent/src/steal/connection.rs index 37435176b8b..fa94556c353 100644 --- a/mirrord/agent/src/steal/connection.rs +++ b/mirrord/agent/src/steal/connection.rs @@ -12,7 +12,7 @@ use hyper::{ http::{header::UPGRADE, request::Parts}, }; use mirrord_protocol::{ - body_chunks::{BodyExt as _, Frames}, + batched_body::{BatchedBody, Frames}, tcp::{ ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, DaemonTcp, HttpRequest, HttpResponseFallback, InternalHttpBody, InternalHttpBodyFrame, InternalHttpRequest, @@ -173,7 +173,7 @@ impl Client { }, mut body, ) = request.request.into_parts(); - match body.next_frames(true).await { + match body.ready_frames() { Err(..) => return, // We don't check is_last here since loop will finish when body.next_frames() // returns None @@ -205,7 +205,7 @@ impl Client { } loop { - match body.next_frames(false).await { + match body.next_frames().await { Ok(Frames { frames, is_last }) => { let frames = frames .into_iter() diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index cb3093e03a4..406bdc31df8 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -18,7 +18,7 @@ use mirrord_intproxy_protocol::{ MessageId, PortSubscribe, PortSubscription, PortUnsubscribe, ProxyToLayerMessage, }; use mirrord_protocol::{ - body_chunks::BodyExt, + batched_body::BatchedBody, tcp::{ ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, ChunkedResponse, DaemonTcp, HttpRequest, HttpRequestFallback, HttpResponse, HttpResponseFallback, InternalHttpBodyFrame, @@ -529,8 +529,7 @@ impl IncomingProxy { match response .internal_response .body - .next_frames(true) - .await + .ready_frames() .map_err(InterceptorError::from) { Ok(frames) => { diff --git a/mirrord/protocol/Cargo.toml b/mirrord/protocol/Cargo.toml index 7d787e1e5c6..8f56636c15e 100644 --- a/mirrord/protocol/Cargo.toml +++ b/mirrord/protocol/Cargo.toml @@ -21,6 +21,7 @@ actix-codec.workspace = true bincode.workspace = true bytes.workspace = true thiserror.workspace = true +futures.workspace = true hickory-resolver.workspace = true hickory-proto.workspace = true serde.workspace = true diff --git a/mirrord/protocol/src/batched_body.rs b/mirrord/protocol/src/batched_body.rs new file mode 100644 index 00000000000..479eb1af61f --- /dev/null +++ b/mirrord/protocol/src/batched_body.rs @@ -0,0 +1,86 @@ +use std::future::Future; + +use futures::FutureExt; +use http_body_util::BodyExt; +use hyper::body::{Body, Frame}; + +/// Utility extension trait for [`Body`]. +/// +/// Contains methods that allow for reading [`Frame`]s in batches. +pub trait BatchedBody: Body { + /// Reads all [`Frame`]s that are available without blocking. + fn ready_frames(&mut self) -> Result, Self::Error>; + + /// Waits for the next [`Frame`] then reads all [`Frame`]s that are available without blocking. + fn next_frames(&mut self) -> impl Future, Self::Error>>; +} + +impl BatchedBody for B +where + B: Body + Unpin, +{ + fn ready_frames(&mut self) -> Result, Self::Error> { + let mut frames = Frames { + frames: vec![], + is_last: false, + }; + + loop { + match self.frame().now_or_never() { + None => { + frames.is_last = false; + break; + } + Some(None) => { + frames.is_last = true; + break; + } + Some(Some(result)) => { + frames.frames.push(result?); + } + } + } + + Ok(frames) + } + + async fn next_frames(&mut self) -> Result, Self::Error> { + let mut frames = Frames { + frames: vec![], + is_last: false, + }; + + match self.frame().await { + None => { + frames.is_last = true; + return Ok(frames); + } + Some(result) => { + frames.frames.push(result?); + } + } + + loop { + match self.frame().now_or_never() { + None => { + frames.is_last = false; + break; + } + Some(None) => { + frames.is_last = true; + break; + } + Some(Some(result)) => { + frames.frames.push(result?); + } + } + } + + Ok(frames) + } +} + +pub struct Frames { + pub frames: Vec>, + pub is_last: bool, +} diff --git a/mirrord/protocol/src/body_chunks.rs b/mirrord/protocol/src/body_chunks.rs deleted file mode 100644 index 19a42c78ae4..00000000000 --- a/mirrord/protocol/src/body_chunks.rs +++ /dev/null @@ -1,71 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use hyper::body::{Body, Frame}; - -pub trait BodyExt { - fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_, B>; -} - -impl BodyExt for B -where - B: Body, -{ - fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_, B> { - FramesFut { - body: self, - no_wait, - } - } -} - -pub struct FramesFut<'a, B> { - body: &'a mut B, - no_wait: bool, -} - -impl Future for FramesFut<'_, B> -where - B: Body + Unpin, -{ - type Output = hyper::Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut frames = vec![]; - - loop { - let result = match Pin::new(&mut self.as_mut().body).poll_frame(cx) { - Poll::Ready(Some(Err(error))) => Poll::Ready(Err(error)), - Poll::Ready(Some(Ok(frame))) => { - frames.push(frame); - continue; - } - Poll::Ready(None) => Poll::Ready(Ok(Frames { - frames, - is_last: true, - })), - Poll::Pending => { - if frames.is_empty() && !self.no_wait { - Poll::Pending - } else { - Poll::Ready(Ok(Frames { - frames, - is_last: false, - })) - } - } - }; - - break result; - } - } -} - -pub struct Frames { - pub frames: Vec>, - pub is_last: bool, -} diff --git a/mirrord/protocol/src/lib.rs b/mirrord/protocol/src/lib.rs index 94a555ce0da..983fcd3536b 100644 --- a/mirrord/protocol/src/lib.rs +++ b/mirrord/protocol/src/lib.rs @@ -3,7 +3,7 @@ #![warn(clippy::indexing_slicing)] #![deny(unused_crate_dependencies)] -pub mod body_chunks; +pub mod batched_body; pub mod codec; pub mod dns; pub mod error; diff --git a/mirrord/protocol/src/tcp.rs b/mirrord/protocol/src/tcp.rs index 023369129ad..efc3ae98905 100644 --- a/mirrord/protocol/src/tcp.rs +++ b/mirrord/protocol/src/tcp.rs @@ -25,7 +25,7 @@ use tokio::sync::mpsc::Receiver; use tokio_stream::wrappers::ReceiverStream; use tracing::{error, Level}; -use crate::{body_chunks::BodyExt as _, ConnectionId, Port, RemoteResult, RequestId}; +use crate::{batched_body::BatchedBody, ConnectionId, Port, RemoteResult, RequestId}; #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub struct NewTcpConnection { @@ -921,7 +921,7 @@ impl HttpResponse { mut body, ) = response.into_parts(); - let frames = body.next_frames(true).await?; + let frames = body.ready_frames()?; let (tx, rx) = tokio::sync::mpsc::channel(frames.frames.len().max(12)); for frame in frames.frames { tx.try_send(Ok(frame)) From 6e2d3bc6c9555a0c233c9abb1a76b54cd83eb1e5 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 14 Jan 2025 15:19:09 +0100 Subject: [PATCH 05/60] local HTTP handling rework --- mirrord/intproxy/src/background_tasks.rs | 52 ++ mirrord/intproxy/src/proxies/incoming.rs | 422 ++++++----- mirrord/intproxy/src/proxies/incoming/http.rs | 436 +++++++++-- .../src/proxies/incoming/interceptor.rs | 699 ++++++------------ mirrord/protocol/src/tcp.rs | 170 +++-- 5 files changed, 984 insertions(+), 795 deletions(-) diff --git a/mirrord/intproxy/src/background_tasks.rs b/mirrord/intproxy/src/background_tasks.rs index e43c8f306b8..50593f47ddd 100644 --- a/mirrord/intproxy/src/background_tasks.rs +++ b/mirrord/intproxy/src/background_tasks.rs @@ -22,6 +22,14 @@ pub struct MessageBus { } impl MessageBus { + /// Wraps this message bus into a struct that allows for peeking the next incoming message. + pub fn peekable(&mut self) -> PeekableMessageBus<'_, T> { + PeekableMessageBus { + peeked: None, + message_bus: self, + } + } + /// Attempts to send a message to this task's parent. pub async fn send>(&self, msg: M) { let _ = self.tx.send(msg.into()).await; @@ -37,6 +45,50 @@ impl MessageBus { } } +/// Wrapper over [`MessageBus`]. +/// +/// Allows for peeking the next incoming message. +pub struct PeekableMessageBus<'a, T: BackgroundTask> { + peeked: Option, + message_bus: &'a mut MessageBus, +} + +impl<'a, T: BackgroundTask> PeekableMessageBus<'a, T> { + /// Attempts to send a message to this task's parent. + pub async fn send>(&self, msg: M) { + let _ = self.message_bus.tx.send(msg.into()).await; + } + + /// Receives a message from this task's parent. + /// [`None`] means that the channel is closed and there will be no more messages. + pub async fn recv(&mut self) -> Option { + if self.peeked.is_some() { + return self.peeked.take(); + } + + tokio::select! { + _ = self.message_bus.tx.closed() => None, + msg = self.message_bus.rx.recv() => msg, + } + } + + /// Peeks the next message from this task's parent. + /// [`None`] means that the channel is closed and there will be no more messages. + pub async fn peek(&mut self) -> Option<&T::MessageIn> { + if self.peeked.is_some() { + return self.peeked.as_ref(); + } + + tokio::select! { + _ = self.message_bus.tx.closed() => None, + msg = self.message_bus.rx.recv() => { + self.peeked = msg; + self.peeked.as_ref() + }, + } + } +} + /// Common trait for all background tasks in the internal proxy. pub trait BackgroundTask: Sized { /// Type of errors that can occur during the execution. diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 406bdc31df8..74175e2eaeb 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -1,35 +1,33 @@ //! Handles the logic of the `incoming` feature. use std::{ - collections::{hash_map::Entry, HashMap}, + collections::{hash_map::Entry, HashMap, VecDeque}, fmt, io, net::SocketAddr, }; use bound_socket::BoundTcpSocket; -use bytes::Bytes; use futures::StreamExt; -use http::RETRY_ON_RESET_ATTEMPTS; -use http_body_util::StreamBody; -use hyper::body::Frame; +use http::PeekedBody; +use http_body_util::BodyStream; +use hyper::body::{Frame, Incoming}; use metadata_store::MetadataStore; use mirrord_intproxy_protocol::{ ConnMetadataRequest, ConnMetadataResponse, IncomingRequest, IncomingResponse, LayerId, MessageId, PortSubscribe, PortSubscription, PortUnsubscribe, ProxyToLayerMessage, }; use mirrord_protocol::{ - batched_body::BatchedBody, tcp::{ - ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, ChunkedResponse, DaemonTcp, HttpRequest, - HttpRequestFallback, HttpResponse, HttpResponseFallback, InternalHttpBodyFrame, - InternalHttpRequest, InternalHttpResponse, LayerTcpSteal, NewTcpConnection, - ReceiverStreamBody, StreamingBody, TcpData, + ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, ChunkedResponse, DaemonTcp, + HttpResponse, InternalHttpBody, InternalHttpBodyFrame, LayerTcp, LayerTcpSteal, + NewTcpConnection, StreamingBody, TcpData, HTTP_CHUNKED_RESPONSE_VERSION, + HTTP_FRAMED_VERSION, }, - ClientMessage, ConnectionId, RequestId, ResponseError, + ClientMessage, ConnectionId, Port, RequestId, ResponseError, }; use thiserror::Error; use tokio::sync::mpsc::{self, Sender}; -use tokio_stream::{wrappers::ReceiverStream, StreamMap, StreamNotifyClose}; +use tokio_stream::{StreamMap, StreamNotifyClose}; use tracing::{debug, Level}; use self::{ @@ -112,7 +110,8 @@ pub struct IncomingProxy { /// For managing streamed [`DaemonTcp::HttpRequestChunked`] request channels. request_body_txs: HashMap<(ConnectionId, RequestId), Sender>, /// For managing streamed [`LayerTcpSteal::HttpResponseChunked`] response streams. - response_body_rxs: StreamMap<(ConnectionId, RequestId), StreamNotifyClose>, + response_body_rxs: + StreamMap<(ConnectionId, RequestId), StreamNotifyClose>>, /// Version of [`mirrord_protocol`] negotiated with the agent. agent_protocol_version: Option, } @@ -159,23 +158,33 @@ impl IncomingProxy { /// Retrieves or creates an [`Interceptor`] for the given [`HttpRequestFallback`]. /// The request may or may not belong to an existing connection (when stealing with an http /// filter, connections are created implicitly). - #[tracing::instrument(level = Level::TRACE, skip(self))] - fn get_interceptor_for_http_request( + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] + async fn get_or_create_http_interceptor( &mut self, - request: &HttpRequestFallback, + connection_id: ConnectionId, + port: Port, + message_bus: &MessageBus, ) -> Result>, IncomingProxyError> { - let id: InterceptorId = InterceptorId(request.connection_id()); + let id: InterceptorId = InterceptorId(connection_id); let interceptor = match self.interceptors.entry(id) { Entry::Occupied(e) => e.into_mut(), Entry::Vacant(e) => { - let Some(subscription) = self.subscriptions.get(request.port()) else { - tracing::trace!( - "received a new connection for port {} that is no longer mirrored", - request.port(), + let Some(subscription) = self.subscriptions.get(port) else { + tracing::debug!( + port, + connection_id, + "Received a new connection for a port that is no longer subscribed, \ + sending an unsubscribe request.", ); + message_bus + .send(ClientMessage::TcpSteal( + LayerTcpSteal::ConnectionUnsubscribe(connection_id), + )) + .await; + return Ok(None); }; @@ -183,11 +192,7 @@ impl IncomingProxy { BoundTcpSocket::bind_specified_or_localhost(subscription.listening_on.ip())?; let interceptor = self.background_tasks.register( - Interceptor::new( - interceptor_socket, - subscription.listening_on, - self.agent_protocol_version.clone(), - ), + Interceptor::new(interceptor_socket, subscription.listening_on), id, Self::CHANNEL_SIZE, ); @@ -207,6 +212,7 @@ impl IncomingProxy { async fn handle_agent_message( &mut self, message: DaemonTcp, + is_steal: bool, message_bus: &mut MessageBus, ) -> Result<(), IncomingProxyError> { match message { @@ -225,90 +231,113 @@ impl IncomingProxy { self.response_body_rxs.remove(key); } } + DaemonTcp::Data(data) => { if let Some(interceptor) = self.interceptors.get(&InterceptorId(data.connection_id)) { interceptor.tx.send(data.bytes).await; } else { - tracing::trace!( - "received new data for connection {} that is already closed", - data.connection_id + tracing::debug!( + connection_id = data.connection_id, + "Received new data for a connection that is already closed", ); } } - DaemonTcp::HttpRequest(req) => { - let req = HttpRequestFallback::Fallback(req); - let interceptor = self.get_interceptor_for_http_request(&req)?; + + DaemonTcp::HttpRequest(request) => { + let interceptor = self + .get_or_create_http_interceptor( + request.connection_id, + request.port, + message_bus, + ) + .await?; + if let Some(interceptor) = interceptor { - interceptor.send(req).await; + interceptor + .send(request.map_body(StreamingBody::from)) + .await; } } - DaemonTcp::HttpRequestFramed(req) => { - let req = HttpRequestFallback::Framed(req); - let interceptor = self.get_interceptor_for_http_request(&req)?; + + DaemonTcp::HttpRequestFramed(request) => { + let interceptor = self + .get_or_create_http_interceptor( + request.connection_id, + request.port, + message_bus, + ) + .await?; + if let Some(interceptor) = interceptor { - interceptor.send(req).await; + interceptor + .send(request.map_body(StreamingBody::from)) + .await; } } - DaemonTcp::HttpRequestChunked(req) => { - match req { - ChunkedRequest::Start(req) => { - let (tx, rx) = mpsc::channel::(128); - let http_stream = StreamingBody::new(rx); - let http_req = HttpRequest { - internal_request: InternalHttpRequest { - method: req.internal_request.method, - uri: req.internal_request.uri, - headers: req.internal_request.headers, - version: req.internal_request.version, - body: http_stream, - }, - connection_id: req.connection_id, - request_id: req.request_id, - port: req.port, - }; - let key = (http_req.connection_id, http_req.request_id); - self.request_body_txs.insert(key, tx.clone()); + DaemonTcp::HttpRequestChunked(request) => { + match request { + ChunkedRequest::Start(request) => { + let interceptor = self + .get_or_create_http_interceptor( + request.connection_id, + request.port, + message_bus, + ) + .await?; - let http_req = HttpRequestFallback::Streamed { - request: http_req, - retries: 0, - }; - let interceptor = self.get_interceptor_for_http_request(&http_req)?; if let Some(interceptor) = interceptor { - interceptor.send(http_req).await; - } - - for frame in req.internal_request.body { - if let Err(err) = tx.send(frame).await { - self.request_body_txs.remove(&key); - tracing::trace!(?err, "error while sending"); - } + let (tx, rx) = mpsc::channel::(128); + let request = request.map_body(|frames| StreamingBody::new(rx, frames)); + let key = (request.connection_id, request.request_id); + interceptor.send(request).await; + self.request_body_txs.insert(key, tx); } } - ChunkedRequest::Body(body) => { - let key = &(body.connection_id, body.request_id); - let mut send_err = false; - if let Some(tx) = self.request_body_txs.get(key) { - for frame in body.frames { + + ChunkedRequest::Body(ChunkedHttpBody { + frames, + is_last, + connection_id, + request_id, + }) => { + if let Some(tx) = self.request_body_txs.get(&(connection_id, request_id)) { + let mut send_err = false; + + for frame in frames { if let Err(err) = tx.send(frame).await { send_err = true; - tracing::trace!(?err, "error while sending"); + tracing::debug!( + frame = ?err.0, + connection_id, + request_id, + "Failed to send an HTTP request body frame to the interceptor, channel is closed" + ); + break; } } - } - if send_err || body.is_last { - self.request_body_txs.remove(key); + + if send_err || is_last { + self.request_body_txs.remove(&(connection_id, request_id)); + } } } - ChunkedRequest::Error(err) => { - self.request_body_txs - .remove(&(err.connection_id, err.request_id)); - tracing::trace!(?err, "ChunkedRequest error received"); + + ChunkedRequest::Error(ChunkedHttpError { + connection_id, + request_id, + }) => { + self.request_body_txs.remove(&(connection_id, request_id)); + tracing::debug!( + connection_id, + request_id, + "Received an error in an HTTP request body", + ); } }; } + DaemonTcp::NewConnection(NewTcpConnection { connection_id, remote_address, @@ -317,13 +346,25 @@ impl IncomingProxy { local_address, }) => { let Some(subscription) = self.subscriptions.get(destination_port) else { - tracing::trace!("received a new connection for port {destination_port} that is no longer mirrored"); + tracing::debug!( + port = destination_port, + connection_id, + "Received a new connection for a port that is no longer subscribed, \ + sending an unsubscribe request.", + ); + + let message = if is_steal { + ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe(connection_id)) + } else { + ClientMessage::TcpSteal(LayerTcpSteal::ConnectionUnsubscribe(connection_id)) + }; + message_bus.send(message).await; + return Ok(()); }; let interceptor_socket = BoundTcpSocket::bind_specified_or_localhost(subscription.listening_on.ip())?; - let id = InterceptorId(connection_id); self.metadata_store.expect( @@ -339,11 +380,7 @@ impl IncomingProxy { ); let interceptor = self.background_tasks.register( - Interceptor::new( - interceptor_socket, - subscription.listening_on, - self.agent_protocol_version.clone(), - ), + Interceptor::new(interceptor_socket, subscription.listening_on), id, Self::CHANNEL_SIZE, ); @@ -356,6 +393,7 @@ impl IncomingProxy { }, ); } + DaemonTcp::SubscribeResult(result) => { let msgs = self.subscriptions.agent_responded(result)?; @@ -386,6 +424,99 @@ impl IncomingProxy { .get(&interceptor_id) .map(|handle| &handle.subscription) } + + /// Handles an HTTP response coming from one of the interceptors. + /// + /// If all response frames are already available, sends the response in a single message. + /// Otherwise, starts a response reader to handle the response. + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] + async fn handle_http_response( + &mut self, + mut response: HttpResponse, + message_bus: &mut MessageBus, + ) { + let _tail = match response.internal_response.body.tail.take() { + Some(tail) => tail, + + // All frames are already fetched, we don't have to stream the body to the agent. + None => { + let message = if self.agent_handles_framed_responses() { + // We can send just one message to the agent. + let response = response.map_body(|body| { + InternalHttpBody( + body.head + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect::>(), + ) + }); + LayerTcpSteal::HttpResponseFramed(response) + } else { + // Agent does not support `LayerTcpSteal::HttpResponseFramed`. + // We can only use legacy `LayerTcpSteal::HttpResponse`, which drops trailing + // headers. + let connection_id = response.connection_id; + let request_id = response.request_id; + let response = response.map_body(|body| { + let mut new_body = Vec::with_capacity(body.head.iter().filter_map(Frame::data_ref).map(|data| data.len()).sum()); + body.head.into_iter().for_each(|frame| match frame.into_data() { + Ok(data) => new_body.extend(data), + Err(frame) => { + if let Some(headers) = frame.trailers_ref() { + tracing::warn!( + connection_id, + request_id, + agent_protocol_version = ?self.agent_protocol_version, + ?headers, + "Agent uses an outdated version of mirrord protocol, \ + we can't send trailing headers from the local application's HTTP response." + ) + } + } + }); + new_body + }); + LayerTcpSteal::HttpResponse(response) + }; + + message_bus.send(ClientMessage::TcpSteal(message)).await; + + return; + } + }; + + if self.agent_handles_streamed_responses() { + let response = response.map_body(|body| { + body.head + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect::>() + }); + let message = ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( + ChunkedResponse::Start(response), + )); + message_bus.send(message).await; + todo!("start response reader") + } else if self.agent_handles_framed_responses() { + todo!("start response reader") + } else { + todo!("start response reader") + } + } + + fn agent_handles_framed_responses(&self) -> bool { + self.agent_protocol_version + .as_ref() + .map(|version| HTTP_FRAMED_VERSION.matches(version)) + .unwrap_or_default() + } + + fn agent_handles_streamed_responses(&self) -> bool { + self.agent_protocol_version + .as_ref() + .map(|version| HTTP_CHUNKED_RESPONSE_VERSION.matches(version)) + .unwrap_or_default() + } } impl BackgroundTask for IncomingProxy { @@ -452,10 +583,10 @@ impl BackgroundTask for IncomingProxy { } }, Some(IncomingProxyMessage::AgentMirror(msg)) => { - self.handle_agent_message(msg, message_bus).await?; + self.handle_agent_message(msg, false, message_bus).await?; } Some(IncomingProxyMessage::AgentSteal(msg)) => { - self.handle_agent_message(msg, message_bus).await?; + self.handle_agent_message(msg, true, message_bus).await?; } Some(IncomingProxyMessage::LayerClosed(msg)) => self.handle_layer_close(msg, message_bus).await, Some(IncomingProxyMessage::LayerForked(msg)) => self.handle_layer_fork(msg), @@ -482,115 +613,24 @@ impl BackgroundTask for IncomingProxy { let Some(PortSubscription::Steal(_)) = self.get_subscription(id) else { continue; }; - let msg = match msg { + + match msg { MessageOut::Raw(bytes) => { - ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { + let msg = ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { connection_id: id.0, bytes, - })) - }, - MessageOut::Http(HttpResponseFallback::Fallback(res)) => { - ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse(res)) - }, - MessageOut::Http(HttpResponseFallback::Framed(res)) => { - ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(res)) + })); + + message_bus.send(msg).await; }, - MessageOut::Http(HttpResponseFallback::Streamed(response, request)) => { - match self.streamed_http_response(response, request).await { - Some(response) => response, - None => continue, - } + + MessageOut::Http(response) => { + self.handle_http_response(response, message_bus).await; } }; - message_bus.send(msg).await; }, }, } } } } - -impl IncomingProxy { - /// Sends back the streamed http response to the agent. - /// - /// If we cannot get the next frame of the streamed body, then we retry the whole - /// process, by sending the original `request` again through the http `interceptor` to - /// our hyper handler. - #[allow(clippy::type_complexity)] - #[tracing::instrument(level = Level::TRACE, skip(self), ret)] - async fn streamed_http_response( - &mut self, - mut response: HttpResponse, hyper::Error>>>>, - request: Option, - ) -> Option { - let mut body = vec![]; - let key = (response.connection_id, response.request_id); - - match response - .internal_response - .body - .ready_frames() - .map_err(InterceptorError::from) - { - Ok(frames) => { - frames - .frames - .into_iter() - .map(From::from) - .for_each(|frame| body.push(frame)); - - self.response_body_rxs - .insert(key, StreamNotifyClose::new(response.internal_response.body)); - - let internal_response = InternalHttpResponse { - status: response.internal_response.status, - version: response.internal_response.version, - headers: response.internal_response.headers, - body, - }; - let response = ChunkedResponse::Start(HttpResponse { - port: response.port, - connection_id: response.connection_id, - request_id: response.request_id, - internal_response, - }); - Some(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( - response, - ))) - } - // Retry on known errors. - Err(error @ InterceptorError::Reset) - | Err(error @ InterceptorError::ConnectionClosedTooSoon(..)) - | Err(error @ InterceptorError::IncompleteMessage(..)) => { - tracing::warn!(%error, ?request, "Failed to read first frames of streaming HTTP response"); - - let interceptor = self - .interceptors - .get(&InterceptorId(response.connection_id))?; - - if let Some(HttpRequestFallback::Streamed { request, retries }) = request - && retries < RETRY_ON_RESET_ATTEMPTS - { - tracing::trace!( - ?request, - ?retries, - "`RST_STREAM` from hyper, retrying the request." - ); - interceptor - .tx - .send(HttpRequestFallback::Streamed { - request, - retries: retries + 1, - }) - .await; - } - - None - } - Err(fail) => { - tracing::warn!(?fail, "Something went wrong, skipping this response!"); - None - } - } - } -} diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index d122aab53c5..a2600d57203 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -1,87 +1,427 @@ -use std::convert::Infallible; +use std::{error::Error, fmt, io, net::SocketAddr, ops::Not, time::Duration}; use bytes::Bytes; -use http_body_util::combinators::BoxBody; +use exponential_backoff::Backoff; use hyper::{ - body::Incoming, + body::{Frame, Incoming}, client::conn::{http1, http2}, - Response, Version, + Request, Response, Version, }; use hyper_util::rt::{TokioExecutor, TokioIo}; -use mirrord_protocol::tcp::HttpRequestFallback; -use tokio::net::TcpStream; +use mirrord_protocol::{ + batched_body::BatchedBody, + tcp::{HttpRequest, StreamingBody}, +}; +use thiserror::Error; +use tokio::{net::TcpStream, time}; use tracing::Level; -use super::interceptor::{InterceptorError, InterceptorResult}; - -pub(super) const RETRY_ON_RESET_ATTEMPTS: u32 = 10; +use super::bound_socket::BoundTcpSocket; -/// Handles the differences between hyper's HTTP/1 and HTTP/2 connections. -pub enum HttpSender { - V1(http1::SendRequest>), - V2(http2::SendRequest>), +/// A retrying HTTP client used to pass requests to the user application. +pub struct LocalHttpClient { + /// Established HTTP connection with the user application. + sender: Option, + /// Established TCP connection with the user application. + stream: Option, + /// Address of the user application's HTTP server. + local_server_address: SocketAddr, } -/// Consumes the given [`TcpStream`] and performs an HTTP handshake, turning it into an HTTP -/// connection. -/// -/// # Returns -/// -/// [`HttpSender`] that can be used to send HTTP requests to the peer. -#[tracing::instrument(level = Level::TRACE, skip(target_stream), err(level = Level::WARN))] -pub async fn handshake( - version: Version, - target_stream: TcpStream, -) -> InterceptorResult { - match version { - Version::HTTP_2 => { - let (sender, connection) = - http2::handshake(TokioExecutor::default(), TokioIo::new(target_stream)).await?; - tokio::spawn(connection); - - Ok(HttpSender::V2(sender)) +impl LocalHttpClient { + /// How many times we attempt to send any given request. + /// + /// See [`LocalHttpError::can_retry`]. + const MAX_SEND_ATTEMPTS: u32 = 10; + const MIN_SEND_BACKOFF: Duration = Duration::from_millis(10); + const MAX_SEND_BACKOFF: Duration = Duration::from_millis(250); + + /// Crates a new client that will initially use the given `stream` (connection with the user + /// application's HTTP server). + pub fn new_for_stream(stream: TcpStream) -> Result { + let local_server_address = stream + .peer_addr() + .map_err(LocalHttpError::SocketSetupFailed)?; + + Ok(Self { + sender: None, + stream: Some(stream), + local_server_address, + }) + } + + /// Reuses or creates a new [`HttpSender`]. + #[tracing::instrument(level = Level::TRACE, err(level = Level::TRACE))] + async fn get_sender(&mut self, version: Version) -> Result { + if let Some(sender) = self.sender.take() { + if sender.version_matches(version) { + return Ok(sender); + } } - Version::HTTP_3 => Err(InterceptorError::UnsupportedHttpVersion(version)), + let stream = match self.stream.take() { + Some(stream) => stream, + None => { + let socket = + BoundTcpSocket::bind_specified_or_localhost(self.local_server_address.ip()) + .map_err(LocalHttpError::SocketSetupFailed)?; + socket + .connect(self.local_server_address) + .await + .map_err(LocalHttpError::ConnectTcpFailed)? + } + }; + + HttpSender::handshake(version, stream).await + } + + /// Tries to send the given `request` to the user application's HTTP server. + /// + /// Checks whether some reponse [`Frame`]s are instantly available. + async fn try_send_request( + &mut self, + request: &HttpRequest, + ) -> Result, LocalHttpError> { + let mut sender = self.get_sender(request.version()).await?; + let response = sender.send_request(request.clone()).await?; + let (parts, mut body) = response.into_parts(); + + let frames = body + .ready_frames() + .map_err(LocalHttpError::PeekBodyFailed)?; + let body = PeekedBody { + head: frames.frames, + tail: frames.is_last.not().then_some(body), + }; - _http_v1 => { - let (sender, connection) = http1::handshake(TokioIo::new(target_stream)).await?; + self.sender.replace(sender); - tokio::spawn(connection.with_upgrades()); + Ok(Response::from_parts(parts, body)) + } + + /// Tries to send the given `request` to the user application's HTTP server. + /// + /// Retries on known errors (see [`LocalHttpError::can_retry`]). + #[tracing::instrument(level = Level::DEBUG, err(level = Level::WARN), ret)] + pub async fn send_request( + &mut self, + request: &HttpRequest, + ) -> Result, LocalHttpError> { + let mut backoffs = Backoff::new( + Self::MAX_SEND_ATTEMPTS, + Self::MIN_SEND_BACKOFF, + Self::MAX_SEND_BACKOFF, + ) + .into_iter() + .flatten(); + + let mut attempt = 0; + loop { + attempt += 1; + tracing::trace!(attempt, "Trying to send the request"); + match (self.try_send_request(&request).await, backoffs.next()) { + (Ok(response), _) => { + tracing::trace!( + attempt, + "Successfully sent the request and peeked first frames" + ); + break Ok(response); + } + + (Err(error), Some(backoff)) if error.can_retry() => { + tracing::warn!( + attempt, + connection_id = request.connection_id, + request_id = request.request_id, + %error, + backoff_s = backoff.as_secs_f32(), + "Failed to send the request to the local application, retrying", + ); + + time::sleep(backoff).await; + } - Ok(HttpSender::V1(sender)) + (Err(error), _) => { + tracing::warn!( + attempts = attempt, + connection_id = request.connection_id, + request_id = request.request_id, + %error, + "Failed to send the request to the local application", + ); + + break Err(error); + } + } } } } +impl fmt::Debug for LocalHttpClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("LocalHttpClient") + .field("has_sender", &self.sender.is_some()) + .field("has_stream", &self.stream.is_some()) + .field("local_server_address", &self.local_server_address) + .finish() + } +} + +/// Errors that can occur when sending an HTTP request to the user application. +#[derive(Error, Debug)] +pub enum LocalHttpError { + #[error("HTTP handshake failed: {0}")] + HandshakeFailed(#[source] hyper::Error), + + #[error("{0:?} is not supported")] + UnsupportedHttpVersion(Version), + + #[error("sending the request failed: {0}")] + SendFailed(#[source] hyper::Error), + + #[error("setting up TCP socket failed: {0}")] + SocketSetupFailed(#[source] io::Error), + + #[error("making a TPC connection failed: {0}")] + ConnectTcpFailed(#[source] io::Error), + + #[error("reading first frames of the response body failed: {0}")] + PeekBodyFailed(#[source] hyper::Error), +} + +impl LocalHttpError { + /// Checks if the given [`hyper::Error`] originates from [`h2::Error`] `RST_STREAM`. + /// + /// This requires that we use the same [`h2`] version as [`hyper`], + /// which is verified in the `hyper_and_h2_versions_in_sync` test below. + pub fn is_h2_reset(error: &hyper::Error) -> bool { + let mut cause = error.source(); + while let Some(err) = cause { + if let Some(typed) = err.downcast_ref::() { + return typed.is_reset(); + }; + + cause = err.source(); + } + + false + } + + /// Checks if we can retry sending the request, given that the previous attempt resulted in this + /// error. + pub fn can_retry(&self) -> bool { + match self { + Self::SocketSetupFailed(..) | Self::UnsupportedHttpVersion(..) => false, + Self::ConnectTcpFailed(..) => true, + Self::HandshakeFailed(err) | Self::SendFailed(err) | Self::PeekBodyFailed(err) => { + err.is_closed() || err.is_incomplete_message() || Self::is_h2_reset(err) + } + } + } +} + +/// Response body returned from [`LocalHttpClient`]. +pub struct PeekedBody { + /// [`Frame`]s that were instantly available. + pub head: Vec>, + /// The rest of the response's body. + pub tail: Option, +} + +impl fmt::Debug for PeekedBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PeekedBody") + .field("head", &self.head) + .field("has_tail", &self.tail.is_some()) + .finish() + } +} + +/// Holds either [`http1::SendRequest`] or [`http2::SendRequest`] and exposes a unified interface. +enum HttpSender { + V1(http1::SendRequest), + V2(http2::SendRequest), +} + impl HttpSender { - #[tracing::instrument(level = Level::TRACE, skip(self), err(level = Level::WARN))] - pub async fn send( + /// Performs an HTTP handshake over the given [`TcpStream`]. + #[tracing::instrument(level = Level::DEBUG, skip(target_stream), err(level = Level::WARN))] + async fn handshake(version: Version, target_stream: TcpStream) -> Result { + let local_addr = target_stream + .local_addr() + .map_err(LocalHttpError::SocketSetupFailed)?; + let peer_addr = target_stream + .peer_addr() + .map_err(LocalHttpError::SocketSetupFailed)?; + + match version { + Version::HTTP_2 => { + let (sender, connection) = + http2::handshake(TokioExecutor::default(), TokioIo::new(target_stream)) + .await + .map_err(LocalHttpError::HandshakeFailed)?; + + tokio::spawn(async move { + match connection.await { + Ok(()) => { + tracing::trace!(%local_addr, %peer_addr, "HTTP connection with the local application finished"); + } + Err(error) => { + tracing::warn!(%error, %local_addr, %peer_addr, "HTTP connection with the local application failed"); + } + } + }); + + Ok(HttpSender::V2(sender)) + } + + Version::HTTP_3 => Err(LocalHttpError::UnsupportedHttpVersion(version)), + + _http_v1 => { + let (sender, connection) = http1::handshake(TokioIo::new(target_stream)) + .await + .map_err(LocalHttpError::HandshakeFailed)?; + + tokio::spawn(async move { + match connection.with_upgrades().await { + Ok(()) => { + tracing::trace!(%local_addr, %peer_addr, "HTTP connection with the local application finished"); + } + Err(error) => { + tracing::warn!(%error, %local_addr, %peer_addr, "HTTP connection with the local application failed"); + } + } + }); + + Ok(HttpSender::V1(sender)) + } + } + } + + /// Tries to send the given [`HttpRequest`] to the server. + #[tracing::instrument( + level = Level::DEBUG, + skip(self, request), + fields(connection_id = request.connection_id, request_id = request.request_id), + ret, + err(level = Level::WARN), + )] + async fn send_request( &mut self, - req: HttpRequestFallback, - ) -> InterceptorResult, InterceptorError> { + request: HttpRequest, + ) -> Result, LocalHttpError> { match self { Self::V1(sender) => { // Solves a "connection was not ready" client error. // https://rust-lang.github.io/wg-async/vision/submitted_stories/status_quo/barbara_tries_unix_socket.html#the-single-magical-line - sender.ready().await?; + sender.ready().await.map_err(LocalHttpError::SendFailed)?; + sender - .send_request(req.into_hyper()) + .send_request(request.internal_request.into()) .await - .map_err(Into::into) + .map_err(LocalHttpError::SendFailed) } Self::V2(sender) => { - let mut req = req.into_hyper(); + let mut hyper_request: Request<_> = request.internal_request.into(); + // fixes https://github.com/metalbear-co/mirrord/issues/2497 // inspired by https://github.com/linkerd/linkerd2-proxy/blob/c5d9f1c1e7b7dddd9d75c0d1a0dca68188f38f34/linkerd/proxy/http/src/h2.rs#L175 - if req.uri().authority().is_none() { - *req.version_mut() = hyper::http::Version::HTTP_11; + if hyper_request.uri().authority().is_none() + && hyper_request.version() != Version::HTTP_11 + { + tracing::trace!( + original_version = ?hyper_request.version(), + "Request URI has no authority, changing HTTP version to {:?}", + Version::HTTP_11, + ); + + *hyper_request.version_mut() = Version::HTTP_11; } + // Solves a "connection was not ready" client error. // https://rust-lang.github.io/wg-async/vision/submitted_stories/status_quo/barbara_tries_unix_socket.html#the-single-magical-line - sender.ready().await?; - sender.send_request(req).await.map_err(Into::into) + sender.ready().await.map_err(LocalHttpError::SendFailed)?; + + sender + .send_request(hyper_request) + .await + .map_err(LocalHttpError::SendFailed) } } } + + /// Returns whether this [`HttpSender`] can handle requests of the given [`Version`]. + fn version_matches(&self, version: Version) -> bool { + match (version, self) { + (Version::HTTP_2, Self::V2(..)) => true, + (Version::HTTP_3, _) => false, + (_, Self::V1(..)) => true, + _ => false, + } + } +} + +#[cfg(test)] +mod test { + use std::{ + convert::Infallible, + error::Error, + net::{Ipv4Addr, SocketAddr}, + }; + + use bytes::Bytes; + use http_body_util::Full; + use hyper::{server::conn::http2, service::service_fn, Response}; + use hyper_util::rt::{TokioExecutor, TokioIo}; + use tokio::net::TcpListener; + + /// Checks that [`hyper`] and [`h2`] crate versions are in sync with each other. + /// + /// In [`LocalHttpError::is_h2_reset`](super::LocalHttpError::is_h2_reset) we use + /// `source.downcast_ref::` to drill down on [`h2`] errors from [`hyper`], we + /// need these two crates to stay in sync, otherwise we could always fail some of our checks + /// that rely on this `downcast` working. + /// + /// Even though we're using [`h2::Error::is_reset`] in intproxy, this test can be + /// for any error, and thus here we do it for [`h2::Error::is_go_away`] which is + /// easier to trigger. + #[tokio::test] + async fn hyper_and_h2_versions_in_sync() { + let listener = TcpListener::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0)) + .await + .unwrap(); + let listener_address = listener.local_addr().unwrap(); + + let handle = tokio::spawn(async move { + let stream = listener.accept().await.unwrap().0; + http2::Builder::new(TokioExecutor::default()) + .serve_connection( + TokioIo::new(stream), + service_fn(|_| async move { + Ok::<_, Infallible>(Response::new(Full::new(Bytes::from("Heresy!")))) + }), + ) + .await + }); + + assert!(reqwest::get(format!("https://{listener_address}")) + .await + .is_err()); + + let conn_result = handle.await.unwrap(); + assert!( + conn_result + .as_ref() + .err() + .and_then(Error::source) + .and_then(|source| source.downcast_ref::()) + .is_some_and(h2::Error::is_go_away), + r"The request is supposed to fail with `GO_AWAY`! + Something is wrong if it didn't! + + >> If you're seeing this error, the cause is likely that `hyper` and `h2` + versions are out of sync, and we can't have that due to our use of + `downcast_ref` on some `h2` errors!" + ); + } } diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs index bcf68aef2fe..6000199186e 100644 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ b/mirrord/intproxy/src/proxies/incoming/interceptor.rs @@ -2,38 +2,33 @@ //! intercepted connection. use std::{ - error::Error, io::{self, ErrorKind}, net::SocketAddr, time::Duration, }; -use bytes::BytesMut; -use exponential_backoff::Backoff; -use hyper::{upgrade::OnUpgrade, StatusCode, Version}; +use bytes::{Bytes, BytesMut}; +use hyper::{body::Frame, upgrade::OnUpgrade, Response, StatusCode}; use hyper_util::rt::TokioIo; -use mirrord_protocol::tcp::{ - HttpRequestFallback, HttpResponse, HttpResponseFallback, InternalHttpBody, ReceiverStreamBody, - HTTP_CHUNKED_RESPONSE_VERSION, -}; +use mirrord_protocol::tcp::{HttpRequest, HttpResponse, InternalHttpResponse, StreamingBody}; use thiserror::Error; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream, - time::{self, sleep}, + time, }; use tracing::Level; -use super::http::HttpSender; +use super::http::{LocalHttpClient, LocalHttpError, PeekedBody}; use crate::{ - background_tasks::{BackgroundTask, MessageBus}, - proxies::incoming::{bound_socket::BoundTcpSocket, http::RETRY_ON_RESET_ATTEMPTS}, + background_tasks::{BackgroundTask, MessageBus, PeekableMessageBus}, + proxies::incoming::bound_socket::BoundTcpSocket, }; /// Messages consumed by the [`Interceptor`] when it runs as a [`BackgroundTask`]. pub enum MessageIn { /// Request to be sent to the user application. - Http(HttpRequestFallback), + Http(HttpRequest), /// Data to be sent to the user application. Raw(Vec), } @@ -42,13 +37,13 @@ pub enum MessageIn { #[derive(Debug)] pub enum MessageOut { /// Response received from the user application. - Http(HttpResponseFallback), + Http(HttpResponse), /// Data received from the user application. Raw(Vec), } -impl From for MessageIn { - fn from(value: HttpRequestFallback) -> Self { +impl From> for MessageIn { + fn from(value: HttpRequest) -> Self { Self::Http(value) } } @@ -60,59 +55,33 @@ impl From> for MessageIn { } /// Errors that can occur when [`Interceptor`] runs as a [`BackgroundTask`]. +/// +/// All of these are **fatal** for the interceptor and should terminate its main loop +/// ([`Interceptor::run`]). +/// +/// HTTP error handling and retries are done in the [`LocalHttpClient`]. #[derive(Error, Debug)] pub enum InterceptorError { - /// IO failed. - #[error("io failed: {0}")] - Io(#[from] io::Error), - /// Hyper failed. - #[error("hyper failed: {0}")] - Hyper(hyper::Error), - /// The layer closed connection too soon to send a request. - #[error("connection closed too soon")] - ConnectionClosedTooSoon(HttpRequestFallback), - - #[error("incomplete message")] - IncompleteMessage(HttpRequestFallback), - - /// Received a request with an unsupported HTTP version. - #[error("{0:?} is not supported")] - UnsupportedHttpVersion(Version), - /// Occurs when [`Interceptor`] receives [`MessageIn::Raw`], but it acts as an HTTP gateway and - /// there was no HTTP upgrade. - #[error("received raw bytes, but expected an HTTP request")] - UnexpectedRawData, - /// Occurs when [`Interceptor`] receives [`MessageIn::Http`], but it acts as a TCP proxy. - #[error("received an HTTP request, but expected raw bytes")] - UnexpectedHttpRequest, - - /// We dig into the [`hyper::Error`] to try and see if it's an [`h2::Error`], checking - /// for [`h2::Error::is_reset`]. - /// - /// [`hyper::Error`] mentions that `source` is not a guaranteed thing we can check for, - /// so if you see any weird behavior, check that the [`h2`] crate is in sync with - /// whatever hyper changed (for errors). - #[error("HTTP2 `RST_STREAM` received")] - Reset, - - /// We have reached the max number of attempts that we can retry our http connection, - /// due to a `RST_STREAM`, or when the connection has been closed too soon. - #[error("HTTP2 reached the maximum amount of retries!")] - MaxRetries, -} + #[error("failed to connect to the user application socket: {0}")] + ConnectFailed(#[source] io::Error), -impl From for InterceptorError { - fn from(hyper_fail: hyper::Error) -> Self { - if hyper_fail - .source() - .and_then(|source| source.downcast_ref::()) - .is_some_and(h2::Error::is_reset) - { - Self::Reset - } else { - Self::Hyper(hyper_fail) - } - } + #[error("io on the connection with the user application failed: {0}")] + IoFailed(#[source] io::Error), + + #[error("received an unexpected raw data ({} bytes)", .0.len())] + UnexpectedRawData(Vec), + + #[error("received an unexpected HTTP request: {0:?}")] + UnexpectedHttpRequest(HttpRequest), + + #[error(transparent)] + HttpFailed(#[from] LocalHttpError), + + #[error("failed to set up a TCP socket: {0}")] + SocketSetupFailed(#[source] io::Error), + + #[error("failed to handle an HTTP upgrade: {0}")] + HttpUpgradeFailed(#[source] hyper::Error), } pub type InterceptorResult = core::result::Result; @@ -129,8 +98,6 @@ pub struct Interceptor { socket: BoundTcpSocket, /// Address of user app's listener. peer: SocketAddr, - /// Version of [`mirrord_protocol`] negotiated with the agent. - agent_protocol_version: Option, } impl Interceptor { @@ -140,16 +107,8 @@ impl Interceptor { /// # Note /// /// The socket can be replaced when retrying HTTP requests. - pub fn new( - socket: BoundTcpSocket, - peer: SocketAddr, - agent_protocol_version: Option, - ) -> Self { - Self { - socket, - peer, - agent_protocol_version, - } + pub fn new(socket: BoundTcpSocket, peer: SocketAddr) -> Self { + Self { socket, peer } } } @@ -159,309 +118,143 @@ impl BackgroundTask for Interceptor { type MessageOut = MessageOut; #[tracing::instrument(level = Level::TRACE, skip_all, err)] - async fn run(self, message_bus: &mut MessageBus) -> InterceptorResult<(), Self::Error> { - let mut stream = self.socket.connect(self.peer).await?; - - // First, we determine whether this is a raw TCP connection or an HTTP connection. - // If we receive an HTTP request from our parent task, this must be an HTTP connection. - // If we receive raw bytes or our peer starts sending some data, this must be raw TCP. - let request = tokio::select! { - message = message_bus.recv() => match message { - Some(MessageIn::Raw(data)) => { - if data.is_empty() { - tracing::trace!("incoming interceptor -> agent shutdown, shutting down connection with layer"); - stream.shutdown().await?; - } else { - stream.write_all(&data).await?; - } - - return RawConnection { stream }.run(message_bus).await; - } - Some(MessageIn::Http(request)) => request, - None => return Ok(()), - }, + async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { + let stream = self + .socket + .connect(self.peer) + .await + .map_err(InterceptorError::ConnectFailed)?; + let mut message_bus = message_bus.peekable(); + tokio::select! { + // If there is some data from the user application before we get anything from the agent, + // then this is not HTTP. + // + // We should not block until the agent has something, we don't know what this protocol looks like. result = stream.readable() => { - result?; + result.map_err(InterceptorError::IoFailed)?; return RawConnection { stream }.run(message_bus).await; } - }; - let sender = super::http::handshake(request.version(), stream).await?; - let mut http_conn = HttpConnection { - sender, - peer: self.peer, - agent_protocol_version: self.agent_protocol_version.clone(), - }; - let (response, on_upgrade) = http_conn.send(request).await.inspect_err(|fail| { - tracing::error!(?fail, "Failed getting a filtered http response!") - })?; - message_bus.send(MessageOut::Http(response)).await; - - let raw = if let Some(on_upgrade) = on_upgrade { - let upgraded = on_upgrade.await?; - let parts = upgraded - .downcast::>() - .expect("IO type is known"); - if !parts.read_buf.is_empty() { - message_bus - .send(MessageOut::Raw(parts.read_buf.into())) - .await; + message = message_bus.peek() => match message { + Some(MessageIn::Http(..)) => {} + + Some(MessageIn::Raw(..)) => { + return RawConnection { stream }.run(message_bus).await; + } + + None => return Ok(()), } + } - Some(RawConnection { - stream: parts.io.into_inner(), - }) - } else { - http_conn.run(message_bus).await? + let http_conn = HttpConnection { + local_client: LocalHttpClient::new_for_stream(stream)?, }; - if let Some(raw) = raw { - raw.run(message_bus).await - } else { - Ok(()) - } + http_conn.run(message_bus).await } } /// Utilized by the [`Interceptor`] when it acts as an HTTP gateway. /// See [`HttpConnection::run`] for usage. struct HttpConnection { - /// Server address saved to allow for reconnecting in case a retry is required. - peer: SocketAddr, - /// Handle to the HTTP connection between the [`Interceptor`] the server. - sender: HttpSender, - /// Version of [`mirrord_protocol`] negotiated with the agent. - /// Determines which variant of [`LayerTcpSteal`](mirrord_protocol::tcp::LayerTcpSteal) - /// we use when sending HTTP responses. - agent_protocol_version: Option, + local_client: LocalHttpClient, } impl HttpConnection { - /// Returns whether the agent supports - /// [`LayerTcpSteal::HttpResponseChunked`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseChunked). - pub fn agent_supports_streaming_response(&self) -> bool { - self.agent_protocol_version - .as_ref() - .map(|version| HTTP_CHUNKED_RESPONSE_VERSION.matches(version)) - .unwrap_or(false) - } - /// Handles the result of sending an HTTP request. - /// Returns an [`HttpResponseFallback`] to be returned to the client or an [`InterceptorError`]. - /// - /// See [`HttpResponseFallback::response_from_request`] for notes on picking the correct - /// [`HttpResponseFallback`] variant. - #[tracing::instrument(level = Level::TRACE, skip(self, response), err(level = Level::WARN))] - async fn handle_response( - &self, - request: HttpRequestFallback, - response: InterceptorResult>, - ) -> InterceptorResult<(HttpResponseFallback, Option)> { + /// Returns an [`HttpResponse`] to be returned to the client or an [`InterceptorError`] when the + /// given [`LocalHttpError`] is fatal for the interceptor. Most [`LocalHttpError`]s are not + /// fatal and should be converted to [`StatusCode::BAD_GATEWAY`] responses instead. + #[tracing::instrument(level = Level::TRACE, ret, err(level = Level::WARN))] + fn handle_send_result( + request: HttpRequest, + response: Result, LocalHttpError>, + ) -> InterceptorResult<(HttpResponse, Option)> { match response { - Err(InterceptorError::Hyper(e)) if e.is_closed() => { - tracing::warn!( - "Sending request to local application failed with: {e:?}. \ - Seems like the local application closed the connection too early, so \ - creating a new connection and trying again." - ); - tracing::trace!("The request to be retried: {request:?}."); - - Err(InterceptorError::ConnectionClosedTooSoon(request)) + Err(LocalHttpError::SocketSetupFailed(error)) => { + Err(InterceptorError::SocketSetupFailed(error)) } - Err(InterceptorError::Hyper(e)) if e.is_parse() => { - tracing::warn!( - "Could not parse HTTP response to filtered HTTP request, got error: {e:?}." - ); - let body_message = format!( - "mirrord: could not parse HTTP response from local application - {e:?}" - ); - Ok(( - HttpResponseFallback::response_from_request( - request, - StatusCode::BAD_GATEWAY, - &body_message, - self.agent_protocol_version.as_ref(), - ), - None, - )) - } - Err(InterceptorError::Hyper(e)) if e.is_incomplete_message() => { - tracing::warn!( - "Sending request to local application failed with: {e:?}. \ - Connection closed before the message could complete!" - ); - tracing::trace!( - ?request, - "Retrying the request, see \ - [https://github.com/hyperium/hyper/issues/2136] for more info." - ); - Err(InterceptorError::IncompleteMessage(request)) + Err(LocalHttpError::UnsupportedHttpVersion(..)) => { + Err(InterceptorError::UnexpectedHttpRequest(request)) } - Err(fail) => { - tracing::warn!(?fail, "Request to local application failed!"); - let body_message = format!( - "mirrord tried to forward the request to the local application and got {fail:?}" - ); - Ok(( - HttpResponseFallback::response_from_request( - request, - StatusCode::BAD_GATEWAY, - &body_message, - self.agent_protocol_version.as_ref(), - ), - None, - )) + Err(error) => { + let message_frame = Frame::data(Bytes::from_owner(format!("mirrord: {error}"))); + let body = PeekedBody { + head: vec![message_frame], + tail: None, + }; + let response = HttpResponse { + port: request.port, + connection_id: request.connection_id, + request_id: request.request_id, + internal_response: InternalHttpResponse { + status: StatusCode::BAD_GATEWAY, + version: request.internal_request.version, + headers: Default::default(), + body, + }, + }; + + Ok((response, None)) } - Ok(mut res) => { - let upgrade = if res.status() == StatusCode::SWITCHING_PROTOCOLS { - Some(hyper::upgrade::on(&mut res)) + Ok(mut response) => { + let upgrade = if response.status() == StatusCode::SWITCHING_PROTOCOLS { + Some(hyper::upgrade::on(&mut response)) } else { None }; - let result = match &request { - HttpRequestFallback::Framed(..) => { - HttpResponse::::from_hyper_response( - res, - request.port(), - request.connection_id(), - request.request_id(), - ) - .await - .map(HttpResponseFallback::Framed) - } - HttpRequestFallback::Fallback(..) => { - HttpResponse::>::from_hyper_response( - res, - request.port(), - request.connection_id(), - request.request_id(), - ) - .await - .map(HttpResponseFallback::Fallback) - } - HttpRequestFallback::Streamed { .. } - if self.agent_supports_streaming_response() => - { - HttpResponse::::from_hyper_response( - res, - request.port(), - request.connection_id(), - request.request_id(), - ) - .await - .map(|response| { - HttpResponseFallback::Streamed(response, Some(request.clone())) - }) - } - HttpRequestFallback::Streamed { .. } => { - HttpResponse::::from_hyper_response( - res, - request.port(), - request.connection_id(), - request.request_id(), - ) - .await - .map(HttpResponseFallback::Framed) - } + let (parts, body) = response.into_parts(); + let response = HttpResponse { + port: request.port, + connection_id: request.connection_id, + request_id: request.request_id, + internal_response: InternalHttpResponse { + status: parts.status, + version: parts.version, + headers: parts.headers, + body, + }, }; - Ok(result.map(|response| (response, upgrade))?) - } - } - } - - /// Sends the given [`HttpRequestFallback`] to the server. - /// - /// If we get a `RST_STREAM` error from the server, or the connection was closed too - /// soon starts a new connection and retries using a [`Backoff`] until we reach - /// [`RETRY_ON_RESET_ATTEMPTS`]. - /// - /// Returns [`HttpResponseFallback`] from the server. - #[tracing::instrument(level = Level::TRACE, skip(self), ret, err)] - async fn send( - &mut self, - request: HttpRequestFallback, - ) -> InterceptorResult<(HttpResponseFallback, Option)> { - let min = Duration::from_millis(10); - let max = Duration::from_millis(250); - - let mut backoffs = Backoff::new(RETRY_ON_RESET_ATTEMPTS, min, max) - .into_iter() - .flatten(); - - // Retry to handle this request a few times. - loop { - let response = self.sender.send(request.clone()).await; - - match self.handle_response(request.clone(), response).await { - Ok(response) => return Ok(response), - - Err(error @ InterceptorError::Reset) - | Err(error @ InterceptorError::ConnectionClosedTooSoon(_)) - | Err(error @ InterceptorError::IncompleteMessage(_)) => { - tracing::warn!( - ?request, - %error, - "Either the connection closed, the message is incomplete, \ - or we got a reset, retrying!" - ); - - let Some(backoff) = backoffs.next() else { - break; - }; - - sleep(backoff).await; - - // Create a new connection for the next attempt. - let socket = BoundTcpSocket::bind_specified_or_localhost(self.peer.ip())?; - let stream = socket.connect(self.peer).await?; - let new_sender = super::http::handshake(request.version(), stream).await?; - self.sender = new_sender; - } - - Err(fail) => return Err(fail), + Ok((response, upgrade)) } } - - Err(InterceptorError::MaxRetries) } /// Proxies HTTP messages until an HTTP upgrade happens or the [`MessageBus`] closes. /// Support retries (with reconnecting to the HTTP server). /// - /// When an HTTP upgrade happens, the underlying [`TcpStream`] is reclaimed, wrapped - /// in a [`RawConnection`] and returned. When [`MessageBus`] closes, [`None`] is returned. + /// When an HTTP upgrade happens, the underlying [`TcpStream`] is reclaimed and wrapped + /// in a [`RawConnection`], which handles the rest of the connection. #[tracing::instrument(level = Level::TRACE, skip_all, ret, err)] async fn run( mut self, - message_bus: &mut MessageBus, - ) -> InterceptorResult> { + mut message_bus: PeekableMessageBus<'_, Interceptor>, + ) -> InterceptorResult<()> { let upgrade = loop { - let Some(msg) = message_bus.recv().await else { - return Ok(None); - }; + match message_bus.recv().await { + None => return Ok(()), - match msg { - MessageIn::Raw(..) => { + Some(MessageIn::Raw(data)) => { // We should not receive any raw data from the agent before sending a //`101 SWITCHING PROTOCOLS` response. - return Err(InterceptorError::UnexpectedRawData); + return Err(InterceptorError::UnexpectedRawData(data)); } - MessageIn::Http(req) => { - let (res, on_upgrade) = self.send(req).await.inspect_err(|fail| { - tracing::error!(?fail, "Failed getting a filtered http response!") - })?; - tracing::debug!("{} has upgrade: {}", res.request_id(), on_upgrade.is_some()); + Some(MessageIn::Http(request)) => { + let result = self.local_client.send_request(&request).await; + let (res, on_upgrade) = Self::handle_send_result(request, result)?; message_bus.send(MessageOut::Http(res)).await; if let Some(on_upgrade) = on_upgrade { - break on_upgrade.await?; + break on_upgrade + .await + .map_err(InterceptorError::HttpUpgradeFailed)?; } } } @@ -477,7 +270,7 @@ impl HttpConnection { message_bus.send(MessageOut::Raw(read_buf.into())).await; } - Ok(Some(RawConnection { stream })) + RawConnection { stream }.run(message_bus).await } } @@ -502,7 +295,10 @@ impl RawConnection { /// /// 3. This implementation exits only when an error is encountered or the [`MessageBus`] is /// closed. - async fn run(mut self, message_bus: &mut MessageBus) -> InterceptorResult<()> { + async fn run<'a>( + mut self, + mut message_bus: PeekableMessageBus<'a, Interceptor>, + ) -> InterceptorResult<()> { let mut buf = BytesMut::with_capacity(64 * 1024); let mut reading_closed = false; let mut remote_closed = false; @@ -513,7 +309,7 @@ impl RawConnection { res = self.stream.read_buf(&mut buf), if !reading_closed => match res { Err(e) if e.kind() == ErrorKind::WouldBlock => {}, - Err(e) => break Err(e.into()), + Err(e) => break Err(InterceptorError::IoFailed(e)), Ok(..) => { if buf.is_empty() { tracing::trace!("incoming interceptor -> layer shutdown, sending a 0-sized read to inform the agent"); @@ -532,12 +328,12 @@ impl RawConnection { Some(MessageIn::Raw(data)) => { if data.is_empty() { tracing::trace!("incoming interceptor -> agent shutdown, shutting down connection with layer"); - self.stream.shutdown().await?; + self.stream.shutdown().await.map_err(InterceptorError::IoFailed)?; } else { - self.stream.write_all(&data).await?; + self.stream.write_all(&data).await.map_err(InterceptorError::IoFailed)?; } }, - Some(MessageIn::Http(..)) => break Err(InterceptorError::UnexpectedHttpRequest), + Some(MessageIn::Http(request)) => break Err(InterceptorError::UnexpectedHttpRequest(request)), }, _ = time::sleep(Duration::from_secs(1)), if remote_closed => { @@ -552,25 +348,23 @@ impl RawConnection { #[cfg(test)] mod test { - use std::{ - convert::Infallible, - net::Ipv4Addr, - sync::{Arc, Mutex}, - }; + use std::{convert::Infallible, net::Ipv4Addr, sync::Arc}; use bytes::Bytes; use futures::future::FutureExt; - use http_body_util::{BodyExt, Empty, Full}; + use http_body_util::{BodyExt, Empty}; use hyper::{ body::Incoming, header::{HeaderValue, CONNECTION, UPGRADE}, server::conn::http1, service::service_fn, upgrade::Upgraded, - Method, Request, Response, + Method, Request, Response, Version, + }; + use hyper_util::rt::TokioIo; + use mirrord_protocol::tcp::{ + HttpRequest, InternalHttpBodyFrame, InternalHttpRequest, StreamingBody, }; - use hyper_util::rt::{TokioExecutor, TokioIo}; - use mirrord_protocol::tcp::{HttpRequest, InternalHttpRequest, StreamingBody}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::TcpListener, @@ -684,19 +478,11 @@ mod test { let interceptor = { let socket = BoundTcpSocket::bind_specified_or_localhost(Ipv4Addr::LOCALHOST.into()).unwrap(); - tasks.register( - Interceptor::new( - socket, - local_destination, - Some(mirrord_protocol::VERSION.clone()), - ), - (), - 8, - ) + tasks.register(Interceptor::new(socket, local_destination), (), 8) }; interceptor - .send(HttpRequestFallback::Fallback(HttpRequest { + .send(HttpRequest { connection_id: 0, request_id: 0, port: 80, @@ -712,24 +498,26 @@ mod test { version: Version::HTTP_11, body: Default::default(), }, - })) + }) .await; let (_, update) = tasks.next().await.expect("no task result"); match update { TaskUpdate::Message(MessageOut::Http(res)) => { - let res = res - .into_hyper::() - .expect("failed to convert into hyper response"); - assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS); - println!("{:?}", res.headers()); + assert_eq!( + res.internal_response.status, + StatusCode::SWITCHING_PROTOCOLS + ); + println!("Received repsonse from the interceptor: {res:?}"); assert!(res - .headers() + .internal_response + .headers .get(CONNECTION) .filter(|v| *v == "upgrade") .is_some()); assert!(res - .headers() + .internal_response + .headers .get(UPGRADE) .filter(|v| *v == TEST_PROTO) .is_some()); @@ -759,152 +547,87 @@ mod test { server_task.await.expect("dummy echo server panicked"); } - /// Ensure that [`HttpRequestFallback::Streamed`] are received frame by frame + /// Ensure that body of [`MessageOut::Http`] response is received frame by frame #[tokio::test] async fn receive_request_as_frames() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let local_destination = listener.local_addr().unwrap(); let mut tasks: BackgroundTasks<(), MessageOut, InterceptorError> = Default::default(); - let socket = - BoundTcpSocket::bind_specified_or_localhost(Ipv4Addr::LOCALHOST.into()).unwrap(); let interceptor = Interceptor::new( - socket, - local_destination, - Some(mirrord_protocol::VERSION.clone()), + BoundTcpSocket::bind_specified_or_localhost(Ipv4Addr::LOCALHOST.into()).unwrap(), + listener.local_addr().unwrap(), ); - let sender = tasks.register(interceptor, (), 8); - - let (tx, rx) = tokio::sync::mpsc::channel(12); - sender - .send(MessageIn::Http(HttpRequestFallback::Streamed { - request: HttpRequest { - internal_request: InternalHttpRequest { - method: Method::POST, - uri: "/".parse().unwrap(), - headers: Default::default(), - version: Version::HTTP_11, - body: StreamingBody::new(rx), - }, - connection_id: 1, - request_id: 2, - port: 3, + let interceptor = tasks.register(interceptor, (), 8); + + let (frame_tx, frame_rx) = tokio::sync::mpsc::channel(1); + interceptor + .send(MessageIn::Http(HttpRequest { + internal_request: InternalHttpRequest { + method: Method::POST, + uri: "/".parse().unwrap(), + headers: Default::default(), + version: Version::HTTP_11, + body: StreamingBody::from(frame_rx), }, - retries: 0, + connection_id: 1, + request_id: 2, + port: 3, })) .await; - let (connection, _peer_addr) = listener.accept().await.unwrap(); + let connection = listener.accept().await.unwrap().0; - let tx = Mutex::new(Some(tx)); let notifier = Arc::new(Notify::default()); - let finished = notifier.notified(); - let service = service_fn(|mut req: Request| { - let tx = tx.lock().unwrap().take().unwrap(); + // Task that sends the next frame when notified. + // Sends two frames, then exits. + tokio::spawn({ let notifier = notifier.clone(); async move { - let x = req.body_mut().frame().now_or_never(); - assert!(x.is_none()); - tx.send(mirrord_protocol::tcp::InternalHttpBodyFrame::Data( - b"string".to_vec(), - )) - .await - .unwrap(); - let x = req - .body_mut() - .frame() - .await - .unwrap() - .unwrap() - .into_data() - .unwrap(); - assert_eq!(x, b"string".to_vec()); - let x = req.body_mut().frame().now_or_never(); - assert!(x.is_none()); - - tx.send(mirrord_protocol::tcp::InternalHttpBodyFrame::Data( - b"another_string".to_vec(), - )) - .await - .unwrap(); - let x = req - .body_mut() - .frame() - .await - .unwrap() - .unwrap() - .into_data() - .unwrap(); - assert_eq!(x, b"another_string".to_vec()); - - drop(tx); - let x = req.body_mut().frame().await; - assert!(x.is_none()); - - notifier.notify_waiters(); - Ok::<_, hyper::Error>(Response::new(Empty::::new())) + for _ in 0..2 { + notifier.notified().await; + if frame_tx + .send(InternalHttpBodyFrame::Data(b"some-data".into())) + .await + .is_err() + { + break; + } + } + + // Wait for third notification before dropping the frame sender. + notifier.notified().await; } }); - let conn = http1::Builder::new().serve_connection(TokioIo::new(connection), service); - tokio::select! { - result = conn => { - result.unwrap() - } - _ = finished => { + let service = service_fn(|mut req: Request| { + let notifier = notifier.clone(); + async move { + for _ in 0..2 { + let frame = req.body_mut().frame().now_or_never(); + assert!(frame.is_none()); + + notifier.notify_one(); + let frame = req + .body_mut() + .frame() + .await + .unwrap() + .unwrap() + .into_data() + .unwrap(); + assert_eq!(frame, b"some-data".to_vec()); + let frame = req.body_mut().frame().now_or_never(); + assert!(frame.is_none()); + } - } - } - } + notifier.notify_one(); + let frame = req.body_mut().frame().await; + assert!(frame.is_none()); - /// Checks that [`hyper`] and [`h2`] crate versions are in sync with each other. - /// - /// As we use `source.downcast_ref::` to drill down on [`h2`] errors from - /// [`hyper`], we need these two crates to stay in sync, otherwise we could always - /// fail some of our checks that rely on this `downcast` working. - /// - /// Even though we're using [`h2::Error::is_reset`] in intproxy, this test can be - /// for any error, and thus here we do it for [`h2::Error::is_go_away`] which is - /// easier to trigger. - #[tokio::test] - async fn hyper_and_h2_versions_in_sync() { - let notify = Arc::new(Notify::new()); - let wait_notify = notify.clone(); - - tokio::spawn(async move { - let listener = TcpListener::bind("127.0.0.1:6666").await.unwrap(); - - notify.notify_waiters(); - let (io, _) = listener.accept().await.unwrap(); - - if let Err(fail) = hyper::server::conn::http2::Builder::new(TokioExecutor::default()) - .serve_connection( - TokioIo::new(io), - service_fn(|_| async move { - Ok::<_, Infallible>(Response::new(Full::new(Bytes::from("Heresy!")))) - }), - ) - .await - { - assert!(fail - .source() - .and_then(|source| source.downcast_ref::()) - .is_some_and(h2::Error::is_go_away)); - } else { - panic!( - r"The request is supposed to fail with `GO_AWAY`! - Something is wrong if it didn't! - - >> If you're seeing this error, the cause is likely that `hyper` and `h2` - versions are out of sync, and we can't have that due to our use of - `downcast_ref` on some `h2` errors!" - ); + Ok::<_, Infallible>(Response::new(Empty::::new())) } }); - - // Wait for the listener to be ready for our connection. - wait_notify.notified().await; - - assert!(reqwest::get("https://127.0.0.1:6666").await.is_err()); + let conn = http1::Builder::new().serve_connection(TokioIo::new(connection), service); + conn.await.unwrap(); } } diff --git a/mirrord/protocol/src/tcp.rs b/mirrord/protocol/src/tcp.rs index efc3ae98905..4fb5fb98c42 100644 --- a/mirrord/protocol/src/tcp.rs +++ b/mirrord/protocol/src/tcp.rs @@ -21,7 +21,7 @@ use hyper::{ use mirrord_macros::protocol_break; use semver::VersionReq; use serde::{Deserialize, Serialize}; -use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc::{self, Receiver}; use tokio_stream::wrappers::ReceiverStream; use tracing::{error, Level}; @@ -255,57 +255,8 @@ pub struct InternalHttpRequest { pub body: Body, } -impl From> for Request> -where - E: From, -{ - fn from(value: InternalHttpRequest) -> Self { - let InternalHttpRequest { - method, - uri, - headers, - version, - body, - } = value; - let mut request = Request::new(BoxBody::new(body.map_err(|e| e.into()))); - *request.method_mut() = method; - *request.uri_mut() = uri; - *request.version_mut() = version; - *request.headers_mut() = headers; - - request - } -} - -impl From>> for Request> -where - E: From, -{ - fn from(value: InternalHttpRequest>) -> Self { - let InternalHttpRequest { - method, - uri, - headers, - version, - body, - } = value; - let mut request = Request::new(BoxBody::new( - Full::new(Bytes::from(body)).map_err(|e| e.into()), - )); - *request.method_mut() = method; - *request.uri_mut() = uri; - *request.version_mut() = version; - *request.headers_mut() = headers; - - request - } -} - -impl From> for Request> -where - E: From, -{ - fn from(value: InternalHttpRequest) -> Self { +impl From> for Request { + fn from(value: InternalHttpRequest) -> Self { let InternalHttpRequest { method, uri, @@ -313,7 +264,7 @@ where version, body, } = value; - let mut request = Request::new(BoxBody::new(body.map_err(|e| e.into()))); + let mut request = Request::new(body); *request.method_mut() = method; *request.uri_mut() = uri; *request.version_mut() = version; @@ -333,11 +284,16 @@ pub enum HttpRequestFallback { }, } -#[derive(Debug)] +/// [`Body`] implementation that reads [`Frame`]s from an [`mpsc::channel`] and caches them +/// internally in a shared vector. +/// +/// This struct maintains its position in the shared vector. +/// When cloned, it resets the index. This allows for replaying the body even though it is streamed +/// from a channel. pub struct StreamingBody { /// Shared with instances acquired via [`Clone`]. /// Allows the clones to receive a copy of the data. - origin: Arc, Vec)>>, + shared_state: Arc, Vec)>>, /// Index of the next frame to return from the buffer. /// If outside of the buffer, we need to poll the stream to get the next frame. /// Local state of this instance, zeroed when cloning. @@ -345,9 +301,16 @@ pub struct StreamingBody { } impl StreamingBody { - pub fn new(rx: Receiver) -> Self { + /// Creates a new instance of this [`Body`]. + /// + /// It will first read all frames from the vector given as `first_frames`. + /// Following frames will be fetched from the given `rx`. + pub fn new( + rx: Receiver, + first_frames: Vec, + ) -> Self { Self { - origin: Arc::new(Mutex::new((rx, vec![]))), + shared_state: Arc::new(Mutex::new((rx, first_frames))), idx: 0, } } @@ -356,7 +319,8 @@ impl StreamingBody { impl Clone for StreamingBody { fn clone(&self) -> Self { Self { - origin: self.origin.clone(), + shared_state: self.shared_state.clone(), + // Setting idx to 0 in order to replay the previous frames. idx: 0, } } @@ -372,7 +336,7 @@ impl Body for StreamingBody { cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { let this = self.get_mut(); - let mut guard = this.origin.lock().unwrap(); + let mut guard = this.shared_state.lock().unwrap(); if let Some(frame) = guard.1.get(this.idx) { this.idx += 1; @@ -390,6 +354,37 @@ impl Body for StreamingBody { } } +impl Default for StreamingBody { + fn default() -> Self { + let (_, dummy_rx) = mpsc::channel(1); // `mpsc::channel` panics on capacity 0 + Self { + shared_state: Arc::new(Mutex::new((dummy_rx, Default::default()))), + idx: 0, + } + } +} + +impl From> for StreamingBody { + fn from(value: Vec) -> Self { + let (_, dummy_rx) = mpsc::channel(1); // `mpsc::channel` panics on capacity 0 + let frames = vec![InternalHttpBodyFrame::Data(value)]; + Self::new(dummy_rx, frames) + } +} + +impl From for StreamingBody { + fn from(value: InternalHttpBody) -> Self { + let (_, dummy_rx) = mpsc::channel(1); // `mpsc::channel` panics on capacity 0 + Self::new(dummy_rx, value.0.into_iter().collect()) + } +} + +impl From> for StreamingBody { + fn from(value: Receiver) -> Self { + Self::new(value, Default::default()) + } +} + impl HttpRequestFallback { pub fn connection_id(&self) -> ConnectionId { match self { @@ -422,16 +417,24 @@ impl HttpRequestFallback { HttpRequestFallback::Streamed { request: req, .. } => req.version(), } } +} - pub fn into_hyper(self) -> Request> - where - E: From, - { - match self { - HttpRequestFallback::Framed(req) => req.internal_request.into(), - HttpRequestFallback::Fallback(req) => req.internal_request.into(), - HttpRequestFallback::Streamed { request: req, .. } => req.internal_request.into(), +impl fmt::Debug for StreamingBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut s = f.debug_struct("StreamingBody"); + s.field("idx", &self.idx); + + match self.shared_state.try_lock() { + Ok(guard) => { + s.field("frame_rx_closed", &guard.0.is_closed()); + s.field("cached_frames", &guard.1); + } + Err(error) => { + s.field("lock_error", &error); + } } + + s.finish() } } @@ -478,6 +481,21 @@ impl HttpRequest { pub fn version(&self) -> Version { self.internal_request.version } + + pub fn map_body B2>(self, map: F) -> HttpRequest { + HttpRequest { + connection_id: self.connection_id, + request_id: self.request_id, + port: self.port, + internal_request: InternalHttpRequest { + method: self.internal_request.method, + uri: self.internal_request.uri, + headers: self.internal_request.headers, + version: self.internal_request.version, + body: map(self.internal_request.body), + }, + } + } } /// (De-)Serializable HTTP response. @@ -517,7 +535,7 @@ impl InternalHttpResponse { } #[derive(Serialize, Deserialize, Debug, Default, PartialEq, Eq, Clone)] -pub struct InternalHttpBody(VecDeque); +pub struct InternalHttpBody(pub VecDeque); impl InternalHttpBody { pub fn from_bytes(bytes: &[u8]) -> Self { @@ -710,6 +728,22 @@ pub struct HttpResponse { pub internal_response: InternalHttpResponse, } +impl HttpResponse { + pub fn map_body B2>(self, map: F) -> HttpResponse { + HttpResponse { + connection_id: self.connection_id, + request_id: self.request_id, + port: self.port, + internal_response: InternalHttpResponse { + status: self.internal_response.status, + version: self.internal_response.version, + headers: self.internal_response.headers, + body: map(self.internal_response.body), + }, + } + } +} + impl HttpResponse { /// We cannot implement this with the [`From`] trait as it doesn't support `async` conversions, /// and we also need some extra parameters. From 6a475856d76da4d56b328ccd05f0fd9d80c1122b Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 14 Jan 2025 15:50:18 +0100 Subject: [PATCH 06/60] Remove obsolete stuff from mirrord-protocol --- Cargo.lock | 1 - mirrord/agent/src/steal/api.rs | 6 +- mirrord/agent/src/steal/connection.rs | 42 +- mirrord/intproxy/src/proxies/incoming.rs | 5 +- mirrord/intproxy/src/proxies/incoming/http.rs | 7 +- .../src/proxies/incoming/interceptor.rs | 11 +- .../src/proxies/incoming/streaming_body.rs | 132 ++++ mirrord/protocol/Cargo.toml | 1 - mirrord/protocol/src/tcp.rs | 674 +++--------------- 9 files changed, 237 insertions(+), 642 deletions(-) create mode 100644 mirrord/intproxy/src/proxies/incoming/streaming_body.rs diff --git a/Cargo.lock b/Cargo.lock index 466053b1533..9c318674fad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4498,7 +4498,6 @@ dependencies = [ "serde", "socket2", "thiserror 2.0.9", - "tokio", "tokio-stream", "tracing", ] diff --git a/mirrord/agent/src/steal/api.rs b/mirrord/agent/src/steal/api.rs index a6ec1d8d1f7..0f610b3e1a1 100644 --- a/mirrord/agent/src/steal/api.rs +++ b/mirrord/agent/src/steal/api.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, convert::Infallible}; use bytes::Bytes; use hyper::body::Frame; @@ -40,7 +40,7 @@ pub(crate) struct TcpStealerApi { /// View on the stealer task's status. task_status: TaskStatus, - response_body_txs: HashMap<(ConnectionId, RequestId), Sender>>>, + response_body_txs: HashMap<(ConnectionId, RequestId), Sender, Infallible>>>, } impl TcpStealerApi { @@ -196,7 +196,7 @@ impl TcpStealerApi { let key = (response.connection_id, response.request_id); self.response_body_txs.insert(key, tx.clone()); - self.http_response(HttpResponseFallback::Streamed(http_response, None)) + self.http_response(HttpResponseFallback::Streamed(http_response)) .await?; for frame in response.internal_response.body { diff --git a/mirrord/agent/src/steal/connection.rs b/mirrord/agent/src/steal/connection.rs index fa94556c353..4ad9e92816e 100644 --- a/mirrord/agent/src/steal/connection.rs +++ b/mirrord/agent/src/steal/connection.rs @@ -599,40 +599,16 @@ impl TcpConnectionStealer { async fn send_http_response(&mut self, client_id: ClientId, response: HttpResponseFallback) { let connection_id = response.connection_id(); let request_id = response.request_id(); - - match response.into_hyper::() { - Ok(response) => { - self.connections - .send( - connection_id, - ConnectionMessageIn::Response { - client_id, - request_id, - response, - }, - ) - .await; - } - Err(error) => { - tracing::warn!( - ?error, - connection_id, - request_id, + self.connections + .send( + connection_id, + ConnectionMessageIn::Response { client_id, - "Failed to transform client message into a hyper response", - ); - - self.connections - .send( - connection_id, - ConnectionMessageIn::ResponseFailed { - client_id, - request_id, - }, - ) - .await; - } - } + request_id, + response: response.into_hyper::(), + }, + ) + .await; } /// Handles [`Command`]s that were received by [`TcpConnectionStealer::command_rx`]. diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 74175e2eaeb..c83e8ff5387 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -20,11 +20,11 @@ use mirrord_protocol::{ tcp::{ ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, ChunkedResponse, DaemonTcp, HttpResponse, InternalHttpBody, InternalHttpBodyFrame, LayerTcp, LayerTcpSteal, - NewTcpConnection, StreamingBody, TcpData, HTTP_CHUNKED_RESPONSE_VERSION, - HTTP_FRAMED_VERSION, + NewTcpConnection, TcpData, HTTP_CHUNKED_RESPONSE_VERSION, HTTP_FRAMED_VERSION, }, ClientMessage, ConnectionId, Port, RequestId, ResponseError, }; +use streaming_body::StreamingBody; use thiserror::Error; use tokio::sync::mpsc::{self, Sender}; use tokio_stream::{StreamMap, StreamNotifyClose}; @@ -46,6 +46,7 @@ mod http; mod interceptor; mod metadata_store; pub mod port_subscription_ext; +mod streaming_body; mod subscriptions; /// Id of a single [`Interceptor`] task. Used to manage interceptor tasks with the diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index a2600d57203..dd8f2874ed7 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -8,15 +8,12 @@ use hyper::{ Request, Response, Version, }; use hyper_util::rt::{TokioExecutor, TokioIo}; -use mirrord_protocol::{ - batched_body::BatchedBody, - tcp::{HttpRequest, StreamingBody}, -}; +use mirrord_protocol::{batched_body::BatchedBody, tcp::HttpRequest}; use thiserror::Error; use tokio::{net::TcpStream, time}; use tracing::Level; -use super::bound_socket::BoundTcpSocket; +use super::{bound_socket::BoundTcpSocket, streaming_body::StreamingBody}; /// A retrying HTTP client used to pass requests to the user application. pub struct LocalHttpClient { diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs index 6000199186e..7a14e7daf1f 100644 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ b/mirrord/intproxy/src/proxies/incoming/interceptor.rs @@ -10,7 +10,7 @@ use std::{ use bytes::{Bytes, BytesMut}; use hyper::{body::Frame, upgrade::OnUpgrade, Response, StatusCode}; use hyper_util::rt::TokioIo; -use mirrord_protocol::tcp::{HttpRequest, HttpResponse, InternalHttpResponse, StreamingBody}; +use mirrord_protocol::tcp::{HttpRequest, HttpResponse, InternalHttpResponse}; use thiserror::Error; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, @@ -19,7 +19,10 @@ use tokio::{ }; use tracing::Level; -use super::http::{LocalHttpClient, LocalHttpError, PeekedBody}; +use super::{ + http::{LocalHttpClient, LocalHttpError, PeekedBody}, + streaming_body::StreamingBody, +}; use crate::{ background_tasks::{BackgroundTask, MessageBus, PeekableMessageBus}, proxies::incoming::bound_socket::BoundTcpSocket, @@ -362,9 +365,7 @@ mod test { Method, Request, Response, Version, }; use hyper_util::rt::TokioIo; - use mirrord_protocol::tcp::{ - HttpRequest, InternalHttpBodyFrame, InternalHttpRequest, StreamingBody, - }; + use mirrord_protocol::tcp::{HttpRequest, InternalHttpBodyFrame, InternalHttpRequest}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::TcpListener, diff --git a/mirrord/intproxy/src/proxies/incoming/streaming_body.rs b/mirrord/intproxy/src/proxies/incoming/streaming_body.rs new file mode 100644 index 00000000000..f23adaed761 --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/streaming_body.rs @@ -0,0 +1,132 @@ +use std::{ + convert::Infallible, + fmt, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use hyper::body::{Body, Frame}; +use mirrord_protocol::tcp::{InternalHttpBody, InternalHttpBodyFrame}; +use tokio::sync::mpsc::{self, Receiver}; + +/// [`Body`] implementation that reads [`Frame`]s from an [`mpsc::channel`] and caches them +/// internally in a shared vector. +/// +/// This struct maintains its position in the shared vector. +/// When cloned, it resets the index. This allows for replaying the body even though it is streamed +/// from a channel. +pub struct StreamingBody { + /// Shared with instances acquired via [`Clone`]. + /// Allows the clones to receive a copy of the data. + shared_state: Arc, Vec)>>, + /// Index of the next frame to return from the buffer. + /// If outside of the buffer, we need to poll the stream to get the next frame. + /// Local state of this instance, zeroed when cloning. + idx: usize, +} + +impl StreamingBody { + /// Creates a new instance of this [`Body`]. + /// + /// It will first read all frames from the vector given as `first_frames`. + /// Following frames will be fetched from the given `rx`. + pub fn new( + rx: Receiver, + first_frames: Vec, + ) -> Self { + Self { + shared_state: Arc::new(Mutex::new((rx, first_frames))), + idx: 0, + } + } +} + +impl Clone for StreamingBody { + fn clone(&self) -> Self { + Self { + shared_state: self.shared_state.clone(), + // Setting idx to 0 in order to replay the previous frames. + idx: 0, + } + } +} + +impl Body for StreamingBody { + type Data = Bytes; + + type Error = Infallible; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let this = self.get_mut(); + let mut guard = this.shared_state.lock().unwrap(); + + if let Some(frame) = guard.1.get(this.idx) { + this.idx += 1; + return Poll::Ready(Some(Ok(frame.clone().into()))); + } + + match std::task::ready!(guard.0.poll_recv(cx)) { + None => Poll::Ready(None), + Some(frame) => { + guard.1.push(frame.clone()); + this.idx += 1; + Poll::Ready(Some(Ok(frame.into()))) + } + } + } +} + +impl Default for StreamingBody { + fn default() -> Self { + let (_, dummy_rx) = mpsc::channel(1); // `mpsc::channel` panics on capacity 0 + Self { + shared_state: Arc::new(Mutex::new((dummy_rx, Default::default()))), + idx: 0, + } + } +} + +impl From> for StreamingBody { + fn from(value: Vec) -> Self { + let (_, dummy_rx) = mpsc::channel(1); // `mpsc::channel` panics on capacity 0 + let frames = vec![InternalHttpBodyFrame::Data(value)]; + Self::new(dummy_rx, frames) + } +} + +impl From for StreamingBody { + fn from(value: InternalHttpBody) -> Self { + let (_, dummy_rx) = mpsc::channel(1); // `mpsc::channel` panics on capacity 0 + Self::new(dummy_rx, value.0.into_iter().collect()) + } +} + +impl From> for StreamingBody { + fn from(value: Receiver) -> Self { + Self::new(value, Default::default()) + } +} + +impl fmt::Debug for StreamingBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut s = f.debug_struct("StreamingBody"); + s.field("idx", &self.idx); + + match self.shared_state.try_lock() { + Ok(guard) => { + s.field("frame_rx_closed", &guard.0.is_closed()); + s.field("cached_frames", &guard.1); + } + Err(error) => { + s.field("lock_error", &error); + } + } + + s.finish() + } +} diff --git a/mirrord/protocol/Cargo.toml b/mirrord/protocol/Cargo.toml index 8f56636c15e..a7f578d482f 100644 --- a/mirrord/protocol/Cargo.toml +++ b/mirrord/protocol/Cargo.toml @@ -33,7 +33,6 @@ fancy-regex = { workspace = true } socket2.workspace = true semver = { workspace = true, features = ["serde"] } tokio-stream.workspace = true -tokio.workspace = true mirrord-macros = { path = "../macros" } diff --git a/mirrord/protocol/src/tcp.rs b/mirrord/protocol/src/tcp.rs index 4fb5fb98c42..32ca5fef5f9 100644 --- a/mirrord/protocol/src/tcp.rs +++ b/mirrord/protocol/src/tcp.rs @@ -5,7 +5,7 @@ use std::{ fmt, net::IpAddr, pin::Pin, - sync::{Arc, LazyLock, Mutex}, + sync::LazyLock, task::{Context, Poll}, }; @@ -13,19 +13,15 @@ use bincode::{Decode, Encode}; use bytes::Bytes; use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; use hyper::{ - body::{Body, Frame, Incoming}, - http, - http::response::Parts, + body::{Body, Frame}, HeaderMap, Method, Request, Response, StatusCode, Uri, Version, }; use mirrord_macros::protocol_break; use semver::VersionReq; use serde::{Deserialize, Serialize}; -use tokio::sync::mpsc::{self, Receiver}; use tokio_stream::wrappers::ReceiverStream; -use tracing::{error, Level}; -use crate::{batched_body::BatchedBody, ConnectionId, Port, RemoteResult, RequestId}; +use crate::{ConnectionId, Port, RemoteResult, RequestId}; #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub struct NewTcpConnection { @@ -120,7 +116,7 @@ pub struct Filter(String); impl Filter { pub fn new(filter_str: String) -> Result> { let _ = fancy_regex::Regex::new(&filter_str).inspect_err(|fail| { - error!( + tracing::error!( r" Something went wrong while creating a regex for [{filter_str:#?}]! @@ -239,7 +235,7 @@ pub enum ChunkedResponse { /// (De-)Serializable HTTP request. #[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] -pub struct InternalHttpRequest { +pub struct InternalHttpRequest { #[serde(with = "http_serde::method")] pub method: Method, @@ -252,11 +248,34 @@ pub struct InternalHttpRequest { #[serde(with = "http_serde::version")] pub version: Version, - pub body: Body, + pub body: B, +} + +impl InternalHttpRequest { + pub fn map_body(self, cb: F) -> InternalHttpRequest + where + F: FnOnce(B) -> T, + { + let InternalHttpRequest { + version, + headers, + method, + uri, + body, + } = self; + + InternalHttpRequest { + version, + headers, + method, + uri, + body: cb(body), + } + } } -impl From> for Request { - fn from(value: InternalHttpRequest) -> Self { +impl From> for Request { + fn from(value: InternalHttpRequest) -> Self { let InternalHttpRequest { method, uri, @@ -274,170 +293,6 @@ impl From> for Request { } } -#[derive(Clone, Debug)] -pub enum HttpRequestFallback { - Framed(HttpRequest), - Fallback(HttpRequest>), - Streamed { - request: HttpRequest, - retries: u32, - }, -} - -/// [`Body`] implementation that reads [`Frame`]s from an [`mpsc::channel`] and caches them -/// internally in a shared vector. -/// -/// This struct maintains its position in the shared vector. -/// When cloned, it resets the index. This allows for replaying the body even though it is streamed -/// from a channel. -pub struct StreamingBody { - /// Shared with instances acquired via [`Clone`]. - /// Allows the clones to receive a copy of the data. - shared_state: Arc, Vec)>>, - /// Index of the next frame to return from the buffer. - /// If outside of the buffer, we need to poll the stream to get the next frame. - /// Local state of this instance, zeroed when cloning. - idx: usize, -} - -impl StreamingBody { - /// Creates a new instance of this [`Body`]. - /// - /// It will first read all frames from the vector given as `first_frames`. - /// Following frames will be fetched from the given `rx`. - pub fn new( - rx: Receiver, - first_frames: Vec, - ) -> Self { - Self { - shared_state: Arc::new(Mutex::new((rx, first_frames))), - idx: 0, - } - } -} - -impl Clone for StreamingBody { - fn clone(&self) -> Self { - Self { - shared_state: self.shared_state.clone(), - // Setting idx to 0 in order to replay the previous frames. - idx: 0, - } - } -} - -impl Body for StreamingBody { - type Data = Bytes; - - type Error = Infallible; - - fn poll_frame( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - let this = self.get_mut(); - let mut guard = this.shared_state.lock().unwrap(); - - if let Some(frame) = guard.1.get(this.idx) { - this.idx += 1; - return Poll::Ready(Some(Ok(frame.clone().into()))); - } - - match std::task::ready!(guard.0.poll_recv(cx)) { - None => Poll::Ready(None), - Some(frame) => { - guard.1.push(frame.clone()); - this.idx += 1; - Poll::Ready(Some(Ok(frame.into()))) - } - } - } -} - -impl Default for StreamingBody { - fn default() -> Self { - let (_, dummy_rx) = mpsc::channel(1); // `mpsc::channel` panics on capacity 0 - Self { - shared_state: Arc::new(Mutex::new((dummy_rx, Default::default()))), - idx: 0, - } - } -} - -impl From> for StreamingBody { - fn from(value: Vec) -> Self { - let (_, dummy_rx) = mpsc::channel(1); // `mpsc::channel` panics on capacity 0 - let frames = vec![InternalHttpBodyFrame::Data(value)]; - Self::new(dummy_rx, frames) - } -} - -impl From for StreamingBody { - fn from(value: InternalHttpBody) -> Self { - let (_, dummy_rx) = mpsc::channel(1); // `mpsc::channel` panics on capacity 0 - Self::new(dummy_rx, value.0.into_iter().collect()) - } -} - -impl From> for StreamingBody { - fn from(value: Receiver) -> Self { - Self::new(value, Default::default()) - } -} - -impl HttpRequestFallback { - pub fn connection_id(&self) -> ConnectionId { - match self { - HttpRequestFallback::Framed(req) => req.connection_id, - HttpRequestFallback::Fallback(req) => req.connection_id, - HttpRequestFallback::Streamed { request: req, .. } => req.connection_id, - } - } - - pub fn port(&self) -> Port { - match self { - HttpRequestFallback::Framed(req) => req.port, - HttpRequestFallback::Fallback(req) => req.port, - HttpRequestFallback::Streamed { request: req, .. } => req.port, - } - } - - pub fn request_id(&self) -> RequestId { - match self { - HttpRequestFallback::Framed(req) => req.request_id, - HttpRequestFallback::Fallback(req) => req.request_id, - HttpRequestFallback::Streamed { request: req, .. } => req.request_id, - } - } - - pub fn version(&self) -> Version { - match self { - HttpRequestFallback::Framed(req) => req.version(), - HttpRequestFallback::Fallback(req) => req.version(), - HttpRequestFallback::Streamed { request: req, .. } => req.version(), - } - } -} - -impl fmt::Debug for StreamingBody { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut s = f.debug_struct("StreamingBody"); - s.field("idx", &self.idx); - - match self.shared_state.try_lock() { - Ok(guard) => { - s.field("frame_rx_closed", &guard.0.is_closed()); - s.field("cached_frames", &guard.1); - } - Err(error) => { - s.field("lock_error", &error); - } - } - - s.finish() - } -} - /// Minimal mirrord-protocol version that allows [`DaemonTcp::HttpRequestFramed`] and /// [`LayerTcpSteal::HttpResponseFramed`]. pub static HTTP_FRAMED_VERSION: LazyLock = @@ -482,7 +337,10 @@ impl HttpRequest { self.internal_request.version } - pub fn map_body B2>(self, map: F) -> HttpRequest { + pub fn map_body(self, map: F) -> HttpRequest + where + F: FnOnce(B) -> T, + { HttpRequest { connection_id: self.connection_id, request_id: self.request_id, @@ -538,12 +396,6 @@ impl InternalHttpResponse { pub struct InternalHttpBody(pub VecDeque); impl InternalHttpBody { - pub fn from_bytes(bytes: &[u8]) -> Self { - InternalHttpBody(VecDeque::from([InternalHttpBodyFrame::Data( - bytes.to_vec(), - )])) - } - pub async fn from_body(mut body: B) -> Result where B: Body + Unpin, @@ -583,15 +435,11 @@ pub enum InternalHttpBodyFrame { impl From> for InternalHttpBodyFrame { fn from(frame: Frame) -> Self { - if frame.is_data() { - InternalHttpBodyFrame::Data(frame.into_data().expect("Malfromed data frame").to_vec()) - } else if frame.is_trailers() { - InternalHttpBodyFrame::Trailers( - frame.into_trailers().expect("Malfromed trailers frame"), - ) - } else { - panic!("Malfromed frame type") - } + frame + .into_data() + .map(|bytes| Self::Data(bytes.into())) + .or_else(|frame| frame.into_trailers().map(Self::Trailers)) + .expect("malformed frame type") } } @@ -609,22 +457,13 @@ impl fmt::Debug for InternalHttpBodyFrame { } } -pub type ReceiverStreamBody = StreamBody>>>; +pub type ReceiverStreamBody = StreamBody, Infallible>>>; #[derive(Debug)] pub enum HttpResponseFallback { Framed(HttpResponse), Fallback(HttpResponse>), - - /// Holds the [`HttpResponse`] that we're supposed to send back to the agent. - /// - /// It also holds the original http request [`HttpRequestFallback`], so we can retry - /// if our hyper server sent us a - /// [`RST_STREAM`](https://docs.rs/h2/latest/h2/struct.Error.html#method.is_reset). - Streamed( - HttpResponse, - Option, - ), + Streamed(HttpResponse), } impl HttpResponseFallback { @@ -632,7 +471,7 @@ impl HttpResponseFallback { match self { HttpResponseFallback::Framed(req) => req.connection_id, HttpResponseFallback::Fallback(req) => req.connection_id, - HttpResponseFallback::Streamed(req, _) => req.connection_id, + HttpResponseFallback::Streamed(req) => req.connection_id, } } @@ -640,391 +479,47 @@ impl HttpResponseFallback { match self { HttpResponseFallback::Framed(req) => req.request_id, HttpResponseFallback::Fallback(req) => req.request_id, - HttpResponseFallback::Streamed(req, _) => req.request_id, + HttpResponseFallback::Streamed(req) => req.request_id, } } - #[tracing::instrument(level = Level::TRACE, err(level = Level::WARN))] - pub fn into_hyper(self) -> Result>, http::Error> - where - E: From, - { + pub fn into_hyper(self) -> Response> { match self { - HttpResponseFallback::Framed(req) => req.internal_response.try_into(), - HttpResponseFallback::Fallback(req) => req.internal_response.try_into(), - HttpResponseFallback::Streamed(req, _) => req.internal_response.try_into(), - } - } - - /// Produces an [`HttpResponseFallback`] to the given [`HttpRequestFallback`]. - /// - /// # Note on picking response variant - /// - /// Variant of returned [`HttpResponseFallback`] is picked based on the variant of given - /// [`HttpRequestFallback`] and agent protocol version. We need to consider both due - /// to: - /// 1. Old agent versions always responding with client's `mirrord_protocol` version to - /// [`ClientMessage::SwitchProtocolVersion`](super::ClientMessage::SwitchProtocolVersion), - /// 2. [`LayerTcpSteal::HttpResponseChunked`] being introduced after - /// [`DaemonTcp::HttpRequestChunked`]. - pub fn response_from_request( - request: HttpRequestFallback, - status: StatusCode, - message: &str, - agent_protocol_version: Option<&semver::Version>, - ) -> Self { - let agent_supports_streaming_response = agent_protocol_version - .map(|version| HTTP_CHUNKED_RESPONSE_VERSION.matches(version)) - .unwrap_or(false); - - match request.clone() { - // We received `DaemonTcp::HttpRequestFramed` from the agent, - // so we know it supports `LayerTcpSteal::HttpResponseFramed` (both were introduced in - // the same `mirrord_protocol` version). - HttpRequestFallback::Framed(request) => HttpResponseFallback::Framed( - HttpResponse::::response_from_request(request, status, message), - ), - - // We received `DaemonTcp::HttpRequest` from the agent, so we assume it only supports - // `LayerTcpSteal::HttpResponse`. - HttpRequestFallback::Fallback(request) => HttpResponseFallback::Fallback( - HttpResponse::>::response_from_request(request, status, message), - ), - - // We received `DaemonTcp::HttpRequestChunked` and the agent supports - // `LayerTcpSteal::HttpResponseChunked`. - HttpRequestFallback::Streamed { - request: streamed_request, - .. - } if agent_supports_streaming_response => HttpResponseFallback::Streamed( - HttpResponse::::response_from_request( - streamed_request, - status, - message, - ), - Some(request), - ), - - // We received `DaemonTcp::HttpRequestChunked` from the agent, - // but the agent does not support `LayerTcpSteal::HttpResponseChunked`. - // However, it must support the older `LayerTcpSteal::HttpResponseFramed` - // variant (was introduced before `DaemonTcp::HttpRequestChunked`). - HttpRequestFallback::Streamed { request, .. } => HttpResponseFallback::Framed( - HttpResponse::::response_from_request(request, status, message), - ), + HttpResponseFallback::Framed(req) => req.internal_response.into(), + HttpResponseFallback::Fallback(req) => req.internal_response.into(), + HttpResponseFallback::Streamed(req) => req.internal_response.into(), } } } #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] -#[bincode(bounds = "for<'de> Body: Serialize + Deserialize<'de>")] -pub struct HttpResponse { +#[bincode(bounds = "for<'de> B: Serialize + Deserialize<'de>")] +pub struct HttpResponse { /// This is used to make sure the response is sent in its turn, after responses to all earlier /// requests were already sent. pub port: Port, pub connection_id: ConnectionId, pub request_id: RequestId, #[bincode(with_serde)] - pub internal_response: InternalHttpResponse, + pub internal_response: InternalHttpResponse, } -impl HttpResponse { - pub fn map_body B2>(self, map: F) -> HttpResponse { +impl HttpResponse { + pub fn map_body(self, cb: F) -> HttpResponse + where + F: FnOnce(B) -> T, + { HttpResponse { connection_id: self.connection_id, request_id: self.request_id, port: self.port, - internal_response: InternalHttpResponse { - status: self.internal_response.status, - version: self.internal_response.version, - headers: self.internal_response.headers, - body: map(self.internal_response.body), - }, - } - } -} - -impl HttpResponse { - /// We cannot implement this with the [`From`] trait as it doesn't support `async` conversions, - /// and we also need some extra parameters. - /// - /// So this is our alternative implementation to `From>`. - #[tracing::instrument(level = Level::TRACE, err(level = Level::WARN))] - pub async fn from_hyper_response( - response: Response, - port: Port, - connection_id: ConnectionId, - request_id: RequestId, - ) -> Result, hyper::Error> { - let ( - Parts { - status, - version, - headers, - .. - }, - body, - ) = response.into_parts(); - - let body = InternalHttpBody::from_body(body).await?; - - let internal_response = InternalHttpResponse { - status, - headers, - version, - body, - }; - - Ok(HttpResponse { - request_id, - port, - connection_id, - internal_response, - }) - } - - pub fn response_from_request( - request: HttpRequest, - status: StatusCode, - message: &str, - ) -> Self { - let HttpRequest { - internal_request: InternalHttpRequest { version, .. }, - connection_id, - request_id, - port, - } = request; - - let body = InternalHttpBody::from_bytes( - format!( - "{} {}\n{}\n", - status.as_str(), - status.canonical_reason().unwrap_or_default(), - message - ) - .as_bytes(), - ); - - Self { - port, - connection_id, - request_id, - internal_response: InternalHttpResponse { - status, - version, - headers: Default::default(), - body, - }, - } - } - - pub fn empty_response_from_request( - request: HttpRequest, - status: StatusCode, - ) -> Self { - let HttpRequest { - internal_request: InternalHttpRequest { version, .. }, - connection_id, - request_id, - port, - } = request; - - Self { - port, - connection_id, - request_id, - internal_response: InternalHttpResponse { - status, - version, - headers: Default::default(), - body: Default::default(), - }, + internal_response: self.internal_response.map_body(cb), } } } -impl HttpResponse> { - /// We cannot implement this with the [`From`] trait as it doesn't support `async` conversions, - /// and we also need some extra parameters. - /// - /// So this is our alternative implementation to `From>`. - #[tracing::instrument(level = Level::TRACE, err(level = Level::WARN))] - pub async fn from_hyper_response( - response: Response, - port: Port, - connection_id: ConnectionId, - request_id: RequestId, - ) -> Result>, hyper::Error> { - let ( - Parts { - status, - version, - headers, - .. - }, - body, - ) = response.into_parts(); - - let body = body.collect().await?.to_bytes().to_vec(); - - let internal_response = InternalHttpResponse { - status, - headers, - version, - body, - }; - - Ok(HttpResponse { - request_id, - port, - connection_id, - internal_response, - }) - } - - pub fn response_from_request( - request: HttpRequest>, - status: StatusCode, - message: &str, - ) -> Self { - let HttpRequest { - internal_request: InternalHttpRequest { version, .. }, - connection_id, - request_id, - port, - } = request; - - let body = format!( - "{} {}\n{}\n", - status.as_str(), - status.canonical_reason().unwrap_or_default(), - message - ) - .into_bytes(); - - Self { - port, - connection_id, - request_id, - internal_response: InternalHttpResponse { - status, - version, - headers: Default::default(), - body, - }, - } - } - - pub fn empty_response_from_request(request: HttpRequest>, status: StatusCode) -> Self { - let HttpRequest { - internal_request: InternalHttpRequest { version, .. }, - connection_id, - request_id, - port, - } = request; - - Self { - port, - connection_id, - request_id, - internal_response: InternalHttpResponse { - status, - version, - headers: Default::default(), - body: Default::default(), - }, - } - } -} - -impl HttpResponse { - #[tracing::instrument(level = Level::TRACE, err(level = Level::WARN))] - pub async fn from_hyper_response( - response: Response, - port: Port, - connection_id: ConnectionId, - request_id: RequestId, - ) -> Result, hyper::Error> { - let ( - Parts { - status, - version, - headers, - .. - }, - mut body, - ) = response.into_parts(); - - let frames = body.ready_frames()?; - let (tx, rx) = tokio::sync::mpsc::channel(frames.frames.len().max(12)); - for frame in frames.frames { - tx.try_send(Ok(frame)) - .expect("Channel is open, capacity sufficient") - } - if !frames.is_last { - tokio::spawn(async move { - while let Some(frame) = body.frame().await { - if tx.send(frame).await.is_err() { - return; - } - } - }); - }; - - let body = StreamBody::new(ReceiverStream::from(rx)); - - let internal_response = InternalHttpResponse { - status, - headers, - version, - body, - }; - - Ok(HttpResponse { - request_id, - port, - connection_id, - internal_response, - }) - } - - #[tracing::instrument(level = Level::TRACE, ret)] - pub fn response_from_request( - request: HttpRequest, - status: StatusCode, - message: &str, - ) -> Self { - let HttpRequest { - internal_request: InternalHttpRequest { version, .. }, - connection_id, - request_id, - port, - } = request; - - let (tx, rx) = tokio::sync::mpsc::channel(1); - let frame = Frame::data(Bytes::copy_from_slice(message.as_bytes())); - tx.try_send(Ok(frame)) - .expect("channel is open, capacity is sufficient"); - let body = StreamBody::new(ReceiverStream::new(rx)); - - Self { - port, - connection_id, - request_id, - internal_response: InternalHttpResponse { - status, - version, - headers: Default::default(), - body, - }, - } - } -} - -impl TryFrom> for Response> { - type Error = http::Error; - - fn try_from(value: InternalHttpResponse) -> Result { +impl From>> for Response> { + fn from(value: InternalHttpResponse>) -> Self { let InternalHttpResponse { status, version, @@ -1032,19 +527,21 @@ impl TryFrom> for Response TryFrom>> for Response> { - type Error = http::Error; - - fn try_from(value: InternalHttpResponse>) -> Result { +impl From> for Response> { + fn from(value: InternalHttpResponse) -> Self { let InternalHttpResponse { status, version, @@ -1052,24 +549,17 @@ impl TryFrom>> for Response> { body, } = value; - let mut builder = Response::builder().status(status).version(version); - if let Some(h) = builder.headers_mut() { - *h = headers; - } + let mut response = Response::new(body.map_err(|_| unreachable!()).boxed()); + *response.status_mut() = status; + *response.version_mut() = version; + *response.headers_mut() = headers; - builder.body(BoxBody::new( - Full::new(Bytes::from(body)).map_err(|_| unreachable!()), - )) + response } } -impl TryFrom> for Response> -where - E: From, -{ - type Error = http::Error; - - fn try_from(value: InternalHttpResponse) -> Result { +impl From> for Response> { + fn from(value: InternalHttpResponse) -> Self { let InternalHttpResponse { status, version, @@ -1077,11 +567,11 @@ where body, } = value; - let mut builder = Response::builder().status(status).version(version); - if let Some(h) = builder.headers_mut() { - *h = headers; - } + let mut response = Response::new(body.map_err(|_| unreachable!()).boxed()); + *response.status_mut() = status; + *response.version_mut() = version; + *response.headers_mut() = headers; - builder.body(BoxBody::new(body.map_err(|e| e.into()))) + response } } From 4b79bf737f603d5908ffe2add5f7609ee0c1dfb1 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 14 Jan 2025 16:07:14 +0100 Subject: [PATCH 07/60] Moved HttpResponseFallback to the agent --- Cargo.lock | 1 - mirrord/agent/src/steal.rs | 4 +- mirrord/agent/src/steal/api.rs | 7 +- mirrord/agent/src/steal/connection.rs | 6 +- mirrord/agent/src/steal/http.rs | 7 +- .../agent/src/steal/http/response_fallback.rs | 58 +++++++++ mirrord/protocol/Cargo.toml | 1 - mirrord/protocol/src/tcp.rs | 115 +++--------------- 8 files changed, 90 insertions(+), 109 deletions(-) create mode 100644 mirrord/agent/src/steal/http/response_fallback.rs diff --git a/Cargo.lock b/Cargo.lock index 9c318674fad..61932caa8e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4498,7 +4498,6 @@ dependencies = [ "serde", "socket2", "thiserror 2.0.9", - "tokio-stream", "tracing", ] diff --git a/mirrord/agent/src/steal.rs b/mirrord/agent/src/steal.rs index 399c0597d4e..a425748a0d8 100644 --- a/mirrord/agent/src/steal.rs +++ b/mirrord/agent/src/steal.rs @@ -1,5 +1,5 @@ use mirrord_protocol::{ - tcp::{DaemonTcp, HttpResponseFallback, StealType, TcpData}, + tcp::{DaemonTcp, StealType, TcpData}, ConnectionId, Port, }; use tokio::sync::mpsc::Sender; @@ -17,6 +17,8 @@ mod subscriptions; pub(crate) use api::TcpStealerApi; pub(crate) use connection::TcpConnectionStealer; +use self::http::HttpResponseFallback; + /// Commands from the agent that are passed down to the stealer worker, through [`TcpStealerApi`]. /// /// These are the operations that the agent receives from the layer to make the _steal_ feature diff --git a/mirrord/agent/src/steal/api.rs b/mirrord/agent/src/steal/api.rs index 0f610b3e1a1..efc9cfd4e33 100644 --- a/mirrord/agent/src/steal/api.rs +++ b/mirrord/agent/src/steal/api.rs @@ -3,16 +3,13 @@ use std::{collections::HashMap, convert::Infallible}; use bytes::Bytes; use hyper::body::Frame; use mirrord_protocol::{ - tcp::{ - ChunkedResponse, DaemonTcp, HttpResponse, HttpResponseFallback, InternalHttpResponse, - LayerTcpSteal, ReceiverStreamBody, TcpData, - }, + tcp::{ChunkedResponse, DaemonTcp, HttpResponse, InternalHttpResponse, LayerTcpSteal, TcpData}, RequestId, }; use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio_stream::wrappers::ReceiverStream; -use super::*; +use super::{http::ReceiverStreamBody, *}; use crate::{ error::{AgentError, Result}, util::ClientId, diff --git a/mirrord/agent/src/steal/connection.rs b/mirrord/agent/src/steal/connection.rs index 4ad9e92816e..5e4b6b1219a 100644 --- a/mirrord/agent/src/steal/connection.rs +++ b/mirrord/agent/src/steal/connection.rs @@ -15,9 +15,8 @@ use mirrord_protocol::{ batched_body::{BatchedBody, Frames}, tcp::{ ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, DaemonTcp, HttpRequest, - HttpResponseFallback, InternalHttpBody, InternalHttpBodyFrame, InternalHttpRequest, - StealType, TcpClose, TcpData, HTTP_CHUNKED_REQUEST_VERSION, HTTP_FILTERED_UPGRADE_VERSION, - HTTP_FRAMED_VERSION, + InternalHttpBody, InternalHttpBodyFrame, InternalHttpRequest, StealType, TcpClose, TcpData, + HTTP_CHUNKED_REQUEST_VERSION, HTTP_FILTERED_UPGRADE_VERSION, HTTP_FRAMED_VERSION, }, ConnectionId, Port, RemoteError::{BadHttpFilterExRegex, BadHttpFilterRegex}, @@ -31,6 +30,7 @@ use tokio::{ use tokio_util::sync::CancellationToken; use tracing::warn; +use super::http::HttpResponseFallback; use crate::{ error::{AgentError, Result}, steal::{ diff --git a/mirrord/agent/src/steal/http.rs b/mirrord/agent/src/steal/http.rs index 159d9c9aac8..cad0308bc96 100644 --- a/mirrord/agent/src/steal/http.rs +++ b/mirrord/agent/src/steal/http.rs @@ -3,11 +3,12 @@ use crate::http::HttpVersion; mod filter; +mod response_fallback; mod reversible_stream; -pub use filter::HttpFilter; - -pub(crate) use self::reversible_stream::ReversibleStream; +pub(crate) use filter::HttpFilter; +pub(crate) use response_fallback::{HttpResponseFallback, ReceiverStreamBody}; +pub(crate) use reversible_stream::ReversibleStream; /// Handy alias due to [`ReversibleStream`] being generic, avoiding value mismatches. pub(crate) type DefaultReversibleStream = ReversibleStream<{ HttpVersion::MINIMAL_HEADER_SIZE }>; diff --git a/mirrord/agent/src/steal/http/response_fallback.rs b/mirrord/agent/src/steal/http/response_fallback.rs new file mode 100644 index 00000000000..2124ec41a57 --- /dev/null +++ b/mirrord/agent/src/steal/http/response_fallback.rs @@ -0,0 +1,58 @@ +use std::convert::Infallible; + +use bytes::Bytes; +use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; +use hyper::{body::Frame, Response}; +use mirrord_protocol::{ + tcp::{HttpResponse, InternalHttpBody}, + ConnectionId, RequestId, +}; +use tokio_stream::wrappers::ReceiverStream; + +pub type ReceiverStreamBody = StreamBody, Infallible>>>; + +#[derive(Debug)] +pub enum HttpResponseFallback { + Framed(HttpResponse), + Fallback(HttpResponse>), + Streamed(HttpResponse), +} + +impl HttpResponseFallback { + pub fn connection_id(&self) -> ConnectionId { + match self { + HttpResponseFallback::Framed(req) => req.connection_id, + HttpResponseFallback::Fallback(req) => req.connection_id, + HttpResponseFallback::Streamed(req) => req.connection_id, + } + } + + pub fn request_id(&self) -> RequestId { + match self { + HttpResponseFallback::Framed(req) => req.request_id, + HttpResponseFallback::Fallback(req) => req.request_id, + HttpResponseFallback::Streamed(req) => req.request_id, + } + } + + pub fn into_hyper(self) -> Response> { + match self { + HttpResponseFallback::Framed(req) => req + .internal_response + .map_body(|body| body.map_err(|_| unreachable!()).boxed()) + .into(), + HttpResponseFallback::Fallback(req) => req + .internal_response + .map_body(|body| { + Full::new(Bytes::from_owner(body)) + .map_err(|_| unreachable!()) + .boxed() + }) + .into(), + HttpResponseFallback::Streamed(req) => req + .internal_response + .map_body(|body| body.map_err(|_| unreachable!()).boxed()) + .into(), + } + } +} diff --git a/mirrord/protocol/Cargo.toml b/mirrord/protocol/Cargo.toml index a7f578d482f..841d6b6f776 100644 --- a/mirrord/protocol/Cargo.toml +++ b/mirrord/protocol/Cargo.toml @@ -32,7 +32,6 @@ http-body-util = { workspace = true } fancy-regex = { workspace = true } socket2.workspace = true semver = { workspace = true, features = ["serde"] } -tokio-stream.workspace = true mirrord-macros = { path = "../macros" } diff --git a/mirrord/protocol/src/tcp.rs b/mirrord/protocol/src/tcp.rs index 32ca5fef5f9..acf3d734121 100644 --- a/mirrord/protocol/src/tcp.rs +++ b/mirrord/protocol/src/tcp.rs @@ -11,7 +11,7 @@ use std::{ use bincode::{Decode, Encode}; use bytes::Bytes; -use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; +use http_body_util::BodyExt; use hyper::{ body::{Body, Frame}, HeaderMap, Method, Request, Response, StatusCode, Uri, Version, @@ -19,7 +19,6 @@ use hyper::{ use mirrord_macros::protocol_break; use semver::VersionReq; use serde::{Deserialize, Serialize}; -use tokio_stream::wrappers::ReceiverStream; use crate::{ConnectionId, Port, RemoteResult, RequestId}; @@ -283,6 +282,7 @@ impl From> for Request { version, body, } = value; + let mut request = Request::new(body); *request.method_mut() = method; *request.uri_mut() = uri; @@ -392,6 +392,24 @@ impl InternalHttpResponse { } } +impl From> for Response { + fn from(value: InternalHttpResponse) -> Self { + let InternalHttpResponse { + status, + version, + headers, + body, + } = value; + + let mut response = Response::new(body); + *response.status_mut() = status; + *response.version_mut() = version; + *response.headers_mut() = headers; + + response + } +} + #[derive(Serialize, Deserialize, Debug, Default, PartialEq, Eq, Clone)] pub struct InternalHttpBody(pub VecDeque); @@ -457,41 +475,6 @@ impl fmt::Debug for InternalHttpBodyFrame { } } -pub type ReceiverStreamBody = StreamBody, Infallible>>>; - -#[derive(Debug)] -pub enum HttpResponseFallback { - Framed(HttpResponse), - Fallback(HttpResponse>), - Streamed(HttpResponse), -} - -impl HttpResponseFallback { - pub fn connection_id(&self) -> ConnectionId { - match self { - HttpResponseFallback::Framed(req) => req.connection_id, - HttpResponseFallback::Fallback(req) => req.connection_id, - HttpResponseFallback::Streamed(req) => req.connection_id, - } - } - - pub fn request_id(&self) -> RequestId { - match self { - HttpResponseFallback::Framed(req) => req.request_id, - HttpResponseFallback::Fallback(req) => req.request_id, - HttpResponseFallback::Streamed(req) => req.request_id, - } - } - - pub fn into_hyper(self) -> Response> { - match self { - HttpResponseFallback::Framed(req) => req.internal_response.into(), - HttpResponseFallback::Fallback(req) => req.internal_response.into(), - HttpResponseFallback::Streamed(req) => req.internal_response.into(), - } - } -} - #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] #[bincode(bounds = "for<'de> B: Serialize + Deserialize<'de>")] pub struct HttpResponse { @@ -517,61 +500,3 @@ impl HttpResponse { } } } - -impl From>> for Response> { - fn from(value: InternalHttpResponse>) -> Self { - let InternalHttpResponse { - status, - version, - headers, - body, - } = value; - - let mut response = Response::new( - Full::new(Bytes::from_owner(body)) - .map_err(|_| unreachable!()) - .boxed(), - ); - *response.status_mut() = status; - *response.version_mut() = version; - *response.headers_mut() = headers; - - response - } -} - -impl From> for Response> { - fn from(value: InternalHttpResponse) -> Self { - let InternalHttpResponse { - status, - version, - headers, - body, - } = value; - - let mut response = Response::new(body.map_err(|_| unreachable!()).boxed()); - *response.status_mut() = status; - *response.version_mut() = version; - *response.headers_mut() = headers; - - response - } -} - -impl From> for Response> { - fn from(value: InternalHttpResponse) -> Self { - let InternalHttpResponse { - status, - version, - headers, - body, - } = value; - - let mut response = Response::new(body.map_err(|_| unreachable!()).boxed()); - *response.status_mut() = status; - *response.version_mut() = version; - *response.headers_mut() = headers; - - response - } -} From 3486e0a713010dfd6da77be7890ef64d82e1fedd Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 14 Jan 2025 16:14:41 +0100 Subject: [PATCH 08/60] Moved frame senders to InterceptorHandle --- mirrord/intproxy/src/proxies/incoming.rs | 71 ++++++++++++++---------- 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index c83e8ff5387..5a3dcf8fbd8 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -86,6 +86,8 @@ struct InterceptorHandle { tx: TaskSender, /// Port subscription that the intercepted connection belongs to. subscription: PortSubscription, + /// Senders for the bodies of in-progress HTTP requests. + request_body_txs: HashMap>, } /// Handles logic and state of the `incoming` feature. @@ -108,8 +110,6 @@ pub struct IncomingProxy { background_tasks: BackgroundTasks, /// For managing intercepted connections metadata. metadata_store: MetadataStore, - /// For managing streamed [`DaemonTcp::HttpRequestChunked`] request channels. - request_body_txs: HashMap<(ConnectionId, RequestId), Sender>, /// For managing streamed [`LayerTcpSteal::HttpResponseChunked`] response streams. response_body_rxs: StreamMap<(ConnectionId, RequestId), StreamNotifyClose>>, @@ -165,7 +165,7 @@ impl IncomingProxy { connection_id: ConnectionId, port: Port, message_bus: &MessageBus, - ) -> Result>, IncomingProxyError> { + ) -> Result, IncomingProxyError> { let id: InterceptorId = InterceptorId(connection_id); let interceptor = match self.interceptors.entry(id) { @@ -201,11 +201,12 @@ impl IncomingProxy { e.insert(InterceptorHandle { tx: interceptor, subscription: subscription.subscription.clone(), + request_body_txs: Default::default(), }) } }; - Ok(Some(&interceptor.tx)) + Ok(Some(interceptor)) } /// Handles all agent messages. @@ -220,8 +221,6 @@ impl IncomingProxy { DaemonTcp::Close(close) => { self.interceptors .remove(&InterceptorId(close.connection_id)); - self.request_body_txs - .retain(|(connection_id, _), _| *connection_id != close.connection_id); let keys: Vec<(ConnectionId, RequestId)> = self .response_body_rxs .keys() @@ -256,6 +255,7 @@ impl IncomingProxy { if let Some(interceptor) = interceptor { interceptor + .tx .send(request.map_body(StreamingBody::from)) .await; } @@ -272,6 +272,7 @@ impl IncomingProxy { if let Some(interceptor) = interceptor { interceptor + .tx .send(request.map_body(StreamingBody::from)) .await; } @@ -291,9 +292,8 @@ impl IncomingProxy { if let Some(interceptor) = interceptor { let (tx, rx) = mpsc::channel::(128); let request = request.map_body(|frames| StreamingBody::new(rx, frames)); - let key = (request.connection_id, request.request_id); - interceptor.send(request).await; - self.request_body_txs.insert(key, tx); + interceptor.request_body_txs.insert(request.request_id, tx); + interceptor.tx.send(request).await; } } @@ -303,33 +303,47 @@ impl IncomingProxy { connection_id, request_id, }) => { - if let Some(tx) = self.request_body_txs.get(&(connection_id, request_id)) { - let mut send_err = false; - - for frame in frames { - if let Err(err) = tx.send(frame).await { - send_err = true; - tracing::debug!( - frame = ?err.0, - connection_id, - request_id, - "Failed to send an HTTP request body frame to the interceptor, channel is closed" - ); - break; - } - } + let Some(interceptor) = + self.interceptors.get_mut(&InterceptorId(connection_id)) + else { + return Ok(()); + }; - if send_err || is_last { - self.request_body_txs.remove(&(connection_id, request_id)); + let Entry::Occupied(tx) = interceptor.request_body_txs.entry(request_id) + else { + return Ok(()); + }; + + let mut send_err = false; + + for frame in frames { + if let Err(err) = tx.get().send(frame).await { + send_err = true; + tracing::debug!( + frame = ?err.0, + connection_id, + request_id, + "Failed to send an HTTP request body frame to the interceptor, channel is closed" + ); + break; } } + + if send_err || is_last { + tx.remove(); + } } ChunkedRequest::Error(ChunkedHttpError { connection_id, request_id, }) => { - self.request_body_txs.remove(&(connection_id, request_id)); + if let Some(interceptor) = + self.interceptors.get_mut(&InterceptorId(connection_id)) + { + interceptor.request_body_txs.remove(&request_id); + }; + tracing::debug!( connection_id, request_id, @@ -391,6 +405,7 @@ impl IncomingProxy { InterceptorHandle { tx: interceptor, subscription: subscription.subscription.clone(), + request_body_txs: Default::default(), }, ); } @@ -607,7 +622,7 @@ impl BackgroundTask for IncomingProxy { message_bus.send(msg).await; } - self.request_body_txs.retain(|(connection_id, _), _| *connection_id != id.0); + self.interceptors.remove(&id); }, (id, TaskUpdate::Message(msg)) => { From e5b5d933b2ab36706775f1d0b1600bcb0dd87358 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 14 Jan 2025 17:16:11 +0100 Subject: [PATCH 09/60] HttpResponseReaders in IncomingProxy --- mirrord/intproxy/Cargo.toml | 2 +- mirrord/intproxy/src/proxies/incoming.rs | 161 ++++++++-------- mirrord/intproxy/src/proxies/incoming/http.rs | 37 +++- .../src/proxies/incoming/interceptor.rs | 27 ++- .../src/proxies/incoming/response_reader.rs | 172 ++++++++++++++++++ 5 files changed, 292 insertions(+), 107 deletions(-) create mode 100644 mirrord/intproxy/src/proxies/incoming/response_reader.rs diff --git a/mirrord/intproxy/Cargo.toml b/mirrord/intproxy/Cargo.toml index ede6a260c02..f7972fce76a 100644 --- a/mirrord/intproxy/Cargo.toml +++ b/mirrord/intproxy/Cargo.toml @@ -38,7 +38,6 @@ h2 = "0.4" hyper-util.workspace = true http-body-util.workspace = true bytes.workspace = true -futures.workspace = true rand.workspace = true tokio-rustls.workspace = true rustls.workspace = true @@ -46,5 +45,6 @@ rustls-pemfile.workspace = true exponential-backoff = "2" [dev-dependencies] +futures.workspace = true reqwest.workspace = true rstest.workspace = true diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 5a3dcf8fbd8..0aa4acdada1 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -2,15 +2,14 @@ use std::{ collections::{hash_map::Entry, HashMap, VecDeque}, + convert::Infallible, fmt, io, net::SocketAddr, }; use bound_socket::BoundTcpSocket; -use futures::StreamExt; use http::PeekedBody; -use http_body_util::BodyStream; -use hyper::body::{Frame, Incoming}; +use hyper::body::Frame; use metadata_store::MetadataStore; use mirrord_intproxy_protocol::{ ConnMetadataRequest, ConnMetadataResponse, IncomingRequest, IncomingResponse, LayerId, @@ -24,11 +23,11 @@ use mirrord_protocol::{ }, ClientMessage, ConnectionId, Port, RequestId, ResponseError, }; +use response_reader::HttpResponseReader; use streaming_body::StreamingBody; use thiserror::Error; use tokio::sync::mpsc::{self, Sender}; -use tokio_stream::{StreamMap, StreamNotifyClose}; -use tracing::{debug, Level}; +use tracing::Level; use self::{ interceptor::{Interceptor, InterceptorError, MessageOut}, @@ -36,7 +35,9 @@ use self::{ subscriptions::SubscriptionsManager, }; use crate::{ - background_tasks::{BackgroundTask, BackgroundTasks, MessageBus, TaskSender, TaskUpdate}, + background_tasks::{ + BackgroundTask, BackgroundTasks, MessageBus, TaskError, TaskSender, TaskUpdate, + }, main_tasks::{LayerClosed, LayerForked, ToLayer}, ProxyMessage, }; @@ -46,13 +47,14 @@ mod http; mod interceptor; mod metadata_store; pub mod port_subscription_ext; +mod response_reader; mod streaming_body; mod subscriptions; /// Id of a single [`Interceptor`] task. Used to manage interceptor tasks with the /// [`BackgroundTasks`] struct. #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct InterceptorId(pub ConnectionId); +pub struct InterceptorId(ConnectionId); impl fmt::Display for InterceptorId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -60,6 +62,9 @@ impl fmt::Display for InterceptorId { } } +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct ReaderId(ConnectionId, RequestId); + /// Errors that can occur when handling the `incoming` feature. #[derive(Error, Debug)] pub enum IncomingProxyError { @@ -104,17 +109,22 @@ struct InterceptorHandle { pub struct IncomingProxy { /// Active port subscriptions for all layers. subscriptions: SubscriptionsManager, - /// [`TaskSender`]s for active [`Interceptor`]s. - interceptors: HashMap, - /// For receiving updates from [`Interceptor`]s. - background_tasks: BackgroundTasks, /// For managing intercepted connections metadata. metadata_store: MetadataStore, - /// For managing streamed [`LayerTcpSteal::HttpResponseChunked`] response streams. - response_body_rxs: - StreamMap<(ConnectionId, RequestId), StreamNotifyClose>>, - /// Version of [`mirrord_protocol`] negotiated with the agent. + /// Determines which version of [`LayerTcpSteal`] we use to send HTTP responses to the agent. agent_protocol_version: Option, + + /// [`TaskSender`]s for active [`Interceptor`]s. + interceptor_handles: HashMap, + /// For receiving updates from [`Interceptor`]s. + interceptors: BackgroundTasks, + + /// [TaskSender]s for active [`HttpResponseReader`]s. + /// + /// Keep the readers alive. + readers_txs: HashMap>, + /// For reading bodies of user app's HTTP responses. + readers: BackgroundTasks, } impl IncomingProxy { @@ -168,7 +178,7 @@ impl IncomingProxy { ) -> Result, IncomingProxyError> { let id: InterceptorId = InterceptorId(connection_id); - let interceptor = match self.interceptors.entry(id) { + let interceptor = match self.interceptor_handles.entry(id) { Entry::Occupied(e) => e.into_mut(), Entry::Vacant(e) => { @@ -192,7 +202,7 @@ impl IncomingProxy { let interceptor_socket = BoundTcpSocket::bind_specified_or_localhost(subscription.listening_on.ip())?; - let interceptor = self.background_tasks.register( + let interceptor = self.interceptors.register( Interceptor::new(interceptor_socket, subscription.listening_on), id, Self::CHANNEL_SIZE, @@ -219,21 +229,15 @@ impl IncomingProxy { ) -> Result<(), IncomingProxyError> { match message { DaemonTcp::Close(close) => { - self.interceptors + self.readers_txs.retain(|id, _| id.0 != close.connection_id); + self.interceptor_handles .remove(&InterceptorId(close.connection_id)); - let keys: Vec<(ConnectionId, RequestId)> = self - .response_body_rxs - .keys() - .filter(|key| key.0 == close.connection_id) - .cloned() - .collect(); - for key in keys.iter() { - self.response_body_rxs.remove(key); - } } DaemonTcp::Data(data) => { - if let Some(interceptor) = self.interceptors.get(&InterceptorId(data.connection_id)) + if let Some(interceptor) = self + .interceptor_handles + .get(&InterceptorId(data.connection_id)) { interceptor.tx.send(data.bytes).await; } else { @@ -303,8 +307,9 @@ impl IncomingProxy { connection_id, request_id, }) => { - let Some(interceptor) = - self.interceptors.get_mut(&InterceptorId(connection_id)) + let Some(interceptor) = self + .interceptor_handles + .get_mut(&InterceptorId(connection_id)) else { return Ok(()); }; @@ -338,8 +343,9 @@ impl IncomingProxy { connection_id, request_id, }) => { - if let Some(interceptor) = - self.interceptors.get_mut(&InterceptorId(connection_id)) + if let Some(interceptor) = self + .interceptor_handles + .get_mut(&InterceptorId(connection_id)) { interceptor.request_body_txs.remove(&request_id); }; @@ -394,13 +400,13 @@ impl IncomingProxy { }, ); - let interceptor = self.background_tasks.register( + let interceptor = self.interceptors.register( Interceptor::new(interceptor_socket, subscription.listening_on), id, Self::CHANNEL_SIZE, ); - self.interceptors.insert( + self.interceptor_handles.insert( id, InterceptorHandle { tx: interceptor, @@ -436,7 +442,7 @@ impl IncomingProxy { } fn get_subscription(&self, interceptor_id: InterceptorId) -> Option<&PortSubscription> { - self.interceptors + self.interceptor_handles .get(&interceptor_id) .map(|handle| &handle.subscription) } @@ -451,7 +457,7 @@ impl IncomingProxy { mut response: HttpResponse, message_bus: &mut MessageBus, ) { - let _tail = match response.internal_response.body.tail.take() { + let tail = match response.internal_response.body.tail.take() { Some(tail) => tail, // All frames are already fetched, we don't have to stream the body to the agent. @@ -501,23 +507,34 @@ impl IncomingProxy { } }; - if self.agent_handles_streamed_responses() { + let reader_id = ReaderId(response.connection_id, response.request_id); + let response_reader = if self.agent_handles_streamed_responses() { let response = response.map_body(|body| { body.head .into_iter() .map(InternalHttpBodyFrame::from) .collect::>() }); + let connection_id = response.connection_id; + let request_id = response.request_id; let message = ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( ChunkedResponse::Start(response), )); message_bus.send(message).await; - todo!("start response reader") + HttpResponseReader::Chunked { + connection_id, + request_id, + body: tail, + } } else if self.agent_handles_framed_responses() { - todo!("start response reader") + response.internal_response.body.tail.replace(tail); + HttpResponseReader::Framed(response) } else { - todo!("start response reader") - } + response.internal_response.body.tail.replace(tail); + HttpResponseReader::Legacy(response) + }; + + self.readers.register(response_reader, reader_id, 16); } fn agent_handles_framed_responses(&self) -> bool { @@ -544,47 +561,6 @@ impl BackgroundTask for IncomingProxy { async fn run(mut self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { loop { tokio::select! { - Some(((connection_id, request_id), stream_item)) = self.response_body_rxs.next() => match stream_item { - Some(Ok(frame)) => { - let int_frame = InternalHttpBodyFrame::from(frame); - let res = ChunkedResponse::Body(ChunkedHttpBody { - frames: vec![int_frame], - is_last: false, - connection_id, - request_id, - }); - message_bus - .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( - res, - ))) - .await; - }, - Some(Err(error)) => { - debug!(%error, "Error while reading streamed response body"); - let res = ChunkedResponse::Error(ChunkedHttpError {connection_id, request_id}); - message_bus - .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( - res, - ))) - .await; - self.response_body_rxs.remove(&(connection_id, request_id)); - }, - None => { - let res = ChunkedResponse::Body(ChunkedHttpBody { - frames: vec![], - is_last: true, - connection_id, - request_id, - }); - message_bus - .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( - res, - ))) - .await; - self.response_body_rxs.remove(&(connection_id, request_id)); - } - }, - msg = message_bus.recv() => match msg { None => { tracing::trace!("message bus closed, exiting"); @@ -611,7 +587,7 @@ impl BackgroundTask for IncomingProxy { } }, - Some(task_update) = self.background_tasks.next() => match task_update { + Some(task_update) = self.interceptors.next() => match task_update { (id, TaskUpdate::Finished(res)) => { tracing::trace!("{id} finished: {res:?}"); @@ -622,7 +598,7 @@ impl BackgroundTask for IncomingProxy { message_bus.send(msg).await; } - self.interceptors.remove(&id); + self.interceptor_handles.remove(&id); }, (id, TaskUpdate::Message(msg)) => { @@ -646,6 +622,23 @@ impl BackgroundTask for IncomingProxy { }; }, }, + + Some(task_update) = self.readers.next() => match task_update { + (id, TaskUpdate::Finished(Ok(()))) => { + self.readers_txs.remove(&id); + } + + (id, TaskUpdate::Finished(Err(TaskError::Panic))) => { + tracing::error!(connection_id = id.0, request_id = id.1, "HttpResponseReader task panicked"); + + self.interceptor_handles.remove(&InterceptorId(id.0)); + self.readers_txs.remove(&id); + } + + (_, TaskUpdate::Message(msg)) => { + message_bus.send(ClientMessage::TcpSteal(msg)).await; + } + } } } } diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index dd8f2874ed7..25e53b4bfa2 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -5,10 +5,14 @@ use exponential_backoff::Backoff; use hyper::{ body::{Frame, Incoming}, client::conn::{http1, http2}, - Request, Response, Version, + Request, Response, StatusCode, Version, }; use hyper_util::rt::{TokioExecutor, TokioIo}; -use mirrord_protocol::{batched_body::BatchedBody, tcp::HttpRequest}; +use mirrord_protocol::{ + batched_body::BatchedBody, + tcp::{HttpRequest, HttpResponse, InternalHttpResponse}, + ConnectionId, Port, RequestId, +}; use thiserror::Error; use tokio::{net::TcpStream, time}; use tracing::Level; @@ -85,7 +89,7 @@ impl LocalHttpClient { let frames = body .ready_frames() - .map_err(LocalHttpError::PeekBodyFailed)?; + .map_err(LocalHttpError::ReadBodyFailed)?; let body = PeekedBody { head: frames.frames, tail: frames.is_last.not().then_some(body), @@ -182,8 +186,8 @@ pub enum LocalHttpError { #[error("making a TPC connection failed: {0}")] ConnectTcpFailed(#[source] io::Error), - #[error("reading first frames of the response body failed: {0}")] - PeekBodyFailed(#[source] hyper::Error), + #[error("reading the response body failed: {0}")] + ReadBodyFailed(#[source] hyper::Error), } impl LocalHttpError { @@ -210,11 +214,32 @@ impl LocalHttpError { match self { Self::SocketSetupFailed(..) | Self::UnsupportedHttpVersion(..) => false, Self::ConnectTcpFailed(..) => true, - Self::HandshakeFailed(err) | Self::SendFailed(err) | Self::PeekBodyFailed(err) => { + Self::HandshakeFailed(err) | Self::SendFailed(err) | Self::ReadBodyFailed(err) => { err.is_closed() || err.is_incomplete_message() || Self::is_h2_reset(err) } } } + + /// Produces a [`StatusCode::BAD_GATEWAY`] response from this error. + pub fn as_error_response( + &self, + version: Version, + request_id: RequestId, + connection_id: ConnectionId, + port: Port, + ) -> HttpResponse> { + HttpResponse { + request_id, + connection_id, + port, + internal_response: InternalHttpResponse { + status: StatusCode::BAD_GATEWAY, + version, + headers: Default::default(), + body: format!("mirrord: {self}").into_bytes(), + }, + } + } } /// Response body returned from [`LocalHttpClient`]. diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs index 7a14e7daf1f..f5972351230 100644 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ b/mirrord/intproxy/src/proxies/incoming/interceptor.rs @@ -184,22 +184,17 @@ impl HttpConnection { } Err(error) => { - let message_frame = Frame::data(Bytes::from_owner(format!("mirrord: {error}"))); - let body = PeekedBody { - head: vec![message_frame], - tail: None, - }; - let response = HttpResponse { - port: request.port, - connection_id: request.connection_id, - request_id: request.request_id, - internal_response: InternalHttpResponse { - status: StatusCode::BAD_GATEWAY, - version: request.internal_request.version, - headers: Default::default(), - body, - }, - }; + let response = error + .as_error_response( + request.internal_request.version, + request.request_id, + request.connection_id, + request.port, + ) + .map_body(|body| PeekedBody { + head: vec![Frame::data(Bytes::from_owner(body))], + tail: None, + }); Ok((response, None)) } diff --git a/mirrord/intproxy/src/proxies/incoming/response_reader.rs b/mirrord/intproxy/src/proxies/incoming/response_reader.rs new file mode 100644 index 00000000000..56ef95ac1a4 --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/response_reader.rs @@ -0,0 +1,172 @@ +use std::convert::Infallible; + +use http_body_util::BodyExt; +use hyper::body::{Frame, Incoming}; +use mirrord_protocol::{ + batched_body::BatchedBody, + tcp::{ + ChunkedHttpBody, ChunkedHttpError, ChunkedResponse, HttpResponse, InternalHttpBody, + InternalHttpBodyFrame, LayerTcpSteal, + }, + ConnectionId, RequestId, +}; + +use super::http::PeekedBody; +use crate::{ + background_tasks::{BackgroundTask, MessageBus}, + proxies::incoming::http::LocalHttpError, +}; + +pub enum HttpResponseReader { + Legacy(HttpResponse), + Framed(HttpResponse), + Chunked { + connection_id: ConnectionId, + request_id: RequestId, + body: Incoming, + }, +} + +impl BackgroundTask for HttpResponseReader { + type Error = Infallible; + type MessageIn = Infallible; + type MessageOut = LayerTcpSteal; + + async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { + match self { + Self::Legacy(mut response) => { + let tail = match response.internal_response.body.tail.take() { + Some(incoming) => { + tokio::select! { + _ = message_bus.recv() => return Ok(()), + + result = incoming.collect() => match result { + Ok(data) => Vec::from(data.to_bytes()), + + Err(error) => { + let response = LocalHttpError::ReadBodyFailed(error) + .as_error_response( + response.internal_response.version, + response.request_id, + response.connection_id, + response.port, + ); + message_bus.send(LayerTcpSteal::HttpResponse(response)).await; + return Ok(()); + } + } + } + } + + None => vec![], + }; + + let response = response.map_body(|body| { + let mut complete = Vec::with_capacity( + body.head + .iter() + .filter_map(|frame| Some(frame.data_ref()?.len())) + .sum::() + + tail.len(), + ); + for frame in body + .head + .into_iter() + .map(Frame::into_data) + .filter_map(Result::ok) + { + complete.extend(frame); + } + complete.extend(tail); + complete + }); + + message_bus + .send(LayerTcpSteal::HttpResponse(response)) + .await; + } + + Self::Framed(mut response) => { + if let Some(mut incoming) = response.internal_response.body.tail.take() { + loop { + tokio::select! { + _ = message_bus.recv() => return Ok(()), + + result = incoming.next_frames() => match result { + Ok(data) => { + response.internal_response.body.head.extend(data.frames); + if data.is_last { + break; + } + }, + + Err(error) => { + let response = LocalHttpError::ReadBodyFailed(error) + .as_error_response( + response.internal_response.version, + response.request_id, + response.connection_id, + response.port, + ); + message_bus.send(LayerTcpSteal::HttpResponse(response)).await; + return Ok(()); + } + } + } + } + }; + + let response = response.map_body(|body| { + InternalHttpBody( + body.head + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect(), + ) + }); + + message_bus + .send(LayerTcpSteal::HttpResponseFramed(response)) + .await; + } + + Self::Chunked { + connection_id, + request_id, + mut body, + } => loop { + tokio::select! { + _ = message_bus.recv() => return Ok(()), + + result = body.next_frames() => match result { + Ok(data) => { + let message = LayerTcpSteal::HttpResponseChunked(ChunkedResponse::Body(ChunkedHttpBody { + frames: data.frames.into_iter().map(InternalHttpBodyFrame::from).collect(), + is_last: data.is_last, + connection_id, + request_id, + })); + message_bus.send(message).await; + + if data.is_last { + break; + } + }, + + Err(..) => { + let message = LayerTcpSteal::HttpResponseChunked(ChunkedResponse::Error(ChunkedHttpError { + connection_id, + request_id, + })); + message_bus.send(message).await; + + return Ok(()); + } + } + } + }, + } + + Ok(()) + } +} From 520f5e9f155cbe66f11e30e8233a71e179fc19d2 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 14 Jan 2025 17:17:48 +0100 Subject: [PATCH 10/60] Clippy --- mirrord/agent/src/steal/api.rs | 4 +++- mirrord/intproxy/src/proxies/incoming/http.rs | 2 +- mirrord/intproxy/src/proxies/incoming/interceptor.rs | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mirrord/agent/src/steal/api.rs b/mirrord/agent/src/steal/api.rs index efc9cfd4e33..489f1ae5380 100644 --- a/mirrord/agent/src/steal/api.rs +++ b/mirrord/agent/src/steal/api.rs @@ -16,6 +16,8 @@ use crate::{ watched_task::TaskStatus, }; +type ResponseBodyTx = Sender, Infallible>>; + /// Bridges the communication between the agent and the [`TcpConnectionStealer`] task. /// There is an API instance for each connected layer ("client"). All API instances send commands /// On the same stealer command channel, where the layer-independent stealer listens to them. @@ -37,7 +39,7 @@ pub(crate) struct TcpStealerApi { /// View on the stealer task's status. task_status: TaskStatus, - response_body_txs: HashMap<(ConnectionId, RequestId), Sender, Infallible>>>, + response_body_txs: HashMap<(ConnectionId, RequestId), ResponseBodyTx>, } impl TcpStealerApi { diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index 25e53b4bfa2..8c51228b545 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -120,7 +120,7 @@ impl LocalHttpClient { loop { attempt += 1; tracing::trace!(attempt, "Trying to send the request"); - match (self.try_send_request(&request).await, backoffs.next()) { + match (self.try_send_request(request).await, backoffs.next()) { (Ok(response), _) => { tracing::trace!( attempt, diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs index f5972351230..b509f74603c 100644 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ b/mirrord/intproxy/src/proxies/incoming/interceptor.rs @@ -293,9 +293,9 @@ impl RawConnection { /// /// 3. This implementation exits only when an error is encountered or the [`MessageBus`] is /// closed. - async fn run<'a>( + async fn run( mut self, - mut message_bus: PeekableMessageBus<'a, Interceptor>, + mut message_bus: PeekableMessageBus<'_, Interceptor>, ) -> InterceptorResult<()> { let mut buf = BytesMut::with_capacity(64 * 1024); let mut reading_closed = false; From 85940d74253c6a2131559b6116be14d4fa50da2e Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 14 Jan 2025 18:51:00 +0100 Subject: [PATCH 11/60] Better tracing --- mirrord/intproxy/src/background_tasks.rs | 5 +- mirrord/intproxy/src/proxies/incoming.rs | 366 +++++++++--------- mirrord/intproxy/src/proxies/incoming/http.rs | 57 ++- .../src/proxies/incoming/interceptor.rs | 54 ++- .../src/proxies/incoming/response_reader.rs | 140 +++++-- 5 files changed, 382 insertions(+), 240 deletions(-) diff --git a/mirrord/intproxy/src/background_tasks.rs b/mirrord/intproxy/src/background_tasks.rs index 50593f47ddd..99d27a8ab8d 100644 --- a/mirrord/intproxy/src/background_tasks.rs +++ b/mirrord/intproxy/src/background_tasks.rs @@ -8,6 +8,7 @@ use std::{collections::HashMap, fmt, future::Future, hash::Hash}; +use thiserror::Error; use tokio::{ sync::mpsc::{self, Receiver, Sender}, task::JoinHandle, @@ -217,12 +218,14 @@ where } /// An error that can occur when executing a [`BackgroundTask`]. -#[derive(Debug)] +#[derive(Debug, Error)] #[cfg_attr(test, derive(PartialEq, Eq))] pub enum TaskError { /// An internal task error. + #[error(transparent)] Error(Err), /// A panic. + #[error("task panicked")] Panic, } diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 0aa4acdada1..77b1a6ea54c 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -3,7 +3,7 @@ use std::{ collections::{hash_map::Entry, HashMap, VecDeque}, convert::Infallible, - fmt, io, + io, net::SocketAddr, }; @@ -13,7 +13,7 @@ use hyper::body::Frame; use metadata_store::MetadataStore; use mirrord_intproxy_protocol::{ ConnMetadataRequest, ConnMetadataResponse, IncomingRequest, IncomingResponse, LayerId, - MessageId, PortSubscribe, PortSubscription, PortUnsubscribe, ProxyToLayerMessage, + MessageId, PortSubscription, ProxyToLayerMessage, }; use mirrord_protocol::{ tcp::{ @@ -51,19 +51,8 @@ mod response_reader; mod streaming_body; mod subscriptions; -/// Id of a single [`Interceptor`] task. Used to manage interceptor tasks with the -/// [`BackgroundTasks`] struct. -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct InterceptorId(ConnectionId); - -impl fmt::Display for InterceptorId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "incoming interceptor {}", self.0,) - } -} - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct ReaderId(ConnectionId, RequestId); +/// Identifies an [`HttpResponseReader`]. +type ReaderId = (ConnectionId, RequestId); /// Errors that can occur when handling the `incoming` feature. #[derive(Error, Debug)] @@ -75,6 +64,7 @@ pub enum IncomingProxyError { } /// Messages consumed by [`IncomingProxy`] running as a [`BackgroundTask`]. +#[derive(Debug)] pub enum IncomingProxyMessage { LayerRequest(MessageId, LayerId, IncomingRequest), LayerForked(LayerForked), @@ -99,12 +89,15 @@ struct InterceptorHandle { /// Run as a [`BackgroundTask`]. /// /// Handles port subscriptions state of the connected layers. Utilizes multiple background tasks -/// ([`Interceptor`]s) to handle incoming connections. Each connection is managed by a single -/// [`Interceptor`], that establishes a TCP connection with the user application's port and proxies -/// data. +/// ([`Interceptor`]s and [`HttpResponseReader`]s) to handle incoming connections. +/// +/// Each connection is managed by a single [`Interceptor`], +/// that establishes a TCP connection with the user application's port and proxies data. +/// +/// Bodies of HTTP responses from the user application are polled by [`HttpResponseReader`]s. /// /// Incoming connections are created by the agent either explicitly ([`NewTcpConnection`] message) -/// or implicitly ([`HttpRequest`]). +/// or implicitly ([`HttpRequest`](mirrord_protocol::tcp::HttpRequest)). #[derive(Default)] pub struct IncomingProxy { /// Active port subscriptions for all layers. @@ -115,9 +108,9 @@ pub struct IncomingProxy { agent_protocol_version: Option, /// [`TaskSender`]s for active [`Interceptor`]s. - interceptor_handles: HashMap, + interceptor_handles: HashMap, /// For receiving updates from [`Interceptor`]s. - interceptors: BackgroundTasks, + interceptors: BackgroundTasks, /// [TaskSender]s for active [`HttpResponseReader`]s. /// @@ -128,57 +121,20 @@ pub struct IncomingProxy { } impl IncomingProxy { - /// Used when registering new `RawInterceptor` and `HttpInterceptor` tasks in the - /// [`BackgroundTasks`] struct. - // TODO: Update outdated documentation. RawInterceptor, HttpInterceptor do not exist + /// Used when registering new [`Interceptor`]s in the internal [`BackgroundTasks`] instance. const CHANNEL_SIZE: usize = 512; - /// Tries to register the new subscription in the [`SubscriptionsManager`]. - #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] - async fn handle_port_subscribe( - &mut self, - message_id: MessageId, - layer_id: LayerId, - subscribe: PortSubscribe, - message_bus: &mut MessageBus, - ) { - let msg = self - .subscriptions - .layer_subscribed(layer_id, message_id, subscribe); - - if let Some(msg) = msg { - message_bus.send(msg).await; - } - } - - /// Tries to unregister the subscription from the [`SubscriptionsManager`]. - #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] - async fn handle_port_unsubscribe( - &mut self, - layer_id: LayerId, - request: PortUnsubscribe, - message_bus: &mut MessageBus, - ) { - let msg = self.subscriptions.layer_unsubscribed(layer_id, request); - - if let Some(msg) = msg { - message_bus.send(msg).await; - } - } - /// Retrieves or creates an [`Interceptor`] for the given [`HttpRequestFallback`]. /// The request may or may not belong to an existing connection (when stealing with an http /// filter, connections are created implicitly). - #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus), err)] async fn get_or_create_http_interceptor( &mut self, connection_id: ConnectionId, port: Port, message_bus: &MessageBus, ) -> Result, IncomingProxyError> { - let id: InterceptorId = InterceptorId(connection_id); - - let interceptor = match self.interceptor_handles.entry(id) { + let interceptor = match self.interceptor_handles.entry(connection_id) { Entry::Occupied(e) => e.into_mut(), Entry::Vacant(e) => { @@ -204,7 +160,7 @@ impl IncomingProxy { let interceptor = self.interceptors.register( Interceptor::new(interceptor_socket, subscription.listening_on), - id, + connection_id, Self::CHANNEL_SIZE, ); @@ -220,7 +176,7 @@ impl IncomingProxy { } /// Handles all agent messages. - #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus), err)] async fn handle_agent_message( &mut self, message: DaemonTcp, @@ -230,15 +186,11 @@ impl IncomingProxy { match message { DaemonTcp::Close(close) => { self.readers_txs.retain(|id, _| id.0 != close.connection_id); - self.interceptor_handles - .remove(&InterceptorId(close.connection_id)); + self.interceptor_handles.remove(&close.connection_id); } DaemonTcp::Data(data) => { - if let Some(interceptor) = self - .interceptor_handles - .get(&InterceptorId(data.connection_id)) - { + if let Some(interceptor) = self.interceptor_handles.get(&data.connection_id) { interceptor.tx.send(data.bytes).await; } else { tracing::debug!( @@ -307,9 +259,7 @@ impl IncomingProxy { connection_id, request_id, }) => { - let Some(interceptor) = self - .interceptor_handles - .get_mut(&InterceptorId(connection_id)) + let Some(interceptor) = self.interceptor_handles.get_mut(&connection_id) else { return Ok(()); }; @@ -343,9 +293,7 @@ impl IncomingProxy { connection_id, request_id, }) => { - if let Some(interceptor) = self - .interceptor_handles - .get_mut(&InterceptorId(connection_id)) + if let Some(interceptor) = self.interceptor_handles.get_mut(&connection_id) { interceptor.request_body_txs.remove(&request_id); }; @@ -386,14 +334,13 @@ impl IncomingProxy { let interceptor_socket = BoundTcpSocket::bind_specified_or_localhost(subscription.listening_on.ip())?; - let id = InterceptorId(connection_id); self.metadata_store.expect( ConnMetadataRequest { listener_address: subscription.listening_on, peer_address: interceptor_socket.local_addr()?, }, - id.0, + connection_id, ConnMetadataResponse { remote_source: SocketAddr::new(remote_address, source_port), local_address, @@ -402,12 +349,12 @@ impl IncomingProxy { let interceptor = self.interceptors.register( Interceptor::new(interceptor_socket, subscription.listening_on), - id, + connection_id, Self::CHANNEL_SIZE, ); self.interceptor_handles.insert( - id, + connection_id, InterceptorHandle { tx: interceptor, subscription: subscription.subscription.clone(), @@ -428,25 +375,6 @@ impl IncomingProxy { Ok(()) } - fn handle_layer_fork(&mut self, msg: LayerForked) { - let LayerForked { child, parent } = msg; - self.subscriptions.layer_forked(parent, child); - } - - async fn handle_layer_close(&mut self, msg: LayerClosed, message_bus: &MessageBus) { - let msgs = self.subscriptions.layer_closed(msg.id); - - for msg in msgs { - message_bus.send(msg).await; - } - } - - fn get_subscription(&self, interceptor_id: InterceptorId) -> Option<&PortSubscription> { - self.interceptor_handles - .get(&interceptor_id) - .map(|handle| &handle.subscription) - } - /// Handles an HTTP response coming from one of the interceptors. /// /// If all response frames are already available, sends the response in a single message. @@ -460,10 +388,10 @@ impl IncomingProxy { let tail = match response.internal_response.body.tail.take() { Some(tail) => tail, - // All frames are already fetched, we don't have to stream the body to the agent. + // All frames are already fetched, we don't have to wait for the body. + // We can send just one message. None => { let message = if self.agent_handles_framed_responses() { - // We can send just one message to the agent. let response = response.map_body(|body| { InternalHttpBody( body.head @@ -475,8 +403,8 @@ impl IncomingProxy { LayerTcpSteal::HttpResponseFramed(response) } else { // Agent does not support `LayerTcpSteal::HttpResponseFramed`. - // We can only use legacy `LayerTcpSteal::HttpResponse`, which drops trailing - // headers. + // We can only use legacy `LayerTcpSteal::HttpResponse`, + // which drops trailing headers. let connection_id = response.connection_id; let request_id = response.request_id; let response = response.map_body(|body| { @@ -507,7 +435,7 @@ impl IncomingProxy { } }; - let reader_id = ReaderId(response.connection_id, response.request_id); + let reader_id = (response.connection_id, response.request_id); let response_reader = if self.agent_handles_streamed_responses() { let response = response.map_body(|body| { body.head @@ -550,96 +478,186 @@ impl IncomingProxy { .map(|version| HTTP_CHUNKED_RESPONSE_VERSION.matches(version)) .unwrap_or_default() } -} -impl BackgroundTask for IncomingProxy { - type Error = IncomingProxyError; - type MessageIn = IncomingProxyMessage; - type MessageOut = ProxyMessage; - - #[tracing::instrument(level = Level::TRACE, skip_all, err)] - async fn run(mut self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { - loop { - tokio::select! { - msg = message_bus.recv() => match msg { - None => { - tracing::trace!("message bus closed, exiting"); - break Ok(()); - }, - Some(IncomingProxyMessage::LayerRequest(message_id, layer_id, req)) => match req { - IncomingRequest::PortSubscribe(subscribe) => self.handle_port_subscribe(message_id, layer_id, subscribe, message_bus).await, - IncomingRequest::PortUnsubscribe(unsubscribe) => self.handle_port_unsubscribe(layer_id, unsubscribe, message_bus).await, - IncomingRequest::ConnMetadata(req) => { - let res = self.metadata_store.get(req); - message_bus.send(ToLayer { message_id, layer_id, message: ProxyToLayerMessage::Incoming(IncomingResponse::ConnMetadata(res)) }).await; - } - }, - Some(IncomingProxyMessage::AgentMirror(msg)) => { - self.handle_agent_message(msg, false, message_bus).await?; - } - Some(IncomingProxyMessage::AgentSteal(msg)) => { - self.handle_agent_message(msg, true, message_bus).await?; + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus), err)] + async fn handle_message( + &mut self, + message: IncomingProxyMessage, + message_bus: &mut MessageBus, + ) -> Result<(), IncomingProxyError> { + match message { + IncomingProxyMessage::LayerRequest(message_id, layer_id, req) => match req { + IncomingRequest::PortSubscribe(subscribe) => { + let msg = self + .subscriptions + .layer_subscribed(layer_id, message_id, subscribe); + + if let Some(msg) = msg { + message_bus.send(msg).await; } - Some(IncomingProxyMessage::LayerClosed(msg)) => self.handle_layer_close(msg, message_bus).await, - Some(IncomingProxyMessage::LayerForked(msg)) => self.handle_layer_fork(msg), - Some(IncomingProxyMessage::AgentProtocolVersion(version)) => { - self.agent_protocol_version.replace(version); + } + IncomingRequest::PortUnsubscribe(unsubscribe) => { + let msg = self.subscriptions.layer_unsubscribed(layer_id, unsubscribe); + + if let Some(msg) = msg { + message_bus.send(msg).await; } - }, + } + IncomingRequest::ConnMetadata(req) => { + let res = self.metadata_store.get(req); + message_bus + .send(ToLayer { + message_id, + layer_id, + message: ProxyToLayerMessage::Incoming(IncomingResponse::ConnMetadata( + res, + )), + }) + .await; + } + }, - Some(task_update) = self.interceptors.next() => match task_update { - (id, TaskUpdate::Finished(res)) => { - tracing::trace!("{id} finished: {res:?}"); + IncomingProxyMessage::AgentMirror(msg) => { + self.handle_agent_message(msg, false, message_bus).await?; + } - self.metadata_store.no_longer_expect(id.0); + IncomingProxyMessage::AgentSteal(msg) => { + self.handle_agent_message(msg, true, message_bus).await?; + } - let msg = self.get_subscription(id).map(|s| s.wrap_agent_unsubscribe_connection(id.0)); - if let Some(msg) = msg { - message_bus.send(msg).await; - } + IncomingProxyMessage::LayerClosed(msg) => { + let msgs = self.subscriptions.layer_closed(msg.id); - self.interceptor_handles.remove(&id); - }, + for msg in msgs { + message_bus.send(msg).await; + } + } - (id, TaskUpdate::Message(msg)) => { - let Some(PortSubscription::Steal(_)) = self.get_subscription(id) else { - continue; - }; + IncomingProxyMessage::LayerForked(msg) => { + self.subscriptions.layer_forked(msg.parent, msg.child); + } - match msg { - MessageOut::Raw(bytes) => { - let msg = ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { - connection_id: id.0, - bytes, - })); + IncomingProxyMessage::AgentProtocolVersion(version) => { + self.agent_protocol_version.replace(version); + } + } - message_bus.send(msg).await; - }, + Ok(()) + } - MessageOut::Http(response) => { - self.handle_http_response(response, message_bus).await; - } - }; - }, - }, + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] + async fn handle_interceptor_update( + &mut self, + connection_id: ConnectionId, + update: TaskUpdate, + message_bus: &mut MessageBus, + ) { + match update { + TaskUpdate::Finished(res) => { + if let Err(error) = res { + tracing::warn!(connection_id, %error, "Incoming interceptor failed"); + } - Some(task_update) = self.readers.next() => match task_update { - (id, TaskUpdate::Finished(Ok(()))) => { - self.readers_txs.remove(&id); - } + self.metadata_store.no_longer_expect(connection_id); - (id, TaskUpdate::Finished(Err(TaskError::Panic))) => { - tracing::error!(connection_id = id.0, request_id = id.1, "HttpResponseReader task panicked"); + let msg = self + .interceptor_handles + .get(&connection_id) + .map(|interceptor| { + interceptor + .subscription + .wrap_agent_unsubscribe_connection(connection_id) + }); + if let Some(msg) = msg { + message_bus.send(msg).await; + } - self.interceptor_handles.remove(&InterceptorId(id.0)); - self.readers_txs.remove(&id); + self.interceptor_handles.remove(&connection_id); + } + + TaskUpdate::Message(msg) => { + let Some(PortSubscription::Steal(_)) = self + .interceptor_handles + .get(&connection_id) + .map(|interceptor| &interceptor.subscription) + else { + return; + }; + + match msg { + MessageOut::Raw(bytes) => { + let msg = ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { + connection_id, + bytes, + })); + + message_bus.send(msg).await; } - (_, TaskUpdate::Message(msg)) => { - message_bus.send(ClientMessage::TcpSteal(msg)).await; + MessageOut::Http(response) => { + self.handle_http_response(response, message_bus).await; } + }; + } + } + } + + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] + async fn handle_response_reader_update( + &mut self, + connection_id: ConnectionId, + request_id: RequestId, + update: TaskUpdate, + message_bus: &mut MessageBus, + ) { + match update { + TaskUpdate::Finished(Ok(())) => { + self.readers_txs.remove(&(connection_id, request_id)); + } + + TaskUpdate::Finished(Err(TaskError::Panic)) => { + tracing::error!(connection_id, request_id, "HttpResponseReader panicked"); + + self.readers_txs.remove(&(connection_id, request_id)); + if let Some(interceptor) = self.interceptor_handles.remove(&connection_id) { + message_bus + .send( + interceptor + .subscription + .wrap_agent_unsubscribe_connection(connection_id), + ) + .await; } } + + TaskUpdate::Message(msg) => { + message_bus.send(ClientMessage::TcpSteal(msg)).await; + } + } + } +} + +impl BackgroundTask for IncomingProxy { + type Error = IncomingProxyError; + type MessageIn = IncomingProxyMessage; + type MessageOut = ProxyMessage; + + #[tracing::instrument(level = Level::TRACE, name = "incoming_proxy_main_loop", skip_all, err)] + async fn run(mut self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { + loop { + tokio::select! { + msg = message_bus.recv() => match msg { + None => { + tracing::trace!("message bus closed, exiting"); + break Ok(()); + }, + Some(message) => self.handle_message(message, message_bus).await?, + }, + + Some((id, update)) = self.interceptors.next() => self.handle_interceptor_update(id, update, message_bus).await, + + Some((id, update)) = self.readers.next() => self.handle_response_reader_update(id.0, id.1, update, message_bus).await, + } } } } diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index 8c51228b545..e8687bdbb0d 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -1,4 +1,10 @@ -use std::{error::Error, fmt, io, net::SocketAddr, ops::Not, time::Duration}; +use std::{ + error::Error, + fmt, io, + net::SocketAddr, + ops::Not, + time::{Duration, Instant}, +}; use bytes::Bytes; use exponential_backoff::Backoff; @@ -52,28 +58,48 @@ impl LocalHttpClient { } /// Reuses or creates a new [`HttpSender`]. - #[tracing::instrument(level = Level::TRACE, err(level = Level::TRACE))] async fn get_sender(&mut self, version: Version) -> Result { if let Some(sender) = self.sender.take() { if sender.version_matches(version) { + tracing::trace!("Reusing the HTTP connection."); return Ok(sender); + } else { + tracing::trace!("HTTP connection found, but the HTTP version does not match."); } } let stream = match self.stream.take() { - Some(stream) => stream, + Some(stream) => { + tracing::trace!("Reusing the TCP connection."); + stream + } None => { let socket = BoundTcpSocket::bind_specified_or_localhost(self.local_server_address.ip()) .map_err(LocalHttpError::SocketSetupFailed)?; - socket + + let start = Instant::now(); + let socket = socket .connect(self.local_server_address) .await - .map_err(LocalHttpError::ConnectTcpFailed)? + .map_err(LocalHttpError::ConnectTcpFailed)?; + tracing::trace!( + elapsed_s = start.elapsed().as_secs_f32(), + "Made the TCP connection" + ); + + socket } }; - HttpSender::handshake(version, stream).await + let start = Instant::now(); + let sender = HttpSender::handshake(version, stream).await?; + tracing::trace!( + elapsed_s = start.elapsed().as_secs_f32(), + "Made the HTTP connection" + ); + + Ok(sender) } /// Tries to send the given `request` to the user application's HTTP server. @@ -84,7 +110,14 @@ impl LocalHttpClient { request: &HttpRequest, ) -> Result, LocalHttpError> { let mut sender = self.get_sender(request.version()).await?; + + let start = Instant::now(); let response = sender.send_request(request.clone()).await?; + tracing::trace!( + elapsed_s = start.elapsed().as_secs_f32(), + "Sent the HTTP request" + ); + let (parts, mut body) = response.into_parts(); let frames = body @@ -122,10 +155,6 @@ impl LocalHttpClient { tracing::trace!(attempt, "Trying to send the request"); match (self.try_send_request(request).await, backoffs.next()) { (Ok(response), _) => { - tracing::trace!( - attempt, - "Successfully sent the request and peeked first frames" - ); break Ok(response); } @@ -267,7 +296,6 @@ enum HttpSender { impl HttpSender { /// Performs an HTTP handshake over the given [`TcpStream`]. - #[tracing::instrument(level = Level::DEBUG, skip(target_stream), err(level = Level::WARN))] async fn handshake(version: Version, target_stream: TcpStream) -> Result { let local_addr = target_stream .local_addr() @@ -321,13 +349,6 @@ impl HttpSender { } /// Tries to send the given [`HttpRequest`] to the server. - #[tracing::instrument( - level = Level::DEBUG, - skip(self, request), - fields(connection_id = request.connection_id, request_id = request.request_id), - ret, - err(level = Level::WARN), - )] async fn send_request( &mut self, request: HttpRequest, diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs index b509f74603c..5fb37a44799 100644 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ b/mirrord/intproxy/src/proxies/incoming/interceptor.rs @@ -120,7 +120,12 @@ impl BackgroundTask for Interceptor { type MessageIn = MessageIn; type MessageOut = MessageOut; - #[tracing::instrument(level = Level::TRACE, skip_all, err)] + #[tracing::instrument( + level = Level::TRACE, + name = "incoming_interceptor_main_loop", + skip_all, fields(peer_addr = %self.peer), + err(level = Level::WARN) + )] async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { let stream = self .socket @@ -136,25 +141,28 @@ impl BackgroundTask for Interceptor { // We should not block until the agent has something, we don't know what this protocol looks like. result = stream.readable() => { result.map_err(InterceptorError::IoFailed)?; - return RawConnection { stream }.run(message_bus).await; + tracing::trace!("TCP connection became readable, assuming raw TCP"); + RawConnection { stream }.run(message_bus).await } message = message_bus.peek() => match message { - Some(MessageIn::Http(..)) => {} + Some(MessageIn::Http(..)) => { + tracing::trace!("Next message on the message bus is an HTTP request, running as an HTTP gateway"); + HttpConnection { + local_client: LocalHttpClient::new_for_stream(stream)?, + } + .run(message_bus) + .await + } Some(MessageIn::Raw(..)) => { - return RawConnection { stream }.run(message_bus).await; - } + tracing::trace!("Next message on the message bus is raw TCP data, running as a TCP proxy"); + RawConnection { stream }.run(message_bus).await + }, - None => return Ok(()), + None => Ok(()), } } - - let http_conn = HttpConnection { - local_client: LocalHttpClient::new_for_stream(stream)?, - }; - - http_conn.run(message_bus).await } } @@ -229,7 +237,12 @@ impl HttpConnection { /// /// When an HTTP upgrade happens, the underlying [`TcpStream`] is reclaimed and wrapped /// in a [`RawConnection`], which handles the rest of the connection. - #[tracing::instrument(level = Level::TRACE, skip_all, ret, err)] + #[tracing::instrument( + level = Level::TRACE, + name = "http_connection_main_loop", + skip_all, ret, + err(level = Level::WARN), + )] async fn run( mut self, mut message_bus: PeekableMessageBus<'_, Interceptor>, @@ -250,6 +263,7 @@ impl HttpConnection { message_bus.send(MessageOut::Http(res)).await; if let Some(on_upgrade) = on_upgrade { + tracing::trace!("Detected an HTTP upgrade"); break on_upgrade .await .map_err(InterceptorError::HttpUpgradeFailed)?; @@ -293,6 +307,12 @@ impl RawConnection { /// /// 3. This implementation exits only when an error is encountered or the [`MessageBus`] is /// closed. + #[tracing::instrument( + level = Level::TRACE, + name = "raw_connection_main_loop", + skip_all, ret, + err(level = Level::WARN), + )] async fn run( mut self, mut message_bus: PeekableMessageBus<'_, Interceptor>, @@ -310,7 +330,7 @@ impl RawConnection { Err(e) => break Err(InterceptorError::IoFailed(e)), Ok(..) => { if buf.is_empty() { - tracing::trace!("incoming interceptor -> layer shutdown, sending a 0-sized read to inform the agent"); + tracing::trace!("layer shutdown, sending a 0-sized read to inform the agent"); reading_closed = true; } message_bus.send(MessageOut::Raw(buf.to_vec())).await; @@ -320,12 +340,12 @@ impl RawConnection { msg = message_bus.recv(), if !remote_closed => match msg { None => { - tracing::trace!("incoming interceptor -> message bus closed, waiting 1 second before exiting"); + tracing::trace!("message bus closed, waiting 1 second before exiting"); remote_closed = true; }, Some(MessageIn::Raw(data)) => { if data.is_empty() { - tracing::trace!("incoming interceptor -> agent shutdown, shutting down connection with layer"); + tracing::trace!("agent shutdown, shutting down connection with layer"); self.stream.shutdown().await.map_err(InterceptorError::IoFailed)?; } else { self.stream.write_all(&data).await.map_err(InterceptorError::IoFailed)?; @@ -335,7 +355,7 @@ impl RawConnection { }, _ = time::sleep(Duration::from_secs(1)), if remote_closed => { - tracing::trace!("incoming interceptor -> layer silent for 1 second and message bus is closed, exiting"); + tracing::trace!("layer silent for 1 second and message bus is closed, exiting"); break Ok(()); }, diff --git a/mirrord/intproxy/src/proxies/incoming/response_reader.rs b/mirrord/intproxy/src/proxies/incoming/response_reader.rs index 56ef95ac1a4..d07a1ca723b 100644 --- a/mirrord/intproxy/src/proxies/incoming/response_reader.rs +++ b/mirrord/intproxy/src/proxies/incoming/response_reader.rs @@ -1,4 +1,4 @@ -use std::convert::Infallible; +use std::{convert::Infallible, time::Instant}; use http_body_util::BodyExt; use hyper::body::{Frame, Incoming}; @@ -10,6 +10,7 @@ use mirrord_protocol::{ }, ConnectionId, RequestId, }; +use tracing::Level; use super::http::PeekedBody; use crate::{ @@ -27,23 +28,67 @@ pub enum HttpResponseReader { }, } +impl HttpResponseReader { + fn request_id(&self) -> RequestId { + match self { + Self::Legacy(response) => response.request_id, + Self::Framed(response) => response.request_id, + Self::Chunked { request_id, .. } => *request_id, + } + } + + fn connection_id(&self) -> ConnectionId { + match self { + Self::Legacy(response) => response.connection_id, + Self::Framed(response) => response.connection_id, + Self::Chunked { connection_id, .. } => *connection_id, + } + } +} + impl BackgroundTask for HttpResponseReader { type Error = Infallible; type MessageIn = Infallible; type MessageOut = LayerTcpSteal; + #[tracing::instrument( + level = Level::TRACE, + name = "http_response_reader_main_loop", + fields( + connection_id = self.connection_id(), + request_id = self.request_id(), + ), + skip_all, err, + )] async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { match self { Self::Legacy(mut response) => { let tail = match response.internal_response.body.tail.take() { Some(incoming) => { + let start = Instant::now(); tokio::select! { - _ = message_bus.recv() => return Ok(()), + _ = message_bus.recv() => { + tracing::trace!("Message bus closed, exiting"); + return Ok(()); + }, result = incoming.collect() => match result { - Ok(data) => Vec::from(data.to_bytes()), + Ok(data) => { + tracing::trace!( + elapsed_s = start.elapsed().as_secs_f32(), + "Collected the whole body.", + ); + Vec::from(data.to_bytes()) + }, Err(error) => { + tracing::warn!( + connection_id = response.connection_id, + request_id = response.request_id, + %error, + "Failed to read the response body.", + ); + let response = LocalHttpError::ReadBodyFailed(error) .as_error_response( response.internal_response.version, @@ -52,6 +97,7 @@ impl BackgroundTask for HttpResponseReader { response.port, ); message_bus.send(LayerTcpSteal::HttpResponse(response)).await; + return Ok(()); } } @@ -88,19 +134,35 @@ impl BackgroundTask for HttpResponseReader { Self::Framed(mut response) => { if let Some(mut incoming) = response.internal_response.body.tail.take() { + let start = Instant::now(); loop { tokio::select! { - _ = message_bus.recv() => return Ok(()), + _ = message_bus.recv() => { + tracing::trace!("Message bus closed, exiting"); + return Ok(()); + }, result = incoming.next_frames() => match result { Ok(data) => { response.internal_response.body.head.extend(data.frames); + if data.is_last { + tracing::trace!( + elapsed_s = start.elapsed().as_secs_f32(), + "Collected the whole response body." + ); break; } }, Err(error) => { + tracing::warn!( + connection_id = response.connection_id, + request_id = response.request_id, + %error, + "Failed to read the response body.", + ); + let response = LocalHttpError::ReadBodyFailed(error) .as_error_response( response.internal_response.version, @@ -109,6 +171,7 @@ impl BackgroundTask for HttpResponseReader { response.port, ); message_bus.send(LayerTcpSteal::HttpResponse(response)).await; + return Ok(()); } } @@ -134,37 +197,54 @@ impl BackgroundTask for HttpResponseReader { connection_id, request_id, mut body, - } => loop { - tokio::select! { - _ = message_bus.recv() => return Ok(()), - - result = body.next_frames() => match result { - Ok(data) => { - let message = LayerTcpSteal::HttpResponseChunked(ChunkedResponse::Body(ChunkedHttpBody { - frames: data.frames.into_iter().map(InternalHttpBodyFrame::from).collect(), - is_last: data.is_last, - connection_id, - request_id, - })); - message_bus.send(message).await; - - if data.is_last { - break; - } + } => { + let start = Instant::now(); + loop { + tokio::select! { + _ = message_bus.recv() => { + tracing::trace!("Message bus closed, exiting"); + return Ok(()) }, - Err(..) => { - let message = LayerTcpSteal::HttpResponseChunked(ChunkedResponse::Error(ChunkedHttpError { - connection_id, - request_id, - })); - message_bus.send(message).await; - - return Ok(()); + result = body.next_frames() => match result { + Ok(data) => { + let message = LayerTcpSteal::HttpResponseChunked(ChunkedResponse::Body(ChunkedHttpBody { + frames: data.frames.into_iter().map(InternalHttpBodyFrame::from).collect(), + is_last: data.is_last, + connection_id, + request_id, + })); + message_bus.send(message).await; + + if data.is_last { + tracing::trace!( + elapsed_s = start.elapsed().as_secs_f32(), + "Collected the whole response body." + ); + break; + } + }, + + Err(error) => { + tracing::warn!( + connection_id, + request_id, + %error, + "Failed to read the response body.", + ); + + let message = LayerTcpSteal::HttpResponseChunked(ChunkedResponse::Error(ChunkedHttpError { + connection_id, + request_id, + })); + message_bus.send(message).await; + + return Ok(()); + } } } } - }, + } } Ok(()) From f59a62ccf327173f06fd83e74e4f153cfa951d84 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Wed, 15 Jan 2025 10:05:41 +0100 Subject: [PATCH 12/60] HttpResponseReader logic split into methods --- .../src/proxies/incoming/response_reader.rs | 403 ++++++++++-------- 1 file changed, 224 insertions(+), 179 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/response_reader.rs b/mirrord/intproxy/src/proxies/incoming/response_reader.rs index d07a1ca723b..b27c209a8c7 100644 --- a/mirrord/intproxy/src/proxies/incoming/response_reader.rs +++ b/mirrord/intproxy/src/proxies/incoming/response_reader.rs @@ -18,9 +18,16 @@ use crate::{ proxies::incoming::http::LocalHttpError, }; +/// Background task responsible for asynchronous read of an HTTP response body coming from the user +/// application. +/// +/// Meant to be run as a [`BackgroundTask`]. pub enum HttpResponseReader { + /// Produces a [`LayerTcpSteal::HttpResponse`] message. Legacy(HttpResponse), + /// Produces a [`LayerTcpSteal::HttpResponseFramed`] message. Framed(HttpResponse), + /// Produces [`LayerTcpSteal::HttpResponseChunked`] messasages. Chunked { connection_id: ConnectionId, request_id: RequestId, @@ -44,6 +51,218 @@ impl HttpResponseReader { Self::Chunked { connection_id, .. } => *connection_id, } } + + /// Reads the body and produces a [`LayerTcpSteal::HttpResponse`] message. + /// + /// When reading the body fails, produces a [`LayerTcpSteal::HttpResponse`] error response. + async fn run_legacy( + mut response: HttpResponse, + message_bus: &mut MessageBus, + ) { + let tail = match response.internal_response.body.tail.take() { + Some(incoming) => { + let start = Instant::now(); + let result = tokio::select! { + _ = message_bus.recv() => { + tracing::trace!("Message bus closed, exiting"); + return; + }, + + result = incoming.collect() => result, + }; + + match result { + Ok(data) => { + tracing::trace!( + elapsed_s = start.elapsed().as_secs_f32(), + "Collected the whole body.", + ); + Vec::from(data.to_bytes()) + } + + Err(error) => { + tracing::warn!( + connection_id = response.connection_id, + request_id = response.request_id, + %error, + "Failed to read the response body.", + ); + + let response = LocalHttpError::ReadBodyFailed(error).as_error_response( + response.internal_response.version, + response.request_id, + response.connection_id, + response.port, + ); + message_bus + .send(LayerTcpSteal::HttpResponse(response)) + .await; + return; + } + } + } + + None => vec![], + }; + + let response = response.map_body(|body| { + let mut complete = Vec::with_capacity( + body.head + .iter() + .filter_map(|frame| Some(frame.data_ref()?.len())) + .sum::() + + tail.len(), + ); + for frame in body + .head + .into_iter() + .map(Frame::into_data) + .filter_map(Result::ok) + { + complete.extend(frame); + } + complete.extend(tail); + complete + }); + + message_bus + .send(LayerTcpSteal::HttpResponse(response)) + .await; + } + + /// Reads the body and produces a [`LayerTcpSteal::HttpResponseFramed`] message. + /// + /// When reading the body fails, produces a [`LayerTcpSteal::HttpResponse`] error response. + async fn run_framed( + mut response: HttpResponse, + message_bus: &mut MessageBus, + ) { + if let Some(mut incoming) = response.internal_response.body.tail.take() { + let start = Instant::now(); + loop { + let result = tokio::select! { + _ = message_bus.recv() => { + tracing::trace!("Message bus closed, exiting"); + return; + }, + + result = incoming.next_frames() => result, + }; + + match result { + Ok(data) => { + response.internal_response.body.head.extend(data.frames); + + if data.is_last { + tracing::trace!( + elapsed_s = start.elapsed().as_secs_f32(), + "Collected the whole response body." + ); + break; + } + } + + Err(error) => { + tracing::warn!( + connection_id = response.connection_id, + request_id = response.request_id, + %error, + "Failed to read the response body.", + ); + + let response = LocalHttpError::ReadBodyFailed(error).as_error_response( + response.internal_response.version, + response.request_id, + response.connection_id, + response.port, + ); + message_bus + .send(LayerTcpSteal::HttpResponse(response)) + .await; + return; + } + } + } + }; + + let response = response.map_body(|body| { + InternalHttpBody( + body.head + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect(), + ) + }); + + message_bus + .send(LayerTcpSteal::HttpResponseFramed(response)) + .await; + } + + /// Reads the body and produces [`LayerTcpSteal::HttpResponseChunked`] messages. + async fn run_chunked( + connection_id: ConnectionId, + request_id: RequestId, + mut body: Incoming, + message_bus: &mut MessageBus, + ) { + let start = Instant::now(); + loop { + let result = tokio::select! { + _ = message_bus.recv() => { + tracing::trace!("Message bus closed, exiting"); + return; + }, + + result = body.next_frames() => result, + }; + + match result { + Ok(data) => { + let message = LayerTcpSteal::HttpResponseChunked(ChunkedResponse::Body( + ChunkedHttpBody { + frames: data + .frames + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect(), + is_last: data.is_last, + connection_id, + request_id, + }, + )); + message_bus.send(message).await; + + if data.is_last { + tracing::trace!( + elapsed_s = start.elapsed().as_secs_f32(), + "Collected the whole response body." + ); + break; + } + } + + Err(error) => { + tracing::warn!( + connection_id, + request_id, + %error, + "Failed to read the response body.", + ); + + let message = LayerTcpSteal::HttpResponseChunked(ChunkedResponse::Error( + ChunkedHttpError { + connection_id, + request_id, + }, + )); + message_bus.send(message).await; + + return; + } + } + } + } } impl BackgroundTask for HttpResponseReader { @@ -58,193 +277,19 @@ impl BackgroundTask for HttpResponseReader { connection_id = self.connection_id(), request_id = self.request_id(), ), - skip_all, err, + skip_all, )] async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { match self { - Self::Legacy(mut response) => { - let tail = match response.internal_response.body.tail.take() { - Some(incoming) => { - let start = Instant::now(); - tokio::select! { - _ = message_bus.recv() => { - tracing::trace!("Message bus closed, exiting"); - return Ok(()); - }, - - result = incoming.collect() => match result { - Ok(data) => { - tracing::trace!( - elapsed_s = start.elapsed().as_secs_f32(), - "Collected the whole body.", - ); - Vec::from(data.to_bytes()) - }, - - Err(error) => { - tracing::warn!( - connection_id = response.connection_id, - request_id = response.request_id, - %error, - "Failed to read the response body.", - ); - - let response = LocalHttpError::ReadBodyFailed(error) - .as_error_response( - response.internal_response.version, - response.request_id, - response.connection_id, - response.port, - ); - message_bus.send(LayerTcpSteal::HttpResponse(response)).await; - - return Ok(()); - } - } - } - } + Self::Legacy(response) => Self::run_legacy(response, message_bus).await, - None => vec![], - }; - - let response = response.map_body(|body| { - let mut complete = Vec::with_capacity( - body.head - .iter() - .filter_map(|frame| Some(frame.data_ref()?.len())) - .sum::() - + tail.len(), - ); - for frame in body - .head - .into_iter() - .map(Frame::into_data) - .filter_map(Result::ok) - { - complete.extend(frame); - } - complete.extend(tail); - complete - }); - - message_bus - .send(LayerTcpSteal::HttpResponse(response)) - .await; - } - - Self::Framed(mut response) => { - if let Some(mut incoming) = response.internal_response.body.tail.take() { - let start = Instant::now(); - loop { - tokio::select! { - _ = message_bus.recv() => { - tracing::trace!("Message bus closed, exiting"); - return Ok(()); - }, - - result = incoming.next_frames() => match result { - Ok(data) => { - response.internal_response.body.head.extend(data.frames); - - if data.is_last { - tracing::trace!( - elapsed_s = start.elapsed().as_secs_f32(), - "Collected the whole response body." - ); - break; - } - }, - - Err(error) => { - tracing::warn!( - connection_id = response.connection_id, - request_id = response.request_id, - %error, - "Failed to read the response body.", - ); - - let response = LocalHttpError::ReadBodyFailed(error) - .as_error_response( - response.internal_response.version, - response.request_id, - response.connection_id, - response.port, - ); - message_bus.send(LayerTcpSteal::HttpResponse(response)).await; - - return Ok(()); - } - } - } - } - }; - - let response = response.map_body(|body| { - InternalHttpBody( - body.head - .into_iter() - .map(InternalHttpBodyFrame::from) - .collect(), - ) - }); - - message_bus - .send(LayerTcpSteal::HttpResponseFramed(response)) - .await; - } + Self::Framed(response) => Self::run_framed(response, message_bus).await, Self::Chunked { connection_id, request_id, - mut body, - } => { - let start = Instant::now(); - loop { - tokio::select! { - _ = message_bus.recv() => { - tracing::trace!("Message bus closed, exiting"); - return Ok(()) - }, - - result = body.next_frames() => match result { - Ok(data) => { - let message = LayerTcpSteal::HttpResponseChunked(ChunkedResponse::Body(ChunkedHttpBody { - frames: data.frames.into_iter().map(InternalHttpBodyFrame::from).collect(), - is_last: data.is_last, - connection_id, - request_id, - })); - message_bus.send(message).await; - - if data.is_last { - tracing::trace!( - elapsed_s = start.elapsed().as_secs_f32(), - "Collected the whole response body." - ); - break; - } - }, - - Err(error) => { - tracing::warn!( - connection_id, - request_id, - %error, - "Failed to read the response body.", - ); - - let message = LayerTcpSteal::HttpResponseChunked(ChunkedResponse::Error(ChunkedHttpError { - connection_id, - request_id, - })); - message_bus.send(message).await; - - return Ok(()); - } - } - } - } - } + body, + } => Self::run_chunked(connection_id, request_id, body, message_bus).await, } Ok(()) From 600d958f6147c764cd2a7223dd4349a3794674b3 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Wed, 15 Jan 2025 10:55:46 +0100 Subject: [PATCH 13/60] unless_bus_closed --- mirrord/intproxy/src/background_tasks.rs | 20 +++++++++ mirrord/intproxy/src/proxies/incoming.rs | 6 +++ .../src/proxies/incoming/interceptor.rs | 41 +++++++++++++------ .../src/proxies/incoming/response_reader.rs | 33 +++++---------- 4 files changed, 65 insertions(+), 35 deletions(-) diff --git a/mirrord/intproxy/src/background_tasks.rs b/mirrord/intproxy/src/background_tasks.rs index 99d27a8ab8d..f6ffc7feea8 100644 --- a/mirrord/intproxy/src/background_tasks.rs +++ b/mirrord/intproxy/src/background_tasks.rs @@ -46,6 +46,21 @@ impl MessageBus { } } +/// Runs the given [`Future`] and returns its output, unless the given [`MessageBus`] gets closed +/// first. In that case, returns [`None`]. +/// +/// Can be used to make sure [`BackgroundTask`]s don't linger when they are no longer needed. +pub async fn unless_bus_closed(message_bus: &MessageBus, future: F) -> Option +where + F: Future, + T: BackgroundTask, +{ + tokio::select! { + _ = message_bus.tx.closed() => None, + output = future => Some(output), + } +} + /// Wrapper over [`MessageBus`]. /// /// Allows for peeking the next incoming message. @@ -88,6 +103,11 @@ impl<'a, T: BackgroundTask> PeekableMessageBus<'a, T> { }, } } + + /// Returns an immutable reference to the wrapped [`MessageBus`]. + pub fn inner(&self) -> &MessageBus { + self.message_bus + } } /// Common trait for all background tasks in the internal proxy. diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 77b1a6ea54c..66105aa4176 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -1,4 +1,10 @@ //! Handles the logic of the `incoming` feature. +//! +//! +//! Background tasks: +//! 1. TcpProxy - always handles remote connection first. Attempts to connect a couple times. Waits +//! until connection becomes readable (is TCP) or receives an http request. +//! 2. HttpSender - use std::{ collections::{hash_map::Entry, HashMap, VecDeque}, diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs index 5fb37a44799..e58fde21cee 100644 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ b/mirrord/intproxy/src/proxies/incoming/interceptor.rs @@ -24,7 +24,7 @@ use super::{ streaming_body::StreamingBody, }; use crate::{ - background_tasks::{BackgroundTask, MessageBus, PeekableMessageBus}, + background_tasks::{unless_bus_closed, BackgroundTask, MessageBus, PeekableMessageBus}, proxies::incoming::bound_socket::BoundTcpSocket, }; @@ -127,11 +127,12 @@ impl BackgroundTask for Interceptor { err(level = Level::WARN) )] async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { - let stream = self - .socket - .connect(self.peer) - .await - .map_err(InterceptorError::ConnectFailed)?; + let Some(result) = unless_bus_closed(message_bus, self.socket.connect(self.peer)).await + else { + tracing::trace!("Message bus closed, exiting"); + return Ok(()); + }; + let stream = result.map_err(InterceptorError::ConnectFailed)?; let mut message_bus = message_bus.peekable(); tokio::select! { @@ -258,16 +259,30 @@ impl HttpConnection { } Some(MessageIn::Http(request)) => { - let result = self.local_client.send_request(&request).await; + let Some(result) = unless_bus_closed( + message_bus.inner(), + self.local_client.send_request(&request), + ) + .await + else { + tracing::trace!("Message bus closed, exiting"); + return Ok(()); + }; let (res, on_upgrade) = Self::handle_send_result(request, result)?; message_bus.send(MessageOut::Http(res)).await; - if let Some(on_upgrade) = on_upgrade { - tracing::trace!("Detected an HTTP upgrade"); - break on_upgrade - .await - .map_err(InterceptorError::HttpUpgradeFailed)?; - } + let Some(on_upgrade) = on_upgrade else { + continue; + }; + + tracing::trace!("Detected an HTTP upgrade"); + let Some(result) = unless_bus_closed(message_bus.inner(), on_upgrade).await + else { + tracing::trace!("Message bus closed, exiting"); + return Ok(()); + }; + + break result.map_err(InterceptorError::HttpUpgradeFailed)?; } } }; diff --git a/mirrord/intproxy/src/proxies/incoming/response_reader.rs b/mirrord/intproxy/src/proxies/incoming/response_reader.rs index b27c209a8c7..816319e2036 100644 --- a/mirrord/intproxy/src/proxies/incoming/response_reader.rs +++ b/mirrord/intproxy/src/proxies/incoming/response_reader.rs @@ -14,7 +14,7 @@ use tracing::Level; use super::http::PeekedBody; use crate::{ - background_tasks::{BackgroundTask, MessageBus}, + background_tasks::{unless_bus_closed, BackgroundTask, MessageBus}, proxies::incoming::http::LocalHttpError, }; @@ -62,13 +62,9 @@ impl HttpResponseReader { let tail = match response.internal_response.body.tail.take() { Some(incoming) => { let start = Instant::now(); - let result = tokio::select! { - _ = message_bus.recv() => { - tracing::trace!("Message bus closed, exiting"); - return; - }, - - result = incoming.collect() => result, + let Some(result) = unless_bus_closed(message_bus, incoming.collect()).await else { + tracing::trace!("Message bus closed, exiting"); + return; }; match result { @@ -140,13 +136,10 @@ impl HttpResponseReader { if let Some(mut incoming) = response.internal_response.body.tail.take() { let start = Instant::now(); loop { - let result = tokio::select! { - _ = message_bus.recv() => { - tracing::trace!("Message bus closed, exiting"); - return; - }, - - result = incoming.next_frames() => result, + let Some(result) = unless_bus_closed(message_bus, incoming.next_frames()).await + else { + tracing::trace!("Message bus closed, exiting"); + return; }; match result { @@ -208,13 +201,9 @@ impl HttpResponseReader { ) { let start = Instant::now(); loop { - let result = tokio::select! { - _ = message_bus.recv() => { - tracing::trace!("Message bus closed, exiting"); - return; - }, - - result = body.next_frames() => result, + let Some(result) = unless_bus_closed(message_bus, body.next_frames()).await else { + tracing::trace!("Message bus closed, exiting"); + return; }; match result { From 8b15a86bcc314e92f6058fe37e2d56f5a095a8a9 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Thu, 16 Jan 2025 18:01:16 +0100 Subject: [PATCH 14/60] I hate this --- Cargo.lock | 1 - mirrord/intproxy/Cargo.toml | 1 - mirrord/intproxy/src/background_tasks.rs | 71 +- mirrord/intproxy/src/proxies/incoming.rs | 648 ++++++++--------- mirrord/intproxy/src/proxies/incoming/http.rs | 250 ++----- .../src/proxies/incoming/http/client_store.rs | 142 ++++ .../proxies/incoming/http/response_mode.rs | 21 + .../incoming/{ => http}/streaming_body.rs | 0 .../src/proxies/incoming/http_gateway.rs | 682 ++++++++++++++++++ .../src/proxies/incoming/interceptor.rs | 664 ----------------- .../src/proxies/incoming/response_reader.rs | 286 -------- .../intproxy/src/proxies/incoming/tasks.rs | 63 ++ .../src/proxies/incoming/tcp_proxy.rs | 120 +++ 13 files changed, 1404 insertions(+), 1545 deletions(-) create mode 100644 mirrord/intproxy/src/proxies/incoming/http/client_store.rs create mode 100644 mirrord/intproxy/src/proxies/incoming/http/response_mode.rs rename mirrord/intproxy/src/proxies/incoming/{ => http}/streaming_body.rs (100%) create mode 100644 mirrord/intproxy/src/proxies/incoming/http_gateway.rs delete mode 100644 mirrord/intproxy/src/proxies/incoming/interceptor.rs delete mode 100644 mirrord/intproxy/src/proxies/incoming/response_reader.rs create mode 100644 mirrord/intproxy/src/proxies/incoming/tasks.rs create mode 100644 mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs diff --git a/Cargo.lock b/Cargo.lock index 61932caa8e4..0a54470c5b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4313,7 +4313,6 @@ version = "3.128.0" dependencies = [ "bytes", "exponential-backoff", - "futures", "h2 0.4.7", "http-body-util", "hyper 1.5.2", diff --git a/mirrord/intproxy/Cargo.toml b/mirrord/intproxy/Cargo.toml index f7972fce76a..b132de1ed37 100644 --- a/mirrord/intproxy/Cargo.toml +++ b/mirrord/intproxy/Cargo.toml @@ -45,6 +45,5 @@ rustls-pemfile.workspace = true exponential-backoff = "2" [dev-dependencies] -futures.workspace = true reqwest.workspace = true rstest.workspace = true diff --git a/mirrord/intproxy/src/background_tasks.rs b/mirrord/intproxy/src/background_tasks.rs index f6ffc7feea8..02092c32a43 100644 --- a/mirrord/intproxy/src/background_tasks.rs +++ b/mirrord/intproxy/src/background_tasks.rs @@ -23,14 +23,6 @@ pub struct MessageBus { } impl MessageBus { - /// Wraps this message bus into a struct that allows for peeking the next incoming message. - pub fn peekable(&mut self) -> PeekableMessageBus<'_, T> { - PeekableMessageBus { - peeked: None, - message_bus: self, - } - } - /// Attempts to send a message to this task's parent. pub async fn send>(&self, msg: M) { let _ = self.tx.send(msg.into()).await; @@ -44,70 +36,21 @@ impl MessageBus { msg = self.rx.recv() => msg, } } -} -/// Runs the given [`Future`] and returns its output, unless the given [`MessageBus`] gets closed -/// first. In that case, returns [`None`]. -/// -/// Can be used to make sure [`BackgroundTask`]s don't linger when they are no longer needed. -pub async fn unless_bus_closed(message_bus: &MessageBus, future: F) -> Option -where - F: Future, - T: BackgroundTask, -{ - tokio::select! { - _ = message_bus.tx.closed() => None, - output = future => Some(output), + pub fn closed(&self) -> Closed { + Closed(self.tx.clone()) } } -/// Wrapper over [`MessageBus`]. -/// -/// Allows for peeking the next incoming message. -pub struct PeekableMessageBus<'a, T: BackgroundTask> { - peeked: Option, - message_bus: &'a mut MessageBus, -} - -impl<'a, T: BackgroundTask> PeekableMessageBus<'a, T> { - /// Attempts to send a message to this task's parent. - pub async fn send>(&self, msg: M) { - let _ = self.message_bus.tx.send(msg.into()).await; - } - - /// Receives a message from this task's parent. - /// [`None`] means that the channel is closed and there will be no more messages. - pub async fn recv(&mut self) -> Option { - if self.peeked.is_some() { - return self.peeked.take(); - } +pub struct Closed(Sender); +impl Closed { + pub async fn cancel_on_close(&self, future: F) -> Option { tokio::select! { - _ = self.message_bus.tx.closed() => None, - msg = self.message_bus.rx.recv() => msg, + _ = self.0.closed() => None, + output = future => Some(output) } } - - /// Peeks the next message from this task's parent. - /// [`None`] means that the channel is closed and there will be no more messages. - pub async fn peek(&mut self) -> Option<&T::MessageIn> { - if self.peeked.is_some() { - return self.peeked.as_ref(); - } - - tokio::select! { - _ = self.message_bus.tx.closed() => None, - msg = self.message_bus.rx.recv() => { - self.peeked = msg; - self.peeked.as_ref() - }, - } - } - - /// Returns an immutable reference to the wrapped [`MessageBus`]. - pub fn inner(&self) -> &MessageBus { - self.message_bus - } } /// Common trait for all background tasks in the internal proxy. diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 66105aa4176..8ee466ae11e 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -7,15 +7,14 @@ //! 2. HttpSender - use std::{ - collections::{hash_map::Entry, HashMap, VecDeque}, - convert::Infallible, + collections::{hash_map::Entry, HashMap}, io, net::SocketAddr, }; use bound_socket::BoundTcpSocket; -use http::PeekedBody; -use hyper::body::Frame; +use http::{ClientStore, ResponseMode, StreamingBody}; +use http_gateway::HttpGatewayTask; use metadata_store::MetadataStore; use mirrord_intproxy_protocol::{ ConnMetadataRequest, ConnMetadataResponse, IncomingRequest, IncomingResponse, LayerId, @@ -23,23 +22,18 @@ use mirrord_intproxy_protocol::{ }; use mirrord_protocol::{ tcp::{ - ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, ChunkedResponse, DaemonTcp, - HttpResponse, InternalHttpBody, InternalHttpBodyFrame, LayerTcp, LayerTcpSteal, - NewTcpConnection, TcpData, HTTP_CHUNKED_RESPONSE_VERSION, HTTP_FRAMED_VERSION, + ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, DaemonTcp, HttpRequest, + InternalHttpBodyFrame, LayerTcp, LayerTcpSteal, NewTcpConnection, StealType, TcpData, }, - ClientMessage, ConnectionId, Port, RequestId, ResponseError, + ClientMessage, ConnectionId, RequestId, ResponseError, }; -use response_reader::HttpResponseReader; -use streaming_body::StreamingBody; +use tasks::{HttpGatewayId, HttpOut, InProxyTask, InProxyTaskError, InProxyTaskMessage}; +use tcp_proxy::{LocalTcpConnection, TcpProxyTask}; use thiserror::Error; -use tokio::sync::mpsc::{self, Sender}; +use tokio::sync::mpsc; use tracing::Level; -use self::{ - interceptor::{Interceptor, InterceptorError, MessageOut}, - port_subscription_ext::PortSubscriptionExt, - subscriptions::SubscriptionsManager, -}; +use self::subscriptions::SubscriptionsManager; use crate::{ background_tasks::{ BackgroundTask, BackgroundTasks, MessageBus, TaskError, TaskSender, TaskUpdate, @@ -50,23 +44,20 @@ use crate::{ mod bound_socket; mod http; -mod interceptor; +mod http_gateway; mod metadata_store; pub mod port_subscription_ext; -mod response_reader; -mod streaming_body; mod subscriptions; - -/// Identifies an [`HttpResponseReader`]. -type ReaderId = (ConnectionId, RequestId); +mod tasks; +mod tcp_proxy; /// Errors that can occur when handling the `incoming` feature. #[derive(Error, Debug)] pub enum IncomingProxyError { - #[error(transparent)] - Io(#[from] io::Error), + #[error("failed to prepare a TPC socket: {0}")] + SocketSetupFailed(#[source] io::Error), #[error("subscribing port failed: {0}")] - SubscriptionFailed(ResponseError), + SubscriptionFailed(#[source] ResponseError), } /// Messages consumed by [`IncomingProxy`] running as a [`BackgroundTask`]. @@ -81,14 +72,9 @@ pub enum IncomingProxyMessage { AgentProtocolVersion(semver::Version), } -/// Handle for an [`Interceptor`]. -struct InterceptorHandle { - /// A channel for sending messages to the [`Interceptor`] task. - tx: TaskSender, - /// Port subscription that the intercepted connection belongs to. - subscription: PortSubscription, - /// Senders for the bodies of in-progress HTTP requests. - request_body_txs: HashMap>, +struct HttpGatewayHandle { + _tx: TaskSender, + body_tx: Option>, } /// Handles logic and state of the `incoming` feature. @@ -110,20 +96,16 @@ pub struct IncomingProxy { subscriptions: SubscriptionsManager, /// For managing intercepted connections metadata. metadata_store: MetadataStore, - /// Determines which version of [`LayerTcpSteal`] we use to send HTTP responses to the agent. - agent_protocol_version: Option, - - /// [`TaskSender`]s for active [`Interceptor`]s. - interceptor_handles: HashMap, - /// For receiving updates from [`Interceptor`]s. - interceptors: BackgroundTasks, - - /// [TaskSender]s for active [`HttpResponseReader`]s. - /// - /// Keep the readers alive. - readers_txs: HashMap>, - /// For reading bodies of user app's HTTP responses. - readers: BackgroundTasks, + response_mode: ResponseMode, + /// Cache for [`LocalHttpClient`](http::LocalHttpClient)s. + client_store: ClientStore, + /// Each mirrored remote connection is mapped to a [TcpProxyTask] in mirror mode. + mirror_tcp_proxies: HashMap>, + /// Each remote connection stolen in whole is mapped to a [TcpProxyTask] in steal mode. + steal_tcp_proxies: HashMap>, + /// Each remote connection stolen with a filter is mapped to [HttpGatewayTask]s. + http_gateways: HashMap>, + tasks: BackgroundTasks, } impl IncomingProxy { @@ -133,52 +115,93 @@ impl IncomingProxy { /// Retrieves or creates an [`Interceptor`] for the given [`HttpRequestFallback`]. /// The request may or may not belong to an existing connection (when stealing with an http /// filter, connections are created implicitly). - #[tracing::instrument(level = Level::TRACE, skip(self, message_bus), err)] - async fn get_or_create_http_interceptor( + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] + async fn start_http_gateway( &mut self, - connection_id: ConnectionId, - port: Port, + request: HttpRequest, + body_tx: Option>, message_bus: &MessageBus, - ) -> Result, IncomingProxyError> { - let interceptor = match self.interceptor_handles.entry(connection_id) { - Entry::Occupied(e) => e.into_mut(), - - Entry::Vacant(e) => { - let Some(subscription) = self.subscriptions.get(port) else { - tracing::debug!( - port, - connection_id, - "Received a new connection for a port that is no longer subscribed, \ - sending an unsubscribe request.", - ); - + ) { + let subscription = self.subscriptions.get(request.port).filter(|subscription| { + matches!( + subscription.subscription, + PortSubscription::Steal( + StealType::FilteredHttp(..) | StealType::FilteredHttpEx(..) + ) + ) + }); + let Some(subscription) = subscription else { + tracing::debug!( + port = request.port, + connection_id = request.connection_id, + request_id = request.request_id, + "Received a new request within a stale port subscription, sending an unsubscribe request or an error response." + ); + + match self.http_gateways.entry(request.connection_id) { + // This is a new connection, we can just unsubscribe it. + Entry::Vacant(..) => { message_bus .send(ClientMessage::TcpSteal( - LayerTcpSteal::ConnectionUnsubscribe(connection_id), + LayerTcpSteal::ConnectionUnsubscribe(request.connection_id), )) .await; + } - return Ok(None); - }; - - let interceptor_socket = - BoundTcpSocket::bind_specified_or_localhost(subscription.listening_on.ip())?; - - let interceptor = self.interceptors.register( - Interceptor::new(interceptor_socket, subscription.listening_on), - connection_id, - Self::CHANNEL_SIZE, - ); + // This is not a new connection, but we don't have any requests in progress. + // We can still unsubscribe it. + Entry::Occupied(e) if e.get().is_empty() => { + message_bus + .send(ClientMessage::TcpSteal( + LayerTcpSteal::ConnectionUnsubscribe(request.connection_id), + )) + .await; + e.remove(); + } - e.insert(InterceptorHandle { - tx: interceptor, - subscription: subscription.subscription.clone(), - request_body_txs: Default::default(), - }) + // This is not a new connection, and we have requests in progress. + // We can only send an error response. + Entry::Occupied(..) => { + let response = http::mirrord_error_response( + "port no longer subscribed with an HTTP filter", + request.version(), + request.connection_id, + request.request_id, + request.port, + ); + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse( + response, + ))) + .await; + } } + + return; }; - Ok(Some(interceptor)) + let connection_id = request.connection_id; + let request_id = request.request_id; + let id = HttpGatewayId { + connection_id, + request_id, + port: request.port, + version: request.version(), + }; + let tx = self.tasks.register( + HttpGatewayTask::new( + request, + self.client_store.clone(), + self.response_mode, + subscription.listening_on, + ), + InProxyTask::HttpGateway(id), + Self::CHANNEL_SIZE, + ); + self.http_gateways + .entry(connection_id) + .or_default() + .insert(request_id, HttpGatewayHandle { _tx: tx, body_tx }); } /// Handles all agent messages. @@ -191,72 +214,49 @@ impl IncomingProxy { ) -> Result<(), IncomingProxyError> { match message { DaemonTcp::Close(close) => { - self.readers_txs.retain(|id, _| id.0 != close.connection_id); - self.interceptor_handles.remove(&close.connection_id); + if is_steal { + self.steal_tcp_proxies.remove(&close.connection_id); + self.http_gateways.remove(&close.connection_id); + } else { + self.mirror_tcp_proxies.remove(&close.connection_id); + } } DaemonTcp::Data(data) => { - if let Some(interceptor) = self.interceptor_handles.get(&data.connection_id) { - interceptor.tx.send(data.bytes).await; + let tx: Option<&TaskSender> = if is_steal { + self.steal_tcp_proxies.get(&data.connection_id) + } else { + self.mirror_tcp_proxies.get(&data.connection_id) + }; + + if let Some(tx) = tx { + tx.send(data.bytes).await; } else { tracing::debug!( connection_id = data.connection_id, - "Received new data for a connection that is already closed", + "Received new data for a connection that does not belong to any TcpProxy task", ); } } DaemonTcp::HttpRequest(request) => { - let interceptor = self - .get_or_create_http_interceptor( - request.connection_id, - request.port, - message_bus, - ) - .await?; - - if let Some(interceptor) = interceptor { - interceptor - .tx - .send(request.map_body(StreamingBody::from)) - .await; - } + self.start_http_gateway(request.map_body(From::from), None, message_bus) + .await; } DaemonTcp::HttpRequestFramed(request) => { - let interceptor = self - .get_or_create_http_interceptor( - request.connection_id, - request.port, - message_bus, - ) - .await?; - - if let Some(interceptor) = interceptor { - interceptor - .tx - .send(request.map_body(StreamingBody::from)) - .await; - } + self.start_http_gateway(request.map_body(From::from), None, message_bus) + .await; } DaemonTcp::HttpRequestChunked(request) => { match request { ChunkedRequest::Start(request) => { - let interceptor = self - .get_or_create_http_interceptor( - request.connection_id, - request.port, - message_bus, - ) - .await?; - - if let Some(interceptor) = interceptor { - let (tx, rx) = mpsc::channel::(128); - let request = request.map_body(|frames| StreamingBody::new(rx, frames)); - interceptor.request_body_txs.insert(request.request_id, tx); - interceptor.tx.send(request).await; - } + let (body_tx, body_rx) = mpsc::channel(128); + let request = + request.map_body(|frames| StreamingBody::new(body_rx, frames)); + self.start_http_gateway(request, Some(body_tx), message_bus) + .await; } ChunkedRequest::Body(ChunkedHttpBody { @@ -265,33 +265,32 @@ impl IncomingProxy { connection_id, request_id, }) => { - let Some(interceptor) = self.interceptor_handles.get_mut(&connection_id) - else { + let gateway = self + .http_gateways + .get_mut(&connection_id) + .and_then(|gateways| gateways.get_mut(&request_id)); + let Some(gateway) = gateway else { return Ok(()); }; - let Entry::Occupied(tx) = interceptor.request_body_txs.entry(request_id) - else { + let Some(tx) = gateway.body_tx.as_ref() else { return Ok(()); }; - let mut send_err = false; - for frame in frames { - if let Err(err) = tx.get().send(frame).await { - send_err = true; + if let Err(err) = tx.send(frame).await { tracing::debug!( frame = ?err.0, connection_id, request_id, - "Failed to send an HTTP request body frame to the interceptor, channel is closed" + "Failed to send an HTTP request body frame to the HttpGateway task, channel is closed" ); break; } } - if send_err || is_last { - tx.remove(); + if is_last { + gateway.body_tx = None; } } @@ -299,16 +298,15 @@ impl IncomingProxy { connection_id, request_id, }) => { - if let Some(interceptor) = self.interceptor_handles.get_mut(&connection_id) - { - interceptor.request_body_txs.remove(&request_id); - }; - tracing::debug!( connection_id, request_id, "Received an error in an HTTP request body", ); + + if let Some(gateways) = self.http_gateways.get_mut(&connection_id) { + gateways.remove(&request_id); + }; } }; } @@ -320,12 +318,19 @@ impl IncomingProxy { source_port, local_address, }) => { - let Some(subscription) = self.subscriptions.get(destination_port) else { + let subscription = + self.subscriptions + .get(destination_port) + .filter(|subscription| match &subscription.subscription { + PortSubscription::Mirror(..) if !is_steal => true, + PortSubscription::Steal(StealType::All(..)) if is_steal => true, + _ => false, + }); + let Some(subscription) = subscription else { tracing::debug!( port = destination_port, connection_id, - "Received a new connection for a port that is no longer subscribed, \ - sending an unsubscribe request.", + "Received a new connection within a stale port subscription, sending an unsubscribe request.", ); let message = if is_steal { @@ -338,13 +343,16 @@ impl IncomingProxy { return Ok(()); }; - let interceptor_socket = - BoundTcpSocket::bind_specified_or_localhost(subscription.listening_on.ip())?; + let socket = + BoundTcpSocket::bind_specified_or_localhost(subscription.listening_on.ip()) + .map_err(IncomingProxyError::SocketSetupFailed)?; self.metadata_store.expect( ConnMetadataRequest { listener_address: subscription.listening_on, - peer_address: interceptor_socket.local_addr()?, + peer_address: socket + .local_addr() + .map_err(IncomingProxyError::SocketSetupFailed)?, }, connection_id, ConnMetadataResponse { @@ -353,20 +361,28 @@ impl IncomingProxy { }, ); - let interceptor = self.interceptors.register( - Interceptor::new(interceptor_socket, subscription.listening_on), - connection_id, + let id = if is_steal { + InProxyTask::StealTcpProxy(connection_id) + } else { + InProxyTask::MirrorTcpProxy(connection_id) + }; + let tx = self.tasks.register( + TcpProxyTask::new( + LocalTcpConnection::FromTheStart { + socket, + peer: subscription.listening_on, + }, + !is_steal, + ), + id, Self::CHANNEL_SIZE, ); - self.interceptor_handles.insert( - connection_id, - InterceptorHandle { - tx: interceptor, - subscription: subscription.subscription.clone(), - request_body_txs: Default::default(), - }, - ); + if is_steal { + self.steal_tcp_proxies.insert(connection_id, tx); + } else { + self.mirror_tcp_proxies.insert(connection_id, tx); + } } DaemonTcp::SubscribeResult(result) => { @@ -381,110 +397,6 @@ impl IncomingProxy { Ok(()) } - /// Handles an HTTP response coming from one of the interceptors. - /// - /// If all response frames are already available, sends the response in a single message. - /// Otherwise, starts a response reader to handle the response. - #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] - async fn handle_http_response( - &mut self, - mut response: HttpResponse, - message_bus: &mut MessageBus, - ) { - let tail = match response.internal_response.body.tail.take() { - Some(tail) => tail, - - // All frames are already fetched, we don't have to wait for the body. - // We can send just one message. - None => { - let message = if self.agent_handles_framed_responses() { - let response = response.map_body(|body| { - InternalHttpBody( - body.head - .into_iter() - .map(InternalHttpBodyFrame::from) - .collect::>(), - ) - }); - LayerTcpSteal::HttpResponseFramed(response) - } else { - // Agent does not support `LayerTcpSteal::HttpResponseFramed`. - // We can only use legacy `LayerTcpSteal::HttpResponse`, - // which drops trailing headers. - let connection_id = response.connection_id; - let request_id = response.request_id; - let response = response.map_body(|body| { - let mut new_body = Vec::with_capacity(body.head.iter().filter_map(Frame::data_ref).map(|data| data.len()).sum()); - body.head.into_iter().for_each(|frame| match frame.into_data() { - Ok(data) => new_body.extend(data), - Err(frame) => { - if let Some(headers) = frame.trailers_ref() { - tracing::warn!( - connection_id, - request_id, - agent_protocol_version = ?self.agent_protocol_version, - ?headers, - "Agent uses an outdated version of mirrord protocol, \ - we can't send trailing headers from the local application's HTTP response." - ) - } - } - }); - new_body - }); - LayerTcpSteal::HttpResponse(response) - }; - - message_bus.send(ClientMessage::TcpSteal(message)).await; - - return; - } - }; - - let reader_id = (response.connection_id, response.request_id); - let response_reader = if self.agent_handles_streamed_responses() { - let response = response.map_body(|body| { - body.head - .into_iter() - .map(InternalHttpBodyFrame::from) - .collect::>() - }); - let connection_id = response.connection_id; - let request_id = response.request_id; - let message = ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( - ChunkedResponse::Start(response), - )); - message_bus.send(message).await; - HttpResponseReader::Chunked { - connection_id, - request_id, - body: tail, - } - } else if self.agent_handles_framed_responses() { - response.internal_response.body.tail.replace(tail); - HttpResponseReader::Framed(response) - } else { - response.internal_response.body.tail.replace(tail); - HttpResponseReader::Legacy(response) - }; - - self.readers.register(response_reader, reader_id, 16); - } - - fn agent_handles_framed_responses(&self) -> bool { - self.agent_protocol_version - .as_ref() - .map(|version| HTTP_FRAMED_VERSION.matches(version)) - .unwrap_or_default() - } - - fn agent_handles_streamed_responses(&self) -> bool { - self.agent_protocol_version - .as_ref() - .map(|version| HTTP_CHUNKED_RESPONSE_VERSION.matches(version)) - .unwrap_or_default() - } - #[tracing::instrument(level = Level::TRACE, skip(self, message_bus), err)] async fn handle_message( &mut self, @@ -544,7 +456,7 @@ impl IncomingProxy { } IncomingProxyMessage::AgentProtocolVersion(version) => { - self.agent_protocol_version.replace(version); + self.response_mode = ResponseMode::from(&version); } } @@ -552,92 +464,162 @@ impl IncomingProxy { } #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] - async fn handle_interceptor_update( + async fn handle_task_update( &mut self, - connection_id: ConnectionId, - update: TaskUpdate, + id: InProxyTask, + update: TaskUpdate, message_bus: &mut MessageBus, ) { - match update { - TaskUpdate::Finished(res) => { - if let Err(error) = res { - tracing::warn!(connection_id, %error, "Incoming interceptor failed"); - } + match (id, update) { + (InProxyTask::MirrorTcpProxy(connection_id), TaskUpdate::Finished(result)) => { + match result { + Err(TaskError::Error(error)) => { + tracing::warn!(connection_id, %error, "MirrorTcpProxy task failed"); + } + Err(TaskError::Panic) => { + tracing::error!(connection_id, "MirrorTcpProxy task panicked"); + } + Ok(()) => {} + }; self.metadata_store.no_longer_expect(connection_id); - let msg = self - .interceptor_handles - .get(&connection_id) - .map(|interceptor| { - interceptor - .subscription - .wrap_agent_unsubscribe_connection(connection_id) - }); - if let Some(msg) = msg { - message_bus.send(msg).await; + if self.mirror_tcp_proxies.remove(&connection_id).is_some() { + message_bus + .send(ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe( + connection_id, + ))) + .await; } - - self.interceptor_handles.remove(&connection_id); } - TaskUpdate::Message(msg) => { - let Some(PortSubscription::Steal(_)) = self - .interceptor_handles - .get(&connection_id) - .map(|interceptor| &interceptor.subscription) - else { - return; + (InProxyTask::MirrorTcpProxy(..), TaskUpdate::Message(..)) => unreachable!(), + + (InProxyTask::StealTcpProxy(connection_id), TaskUpdate::Finished(result)) => { + match result { + Err(TaskError::Error(error)) => { + tracing::warn!(connection_id, %error, "StealTcpProxy task failed"); + } + Err(TaskError::Panic) => { + tracing::error!(connection_id, "StealTcpProxy task panicked"); + } + Ok(()) => {} }; - match msg { - MessageOut::Raw(bytes) => { - let msg = ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { + self.metadata_store.no_longer_expect(connection_id); + + if self.steal_tcp_proxies.remove(&connection_id).is_some() { + message_bus + .send(ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe( connection_id, - bytes, - })); + ))) + .await; + } + } - message_bus.send(msg).await; - } + ( + InProxyTask::StealTcpProxy(connection_id), + TaskUpdate::Message(InProxyTaskMessage::Tcp(bytes)), + ) => { + if self.steal_tcp_proxies.contains_key(&connection_id) { + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { + connection_id, + bytes, + }))) + .await; + } + } - MessageOut::Http(response) => { - self.handle_http_response(response, message_bus).await; - } - }; + (InProxyTask::StealTcpProxy(..), TaskUpdate::Message(InProxyTaskMessage::Http(..))) => { + unreachable!() } - } - } - #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] - async fn handle_response_reader_update( - &mut self, - connection_id: ConnectionId, - request_id: RequestId, - update: TaskUpdate, - message_bus: &mut MessageBus, - ) { - match update { - TaskUpdate::Finished(Ok(())) => { - self.readers_txs.remove(&(connection_id, request_id)); + (InProxyTask::HttpGateway(id), TaskUpdate::Finished(result)) => { + let respond_on_panic = self + .http_gateways + .get_mut(&id.connection_id) + .and_then(|gateways| gateways.remove(&id.request_id)) + .is_some(); + + match result { + Ok(()) => {} + Err(TaskError::Error( + InProxyTaskError::IoError(..) | InProxyTaskError::UpgradeError(..), + )) => unreachable!(), + Err(TaskError::Panic) => { + tracing::error!( + connection_id = id.connection_id, + request_id = id.request_id, + "HttpGatewayTask panicked", + ); + + if respond_on_panic { + let response = http::mirrord_error_response( + "HTTP gateway task panicked", + id.version, + id.connection_id, + id.request_id, + id.port, + ); + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse( + response, + ))) + .await; + } + } + } } - TaskUpdate::Finished(Err(TaskError::Panic)) => { - tracing::error!(connection_id, request_id, "HttpResponseReader panicked"); + ( + InProxyTask::HttpGateway(id), + TaskUpdate::Message(InProxyTaskMessage::Http(message)), + ) => { + let exists = self + .http_gateways + .get(&id.connection_id) + .and_then(|gateways| gateways.get(&id.request_id)) + .is_some(); + if !exists { + return; + } - self.readers_txs.remove(&(connection_id, request_id)); - if let Some(interceptor) = self.interceptor_handles.remove(&connection_id) { - message_bus - .send( - interceptor - .subscription - .wrap_agent_unsubscribe_connection(connection_id), - ) - .await; + match message { + HttpOut::ResponseBasic(response) => { + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse( + response, + ))) + .await + } + HttpOut::ResponseFramed(response) => { + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed( + response, + ))) + .await + } + HttpOut::ResponseChunked(response) => { + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( + response, + ))) + .await; + } + HttpOut::Upgraded(on_upgrade) => { + let proxy = self.tasks.register( + TcpProxyTask::new(LocalTcpConnection::AfterUpgrade(on_upgrade), false), + InProxyTask::StealTcpProxy(id.connection_id), + Self::CHANNEL_SIZE, + ); + self.steal_tcp_proxies.insert(id.connection_id, proxy); + } } } - TaskUpdate::Message(msg) => { - message_bus.send(ClientMessage::TcpSteal(msg)).await; + (InProxyTask::HttpGateway(..), TaskUpdate::Message(InProxyTaskMessage::Tcp(..))) => { + unreachable!() } } } @@ -660,9 +642,7 @@ impl BackgroundTask for IncomingProxy { Some(message) => self.handle_message(message, message_bus).await?, }, - Some((id, update)) = self.interceptors.next() => self.handle_interceptor_update(id, update, message_bus).await, - - Some((id, update)) = self.readers.next() => self.handle_response_reader_update(id.0, id.1, update, message_bus).await, + Some((id, update)) = self.tasks.next() => self.handle_task_update(id, update, message_bus).await, } } } diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index e8687bdbb0d..b01aee0cbf9 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -1,188 +1,77 @@ -use std::{ - error::Error, - fmt, io, - net::SocketAddr, - ops::Not, - time::{Duration, Instant}, -}; +use std::{error::Error, fmt, io, net::SocketAddr}; -use bytes::Bytes; -use exponential_backoff::Backoff; use hyper::{ - body::{Frame, Incoming}, + body::Incoming, client::conn::{http1, http2}, Request, Response, StatusCode, Version, }; use hyper_util::rt::{TokioExecutor, TokioIo}; use mirrord_protocol::{ - batched_body::BatchedBody, tcp::{HttpRequest, HttpResponse, InternalHttpResponse}, ConnectionId, Port, RequestId, }; use thiserror::Error; -use tokio::{net::TcpStream, time}; +use tokio::net::TcpStream; use tracing::Level; -use super::{bound_socket::BoundTcpSocket, streaming_body::StreamingBody}; +mod client_store; +mod response_mode; +mod streaming_body; + +pub use client_store::ClientStore; +pub use response_mode::ResponseMode; +pub use streaming_body::StreamingBody; -/// A retrying HTTP client used to pass requests to the user application. +/// An HTTP client used to pass requests to the user application. pub struct LocalHttpClient { /// Established HTTP connection with the user application. - sender: Option, - /// Established TCP connection with the user application. - stream: Option, + sender: HttpSender, /// Address of the user application's HTTP server. local_server_address: SocketAddr, } impl LocalHttpClient { - /// How many times we attempt to send any given request. - /// - /// See [`LocalHttpError::can_retry`]. - const MAX_SEND_ATTEMPTS: u32 = 10; - const MIN_SEND_BACKOFF: Duration = Duration::from_millis(10); - const MAX_SEND_BACKOFF: Duration = Duration::from_millis(250); - - /// Crates a new client that will initially use the given `stream` (connection with the user - /// application's HTTP server). - pub fn new_for_stream(stream: TcpStream) -> Result { + /// Makes an HTTP connection with the given server and creates a new client. + pub async fn new( + local_server_address: SocketAddr, + version: Version, + ) -> Result { + let stream = TcpStream::connect(local_server_address) + .await + .map_err(LocalHttpError::ConnectTcpFailed)?; let local_server_address = stream .peer_addr() .map_err(LocalHttpError::SocketSetupFailed)?; + let sender = HttpSender::handshake(version, stream).await?; Ok(Self { - sender: None, - stream: Some(stream), + sender, local_server_address, }) } - /// Reuses or creates a new [`HttpSender`]. - async fn get_sender(&mut self, version: Version) -> Result { - if let Some(sender) = self.sender.take() { - if sender.version_matches(version) { - tracing::trace!("Reusing the HTTP connection."); - return Ok(sender); - } else { - tracing::trace!("HTTP connection found, but the HTTP version does not match."); - } - } - - let stream = match self.stream.take() { - Some(stream) => { - tracing::trace!("Reusing the TCP connection."); - stream - } - None => { - let socket = - BoundTcpSocket::bind_specified_or_localhost(self.local_server_address.ip()) - .map_err(LocalHttpError::SocketSetupFailed)?; - - let start = Instant::now(); - let socket = socket - .connect(self.local_server_address) - .await - .map_err(LocalHttpError::ConnectTcpFailed)?; - tracing::trace!( - elapsed_s = start.elapsed().as_secs_f32(), - "Made the TCP connection" - ); - - socket - } - }; - - let start = Instant::now(); - let sender = HttpSender::handshake(version, stream).await?; - tracing::trace!( - elapsed_s = start.elapsed().as_secs_f32(), - "Made the HTTP connection" - ); - - Ok(sender) - } - - /// Tries to send the given `request` to the user application's HTTP server. - /// - /// Checks whether some reponse [`Frame`]s are instantly available. - async fn try_send_request( - &mut self, - request: &HttpRequest, - ) -> Result, LocalHttpError> { - let mut sender = self.get_sender(request.version()).await?; - - let start = Instant::now(); - let response = sender.send_request(request.clone()).await?; - tracing::trace!( - elapsed_s = start.elapsed().as_secs_f32(), - "Sent the HTTP request" - ); - - let (parts, mut body) = response.into_parts(); - - let frames = body - .ready_frames() - .map_err(LocalHttpError::ReadBodyFailed)?; - let body = PeekedBody { - head: frames.frames, - tail: frames.is_last.not().then_some(body), - }; - - self.sender.replace(sender); - - Ok(Response::from_parts(parts, body)) - } - /// Tries to send the given `request` to the user application's HTTP server. /// /// Retries on known errors (see [`LocalHttpError::can_retry`]). #[tracing::instrument(level = Level::DEBUG, err(level = Level::WARN), ret)] pub async fn send_request( &mut self, - request: &HttpRequest, - ) -> Result, LocalHttpError> { - let mut backoffs = Backoff::new( - Self::MAX_SEND_ATTEMPTS, - Self::MIN_SEND_BACKOFF, - Self::MAX_SEND_BACKOFF, - ) - .into_iter() - .flatten(); - - let mut attempt = 0; - loop { - attempt += 1; - tracing::trace!(attempt, "Trying to send the request"); - match (self.try_send_request(request).await, backoffs.next()) { - (Ok(response), _) => { - break Ok(response); - } - - (Err(error), Some(backoff)) if error.can_retry() => { - tracing::warn!( - attempt, - connection_id = request.connection_id, - request_id = request.request_id, - %error, - backoff_s = backoff.as_secs_f32(), - "Failed to send the request to the local application, retrying", - ); - - time::sleep(backoff).await; - } + request: HttpRequest, + ) -> Result, LocalHttpError> { + self.sender.send_request(request).await + } - (Err(error), _) => { - tracing::warn!( - attempts = attempt, - connection_id = request.connection_id, - request_id = request.request_id, - %error, - "Failed to send the request to the local application", - ); + /// Returns the address of the local server to which this client is connected. + pub fn local_server_address(&self) -> SocketAddr { + self.local_server_address + } - break Err(error); - } - } + pub fn handles_version(&self, version: Version) -> bool { + match (&self.sender, version) { + (_, Version::HTTP_3) => false, + (HttpSender::V2(..), Version::HTTP_2) => true, + (HttpSender::V1(..), _) => true, + (HttpSender::V2(..), _) => false, } } } @@ -190,8 +79,6 @@ impl LocalHttpClient { impl fmt::Debug for LocalHttpClient { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("LocalHttpClient") - .field("has_sender", &self.sender.is_some()) - .field("has_stream", &self.stream.is_some()) .field("local_server_address", &self.local_server_address) .finish() } @@ -200,7 +87,7 @@ impl fmt::Debug for LocalHttpClient { /// Errors that can occur when sending an HTTP request to the user application. #[derive(Error, Debug)] pub enum LocalHttpError { - #[error("HTTP handshake failed: {0}")] + #[error("handshake failed: {0}")] HandshakeFailed(#[source] hyper::Error), #[error("{0:?} is not supported")] @@ -248,43 +135,26 @@ impl LocalHttpError { } } } - - /// Produces a [`StatusCode::BAD_GATEWAY`] response from this error. - pub fn as_error_response( - &self, - version: Version, - request_id: RequestId, - connection_id: ConnectionId, - port: Port, - ) -> HttpResponse> { - HttpResponse { - request_id, - connection_id, - port, - internal_response: InternalHttpResponse { - status: StatusCode::BAD_GATEWAY, - version, - headers: Default::default(), - body: format!("mirrord: {self}").into_bytes(), - }, - } - } } -/// Response body returned from [`LocalHttpClient`]. -pub struct PeekedBody { - /// [`Frame`]s that were instantly available. - pub head: Vec>, - /// The rest of the response's body. - pub tail: Option, -} - -impl fmt::Debug for PeekedBody { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("PeekedBody") - .field("head", &self.head) - .field("has_tail", &self.tail.is_some()) - .finish() +/// Produces a mirrord-specific [`StatusCode::BAD_GATEWAY`] response. +pub fn mirrord_error_response( + message: M, + version: Version, + connection_id: ConnectionId, + request_id: RequestId, + port: Port, +) -> HttpResponse> { + HttpResponse { + connection_id, + port, + request_id, + internal_response: InternalHttpResponse { + status: StatusCode::BAD_GATEWAY, + version, + headers: Default::default(), + body: format!("mirrord: {message}\n").into_bytes(), + }, } } @@ -392,16 +262,6 @@ impl HttpSender { } } } - - /// Returns whether this [`HttpSender`] can handle requests of the given [`Version`]. - fn version_matches(&self, version: Version) -> bool { - match (version, self) { - (Version::HTTP_2, Self::V2(..)) => true, - (Version::HTTP_3, _) => false, - (_, Self::V1(..)) => true, - _ => false, - } - } } #[cfg(test)] diff --git a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs new file mode 100644 index 00000000000..b22c26ed06c --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs @@ -0,0 +1,142 @@ +use std::{cmp, net::SocketAddr, time::Duration}; + +use hyper::Version; +use tokio::{ + sync::watch, + time::{self, Instant}, +}; + +use super::{LocalHttpClient, LocalHttpError}; + +struct IdleLocalClient { + client: LocalHttpClient, + last_used: Instant, +} + +#[derive(Clone)] +pub struct ClientStore(watch::Sender>); + +impl Default for ClientStore { + fn default() -> Self { + let (tx, _) = watch::channel(Default::default()); + + tokio::spawn(cleanup_task(tx.clone())); + + Self(tx) + } +} + +impl ClientStore { + const IDLE_CLIENT_TIMEOUT: Duration = Duration::from_secs(3); + + pub async fn get( + &self, + server_addr: SocketAddr, + version: Version, + ) -> Result { + let mut ready = None; + + self.0.send_if_modified(|clients| { + let position = clients.iter().position(|idle| { + idle.client.handles_version(version) + && idle.client.local_server_address() == server_addr + }); + + let Some(position) = position else { + return false; + }; + + let client = clients.swap_remove(position).client; + ready.replace(client); + true + }); + + if let Some(ready) = ready { + return Ok(ready); + } + + let connect_task = tokio::spawn(LocalHttpClient::new(server_addr, version)); + + tokio::select! { + result = connect_task => result.expect("this task should not panic"), + ready = self.wait_for_ready(server_addr, version) => Ok(ready), + } + } + + pub fn push_idle(&self, client: LocalHttpClient) { + self.0.send_modify(|clients| { + clients.push(IdleLocalClient { + client, + last_used: Instant::now(), + }) + }); + } + + async fn wait_for_ready(&self, server_addr: SocketAddr, version: Version) -> LocalHttpClient { + let mut recevier = self.0.subscribe(); + + loop { + let mut ready = None; + + self.0.send_if_modified(|clients| { + let position = clients.iter().position(|idle| { + idle.client.handles_version(version) + && idle.client.local_server_address() == server_addr + }); + let Some(position) = position else { + return false; + }; + + let client = clients.swap_remove(position).client; + ready.replace(client); + + true + }); + + if let Some(ready) = ready { + break ready; + } + + recevier + .changed() + .await + .expect("sender alive in this struct"); + } + } +} + +async fn cleanup_task(clients: watch::Sender>) { + loop { + let now = Instant::now(); + let mut min_last_used = None; + + clients.send_if_modified(|clients| { + let mut removed = false; + + clients.retain(|client| { + if client.last_used + ClientStore::IDLE_CLIENT_TIMEOUT > now { + min_last_used = min_last_used + .map(|previous| cmp::min(previous, client.last_used)) + .or(Some(client.last_used)); + + true + } else { + removed = true; + false + } + }); + + removed + }); + + if let Some(min_last_used) = min_last_used { + time::sleep_until(min_last_used + ClientStore::IDLE_CLIENT_TIMEOUT).await; + } else { + clients + .subscribe() + .changed() + .await + .expect("sender alive in this function"); + } + } +} diff --git a/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs b/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs new file mode 100644 index 00000000000..4c9140f1beb --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs @@ -0,0 +1,21 @@ +use mirrord_protocol::tcp::{HTTP_CHUNKED_RESPONSE_VERSION, HTTP_FRAMED_VERSION}; + +#[derive(Debug, Clone, Copy, Default)] +pub enum ResponseMode { + Chunked, + Framed, + #[default] + Basic, +} + +impl From<&semver::Version> for ResponseMode { + fn from(value: &semver::Version) -> Self { + if HTTP_CHUNKED_RESPONSE_VERSION.matches(value) { + Self::Chunked + } else if HTTP_FRAMED_VERSION.matches(value) { + Self::Framed + } else { + Self::Basic + } + } +} diff --git a/mirrord/intproxy/src/proxies/incoming/streaming_body.rs b/mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs similarity index 100% rename from mirrord/intproxy/src/proxies/incoming/streaming_body.rs rename to mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs diff --git a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs new file mode 100644 index 00000000000..02d8cdac9a6 --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs @@ -0,0 +1,682 @@ +use std::{convert::Infallible, net::SocketAddr, time::Duration}; + +use exponential_backoff::Backoff; +use http_body_util::BodyExt; +use hyper::StatusCode; +use mirrord_protocol::{ + batched_body::BatchedBody, + tcp::{ + ChunkedHttpBody, ChunkedHttpError, ChunkedResponse, HttpRequest, HttpResponse, + InternalHttpBody, InternalHttpBodyFrame, InternalHttpResponse, + }, +}; +use tokio::time; + +use super::{ + http::{mirrord_error_response, ClientStore, LocalHttpError, ResponseMode, StreamingBody}, + tasks::{HttpOut, InProxyTaskMessage}, +}; +use crate::background_tasks::{BackgroundTask, MessageBus}; + +pub struct HttpGatewayTask { + request: HttpRequest, + client_store: ClientStore, + response_mode: ResponseMode, + server_addr: SocketAddr, +} + +impl HttpGatewayTask { + pub fn new( + request: HttpRequest, + client_store: ClientStore, + response_mode: ResponseMode, + server_addr: SocketAddr, + ) -> Self { + Self { + request, + client_store, + response_mode, + server_addr, + } + } + + async fn send_attempt(&self, message_bus: &mut MessageBus) -> Result<(), LocalHttpError> { + let mut client = self + .client_store + .get(self.server_addr, self.request.version()) + .await?; + let mut response = client.send_request(self.request.clone()).await?; + let on_upgrade = (response.status() == StatusCode::SWITCHING_PROTOCOLS) + .then(|| hyper::upgrade::on(&mut response)); + let (parts, mut body) = response.into_parts(); + + match self.response_mode { + ResponseMode::Basic => { + let body: Vec = body + .collect() + .await + .map_err(LocalHttpError::ReadBodyFailed)? + .to_bytes() + .into(); + let response = HttpResponse { + port: self.request.port, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + internal_response: InternalHttpResponse { + status: parts.status, + version: parts.version, + headers: parts.headers, + body, + }, + }; + message_bus.send(HttpOut::ResponseBasic(response)).await + } + ResponseMode::Framed => { + let body = InternalHttpBody::from_body(body) + .await + .map_err(LocalHttpError::ReadBodyFailed)?; + let response = HttpResponse { + port: self.request.port, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + internal_response: InternalHttpResponse { + status: parts.status, + version: parts.version, + headers: parts.headers, + body, + }, + }; + message_bus.send(HttpOut::ResponseFramed(response)).await + } + ResponseMode::Chunked => { + let ready_frames = body + .ready_frames() + .map_err(LocalHttpError::ReadBodyFailed)? + .frames + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect(); + let response = HttpResponse { + port: self.request.port, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + internal_response: InternalHttpResponse { + status: parts.status, + version: parts.version, + headers: parts.headers, + body: ready_frames, + }, + }; + message_bus + .send(HttpOut::ResponseChunked(ChunkedResponse::Start(response))) + .await; + + loop { + match body.next_frames().await { + Ok(frames) => { + message_bus + .send(HttpOut::ResponseChunked(ChunkedResponse::Body( + ChunkedHttpBody { + frames: frames + .frames + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect(), + is_last: frames.is_last, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + }, + ))) + .await; + if frames.is_last { + break; + } + } + // Do not return any error here, + // as it would be transformed into an error response by the caller. + // We already send the request head to the agent. + Err(..) => { + message_bus + .send(HttpOut::ResponseChunked(ChunkedResponse::Error( + ChunkedHttpError { + connection_id: self.request.connection_id, + request_id: self.request.request_id, + }, + ))) + .await; + + return Ok(()); + } + } + } + } + } + + if let Some(on_upgrade) = on_upgrade { + message_bus.send(HttpOut::Upgraded(on_upgrade)).await; + } + + self.client_store.push_idle(client); + + Ok(()) + } +} + +impl BackgroundTask for HttpGatewayTask { + type Error = Infallible; + type MessageIn = Infallible; + type MessageOut = InProxyTaskMessage; + + async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { + let mut backoffs = + Backoff::new(10, Duration::from_millis(50), Duration::from_millis(500)).into_iter(); + let guard = message_bus.closed(); + + let error = loop { + match guard.cancel_on_close(self.send_attempt(message_bus)).await { + None | Some(Ok(())) => return Ok(()), + Some(Err(error)) => { + let backoff = error + .can_retry() + .then(|| backoffs.next()) + .flatten() + .flatten(); + let Some(backoff) = backoff else { + break error; + }; + + if guard.cancel_on_close(time::sleep(backoff)).await.is_none() { + return Ok(()); + } + } + } + }; + + let response = mirrord_error_response( + error, + self.request.version(), + self.request.connection_id, + self.request.request_id, + self.request.port, + ); + message_bus.send(HttpOut::ResponseBasic(response)).await; + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use std::{io, sync::Arc}; + + use bytes::Bytes; + use http_body_util::{Empty, StreamBody}; + use hyper::{ + body::{Frame, Incoming}, + header::{self, HeaderValue, CONNECTION, UPGRADE}, + server::conn::http1, + service::service_fn, + upgrade::Upgraded, + Method, Request, Response, StatusCode, Version, + }; + use hyper_util::rt::TokioIo; + use mirrord_protocol::tcp::{HttpRequest, InternalHttpRequest}; + use rstest::rstest; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpListener, + sync::{mpsc, watch, Semaphore}, + task, + }; + use tokio_stream::wrappers::ReceiverStream; + + use super::*; + use crate::{ + background_tasks::{BackgroundTasks, TaskUpdate}, + proxies::incoming::{ + tcp_proxy::{LocalTcpConnection, TcpProxyTask}, + InProxyTaskError, + }, + }; + + /// Binary protocol over TCP. + /// Server first sends bytes [`INITIAL_MESSAGE`], then echoes back all received data. + const TEST_PROTO: &str = "dummyecho"; + + const INITIAL_MESSAGE: &[u8] = &[0x4a, 0x50, 0x32, 0x47, 0x4d, 0x44]; + + /// Handles requests upgrading to the [`TEST_PROTO`] protocol. + async fn upgrade_req_handler( + mut req: Request, + ) -> hyper::Result>> { + async fn dummy_echo(upgraded: Upgraded) -> io::Result<()> { + let mut upgraded = TokioIo::new(upgraded); + let mut buf = [0_u8; 64]; + + upgraded.write_all(INITIAL_MESSAGE).await?; + + loop { + let bytes_read = upgraded.read(&mut buf[..]).await?; + if bytes_read == 0 { + break; + } + + let echo_back = buf.get(0..bytes_read).unwrap(); + upgraded.write_all(echo_back).await?; + } + + Ok(()) + } + + let mut res = Response::new(Empty::new()); + + let contains_expected_upgrade = req + .headers() + .get(UPGRADE) + .filter(|proto| *proto == TEST_PROTO) + .is_some(); + if !contains_expected_upgrade { + *res.status_mut() = StatusCode::BAD_REQUEST; + return Ok(res); + } + + task::spawn(async move { + match hyper::upgrade::on(&mut req).await { + Ok(upgraded) => { + if let Err(e) = dummy_echo(upgraded).await { + eprintln!("server foobar io error: {}", e) + }; + } + Err(e) => eprintln!("upgrade error: {}", e), + } + }); + + *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; + res.headers_mut() + .insert(UPGRADE, HeaderValue::from_static(TEST_PROTO)); + res.headers_mut() + .insert(CONNECTION, HeaderValue::from_static("upgrade")); + Ok(res) + } + + /// Runs a [`hyper`] server that accepts only requests upgrading to the [`TEST_PROTO`] protocol. + async fn dummy_echo_server(listener: TcpListener, mut shutdown: watch::Receiver) { + loop { + tokio::select! { + res = listener.accept() => { + let (stream, _) = res.expect("dummy echo server failed to accept connection"); + + let mut shutdown = shutdown.clone(); + + task::spawn(async move { + let conn = http1::Builder::new().serve_connection(TokioIo::new(stream), service_fn(upgrade_req_handler)); + let mut conn = conn.with_upgrades(); + let mut conn = Pin::new(&mut conn); + + tokio::select! { + res = &mut conn => { + res.expect("dummy echo server failed to serve connection"); + } + + _ = shutdown.changed() => { + conn.graceful_shutdown(); + } + } + }); + } + + _ = shutdown.changed() => break, + } + } + } + + #[tokio::test] + async fn upgrade_test() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_destination = listener.local_addr().unwrap(); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let server_task = task::spawn(dummy_echo_server(listener, shutdown_rx)); + + let mut tasks: BackgroundTasks = + Default::default(); + let _gateway = { + let request = HttpRequest { + connection_id: 0, + request_id: 0, + port: 80, + internal_request: InternalHttpRequest { + method: Method::GET, + uri: "dummyecho://www.mirrord.dev/".parse().unwrap(), + headers: [ + (CONNECTION, HeaderValue::from_static("upgrade")), + (UPGRADE, HeaderValue::from_static(TEST_PROTO)), + ] + .into_iter() + .collect(), + version: Version::HTTP_11, + body: Default::default(), + }, + }; + let gateway = HttpGatewayTask::new( + request, + Default::default(), + ResponseMode::Basic, + local_destination, + ); + tasks.register(gateway, 0, 8) + }; + + let message = tasks + .next() + .await + .expect("no task result") + .1 + .unwrap_message(); + match message { + InProxyTaskMessage::Http(HttpOut::ResponseBasic(res)) => { + assert_eq!( + res.internal_response.status, + StatusCode::SWITCHING_PROTOCOLS + ); + println!("Received response from the gateway: {res:?}"); + assert!(res + .internal_response + .headers + .get(CONNECTION) + .filter(|v| *v == "upgrade") + .is_some()); + assert!(res + .internal_response + .headers + .get(UPGRADE) + .filter(|v| *v == TEST_PROTO) + .is_some()); + } + other => panic!("unexpected task update: {other:?}"), + } + + let message = tasks + .next() + .await + .expect("not task result") + .1 + .unwrap_message(); + let on_upgrade = match message { + InProxyTaskMessage::Http(HttpOut::Upgraded(on_upgrade)) => on_upgrade, + other => panic!("unexpected task update: {other:?}"), + }; + let update = tasks.next().await.expect("no task result").1; + match update { + TaskUpdate::Finished(Ok(())) => {} + other => panic!("unexpected task update: {other:?}"), + } + + let proxy = tasks.register( + TcpProxyTask::new(LocalTcpConnection::AfterUpgrade(on_upgrade), false), + 1, + 8, + ); + + proxy.send(b"test test test".to_vec()).await; + + let message = tasks + .next() + .await + .expect("no task result") + .1 + .unwrap_message(); + match message { + InProxyTaskMessage::Tcp(bytes) => { + assert_eq!(bytes, INITIAL_MESSAGE); + } + _ => panic!("unexpected task update: {update:?}"), + } + + let message = tasks + .next() + .await + .expect("no task result") + .1 + .unwrap_message(); + match message { + InProxyTaskMessage::Tcp(bytes) => { + assert_eq!(bytes, b"test test test"); + } + _ => panic!("unexpected task update: {update:?}"), + } + + let _ = shutdown_tx.send(true); + server_task.await.expect("dummy echo server panicked"); + } + + #[rstest] + #[case::basic(ResponseMode::Basic)] + #[case::framed(ResponseMode::Framed)] + #[case::chunked(ResponseMode::Chunked)] + #[tokio::test] + async fn receive_correct_response_variant(#[case] response_mode: ResponseMode) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let semaphore: Arc = Arc::new(Semaphore::const_new(0)); + let semaphore_clone = semaphore.clone(); + + let conn_task = tokio::spawn(async move { + let service = service_fn(|_req: Request| { + let semaphore = semaphore_clone.clone(); + async move { + let (frame_tx, frame_rx) = mpsc::channel::>>(1); + + tokio::spawn(async move { + for _ in 0..2 { + semaphore.acquire().await.unwrap().forget(); + let _ = frame_tx + .send(Ok(Frame::data(Bytes::from_static(b"hello\n")))) + .await; + } + }); + + let body = StreamBody::new(ReceiverStream::new(frame_rx)); + let mut response = Response::new(body); + response + .headers_mut() + .insert(header::CONTENT_LENGTH, HeaderValue::from_static("12")); + + Ok::<_, Infallible>(response) + } + }); + + let (connection, _) = listener.accept().await.unwrap(); + http1::Builder::new() + .serve_connection(TokioIo::new(connection), service) + .await + .unwrap() + }); + + let request = HttpRequest { + connection_id: 0, + request_id: 0, + port: 80, + internal_request: InternalHttpRequest { + method: Method::GET, + uri: "/".parse().unwrap(), + headers: Default::default(), + version: Version::HTTP_11, + body: StreamingBody::from(vec![]), + }, + }; + + let mut tasks: BackgroundTasks<(), InProxyTaskMessage, Infallible> = Default::default(); + let client_store = ClientStore::default(); + let _gateway = tasks.register( + HttpGatewayTask::new(request, client_store.clone(), response_mode, addr), + (), + 8, + ); + + match response_mode { + ResponseMode::Basic => { + semaphore.add_permits(2); + match tasks.next().await.unwrap().1.unwrap_message() { + InProxyTaskMessage::Http(HttpOut::ResponseBasic(response)) => { + assert_eq!(response.internal_response.body, b"hello\nhello\n"); + } + other => panic!("unexpected task message: {other:?}"), + } + } + + ResponseMode::Framed => { + semaphore.add_permits(2); + match tasks.next().await.unwrap().1.unwrap_message() { + InProxyTaskMessage::Http(HttpOut::ResponseFramed(response)) => { + let mut collected = vec![]; + for frame in response.internal_response.body.0 { + match frame { + InternalHttpBodyFrame::Data(data) => collected.extend(data), + InternalHttpBodyFrame::Trailers(trailers) => { + panic!("unexpected trailing headers: {trailers:?}"); + } + } + } + + assert_eq!(collected, b"hello\nhello\n"); + } + other => panic!("unexpected task message: {other:?}"), + } + } + + ResponseMode::Chunked => { + match tasks.next().await.unwrap().1.unwrap_message() { + InProxyTaskMessage::Http(HttpOut::ResponseChunked(ChunkedResponse::Start( + response, + ))) => { + assert!(response.internal_response.body.is_empty()); + } + other => panic!("unexpected task message: {other:?}"), + } + + semaphore.add_permits(1); + match tasks.next().await.unwrap().1.unwrap_message() { + InProxyTaskMessage::Http(HttpOut::ResponseChunked(ChunkedResponse::Body( + body, + ))) => { + assert_eq!( + body.frames, + vec![InternalHttpBodyFrame::Data(b"hello\n".into())], + ); + assert!(!body.is_last); + } + other => panic!("unexpected task message: {other:?}"), + } + + semaphore.add_permits(1); + match tasks.next().await.unwrap().1.unwrap_message() { + InProxyTaskMessage::Http(HttpOut::ResponseChunked(ChunkedResponse::Body( + body, + ))) => { + assert_eq!( + body.frames, + vec![InternalHttpBodyFrame::Data(b"hello\n".into())], + ); + assert!(body.is_last); + } + other => panic!("unexpected task message: {other:?}"), + } + } + } + + match tasks.next().await.unwrap().1 { + TaskUpdate::Finished(Ok(())) => {} + other => panic!("unexpected task update: {other:?}"), + } + + conn_task.await.unwrap(); + } + + #[tokio::test] + async fn streams_request_frames() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let semaphore: Arc = Arc::new(Semaphore::const_new(0)); + let semaphore_clone = semaphore.clone(); + + let conn_task = tokio::spawn(async move { + let service = service_fn(|mut req: Request| { + let semaphore = semaphore_clone.clone(); + async move { + for _ in 0..2 { + semaphore.add_permits(1); + let frame = req + .body_mut() + .frame() + .await + .unwrap() + .unwrap() + .into_data() + .unwrap(); + assert_eq!(frame, "hello\n"); + } + + Ok::<_, Infallible>(Response::new(Empty::::new())) + } + }); + + let (connection, _) = listener.accept().await.unwrap(); + http1::Builder::new() + .serve_connection(TokioIo::new(connection), service) + .await + .unwrap() + }); + + let (frame_tx, frame_rx) = mpsc::channel(1); + let body = StreamingBody::new(frame_rx, vec![]); + let mut request = HttpRequest { + connection_id: 0, + request_id: 0, + port: 80, + internal_request: InternalHttpRequest { + method: Method::GET, + uri: "/".parse().unwrap(), + headers: Default::default(), + version: Version::HTTP_11, + body, + }, + }; + request + .internal_request + .headers + .insert(header::CONTENT_LENGTH, HeaderValue::from_static("12")); + + let mut tasks: BackgroundTasks<(), InProxyTaskMessage, Infallible> = Default::default(); + let client_store = ClientStore::default(); + let _gateway = tasks.register( + HttpGatewayTask::new(request, client_store.clone(), ResponseMode::Basic, addr), + (), + 8, + ); + + for _ in 0..2 { + semaphore.acquire().await.unwrap().forget(); + frame_tx + .send(InternalHttpBodyFrame::Data(b"hello\n".into())) + .await + .unwrap(); + } + std::mem::drop(frame_tx); + + match tasks.next().await.unwrap().1.unwrap_message() { + InProxyTaskMessage::Http(HttpOut::ResponseBasic(response)) => { + assert_eq!(response.internal_response.status, StatusCode::OK); + } + other => panic!("unexpected message: {other:?}"), + } + + match tasks.next().await.unwrap().1 { + TaskUpdate::Finished(Ok(())) => {} + other => panic!("unexpected task update: {other:?}"), + } + + conn_task.await.unwrap(); + } +} diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs deleted file mode 100644 index e58fde21cee..00000000000 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ /dev/null @@ -1,664 +0,0 @@ -//! [`BackgroundTask`] used by [`Incoming`](super::IncomingProxy) to manage a single -//! intercepted connection. - -use std::{ - io::{self, ErrorKind}, - net::SocketAddr, - time::Duration, -}; - -use bytes::{Bytes, BytesMut}; -use hyper::{body::Frame, upgrade::OnUpgrade, Response, StatusCode}; -use hyper_util::rt::TokioIo; -use mirrord_protocol::tcp::{HttpRequest, HttpResponse, InternalHttpResponse}; -use thiserror::Error; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::TcpStream, - time, -}; -use tracing::Level; - -use super::{ - http::{LocalHttpClient, LocalHttpError, PeekedBody}, - streaming_body::StreamingBody, -}; -use crate::{ - background_tasks::{unless_bus_closed, BackgroundTask, MessageBus, PeekableMessageBus}, - proxies::incoming::bound_socket::BoundTcpSocket, -}; - -/// Messages consumed by the [`Interceptor`] when it runs as a [`BackgroundTask`]. -pub enum MessageIn { - /// Request to be sent to the user application. - Http(HttpRequest), - /// Data to be sent to the user application. - Raw(Vec), -} - -/// Messages produced by the [`Interceptor`] when it runs as a [`BackgroundTask`]. -#[derive(Debug)] -pub enum MessageOut { - /// Response received from the user application. - Http(HttpResponse), - /// Data received from the user application. - Raw(Vec), -} - -impl From> for MessageIn { - fn from(value: HttpRequest) -> Self { - Self::Http(value) - } -} - -impl From> for MessageIn { - fn from(value: Vec) -> Self { - Self::Raw(value) - } -} - -/// Errors that can occur when [`Interceptor`] runs as a [`BackgroundTask`]. -/// -/// All of these are **fatal** for the interceptor and should terminate its main loop -/// ([`Interceptor::run`]). -/// -/// HTTP error handling and retries are done in the [`LocalHttpClient`]. -#[derive(Error, Debug)] -pub enum InterceptorError { - #[error("failed to connect to the user application socket: {0}")] - ConnectFailed(#[source] io::Error), - - #[error("io on the connection with the user application failed: {0}")] - IoFailed(#[source] io::Error), - - #[error("received an unexpected raw data ({} bytes)", .0.len())] - UnexpectedRawData(Vec), - - #[error("received an unexpected HTTP request: {0:?}")] - UnexpectedHttpRequest(HttpRequest), - - #[error(transparent)] - HttpFailed(#[from] LocalHttpError), - - #[error("failed to set up a TCP socket: {0}")] - SocketSetupFailed(#[source] io::Error), - - #[error("failed to handle an HTTP upgrade: {0}")] - HttpUpgradeFailed(#[source] hyper::Error), -} - -pub type InterceptorResult = core::result::Result; - -/// Manages a single intercepted connection. -/// Multiple instances are run as [`BackgroundTask`]s by one [`IncomingProxy`](super::IncomingProxy) -/// to manage individual connections. -/// -/// This interceptor can proxy both raw TCP data and HTTP messages in the same TCP connection. -/// When it receives [`MessageIn::Raw`], it starts acting as a simple proxy. -/// When it received [`MessageIn::Http`], it starts acting as an HTTP gateway. -pub struct Interceptor { - /// Socket that should be used to make the first connection (should already be bound). - socket: BoundTcpSocket, - /// Address of user app's listener. - peer: SocketAddr, -} - -impl Interceptor { - /// Creates a new instance. When run, this instance will use the given `socket` (must be already - /// bound) to communicate with the given `peer`. - /// - /// # Note - /// - /// The socket can be replaced when retrying HTTP requests. - pub fn new(socket: BoundTcpSocket, peer: SocketAddr) -> Self { - Self { socket, peer } - } -} - -impl BackgroundTask for Interceptor { - type Error = InterceptorError; - type MessageIn = MessageIn; - type MessageOut = MessageOut; - - #[tracing::instrument( - level = Level::TRACE, - name = "incoming_interceptor_main_loop", - skip_all, fields(peer_addr = %self.peer), - err(level = Level::WARN) - )] - async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { - let Some(result) = unless_bus_closed(message_bus, self.socket.connect(self.peer)).await - else { - tracing::trace!("Message bus closed, exiting"); - return Ok(()); - }; - let stream = result.map_err(InterceptorError::ConnectFailed)?; - let mut message_bus = message_bus.peekable(); - - tokio::select! { - // If there is some data from the user application before we get anything from the agent, - // then this is not HTTP. - // - // We should not block until the agent has something, we don't know what this protocol looks like. - result = stream.readable() => { - result.map_err(InterceptorError::IoFailed)?; - tracing::trace!("TCP connection became readable, assuming raw TCP"); - RawConnection { stream }.run(message_bus).await - } - - message = message_bus.peek() => match message { - Some(MessageIn::Http(..)) => { - tracing::trace!("Next message on the message bus is an HTTP request, running as an HTTP gateway"); - HttpConnection { - local_client: LocalHttpClient::new_for_stream(stream)?, - } - .run(message_bus) - .await - } - - Some(MessageIn::Raw(..)) => { - tracing::trace!("Next message on the message bus is raw TCP data, running as a TCP proxy"); - RawConnection { stream }.run(message_bus).await - }, - - None => Ok(()), - } - } - } -} - -/// Utilized by the [`Interceptor`] when it acts as an HTTP gateway. -/// See [`HttpConnection::run`] for usage. -struct HttpConnection { - local_client: LocalHttpClient, -} - -impl HttpConnection { - /// Handles the result of sending an HTTP request. - /// Returns an [`HttpResponse`] to be returned to the client or an [`InterceptorError`] when the - /// given [`LocalHttpError`] is fatal for the interceptor. Most [`LocalHttpError`]s are not - /// fatal and should be converted to [`StatusCode::BAD_GATEWAY`] responses instead. - #[tracing::instrument(level = Level::TRACE, ret, err(level = Level::WARN))] - fn handle_send_result( - request: HttpRequest, - response: Result, LocalHttpError>, - ) -> InterceptorResult<(HttpResponse, Option)> { - match response { - Err(LocalHttpError::SocketSetupFailed(error)) => { - Err(InterceptorError::SocketSetupFailed(error)) - } - - Err(LocalHttpError::UnsupportedHttpVersion(..)) => { - Err(InterceptorError::UnexpectedHttpRequest(request)) - } - - Err(error) => { - let response = error - .as_error_response( - request.internal_request.version, - request.request_id, - request.connection_id, - request.port, - ) - .map_body(|body| PeekedBody { - head: vec![Frame::data(Bytes::from_owner(body))], - tail: None, - }); - - Ok((response, None)) - } - - Ok(mut response) => { - let upgrade = if response.status() == StatusCode::SWITCHING_PROTOCOLS { - Some(hyper::upgrade::on(&mut response)) - } else { - None - }; - - let (parts, body) = response.into_parts(); - let response = HttpResponse { - port: request.port, - connection_id: request.connection_id, - request_id: request.request_id, - internal_response: InternalHttpResponse { - status: parts.status, - version: parts.version, - headers: parts.headers, - body, - }, - }; - - Ok((response, upgrade)) - } - } - } - - /// Proxies HTTP messages until an HTTP upgrade happens or the [`MessageBus`] closes. - /// Support retries (with reconnecting to the HTTP server). - /// - /// When an HTTP upgrade happens, the underlying [`TcpStream`] is reclaimed and wrapped - /// in a [`RawConnection`], which handles the rest of the connection. - #[tracing::instrument( - level = Level::TRACE, - name = "http_connection_main_loop", - skip_all, ret, - err(level = Level::WARN), - )] - async fn run( - mut self, - mut message_bus: PeekableMessageBus<'_, Interceptor>, - ) -> InterceptorResult<()> { - let upgrade = loop { - match message_bus.recv().await { - None => return Ok(()), - - Some(MessageIn::Raw(data)) => { - // We should not receive any raw data from the agent before sending a - //`101 SWITCHING PROTOCOLS` response. - return Err(InterceptorError::UnexpectedRawData(data)); - } - - Some(MessageIn::Http(request)) => { - let Some(result) = unless_bus_closed( - message_bus.inner(), - self.local_client.send_request(&request), - ) - .await - else { - tracing::trace!("Message bus closed, exiting"); - return Ok(()); - }; - let (res, on_upgrade) = Self::handle_send_result(request, result)?; - message_bus.send(MessageOut::Http(res)).await; - - let Some(on_upgrade) = on_upgrade else { - continue; - }; - - tracing::trace!("Detected an HTTP upgrade"); - let Some(result) = unless_bus_closed(message_bus.inner(), on_upgrade).await - else { - tracing::trace!("Message bus closed, exiting"); - return Ok(()); - }; - - break result.map_err(InterceptorError::HttpUpgradeFailed)?; - } - } - }; - - let parts = upgrade - .downcast::>() - .expect("IO type is known"); - let stream = parts.io.into_inner(); - let read_buf = parts.read_buf; - - if !read_buf.is_empty() { - message_bus.send(MessageOut::Raw(read_buf.into())).await; - } - - RawConnection { stream }.run(message_bus).await - } -} - -/// Utilized by the [`Interceptor`] when it acts as a TCP proxy. -/// See [`RawConnection::run`] for usage. -#[derive(Debug)] -struct RawConnection { - /// Connection between the [`Interceptor`] and the server. - stream: TcpStream, -} - -impl RawConnection { - /// Proxies raw TCP data until the [`MessageBus`] closes. - /// - /// # Notes - /// - /// 1. When the peer shuts down writing, a single 0-sized read is sent through the - /// [`MessageBus`]. This is to notify the agent about the shutdown condition. - /// - /// 2. A 0-sized read received from the [`MessageBus`] is treated as a shutdown on the agent - /// side. Connection with the peer is shut down as well. - /// - /// 3. This implementation exits only when an error is encountered or the [`MessageBus`] is - /// closed. - #[tracing::instrument( - level = Level::TRACE, - name = "raw_connection_main_loop", - skip_all, ret, - err(level = Level::WARN), - )] - async fn run( - mut self, - mut message_bus: PeekableMessageBus<'_, Interceptor>, - ) -> InterceptorResult<()> { - let mut buf = BytesMut::with_capacity(64 * 1024); - let mut reading_closed = false; - let mut remote_closed = false; - - loop { - tokio::select! { - biased; - - res = self.stream.read_buf(&mut buf), if !reading_closed => match res { - Err(e) if e.kind() == ErrorKind::WouldBlock => {}, - Err(e) => break Err(InterceptorError::IoFailed(e)), - Ok(..) => { - if buf.is_empty() { - tracing::trace!("layer shutdown, sending a 0-sized read to inform the agent"); - reading_closed = true; - } - message_bus.send(MessageOut::Raw(buf.to_vec())).await; - buf.clear(); - } - }, - - msg = message_bus.recv(), if !remote_closed => match msg { - None => { - tracing::trace!("message bus closed, waiting 1 second before exiting"); - remote_closed = true; - }, - Some(MessageIn::Raw(data)) => { - if data.is_empty() { - tracing::trace!("agent shutdown, shutting down connection with layer"); - self.stream.shutdown().await.map_err(InterceptorError::IoFailed)?; - } else { - self.stream.write_all(&data).await.map_err(InterceptorError::IoFailed)?; - } - }, - Some(MessageIn::Http(request)) => break Err(InterceptorError::UnexpectedHttpRequest(request)), - }, - - _ = time::sleep(Duration::from_secs(1)), if remote_closed => { - tracing::trace!("layer silent for 1 second and message bus is closed, exiting"); - - break Ok(()); - }, - } - } - } -} - -#[cfg(test)] -mod test { - use std::{convert::Infallible, net::Ipv4Addr, sync::Arc}; - - use bytes::Bytes; - use futures::future::FutureExt; - use http_body_util::{BodyExt, Empty}; - use hyper::{ - body::Incoming, - header::{HeaderValue, CONNECTION, UPGRADE}, - server::conn::http1, - service::service_fn, - upgrade::Upgraded, - Method, Request, Response, Version, - }; - use hyper_util::rt::TokioIo; - use mirrord_protocol::tcp::{HttpRequest, InternalHttpBodyFrame, InternalHttpRequest}; - use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::TcpListener, - sync::{watch, Notify}, - task, - }; - - use super::*; - use crate::background_tasks::{BackgroundTasks, TaskUpdate}; - - /// Binary protocol over TCP. - /// Server first sends bytes [`INITIAL_MESSAGE`], then echoes back all received data. - const TEST_PROTO: &str = "dummyecho"; - - const INITIAL_MESSAGE: &[u8] = &[0x4a, 0x50, 0x32, 0x47, 0x4d, 0x44]; - - /// Handles requests upgrading to the [`TEST_PROTO`] protocol. - async fn upgrade_req_handler( - mut req: Request, - ) -> hyper::Result>> { - async fn dummy_echo(upgraded: Upgraded) -> io::Result<()> { - let mut upgraded = TokioIo::new(upgraded); - let mut buf = [0_u8; 64]; - - upgraded.write_all(INITIAL_MESSAGE).await?; - - loop { - let bytes_read = upgraded.read(&mut buf[..]).await?; - if bytes_read == 0 { - break; - } - - let echo_back = buf.get(0..bytes_read).unwrap(); - upgraded.write_all(echo_back).await?; - } - - Ok(()) - } - - let mut res = Response::new(Empty::new()); - - let contains_expected_upgrade = req - .headers() - .get(UPGRADE) - .filter(|proto| *proto == TEST_PROTO) - .is_some(); - if !contains_expected_upgrade { - *res.status_mut() = StatusCode::BAD_REQUEST; - return Ok(res); - } - - task::spawn(async move { - match hyper::upgrade::on(&mut req).await { - Ok(upgraded) => { - if let Err(e) = dummy_echo(upgraded).await { - eprintln!("server foobar io error: {}", e) - }; - } - Err(e) => eprintln!("upgrade error: {}", e), - } - }); - - *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; - res.headers_mut() - .insert(UPGRADE, HeaderValue::from_static(TEST_PROTO)); - res.headers_mut() - .insert(CONNECTION, HeaderValue::from_static("upgrade")); - Ok(res) - } - - /// Runs a [`hyper`] server that accepts only requests upgrading to the [`TEST_PROTO`] protocol. - async fn dummy_echo_server(listener: TcpListener, mut shutdown: watch::Receiver) { - loop { - tokio::select! { - res = listener.accept() => { - let (stream, _) = res.expect("dummy echo server failed to accept connection"); - - let mut shutdown = shutdown.clone(); - - task::spawn(async move { - let conn = http1::Builder::new().serve_connection(TokioIo::new(stream), service_fn(upgrade_req_handler)); - let mut conn = conn.with_upgrades(); - let mut conn = Pin::new(&mut conn); - - tokio::select! { - res = &mut conn => { - res.expect("dummy echo server failed to serve connection"); - } - - _ = shutdown.changed() => { - conn.graceful_shutdown(); - } - } - }); - } - - _ = shutdown.changed() => break, - } - } - } - - #[tokio::test] - async fn upgrade_test() { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let local_destination = listener.local_addr().unwrap(); - - let (shutdown_tx, shutdown_rx) = watch::channel(false); - let server_task = task::spawn(dummy_echo_server(listener, shutdown_rx)); - - let mut tasks: BackgroundTasks<(), MessageOut, InterceptorError> = Default::default(); - let interceptor = { - let socket = - BoundTcpSocket::bind_specified_or_localhost(Ipv4Addr::LOCALHOST.into()).unwrap(); - tasks.register(Interceptor::new(socket, local_destination), (), 8) - }; - - interceptor - .send(HttpRequest { - connection_id: 0, - request_id: 0, - port: 80, - internal_request: InternalHttpRequest { - method: Method::GET, - uri: "dummyecho://www.mirrord.dev/".parse().unwrap(), - headers: [ - (CONNECTION, HeaderValue::from_static("upgrade")), - (UPGRADE, HeaderValue::from_static(TEST_PROTO)), - ] - .into_iter() - .collect(), - version: Version::HTTP_11, - body: Default::default(), - }, - }) - .await; - - let (_, update) = tasks.next().await.expect("no task result"); - match update { - TaskUpdate::Message(MessageOut::Http(res)) => { - assert_eq!( - res.internal_response.status, - StatusCode::SWITCHING_PROTOCOLS - ); - println!("Received repsonse from the interceptor: {res:?}"); - assert!(res - .internal_response - .headers - .get(CONNECTION) - .filter(|v| *v == "upgrade") - .is_some()); - assert!(res - .internal_response - .headers - .get(UPGRADE) - .filter(|v| *v == TEST_PROTO) - .is_some()); - } - _ => panic!("unexpected task update: {update:?}"), - } - - interceptor.send(b"test test test".to_vec()).await; - - let (_, update) = tasks.next().await.expect("no task result"); - match update { - TaskUpdate::Message(MessageOut::Raw(bytes)) => { - assert_eq!(bytes, INITIAL_MESSAGE); - } - _ => panic!("unexpected task update: {update:?}"), - } - - let (_, update) = tasks.next().await.expect("no task result"); - match update { - TaskUpdate::Message(MessageOut::Raw(bytes)) => { - assert_eq!(bytes, b"test test test"); - } - _ => panic!("unexpected task update: {update:?}"), - } - - let _ = shutdown_tx.send(true); - server_task.await.expect("dummy echo server panicked"); - } - - /// Ensure that body of [`MessageOut::Http`] response is received frame by frame - #[tokio::test] - async fn receive_request_as_frames() { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - - let mut tasks: BackgroundTasks<(), MessageOut, InterceptorError> = Default::default(); - let interceptor = Interceptor::new( - BoundTcpSocket::bind_specified_or_localhost(Ipv4Addr::LOCALHOST.into()).unwrap(), - listener.local_addr().unwrap(), - ); - let interceptor = tasks.register(interceptor, (), 8); - - let (frame_tx, frame_rx) = tokio::sync::mpsc::channel(1); - interceptor - .send(MessageIn::Http(HttpRequest { - internal_request: InternalHttpRequest { - method: Method::POST, - uri: "/".parse().unwrap(), - headers: Default::default(), - version: Version::HTTP_11, - body: StreamingBody::from(frame_rx), - }, - connection_id: 1, - request_id: 2, - port: 3, - })) - .await; - let connection = listener.accept().await.unwrap().0; - - let notifier = Arc::new(Notify::default()); - - // Task that sends the next frame when notified. - // Sends two frames, then exits. - tokio::spawn({ - let notifier = notifier.clone(); - async move { - for _ in 0..2 { - notifier.notified().await; - if frame_tx - .send(InternalHttpBodyFrame::Data(b"some-data".into())) - .await - .is_err() - { - break; - } - } - - // Wait for third notification before dropping the frame sender. - notifier.notified().await; - } - }); - - let service = service_fn(|mut req: Request| { - let notifier = notifier.clone(); - async move { - for _ in 0..2 { - let frame = req.body_mut().frame().now_or_never(); - assert!(frame.is_none()); - - notifier.notify_one(); - let frame = req - .body_mut() - .frame() - .await - .unwrap() - .unwrap() - .into_data() - .unwrap(); - assert_eq!(frame, b"some-data".to_vec()); - let frame = req.body_mut().frame().now_or_never(); - assert!(frame.is_none()); - } - - notifier.notify_one(); - let frame = req.body_mut().frame().await; - assert!(frame.is_none()); - - Ok::<_, Infallible>(Response::new(Empty::::new())) - } - }); - let conn = http1::Builder::new().serve_connection(TokioIo::new(connection), service); - conn.await.unwrap(); - } -} diff --git a/mirrord/intproxy/src/proxies/incoming/response_reader.rs b/mirrord/intproxy/src/proxies/incoming/response_reader.rs deleted file mode 100644 index 816319e2036..00000000000 --- a/mirrord/intproxy/src/proxies/incoming/response_reader.rs +++ /dev/null @@ -1,286 +0,0 @@ -use std::{convert::Infallible, time::Instant}; - -use http_body_util::BodyExt; -use hyper::body::{Frame, Incoming}; -use mirrord_protocol::{ - batched_body::BatchedBody, - tcp::{ - ChunkedHttpBody, ChunkedHttpError, ChunkedResponse, HttpResponse, InternalHttpBody, - InternalHttpBodyFrame, LayerTcpSteal, - }, - ConnectionId, RequestId, -}; -use tracing::Level; - -use super::http::PeekedBody; -use crate::{ - background_tasks::{unless_bus_closed, BackgroundTask, MessageBus}, - proxies::incoming::http::LocalHttpError, -}; - -/// Background task responsible for asynchronous read of an HTTP response body coming from the user -/// application. -/// -/// Meant to be run as a [`BackgroundTask`]. -pub enum HttpResponseReader { - /// Produces a [`LayerTcpSteal::HttpResponse`] message. - Legacy(HttpResponse), - /// Produces a [`LayerTcpSteal::HttpResponseFramed`] message. - Framed(HttpResponse), - /// Produces [`LayerTcpSteal::HttpResponseChunked`] messasages. - Chunked { - connection_id: ConnectionId, - request_id: RequestId, - body: Incoming, - }, -} - -impl HttpResponseReader { - fn request_id(&self) -> RequestId { - match self { - Self::Legacy(response) => response.request_id, - Self::Framed(response) => response.request_id, - Self::Chunked { request_id, .. } => *request_id, - } - } - - fn connection_id(&self) -> ConnectionId { - match self { - Self::Legacy(response) => response.connection_id, - Self::Framed(response) => response.connection_id, - Self::Chunked { connection_id, .. } => *connection_id, - } - } - - /// Reads the body and produces a [`LayerTcpSteal::HttpResponse`] message. - /// - /// When reading the body fails, produces a [`LayerTcpSteal::HttpResponse`] error response. - async fn run_legacy( - mut response: HttpResponse, - message_bus: &mut MessageBus, - ) { - let tail = match response.internal_response.body.tail.take() { - Some(incoming) => { - let start = Instant::now(); - let Some(result) = unless_bus_closed(message_bus, incoming.collect()).await else { - tracing::trace!("Message bus closed, exiting"); - return; - }; - - match result { - Ok(data) => { - tracing::trace!( - elapsed_s = start.elapsed().as_secs_f32(), - "Collected the whole body.", - ); - Vec::from(data.to_bytes()) - } - - Err(error) => { - tracing::warn!( - connection_id = response.connection_id, - request_id = response.request_id, - %error, - "Failed to read the response body.", - ); - - let response = LocalHttpError::ReadBodyFailed(error).as_error_response( - response.internal_response.version, - response.request_id, - response.connection_id, - response.port, - ); - message_bus - .send(LayerTcpSteal::HttpResponse(response)) - .await; - return; - } - } - } - - None => vec![], - }; - - let response = response.map_body(|body| { - let mut complete = Vec::with_capacity( - body.head - .iter() - .filter_map(|frame| Some(frame.data_ref()?.len())) - .sum::() - + tail.len(), - ); - for frame in body - .head - .into_iter() - .map(Frame::into_data) - .filter_map(Result::ok) - { - complete.extend(frame); - } - complete.extend(tail); - complete - }); - - message_bus - .send(LayerTcpSteal::HttpResponse(response)) - .await; - } - - /// Reads the body and produces a [`LayerTcpSteal::HttpResponseFramed`] message. - /// - /// When reading the body fails, produces a [`LayerTcpSteal::HttpResponse`] error response. - async fn run_framed( - mut response: HttpResponse, - message_bus: &mut MessageBus, - ) { - if let Some(mut incoming) = response.internal_response.body.tail.take() { - let start = Instant::now(); - loop { - let Some(result) = unless_bus_closed(message_bus, incoming.next_frames()).await - else { - tracing::trace!("Message bus closed, exiting"); - return; - }; - - match result { - Ok(data) => { - response.internal_response.body.head.extend(data.frames); - - if data.is_last { - tracing::trace!( - elapsed_s = start.elapsed().as_secs_f32(), - "Collected the whole response body." - ); - break; - } - } - - Err(error) => { - tracing::warn!( - connection_id = response.connection_id, - request_id = response.request_id, - %error, - "Failed to read the response body.", - ); - - let response = LocalHttpError::ReadBodyFailed(error).as_error_response( - response.internal_response.version, - response.request_id, - response.connection_id, - response.port, - ); - message_bus - .send(LayerTcpSteal::HttpResponse(response)) - .await; - return; - } - } - } - }; - - let response = response.map_body(|body| { - InternalHttpBody( - body.head - .into_iter() - .map(InternalHttpBodyFrame::from) - .collect(), - ) - }); - - message_bus - .send(LayerTcpSteal::HttpResponseFramed(response)) - .await; - } - - /// Reads the body and produces [`LayerTcpSteal::HttpResponseChunked`] messages. - async fn run_chunked( - connection_id: ConnectionId, - request_id: RequestId, - mut body: Incoming, - message_bus: &mut MessageBus, - ) { - let start = Instant::now(); - loop { - let Some(result) = unless_bus_closed(message_bus, body.next_frames()).await else { - tracing::trace!("Message bus closed, exiting"); - return; - }; - - match result { - Ok(data) => { - let message = LayerTcpSteal::HttpResponseChunked(ChunkedResponse::Body( - ChunkedHttpBody { - frames: data - .frames - .into_iter() - .map(InternalHttpBodyFrame::from) - .collect(), - is_last: data.is_last, - connection_id, - request_id, - }, - )); - message_bus.send(message).await; - - if data.is_last { - tracing::trace!( - elapsed_s = start.elapsed().as_secs_f32(), - "Collected the whole response body." - ); - break; - } - } - - Err(error) => { - tracing::warn!( - connection_id, - request_id, - %error, - "Failed to read the response body.", - ); - - let message = LayerTcpSteal::HttpResponseChunked(ChunkedResponse::Error( - ChunkedHttpError { - connection_id, - request_id, - }, - )); - message_bus.send(message).await; - - return; - } - } - } - } -} - -impl BackgroundTask for HttpResponseReader { - type Error = Infallible; - type MessageIn = Infallible; - type MessageOut = LayerTcpSteal; - - #[tracing::instrument( - level = Level::TRACE, - name = "http_response_reader_main_loop", - fields( - connection_id = self.connection_id(), - request_id = self.request_id(), - ), - skip_all, - )] - async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { - match self { - Self::Legacy(response) => Self::run_legacy(response, message_bus).await, - - Self::Framed(response) => Self::run_framed(response, message_bus).await, - - Self::Chunked { - connection_id, - request_id, - body, - } => Self::run_chunked(connection_id, request_id, body, message_bus).await, - } - - Ok(()) - } -} diff --git a/mirrord/intproxy/src/proxies/incoming/tasks.rs b/mirrord/intproxy/src/proxies/incoming/tasks.rs new file mode 100644 index 00000000000..bb35a882eab --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/tasks.rs @@ -0,0 +1,63 @@ +use std::{convert::Infallible, io}; + +use hyper::{upgrade::OnUpgrade, Version}; +use mirrord_protocol::{ + tcp::{ChunkedResponse, HttpResponse, InternalHttpBody}, + ConnectionId, Port, RequestId, +}; +use thiserror::Error; + +#[derive(Debug)] +pub enum InProxyTaskMessage { + Tcp(Vec), + Http(HttpOut), +} + +#[derive(Debug)] +pub enum HttpOut { + ResponseBasic(HttpResponse>), + ResponseFramed(HttpResponse), + ResponseChunked(ChunkedResponse), + Upgraded(OnUpgrade), +} + +impl From> for InProxyTaskMessage { + fn from(value: Vec) -> Self { + Self::Tcp(value) + } +} + +impl From for InProxyTaskMessage { + fn from(value: HttpOut) -> Self { + Self::Http(value) + } +} + +#[derive(Error, Debug)] +pub enum InProxyTaskError { + #[error("io failed: {0}")] + IoError(#[from] io::Error), + #[error("local HTTP upgrade failed: {0}")] + UpgradeError(#[source] hyper::Error), +} + +impl From for InProxyTaskError { + fn from(_: Infallible) -> Self { + unreachable!() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum InProxyTask { + MirrorTcpProxy(ConnectionId), + StealTcpProxy(ConnectionId), + HttpGateway(HttpGatewayId), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct HttpGatewayId { + pub connection_id: ConnectionId, + pub request_id: RequestId, + pub port: Port, + pub version: Version, +} diff --git a/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs new file mode 100644 index 00000000000..98304e155be --- /dev/null +++ b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs @@ -0,0 +1,120 @@ +use std::{io::ErrorKind, net::SocketAddr, time::Duration}; + +use bytes::BytesMut; +use hyper::upgrade::OnUpgrade; +use hyper_util::rt::TokioIo; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, + time, +}; + +use super::{ + bound_socket::BoundTcpSocket, + tasks::{InProxyTaskError, InProxyTaskMessage}, +}; +use crate::background_tasks::{BackgroundTask, MessageBus}; + +pub enum LocalTcpConnection { + FromTheStart { + socket: BoundTcpSocket, + peer: SocketAddr, + }, + AfterUpgrade(OnUpgrade), +} + +pub struct TcpProxyTask { + connection: LocalTcpConnection, + discard_data: bool, +} + +impl TcpProxyTask { + pub fn new(connection: LocalTcpConnection, discard_data: bool) -> Self { + Self { + connection, + discard_data, + } + } +} + +impl BackgroundTask for TcpProxyTask { + type Error = InProxyTaskError; + type MessageIn = Vec; + type MessageOut = InProxyTaskMessage; + + async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { + let mut stream = match self.connection { + LocalTcpConnection::FromTheStart { socket, peer } => { + let Some(stream) = message_bus + .closed() + .cancel_on_close(socket.connect(peer)) + .await + else { + return Ok(()); + }; + + stream? + } + + LocalTcpConnection::AfterUpgrade(on_upgrade) => { + let upgraded = on_upgrade.await.map_err(InProxyTaskError::UpgradeError)?; + let parts = upgraded + .downcast::>() + .expect("IO type is known"); + let stream = parts.io.into_inner(); + let read_buf = parts.read_buf; + + if !self.discard_data { + message_bus.send(Vec::from(read_buf)).await; + } + + stream + } + }; + + let mut buf = BytesMut::with_capacity(64 * 1024); + let mut reading_closed = false; + let mut is_lingering = false; + + loop { + tokio::select! { + res = stream.read_buf(&mut buf), if !reading_closed => match res { + Err(e) if e.kind() == ErrorKind::WouldBlock => {}, + Err(e) => break Err(e.into()), + Ok(..) => { + if buf.is_empty() { + reading_closed = true; + } + + if !self.discard_data { + message_bus.send(buf.to_vec()).await; + } + + buf.clear(); + } + }, + + msg = message_bus.recv(), if !is_lingering => match msg { + None => { + if self.discard_data { + break Ok(()); + } + + is_lingering = true; + } + Some(data) => { + if data.is_empty() { + stream.shutdown().await?; + } else { + stream.write_all(&data).await?; + } + }, + }, + + _ = time::sleep(Duration::from_secs(1)), if is_lingering => { + break Ok(()); + } + } + } + } +} From ab4fdb5a1a246a7360b9175bf7f5e9b7450570ca Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Thu, 16 Jan 2025 18:45:30 +0100 Subject: [PATCH 15/60] HttpGateway tests --- Cargo.lock | 2 - mirrord/intproxy/Cargo.toml | 3 - mirrord/intproxy/src/proxies/incoming/http.rs | 89 +-------------- .../src/proxies/incoming/http/client_store.rs | 24 +++-- .../src/proxies/incoming/http_gateway.rs | 102 +++++++++++++++++- 5 files changed, 118 insertions(+), 102 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0a54470c5b1..0b1111a777d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4313,7 +4313,6 @@ version = "3.128.0" dependencies = [ "bytes", "exponential-backoff", - "h2 0.4.7", "http-body-util", "hyper 1.5.2", "hyper-util", @@ -4324,7 +4323,6 @@ dependencies = [ "mirrord-operator", "mirrord-protocol", "rand", - "reqwest 0.12.12", "rstest", "rustls 0.23.20", "rustls-pemfile 2.2.0", diff --git a/mirrord/intproxy/Cargo.toml b/mirrord/intproxy/Cargo.toml index b132de1ed37..8b39387ec59 100644 --- a/mirrord/intproxy/Cargo.toml +++ b/mirrord/intproxy/Cargo.toml @@ -33,8 +33,6 @@ tokio.workspace = true tracing.workspace = true tokio-stream.workspace = true hyper = { workspace = true, features = ["client", "http1", "http2"] } -# For checking the `RST_STREAM` error from HTTP2 stealer + filter. -h2 = "0.4" hyper-util.workspace = true http-body-util.workspace = true bytes.workspace = true @@ -45,5 +43,4 @@ rustls-pemfile.workspace = true exponential-backoff = "2" [dev-dependencies] -reqwest.workspace = true rstest.workspace = true diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index b01aee0cbf9..5b9a7a0012b 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -1,4 +1,4 @@ -use std::{error::Error, fmt, io, net::SocketAddr}; +use std::{fmt, io, net::SocketAddr}; use hyper::{ body::Incoming, @@ -107,23 +107,6 @@ pub enum LocalHttpError { } impl LocalHttpError { - /// Checks if the given [`hyper::Error`] originates from [`h2::Error`] `RST_STREAM`. - /// - /// This requires that we use the same [`h2`] version as [`hyper`], - /// which is verified in the `hyper_and_h2_versions_in_sync` test below. - pub fn is_h2_reset(error: &hyper::Error) -> bool { - let mut cause = error.source(); - while let Some(err) = cause { - if let Some(typed) = err.downcast_ref::() { - return typed.is_reset(); - }; - - cause = err.source(); - } - - false - } - /// Checks if we can retry sending the request, given that the previous attempt resulted in this /// error. pub fn can_retry(&self) -> bool { @@ -131,7 +114,10 @@ impl LocalHttpError { Self::SocketSetupFailed(..) | Self::UnsupportedHttpVersion(..) => false, Self::ConnectTcpFailed(..) => true, Self::HandshakeFailed(err) | Self::SendFailed(err) | Self::ReadBodyFailed(err) => { - err.is_closed() || err.is_incomplete_message() || Self::is_h2_reset(err) + !(err.is_parse() + || err.is_parse_status() + || err.is_parse_too_large() + || err.is_user()) } } } @@ -263,68 +249,3 @@ impl HttpSender { } } } - -#[cfg(test)] -mod test { - use std::{ - convert::Infallible, - error::Error, - net::{Ipv4Addr, SocketAddr}, - }; - - use bytes::Bytes; - use http_body_util::Full; - use hyper::{server::conn::http2, service::service_fn, Response}; - use hyper_util::rt::{TokioExecutor, TokioIo}; - use tokio::net::TcpListener; - - /// Checks that [`hyper`] and [`h2`] crate versions are in sync with each other. - /// - /// In [`LocalHttpError::is_h2_reset`](super::LocalHttpError::is_h2_reset) we use - /// `source.downcast_ref::` to drill down on [`h2`] errors from [`hyper`], we - /// need these two crates to stay in sync, otherwise we could always fail some of our checks - /// that rely on this `downcast` working. - /// - /// Even though we're using [`h2::Error::is_reset`] in intproxy, this test can be - /// for any error, and thus here we do it for [`h2::Error::is_go_away`] which is - /// easier to trigger. - #[tokio::test] - async fn hyper_and_h2_versions_in_sync() { - let listener = TcpListener::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0)) - .await - .unwrap(); - let listener_address = listener.local_addr().unwrap(); - - let handle = tokio::spawn(async move { - let stream = listener.accept().await.unwrap().0; - http2::Builder::new(TokioExecutor::default()) - .serve_connection( - TokioIo::new(stream), - service_fn(|_| async move { - Ok::<_, Infallible>(Response::new(Full::new(Bytes::from("Heresy!")))) - }), - ) - .await - }); - - assert!(reqwest::get(format!("https://{listener_address}")) - .await - .is_err()); - - let conn_result = handle.await.unwrap(); - assert!( - conn_result - .as_ref() - .err() - .and_then(Error::source) - .and_then(|source| source.downcast_ref::()) - .is_some_and(h2::Error::is_go_away), - r"The request is supposed to fail with `GO_AWAY`! - Something is wrong if it didn't! - - >> If you're seeing this error, the cause is likely that `hyper` and `h2` - versions are out of sync, and we can't have that due to our use of - `downcast_ref` on some `h2` errors!" - ); - } -} diff --git a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs index b22c26ed06c..4440cb1b1f6 100644 --- a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs +++ b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs @@ -8,6 +8,7 @@ use tokio::{ use super::{LocalHttpClient, LocalHttpError}; +#[derive(Debug)] struct IdleLocalClient { client: LocalHttpClient, last_used: Instant, @@ -18,17 +19,21 @@ pub struct ClientStore(watch::Sender>); impl Default for ClientStore { fn default() -> Self { - let (tx, _) = watch::channel(Default::default()); - - tokio::spawn(cleanup_task(tx.clone())); - - Self(tx) + Self::new_with_timeout(Self::IDLE_CLIENT_TIMEOUT) } } impl ClientStore { const IDLE_CLIENT_TIMEOUT: Duration = Duration::from_secs(3); + pub fn new_with_timeout(timeout: Duration) -> Self { + let (tx, _) = watch::channel(Default::default()); + + tokio::spawn(cleanup_task(tx.clone(), timeout)); + + Self(tx) + } + pub async fn get( &self, server_addr: SocketAddr, @@ -37,6 +42,7 @@ impl ClientStore { let mut ready = None; self.0.send_if_modified(|clients| { + println!("ready clients: {clients:?}"); let position = clients.iter().position(|idle| { idle.client.handles_version(version) && idle.client.local_server_address() == server_addr @@ -52,6 +58,7 @@ impl ClientStore { }); if let Some(ready) = ready { + println!("found ready client"); return Ok(ready); } @@ -64,6 +71,7 @@ impl ClientStore { } pub fn push_idle(&self, client: LocalHttpClient) { + println!("storing idle client {client:?}"); self.0.send_modify(|clients| { clients.push(IdleLocalClient { client, @@ -105,7 +113,7 @@ impl ClientStore { } } -async fn cleanup_task(clients: watch::Sender>) { +async fn cleanup_task(clients: watch::Sender>, idle_client_timeout: Duration) { loop { let now = Instant::now(); let mut min_last_used = None; @@ -114,7 +122,7 @@ async fn cleanup_task(clients: watch::Sender>) { let mut removed = false; clients.retain(|client| { - if client.last_used + ClientStore::IDLE_CLIENT_TIMEOUT > now { + if client.last_used + idle_client_timeout > now { min_last_used = min_last_used .map(|previous| cmp::min(previous, client.last_used)) .or(Some(client.last_used)); @@ -130,7 +138,7 @@ async fn cleanup_task(clients: watch::Sender>) { }); if let Some(min_last_used) = min_last_used { - time::sleep_until(min_last_used + ClientStore::IDLE_CLIENT_TIMEOUT).await; + time::sleep_until(min_last_used + idle_client_timeout).await; } else { clients .subscribe() diff --git a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs index 02d8cdac9a6..e684274b0a0 100644 --- a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs +++ b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs @@ -41,10 +41,12 @@ impl HttpGatewayTask { } async fn send_attempt(&self, message_bus: &mut MessageBus) -> Result<(), LocalHttpError> { + println!("making send attempt"); let mut client = self .client_store .get(self.server_addr, self.request.version()) .await?; + println!("got client"); let mut response = client.send_request(self.request.clone()).await?; let on_upgrade = (response.status() == StatusCode::SWITCHING_PROTOCOLS) .then(|| hyper::upgrade::on(&mut response)); @@ -181,6 +183,7 @@ impl BackgroundTask for HttpGatewayTask { .then(|| backoffs.next()) .flatten() .flatten(); + println!("send attempt failed with {error}, backoff={backoff:?}"); let Some(backoff) = backoff else { break error; }; @@ -330,8 +333,10 @@ mod test { } } + /// Verifies that [`HttpGatewayTask`] and [`TcpProxyTask`] together correctly handle HTTP + /// upgrades. #[tokio::test] - async fn upgrade_test() { + async fn handles_http_upgrades() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let local_destination = listener.local_addr().unwrap(); @@ -450,12 +455,17 @@ mod test { server_task.await.expect("dummy echo server panicked"); } + /// Verifies that [`HttpGatewayTask`] produces correct variant of the [`HttpResponse`]. + /// + /// Verifies that body of + /// [`LayerTcpSteal::HttpResponseChunked`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseChunked) + /// is streamed. #[rstest] #[case::basic(ResponseMode::Basic)] #[case::framed(ResponseMode::Framed)] #[case::chunked(ResponseMode::Chunked)] #[tokio::test] - async fn receive_correct_response_variant(#[case] response_mode: ResponseMode) { + async fn produces_correct_response_variant(#[case] response_mode: ResponseMode) { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let semaphore: Arc = Arc::new(Semaphore::const_new(0)); @@ -507,9 +517,8 @@ mod test { }; let mut tasks: BackgroundTasks<(), InProxyTaskMessage, Infallible> = Default::default(); - let client_store = ClientStore::default(); let _gateway = tasks.register( - HttpGatewayTask::new(request, client_store.clone(), response_mode, addr), + HttpGatewayTask::new(request, ClientStore::default(), response_mode, addr), (), 8, ); @@ -593,8 +602,10 @@ mod test { conn_task.await.unwrap(); } + /// Verifies that [`HttpGateway`] sends request body frames to the server as soon as they are + /// available. #[tokio::test] - async fn streams_request_frames() { + async fn streams_request_body_frames() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -679,4 +690,85 @@ mod test { conn_task.await.unwrap(); } + + /// Verifies that [`HttpGateway`] reuses already established HTTP connections. + #[tokio::test] + async fn reuses_client_connections() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let service = service_fn(|_req: Request| { + std::future::ready(Ok::<_, Infallible>(Response::new(Empty::::new()))) + }); + + let (connection, _) = listener.accept().await.unwrap(); + std::mem::drop(listener); + http1::Builder::new() + .serve_connection(TokioIo::new(connection), service) + .await + .unwrap() + }); + + let mut request = HttpRequest { + connection_id: 0, + request_id: 0, + port: 80, + internal_request: InternalHttpRequest { + method: Method::GET, + uri: "/".parse().unwrap(), + headers: Default::default(), + version: Version::HTTP_11, + body: Default::default(), + }, + }; + request + .internal_request + .headers + .insert(header::CONNECTION, HeaderValue::from_static("keep-alive")); + + let mut tasks: BackgroundTasks = Default::default(); + let client_store = ClientStore::new_with_timeout(Duration::from_secs(1337 * 21 * 37)); + let _gateway_1 = tasks.register( + HttpGatewayTask::new( + request.clone(), + client_store.clone(), + ResponseMode::Basic, + addr, + ), + 0, + 8, + ); + let _gateway_2 = tasks.register( + HttpGatewayTask::new( + request.clone(), + client_store.clone(), + ResponseMode::Basic, + addr, + ), + 1, + 8, + ); + + let mut finished = 0; + let mut responses = 0; + + while finished < 2 && responses < 2 { + match tasks.next().await.unwrap() { + (id, TaskUpdate::Finished(Ok(()))) => { + println!("gateway {id} finished"); + finished += 1; + } + ( + id, + TaskUpdate::Message(InProxyTaskMessage::Http(HttpOut::ResponseBasic(response))), + ) => { + println!("gateway {id} returned a response"); + assert_eq!(response.internal_response.status, StatusCode::OK); + responses += 1; + } + other => panic!("unexpected task update: {other:?}"), + } + } + } } From d729c3b391012a87312ffa995743310792ecae56 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Thu, 16 Jan 2025 18:50:33 +0100 Subject: [PATCH 16/60] ClientStore test --- .../src/proxies/incoming/http/client_store.rs | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs index 4440cb1b1f6..1e53ccec6dc 100644 --- a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs +++ b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs @@ -148,3 +148,44 @@ async fn cleanup_task(clients: watch::Sender>, idle_client_ } } } + +#[cfg(test)] +mod test { + use std::{convert::Infallible, time::Duration}; + + use bytes::Bytes; + use http_body_util::Empty; + use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Response, Version}; + use hyper_util::rt::TokioIo; + use tokio::{net::TcpListener, time}; + + use super::ClientStore; + + /// Verifies that [`ClientStore`] cleans up unused connections. + #[tokio::test] + async fn cleans_up_unused_connections() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let service = service_fn(|_req: Request| { + std::future::ready(Ok::<_, Infallible>(Response::new(Empty::::new()))) + }); + + let (connection, _) = listener.accept().await.unwrap(); + std::mem::drop(listener); + http1::Builder::new() + .serve_connection(TokioIo::new(connection), service) + .await + .unwrap() + }); + + let client_store = ClientStore::new_with_timeout(Duration::from_millis(10)); + let client = client_store.get(addr, Version::HTTP_11).await.unwrap(); + client_store.push_idle(client); + + time::sleep(Duration::from_millis(100)).await; + + assert!(client_store.0.borrow().is_empty()); + } +} From 1a08649226be6a72ae36032c2964058ee6e4a64e Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Thu, 16 Jan 2025 19:14:04 +0100 Subject: [PATCH 17/60] Changed implementation of ClientStore shared state --- .../src/proxies/incoming/http/client_store.rs | 143 ++++++++++-------- 1 file changed, 76 insertions(+), 67 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs index 1e53ccec6dc..3f50c7304e6 100644 --- a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs +++ b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs @@ -1,21 +1,38 @@ -use std::{cmp, net::SocketAddr, time::Duration}; +use std::{ + cmp, + net::SocketAddr, + sync::{Arc, Mutex}, + time::Duration, +}; use hyper::Version; use tokio::{ - sync::watch, + sync::Notify, time::{self, Instant}, }; use super::{LocalHttpClient, LocalHttpError}; +/// Idle [`LocalHttpClient`] caches in [`ClientStore`]. #[derive(Debug)] struct IdleLocalClient { client: LocalHttpClient, last_used: Instant, } +/// Cache for unused [`LocalHttpClient`]s. +/// +/// [`LocalHttpClient`] that have not been used for some time are dropped in the background by a +/// dedicated [`tokio::task`]. #[derive(Clone)] -pub struct ClientStore(watch::Sender>); +pub struct ClientStore { + clients: Arc>>, + /// Used to notify other tasks when there is a new client in the store. + /// + /// Make sure to only call [`Notify::notify_waiters`] and [`Notify::notified`] when holding a + /// lock on [`Self::clients`]. Otherwise you'll have a race condition. + notify: Arc, +} impl Default for ClientStore { fn default() -> Self { @@ -26,40 +43,40 @@ impl Default for ClientStore { impl ClientStore { const IDLE_CLIENT_TIMEOUT: Duration = Duration::from_secs(3); + /// Creates a new store. + /// + /// The store will keep unused clients alive for at least the given time. pub fn new_with_timeout(timeout: Duration) -> Self { - let (tx, _) = watch::channel(Default::default()); + let store = Self { + clients: Default::default(), + notify: Default::default(), + }; - tokio::spawn(cleanup_task(tx.clone(), timeout)); + tokio::spawn(cleanup_task(store.clone(), timeout)); - Self(tx) + store } + /// Reuses or creates a new [`LocalHttpClient`]. pub async fn get( &self, server_addr: SocketAddr, version: Version, ) -> Result { - let mut ready = None; - - self.0.send_if_modified(|clients| { - println!("ready clients: {clients:?}"); - let position = clients.iter().position(|idle| { + let ready = { + let mut guard = self.clients.lock().unwrap(); + let position = guard.iter().position(|idle| { idle.client.handles_version(version) && idle.client.local_server_address() == server_addr }); - - let Some(position) = position else { - return false; - }; - - let client = clients.swap_remove(position).client; - ready.replace(client); - true - }); + match position { + Some(position) => Some(guard.swap_remove(position)), + None => None, + } + }; if let Some(ready) = ready { - println!("found ready client"); - return Ok(ready); + return Ok(ready.client); } let connect_task = tokio::spawn(LocalHttpClient::new(server_addr, version)); @@ -70,58 +87,54 @@ impl ClientStore { } } + /// Stores an unused [`LocalHttpClient`], so that it can be reused later. pub fn push_idle(&self, client: LocalHttpClient) { - println!("storing idle client {client:?}"); - self.0.send_modify(|clients| { - clients.push(IdleLocalClient { - client, - last_used: Instant::now(), - }) + let mut guard = self.clients.lock().unwrap(); + guard.push(IdleLocalClient { + client, + last_used: Instant::now(), }); + self.notify.notify_waiters(); } + /// Waits until there is a ready unused client. async fn wait_for_ready(&self, server_addr: SocketAddr, version: Version) -> LocalHttpClient { - let mut recevier = self.0.subscribe(); - loop { - let mut ready = None; - - self.0.send_if_modified(|clients| { - let position = clients.iter().position(|idle| { + let notified = { + let mut guard = self.clients.lock().unwrap(); + let position = guard.iter().position(|idle| { idle.client.handles_version(version) && idle.client.local_server_address() == server_addr }); - let Some(position) = position else { - return false; - }; - - let client = clients.swap_remove(position).client; - ready.replace(client); - - true - }); - if let Some(ready) = ready { - break ready; - } + match position { + Some(position) => return guard.swap_remove(position).client, + None => self.notify.notified(), + } + }; - recevier - .changed() - .await - .expect("sender alive in this struct"); + notified.await; } } } -async fn cleanup_task(clients: watch::Sender>, idle_client_timeout: Duration) { +/// Cleans up stale [`LocalHttpClient`]s from the [`ClientStore`]. +async fn cleanup_task(store: ClientStore, idle_client_timeout: Duration) { + let clients = Arc::downgrade(&store.clients); + let notify = store.notify.clone(); + std::mem::drop(store); + loop { + let Some(clients) = clients.upgrade() else { + break; + }; + let now = Instant::now(); let mut min_last_used = None; - - clients.send_if_modified(|clients| { - let mut removed = false; - - clients.retain(|client| { + let notified = { + let mut guard = clients.lock().unwrap(); + let notified = notify.notified(); + guard.retain(|client| { if client.last_used + idle_client_timeout > now { min_last_used = min_last_used .map(|previous| cmp::min(previous, client.last_used)) @@ -129,22 +142,16 @@ async fn cleanup_task(clients: watch::Sender>, idle_client_ true } else { - removed = true; false } }); - - removed - }); + notified + }; if let Some(min_last_used) = min_last_used { time::sleep_until(min_last_used + idle_client_timeout).await; } else { - clients - .subscribe() - .changed() - .await - .expect("sender alive in this function"); + notified.await; } } } @@ -155,7 +162,9 @@ mod test { use bytes::Bytes; use http_body_util::Empty; - use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Response, Version}; + use hyper::{ + body::Incoming, server::conn::http1, service::service_fn, Request, Response, Version, + }; use hyper_util::rt::TokioIo; use tokio::{net::TcpListener, time}; @@ -186,6 +195,6 @@ mod test { time::sleep(Duration::from_millis(100)).await; - assert!(client_store.0.borrow().is_empty()); + assert!(client_store.clients.lock().unwrap().is_empty()); } } From c25b15e661dd9e38c79a668473e1e764db39e8c7 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Thu, 16 Jan 2025 19:41:40 +0100 Subject: [PATCH 18/60] Some docs --- .../src/proxies/incoming/http/response_mode.rs | 5 +++++ .../src/proxies/incoming/http_gateway.rs | 16 +++++++++++++--- mirrord/intproxy/src/proxies/incoming/tasks.rs | 4 ++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs b/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs index 4c9140f1beb..c0a65a2072d 100644 --- a/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs +++ b/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs @@ -1,9 +1,14 @@ use mirrord_protocol::tcp::{HTTP_CHUNKED_RESPONSE_VERSION, HTTP_FRAMED_VERSION}; +/// Determines how [`IncomingProxy`](crate::proxies::incoming::IncomingProxy) should send HTTP +/// responses. #[derive(Debug, Clone, Copy, Default)] pub enum ResponseMode { + /// [`LayerTcpSteal::HttpResponseChunked`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseChunked) Chunked, + /// [`LayerTcpSteal::HttpResponseFramed`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseFramed) Framed, + /// [`LayerTcpSteal::HttpResponse`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponse) #[default] Basic, } diff --git a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs index e684274b0a0..8a4fc255264 100644 --- a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs +++ b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs @@ -18,14 +18,22 @@ use super::{ }; use crate::background_tasks::{BackgroundTask, MessageBus}; +/// [`BackgroundTask`] used by the [`IncomingProxy`](super::IncomingProxy). +/// +/// Responsible for delivering a single HTTP request to the user application. pub struct HttpGatewayTask { + /// Request to deliver. request: HttpRequest, + /// Shared cache of [`LocalHttpClient`](super::http::LocalHttpClient)s. client_store: ClientStore, + /// Determines response variant. response_mode: ResponseMode, + /// Address of the HTTP server in the user application. server_addr: SocketAddr, } impl HttpGatewayTask { + /// Creates a new gateway task. pub fn new( request: HttpRequest, client_store: ClientStore, @@ -40,13 +48,16 @@ impl HttpGatewayTask { } } + /// Makes an attempt to send the request and read the whole response. + /// + /// [`Err`] is handled in the caller and, if we run out of send attempts, converted to an error + /// response. Because of this, this function should not return any error that happened after + /// sending [`ChunkedResponse::Start`]. The agent would get a duplicated response. async fn send_attempt(&self, message_bus: &mut MessageBus) -> Result<(), LocalHttpError> { - println!("making send attempt"); let mut client = self .client_store .get(self.server_addr, self.request.version()) .await?; - println!("got client"); let mut response = client.send_request(self.request.clone()).await?; let on_upgrade = (response.status() == StatusCode::SWITCHING_PROTOCOLS) .then(|| hyper::upgrade::on(&mut response)); @@ -183,7 +194,6 @@ impl BackgroundTask for HttpGatewayTask { .then(|| backoffs.next()) .flatten() .flatten(); - println!("send attempt failed with {error}, backoff={backoff:?}"); let Some(backoff) = backoff else { break error; }; diff --git a/mirrord/intproxy/src/proxies/incoming/tasks.rs b/mirrord/intproxy/src/proxies/incoming/tasks.rs index bb35a882eab..7fd5c72b661 100644 --- a/mirrord/intproxy/src/proxies/incoming/tasks.rs +++ b/mirrord/intproxy/src/proxies/incoming/tasks.rs @@ -54,6 +54,10 @@ pub enum InProxyTask { HttpGateway(HttpGatewayId), } +/// Identifies a [`HttpGatewayTask`](super::http_gateway::HttpGatewayTask). +/// +/// ([`ConnectionId`], [`RequestId`]) would suffice, but storing extra data allows us to produce an +/// error response in case the task somehow panics. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct HttpGatewayId { pub connection_id: ConnectionId, From 295099f6d842e528b8abe8cddcab55cb2031c432 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Thu, 16 Jan 2025 22:25:50 +0100 Subject: [PATCH 19/60] Docs --- mirrord/intproxy/src/proxies/incoming.rs | 512 ++++++++++-------- .../src/proxies/incoming/tcp_proxy.rs | 13 + 2 files changed, 291 insertions(+), 234 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 8ee466ae11e..5d17dfe4d39 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -6,11 +6,7 @@ //! until connection becomes readable (is TCP) or receives an http request. //! 2. HttpSender - -use std::{ - collections::{hash_map::Entry, HashMap}, - io, - net::SocketAddr, -}; +use std::{collections::HashMap, io, net::SocketAddr}; use bound_socket::BoundTcpSocket; use http::{ClientStore, ResponseMode, StreamingBody}; @@ -72,8 +68,13 @@ pub enum IncomingProxyMessage { AgentProtocolVersion(semver::Version), } +/// Handle to a running [`HttpGatewayTask`]. struct HttpGatewayHandle { + /// Only keeps the [`HttpGatewayTask`] alive. _tx: TaskSender, + /// For sending request body [`Frame`](hyper::body::Frame)s. + /// + /// [`None`] if all frames were already sent. body_tx: Option>, } @@ -96,26 +97,34 @@ pub struct IncomingProxy { subscriptions: SubscriptionsManager, /// For managing intercepted connections metadata. metadata_store: MetadataStore, + /// What HTTP response flavor we produce. response_mode: ResponseMode, /// Cache for [`LocalHttpClient`](http::LocalHttpClient)s. client_store: ClientStore, /// Each mirrored remote connection is mapped to a [TcpProxyTask] in mirror mode. + /// + /// Each entry here maps to a connection that is in progress both locally and remotely. mirror_tcp_proxies: HashMap>, /// Each remote connection stolen in whole is mapped to a [TcpProxyTask] in steal mode. + /// + /// Each entry here maps to a connection that is in progress both locally and remotely. steal_tcp_proxies: HashMap>, - /// Each remote connection stolen with a filter is mapped to [HttpGatewayTask]s. + /// Each remote HTTP request stolen with a filter is mapped to a [HttpGatewayTask]. + /// + /// Each entry here maps to a request that is in progress both locally and remotely. http_gateways: HashMap>, + /// Running [`BackgroundTask`]s utilized by this proxy. tasks: BackgroundTasks, } impl IncomingProxy { - /// Used when registering new [`Interceptor`]s in the internal [`BackgroundTasks`] instance. + /// Used when registering new tasks in the internal [`BackgroundTasks`] instance. const CHANNEL_SIZE: usize = 512; - /// Retrieves or creates an [`Interceptor`] for the given [`HttpRequestFallback`]. - /// The request may or may not belong to an existing connection (when stealing with an http - /// filter, connections are created implicitly). - #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] + /// Starts a new [`HttpGatewayTask`] to handle the given request. + /// + /// If we don't have a [`PortSubscription`] for the port, the task is not started. + /// Instead, we respond immediately to the agent. async fn start_http_gateway( &mut self, request: HttpRequest, @@ -132,49 +141,35 @@ impl IncomingProxy { }); let Some(subscription) = subscription else { tracing::debug!( - port = request.port, - connection_id = request.connection_id, - request_id = request.request_id, - "Received a new request within a stale port subscription, sending an unsubscribe request or an error response." + ?request, + "Received a new HTTP request within a stale port subscription, \ + sending an unsubscribe request or an error response." ); - match self.http_gateways.entry(request.connection_id) { - // This is a new connection, we can just unsubscribe it. - Entry::Vacant(..) => { - message_bus - .send(ClientMessage::TcpSteal( - LayerTcpSteal::ConnectionUnsubscribe(request.connection_id), - )) - .await; - } - - // This is not a new connection, but we don't have any requests in progress. - // We can still unsubscribe it. - Entry::Occupied(e) if e.get().is_empty() => { - message_bus - .send(ClientMessage::TcpSteal( - LayerTcpSteal::ConnectionUnsubscribe(request.connection_id), - )) - .await; - e.remove(); - } - - // This is not a new connection, and we have requests in progress. - // We can only send an error response. - Entry::Occupied(..) => { - let response = http::mirrord_error_response( - "port no longer subscribed with an HTTP filter", - request.version(), - request.connection_id, - request.request_id, - request.port, - ); - message_bus - .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse( - response, - ))) - .await; - } + let no_other_requests = self + .http_gateways + .get(&request.connection_id) + .map(|gateways| gateways.is_empty()) + .unwrap_or(true); + if no_other_requests { + message_bus + .send(ClientMessage::TcpSteal( + LayerTcpSteal::ConnectionUnsubscribe(request.connection_id), + )) + .await; + } else { + let response = http::mirrord_error_response( + "port no longer subscribed with an HTTP filter", + request.version(), + request.connection_id, + request.request_id, + request.port, + ); + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse( + response, + ))) + .await; } return; @@ -204,8 +199,175 @@ impl IncomingProxy { .insert(request_id, HttpGatewayHandle { _tx: tx, body_tx }); } + /// Handles [`NewTcpConnection`] message from the agent, starting a new [`TcpProxyTask`]. + /// + /// If we don't have a [`PortSubscription`] for the port, the task is not started. + /// Instead, we respond immediately to the agent. + async fn handle_new_connection( + &mut self, + connection: NewTcpConnection, + is_steal: bool, + message_bus: &mut MessageBus, + ) -> Result<(), IncomingProxyError> { + let NewTcpConnection { + connection_id, + remote_address, + destination_port, + source_port, + local_address, + } = connection; + + let subscription = self + .subscriptions + .get(destination_port) + .filter(|subscription| match &subscription.subscription { + PortSubscription::Mirror(..) if !is_steal => true, + PortSubscription::Steal(StealType::All(..)) if is_steal => true, + _ => false, + }); + let Some(subscription) = subscription else { + tracing::debug!( + port = destination_port, + connection_id, + "Received a new connection within a stale port subscription, sending an unsubscribe request.", + ); + + let message = if is_steal { + ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe(connection_id)) + } else { + ClientMessage::TcpSteal(LayerTcpSteal::ConnectionUnsubscribe(connection_id)) + }; + message_bus.send(message).await; + + return Ok(()); + }; + + let socket = BoundTcpSocket::bind_specified_or_localhost(subscription.listening_on.ip()) + .map_err(IncomingProxyError::SocketSetupFailed)?; + + self.metadata_store.expect( + ConnMetadataRequest { + listener_address: subscription.listening_on, + peer_address: socket + .local_addr() + .map_err(IncomingProxyError::SocketSetupFailed)?, + }, + connection_id, + ConnMetadataResponse { + remote_source: SocketAddr::new(remote_address, source_port), + local_address, + }, + ); + + let id = if is_steal { + InProxyTask::StealTcpProxy(connection_id) + } else { + InProxyTask::MirrorTcpProxy(connection_id) + }; + let tx = self.tasks.register( + TcpProxyTask::new( + LocalTcpConnection::FromTheStart { + socket, + peer: subscription.listening_on, + }, + !is_steal, + ), + id, + Self::CHANNEL_SIZE, + ); + + if is_steal { + self.steal_tcp_proxies.insert(connection_id, tx); + } else { + self.mirror_tcp_proxies.insert(connection_id, tx); + } + + Ok(()) + } + + /// Handles [`ChunkedRequest`] message from the agent. + async fn handle_chunked_request( + &mut self, + request: ChunkedRequest, + message_bus: &mut MessageBus, + ) { + match request { + ChunkedRequest::Start(request) => { + let (body_tx, body_rx) = mpsc::channel(128); + let request = request.map_body(|frames| StreamingBody::new(body_rx, frames)); + self.start_http_gateway(request, Some(body_tx), message_bus) + .await; + } + + ChunkedRequest::Body(ChunkedHttpBody { + frames, + is_last, + connection_id, + request_id, + }) => { + let gateway = self + .http_gateways + .get_mut(&connection_id) + .and_then(|gateways| gateways.get_mut(&request_id)); + let Some(gateway) = gateway else { + tracing::debug!( + connection_id, + request_id, + frames = ?frames, + last_body_chunk = is_last, + "Received a body chunk for a request that is no longer alive locally" + ); + + return; + }; + + let Some(tx) = gateway.body_tx.as_ref() else { + tracing::debug!( + connection_id, + request_id, + frames = ?frames, + last_body_chunk = is_last, + "Received a body chunk for a request with a closed body" + ); + + return; + }; + + for frame in frames { + if let Err(err) = tx.send(frame).await { + tracing::debug!( + frame = ?err.0, + connection_id, + request_id, + "Failed to send an HTTP request body frame to the HttpGatewayTask, channel is closed" + ); + break; + } + } + + if is_last { + gateway.body_tx = None; + } + } + + ChunkedRequest::Error(ChunkedHttpError { + connection_id, + request_id, + }) => { + tracing::debug!( + connection_id, + request_id, + "Received an error in an HTTP request body", + ); + + if let Some(gateways) = self.http_gateways.get_mut(&connection_id) { + gateways.remove(&request_id); + }; + } + } + } + /// Handles all agent messages. - #[tracing::instrument(level = Level::TRACE, skip(self, message_bus), err)] async fn handle_agent_message( &mut self, message: DaemonTcp, @@ -223,7 +385,7 @@ impl IncomingProxy { } DaemonTcp::Data(data) => { - let tx: Option<&TaskSender> = if is_steal { + let tx = if is_steal { self.steal_tcp_proxies.get(&data.connection_id) } else { self.mirror_tcp_proxies.get(&data.connection_id) @@ -234,6 +396,7 @@ impl IncomingProxy { } else { tracing::debug!( connection_id = data.connection_id, + bytes = data.bytes.len(), "Received new data for a connection that does not belong to any TcpProxy task", ); } @@ -250,139 +413,12 @@ impl IncomingProxy { } DaemonTcp::HttpRequestChunked(request) => { - match request { - ChunkedRequest::Start(request) => { - let (body_tx, body_rx) = mpsc::channel(128); - let request = - request.map_body(|frames| StreamingBody::new(body_rx, frames)); - self.start_http_gateway(request, Some(body_tx), message_bus) - .await; - } - - ChunkedRequest::Body(ChunkedHttpBody { - frames, - is_last, - connection_id, - request_id, - }) => { - let gateway = self - .http_gateways - .get_mut(&connection_id) - .and_then(|gateways| gateways.get_mut(&request_id)); - let Some(gateway) = gateway else { - return Ok(()); - }; - - let Some(tx) = gateway.body_tx.as_ref() else { - return Ok(()); - }; - - for frame in frames { - if let Err(err) = tx.send(frame).await { - tracing::debug!( - frame = ?err.0, - connection_id, - request_id, - "Failed to send an HTTP request body frame to the HttpGateway task, channel is closed" - ); - break; - } - } - - if is_last { - gateway.body_tx = None; - } - } - - ChunkedRequest::Error(ChunkedHttpError { - connection_id, - request_id, - }) => { - tracing::debug!( - connection_id, - request_id, - "Received an error in an HTTP request body", - ); - - if let Some(gateways) = self.http_gateways.get_mut(&connection_id) { - gateways.remove(&request_id); - }; - } - }; + self.handle_chunked_request(request, message_bus).await; } - DaemonTcp::NewConnection(NewTcpConnection { - connection_id, - remote_address, - destination_port, - source_port, - local_address, - }) => { - let subscription = - self.subscriptions - .get(destination_port) - .filter(|subscription| match &subscription.subscription { - PortSubscription::Mirror(..) if !is_steal => true, - PortSubscription::Steal(StealType::All(..)) if is_steal => true, - _ => false, - }); - let Some(subscription) = subscription else { - tracing::debug!( - port = destination_port, - connection_id, - "Received a new connection within a stale port subscription, sending an unsubscribe request.", - ); - - let message = if is_steal { - ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe(connection_id)) - } else { - ClientMessage::TcpSteal(LayerTcpSteal::ConnectionUnsubscribe(connection_id)) - }; - message_bus.send(message).await; - - return Ok(()); - }; - - let socket = - BoundTcpSocket::bind_specified_or_localhost(subscription.listening_on.ip()) - .map_err(IncomingProxyError::SocketSetupFailed)?; - - self.metadata_store.expect( - ConnMetadataRequest { - listener_address: subscription.listening_on, - peer_address: socket - .local_addr() - .map_err(IncomingProxyError::SocketSetupFailed)?, - }, - connection_id, - ConnMetadataResponse { - remote_source: SocketAddr::new(remote_address, source_port), - local_address, - }, - ); - - let id = if is_steal { - InProxyTask::StealTcpProxy(connection_id) - } else { - InProxyTask::MirrorTcpProxy(connection_id) - }; - let tx = self.tasks.register( - TcpProxyTask::new( - LocalTcpConnection::FromTheStart { - socket, - peer: subscription.listening_on, - }, - !is_steal, - ), - id, - Self::CHANNEL_SIZE, - ); - - if is_steal { - self.steal_tcp_proxies.insert(connection_id, tx); - } else { - self.mirror_tcp_proxies.insert(connection_id, tx); - } + DaemonTcp::NewConnection(connection) => { + self.handle_new_connection(connection, is_steal, message_bus) + .await?; } DaemonTcp::SubscribeResult(result) => { @@ -397,6 +433,7 @@ impl IncomingProxy { Ok(()) } + /// Handles all messages from this task's [`MessageBus`]. #[tracing::instrument(level = Level::TRACE, skip(self, message_bus), err)] async fn handle_message( &mut self, @@ -463,64 +500,53 @@ impl IncomingProxy { Ok(()) } + /// Handles all updates from [`TcpProxyTask`]s. #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] - async fn handle_task_update( + async fn handle_tcp_proxy_update( &mut self, - id: InProxyTask, + connection_id: ConnectionId, + is_steal: bool, update: TaskUpdate, message_bus: &mut MessageBus, ) { - match (id, update) { - (InProxyTask::MirrorTcpProxy(connection_id), TaskUpdate::Finished(result)) => { + match update { + TaskUpdate::Finished(result) => { match result { Err(TaskError::Error(error)) => { - tracing::warn!(connection_id, %error, "MirrorTcpProxy task failed"); + tracing::warn!(connection_id, %error, is_steal, "TcpProxyTask failed"); } Err(TaskError::Panic) => { - tracing::error!(connection_id, "MirrorTcpProxy task panicked"); + tracing::error!(connection_id, is_steal, "TcpProxyTask task panicked"); } Ok(()) => {} }; self.metadata_store.no_longer_expect(connection_id); - if self.mirror_tcp_proxies.remove(&connection_id).is_some() { - message_bus - .send(ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe( - connection_id, - ))) - .await; - } - } - - (InProxyTask::MirrorTcpProxy(..), TaskUpdate::Message(..)) => unreachable!(), - - (InProxyTask::StealTcpProxy(connection_id), TaskUpdate::Finished(result)) => { - match result { - Err(TaskError::Error(error)) => { - tracing::warn!(connection_id, %error, "StealTcpProxy task failed"); + if is_steal { + if self.steal_tcp_proxies.remove(&connection_id).is_some() { + message_bus + .send(ClientMessage::TcpSteal( + LayerTcpSteal::ConnectionUnsubscribe(connection_id), + )) + .await; } - Err(TaskError::Panic) => { - tracing::error!(connection_id, "StealTcpProxy task panicked"); + } else { + if self.mirror_tcp_proxies.remove(&connection_id).is_some() { + message_bus + .send(ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe( + connection_id, + ))) + .await; } - Ok(()) => {} - }; - - self.metadata_store.no_longer_expect(connection_id); - - if self.steal_tcp_proxies.remove(&connection_id).is_some() { - message_bus - .send(ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe( - connection_id, - ))) - .await; } } - ( - InProxyTask::StealTcpProxy(connection_id), - TaskUpdate::Message(InProxyTaskMessage::Tcp(bytes)), - ) => { + TaskUpdate::Message(..) if !is_steal => { + unreachable!("TcpProxyTask does not produce messages in mirror mode") + } + + TaskUpdate::Message(InProxyTaskMessage::Tcp(bytes)) => { if self.steal_tcp_proxies.contains_key(&connection_id) { message_bus .send(ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { @@ -531,11 +557,22 @@ impl IncomingProxy { } } - (InProxyTask::StealTcpProxy(..), TaskUpdate::Message(InProxyTaskMessage::Http(..))) => { - unreachable!() + TaskUpdate::Message(InProxyTaskMessage::Http(..)) => { + unreachable!("TcpProxyTask does not produce HTTP messages") } + } + } - (InProxyTask::HttpGateway(id), TaskUpdate::Finished(result)) => { + /// Handles all updates from [`HttpGatewayTask`]s. + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] + async fn handle_http_gateway_update( + &mut self, + id: HttpGatewayId, + update: TaskUpdate, + message_bus: &mut MessageBus, + ) { + match update { + TaskUpdate::Finished(result) => { let respond_on_panic = self .http_gateways .get_mut(&id.connection_id) @@ -546,7 +583,7 @@ impl IncomingProxy { Ok(()) => {} Err(TaskError::Error( InProxyTaskError::IoError(..) | InProxyTaskError::UpgradeError(..), - )) => unreachable!(), + )) => unreachable!("HttpGatewayTask does not return any errors"), Err(TaskError::Panic) => { tracing::error!( connection_id = id.connection_id, @@ -572,10 +609,7 @@ impl IncomingProxy { } } - ( - InProxyTask::HttpGateway(id), - TaskUpdate::Message(InProxyTaskMessage::Http(message)), - ) => { + TaskUpdate::Message(InProxyTaskMessage::Http(message)) => { let exists = self .http_gateways .get(&id.connection_id) @@ -618,8 +652,8 @@ impl IncomingProxy { } } - (InProxyTask::HttpGateway(..), TaskUpdate::Message(InProxyTaskMessage::Tcp(..))) => { - unreachable!() + TaskUpdate::Message(InProxyTaskMessage::Tcp(..)) => { + unreachable!("HttpGatewayTask does not produce TCP messages") } } } @@ -642,7 +676,17 @@ impl BackgroundTask for IncomingProxy { Some(message) => self.handle_message(message, message_bus).await?, }, - Some((id, update)) = self.tasks.next() => self.handle_task_update(id, update, message_bus).await, + Some((id, update)) = self.tasks.next() => match id { + InProxyTask::MirrorTcpProxy(connection_id) => { + self.handle_tcp_proxy_update(connection_id, false, update, message_bus).await; + } + InProxyTask::StealTcpProxy(connection_id) => { + self.handle_tcp_proxy_update(connection_id, true, update, message_bus).await; + } + InProxyTask::HttpGateway(id) => { + self.handle_http_gateway_update(id, update, message_bus).await; + } + }, } } } diff --git a/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs index 98304e155be..1850c1a11d3 100644 --- a/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs +++ b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs @@ -15,20 +15,33 @@ use super::{ }; use crate::background_tasks::{BackgroundTask, MessageBus}; +/// Local TCP connections between the [`TcpProxyTask`] and the user application. pub enum LocalTcpConnection { + /// Not yet established. Should be made by the [`TcpProxyTask`] from the given + /// [`BoundTcpSocket`]. FromTheStart { socket: BoundTcpSocket, peer: SocketAddr, }, + /// Upgraded HTTP connection from a previously stolen HTTP request. AfterUpgrade(OnUpgrade), } +/// [`BackgroundTask`] of [`IncomingProxy`](super::IncomingProxy) that handles a remote +/// stolen/mirrored TCP connection. pub struct TcpProxyTask { + /// The local connection between this task and the user application. connection: LocalTcpConnection, + /// Whether this task should silently discard data coming from the user application. discard_data: bool, } impl TcpProxyTask { + /// Creates a new task. + /// + /// * This task will talk with the user application using the given [`LocalTcpConnection`]. + /// * If `discard_data` is set, this task will silently discard all data coming from the user + /// application. pub fn new(connection: LocalTcpConnection, discard_data: bool) -> Self { Self { connection, From 750890257e7df33fd29ca7e63299d48ad744c92b Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Thu, 16 Jan 2025 22:28:16 +0100 Subject: [PATCH 20/60] Removed obsolete integration test - replaced before with a unit test --- mirrord/layer/tests/apps/app_chunked.py | 33 ----- mirrord/layer/tests/common/mod.rs | 13 +- mirrord/layer/tests/issue3013.rs | 184 ------------------------ 3 files changed, 2 insertions(+), 228 deletions(-) delete mode 100644 mirrord/layer/tests/apps/app_chunked.py delete mode 100644 mirrord/layer/tests/issue3013.rs diff --git a/mirrord/layer/tests/apps/app_chunked.py b/mirrord/layer/tests/apps/app_chunked.py deleted file mode 100644 index 04ffde714b7..00000000000 --- a/mirrord/layer/tests/apps/app_chunked.py +++ /dev/null @@ -1,33 +0,0 @@ -from http.server import HTTPServer, BaseHTTPRequestHandler -import time; - -class ChunkedHTTPHandler(BaseHTTPRequestHandler): - protocol_version = "HTTP/1.1" - - def do_GET(self): - # Send response headers - self.send_response(200) - self.send_header("Content-Type", "text/plain") - self.send_header("Transfer-Encoding", "chunked") - self.end_headers() - - # Send the response in chunks - chunks = [ - "This is the first chunk.\n"*8000, - "This is the second chunk.\n" - ] - for chunk in chunks: - time.sleep(3.0) - # Write the chunk size in hexadecimal followed by the chunk data - self.wfile.write(f"{len(chunk):X}\r\n".encode('utf-8')) - self.wfile.write(chunk.encode('utf-8')) - self.wfile.write(b"\r\n") - - # Signal the end of the response - self.wfile.write(b"0\r\n\r\n") - -if __name__ == "__main__": - port = 80 - print(f"Starting server on port {port}") - server = HTTPServer(("0.0.0.0", port), ChunkedHTTPHandler) - server.serve_forever() diff --git a/mirrord/layer/tests/common/mod.rs b/mirrord/layer/tests/common/mod.rs index dbcb92adc90..303231cf3da 100644 --- a/mirrord/layer/tests/common/mod.rs +++ b/mirrord/layer/tests/common/mod.rs @@ -786,9 +786,6 @@ pub enum Application { DynamicApp(String, Vec), /// Go app that only checks whether Linux pidfd syscalls are supported. Go23Issue2988, - /// Python HTTP server that returns large (200kb) chunked responses - /// and processes one request at a time. - PythonHTTPChunked, } impl Application { @@ -817,8 +814,7 @@ impl Application { Application::PythonFlaskHTTP | Application::PythonSelfConnect | Application::PythonDontLoad - | Application::PythonListen - | Application::PythonHTTPChunked => Self::get_python3_executable().await, + | Application::PythonListen => Self::get_python3_executable().await, Application::PythonFastApiHTTP | Application::PythonIssue864 => String::from("uvicorn"), Application::Fork => String::from("tests/apps/fork/out.c_test_app"), Application::ReadLink => String::from("tests/apps/readlink/out.c_test_app"), @@ -1109,10 +1105,6 @@ impl Application { ] } Application::DynamicApp(_, args) => args.to_owned(), - Application::PythonHTTPChunked => { - app_path.push("app_chunked.py"); - vec![String::from("-u"), app_path.to_string_lossy().to_string()] - } } } @@ -1126,8 +1118,7 @@ impl Application { | Application::Go23FileOps | Application::NodeHTTP | Application::RustIssue1054 - | Application::PythonFlaskHTTP - | Application::PythonHTTPChunked => 80, + | Application::PythonFlaskHTTP => 80, // mapped from 9999 in `configs/port_mapping.json` Application::PythonFastApiHTTP | Application::PythonIssue864 => 1234, Application::RustIssue1123 => 41222, diff --git a/mirrord/layer/tests/issue3013.rs b/mirrord/layer/tests/issue3013.rs deleted file mode 100644 index 3e3870d3f3e..00000000000 --- a/mirrord/layer/tests/issue3013.rs +++ /dev/null @@ -1,184 +0,0 @@ -#![feature(assert_matches)] -#![warn(clippy::indexing_slicing)] -use std::{ - path::{Path, PathBuf}, - time::Duration, -}; - -use hyper::{ - header::{HeaderName, HeaderValue}, - Method, StatusCode, Version, -}; -use mirrord_protocol::{ - self, - file::{ - CloseFileRequest, OpenFileRequest, OpenFileResponse, ReadFileRequest, ReadFileResponse, - }, - tcp::{ - DaemonTcp, HttpRequest, HttpResponse, InternalHttpRequest, InternalHttpResponse, - LayerTcpSteal, StealType, - }, - ClientMessage, DaemonMessage, FileRequest, FileResponse, -}; -use rstest::rstest; - -mod common; - -pub use common::*; - -/// Verifies that [issue 3013](https://github.com/metalbear-co/mirrord/issues/3013) is resolved. -/// -/// The issue was that the first request was leaving behind a lingering HTTP connection, that was in -/// turn blocking the local application. The lingering connection was not a bug on our side, but -/// still we can handle this case more smoothly. -#[rstest] -#[tokio::test] -#[timeout(Duration::from_secs(60))] -async fn issue_3013(dylib_path: &Path) { - let config_file = tempfile::tempdir().unwrap(); - let config = serde_json::json!( - { - "feature": { - "network": { - "outgoing": false, - "dns": false, - "incoming": { - "mode": "steal", - "http_filter": { - "header_filter": "x-filter: yes", - }, - } - }, - "fs": { - "mode": "local", - } - }, - } - ); - let config_path = config_file.path().join("config.json"); - tokio::fs::write(&config_path, serde_json::to_string_pretty(&config).unwrap()) - .await - .unwrap(); - - let (test_process, mut test_intproxy) = Application::PythonHTTPChunked - .start_process_with_layer(dylib_path, vec![], Some(&config_path.to_string_lossy())) - .await; - - match test_intproxy.recv().await { - ClientMessage::FileRequest(FileRequest::Open(OpenFileRequest { path, .. })) => { - assert_eq!(path, PathBuf::from("/etc/hostname")); - } - other => panic!("unexpected message from intproxy: {other:?}"), - } - test_intproxy - .send(DaemonMessage::File(FileResponse::Open(Ok( - OpenFileResponse { fd: 2137 }, - )))) - .await; - match test_intproxy.recv().await { - ClientMessage::FileRequest(FileRequest::Read(ReadFileRequest { - remote_fd: 2137, .. - })) => {} - other => panic!("unexpected message from intproxy: {other:?}"), - } - test_intproxy - .send(DaemonMessage::File(FileResponse::Read(Ok( - ReadFileResponse { - bytes: "test-hostname".as_bytes().into(), - read_amount: 13, - }, - )))) - .await; - match test_intproxy.recv().await { - ClientMessage::FileRequest(FileRequest::Close(CloseFileRequest { fd: 2137 })) => {} - other => panic!("unexpected message from intproxy: {other:?}"), - } - - match test_intproxy.recv().await { - ClientMessage::TcpSteal(LayerTcpSteal::PortSubscribe(StealType::FilteredHttpEx(..))) => {} - other => panic!("unexpected message from intproxy: {other:?}"), - } - test_intproxy - .send(DaemonMessage::TcpSteal(DaemonTcp::SubscribeResult(Ok(80)))) - .await; - test_process - .wait_for_line(Duration::from_secs(40), "daemon subscribed") - .await; - println!("The application subscribed the port"); - - println!("Sending the first request to the intproxy"); - test_intproxy - .send(DaemonMessage::TcpSteal(DaemonTcp::HttpRequestFramed( - HttpRequest { - internal_request: InternalHttpRequest { - method: Method::GET, - uri: "/some/path".parse().unwrap(), - headers: [( - HeaderName::from_static("connection"), - HeaderValue::from_static("keep-alive"), - )] - .into_iter() - .collect(), - version: Version::HTTP_11, - body: Default::default(), - }, - connection_id: 0, - request_id: 0, - port: 80, - }, - ))) - .await; - match test_intproxy.recv().await { - ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(HttpResponse { - port: 80, - connection_id: 0, - request_id: 0, - internal_response: - InternalHttpResponse { - status: StatusCode::OK, - version: Version::HTTP_11, - .. - }, - })) => {} - other => panic!("unexpected message from intproxy: {other:?}"), - } - println!("Received the first response from the intproxy"); - - println!("Sending the second request to the intproxy, without closing the first connection"); - test_intproxy - .send(DaemonMessage::TcpSteal(DaemonTcp::HttpRequestFramed( - HttpRequest { - internal_request: InternalHttpRequest { - method: Method::GET, - uri: "/some/path".parse().unwrap(), - headers: [( - HeaderName::from_static("connection"), - HeaderValue::from_static("keep-alive"), - )] - .into_iter() - .collect(), - version: Version::HTTP_11, - body: Default::default(), - }, - connection_id: 1, - request_id: 0, - port: 80, - }, - ))) - .await; - match test_intproxy.recv().await { - ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(HttpResponse { - port: 80, - connection_id: 1, - request_id: 0, - internal_response: - InternalHttpResponse { - status: StatusCode::OK, - version: Version::HTTP_11, - .. - }, - })) => {} - other => panic!("unexpected message from intproxy: {other:?}"), - } - println!("Received the second response from the intproxy"); -} From 203af979e4ae96af1b9d009ed0d66256b51e2900 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Thu, 16 Jan 2025 22:43:13 +0100 Subject: [PATCH 21/60] More ClientStore tracing --- mirrord/intproxy/src/proxies/incoming/http.rs | 8 ++++++ .../src/proxies/incoming/http/client_store.rs | 28 +++++++++++++++---- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index 5b9a7a0012b..50d4b0ff06c 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -28,6 +28,8 @@ pub struct LocalHttpClient { sender: HttpSender, /// Address of the user application's HTTP server. local_server_address: SocketAddr, + /// Address of this client's TCP socket. + self_address: SocketAddr, } impl LocalHttpClient { @@ -42,11 +44,15 @@ impl LocalHttpClient { let local_server_address = stream .peer_addr() .map_err(LocalHttpError::SocketSetupFailed)?; + let self_address = stream + .local_addr() + .map_err(LocalHttpError::SocketSetupFailed)?; let sender = HttpSender::handshake(version, stream).await?; Ok(Self { sender, local_server_address, + self_address, }) } @@ -80,6 +86,8 @@ impl fmt::Debug for LocalHttpClient { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("LocalHttpClient") .field("local_server_address", &self.local_server_address) + .field("self_address", &self.self_address) + .field("is_http_1", &matches!(self.sender, HttpSender::V1(..))) .finish() } } diff --git a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs index 3f50c7304e6..4857667ab53 100644 --- a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs +++ b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs @@ -1,5 +1,5 @@ use std::{ - cmp, + cmp, fmt, net::SocketAddr, sync::{Arc, Mutex}, time::Duration, @@ -10,20 +10,29 @@ use tokio::{ sync::Notify, time::{self, Instant}, }; +use tracing::Level; use super::{LocalHttpClient, LocalHttpError}; /// Idle [`LocalHttpClient`] caches in [`ClientStore`]. -#[derive(Debug)] struct IdleLocalClient { client: LocalHttpClient, last_used: Instant, } +impl fmt::Debug for IdleLocalClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IdleLocalClient") + .field("client", &self.client) + .field("idle_for_s", &self.last_used.elapsed().as_secs_f32()) + .finish() + } +} + /// Cache for unused [`LocalHttpClient`]s. /// /// [`LocalHttpClient`] that have not been used for some time are dropped in the background by a -/// dedicated [`tokio::task`]. +/// dedicated [`tokio::task`]. This timeout defaults to [`Self::IDLE_CLIENT_DEFAULT_TIMEOUT`]. #[derive(Clone)] pub struct ClientStore { clients: Arc>>, @@ -36,12 +45,12 @@ pub struct ClientStore { impl Default for ClientStore { fn default() -> Self { - Self::new_with_timeout(Self::IDLE_CLIENT_TIMEOUT) + Self::new_with_timeout(Self::IDLE_CLIENT_DEFAULT_TIMEOUT) } } impl ClientStore { - const IDLE_CLIENT_TIMEOUT: Duration = Duration::from_secs(3); + pub const IDLE_CLIENT_DEFAULT_TIMEOUT: Duration = Duration::from_secs(3); /// Creates a new store. /// @@ -58,6 +67,7 @@ impl ClientStore { } /// Reuses or creates a new [`LocalHttpClient`]. + #[tracing::instrument(level = Level::TRACE, skip(self), ret, err(level = Level::WARN))] pub async fn get( &self, server_addr: SocketAddr, @@ -76,6 +86,7 @@ impl ClientStore { }; if let Some(ready) = ready { + tracing::trace!(?ready, "Reused an idle client"); return Ok(ready.client); } @@ -83,11 +94,15 @@ impl ClientStore { tokio::select! { result = connect_task => result.expect("this task should not panic"), - ready = self.wait_for_ready(server_addr, version) => Ok(ready), + ready = self.wait_for_ready(server_addr, version) => { + tracing::trace!(?ready, "Reused an idle client"); + Ok(ready) + }, } } /// Stores an unused [`LocalHttpClient`], so that it can be reused later. + #[tracing::instrument(level = Level::TRACE, skip(self))] pub fn push_idle(&self, client: LocalHttpClient) { let mut guard = self.clients.lock().unwrap(); guard.push(IdleLocalClient { @@ -142,6 +157,7 @@ async fn cleanup_task(store: ClientStore, idle_client_timeout: Duration) { true } else { + tracing::trace!(?client, "Dropping an idle client"); false } }); From 39dee15b510cad7fe18d02d22ea5845a6d2294f9 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Thu, 16 Jan 2025 22:49:20 +0100 Subject: [PATCH 22/60] Less spammy debug for InProxyTaskMessage --- mirrord/intproxy/src/proxies/incoming/tasks.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/tasks.rs b/mirrord/intproxy/src/proxies/incoming/tasks.rs index 7fd5c72b661..d913715adba 100644 --- a/mirrord/intproxy/src/proxies/incoming/tasks.rs +++ b/mirrord/intproxy/src/proxies/incoming/tasks.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, io}; +use std::{convert::Infallible, fmt, io}; use hyper::{upgrade::OnUpgrade, Version}; use mirrord_protocol::{ @@ -7,12 +7,23 @@ use mirrord_protocol::{ }; use thiserror::Error; -#[derive(Debug)] pub enum InProxyTaskMessage { Tcp(Vec), Http(HttpOut), } +impl fmt::Debug for InProxyTaskMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Tcp(data) => f + .debug_tuple("Tcp") + .field(&format_args!("{} bytes", data.len())) + .finish(), + Self::Http(msg) => f.debug_tuple("Http").field(msg).finish(), + } + } +} + #[derive(Debug)] pub enum HttpOut { ResponseBasic(HttpResponse>), From 609ecd19fab66369a7a5d9a7beb1a48b3cf0b37f Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Thu, 16 Jan 2025 22:56:18 +0100 Subject: [PATCH 23/60] Clippy and docs --- mirrord/intproxy/src/proxies/incoming.rs | 29 +++++++++---------- .../src/proxies/incoming/http/client_store.rs | 5 +--- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 5d17dfe4d39..4f3338a03f0 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -81,16 +81,17 @@ struct HttpGatewayHandle { /// Handles logic and state of the `incoming` feature. /// Run as a [`BackgroundTask`]. /// -/// Handles port subscriptions state of the connected layers. Utilizes multiple background tasks -/// ([`Interceptor`]s and [`HttpResponseReader`]s) to handle incoming connections. +/// Handles port subscriptions state of the connected layers. +/// Utilizes multiple background tasks ([`TcpProxyTask`]s and [`HttpGatewayTask`]s) to handle +/// incoming connections and requests. /// -/// Each connection is managed by a single [`Interceptor`], -/// that establishes a TCP connection with the user application's port and proxies data. +/// Each connection stolen/mirrored in whole is managed by a single [`TcpProxyTask`]. /// -/// Bodies of HTTP responses from the user application are polled by [`HttpResponseReader`]s. +/// Each request stolen with a filter is managed by a single [`HttpGatewayTask`]. /// -/// Incoming connections are created by the agent either explicitly ([`NewTcpConnection`] message) -/// or implicitly ([`HttpRequest`](mirrord_protocol::tcp::HttpRequest)). +/// Incoming connections are created by the agent either explicitly ([`NewTcpConnection`] message, +/// connections stolen/mirrord in whole) or implicitly ([`HttpRequest`] message, requests stolen +/// with a filter). #[derive(Default)] pub struct IncomingProxy { /// Active port subscriptions for all layers. @@ -531,14 +532,12 @@ impl IncomingProxy { )) .await; } - } else { - if self.mirror_tcp_proxies.remove(&connection_id).is_some() { - message_bus - .send(ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe( - connection_id, - ))) - .await; - } + } else if self.mirror_tcp_proxies.remove(&connection_id).is_some() { + message_bus + .send(ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe( + connection_id, + ))) + .await; } } diff --git a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs index 4857667ab53..6bafb0e619b 100644 --- a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs +++ b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs @@ -79,10 +79,7 @@ impl ClientStore { idle.client.handles_version(version) && idle.client.local_server_address() == server_addr }); - match position { - Some(position) => Some(guard.swap_remove(position)), - None => None, - } + position.map(|position| guard.swap_remove(position)) }; if let Some(ready) = ready { From 20c6fc2890ee687440170efe1f732a46dd3fd0cf Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Thu, 16 Jan 2025 23:24:27 +0100 Subject: [PATCH 24/60] More tracing --- mirrord/intproxy/src/proxies/incoming/http.rs | 4 +- .../src/proxies/incoming/http_gateway.rs | 112 ++++++++++++++++-- 2 files changed, 104 insertions(+), 12 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index 50d4b0ff06c..814fb5ebe0c 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -56,9 +56,7 @@ impl LocalHttpClient { }) } - /// Tries to send the given `request` to the user application's HTTP server. - /// - /// Retries on known errors (see [`LocalHttpError::can_retry`]). + /// Send the given `request` to the user application's HTTP server. #[tracing::instrument(level = Level::DEBUG, err(level = Level::WARN), ret)] pub async fn send_request( &mut self, diff --git a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs index 8a4fc255264..5f5a8b3407f 100644 --- a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs +++ b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs @@ -1,4 +1,10 @@ -use std::{convert::Infallible, net::SocketAddr, time::Duration}; +use std::{ + convert::Infallible, + error::Error, + fmt, + net::SocketAddr, + time::{Duration, Instant}, +}; use exponential_backoff::Backoff; use http_body_util::BodyExt; @@ -11,6 +17,7 @@ use mirrord_protocol::{ }, }; use tokio::time; +use tracing::Level; use super::{ http::{mirrord_error_response, ClientStore, LocalHttpError, ResponseMode, StreamingBody}, @@ -32,6 +39,16 @@ pub struct HttpGatewayTask { server_addr: SocketAddr, } +impl fmt::Debug for HttpGatewayTask { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HttpGatewayTask") + .field("request", &self.request) + .field("response_mode", &self.response_mode) + .field("server_addr", &self.server_addr) + .finish() + } +} + impl HttpGatewayTask { /// Creates a new gateway task. pub fn new( @@ -53,6 +70,7 @@ impl HttpGatewayTask { /// [`Err`] is handled in the caller and, if we run out of send attempts, converted to an error /// response. Because of this, this function should not return any error that happened after /// sending [`ChunkedResponse::Start`]. The agent would get a duplicated response. + #[tracing::instrument(level = Level::TRACE, skip_all, err(level = Level::WARN))] async fn send_attempt(&self, message_bus: &mut MessageBus) -> Result<(), LocalHttpError> { let mut client = self .client_store @@ -65,12 +83,19 @@ impl HttpGatewayTask { match self.response_mode { ResponseMode::Basic => { + let start = Instant::now(); let body: Vec = body .collect() .await .map_err(LocalHttpError::ReadBodyFailed)? .to_bytes() .into(); + tracing::trace!( + body_len = body.len(), + elapsed_ms = start.elapsed().as_millis(), + "Collected the whole response body", + ); + let response = HttpResponse { port: self.request.port, connection_id: self.request.connection_id, @@ -85,9 +110,16 @@ impl HttpGatewayTask { message_bus.send(HttpOut::ResponseBasic(response)).await } ResponseMode::Framed => { + let start = Instant::now(); let body = InternalHttpBody::from_body(body) .await .map_err(LocalHttpError::ReadBodyFailed)?; + tracing::trace!( + ?body, + elapsed_ms = start.elapsed().as_millis(), + "Collected the whole response body", + ); + let response = HttpResponse { port: self.request.port, connection_id: self.request.connection_id, @@ -109,6 +141,11 @@ impl HttpGatewayTask { .into_iter() .map(InternalHttpBodyFrame::from) .collect(); + tracing::trace!( + ?ready_frames, + "Some response body frames were instantly ready" + ); + let response = HttpResponse { port: self.request.port, connection_id: self.request.connection_id, @@ -125,30 +162,48 @@ impl HttpGatewayTask { .await; loop { + let start = Instant::now(); match body.next_frames().await { Ok(frames) => { + let body_finished = frames.is_last; + let frames = frames + .frames + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect::>(); + tracing::trace!( + ?frames, + body_finished, + elapsed_ms = start.elapsed().as_millis(), + "Received next response body frames", + ); + message_bus .send(HttpOut::ResponseChunked(ChunkedResponse::Body( ChunkedHttpBody { - frames: frames - .frames - .into_iter() - .map(InternalHttpBodyFrame::from) - .collect(), - is_last: frames.is_last, + frames: frames, + is_last: body_finished, connection_id: self.request.connection_id, request_id: self.request.request_id, }, ))) .await; - if frames.is_last { + + if body_finished { break; } } // Do not return any error here, // as it would be transformed into an error response by the caller. // We already send the request head to the agent. - Err(..) => { + Err(error) => { + tracing::warn!( + error = ?ErrorWithSources(&error), + elapsed_ms = start.elapsed().as_millis(), + gateway = ?self, + "Failed to read next response body frames", + ); + message_bus .send(HttpOut::ResponseChunked(ChunkedResponse::Error( ChunkedHttpError { @@ -180,12 +235,16 @@ impl BackgroundTask for HttpGatewayTask { type MessageIn = Infallible; type MessageOut = InProxyTaskMessage; + #[tracing::instrument(level = Level::TRACE, name = "http_gateway_task_main_loop", skip(message_bus))] async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { let mut backoffs = Backoff::new(10, Duration::from_millis(50), Duration::from_millis(500)).into_iter(); let guard = message_bus.closed(); + let mut attempt = 0; let error = loop { + attempt += 1; + tracing::trace!(attempt, "Starting send attempt"); match guard.cancel_on_close(self.send_attempt(message_bus)).await { None | Some(Ok(())) => return Ok(()), Some(Err(error)) => { @@ -195,9 +254,23 @@ impl BackgroundTask for HttpGatewayTask { .flatten() .flatten(); let Some(backoff) = backoff else { + tracing::warn!( + gateway = ?self, + failed_attempts = attempt, + error = ?ErrorWithSources(&error), + "Failed to send an HTTP request", + ); + break error; }; + tracing::trace!( + backoff_ms = backoff.as_millis(), + failed_attempts = attempt, + error = ?ErrorWithSources(&error), + "Trying again after backoff", + ); + if guard.cancel_on_close(time::sleep(backoff)).await.is_none() { return Ok(()); } @@ -218,6 +291,27 @@ impl BackgroundTask for HttpGatewayTask { } } +/// Helper struct for tracing an [`Error`] along with all its sources, +/// down to the root cause. +/// +/// Might help when inspecting [`hyper`] errors. +struct ErrorWithSources<'a>(&'a dyn Error); + +impl fmt::Debug for ErrorWithSources<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut list = f.debug_list(); + list.entry(&self.0); + + let mut source = self.0.source(); + while let Some(error) = source { + list.entry(&error); + source = error.source(); + } + + list.finish() + } +} + #[cfg(test)] mod test { use std::{io, sync::Arc}; From 053599c6d505a12b758e171bae9691434d57aa90 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Fri, 17 Jan 2025 13:24:30 +0100 Subject: [PATCH 25/60] Clippy --- .../intproxy/src/proxies/incoming/http_gateway.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs index 5f5a8b3407f..1ee3ec5eaf3 100644 --- a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs +++ b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs @@ -165,7 +165,7 @@ impl HttpGatewayTask { let start = Instant::now(); match body.next_frames().await { Ok(frames) => { - let body_finished = frames.is_last; + let is_last = frames.is_last; let frames = frames .frames .into_iter() @@ -173,23 +173,23 @@ impl HttpGatewayTask { .collect::>(); tracing::trace!( ?frames, - body_finished, + is_last, elapsed_ms = start.elapsed().as_millis(), - "Received next response body frames", + "Received a next batch of response body frames", ); message_bus .send(HttpOut::ResponseChunked(ChunkedResponse::Body( ChunkedHttpBody { - frames: frames, - is_last: body_finished, + frames, + is_last, connection_id: self.request.connection_id, request_id: self.request.request_id, }, ))) .await; - if body_finished { + if is_last { break; } } From 5c343fb942b6da66dc597bfb46dacb62f34e3207 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Fri, 17 Jan 2025 13:58:05 +0100 Subject: [PATCH 26/60] Fixed TcpProxyTask --- .../src/proxies/incoming/tcp_proxy.rs | 58 +++++++++++++++++-- mirrord/layer/tests/common/mod.rs | 28 ++++----- mirrord/layer/tests/fileops.rs | 20 +++---- mirrord/layer/tests/http_mirroring.rs | 2 + 4 files changed, 81 insertions(+), 27 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs index 1850c1a11d3..0eeb36433f3 100644 --- a/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs +++ b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs @@ -8,6 +8,7 @@ use tokio::{ net::TcpStream, time, }; +use tracing::Level; use super::{ bound_socket::BoundTcpSocket, @@ -16,6 +17,7 @@ use super::{ use crate::background_tasks::{BackgroundTask, MessageBus}; /// Local TCP connections between the [`TcpProxyTask`] and the user application. +#[derive(Debug)] pub enum LocalTcpConnection { /// Not yet established. Should be made by the [`TcpProxyTask`] from the given /// [`BoundTcpSocket`]. @@ -29,6 +31,7 @@ pub enum LocalTcpConnection { /// [`BackgroundTask`] of [`IncomingProxy`](super::IncomingProxy) that handles a remote /// stolen/mirrored TCP connection. +#[derive(Debug)] pub struct TcpProxyTask { /// The local connection between this task and the user application. connection: LocalTcpConnection, @@ -55,6 +58,7 @@ impl BackgroundTask for TcpProxyTask { type MessageIn = Vec; type MessageOut = InProxyTaskMessage; + #[tracing::instrument(level = Level::TRACE, name = "tcp_proxy_task_main_loop", skip(message_bus), err(level = Level::WARN))] async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { let mut stream = match self.connection { LocalTcpConnection::FromTheStart { socket, peer } => { @@ -85,6 +89,9 @@ impl BackgroundTask for TcpProxyTask { } }; + let peer_addr = stream.peer_addr()?; + let self_addr = stream.local_addr()?; + let mut buf = BytesMut::with_capacity(64 * 1024); let mut reading_closed = false; let mut is_lingering = false; @@ -97,6 +104,19 @@ impl BackgroundTask for TcpProxyTask { Ok(..) => { if buf.is_empty() { reading_closed = true; + + tracing::trace!( + peer_addr = %peer_addr, + self_addr = %self_addr, + "The user application shut down its side of the connection", + ) + } else { + tracing::trace!( + data_len = buf.len(), + peer_addr = %peer_addr, + self_addr = %self_addr, + "Received some data from the user application", + ); } if !self.discard_data { @@ -108,23 +128,53 @@ impl BackgroundTask for TcpProxyTask { }, msg = message_bus.recv(), if !is_lingering => match msg { - None => { - if self.discard_data { - break Ok(()); - } + None if self.discard_data => { + tracing::trace!( + peer_addr = %peer_addr, + self_addr = %self_addr, + "Message bus closed, waiting until the connection is silent", + ); is_lingering = true; } + None => { + tracing::trace!( + peer_addr = %peer_addr, + self_addr = %self_addr, + "Message bus closed, exiting", + ); + + break Ok(()); + } Some(data) => { if data.is_empty() { + tracing::trace!( + peer_addr = %peer_addr, + self_addr = %self_addr, + "The agent shut down its side of the connection", + ); + stream.shutdown().await?; } else { + tracing::trace!( + data_len = data.len(), + peer_addr = %peer_addr, + self_addr = %self_addr, + "Received some data from the agent", + ); + stream.write_all(&data).await?; } }, }, _ = time::sleep(Duration::from_secs(1)), if is_lingering => { + tracing::trace!( + peer_addr = %peer_addr, + self_addr = %self_addr, + "Message bus is closed and the connection is silent, exiting", + ); + break Ok(()); } } diff --git a/mirrord/layer/tests/common/mod.rs b/mirrord/layer/tests/common/mod.rs index 303231cf3da..d304642d28b 100644 --- a/mirrord/layer/tests/common/mod.rs +++ b/mirrord/layer/tests/common/mod.rs @@ -44,7 +44,7 @@ pub const RUST_OUTGOING_LOCAL: &str = "4.4.4.4:4444"; /// /// We take advantage of how Rust's thread naming scheme for tests to create the log files, /// and if we have no thread name, then we just write the logs to `stderr`. -pub fn init_tracing() -> Result> { +pub fn init_tracing() -> DefaultGuard { let subscriber = tracing_subscriber::fmt() .with_env_filter(EnvFilter::new("mirrord=trace")) .without_time() @@ -61,7 +61,7 @@ pub fn init_tracing() -> Result> { .map(|name| name.replace(':', "_")) { Some(test_name) => { - let mut logs_file = PathBuf::from_str("/tmp/intproxy_logs")?; + let mut logs_file = PathBuf::from("/tmp/intproxy_logs"); #[cfg(target_os = "macos")] logs_file.push("macos"); @@ -71,26 +71,28 @@ pub fn init_tracing() -> Result> { let _ = std::fs::create_dir_all(&logs_file).ok(); logs_file.push(&test_name); - match File::create(logs_file) { + match File::create(&logs_file) { + // Writes the logs to the file. Ok(file) => { + println!("Created intproxy log file: {}", logs_file.display()); let subscriber = subscriber.with_writer(Arc::new(file)).finish(); - - // Writes the logs to a file. - Ok(tracing::subscriber::set_default(subscriber)) + tracing::subscriber::set_default(subscriber) } - Err(_) => { + // File creation failure makes the output go to `stderr`. + Err(error) => { + println!("Failed to create intproxy log file at {}: {error}. Intproxy logs will be flushed to stderr", logs_file.display()); let subscriber = subscriber.with_writer(io::stderr).finish(); - - // File creation failure makes the output go to `stderr`. - Ok(tracing::subscriber::set_default(subscriber)) + tracing::subscriber::set_default(subscriber) } } } + // No thread name makes the output go to `stderr`. None => { + println!( + "Failed to obtain current thread name, intproxy logs will be flushed to stderr" + ); let subscriber = subscriber.with_writer(io::stderr).finish(); - - // No thread name makes the output go to `stderr`. - Ok(tracing::subscriber::set_default(subscriber)) + tracing::subscriber::set_default(subscriber) } } } diff --git a/mirrord/layer/tests/fileops.rs b/mirrord/layer/tests/fileops.rs index 5a9ee96f726..b2fd714a557 100644 --- a/mirrord/layer/tests/fileops.rs +++ b/mirrord/layer/tests/fileops.rs @@ -44,7 +44,7 @@ async fn self_open( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer(dylib_path, vec![], None) @@ -108,7 +108,7 @@ async fn read_from_mirrord_bin(dylib_path: &Path) { #[tokio::test] #[timeout(Duration::from_secs(60))] async fn pwrite(#[values(Application::RustFileOps)] application: Application, dylib_path: &Path) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); // add rw override for the specific path let (mut test_process, mut intproxy) = application @@ -228,7 +228,7 @@ async fn node_close( #[values(Application::NodeFileOps)] application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer( @@ -295,7 +295,7 @@ async fn go_stat( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); // add rw override for the specific path let (mut test_process, mut intproxy) = application @@ -478,7 +478,7 @@ async fn go_dir_on_linux( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer( @@ -575,7 +575,7 @@ async fn go_dir_bypass( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let tmp_dir = temp_dir().join("go_dir_bypass_test"); std::fs::create_dir_all(tmp_dir.clone()).unwrap(); @@ -616,7 +616,7 @@ async fn read_go( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer(dylib_path, vec![("MIRRORD_FILE_MODE", "read")], None) @@ -658,7 +658,7 @@ async fn write_go( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut layer_connection) = application .start_process_with_layer(dylib_path, get_rw_test_file_env_vars(), None) @@ -687,7 +687,7 @@ async fn lseek_go( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer(dylib_path, get_rw_test_file_env_vars(), None) @@ -722,7 +722,7 @@ async fn faccessat_go( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer(dylib_path, get_rw_test_file_env_vars(), None) diff --git a/mirrord/layer/tests/http_mirroring.rs b/mirrord/layer/tests/http_mirroring.rs index e37d433a0e3..96f2230ed18 100644 --- a/mirrord/layer/tests/http_mirroring.rs +++ b/mirrord/layer/tests/http_mirroring.rs @@ -30,6 +30,8 @@ async fn mirroring_with_http( dylib_path: &Path, config_dir: &Path, ) { + let _guard = init_tracing(); + let (mut test_process, mut intproxy) = application .start_process_with_layer_and_port( dylib_path, From 0a346a41b043389dc138f2ac0af350aa322d8fe0 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Fri, 17 Jan 2025 14:08:36 +0100 Subject: [PATCH 27/60] More IncomingProxy docs --- mirrord/intproxy/src/proxies/incoming.rs | 33 ++++++++++++++++++++---- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 4f3338a03f0..d0d3fc88859 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -85,13 +85,36 @@ struct HttpGatewayHandle { /// Utilizes multiple background tasks ([`TcpProxyTask`]s and [`HttpGatewayTask`]s) to handle /// incoming connections and requests. /// -/// Each connection stolen/mirrored in whole is managed by a single [`TcpProxyTask`]. +/// # Connections stolen/mirrored in whole /// -/// Each request stolen with a filter is managed by a single [`HttpGatewayTask`]. +/// Each such connection exists in two places: /// -/// Incoming connections are created by the agent either explicitly ([`NewTcpConnection`] message, -/// connections stolen/mirrord in whole) or implicitly ([`HttpRequest`] message, requests stolen -/// with a filter). +/// 1. Here, between the intproxy and the user application. Managed by a single [`TcpProxyTask`]. +/// 2. In the cluster, between the agent and the original TCP client. +/// +/// We are notified about such connections with the [`NewTcpConnection`] message. +/// +/// # Requests stolen with a filter +/// +/// In the cluster, we have a real persistent connection between the agent and the original HTTP +/// client. From this connection, intproxy receives a subset of requests. +/// +/// Locally, we don't have a concept of a filered connection. +/// Each request is handled independently by a single [`HttpGatewayTask`]. +/// Also: +/// 1. Local HTTP connections are reused when possible. +/// 2. Unless the error is fatal, each requests are retried a couple times. +/// 3. We never send [`LayerTcpSteal::ConnectionUnsubscribe`] (due to requests being handled +/// independently). If a request fails locally, we send a +/// [`StatusCode::BAD_GATEWAY`](hyper::http::StatusCode::BAD_GATEWAY) response. +/// +/// We are notified about stolen requests with the [`HttpRequest`] messages. +/// +/// # HTTP upgrades +/// +/// An HTTP request stolen with a filter can result in an HTTP upgrade. +/// When this happens, the TPC connection is recovered and passed to a new [`TcpProxyTask`]. +/// The TCP connection is then treated as stolen in whole. #[derive(Default)] pub struct IncomingProxy { /// Active port subscriptions for all layers. From 7458a3840c9e114a392af0d9fd8672ea0d50e8c1 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Fri, 17 Jan 2025 14:09:43 +0100 Subject: [PATCH 28/60] macos tests fixed --- mirrord/layer/tests/fileops.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mirrord/layer/tests/fileops.rs b/mirrord/layer/tests/fileops.rs index b2fd714a557..5daacdadc50 100644 --- a/mirrord/layer/tests/fileops.rs +++ b/mirrord/layer/tests/fileops.rs @@ -65,7 +65,7 @@ async fn self_open( #[tokio::test] #[timeout(Duration::from_secs(20))] async fn read_from_mirrord_bin(dylib_path: &Path) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let contents = "please don't flake"; let temp_dir = env::temp_dir(); @@ -358,7 +358,7 @@ async fn go_dir( application: Application, dylib_path: &Path, ) { - let _tracing = init_tracing().unwrap(); + let _tracing = init_tracing(); let (mut test_process, mut intproxy) = application .start_process_with_layer( From 7b8a14fda564921e225dfbc28dc224b5149594df Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Fri, 17 Jan 2025 16:28:13 +0100 Subject: [PATCH 29/60] Fixed reverseportforwarder and its tests --- mirrord/cli/src/port_forward.rs | 667 +++++++----------- mirrord/intproxy/src/proxies/incoming.rs | 2 +- .../proxies/incoming/http/response_mode.rs | 5 + .../src/proxies/incoming/http_gateway.rs | 39 +- .../proxies/incoming/port_subscription_ext.rs | 15 +- 5 files changed, 305 insertions(+), 423 deletions(-) diff --git a/mirrord/cli/src/port_forward.rs b/mirrord/cli/src/port_forward.rs index 002a904ae1b..9196d200663 100644 --- a/mirrord/cli/src/port_forward.rs +++ b/mirrord/cli/src/port_forward.rs @@ -1,6 +1,6 @@ use std::{ collections::{HashMap, HashSet, VecDeque}, - net::{IpAddr, SocketAddr}, + net::{IpAddr, Ipv4Addr, SocketAddr}, time::{Duration, Instant}, }; @@ -11,12 +11,8 @@ use mirrord_config::feature::network::incoming::{ }; use mirrord_intproxy::{ background_tasks::{BackgroundTasks, TaskError, TaskSender, TaskUpdate}, - error::IntProxyError, - main_tasks::{MainTaskId, ProxyMessage, ToLayer}, - proxies::incoming::{ - port_subscription_ext::PortSubscriptionExt, IncomingProxy, IncomingProxyError, - IncomingProxyMessage, - }, + main_tasks::{ProxyMessage, ToLayer}, + proxies::incoming::{IncomingProxy, IncomingProxyError, IncomingProxyMessage}, }; use mirrord_intproxy_protocol::{ IncomingRequest, IncomingResponse, LayerId, PortSubscribe, PortSubscription, @@ -28,7 +24,7 @@ use mirrord_protocol::{ tcp::{DaemonTcpOutgoing, LayerTcpOutgoing}, LayerClose, LayerConnect, LayerWrite, SocketAddress, }, - tcp::{Filter, HttpFilter, LayerTcp, LayerTcpSteal, StealType}, + tcp::{Filter, HttpFilter, StealType}, ClientMessage, ConnectionId, DaemonMessage, LogLevel, Port, CLIENT_READY_FOR_LOGS, }; use thiserror::Error; @@ -426,44 +422,57 @@ impl PortForwarder { } pub struct ReversePortForwarder { - /// details for traffic mirroring or stealing - incoming_mode: IncomingMode, /// communicates with the agent (only TCP supported). agent_connection: AgentConnection, - /// associates destination ports with local ports. - mappings: HashMap, - /// background task (uses IncomingProxy to communicate with layer) - background_tasks: BackgroundTasks, + /// background task (uses [`IncomingProxy`] to communicate with layer) + background_tasks: BackgroundTasks<(), ProxyMessage, IncomingProxyError>, /// incoming proxy background task tx incoming_proxy: TaskSender, - - /// true if Ping has been sent to agent. + /// `true` if [`ClientMessage::Ping`] has been sent to agent and we're waiting for the the + /// [`DaemonMessage::Pong`] waiting_for_pong: bool, ping_pong_timeout: Instant, } impl ReversePortForwarder { pub(crate) async fn new( - agent_connection: AgentConnection, + mut agent_connection: AgentConnection, mappings: HashMap, network_config: IncomingConfig, ) -> Result { - // setup IncomingProxy - let mut background_tasks: BackgroundTasks = + let mut background_tasks: BackgroundTasks<(), ProxyMessage, IncomingProxyError> = Default::default(); - let incoming = - background_tasks.register(IncomingProxy::default(), MainTaskId::IncomingProxy, 512); - // construct IncomingMode from config file + let incoming = background_tasks.register(IncomingProxy::default(), (), 512); + + agent_connection + .sender + .send(ClientMessage::SwitchProtocolVersion( + mirrord_protocol::VERSION.clone(), + )) + .await?; + let protocol_version = match agent_connection.receiver.recv().await { + Some(DaemonMessage::SwitchProtocolVersionResponse(version)) => version, + _ => return Err(PortForwardError::AgentConnectionFailed), + }; + + if CLIENT_READY_FOR_LOGS.matches(&protocol_version) { + agent_connection + .sender + .send(ClientMessage::ReadyForLogs) + .await?; + } + + incoming + .send(IncomingProxyMessage::AgentProtocolVersion(protocol_version)) + .await; + let incoming_mode = IncomingMode::new(&network_config); for (i, (&remote, &local)) in mappings.iter().enumerate() { - // send subscription to incoming proxy let subscription = incoming_mode.subscription(remote); let message_id = i as u64; let layer_id = LayerId(1); let req = IncomingRequest::PortSubscribe(PortSubscribe { - listening_on: format!("127.0.0.1:{local}") - .parse() - .expect("Error parsing socket address"), + listening_on: SocketAddr::new(Ipv4Addr::LOCALHOST.into(), local), subscription, }); incoming @@ -474,9 +483,7 @@ impl ReversePortForwarder { } Ok(Self { - incoming_mode, agent_connection, - mappings, background_tasks, incoming_proxy: incoming, waiting_for_pong: false, @@ -485,31 +492,6 @@ impl ReversePortForwarder { } pub(crate) async fn run(&mut self) -> Result<(), PortForwardError> { - // setup agent connection - self.agent_connection - .sender - .send(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )) - .await?; - match self.agent_connection.receiver.recv().await { - Some(DaemonMessage::SwitchProtocolVersionResponse(version)) - if CLIENT_READY_FOR_LOGS.matches(&version) => - { - self.agent_connection - .sender - .send(ClientMessage::ReadyForLogs) - .await?; - } - _ => return Err(PortForwardError::AgentConnectionFailed), - } - - for remote_port in self.mappings.keys() { - let subscription = self.incoming_mode.subscription(*remote_port); - let msg = subscription.agent_subscribe(); - self.agent_connection.sender.send(msg).await? - } - loop { select! { _ = tokio::time::sleep_until(self.ping_pong_timeout.into()) => { @@ -530,8 +512,8 @@ impl ReversePortForwarder { }, }, - Some((task_id, update)) = self.background_tasks.next() => { - self.handle_msg_from_local(task_id, update).await? + Some((_, update)) = self.background_tasks.next() => { + self.handle_msg_from_local(update).await? }, } } @@ -563,8 +545,8 @@ impl ReversePortForwarder { DaemonMessage::Pong if self.waiting_for_pong => { self.waiting_for_pong = false; } + // Includes unexpected DaemonMessage::Pong other => { - // includes unexepcted DaemonMessage::Pong return Err(PortForwardError::AgentError(format!( "unexpected message from agent: {other:?}" ))); @@ -577,20 +559,11 @@ impl ReversePortForwarder { #[tracing::instrument(level = Level::TRACE, skip(self), err)] async fn handle_msg_from_local( &mut self, - task_id: MainTaskId, - update: TaskUpdate, + update: TaskUpdate, ) -> Result<(), PortForwardError> { - match (task_id, update) { - (MainTaskId::IncomingProxy, TaskUpdate::Message(message)) => match message { + match update { + TaskUpdate::Message(message) => match message { ProxyMessage::ToAgent(message) => { - if matches!( - message, - ClientMessage::TcpSteal(LayerTcpSteal::PortSubscribe(_)) - | ClientMessage::Tcp(LayerTcp::PortSubscribe(_)) - ) { - // suppress additional subscription requests - return Ok(()); - } self.agent_connection.sender.send(message).await?; } ProxyMessage::ToLayer(ToLayer { @@ -598,9 +571,7 @@ impl ReversePortForwarder { .. }) => { if let Err(error) = res { - return Err(PortForwardError::from(IntProxyError::from( - IncomingProxyError::SubscriptionFailed(error), - ))); + return Err(IncomingProxyError::SubscriptionFailed(error).into()); } } other => { @@ -609,21 +580,18 @@ impl ReversePortForwarder { ) } }, - (MainTaskId::IncomingProxy, TaskUpdate::Finished(result)) => match result { + + TaskUpdate::Finished(result) => match result { Ok(()) => { - tracing::error!("incoming proxy task finished unexpectedly"); - return Err(IntProxyError::TaskExit(task_id).into()); + return Err(PortForwardError::IncomingProxyExited); } Err(TaskError::Error(e)) => { - tracing::error!("incoming proxy task failed: {e}"); return Err(e.into()); } Err(TaskError::Panic) => { - tracing::error!("incoming proxy task panicked"); - return Err(IntProxyError::TaskPanic(task_id).into()); + return Err(PortForwardError::IncomingProxyPanicked); } }, - _ => unreachable!("other task types are never used in port forwarding"), } Ok(()) @@ -974,15 +942,20 @@ pub enum PortForwardError { #[error("multiple port forwarding mappings found for desination port `{0:?}`")] ReversePortMapSetupError(RemotePort), - // running errors #[error("agent closed connection with error: `{0}`")] AgentError(String), #[error("connection with the agent failed")] AgentConnectionFailed, - #[error("error from Incoming Proxy task")] - IncomingProxyError(IntProxyError), + #[error("error from the IncomingProxy task: {0}")] + IncomingProxyError(#[from] IncomingProxyError), + + #[error("IncomingProxy task unexpectedly exited")] + IncomingProxyExited, + + #[error("IncomingProxy task panicked")] + IncomingProxyPanicked, #[error("TcpListener operation failed with error: `{0}`")] TcpListenerError(std::io::Error), @@ -1003,12 +976,6 @@ impl From> for PortForwardError { } } -impl From for PortForwardError { - fn from(value: IntProxyError) -> Self { - Self::IncomingProxyError(value) - } -} - #[cfg(test)] mod test { use std::{ @@ -1024,9 +991,9 @@ mod test { DaemonConnect, DaemonRead, LayerConnect, LayerWrite, SocketAddress, }, tcp::{ - DaemonTcp, Filter, HttpRequest, HttpResponse, InternalHttpRequest, - InternalHttpResponse, LayerTcp, LayerTcpSteal, NewTcpConnection, StealType, TcpClose, - TcpData, + DaemonTcp, Filter, HttpRequest, HttpResponse, InternalHttpBody, InternalHttpBodyFrame, + InternalHttpRequest, InternalHttpResponse, LayerTcp, LayerTcpSteal, NewTcpConnection, + StealType, TcpClose, TcpData, }, ClientMessage, DaemonMessage, }; @@ -1044,90 +1011,148 @@ mod test { RemoteAddr, }; + /// Connects [`ReversePortForwarder`] with test code with [`ClientMessage`] and + /// [`DaemonMessage`] channels. Runs a background [`tokio::task`] that auto responds to + /// standard [`mirrord_protocol`] messages (e.g [`ClientMessage::Ping`]). + struct TestAgentConnection { + daemon_msg_tx: mpsc::Sender, + client_msg_rx: mpsc::Receiver, + } + + impl TestAgentConnection { + fn new() -> (Self, AgentConnection) { + let (daemon_to_forwarder, daemon_from_forwarder) = mpsc::channel::(8); + let (client_task_to_test, client_task_from_test) = mpsc::channel::(8); + let (client_forwarder_to_task, client_task_from_forwarder) = + mpsc::channel::(8); + + tokio::spawn(Self::auto_responder( + client_task_from_forwarder, + client_task_to_test, + daemon_to_forwarder.clone(), + )); + + ( + Self { + daemon_msg_tx: daemon_to_forwarder, + client_msg_rx: client_task_from_test, + }, + AgentConnection { + sender: client_forwarder_to_task, + receiver: daemon_from_forwarder, + }, + ) + } + + /// Sends the [`DaemonMessage`] to the [`ReversePortForwarder`]. + async fn send(&self, message: DaemonMessage) { + self.daemon_msg_tx.send(message).await.unwrap(); + } + + /// Receives a [`ClientMessage`] from the [`ReversePortForwarder`]. + /// + /// Some standard messages are handled internally and are never returned: + /// 1. [`ClientMessage::Ping`] + /// 2. [`ClientMessage::SwitchProtocolVersion`] + /// 3. [`ClientMessage::ReadyForLogs`] + async fn recv(&mut self) -> ClientMessage { + self.client_msg_rx.recv().await.unwrap() + } + + async fn auto_responder( + mut rx: mpsc::Receiver, + tx_to_test_code: mpsc::Sender, + tx_to_port_forwarder: mpsc::Sender, + ) { + loop { + let Some(message) = rx.recv().await else { + break; + }; + + match message { + ClientMessage::Ping => { + if tx_to_port_forwarder + .send(DaemonMessage::Pong) + .await + .is_err() + { + break; + } + } + ClientMessage::ReadyForLogs => {} + ClientMessage::SwitchProtocolVersion(version) => { + if tx_to_port_forwarder + .send(DaemonMessage::SwitchProtocolVersionResponse( + std::cmp::min(&version, &*mirrord_protocol::VERSION).clone(), + )) + .await + .is_err() + { + break; + } + } + other => tx_to_test_code.send(other).await.unwrap(), + } + } + } + } + #[tokio::test] async fn single_port_forwarding() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let local_destination = listener.local_addr().unwrap(); drop(listener); - let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); - let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); + let (mut test_connection, agent_connection) = TestAgentConnection::new(); - let agent_connection = AgentConnection { - sender: client_msg_tx, - receiver: daemon_msg_rx, - }; - let remote_destination = (RemoteAddr::Ip("152.37.40.40".parse().unwrap()), 3038); + let remote_ip = "152.37.40.40".parse::().unwrap(); + let remote_destination = (RemoteAddr::Ip(remote_ip), 3038); let mappings = HashMap::from([(local_destination, remote_destination.clone())]); - tokio::spawn(async move { - let mut port_forwarder = PortForwarder::new(agent_connection, mappings) - .await - .unwrap(); - port_forwarder.run().await.unwrap() - }); - - // expect handshake procedure - let expected = Some(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - mirrord_protocol::VERSION.clone(), - )) + // Prepare listeners before sending work to the background task. + let mut port_forwarder = PortForwarder::new(agent_connection, mappings) .await .unwrap(); - let expected = Some(ClientMessage::ReadyForLogs); - assert_eq!(client_msg_rx.recv().await, expected); + tokio::spawn(async move { port_forwarder.run().await.unwrap() }); - // send data to socket + // Connect to PortForwarders listener and send some data to trigger remote connection + // request. let mut stream = TcpStream::connect(local_destination).await.unwrap(); stream.write_all(b"data-my-beloved").await.unwrap(); - // expect Connect on client_msg_rx - let remote_address = SocketAddress::Ip("152.37.40.40:3038".parse().unwrap()); + // Expect a connection request + let remote_address = SocketAddress::Ip(SocketAddr::new(remote_ip.into(), 3038)); let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Connect(LayerConnect { remote_address: remote_address.clone(), })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected); + assert_eq!(test_connection.recv().await, expected,); // reply with successful on daemon_msg_tx - daemon_msg_tx + test_connection .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Connect(Ok( DaemonConnect { connection_id: 1, - remote_address: remote_address.clone(), - local_address: remote_address, + remote_address, + local_address: "1.2.3.4:2137".parse::().unwrap().into(), }, )))) - .await - .unwrap(); + .await; let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Write(LayerWrite { connection_id: 1, bytes: b"data-my-beloved".to_vec(), })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected); + assert_eq!(test_connection.recv().await, expected); // send response data from agent on daemon_msg_tx - daemon_msg_tx + test_connection .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Read(Ok( DaemonRead { connection_id: 1, bytes: b"reply-my-beloved".to_vec(), }, )))) - .await - .unwrap(); + .await; // check data arrives at local let mut buf = [0; 16]; @@ -1147,38 +1172,17 @@ mod test { let local_destination_2 = listener.local_addr().unwrap(); drop(listener); - let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); - let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); - - let agent_connection = AgentConnection { - sender: client_msg_tx, - receiver: daemon_msg_rx, - }; + let (mut test_connection, agent_connection) = TestAgentConnection::new(); let mappings = HashMap::from([ (local_destination_1, remote_destination_1.clone()), (local_destination_2, remote_destination_2.clone()), ]); - tokio::spawn(async move { - let mut port_forwarder = PortForwarder::new(agent_connection, mappings) - .await - .unwrap(); - port_forwarder.run().await.unwrap() - }); - - // expect handshake procedure - let expected = Some(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - mirrord_protocol::VERSION.clone(), - )) + // Prepare listeners before sending work to the background task. + let mut port_forwarder = PortForwarder::new(agent_connection, mappings) .await .unwrap(); - let expected = Some(ClientMessage::ReadyForLogs); - assert_eq!(client_msg_rx.recv().await, expected); + tokio::spawn(async move { port_forwarder.run().await.unwrap() }); // send data to first socket let mut stream_1 = TcpStream::connect(local_destination_1).await.unwrap(); @@ -1194,11 +1198,7 @@ mod test { let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Connect(LayerConnect { remote_address: remote_address_1.clone(), })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected); + assert_eq!(test_connection.recv().await, expected); // send data to second socket let mut stream_2 = TcpStream::connect(local_destination_2).await.unwrap(); @@ -1212,14 +1212,10 @@ mod test { let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Connect(LayerConnect { remote_address: remote_address_2.clone(), })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected); + assert_eq!(test_connection.recv().await, expected); // reply with successful on each daemon_msg_tx - daemon_msg_tx + test_connection .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Connect(Ok( DaemonConnect { connection_id: 1, @@ -1227,9 +1223,8 @@ mod test { local_address: remote_address_1, }, )))) - .await - .unwrap(); - daemon_msg_tx + .await; + test_connection .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Connect(Ok( DaemonConnect { connection_id: 2, @@ -1237,49 +1232,38 @@ mod test { local_address: remote_address_2, }, )))) - .await - .unwrap(); + .await; // expect data to be received let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Write(LayerWrite { connection_id: 1, bytes: b"data-from-1".to_vec(), })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected); + assert_eq!(test_connection.recv().await, expected); let expected = ClientMessage::TcpOutgoing(LayerTcpOutgoing::Write(LayerWrite { connection_id: 2, bytes: b"data-from-2".to_vec(), })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected); + assert_eq!(test_connection.recv().await, expected); // send each data response from agent on daemon_msg_tx - daemon_msg_tx + test_connection .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Read(Ok( DaemonRead { connection_id: 1, bytes: b"reply-to-1".to_vec(), }, )))) - .await - .unwrap(); - daemon_msg_tx + .await; + test_connection .send(DaemonMessage::TcpOutgoing(DaemonTcpOutgoing::Read(Ok( DaemonRead { connection_id: 2, bytes: b"reply-to-2".to_vec(), }, )))) - .await - .unwrap(); + .await; // check data arrives at each local addr let mut buf = [0; 10]; @@ -1297,54 +1281,33 @@ mod test { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let local_destination = listener.local_addr().unwrap(); - let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); - let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); - - let agent_connection = AgentConnection { - sender: client_msg_tx, - receiver: daemon_msg_rx, - }; let remote_address = IpAddr::from("152.37.40.40".parse::().unwrap()); let destination_port = 3038; let mappings = HashMap::from([(destination_port, local_destination.port())]); let network_config = IncomingConfig::default(); + let (mut test_connection, agent_connection) = TestAgentConnection::new(); + tokio::spawn(async move { - let mut port_forwarder = - ReversePortForwarder::new(agent_connection, mappings, network_config) - .await - .unwrap(); - port_forwarder.run().await.unwrap() + ReversePortForwarder::new(agent_connection, mappings, network_config) + .await + .unwrap() + .run() + .await + .unwrap() }); - // expect handshake procedure - let expected = Some(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - mirrord_protocol::VERSION.clone(), - )) - .await - .unwrap(); - let expected = Some(ClientMessage::ReadyForLogs); - assert_eq!(client_msg_rx.recv().await, expected); - // expect port subscription for remote port and send subscribe result - let expected = Some(ClientMessage::Tcp(LayerTcp::PortSubscribe( - destination_port, - ))); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx + let expected = ClientMessage::Tcp(LayerTcp::PortSubscribe(destination_port)); + assert_eq!(test_connection.recv().await, expected); + test_connection .send(DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok( destination_port, )))) - .await - .unwrap(); + .await; // send new connection from agent and some data - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::NewConnection( NewTcpConnection { connection_id: 1, @@ -1354,17 +1317,15 @@ mod test { local_address: local_destination.ip(), }, ))) - .await - .unwrap(); + .await; let mut stream = listener.accept().await.unwrap().0; - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Data(TcpData { connection_id: 1, bytes: b"data-my-beloved".to_vec(), }))) - .await - .unwrap(); + .await; // check data arrives at local let mut buf = [0; 15]; @@ -1372,12 +1333,11 @@ mod test { assert_eq!(buf, b"data-my-beloved".as_ref()); // ensure graceful behaviour on close - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { connection_id: 1, }))) - .await - .unwrap(); + .await; } #[rstest] @@ -1387,13 +1347,6 @@ mod test { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let local_destination = listener.local_addr().unwrap(); - let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); - let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); - - let agent_connection = AgentConnection { - sender: client_msg_tx, - receiver: daemon_msg_rx, - }; let remote_address = IpAddr::from("152.37.40.40".parse::().unwrap()); let destination_port = 3038; let mappings = HashMap::from([(destination_port, local_destination.port())]); @@ -1402,62 +1355,47 @@ mod test { ..Default::default() }; + let (mut test_connection, agent_connection) = TestAgentConnection::new(); tokio::spawn(async move { - let mut port_forwarder = - ReversePortForwarder::new(agent_connection, mappings, network_config) - .await - .unwrap(); - port_forwarder.run().await.unwrap() + ReversePortForwarder::new(agent_connection, mappings, network_config) + .await + .unwrap() + .run() + .await + .unwrap() }); - // expect handshake procedure - let expected = Some(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - mirrord_protocol::VERSION.clone(), - )) - .await - .unwrap(); - let expected = Some(ClientMessage::ReadyForLogs); - assert_eq!(client_msg_rx.recv().await, expected); - // expect port subscription for remote port and send subscribe result - let expected = Some(ClientMessage::TcpSteal(LayerTcpSteal::PortSubscribe( - StealType::All(destination_port), + let expected = ClientMessage::TcpSteal(LayerTcpSteal::PortSubscribe(StealType::All( + destination_port, ))); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok( + assert_eq!(test_connection.recv().await, expected); + test_connection + .send(DaemonMessage::TcpSteal(DaemonTcp::SubscribeResult(Ok( destination_port, )))) - .await - .unwrap(); + .await; // send new connection from agent and some data - daemon_msg_tx - .send(DaemonMessage::Tcp(DaemonTcp::NewConnection( + test_connection + .send(DaemonMessage::TcpSteal(DaemonTcp::NewConnection( NewTcpConnection { connection_id: 1, remote_address, destination_port, - source_port: local_destination.port(), - local_address: local_destination.ip(), + source_port: 2137, + local_address: "1.2.3.4".parse().unwrap(), }, ))) - .await - .unwrap(); + .await; let mut stream = listener.accept().await.unwrap().0; - daemon_msg_tx + test_connection .send(DaemonMessage::TcpSteal(DaemonTcp::Data(TcpData { connection_id: 1, bytes: b"data-my-beloved".to_vec(), }))) - .await - .unwrap(); + .await; // check data arrives at local let mut buf = [0; 15]; @@ -1466,12 +1404,8 @@ mod test { // check for response from local stream.write_all(b"reply-my-beloved").await.unwrap(); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; assert_eq!( - message, + test_connection.recv().await, ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { connection_id: 1, bytes: b"reply-my-beloved".to_vec() @@ -1479,12 +1413,11 @@ mod test { ); // ensure graceful behaviour on close - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { connection_id: 1, }))) - .await - .unwrap(); + .await; } #[rstest] @@ -1497,13 +1430,6 @@ mod test { let local_destination_1 = listener_1.local_addr().unwrap(); let local_destination_2 = listener_2.local_addr().unwrap(); - let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); - let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); - - let agent_connection = AgentConnection { - sender: client_msg_tx, - receiver: daemon_msg_rx, - }; let remote_address = IpAddr::from("152.37.40.40".parse::().unwrap()); let destination_port_1 = 3038; let destination_port_2 = 4048; @@ -1513,6 +1439,7 @@ mod test { ]); let network_config = IncomingConfig::default(); + let (mut test_connection, agent_connection) = TestAgentConnection::new(); tokio::spawn(async move { let mut port_forwarder = ReversePortForwarder::new(agent_connection, mappings, network_config) @@ -1521,48 +1448,29 @@ mod test { port_forwarder.run().await.unwrap() }); - // expect handshake procedure - let expected = Some(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - mirrord_protocol::VERSION.clone(), - )) - .await - .unwrap(); - let expected = Some(ClientMessage::ReadyForLogs); - assert_eq!(client_msg_rx.recv().await, expected); - // expect port subscription for each remote port and send subscribe result // matches! used because order may be random for _ in 0..2 { - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; + let message = test_connection.recv().await; assert!( matches!(message, ClientMessage::Tcp(LayerTcp::PortSubscribe(_))), "expected ClientMessage::Tcp(LayerTcp::PortSubscribe(_), received {message:?}" ); } - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok( destination_port_1, )))) - .await - .unwrap(); - daemon_msg_tx + .await; + test_connection .send(DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok( destination_port_2, )))) - .await - .unwrap(); + .await; // send new connections from agent and some data - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::NewConnection( NewTcpConnection { connection_id: 1, @@ -1572,11 +1480,10 @@ mod test { local_address: local_destination_1.ip(), }, ))) - .await - .unwrap(); + .await; let mut stream_1 = listener_1.accept().await.unwrap().0; - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::NewConnection( NewTcpConnection { connection_id: 2, @@ -1586,25 +1493,22 @@ mod test { local_address: local_destination_2.ip(), }, ))) - .await - .unwrap(); + .await; let mut stream_2 = listener_2.accept().await.unwrap().0; - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Data(TcpData { connection_id: 1, bytes: b"connection-1-my-beloved".to_vec(), }))) - .await - .unwrap(); + .await; - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Data(TcpData { connection_id: 2, bytes: b"connection-2-my-beloved".to_vec(), }))) - .await - .unwrap(); + .await; // check data arrives at local let mut buf = [0; 23]; @@ -1616,19 +1520,17 @@ mod test { assert_eq!(buf, b"connection-2-my-beloved".as_ref()); // ensure graceful behaviour on close - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { connection_id: 1, }))) - .await - .unwrap(); + .await; - daemon_msg_tx + test_connection .send(DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { connection_id: 2, }))) - .await - .unwrap(); + .await; } #[rstest] @@ -1640,14 +1542,6 @@ mod test { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let local_destination = listener.local_addr().unwrap(); - let (daemon_msg_tx, daemon_msg_rx) = mpsc::channel::(12); - let (client_msg_tx, mut client_msg_rx) = mpsc::channel::(12); - - let agent_connection = AgentConnection { - sender: client_msg_tx, - receiver: daemon_msg_rx, - }; - let remote_address = IpAddr::from("152.37.40.40".parse::().unwrap()); let destination_port = 8080; let mappings = HashMap::from([(destination_port, local_destination.port())]); let mut network_config = IncomingConfig { @@ -1656,6 +1550,8 @@ mod test { }; network_config.http_filter.header_filter = Some("header: value".to_string()); + let (mut test_connection, agent_connection) = TestAgentConnection::new(); + tokio::spawn(async move { let mut port_forwarder = ReversePortForwarder::new(agent_connection, mappings, network_config) @@ -1664,27 +1560,8 @@ mod test { port_forwarder.run().await.unwrap() }); - // expect handshake procedure - let expected = Some(ClientMessage::SwitchProtocolVersion( - mirrord_protocol::VERSION.clone(), - )); - assert_eq!(client_msg_rx.recv().await, expected); - daemon_msg_tx - .send(DaemonMessage::SwitchProtocolVersionResponse( - mirrord_protocol::VERSION.clone(), - )) - .await - .unwrap(); - let expected = Some(ClientMessage::ReadyForLogs); - assert_eq!(client_msg_rx.recv().await, expected); - - // expect port subscription for remote port and send subscribe result - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; assert_eq!( - message, + test_connection.recv().await, ClientMessage::TcpSteal(LayerTcpSteal::PortSubscribe(StealType::FilteredHttpEx( destination_port, mirrord_protocol::tcp::HttpFilter::Header( @@ -1692,27 +1569,11 @@ mod test { ) ),)) ); - daemon_msg_tx - .send(DaemonMessage::Tcp(DaemonTcp::SubscribeResult(Ok( + test_connection + .send(DaemonMessage::TcpSteal(DaemonTcp::SubscribeResult(Ok( destination_port, )))) - .await - .unwrap(); - - // send new connection from agent and some data - daemon_msg_tx - .send(DaemonMessage::TcpSteal(DaemonTcp::NewConnection( - NewTcpConnection { - connection_id: 1, - remote_address, - destination_port, - source_port: local_destination.port(), - local_address: local_destination.ip(), - }, - ))) - .await - .unwrap(); - let mut stream = listener.accept().await.unwrap().0; + .await; // send data from agent with correct header let mut headers = HeaderMap::new(); @@ -1724,23 +1585,22 @@ mod test { version: Version::HTTP_11, body: vec![], }; - daemon_msg_tx + test_connection .send(DaemonMessage::TcpSteal(DaemonTcp::HttpRequest( HttpRequest { internal_request, - connection_id: 1, - request_id: 1, - port: local_destination.port(), + connection_id: 0, + request_id: 0, + port: destination_port, }, ))) - .await - .unwrap(); + .await; + let mut stream = listener.accept().await.unwrap().0; // check data is read from stream let mut buf = [0; 15]; assert_eq!(buf, [0; 15]); stream.read_exact(&mut buf).await.unwrap(); - assert_ne!(buf, [0; 15]); // check for response from local stream @@ -1751,31 +1611,30 @@ mod test { let mut headers = HeaderMap::new(); headers.insert("content-length", "3".parse().unwrap()); let internal_response = InternalHttpResponse { - status: StatusCode::from_u16(200).unwrap(), + status: StatusCode::OK, version: Version::HTTP_11, headers, - body: b"yay".to_vec(), + body: InternalHttpBody( + [InternalHttpBodyFrame::Data(b"yay".into())] + .into_iter() + .collect(), + ), }; let expected_response = - ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse(HttpResponse { - connection_id: 1, - request_id: 1, - port: local_destination.port(), + ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(HttpResponse { + connection_id: 0, + request_id: 0, + port: destination_port, internal_response, })); - let message = match client_msg_rx.recv().await.ok_or(0).unwrap() { - ClientMessage::Ping => client_msg_rx.recv().await.ok_or(0).unwrap(), - other => other, - }; - assert_eq!(message, expected_response); + assert_eq!(test_connection.recv().await, expected_response); // ensure graceful behaviour on close - daemon_msg_tx - .send(DaemonMessage::Tcp(DaemonTcp::Close(TcpClose { - connection_id: 1, + test_connection + .send(DaemonMessage::TcpSteal(DaemonTcp::Close(TcpClose { + connection_id: 0, }))) - .await - .unwrap(); + .await; } } diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index d0d3fc88859..3992028b971 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -42,7 +42,7 @@ mod bound_socket; mod http; mod http_gateway; mod metadata_store; -pub mod port_subscription_ext; +mod port_subscription_ext; mod subscriptions; mod tasks; mod tcp_proxy; diff --git a/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs b/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs index c0a65a2072d..c6f4eb5a583 100644 --- a/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs +++ b/mirrord/intproxy/src/proxies/incoming/http/response_mode.rs @@ -4,10 +4,15 @@ use mirrord_protocol::tcp::{HTTP_CHUNKED_RESPONSE_VERSION, HTTP_FRAMED_VERSION}; /// responses. #[derive(Debug, Clone, Copy, Default)] pub enum ResponseMode { + /// Agent supports /// [`LayerTcpSteal::HttpResponseChunked`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseChunked) + /// and the previous variants. Chunked, + /// Agent supports /// [`LayerTcpSteal::HttpResponseFramed`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseFramed) + /// and the previous variant. Framed, + /// Agent supports only /// [`LayerTcpSteal::HttpResponse`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponse) #[default] Basic, diff --git a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs index 1ee3ec5eaf3..63cca2124f4 100644 --- a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs +++ b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs @@ -1,4 +1,5 @@ use std::{ + collections::VecDeque, convert::Infallible, error::Error, fmt, @@ -134,16 +135,46 @@ impl HttpGatewayTask { message_bus.send(HttpOut::ResponseFramed(response)).await } ResponseMode::Chunked => { - let ready_frames = body + let frames = body .ready_frames() - .map_err(LocalHttpError::ReadBodyFailed)? + .map_err(LocalHttpError::ReadBodyFailed)?; + + if frames.is_last { + let ready_frames = frames + .frames + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect::>(); + + tracing::trace!( + ?ready_frames, + "All response body frames were instantly ready, sending full response" + ); + let response = HttpResponse { + port: self.request.port, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + internal_response: InternalHttpResponse { + status: parts.status, + version: parts.version, + headers: parts.headers, + body: InternalHttpBody(ready_frames), + }, + }; + message_bus.send(HttpOut::ResponseFramed(response)).await; + + return Ok(()); + } + + let ready_frames = frames .frames .into_iter() .map(InternalHttpBodyFrame::from) - .collect(); + .collect::>(); tracing::trace!( ?ready_frames, - "Some response body frames were instantly ready" + "Some response body frames were instantly ready, \ + but response body may not be finished yet" ); let response = HttpResponse { diff --git a/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs b/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs index e928be69ace..58d92011076 100644 --- a/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs +++ b/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs @@ -3,7 +3,7 @@ use mirrord_intproxy_protocol::PortSubscription; use mirrord_protocol::{ tcp::{LayerTcp, LayerTcpSteal, StealType}, - ClientMessage, ConnectionId, Port, + ClientMessage, Port, }; /// Retrieves subscribed port from the given [`StealType`]. @@ -26,9 +26,6 @@ pub trait PortSubscriptionExt { /// Returns an unsubscribe request to be sent to the agent. fn wrap_agent_unsubscribe(&self) -> ClientMessage; - - /// Returns an unsubscribe connection request to be sent to the agent. - fn wrap_agent_unsubscribe_connection(&self, connection_id: ConnectionId) -> ClientMessage; } impl PortSubscriptionExt for PortSubscription { @@ -58,14 +55,4 @@ impl PortSubscriptionExt for PortSubscription { } } } - - /// [`LayerTcp::ConnectionUnsubscribe`] or [`LayerTcpSteal::ConnectionUnsubscribe`]. - fn wrap_agent_unsubscribe_connection(&self, connection_id: ConnectionId) -> ClientMessage { - match self { - Self::Mirror(..) => ClientMessage::Tcp(LayerTcp::ConnectionUnsubscribe(connection_id)), - Self::Steal(..) => { - ClientMessage::TcpSteal(LayerTcpSteal::ConnectionUnsubscribe(connection_id)) - } - } - } } From 78f47e890ef7c44f32586c8c473f2c70b3168f89 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Fri, 17 Jan 2025 17:19:22 +0100 Subject: [PATCH 30/60] Upgrade fixed --- .../src/proxies/incoming/http_gateway.rs | 276 ++++++++++-------- .../src/proxies/incoming/tcp_proxy.rs | 4 +- 2 files changed, 157 insertions(+), 123 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs index 63cca2124f4..329d03217e5 100644 --- a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs +++ b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs @@ -4,12 +4,13 @@ use std::{ error::Error, fmt, net::SocketAddr, + ops::ControlFlow, time::{Duration, Instant}, }; use exponential_backoff::Backoff; use http_body_util::BodyExt; -use hyper::StatusCode; +use hyper::{body::Incoming, http::response::Parts, StatusCode}; use mirrord_protocol::{ batched_body::BatchedBody, tcp::{ @@ -66,6 +67,137 @@ impl HttpGatewayTask { } } + /// Handles the response if we operate in [`ResponseMode::Chunked`]. + /// + /// # Returns + /// + /// * An error if we failed before sending the [`ChunkedResponse::Start`] message through the + /// [`MessageBus`] (we can still retry the request) + /// * [`ControlFlow::Break`] if we failed after sending the [`ChunkedResponse::Start`] message + /// * [`ControlFlow::Continue`] if we succeeded + async fn handle_response_chunked( + &self, + parts: Parts, + mut body: Incoming, + message_bus: &mut MessageBus, + ) -> Result, LocalHttpError> { + let frames = body + .ready_frames() + .map_err(LocalHttpError::ReadBodyFailed)?; + + if frames.is_last { + let ready_frames = frames + .frames + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect::>(); + + tracing::trace!( + ?ready_frames, + "All response body frames were instantly ready, sending full response" + ); + let response = HttpResponse { + port: self.request.port, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + internal_response: InternalHttpResponse { + status: parts.status, + version: parts.version, + headers: parts.headers, + body: InternalHttpBody(ready_frames), + }, + }; + message_bus.send(HttpOut::ResponseFramed(response)).await; + + return Ok(ControlFlow::Continue(())); + } + + let ready_frames = frames + .frames + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect::>(); + tracing::trace!( + ?ready_frames, + "Some response body frames were instantly ready, \ + but response body may not be finished yet" + ); + + let response = HttpResponse { + port: self.request.port, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + internal_response: InternalHttpResponse { + status: parts.status, + version: parts.version, + headers: parts.headers, + body: ready_frames, + }, + }; + message_bus + .send(HttpOut::ResponseChunked(ChunkedResponse::Start(response))) + .await; + + loop { + let start = Instant::now(); + match body.next_frames().await { + Ok(frames) => { + let is_last = frames.is_last; + let frames = frames + .frames + .into_iter() + .map(InternalHttpBodyFrame::from) + .collect::>(); + tracing::trace!( + ?frames, + is_last, + elapsed_ms = start.elapsed().as_millis(), + "Received a next batch of response body frames", + ); + + message_bus + .send(HttpOut::ResponseChunked(ChunkedResponse::Body( + ChunkedHttpBody { + frames, + is_last, + connection_id: self.request.connection_id, + request_id: self.request.request_id, + }, + ))) + .await; + + if is_last { + break; + } + } + + // Do not return any error here, as it would later be transformed into an error + // response. We already send the request head to the agent. + Err(error) => { + tracing::warn!( + error = ?ErrorWithSources(&error), + elapsed_ms = start.elapsed().as_millis(), + gateway = ?self, + "Failed to read next response body frames", + ); + + message_bus + .send(HttpOut::ResponseChunked(ChunkedResponse::Error( + ChunkedHttpError { + connection_id: self.request.connection_id, + request_id: self.request.request_id, + }, + ))) + .await; + + return Ok(ControlFlow::Break(())); + } + } + } + + Ok(ControlFlow::Continue(())) + } + /// Makes an attempt to send the request and read the whole response. /// /// [`Err`] is handled in the caller and, if we run out of send attempts, converted to an error @@ -78,11 +210,13 @@ impl HttpGatewayTask { .get(self.server_addr, self.request.version()) .await?; let mut response = client.send_request(self.request.clone()).await?; - let on_upgrade = (response.status() == StatusCode::SWITCHING_PROTOCOLS) - .then(|| hyper::upgrade::on(&mut response)); - let (parts, mut body) = response.into_parts(); + let on_upgrade = (response.status() == StatusCode::SWITCHING_PROTOCOLS).then(|| { + tracing::trace!("Detected an HTTP upgrade"); + hyper::upgrade::on(&mut response) + }); + let (parts, body) = response.into_parts(); - match self.response_mode { + let flow = match self.response_mode { ResponseMode::Basic => { let start = Instant::now(); let body: Vec = body @@ -108,7 +242,9 @@ impl HttpGatewayTask { body, }, }; - message_bus.send(HttpOut::ResponseBasic(response)).await + message_bus.send(HttpOut::ResponseBasic(response)).await; + + ControlFlow::Continue(()) } ResponseMode::Framed => { let start = Instant::now(); @@ -132,131 +268,27 @@ impl HttpGatewayTask { body, }, }; - message_bus.send(HttpOut::ResponseFramed(response)).await + message_bus.send(HttpOut::ResponseFramed(response)).await; + + ControlFlow::Continue(()) } ResponseMode::Chunked => { - let frames = body - .ready_frames() - .map_err(LocalHttpError::ReadBodyFailed)?; - - if frames.is_last { - let ready_frames = frames - .frames - .into_iter() - .map(InternalHttpBodyFrame::from) - .collect::>(); - - tracing::trace!( - ?ready_frames, - "All response body frames were instantly ready, sending full response" - ); - let response = HttpResponse { - port: self.request.port, - connection_id: self.request.connection_id, - request_id: self.request.request_id, - internal_response: InternalHttpResponse { - status: parts.status, - version: parts.version, - headers: parts.headers, - body: InternalHttpBody(ready_frames), - }, - }; - message_bus.send(HttpOut::ResponseFramed(response)).await; - - return Ok(()); - } - - let ready_frames = frames - .frames - .into_iter() - .map(InternalHttpBodyFrame::from) - .collect::>(); - tracing::trace!( - ?ready_frames, - "Some response body frames were instantly ready, \ - but response body may not be finished yet" - ); - - let response = HttpResponse { - port: self.request.port, - connection_id: self.request.connection_id, - request_id: self.request.request_id, - internal_response: InternalHttpResponse { - status: parts.status, - version: parts.version, - headers: parts.headers, - body: ready_frames, - }, - }; - message_bus - .send(HttpOut::ResponseChunked(ChunkedResponse::Start(response))) - .await; - - loop { - let start = Instant::now(); - match body.next_frames().await { - Ok(frames) => { - let is_last = frames.is_last; - let frames = frames - .frames - .into_iter() - .map(InternalHttpBodyFrame::from) - .collect::>(); - tracing::trace!( - ?frames, - is_last, - elapsed_ms = start.elapsed().as_millis(), - "Received a next batch of response body frames", - ); - - message_bus - .send(HttpOut::ResponseChunked(ChunkedResponse::Body( - ChunkedHttpBody { - frames, - is_last, - connection_id: self.request.connection_id, - request_id: self.request.request_id, - }, - ))) - .await; - - if is_last { - break; - } - } - // Do not return any error here, - // as it would be transformed into an error response by the caller. - // We already send the request head to the agent. - Err(error) => { - tracing::warn!( - error = ?ErrorWithSources(&error), - elapsed_ms = start.elapsed().as_millis(), - gateway = ?self, - "Failed to read next response body frames", - ); - - message_bus - .send(HttpOut::ResponseChunked(ChunkedResponse::Error( - ChunkedHttpError { - connection_id: self.request.connection_id, - request_id: self.request.request_id, - }, - ))) - .await; - - return Ok(()); - } - } - } + self.handle_response_chunked(parts, body, message_bus) + .await? } + }; + + if flow.is_break() { + return Ok(()); } if let Some(on_upgrade) = on_upgrade { message_bus.send(HttpOut::Upgraded(on_upgrade)).await; + } else { + // If there was no upgrade and no error, the client can be reused. + self.client_store.push_idle(client); } - self.client_store.push_idle(client); - Ok(()) } } diff --git a/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs index 0eeb36433f3..c2199b36555 100644 --- a/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs +++ b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs @@ -81,7 +81,9 @@ impl BackgroundTask for TcpProxyTask { let stream = parts.io.into_inner(); let read_buf = parts.read_buf; - if !self.discard_data { + if !self.discard_data && !read_buf.is_empty() { + // We don't send empty data, + // because the agent recognizes it as a shutdown from the user application. message_bus.send(Vec::from(read_buf)).await; } From cfa95ffdc9982aabd6267501e189664f7ac3e5c6 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Sat, 18 Jan 2025 14:38:44 +0100 Subject: [PATCH 31/60] Extended changelog --- changelog.d/3013.fixed.md | 1 + 1 file changed, 1 insertion(+) diff --git a/changelog.d/3013.fixed.md b/changelog.d/3013.fixed.md index 794d3f2c5f3..811ab816b8c 100644 --- a/changelog.d/3013.fixed.md +++ b/changelog.d/3013.fixed.md @@ -1 +1,2 @@ Fixed an issue where HTTP requests stolen with a filter would hang with a single-threaded local HTTP server. +Improved handling of incoming connections on the local machine (e.g introduces reuse of local HTTP connections). From c612152cb65db0e106fc656d55879a5bf75b95f5 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Sat, 18 Jan 2025 14:46:55 +0100 Subject: [PATCH 32/60] Frames doc --- mirrord/protocol/src/batched_body.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mirrord/protocol/src/batched_body.rs b/mirrord/protocol/src/batched_body.rs index 479eb1af61f..f6973f3fe30 100644 --- a/mirrord/protocol/src/batched_body.rs +++ b/mirrord/protocol/src/batched_body.rs @@ -80,7 +80,12 @@ where } } +/// A batch of body [`Frame`]s. +/// +/// `D` parameter determines [`Body::Data`] type. pub struct Frames { + /// A batch of consecutive [`Frames`]. pub frames: Vec>, + /// Whether the [`Body`] has finished and this is the last batch. pub is_last: bool, } From 91a18c446ae3a168c9101d3d7451a1f24d40c730 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Sat, 18 Jan 2025 14:50:32 +0100 Subject: [PATCH 33/60] Helper function for BatchedBody --- mirrord/protocol/src/batched_body.rs | 59 +++++++++++++--------------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/mirrord/protocol/src/batched_body.rs b/mirrord/protocol/src/batched_body.rs index f6973f3fe30..9f5780cf495 100644 --- a/mirrord/protocol/src/batched_body.rs +++ b/mirrord/protocol/src/batched_body.rs @@ -24,23 +24,7 @@ where frames: vec![], is_last: false, }; - - loop { - match self.frame().now_or_never() { - None => { - frames.is_last = false; - break; - } - Some(None) => { - frames.is_last = true; - break; - } - Some(Some(result)) => { - frames.frames.push(result?); - } - } - } - + extend_with_ready(self, &mut frames)?; Ok(frames) } @@ -60,26 +44,37 @@ where } } - loop { - match self.frame().now_or_never() { - None => { - frames.is_last = false; - break; - } - Some(None) => { - frames.is_last = true; - break; - } - Some(Some(result)) => { - frames.frames.push(result?); - } - } - } + extend_with_ready(self, &mut frames)?; Ok(frames) } } +/// Extends the given [`Frames`] instance with [`Frame`]s that are available without blocking. +fn extend_with_ready( + body: &mut B, + frames: &mut Frames, +) -> Result<(), B::Error> { + loop { + match body.frame().now_or_never() { + None => { + frames.is_last = false; + break; + } + Some(None) => { + frames.is_last = true; + break; + } + Some(Some(result)) => { + frames.frames.push(result?); + frames.is_last = false; + } + } + } + + Ok(()) +} + /// A batch of body [`Frame`]s. /// /// `D` parameter determines [`Body::Data`] type. From 16a4da1d6bf6ee624e67d5652e53b4a51f138c82 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Sat, 18 Jan 2025 14:54:23 +0100 Subject: [PATCH 34/60] auto_responder -> unwrap instead of is_err + break --- mirrord/cli/src/port_forward.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/mirrord/cli/src/port_forward.rs b/mirrord/cli/src/port_forward.rs index 9196d200663..02e0e7b5771 100644 --- a/mirrord/cli/src/port_forward.rs +++ b/mirrord/cli/src/port_forward.rs @@ -1071,25 +1071,19 @@ mod test { match message { ClientMessage::Ping => { - if tx_to_port_forwarder + tx_to_port_forwarder .send(DaemonMessage::Pong) .await - .is_err() - { - break; - } + .unwrap(); } ClientMessage::ReadyForLogs => {} ClientMessage::SwitchProtocolVersion(version) => { - if tx_to_port_forwarder + tx_to_port_forwarder .send(DaemonMessage::SwitchProtocolVersionResponse( std::cmp::min(&version, &*mirrord_protocol::VERSION).clone(), )) .await - .is_err() - { - break; - } + .unwrap(); } other => tx_to_test_code.send(other).await.unwrap(), } From ec0fd2c95d72565edfdd57d9b03a7099183521b3 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Sat, 18 Jan 2025 15:01:02 +0100 Subject: [PATCH 35/60] ClientStore unwrap -> expect --- .../src/proxies/incoming/http/client_store.rs | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs index 6bafb0e619b..19053c5e29e 100644 --- a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs +++ b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs @@ -74,7 +74,10 @@ impl ClientStore { version: Version, ) -> Result { let ready = { - let mut guard = self.clients.lock().unwrap(); + let mut guard = self + .clients + .lock() + .expect("ClientStore mutex is poisoned, this is a bug"); let position = guard.iter().position(|idle| { idle.client.handles_version(version) && idle.client.local_server_address() == server_addr @@ -101,7 +104,10 @@ impl ClientStore { /// Stores an unused [`LocalHttpClient`], so that it can be reused later. #[tracing::instrument(level = Level::TRACE, skip(self))] pub fn push_idle(&self, client: LocalHttpClient) { - let mut guard = self.clients.lock().unwrap(); + let mut guard = self + .clients + .lock() + .expect("ClientStore mutex is poisoned, this is a bug"); guard.push(IdleLocalClient { client, last_used: Instant::now(), @@ -113,7 +119,10 @@ impl ClientStore { async fn wait_for_ready(&self, server_addr: SocketAddr, version: Version) -> LocalHttpClient { loop { let notified = { - let mut guard = self.clients.lock().unwrap(); + let mut guard = self + .clients + .lock() + .expect("ClientStore mutex is poisoned, this is a bug"); let position = guard.iter().position(|idle| { idle.client.handles_version(version) && idle.client.local_server_address() == server_addr @@ -144,7 +153,9 @@ async fn cleanup_task(store: ClientStore, idle_client_timeout: Duration) { let now = Instant::now(); let mut min_last_used = None; let notified = { - let mut guard = clients.lock().unwrap(); + let mut guard = clients + .lock() + .expect("ClientStore mutex is poisoned, this is a bug"); let notified = notify.notified(); guard.retain(|client| { if client.last_used + idle_client_timeout > now { From b09ea4a64d31d116fd3ecaaf288871a4a1fa7a54 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Sat, 18 Jan 2025 15:06:14 +0100 Subject: [PATCH 36/60] Comments for ClientStore cleanup_task --- mirrord/intproxy/src/proxies/incoming/http/client_store.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs index 19053c5e29e..65264b9edc9 100644 --- a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs +++ b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs @@ -147,6 +147,8 @@ async fn cleanup_task(store: ClientStore, idle_client_timeout: Duration) { loop { let Some(clients) = clients.upgrade() else { + // Failed `upgrade` means that all `ClientStore` instances were dropped. + // This task is no longer needed. break; }; @@ -159,12 +161,14 @@ async fn cleanup_task(store: ClientStore, idle_client_timeout: Duration) { let notified = notify.notified(); guard.retain(|client| { if client.last_used + idle_client_timeout > now { + // We determine how long to sleep before cleaning the store again. min_last_used = min_last_used .map(|previous| cmp::min(previous, client.last_used)) .or(Some(client.last_used)); true } else { + // We drop the idle clients that have gone beyond the timeout. tracing::trace!(?client, "Dropping an idle client"); false } From 7e91864c731bd1b395f1a6fef2b425808d57d1c5 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Sun, 19 Jan 2025 11:39:58 +0100 Subject: [PATCH 37/60] Closed doc --- mirrord/intproxy/src/background_tasks.rs | 46 ++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/mirrord/intproxy/src/background_tasks.rs b/mirrord/intproxy/src/background_tasks.rs index 02092c32a43..8c33d893612 100644 --- a/mirrord/intproxy/src/background_tasks.rs +++ b/mirrord/intproxy/src/background_tasks.rs @@ -37,14 +37,60 @@ impl MessageBus { } } + /// Returns a [`Closed`] instance for this [`MessageBus`]. pub fn closed(&self) -> Closed { Closed(self.tx.clone()) } } +/// A helper struct bound to some [`MessageBus`] instance. +/// +/// Used in [`BackgroundTask`]s to `.await` on [`Future`]s without lingering after their +/// [`MessageBus`] is closed. +/// +/// It's lifetime does not depend on the origin [`MessageBus`] and it does not hold any references +/// to it, so that you can use it **and** the [`MessageBus`] at the same time. +/// +/// # Usage example +/// +/// ```rust +/// use std::convert::Infallible; +/// +/// use mirrord_intproxy::background_tasks::{BackgroundTask, Closed, MessageBus}; +/// +/// struct ExampleTask; +/// +/// impl ExampleTask { +/// /// Thanks to the usage of [`Closed`] in [`Self::run`], +/// /// this function can freely resolve [`Future`]s and use the [`MessageBus`]. +/// /// When the [`MessageBus`] is closed, the whole task will exit. +/// /// +/// /// To achieve the same without [`Closed`], you'd need to wrap each +/// /// [`Future`] resolution with [`tokio::select`]. +/// async fn do_work(&self, message_bus: &mut MessageBus) {} +/// } +/// +/// impl BackgroundTask for ExampleTask { +/// type MessageIn = Infallible; +/// type MessageOut = Infallible; +/// type Error = Infallible; +/// +/// async fn run(self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { +/// let closed: Closed = message_bus.closed(); +/// closed.cancel_on_close(self.do_work(message_bus)).await; +/// Ok(()) +/// } +/// } +/// ``` pub struct Closed(Sender); impl Closed { + /// Resolves the given [`Future`], unless the origin [`MessageBus`] closes first. + /// + /// # Returns + /// + /// * [`Some`] holding the future output - if the future resolved first + /// * [`None`] - if the [`MessageBus`] closed first pub async fn cancel_on_close(&self, future: F) -> Option { tokio::select! { _ = self.0.closed() => None, From 3ffd3d243deab89505794c0745bec1fa9d7afd01 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Sun, 19 Jan 2025 11:46:05 +0100 Subject: [PATCH 38/60] TcpStealApi::response_body_tx doc --- mirrord/agent/src/steal/api.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mirrord/agent/src/steal/api.rs b/mirrord/agent/src/steal/api.rs index 489f1ae5380..c359b9ba831 100644 --- a/mirrord/agent/src/steal/api.rs +++ b/mirrord/agent/src/steal/api.rs @@ -39,6 +39,15 @@ pub(crate) struct TcpStealerApi { /// View on the stealer task's status. task_status: TaskStatus, + /// [`Sender`]s that allow us to provide body [`Frame`]s of responses to filtered HTTP + /// requests. + /// + /// With [`LayerTcpSteal::HttpResponseChunked`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseChunked) + /// Response bodies come from the client in a series of + /// [`ChunkedResponse::Body`](mirrord_protocol::tcp::ChunkedResponse::Body) messages. + /// + /// Thus, we use [`ReceiverStreamBody`] for [`Response`](hyper::Response)'s body type and + /// pipe the [`Frame`]s through an [`mpsc::channel`]. response_body_txs: HashMap<(ConnectionId, RequestId), ResponseBodyTx>, } From d6fe56623f4c01804cf2410438bc5eb980fcad72 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Sun, 19 Jan 2025 11:48:31 +0100 Subject: [PATCH 39/60] Removed expect from client_store::cleanup_task --- .../src/proxies/incoming/http/client_store.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs index 65264b9edc9..69ee3ce512c 100644 --- a/mirrord/intproxy/src/proxies/incoming/http/client_store.rs +++ b/mirrord/intproxy/src/proxies/incoming/http/client_store.rs @@ -155,10 +155,11 @@ async fn cleanup_task(store: ClientStore, idle_client_timeout: Duration) { let now = Instant::now(); let mut min_last_used = None; let notified = { - let mut guard = clients - .lock() - .expect("ClientStore mutex is poisoned, this is a bug"); - let notified = notify.notified(); + let Ok(mut guard) = clients.lock() else { + tracing::error!("ClientStore mutex is poisoned, this is a bug"); + return; + }; + guard.retain(|client| { if client.last_used + idle_client_timeout > now { // We determine how long to sleep before cleaning the store again. @@ -173,7 +174,10 @@ async fn cleanup_task(store: ClientStore, idle_client_timeout: Duration) { false } }); - notified + + // Acquire [`Notified`] while still holding the lock. + // Prevents missed updates. + notify.notified() }; if let Some(min_last_used) = min_last_used { From e43712459cb2240f4a3f1a93eeea78bff2501cc6 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 10:19:59 +0100 Subject: [PATCH 40/60] Doc lint --- mirrord/agent/src/steal/api.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mirrord/agent/src/steal/api.rs b/mirrord/agent/src/steal/api.rs index c359b9ba831..15d2f265ba7 100644 --- a/mirrord/agent/src/steal/api.rs +++ b/mirrord/agent/src/steal/api.rs @@ -42,9 +42,8 @@ pub(crate) struct TcpStealerApi { /// [`Sender`]s that allow us to provide body [`Frame`]s of responses to filtered HTTP /// requests. /// - /// With [`LayerTcpSteal::HttpResponseChunked`](mirrord_protocol::tcp::LayerTcpSteal::HttpResponseChunked) - /// Response bodies come from the client in a series of - /// [`ChunkedResponse::Body`](mirrord_protocol::tcp::ChunkedResponse::Body) messages. + /// With [`LayerTcpSteal::HttpResponseChunked`], response bodies come from the client + /// in a series of [`ChunkedResponse::Body`] messages. /// /// Thus, we use [`ReceiverStreamBody`] for [`Response`](hyper::Response)'s body type and /// pipe the [`Frame`]s through an [`mpsc::channel`]. From 4fd0f6aa611e343349e871c888cb4535656c7fdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Smolarek?= <34063647+Razz4780@users.noreply.github.com> Date: Mon, 20 Jan 2025 15:45:37 +0100 Subject: [PATCH 41/60] Update mirrord/intproxy/src/background_tasks.rs Co-authored-by: meowjesty <43983236+meowjesty@users.noreply.github.com> --- mirrord/intproxy/src/background_tasks.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mirrord/intproxy/src/background_tasks.rs b/mirrord/intproxy/src/background_tasks.rs index 8c33d893612..d885375e7b0 100644 --- a/mirrord/intproxy/src/background_tasks.rs +++ b/mirrord/intproxy/src/background_tasks.rs @@ -48,7 +48,7 @@ impl MessageBus { /// Used in [`BackgroundTask`]s to `.await` on [`Future`]s without lingering after their /// [`MessageBus`] is closed. /// -/// It's lifetime does not depend on the origin [`MessageBus`] and it does not hold any references +/// Its lifetime does not depend on the origin [`MessageBus`] and it does not hold any references /// to it, so that you can use it **and** the [`MessageBus`] at the same time. /// /// # Usage example From ea247c9a4af6136fd47ba52b322e5d5ad768f692 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Smolarek?= <34063647+Razz4780@users.noreply.github.com> Date: Mon, 20 Jan 2025 21:43:21 +0100 Subject: [PATCH 42/60] Update mirrord/intproxy/src/proxies/incoming.rs Co-authored-by: meowjesty <43983236+meowjesty@users.noreply.github.com> --- mirrord/intproxy/src/proxies/incoming.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 3992028b971..24450970fbe 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -103,7 +103,7 @@ struct HttpGatewayHandle { /// Each request is handled independently by a single [`HttpGatewayTask`]. /// Also: /// 1. Local HTTP connections are reused when possible. -/// 2. Unless the error is fatal, each requests are retried a couple times. +/// 2. Unless the error is fatal, each request is retried a couple of times. /// 3. We never send [`LayerTcpSteal::ConnectionUnsubscribe`] (due to requests being handled /// independently). If a request fails locally, we send a /// [`StatusCode::BAD_GATEWAY`](hyper::http::StatusCode::BAD_GATEWAY) response. From f6fad541a4cb976913a1bc11355cb74c57cbe57b Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 21:45:17 +0100 Subject: [PATCH 43/60] error -> unreachable --- mirrord/cli/src/port_forward.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mirrord/cli/src/port_forward.rs b/mirrord/cli/src/port_forward.rs index 02e0e7b5771..617be65f354 100644 --- a/mirrord/cli/src/port_forward.rs +++ b/mirrord/cli/src/port_forward.rs @@ -583,7 +583,9 @@ impl ReversePortForwarder { TaskUpdate::Finished(result) => match result { Ok(()) => { - return Err(PortForwardError::IncomingProxyExited); + unreachable!( + "IncomingProxy should not finish, task sender is alive in this struct" + ); } Err(TaskError::Error(e)) => { return Err(e.into()); @@ -951,9 +953,6 @@ pub enum PortForwardError { #[error("error from the IncomingProxy task: {0}")] IncomingProxyError(#[from] IncomingProxyError), - #[error("IncomingProxy task unexpectedly exited")] - IncomingProxyExited, - #[error("IncomingProxy task panicked")] IncomingProxyPanicked, From 5ac91bc4aa6a9a35a9e91f463dfb9a9b3a1888a7 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 21:46:42 +0100 Subject: [PATCH 44/60] pub(crate) for Closed --- mirrord/intproxy/src/background_tasks.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mirrord/intproxy/src/background_tasks.rs b/mirrord/intproxy/src/background_tasks.rs index d885375e7b0..c3f2dcb4dca 100644 --- a/mirrord/intproxy/src/background_tasks.rs +++ b/mirrord/intproxy/src/background_tasks.rs @@ -38,7 +38,7 @@ impl MessageBus { } /// Returns a [`Closed`] instance for this [`MessageBus`]. - pub fn closed(&self) -> Closed { + pub(crate) fn closed(&self) -> Closed { Closed(self.tx.clone()) } } @@ -82,7 +82,7 @@ impl MessageBus { /// } /// } /// ``` -pub struct Closed(Sender); +pub(crate) struct Closed(Sender); impl Closed { /// Resolves the given [`Future`], unless the origin [`MessageBus`] closes first. @@ -91,7 +91,7 @@ impl Closed { /// /// * [`Some`] holding the future output - if the future resolved first /// * [`None`] - if the [`MessageBus`] closed first - pub async fn cancel_on_close(&self, future: F) -> Option { + pub(crate) async fn cancel_on_close(&self, future: F) -> Option { tokio::select! { _ = self.0.closed() => None, output = future => Some(output) From a02a3392ba2870656338da75daecc55a3adfb032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Smolarek?= <34063647+Razz4780@users.noreply.github.com> Date: Mon, 20 Jan 2025 21:48:29 +0100 Subject: [PATCH 45/60] Update mirrord/intproxy/src/proxies/incoming.rs Co-authored-by: meowjesty <43983236+meowjesty@users.noreply.github.com> --- mirrord/intproxy/src/proxies/incoming.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 24450970fbe..7ba64ffed1a 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -113,7 +113,7 @@ struct HttpGatewayHandle { /// # HTTP upgrades /// /// An HTTP request stolen with a filter can result in an HTTP upgrade. -/// When this happens, the TPC connection is recovered and passed to a new [`TcpProxyTask`]. +/// When this happens, the TCP connection is recovered and passed to a new [`TcpProxyTask`]. /// The TCP connection is then treated as stolen in whole. #[derive(Default)] pub struct IncomingProxy { From 4a012328b0125d9915404ae5ba570bd26630d631 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 21:50:50 +0100 Subject: [PATCH 46/60] not war --- mirrord/intproxy/src/proxies/incoming/http.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index 814fb5ebe0c..8fc0c18fa1e 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -1,4 +1,4 @@ -use std::{fmt, io, net::SocketAddr}; +use std::{fmt, io, net::SocketAddr, ops::Not}; use hyper::{ body::Incoming, @@ -119,12 +119,12 @@ impl LocalHttpError { match self { Self::SocketSetupFailed(..) | Self::UnsupportedHttpVersion(..) => false, Self::ConnectTcpFailed(..) => true, - Self::HandshakeFailed(err) | Self::SendFailed(err) | Self::ReadBodyFailed(err) => { - !(err.is_parse() - || err.is_parse_status() - || err.is_parse_too_large() - || err.is_user()) - } + Self::HandshakeFailed(err) | Self::SendFailed(err) | Self::ReadBodyFailed(err) => (err + .is_parse() + || err.is_parse_status() + || err.is_parse_too_large() + || err.is_user()) + .not(), } } } From 93bbeee426160613b15920eab1c2ddc487981807 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 21:52:04 +0100 Subject: [PATCH 47/60] More doccc --- mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs index c2199b36555..1862c07776c 100644 --- a/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs +++ b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs @@ -36,6 +36,8 @@ pub struct TcpProxyTask { /// The local connection between this task and the user application. connection: LocalTcpConnection, /// Whether this task should silently discard data coming from the user application. + /// + /// The data is discarded only when the remote connection is mirrored. discard_data: bool, } From d3d289853f5d9fe0a0b1bdcf2c4e0516b46152df Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 21:53:24 +0100 Subject: [PATCH 48/60] self_address -> self --- mirrord/intproxy/src/proxies/incoming/http.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index 8fc0c18fa1e..7e099575c9c 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -29,7 +29,7 @@ pub struct LocalHttpClient { /// Address of the user application's HTTP server. local_server_address: SocketAddr, /// Address of this client's TCP socket. - self_address: SocketAddr, + address: SocketAddr, } impl LocalHttpClient { @@ -44,7 +44,7 @@ impl LocalHttpClient { let local_server_address = stream .peer_addr() .map_err(LocalHttpError::SocketSetupFailed)?; - let self_address = stream + let address = stream .local_addr() .map_err(LocalHttpError::SocketSetupFailed)?; let sender = HttpSender::handshake(version, stream).await?; @@ -52,7 +52,7 @@ impl LocalHttpClient { Ok(Self { sender, local_server_address, - self_address, + address, }) } @@ -84,7 +84,7 @@ impl fmt::Debug for LocalHttpClient { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("LocalHttpClient") .field("local_server_address", &self.local_server_address) - .field("self_address", &self.self_address) + .field("address", &self.address) .field("is_http_1", &matches!(self.sender, HttpSender::V1(..))) .finish() } From 6342a43384fbb15ed079b03e0aeea45a8534fc7b Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 22:01:02 +0100 Subject: [PATCH 49/60] rephrased error messages --- mirrord/intproxy/src/proxies/incoming/http.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index 7e099575c9c..e61817625b2 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -93,22 +93,22 @@ impl fmt::Debug for LocalHttpClient { /// Errors that can occur when sending an HTTP request to the user application. #[derive(Error, Debug)] pub enum LocalHttpError { - #[error("handshake failed: {0}")] + #[error("failed to make an HTTP handshake with the local application's HTTP server: {0}")] HandshakeFailed(#[source] hyper::Error), - #[error("{0:?} is not supported")] + #[error("{0:?} is not supported in the local HTTP proxy")] UnsupportedHttpVersion(Version), - #[error("sending the request failed: {0}")] + #[error("failed to send the request to the local application's HTTP server: {0}")] SendFailed(#[source] hyper::Error), - #[error("setting up TCP socket failed: {0}")] + #[error("failed to prepare a local TCP socket: {0}")] SocketSetupFailed(#[source] io::Error), - #[error("making a TPC connection failed: {0}")] + #[error("failed to make a TPC connection with the local application's HTTP server: {0}")] ConnectTcpFailed(#[source] io::Error), - #[error("reading the response body failed: {0}")] + #[error("failed to read the body of the local application's HTTP server response: {0}")] ReadBodyFailed(#[source] hyper::Error), } From b7c1c31142999f4a3f09fe251d77b3dbed4586a6 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 22:04:39 +0100 Subject: [PATCH 50/60] instrument on LocalHttpClient::new --- mirrord/intproxy/src/proxies/incoming/http.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index e61817625b2..e0ed5f03a83 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -34,6 +34,7 @@ pub struct LocalHttpClient { impl LocalHttpClient { /// Makes an HTTP connection with the given server and creates a new client. + #[tracing::instrument(level = Level::TRACE, err(level = Level::WARN), ret)] pub async fn new( local_server_address: SocketAddr, version: Version, From dbe44184a24502ffd743c883c3c62fd959a1b64d Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 22:11:31 +0100 Subject: [PATCH 51/60] More docs on clone for StreamingBody --- .../src/proxies/incoming/http/streaming_body.rs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs b/mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs index f23adaed761..ecc0544d27d 100644 --- a/mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs +++ b/mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs @@ -11,12 +11,18 @@ use hyper::body::{Body, Frame}; use mirrord_protocol::tcp::{InternalHttpBody, InternalHttpBodyFrame}; use tokio::sync::mpsc::{self, Receiver}; -/// [`Body`] implementation that reads [`Frame`]s from an [`mpsc::channel`] and caches them -/// internally in a shared vector. +/// Cheaply cloneable [`Body`] implementation that reads [`Frame`]s from an [`mpsc::channel`]. /// -/// This struct maintains its position in the shared vector. -/// When cloned, it resets the index. This allows for replaying the body even though it is streamed -/// from a channel. +/// # Clone behavior +/// +/// All instances acquired via [`Clone`] share the [`mpsc::Receiver`] and a vector of previously +/// read frames. Each instance maintains its own position in the shared vector, and a new clone +/// starts at 0. +/// +/// When polled with [`Body::poll_frame`], an instance tries to return a cached frame. +/// +/// Thanks to this, each clone returns all frames from the start when polled with +/// [`Body::poll_frame`]. pub struct StreamingBody { /// Shared with instances acquired via [`Clone`]. /// Allows the clones to receive a copy of the data. From 123f310ac2e029d256fbd8a27ea0286b8c1be21a Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 22:13:20 +0100 Subject: [PATCH 52/60] docc --- .../src/proxies/incoming/http/streaming_body.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs b/mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs index ecc0544d27d..06a614f1ba4 100644 --- a/mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs +++ b/mirrord/intproxy/src/proxies/incoming/http/streaming_body.rs @@ -22,14 +22,16 @@ use tokio::sync::mpsc::{self, Receiver}; /// When polled with [`Body::poll_frame`], an instance tries to return a cached frame. /// /// Thanks to this, each clone returns all frames from the start when polled with -/// [`Body::poll_frame`]. +/// [`Body::poll_frame`]. As you'd expect from a cloneable [`Body`] implementation. pub struct StreamingBody { /// Shared with instances acquired via [`Clone`]. - /// Allows the clones to receive a copy of the data. + /// + /// Allows the clones to access previously fetched [`Frame`]s. shared_state: Arc, Vec)>>, - /// Index of the next frame to return from the buffer. + /// Index of the next frame to return from the buffer, not shared with other instances acquired + /// via [`Clone`]. + /// /// If outside of the buffer, we need to poll the stream to get the next frame. - /// Local state of this instance, zeroed when cloning. idx: usize, } From ff9d3171958638882b9cbd960ea5daf2e6514273 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 22:25:03 +0100 Subject: [PATCH 53/60] docsss --- .../intproxy/src/proxies/incoming/tasks.rs | 39 ++++++++++++++++++- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming/tasks.rs b/mirrord/intproxy/src/proxies/incoming/tasks.rs index d913715adba..49a315636c5 100644 --- a/mirrord/intproxy/src/proxies/incoming/tasks.rs +++ b/mirrord/intproxy/src/proxies/incoming/tasks.rs @@ -7,9 +7,19 @@ use mirrord_protocol::{ }; use thiserror::Error; +/// Messages produced by the [`BackgroundTask`](crate::background_tasks::BackgroundTask)s used in +/// the [`IncomingProxy`](super::IncomingProxy). pub enum InProxyTaskMessage { - Tcp(Vec), - Http(HttpOut), + /// Produced by the [`TcpProxyTask`](super::tcp_proxy::TcpProxyTask) in steal mode. + Tcp( + /// Data received from the local application. + Vec, + ), + /// Produced by the [`HttpGatewayTask`](super::http_gateway::HttpGatewayTask). + Http( + /// HTTP spefiic message. + HttpOut, + ), } impl fmt::Debug for InProxyTaskMessage { @@ -24,11 +34,16 @@ impl fmt::Debug for InProxyTaskMessage { } } +/// Messages produced by the [`HttpGatewayTask`](super::http_gateway::HttpGatewayTask). #[derive(Debug)] pub enum HttpOut { + /// Response from the local application's HTTP server. ResponseBasic(HttpResponse>), + /// Response from the local application's HTTP server. ResponseFramed(HttpResponse), + /// Response from the local application's HTTP server. ResponseChunked(ChunkedResponse), + /// Upgraded HTTP connection, to be handled as a remote connection stolen without any filter. Upgraded(OnUpgrade), } @@ -44,6 +59,17 @@ impl From for InProxyTaskMessage { } } +/// Errors that can occur in the [`BackgroundTask`](crate::background_tasks::BackgroundTask)s used +/// in the [`IncomingProxy`](super::IncomingProxy). +/// +/// All of these can occur only in the [`TcpProxyTask`](super::tcp_proxy::TcpProxyTask) +/// and mean that the local connection is irreversibly broken. +/// The [`HttpGatewayTask`](super::http_gateway::HttpGatewayTask) produces no errors +/// and instead responds with an error HTTP response to the agent. +/// +/// However, due to [`BackgroundTasks`](crate::background_tasks::BackgroundTasks) +/// type constraints, we need a common error type. +/// Thus, this type implements [`From`]. #[derive(Error, Debug)] pub enum InProxyTaskError { #[error("io failed: {0}")] @@ -58,10 +84,15 @@ impl From for InProxyTaskError { } } +/// Types of [`BackgroundTask`](crate::background_tasks::BackgroundTask)s used in the +/// [`IncomingProxy`](super::IncomingProxy). #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum InProxyTask { + /// [`TcpProxyTask`](super::tcp_proxy::TcpProxyTask) handling a mirrored connection. MirrorTcpProxy(ConnectionId), + /// [`TcpProxyTask`](super::tcp_proxy::TcpProxyTask) handling a stolen connection. StealTcpProxy(ConnectionId), + /// [`HttpGatewayTask`](super::http_gateway::HttpGatewayTask) handling a stolen HTTP request. HttpGateway(HttpGatewayId), } @@ -71,8 +102,12 @@ pub enum InProxyTask { /// error response in case the task somehow panics. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct HttpGatewayId { + /// Id of the remote connection. pub connection_id: ConnectionId, + /// Id of the stolen request. pub request_id: RequestId, + /// Remote port from which the request was stolen. pub port: Port, + /// HTTP version of the stolen request. pub version: Version, } From f2bc71c5b6e267972d389e14a93a6a44d8db0917 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 22:26:04 +0100 Subject: [PATCH 54/60] Doc fixed --- mirrord/intproxy/src/proxies/incoming.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 7ba64ffed1a..9d6066102d1 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -125,11 +125,11 @@ pub struct IncomingProxy { response_mode: ResponseMode, /// Cache for [`LocalHttpClient`](http::LocalHttpClient)s. client_store: ClientStore, - /// Each mirrored remote connection is mapped to a [TcpProxyTask] in mirror mode. + /// Each mirrored remote connection is mapped to a [`TcpProxyTask`] in mirror mode. /// /// Each entry here maps to a connection that is in progress both locally and remotely. mirror_tcp_proxies: HashMap>, - /// Each remote connection stolen in whole is mapped to a [TcpProxyTask] in steal mode. + /// Each remote connection stolen in whole is mapped to a [`TcpProxyTask`] in steal mode. /// /// Each entry here maps to a connection that is in progress both locally and remotely. steal_tcp_proxies: HashMap>, From 93b7baa5a0554f7fdb2f7d32aefc95cded2c4bb7 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 22:26:32 +0100 Subject: [PATCH 55/60] Doc fixed --- mirrord/intproxy/src/proxies/incoming.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 9d6066102d1..ff6766c60bc 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -133,7 +133,7 @@ pub struct IncomingProxy { /// /// Each entry here maps to a connection that is in progress both locally and remotely. steal_tcp_proxies: HashMap>, - /// Each remote HTTP request stolen with a filter is mapped to a [HttpGatewayTask]. + /// Each remote HTTP request stolen with a filter is mapped to a [`HttpGatewayTask`]. /// /// Each entry here maps to a request that is in progress both locally and remotely. http_gateways: HashMap>, From 9b22880017e77a8888e175b7b945b4d520da54d5 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Mon, 20 Jan 2025 22:27:49 +0100 Subject: [PATCH 56/60] more instrument --- mirrord/intproxy/src/proxies/incoming.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index ff6766c60bc..05399925eae 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -149,6 +149,7 @@ impl IncomingProxy { /// /// If we don't have a [`PortSubscription`] for the port, the task is not started. /// Instead, we respond immediately to the agent. + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] async fn start_http_gateway( &mut self, request: HttpRequest, @@ -227,6 +228,7 @@ impl IncomingProxy { /// /// If we don't have a [`PortSubscription`] for the port, the task is not started. /// Instead, we respond immediately to the agent. + #[tracing::instrument(level = Level::TRACE, skip(self, message_bus))] async fn handle_new_connection( &mut self, connection: NewTcpConnection, From 271df6331c550d66ac8ccfae8e2891dabd62217a Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 21 Jan 2025 09:24:01 +0100 Subject: [PATCH 57/60] in whole -> without a filter --- mirrord/intproxy/src/proxies/incoming.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 05399925eae..e348df29577 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -85,7 +85,7 @@ struct HttpGatewayHandle { /// Utilizes multiple background tasks ([`TcpProxyTask`]s and [`HttpGatewayTask`]s) to handle /// incoming connections and requests. /// -/// # Connections stolen/mirrored in whole +/// # Connections mirrored or stolen without a filter /// /// Each such connection exists in two places: /// @@ -114,7 +114,7 @@ struct HttpGatewayHandle { /// /// An HTTP request stolen with a filter can result in an HTTP upgrade. /// When this happens, the TCP connection is recovered and passed to a new [`TcpProxyTask`]. -/// The TCP connection is then treated as stolen in whole. +/// The TCP connection is then treated as stolen without a filter. #[derive(Default)] pub struct IncomingProxy { /// Active port subscriptions for all layers. @@ -129,7 +129,8 @@ pub struct IncomingProxy { /// /// Each entry here maps to a connection that is in progress both locally and remotely. mirror_tcp_proxies: HashMap>, - /// Each remote connection stolen in whole is mapped to a [`TcpProxyTask`] in steal mode. + /// Each remote connection stolen without a filter is mapped to a [`TcpProxyTask`] in steal + /// mode. /// /// Each entry here maps to a connection that is in progress both locally and remotely. steal_tcp_proxies: HashMap>, From f961054d26a84f0cf7a400ad1b2f86a9eb6f8797 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 21 Jan 2025 09:51:52 +0100 Subject: [PATCH 58/60] moar doccc --- mirrord/intproxy/src/proxies/incoming.rs | 11 +++++++++++ .../intproxy/src/proxies/incoming/http_gateway.rs | 2 ++ mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs | 13 ++++++++++++- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index e348df29577..4776434c9a5 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -94,6 +94,11 @@ struct HttpGatewayHandle { /// /// We are notified about such connections with the [`NewTcpConnection`] message. /// +/// The local connection lives until the agent or the user application closes it, or a local IO +/// error occurs. When we want to close this connection, we simply drop the [`TcpProxyTask`]'s +/// [`TaskSender`]. When a local IO error occurs, the [`TcpProxyTask`] finishes with an +/// [`InProxyTaskError`]. +/// /// # Requests stolen with a filter /// /// In the cluster, we have a real persistent connection between the agent and the original HTTP @@ -110,6 +115,12 @@ struct HttpGatewayHandle { /// /// We are notified about stolen requests with the [`HttpRequest`] messages. /// +/// The request can be cancelled only when one of the following happen: +/// 1. The agent closes the remote connection to which this request belongs +/// 2. The agent informs us that it failed to read request body ([`ChunkedRequest::Error`]) +/// +/// When we want to cancel the request, we drop the [`HttpGatewayTask`]'s [`TaskSender`]. +/// /// # HTTP upgrades /// /// An HTTP request stolen with a filter can result in an HTTP upgrade. diff --git a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs index 329d03217e5..e7f8a7819b0 100644 --- a/mirrord/intproxy/src/proxies/incoming/http_gateway.rs +++ b/mirrord/intproxy/src/proxies/incoming/http_gateway.rs @@ -30,6 +30,8 @@ use crate::background_tasks::{BackgroundTask, MessageBus}; /// [`BackgroundTask`] used by the [`IncomingProxy`](super::IncomingProxy). /// /// Responsible for delivering a single HTTP request to the user application. +/// +/// Exits immediately when it's [`TaskSender`](crate::background_tasks::TaskSender) is dropped. pub struct HttpGatewayTask { /// Request to deliver. request: HttpRequest, diff --git a/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs index 1862c07776c..6929f65b2f0 100644 --- a/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs +++ b/mirrord/intproxy/src/proxies/incoming/tcp_proxy.rs @@ -31,6 +31,13 @@ pub enum LocalTcpConnection { /// [`BackgroundTask`] of [`IncomingProxy`](super::IncomingProxy) that handles a remote /// stolen/mirrored TCP connection. +/// +/// In steal mode, exits immediately when it's [`TaskSender`](crate::background_tasks::TaskSender) +/// is dropped. +/// +/// In mirror mode, when it's [`TaskSender`](crate::background_tasks::TaskSender) is dropped, +/// this proxy keeps reading data from the user application and exits after +/// [`Self::MIRROR_MODE_LINGER_TIMEOUT`] of silence. #[derive(Debug)] pub struct TcpProxyTask { /// The local connection between this task and the user application. @@ -42,6 +49,10 @@ pub struct TcpProxyTask { } impl TcpProxyTask { + /// Mirror mode only: how long do we wait before exiting after the [`MessageBus`] is closed + /// and user application doesn't send any data. + pub const MIRROR_MODE_LINGER_TIMEOUT: Duration = Duration::from_secs(1); + /// Creates a new task. /// /// * This task will talk with the user application using the given [`LocalTcpConnection`]. @@ -172,7 +183,7 @@ impl BackgroundTask for TcpProxyTask { }, }, - _ = time::sleep(Duration::from_secs(1)), if is_lingering => { + _ = time::sleep(Self::MIRROR_MODE_LINGER_TIMEOUT), if is_lingering => { tracing::trace!( peer_addr = %peer_addr, self_addr = %self_addr, From b85878c5fe821e47d5cd05e8d4a61deb217ad0ad Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 21 Jan 2025 15:51:42 +0100 Subject: [PATCH 59/60] TPC -> TCP --- mirrord/intproxy/src/proxies/incoming.rs | 2 +- mirrord/intproxy/src/proxies/incoming/http.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 4776434c9a5..6ffa446dd4c 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -50,7 +50,7 @@ mod tcp_proxy; /// Errors that can occur when handling the `incoming` feature. #[derive(Error, Debug)] pub enum IncomingProxyError { - #[error("failed to prepare a TPC socket: {0}")] + #[error("failed to prepare a TCP socket: {0}")] SocketSetupFailed(#[source] io::Error), #[error("subscribing port failed: {0}")] SubscriptionFailed(#[source] ResponseError), diff --git a/mirrord/intproxy/src/proxies/incoming/http.rs b/mirrord/intproxy/src/proxies/incoming/http.rs index e0ed5f03a83..a871cebc2c5 100644 --- a/mirrord/intproxy/src/proxies/incoming/http.rs +++ b/mirrord/intproxy/src/proxies/incoming/http.rs @@ -106,7 +106,7 @@ pub enum LocalHttpError { #[error("failed to prepare a local TCP socket: {0}")] SocketSetupFailed(#[source] io::Error), - #[error("failed to make a TPC connection with the local application's HTTP server: {0}")] + #[error("failed to make a TCP connection with the local application's HTTP server: {0}")] ConnectTcpFailed(#[source] io::Error), #[error("failed to read the body of the local application's HTTP server response: {0}")] From 216e2c7301f7deece94a1173b58d1b04464620a1 Mon Sep 17 00:00:00 2001 From: Razz4780 Date: Tue, 21 Jan 2025 15:53:58 +0100 Subject: [PATCH 60/60] added ignore to doctest --- mirrord/intproxy/src/background_tasks.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mirrord/intproxy/src/background_tasks.rs b/mirrord/intproxy/src/background_tasks.rs index c3f2dcb4dca..82e6865c67e 100644 --- a/mirrord/intproxy/src/background_tasks.rs +++ b/mirrord/intproxy/src/background_tasks.rs @@ -53,7 +53,7 @@ impl MessageBus { /// /// # Usage example /// -/// ```rust +/// ```ignore /// use std::convert::Infallible; /// /// use mirrord_intproxy::background_tasks::{BackgroundTask, Closed, MessageBus};