diff --git a/Cargo.lock b/Cargo.lock index cab089525e1..911b74ec041 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4002,6 +4002,7 @@ name = "mirrord-intproxy" version = "3.106.0" dependencies = [ "bytes", + "futures", "http-body-util", "hyper 1.3.1", "hyper-util", @@ -4180,7 +4181,7 @@ dependencies = [ [[package]] name = "mirrord-protocol" -version = "1.6.1" +version = "1.7.0" dependencies = [ "actix-codec", "bincode", @@ -4197,6 +4198,8 @@ dependencies = [ "serde", "socket2", "thiserror", + "tokio", + "tokio-stream", "tracing", ] diff --git a/changelog.d/2478.added.md b/changelog.d/2478.added.md new file mode 100644 index 00000000000..a6af4bfcff9 --- /dev/null +++ b/changelog.d/2478.added.md @@ -0,0 +1 @@ +Added support for intercepting streaming HTTP requests with an HTTP filter. \ No newline at end of file diff --git a/mirrord/agent/src/steal/connection.rs b/mirrord/agent/src/steal/connection.rs index 1d7aac28a0b..4eb754be4c0 100644 --- a/mirrord/agent/src/steal/connection.rs +++ b/mirrord/agent/src/steal/connection.rs @@ -4,16 +4,18 @@ use std::{ }; use fancy_regex::Regex; +use http::Request; use http_body_util::BodyExt; use hyper::{ body::Incoming, http::{header::UPGRADE, request::Parts}, - Request, }; use mirrord_protocol::{ tcp::{ - DaemonTcp, HttpRequest, HttpResponseFallback, InternalHttpBody, InternalHttpRequest, - StealType, TcpClose, TcpData, HTTP_FILTERED_UPGRADE_VERSION, HTTP_FRAMED_VERSION, + ChunkedRequest, ChunkedRequestBody, ChunkedRequestError, DaemonTcp, HttpRequest, + HttpResponseFallback, InternalHttpBody, InternalHttpBodyFrame, InternalHttpRequest, + StealType, TcpClose, TcpData, HTTP_CHUNKED_VERSION, HTTP_FILTERED_UPGRADE_VERSION, + HTTP_FRAMED_VERSION, }, ConnectionId, Port, RemoteError::{BadHttpFilterExRegex, BadHttpFilterRegex}, @@ -31,7 +33,7 @@ use crate::{ connections::{ ConnectionMessageIn, ConnectionMessageOut, StolenConnection, StolenConnections, }, - http::HttpFilter, + http::{Frames, HttpFilter, IncomingExt}, orig_dst, subscriptions::{IpTablesRedirector, PortSubscriptions}, Command, StealerCommand, @@ -124,8 +126,8 @@ struct Client { impl Client { /// Attempts to spawn a new [`tokio::task`] to transform the given [`MatchedHttpRequest`] into - /// [`DaemonTcp::HttpRequest`] or [`DaemonTcp::HttpRequestFramed`] and send it via cloned - /// [`Client::tx`]. + /// [`DaemonTcp::HttpRequest`], [`DaemonTcp::HttpRequestFramed`] or + /// [`DaemonTcp::HttpRequestChunked`] and send it via cloned [`Client::tx`]. /// /// Inspects [`Client::protocol_version`] to pick between [`DaemonTcp`] variants and check for /// upgrade requests. @@ -147,10 +149,86 @@ impl Client { } let framed = HTTP_FRAMED_VERSION.matches(&self.protocol_version); + let chunked = HTTP_CHUNKED_VERSION.matches(&self.protocol_version); let tx = self.tx.clone(); tokio::spawn(async move { - if framed { + // Chunked data is preferred over framed data + if chunked { + // Send headers + let connection_id = request.connection_id; + let request_id = request.request_id; + let ( + Parts { + method, + uri, + version, + headers, + .. + }, + mut body, + ) = request.request.into_parts(); + match body.next_frames(true).await { + Err(..) => return, + Ok(Frames { frames, is_last }) => { + let frames = frames + .into_iter() + .map(InternalHttpBodyFrame::try_from) + .filter_map(Result::ok) + .collect(); + let message = + DaemonTcp::HttpRequestChunked(ChunkedRequest::Start(HttpRequest { + internal_request: InternalHttpRequest { + method, + uri, + headers, + version, + body: frames, + }, + connection_id, + request_id, + port: request.port, + })); + if tx.send(message).await.is_err() || is_last { + return; + } + } + } + + loop { + match body.next_frames(false).await { + Ok(Frames { frames, is_last }) => { + let frames = frames + .into_iter() + .map(InternalHttpBodyFrame::try_from) + .filter_map(Result::ok) + .collect(); + let message = DaemonTcp::HttpRequestChunked(ChunkedRequest::Body( + ChunkedRequestBody { + frames, + is_last, + connection_id, + request_id, + }, + )); + if tx.send(message).await.is_err() || is_last { + return; + } + } + Err(_) => { + let _ = tx + .send(DaemonTcp::HttpRequestChunked(ChunkedRequest::Error( + ChunkedRequestError { + connection_id, + request_id, + }, + ))) + .await; + return; + } + } + } + } else if framed { let Ok(request) = request.into_serializable().await else { return; }; @@ -581,3 +659,151 @@ impl TcpConnectionStealer { Ok(()) } } + +#[cfg(test)] +mod test { + use std::net::SocketAddr; + + use bytes::Bytes; + use futures::{future::BoxFuture, FutureExt}; + use http::{Method, Request, Response, Version}; + use http_body_util::{Empty, StreamBody}; + use hyper::{ + body::{Frame, Incoming}, + service::Service, + }; + use hyper_util::rt::TokioIo; + use mirrord_protocol::tcp::{ChunkedRequest, DaemonTcp, InternalHttpBodyFrame}; + use tokio::{ + net::{TcpListener, TcpStream}, + sync::{ + mpsc::{self, Receiver, Sender}, + oneshot, + }, + }; + use tokio_stream::wrappers::ReceiverStream; + + use crate::steal::connection::{Client, MatchedHttpRequest}; + async fn prepare_dummy_service() -> ( + SocketAddr, + Receiver<(Request, oneshot::Sender>>)>, + ) { + type ReqSender = Sender<(Request, oneshot::Sender>>)>; + struct DummyService { + tx: ReqSender, + } + + impl Service> for DummyService { + type Response = Response>; + + type Error = hyper::Error; + + type Future = BoxFuture<'static, Result>; + + fn call(&self, req: Request) -> Self::Future { + let tx = self.tx.clone(); + async move { + let (res_tx, res_rx) = oneshot::channel(); + tx.send((req, res_tx)).await.unwrap(); + Ok(res_rx.await.unwrap()) + } + .boxed() + } + } + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_address = listener.local_addr().unwrap(); + let (tx, rx) = mpsc::channel(4); + + tokio::spawn(async move { + loop { + let (conn, _) = listener.accept().await.unwrap(); + let tx = tx.clone(); + tokio::spawn( + hyper::server::conn::http1::Builder::new() + .serve_connection(TokioIo::new(conn), DummyService { tx }), + ); + } + }); + + (server_address, rx) + } + + #[tokio::test] + async fn test_streaming_response() { + let (addr, mut request_rx) = prepare_dummy_service().await; + let conn = TcpStream::connect(addr).await.unwrap(); + let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(conn)) + .await + .unwrap(); + tokio::spawn(conn); + + let (body_tx, body_rx) = mpsc::channel::>>(12); + let body = StreamBody::new(ReceiverStream::new(body_rx)); + + // Send a frame to be ready in ChunkedRequest::Start before hyper sender is used + body_tx + .send(Ok(Frame::data(b"string".to_vec().into()))) + .await + .unwrap(); + + tokio::spawn( + sender.send_request( + Request::builder() + .method(Method::POST) + .uri("/") + .version(Version::HTTP_11) + .body(body) + .unwrap(), + ), + ); + + let (client_tx, mut client_rx) = mpsc::channel::(4); + let client = Client { + tx: client_tx, + protocol_version: "1.7.0".parse().unwrap(), + subscribed_connections: Default::default(), + }; + + let (request, response_tx) = request_rx.recv().await.unwrap(); + client.send_request_async(MatchedHttpRequest { + connection_id: 0, + port: 80, + request_id: 0, + request, + }); + + // Verify that single-framed ChunkedRequest::Start requests are as expected, containing any + // ready frames that were sent before Request was first sent + let msg = client_rx.recv().await.unwrap(); + let DaemonTcp::HttpRequestChunked(ChunkedRequest::Start(x)) = msg else { + panic!("unexpected type received: {msg:?}") + }; + assert_eq!( + x.internal_request.body, + vec![InternalHttpBodyFrame::Data(b"string".to_vec().into())] + ); + let x = client_rx.recv().now_or_never(); + assert!(x.is_none()); + + // Verify that single-framed ChunkedRequest::Body requests are as expected + body_tx + .send(Ok(Frame::data(b"another_string".to_vec().into()))) + .await + .unwrap(); + let msg = client_rx.recv().await.unwrap(); + let DaemonTcp::HttpRequestChunked(ChunkedRequest::Body(x)) = msg else { + panic!("unexpected type received: {msg:?}") + }; + assert_eq!( + x.frames, + vec![InternalHttpBodyFrame::Data( + b"another_string".to_vec().into() + )] + ); + let x = client_rx.recv().now_or_never(); + assert!(x.is_none()); + + let _ = response_tx.send(Response::new(Empty::default())); + } +} diff --git a/mirrord/agent/src/steal/connections/filtered.rs b/mirrord/agent/src/steal/connections/filtered.rs index b3ec9c29656..900c4ed5877 100644 --- a/mirrord/agent/src/steal/connections/filtered.rs +++ b/mirrord/agent/src/steal/connections/filtered.rs @@ -802,6 +802,7 @@ where #[cfg(test)] mod test { + use bytes::BytesMut; use http::{ header::{CONNECTION, UPGRADE}, diff --git a/mirrord/agent/src/steal/http.rs b/mirrord/agent/src/steal/http.rs index 159d9c9aac8..53da34db228 100644 --- a/mirrord/agent/src/steal/http.rs +++ b/mirrord/agent/src/steal/http.rs @@ -2,12 +2,16 @@ use crate::http::HttpVersion; +mod body_chunks; mod filter; mod reversible_stream; pub use filter::HttpFilter; -pub(crate) use self::reversible_stream::ReversibleStream; +pub(crate) use self::{ + body_chunks::{Frames, IncomingExt}, + reversible_stream::ReversibleStream, +}; /// Handy alias due to [`ReversibleStream`] being generic, avoiding value mismatches. pub(crate) type DefaultReversibleStream = ReversibleStream<{ HttpVersion::MINIMAL_HEADER_SIZE }>; diff --git a/mirrord/agent/src/steal/http/body_chunks.rs b/mirrord/agent/src/steal/http/body_chunks.rs new file mode 100644 index 00000000000..83dfe3a1758 --- /dev/null +++ b/mirrord/agent/src/steal/http/body_chunks.rs @@ -0,0 +1,65 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use hyper::body::{Body, Frame, Incoming}; + +pub trait IncomingExt { + fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_>; +} + +impl IncomingExt for Incoming { + fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_> { + FramesFut { + body: self, + no_wait, + } + } +} + +pub struct FramesFut<'a> { + body: &'a mut Incoming, + no_wait: bool, +} + +impl<'a> Future for FramesFut<'a> { + type Output = hyper::Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut frames = vec![]; + + loop { + let result = match Pin::new(&mut self.as_mut().body).poll_frame(cx) { + Poll::Ready(Some(Err(error))) => Poll::Ready(Err(error)), + Poll::Ready(Some(Ok(frame))) => { + frames.push(frame); + continue; + } + Poll::Ready(None) => Poll::Ready(Ok(Frames { + frames, + is_last: true, + })), + Poll::Pending => { + if frames.is_empty() && !self.no_wait { + Poll::Pending + } else { + Poll::Ready(Ok(Frames { + frames, + is_last: false, + })) + } + } + }; + + break result; + } + } +} + +pub struct Frames { + pub frames: Vec>, + pub is_last: bool, +} diff --git a/mirrord/intproxy/Cargo.toml b/mirrord/intproxy/Cargo.toml index 6f67f3d8b4c..b5deb0bf71b 100644 --- a/mirrord/intproxy/Cargo.toml +++ b/mirrord/intproxy/Cargo.toml @@ -24,7 +24,7 @@ mirrord-kube = { path = "../kube" } mirrord-operator = { path = "../operator", features = ["client"] } mirrord-protocol = { path = "../protocol" } mirrord-intproxy-protocol = { path = "./protocol", features = ["codec-async"] } -mirrord-analytics = { path = "../analytics"} +mirrord-analytics = { path = "../analytics" } serde.workspace = true thiserror.workspace = true @@ -37,3 +37,6 @@ http-body-util.workspace = true bytes.workspace = true rand = "0.8" + +[dev-dependencies] +futures.workspace = true diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index bcdb2d359f1..6a52a3ba8e8 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -11,11 +11,17 @@ use mirrord_intproxy_protocol::{ MessageId, PortSubscribe, PortSubscription, PortUnsubscribe, ProxyToLayerMessage, }; use mirrord_protocol::{ - tcp::{DaemonTcp, HttpRequestFallback, NewTcpConnection}, - ConnectionId, ResponseError, + tcp::{ + ChunkedRequest, DaemonTcp, HttpRequest, HttpRequestFallback, InternalHttpBodyFrame, + InternalHttpRequest, NewTcpConnection, StreamingBody, + }, + ConnectionId, RequestId, ResponseError, }; use thiserror::Error; -use tokio::net::TcpSocket; +use tokio::{ + net::TcpSocket, + sync::mpsc::{self, Sender}, +}; use self::{ interceptor::{Interceptor, InterceptorError, MessageOut}, @@ -140,7 +146,7 @@ impl MetadataStore { /// data. /// /// Incoming connections are created by the agent either explicitly ([`NewTcpConnection`] message) -/// or implicitly ([`HttpRequest`](mirrord_protocol::tcp::HttpRequest)). +/// or implicitly ([`HttpRequest`]). #[derive(Default)] pub struct IncomingProxy { /// Active port subscriptions for all layers. @@ -151,6 +157,8 @@ pub struct IncomingProxy { background_tasks: BackgroundTasks, /// For managing intercepted connections metadata. metadata_store: MetadataStore, + /// For managing streamed [`DaemonTcp::HttpRequestChunked`] request channels. + request_body_txs: HashMap<(ConnectionId, RequestId), Sender>, } impl IncomingProxy { @@ -244,6 +252,8 @@ impl IncomingProxy { DaemonTcp::Close(close) => { self.interceptors .remove(&InterceptorId(close.connection_id)); + self.request_body_txs + .retain(|(connection_id, _), _| *connection_id != close.connection_id) } DaemonTcp::Data(data) => { if let Some(interceptor) = self.interceptors.get(&InterceptorId(data.connection_id)) @@ -270,6 +280,62 @@ impl IncomingProxy { interceptor.send(req).await; } } + DaemonTcp::HttpRequestChunked(req) => { + match req { + ChunkedRequest::Start(req) => { + let (tx, rx) = mpsc::channel::(12); + let http_stream = StreamingBody::new(rx); + let http_req = HttpRequest { + internal_request: InternalHttpRequest { + method: req.internal_request.method, + uri: req.internal_request.uri, + headers: req.internal_request.headers, + version: req.internal_request.version, + body: http_stream, + }, + connection_id: req.connection_id, + request_id: req.request_id, + port: req.port, + }; + let key = (http_req.connection_id, http_req.request_id); + + self.request_body_txs.insert(key, tx.clone()); + + let http_req = HttpRequestFallback::Streamed(http_req); + let interceptor = self.get_interceptor_for_http_request(&http_req)?; + if let Some(interceptor) = interceptor { + interceptor.send(http_req).await; + } + + for frame in req.internal_request.body { + if let Err(err) = tx.send(frame).await { + self.request_body_txs.remove(&key); + tracing::trace!(?err, "error while sending"); + } + } + } + ChunkedRequest::Body(body) => { + let key = &(body.connection_id, body.request_id); + let mut send_err = false; + if let Some(tx) = self.request_body_txs.get(key) { + for frame in body.frames { + if let Err(err) = tx.send(frame).await { + send_err = true; + tracing::trace!(?err, "error while sending"); + } + } + } + if send_err || body.is_last { + self.request_body_txs.remove(key); + } + } + ChunkedRequest::Error(err) => { + self.request_body_txs + .remove(&(err.connection_id, err.request_id)); + tracing::trace!(?err, "ChunkedRequest error received"); + } + }; + } DaemonTcp::NewConnection(NewTcpConnection { connection_id, remote_address, @@ -385,6 +451,8 @@ impl BackgroundTask for IncomingProxy { if let Some(msg) = msg { message_bus.send(msg).await; } + + self.request_body_txs.retain(|(connection_id, _), _| *connection_id != id.0); }, (id, TaskUpdate::Message(msg)) => { diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs index bff2356f7b5..7ccfc87a4a4 100644 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ b/mirrord/intproxy/src/proxies/incoming/interceptor.rs @@ -197,7 +197,6 @@ impl HttpConnection { Err(InterceptorError::ConnectionClosedTooSoon(request)) } - Err(InterceptorError::Hyper(e)) if e.is_parse() => { tracing::warn!( "Could not parse HTTP response to filtered HTTP request, got error: {e:?}." @@ -258,6 +257,18 @@ impl HttpConnection { .await .map(HttpResponseFallback::Fallback) } + HttpRequestFallback::Streamed(..) => { + // Returning `HttpResponseFallback::Framed` variant is safe - streaming + // requests require a strictly higher mirrord-protocol version + HttpResponse::::from_hyper_response( + res, + self.peer.port(), + request.connection_id(), + request.request_id(), + ) + .await + .map(HttpResponseFallback::Framed) + } }; Ok(result @@ -426,10 +437,15 @@ impl RawConnection { #[cfg(test)] mod test { - use std::convert::Infallible; + use std::{ + convert::Infallible, + pin::Pin, + sync::{Arc, Mutex}, + }; use bytes::Bytes; - use http_body_util::Empty; + use futures::future::FutureExt; + use http_body_util::{BodyExt, Empty}; use hyper::{ body::Incoming, header::{HeaderValue, CONNECTION, UPGRADE}, @@ -439,16 +455,25 @@ mod test { Method, Request, Response, }; use hyper_util::rt::TokioIo; - use mirrord_protocol::tcp::{HttpRequest, InternalHttpRequest}; + use mirrord_intproxy_protocol::{IncomingRequest, LayerId, PortSubscribe, PortSubscription}; + use mirrord_protocol::{ + tcp::{Filter, HttpRequest, InternalHttpRequest, LayerTcpSteal, StealType, StreamingBody}, + ClientMessage, + }; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::TcpListener, - sync::watch, + sync::{watch, Notify}, task, }; use super::*; - use crate::background_tasks::{BackgroundTasks, TaskUpdate}; + use crate::{ + background_tasks::{BackgroundTasks, TaskUpdate}, + main_tasks::ProxyMessage, + proxies::incoming::{IncomingProxy, IncomingProxyError, IncomingProxyMessage}, + IntProxy, + }; /// Binary protocol over TCP. /// Server first sends bytes [`INITIAL_MESSAGE`], then echoes back all received data. @@ -618,4 +643,97 @@ mod test { let _ = shutdown_tx.send(true); server_task.await.expect("dummy echo server panicked"); } + + /// Ensure that [`HttpRequestFallback::Streamed`] are received frame by frame + #[tokio::test] + async fn receive_request_as_frames() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_destination = listener.local_addr().unwrap(); + + let mut tasks: BackgroundTasks<(), MessageOut, InterceptorError> = Default::default(); + let socket = TcpSocket::new_v4().unwrap(); + socket.bind("127.0.0.1:0".parse().unwrap()).unwrap(); + let interceptor = Interceptor::new(socket, local_destination); + let sender = tasks.register(interceptor, (), 8); + + let (tx, rx) = tokio::sync::mpsc::channel(12); + sender + .send(MessageIn::Http(HttpRequestFallback::Streamed( + HttpRequest { + internal_request: InternalHttpRequest { + method: Method::POST, + uri: "/".parse().unwrap(), + headers: Default::default(), + version: Version::HTTP_11, + body: StreamingBody::new(rx), + }, + connection_id: 1, + request_id: 2, + port: 3, + }, + ))) + .await; + let (connection, _peer_addr) = listener.accept().await.unwrap(); + + let tx = Mutex::new(Some(tx)); + let notifier = Arc::new(Notify::default()); + let finished = notifier.notified(); + + let service = service_fn(|mut req: Request| { + let tx = tx.lock().unwrap().take().unwrap(); + let notifier = notifier.clone(); + async move { + let x = req.body_mut().frame().now_or_never(); + assert!(x.is_none()); + tx.send(mirrord_protocol::tcp::InternalHttpBodyFrame::Data( + b"string".to_vec(), + )) + .await + .unwrap(); + let x = req + .body_mut() + .frame() + .await + .unwrap() + .unwrap() + .into_data() + .unwrap(); + assert_eq!(x, b"string".to_vec()); + let x = req.body_mut().frame().now_or_never(); + assert!(x.is_none()); + + tx.send(mirrord_protocol::tcp::InternalHttpBodyFrame::Data( + b"another_string".to_vec(), + )) + .await + .unwrap(); + let x = req + .body_mut() + .frame() + .await + .unwrap() + .unwrap() + .into_data() + .unwrap(); + assert_eq!(x, b"another_string".to_vec()); + + drop(tx); + let x = req.body_mut().frame().await; + assert!(x.is_none()); + + notifier.notify_waiters(); + Ok::<_, hyper::Error>(Response::new(Empty::::new())) + } + }); + let conn = http1::Builder::new().serve_connection(TokioIo::new(connection), service); + + tokio::select! { + result = conn => { + result.unwrap() + } + _ = finished => { + + } + } + } } diff --git a/mirrord/protocol/Cargo.toml b/mirrord/protocol/Cargo.toml index 1cc6a35d52a..43c44b00006 100644 --- a/mirrord/protocol/Cargo.toml +++ b/mirrord/protocol/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mirrord-protocol" -version = "1.6.1" +version = "1.7.0" authors.workspace = true description.workspace = true documentation.workspace = true @@ -31,6 +31,8 @@ fancy-regex = { workspace = true } libc.workspace = true socket2.workspace = true semver = { workspace = true, features = ["serde"] } +tokio-stream.workspace = true +tokio.workspace = true mirrord-macros = { path = "../macros" } diff --git a/mirrord/protocol/src/tcp.rs b/mirrord/protocol/src/tcp.rs index 3316f94391f..170d3d2dbf2 100644 --- a/mirrord/protocol/src/tcp.rs +++ b/mirrord/protocol/src/tcp.rs @@ -5,7 +5,7 @@ use std::{ fmt, net::IpAddr, pin::Pin, - sync::LazyLock, + sync::{Arc, LazyLock, Mutex}, task::{Context, Poll}, }; @@ -21,6 +21,7 @@ use hyper::{ use mirrord_macros::protocol_break; use semver::VersionReq; use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc::Receiver; use tracing::error; use crate::{ConnectionId, Port, RemoteResult, RequestId}; @@ -73,6 +74,41 @@ pub enum DaemonTcp { SubscribeResult(RemoteResult), HttpRequest(HttpRequest>), HttpRequestFramed(HttpRequest), + HttpRequestChunked(ChunkedRequest), +} + +/// Contents of a chunked message from server. +#[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] +pub enum ChunkedRequest { + Start(HttpRequest>), + Body(ChunkedRequestBody), + Error(ChunkedRequestError), +} + +/// Contents of a chunked message body frame from server. +#[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] +pub struct ChunkedRequestBody { + #[bincode(with_serde)] + pub frames: Vec, + pub is_last: bool, + pub connection_id: ConnectionId, + pub request_id: RequestId, +} + +impl From for Frame { + fn from(value: InternalHttpBodyFrame) -> Self { + match value { + InternalHttpBodyFrame::Data(data) => Frame::data(data.into()), + InternalHttpBodyFrame::Trailers(map) => Frame::trailers(map), + } + } +} + +/// An error occurred while processing chunked data from server. +#[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] +pub struct ChunkedRequestError { + pub connection_id: ConnectionId, + pub request_id: RequestId, } /// Wraps the string that will become a [`fancy_regex::Regex`], providing a nice API in @@ -221,10 +257,90 @@ where } } -#[derive(Debug, Clone)] +impl From> for Request> +where + E: From, +{ + fn from(value: InternalHttpRequest) -> Self { + let InternalHttpRequest { + method, + uri, + headers, + version, + body, + } = value; + let mut request = Request::new(BoxBody::new(body.map_err(|e| e.into()))); + *request.method_mut() = method; + *request.uri_mut() = uri; + *request.version_mut() = version; + *request.headers_mut() = headers; + + request + } +} + +#[derive(Clone, Debug)] pub enum HttpRequestFallback { Framed(HttpRequest), Fallback(HttpRequest>), + Streamed(HttpRequest), +} + +#[derive(Debug)] +pub struct StreamingBody { + /// Shared with instances acquired via [`Clone`]. + /// Allows the clones to receive a copy of the data. + origin: Arc, Vec)>>, + /// Index of the next frame to return from the buffer. + /// If outside of the buffer, we need to poll the stream to get the next frame. + /// Local state of this instance, zeroed when cloning. + idx: usize, +} + +impl StreamingBody { + pub fn new(rx: Receiver) -> Self { + Self { + origin: Arc::new(Mutex::new((rx, vec![]))), + idx: 0, + } + } +} + +impl Clone for StreamingBody { + fn clone(&self) -> Self { + Self { + origin: self.origin.clone(), + idx: 0, + } + } +} + +impl Body for StreamingBody { + type Data = Bytes; + + type Error = Infallible; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let this = self.get_mut(); + let mut guard = this.origin.lock().unwrap(); + + if let Some(frame) = guard.1.get(this.idx) { + this.idx += 1; + return Poll::Ready(Some(Ok(frame.clone().into()))); + } + + match std::task::ready!(guard.0.poll_recv(cx)) { + None => Poll::Ready(None), + Some(frame) => { + guard.1.push(frame.clone()); + this.idx += 1; + Poll::Ready(Some(Ok(frame.into()))) + } + } + } } impl HttpRequestFallback { @@ -232,6 +348,7 @@ impl HttpRequestFallback { match self { HttpRequestFallback::Framed(req) => req.connection_id, HttpRequestFallback::Fallback(req) => req.connection_id, + HttpRequestFallback::Streamed(req) => req.connection_id, } } @@ -239,6 +356,7 @@ impl HttpRequestFallback { match self { HttpRequestFallback::Framed(req) => req.port, HttpRequestFallback::Fallback(req) => req.port, + HttpRequestFallback::Streamed(req) => req.port, } } @@ -246,6 +364,7 @@ impl HttpRequestFallback { match self { HttpRequestFallback::Framed(req) => req.request_id, HttpRequestFallback::Fallback(req) => req.request_id, + HttpRequestFallback::Streamed(req) => req.request_id, } } @@ -253,6 +372,7 @@ impl HttpRequestFallback { match self { HttpRequestFallback::Framed(req) => req.version(), HttpRequestFallback::Fallback(req) => req.version(), + HttpRequestFallback::Streamed(req) => req.version(), } } @@ -263,6 +383,7 @@ impl HttpRequestFallback { match self { HttpRequestFallback::Framed(req) => req.internal_request.into(), HttpRequestFallback::Fallback(req) => req.internal_request.into(), + HttpRequestFallback::Streamed(req) => req.internal_request.into(), } } } @@ -272,6 +393,11 @@ impl HttpRequestFallback { pub static HTTP_FRAMED_VERSION: LazyLock = LazyLock::new(|| ">=1.3.0".parse().expect("Bad Identifier")); +/// Minimal mirrord-protocol version that allows [`DaemonTcp::HttpRequestChunked`] instead of +/// [`DaemonTcp::HttpRequest`]. +pub static HTTP_CHUNKED_VERSION: LazyLock = + LazyLock::new(|| ">=1.7.0".parse().expect("Bad Identifier")); + /// Minimal mirrord-protocol version that allows [`DaemonTcp::Data`] to be sent in the same /// connection as [`DaemonTcp::HttpRequestFramed`] and [`DaemonTcp::HttpRequest`]. pub static HTTP_FILTERED_UPGRADE_VERSION: LazyLock = @@ -282,10 +408,7 @@ pub static HTTP_FILTERED_UPGRADE_VERSION: LazyLock = #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] #[protocol_break(2)] #[bincode(bounds = "for<'de> Body: Serialize + Deserialize<'de>")] -pub struct HttpRequest -where - for<'de> Body: Serialize + Deserialize<'de>, -{ +pub struct HttpRequest { #[bincode(with_serde)] pub internal_request: InternalHttpRequest, pub connection_id: ConnectionId, @@ -295,10 +418,7 @@ where pub port: Port, } -impl HttpRequest -where - for<'de> B: Serialize + Deserialize<'de>, -{ +impl HttpRequest { /// Gets this request's HTTP version. pub fn version(&self) -> Version { self.internal_request.version @@ -402,15 +522,6 @@ impl From> for InternalHttpBodyFrame { } } -impl From for Frame { - fn from(frame: InternalHttpBodyFrame) -> Self { - match frame { - InternalHttpBodyFrame::Data(data) => Frame::data(Bytes::from(data)), - InternalHttpBodyFrame::Trailers(map) => Frame::trailers(map), - } - } -} - impl fmt::Debug for InternalHttpBodyFrame { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -465,6 +576,9 @@ impl HttpResponseFallback { HttpRequestFallback::Fallback(request) => HttpResponseFallback::Fallback( HttpResponse::>::response_from_request(request, status, message), ), + HttpRequestFallback::Streamed(request) => HttpResponseFallback::Framed( + HttpResponse::::response_from_request(request, status, message), + ), } } } @@ -522,8 +636,8 @@ impl HttpResponse { }) } - pub fn response_from_request( - request: HttpRequest, + pub fn response_from_request( + request: HttpRequest, status: StatusCode, message: &str, ) -> Self {