Skip to content

Commit

Permalink
Add support for chunked HTTP data requests (#2484)
Browse files Browse the repository at this point in the history
* Add implementation for chunked HTTP data requests in agent

* Introduce Streamed variant of HttpRequestFallback for chunked requests

* Complete handle agent request for Chunked requests

* Add unit test for streamed http requests in intproxy

* Add unit test for streamed http requests in agent

* Change required version

Co-authored-by: Michał Smolarek <[email protected]>

---------

Co-authored-by: Michał Smolarek <[email protected]>
  • Loading branch information
gememma and Razz4780 authored Jun 24, 2024
1 parent 156b800 commit ff665e1
Show file tree
Hide file tree
Showing 11 changed files with 647 additions and 42 deletions.
5 changes: 4 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions changelog.d/2478.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for intercepting streaming HTTP requests with an HTTP filter.
240 changes: 233 additions & 7 deletions mirrord/agent/src/steal/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@ use std::{
};

use fancy_regex::Regex;
use http::Request;
use http_body_util::BodyExt;
use hyper::{
body::Incoming,
http::{header::UPGRADE, request::Parts},
Request,
};
use mirrord_protocol::{
tcp::{
DaemonTcp, HttpRequest, HttpResponseFallback, InternalHttpBody, InternalHttpRequest,
StealType, TcpClose, TcpData, HTTP_FILTERED_UPGRADE_VERSION, HTTP_FRAMED_VERSION,
ChunkedRequest, ChunkedRequestBody, ChunkedRequestError, DaemonTcp, HttpRequest,
HttpResponseFallback, InternalHttpBody, InternalHttpBodyFrame, InternalHttpRequest,
StealType, TcpClose, TcpData, HTTP_CHUNKED_VERSION, HTTP_FILTERED_UPGRADE_VERSION,
HTTP_FRAMED_VERSION,
},
ConnectionId, Port,
RemoteError::{BadHttpFilterExRegex, BadHttpFilterRegex},
Expand All @@ -31,7 +33,7 @@ use crate::{
connections::{
ConnectionMessageIn, ConnectionMessageOut, StolenConnection, StolenConnections,
},
http::HttpFilter,
http::{Frames, HttpFilter, IncomingExt},
orig_dst,
subscriptions::{IpTablesRedirector, PortSubscriptions},
Command, StealerCommand,
Expand Down Expand Up @@ -124,8 +126,8 @@ struct Client {

impl Client {
/// Attempts to spawn a new [`tokio::task`] to transform the given [`MatchedHttpRequest`] into
/// [`DaemonTcp::HttpRequest`] or [`DaemonTcp::HttpRequestFramed`] and send it via cloned
/// [`Client::tx`].
/// [`DaemonTcp::HttpRequest`], [`DaemonTcp::HttpRequestFramed`] or
/// [`DaemonTcp::HttpRequestChunked`] and send it via cloned [`Client::tx`].
///
/// Inspects [`Client::protocol_version`] to pick between [`DaemonTcp`] variants and check for
/// upgrade requests.
Expand All @@ -147,10 +149,86 @@ impl Client {
}

let framed = HTTP_FRAMED_VERSION.matches(&self.protocol_version);
let chunked = HTTP_CHUNKED_VERSION.matches(&self.protocol_version);
let tx = self.tx.clone();

tokio::spawn(async move {
if framed {
// Chunked data is preferred over framed data
if chunked {
// Send headers
let connection_id = request.connection_id;
let request_id = request.request_id;
let (
Parts {
method,
uri,
version,
headers,
..
},
mut body,
) = request.request.into_parts();
match body.next_frames(true).await {
Err(..) => return,
Ok(Frames { frames, is_last }) => {
let frames = frames
.into_iter()
.map(InternalHttpBodyFrame::try_from)
.filter_map(Result::ok)
.collect();
let message =
DaemonTcp::HttpRequestChunked(ChunkedRequest::Start(HttpRequest {
internal_request: InternalHttpRequest {
method,
uri,
headers,
version,
body: frames,
},
connection_id,
request_id,
port: request.port,
}));
if tx.send(message).await.is_err() || is_last {
return;
}
}
}

loop {
match body.next_frames(false).await {
Ok(Frames { frames, is_last }) => {
let frames = frames
.into_iter()
.map(InternalHttpBodyFrame::try_from)
.filter_map(Result::ok)
.collect();
let message = DaemonTcp::HttpRequestChunked(ChunkedRequest::Body(
ChunkedRequestBody {
frames,
is_last,
connection_id,
request_id,
},
));
if tx.send(message).await.is_err() || is_last {
return;
}
}
Err(_) => {
let _ = tx
.send(DaemonTcp::HttpRequestChunked(ChunkedRequest::Error(
ChunkedRequestError {
connection_id,
request_id,
},
)))
.await;
return;
}
}
}
} else if framed {
let Ok(request) = request.into_serializable().await else {
return;
};
Expand Down Expand Up @@ -581,3 +659,151 @@ impl TcpConnectionStealer {
Ok(())
}
}

#[cfg(test)]
mod test {
use std::net::SocketAddr;

use bytes::Bytes;
use futures::{future::BoxFuture, FutureExt};
use http::{Method, Request, Response, Version};
use http_body_util::{Empty, StreamBody};
use hyper::{
body::{Frame, Incoming},
service::Service,
};
use hyper_util::rt::TokioIo;
use mirrord_protocol::tcp::{ChunkedRequest, DaemonTcp, InternalHttpBodyFrame};
use tokio::{
net::{TcpListener, TcpStream},
sync::{
mpsc::{self, Receiver, Sender},
oneshot,
},
};
use tokio_stream::wrappers::ReceiverStream;

use crate::steal::connection::{Client, MatchedHttpRequest};
async fn prepare_dummy_service() -> (
SocketAddr,
Receiver<(Request<Incoming>, oneshot::Sender<Response<Empty<Bytes>>>)>,
) {
type ReqSender = Sender<(Request<Incoming>, oneshot::Sender<Response<Empty<Bytes>>>)>;
struct DummyService {
tx: ReqSender,
}

impl Service<Request<Incoming>> for DummyService {
type Response = Response<Empty<Bytes>>;

type Error = hyper::Error;

type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn call(&self, req: Request<Incoming>) -> Self::Future {
let tx = self.tx.clone();
async move {
let (res_tx, res_rx) = oneshot::channel();
tx.send((req, res_tx)).await.unwrap();
Ok(res_rx.await.unwrap())
}
.boxed()
}
}

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let server_address = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel(4);

tokio::spawn(async move {
loop {
let (conn, _) = listener.accept().await.unwrap();
let tx = tx.clone();
tokio::spawn(
hyper::server::conn::http1::Builder::new()
.serve_connection(TokioIo::new(conn), DummyService { tx }),
);
}
});

(server_address, rx)
}

#[tokio::test]
async fn test_streaming_response() {
let (addr, mut request_rx) = prepare_dummy_service().await;
let conn = TcpStream::connect(addr).await.unwrap();
let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(conn))
.await
.unwrap();
tokio::spawn(conn);

let (body_tx, body_rx) = mpsc::channel::<hyper::Result<Frame<Bytes>>>(12);
let body = StreamBody::new(ReceiverStream::new(body_rx));

// Send a frame to be ready in ChunkedRequest::Start before hyper sender is used
body_tx
.send(Ok(Frame::data(b"string".to_vec().into())))
.await
.unwrap();

tokio::spawn(
sender.send_request(
Request::builder()
.method(Method::POST)
.uri("/")
.version(Version::HTTP_11)
.body(body)
.unwrap(),
),
);

let (client_tx, mut client_rx) = mpsc::channel::<DaemonTcp>(4);
let client = Client {
tx: client_tx,
protocol_version: "1.7.0".parse().unwrap(),
subscribed_connections: Default::default(),
};

let (request, response_tx) = request_rx.recv().await.unwrap();
client.send_request_async(MatchedHttpRequest {
connection_id: 0,
port: 80,
request_id: 0,
request,
});

// Verify that single-framed ChunkedRequest::Start requests are as expected, containing any
// ready frames that were sent before Request was first sent
let msg = client_rx.recv().await.unwrap();
let DaemonTcp::HttpRequestChunked(ChunkedRequest::Start(x)) = msg else {
panic!("unexpected type received: {msg:?}")
};
assert_eq!(
x.internal_request.body,
vec![InternalHttpBodyFrame::Data(b"string".to_vec().into())]
);
let x = client_rx.recv().now_or_never();
assert!(x.is_none());

// Verify that single-framed ChunkedRequest::Body requests are as expected
body_tx
.send(Ok(Frame::data(b"another_string".to_vec().into())))
.await
.unwrap();
let msg = client_rx.recv().await.unwrap();
let DaemonTcp::HttpRequestChunked(ChunkedRequest::Body(x)) = msg else {
panic!("unexpected type received: {msg:?}")
};
assert_eq!(
x.frames,
vec![InternalHttpBodyFrame::Data(
b"another_string".to_vec().into()
)]
);
let x = client_rx.recv().now_or_never();
assert!(x.is_none());

let _ = response_tx.send(Response::new(Empty::default()));
}
}
1 change: 1 addition & 0 deletions mirrord/agent/src/steal/connections/filtered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ where

#[cfg(test)]
mod test {

use bytes::BytesMut;
use http::{
header::{CONNECTION, UPGRADE},
Expand Down
6 changes: 5 additions & 1 deletion mirrord/agent/src/steal/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
use crate::http::HttpVersion;

mod body_chunks;
mod filter;
mod reversible_stream;

pub use filter::HttpFilter;

pub(crate) use self::reversible_stream::ReversibleStream;
pub(crate) use self::{
body_chunks::{Frames, IncomingExt},
reversible_stream::ReversibleStream,
};

/// Handy alias due to [`ReversibleStream`] being generic, avoiding value mismatches.
pub(crate) type DefaultReversibleStream = ReversibleStream<{ HttpVersion::MINIMAL_HEADER_SIZE }>;
Loading

0 comments on commit ff665e1

Please sign in to comment.