Skip to content

Commit

Permalink
Streamed http responses (#2557)
Browse files Browse the repository at this point in the history
* Add Streamed Variant to HttpResponseFallback and add to Interceptor

* Add streamed response handling to IncomingProxy

* Enable streamed responses in agent
  • Loading branch information
gememma authored Jul 1, 2024
1 parent e4a249d commit 6c8249a
Show file tree
Hide file tree
Showing 13 changed files with 355 additions and 88 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions changelog.d/2557.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for streaming HTTP responses.
83 changes: 81 additions & 2 deletions mirrord/agent/src/steal/api.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
use mirrord_protocol::tcp::{DaemonTcp, HttpResponseFallback, LayerTcpSteal, TcpData};
use std::collections::HashMap;

use bytes::Bytes;
use hyper::body::Frame;
use mirrord_protocol::{
tcp::{
ChunkedResponse, DaemonTcp, HttpResponse, HttpResponseFallback, InternalHttpResponse,
LayerTcpSteal, ReceiverStreamBody, TcpData,
},
RequestId,
};
use tokio::sync::mpsc::{self, OwnedPermit, Receiver, Sender};
use tokio_stream::wrappers::ReceiverStream;

use super::*;
use crate::{
Expand Down Expand Up @@ -31,6 +42,8 @@ pub(crate) struct TcpStealerApi {

/// View on the stealer task's status.
task_status: TaskStatus,

response_body_txs: HashMap<(ConnectionId, RequestId), Sender<hyper::Result<Frame<Bytes>>>>,
}

impl TcpStealerApi {
Expand Down Expand Up @@ -65,6 +78,7 @@ impl TcpStealerApi {
close_permit: Some(close_permit),
daemon_rx,
task_status,
response_body_txs: HashMap::new(),
})
}

Expand All @@ -89,7 +103,13 @@ impl TcpStealerApi {
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) async fn recv(&mut self) -> Result<DaemonTcp> {
match self.daemon_rx.recv().await {
Some(msg) => Ok(msg),
Some(msg) => {
if let DaemonTcp::Close(close) = &msg {
self.response_body_txs
.retain(|(key_id, _), _| *key_id != close.connection_id);
}
Ok(msg)
}
None => Err(self.task_status.unwrap_err().await),
}
}
Expand Down Expand Up @@ -153,6 +173,8 @@ impl TcpStealerApi {
match message {
LayerTcpSteal::PortSubscribe(port_steal) => self.port_subscribe(port_steal).await,
LayerTcpSteal::ConnectionUnsubscribe(connection_id) => {
self.response_body_txs
.retain(|(key_id, _), _| *key_id != connection_id);
self.connection_unsubscribe(connection_id).await
}
LayerTcpSteal::PortUnsubscribe(port) => self.port_unsubscribe(port).await,
Expand All @@ -165,6 +187,63 @@ impl TcpStealerApi {
self.http_response(HttpResponseFallback::Framed(response))
.await
}
LayerTcpSteal::HttpResponseChunked(inner) => match inner {
ChunkedResponse::Start(response) => {
let (tx, rx) = mpsc::channel(12);
let body = ReceiverStreamBody::new(ReceiverStream::from(rx));
let http_response: HttpResponse<ReceiverStreamBody> = HttpResponse {
port: response.port,
connection_id: response.connection_id,
request_id: response.request_id,
internal_response: InternalHttpResponse {
status: response.internal_response.status,
version: response.internal_response.version,
headers: response.internal_response.headers,
body,
},
};

let key = (response.connection_id, response.request_id);
self.response_body_txs.insert(key, tx.clone());

self.http_response(HttpResponseFallback::Streamed(http_response))
.await?;

for frame in response.internal_response.body {
if let Err(err) = tx.send(Ok(frame.into())).await {
self.response_body_txs.remove(&key);
tracing::trace!(?err, "error while sending streaming response frame");
}
}
Ok(())
}
ChunkedResponse::Body(body) => {
let key = &(body.connection_id, body.request_id);
let mut send_err = false;
if let Some(tx) = self.response_body_txs.get(key) {
for frame in body.frames {
if let Err(err) = tx.send(Ok(frame.into())).await {
send_err = true;
tracing::trace!(
?err,
"error while sending streaming response body"
);
break;
}
}
}
if send_err || body.is_last {
self.response_body_txs.remove(key);
};
Ok(())
}
ChunkedResponse::Error(err) => {
self.response_body_txs
.remove(&(err.connection_id, err.request_id));
tracing::trace!(?err, "ChunkedResponse error received");
Ok(())
}
},
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions mirrord/agent/src/steal/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ use hyper::{
http::{header::UPGRADE, request::Parts},
};
use mirrord_protocol::{
body_chunks::{BodyExt as _, Frames},
tcp::{
ChunkedRequest, ChunkedRequestBody, ChunkedRequestError, DaemonTcp, HttpRequest,
ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, DaemonTcp, HttpRequest,
HttpResponseFallback, InternalHttpBody, InternalHttpBodyFrame, InternalHttpRequest,
StealType, TcpClose, TcpData, HTTP_CHUNKED_VERSION, HTTP_FILTERED_UPGRADE_VERSION,
HTTP_FRAMED_VERSION,
Expand All @@ -33,7 +34,7 @@ use crate::{
connections::{
ConnectionMessageIn, ConnectionMessageOut, StolenConnection, StolenConnections,
},
http::{Frames, HttpFilter, IncomingExt},
http::HttpFilter,
orig_dst,
subscriptions::{IpTablesRedirector, PortSubscriptions},
Command, StealerCommand,
Expand Down Expand Up @@ -204,7 +205,7 @@ impl Client {
.filter_map(Result::ok)
.collect();
let message = DaemonTcp::HttpRequestChunked(ChunkedRequest::Body(
ChunkedRequestBody {
ChunkedHttpBody {
frames,
is_last,
connection_id,
Expand All @@ -218,7 +219,7 @@ impl Client {
Err(_) => {
let _ = tx
.send(DaemonTcp::HttpRequestChunked(ChunkedRequest::Error(
ChunkedRequestError {
ChunkedHttpError {
connection_id,
request_id,
},
Expand Down
6 changes: 1 addition & 5 deletions mirrord/agent/src/steal/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@
use crate::http::HttpVersion;

mod body_chunks;
mod filter;
mod reversible_stream;

pub use filter::HttpFilter;

pub(crate) use self::{
body_chunks::{Frames, IncomingExt},
reversible_stream::ReversibleStream,
};
pub(crate) use self::reversible_stream::ReversibleStream;

/// Handy alias due to [`ReversibleStream`] being generic, avoiding value mismatches.
pub(crate) type DefaultReversibleStream = ReversibleStream<{ HttpVersion::MINIMAL_HEADER_SIZE }>;
4 changes: 1 addition & 3 deletions mirrord/intproxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ hyper = { workspace = true, features = ["client", "http1", "http2"] }
hyper-util.workspace = true
http-body-util.workspace = true
bytes.workspace = true
futures.workspace = true

rand = "0.8"

[dev-dependencies]
futures.workspace = true
114 changes: 106 additions & 8 deletions mirrord/intproxy/src/proxies/incoming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,28 @@ use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
};

use futures::StreamExt;
use mirrord_intproxy_protocol::{
ConnMetadataRequest, ConnMetadataResponse, IncomingRequest, IncomingResponse, LayerId,
MessageId, PortSubscribe, PortSubscription, PortUnsubscribe, ProxyToLayerMessage,
};
use mirrord_protocol::{
body_chunks::BodyExt as _,
tcp::{
ChunkedRequest, DaemonTcp, HttpRequest, HttpRequestFallback, InternalHttpBodyFrame,
InternalHttpRequest, NewTcpConnection, StreamingBody,
ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, ChunkedResponse, DaemonTcp, HttpRequest,
HttpRequestFallback, HttpResponse, HttpResponseFallback, InternalHttpBodyFrame,
InternalHttpRequest, InternalHttpResponse, LayerTcpSteal, NewTcpConnection,
ReceiverStreamBody, StreamingBody, TcpData,
},
ConnectionId, RequestId, ResponseError,
ClientMessage, ConnectionId, RequestId, ResponseError,
};
use thiserror::Error;
use tokio::{
net::TcpSocket,
sync::mpsc::{self, Sender},
};
use tokio_stream::{StreamMap, StreamNotifyClose};
use tracing::debug;

use self::{
interceptor::{Interceptor, InterceptorError, MessageOut},
Expand Down Expand Up @@ -159,6 +165,8 @@ pub struct IncomingProxy {
metadata_store: MetadataStore,
/// For managing streamed [`DaemonTcp::HttpRequestChunked`] request channels.
request_body_txs: HashMap<(ConnectionId, RequestId), Sender<InternalHttpBodyFrame>>,
/// For managing streamed [`LayerTcpSteal::HttpResponseChunked`] response streams.
response_body_rxs: StreamMap<(ConnectionId, RequestId), StreamNotifyClose<ReceiverStreamBody>>,
}

impl IncomingProxy {
Expand Down Expand Up @@ -253,7 +261,16 @@ impl IncomingProxy {
self.interceptors
.remove(&InterceptorId(close.connection_id));
self.request_body_txs
.retain(|(connection_id, _), _| *connection_id != close.connection_id)
.retain(|(connection_id, _), _| *connection_id != close.connection_id);
let keys: Vec<(ConnectionId, RequestId)> = self
.response_body_rxs
.keys()
.filter(|key| key.0 == close.connection_id)
.cloned()
.collect();
for key in keys.iter() {
self.response_body_rxs.remove(key);
}
}
DaemonTcp::Data(data) => {
if let Some(interceptor) = self.interceptors.get(&InterceptorId(data.connection_id))
Expand Down Expand Up @@ -418,6 +435,47 @@ impl BackgroundTask for IncomingProxy {
async fn run(mut self, message_bus: &mut MessageBus<Self>) -> Result<(), Self::Error> {
loop {
tokio::select! {
Some(((connection_id, request_id), stream_item)) = self.response_body_rxs.next() => match stream_item {
Some(Ok(frame)) => {
let int_frame = InternalHttpBodyFrame::from(frame);
let res = ChunkedResponse::Body(ChunkedHttpBody {
frames: vec![int_frame],
is_last: false,
connection_id,
request_id,
});
message_bus
.send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked(
res,
)))
.await;
},
Some(Err(error)) => {
debug!(%error, "Error while reading streamed response body");
let res = ChunkedResponse::Error(ChunkedHttpError {connection_id, request_id});
message_bus
.send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked(
res,
)))
.await;
self.response_body_rxs.remove(&(connection_id, request_id));
},
None => {
let res = ChunkedResponse::Body(ChunkedHttpBody {
frames: vec![],
is_last: true,
connection_id,
request_id,
});
message_bus
.send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked(
res,
)))
.await;
self.response_body_rxs.remove(&(connection_id, request_id));
}
},

msg = message_bus.recv() => match msg {
None => {
tracing::trace!("message bus closed, exiting");
Expand Down Expand Up @@ -456,10 +514,50 @@ impl BackgroundTask for IncomingProxy {
},

(id, TaskUpdate::Message(msg)) => {
let msg = self.get_subscription(id).and_then(|s| s.wrap_response(msg, id.0));
if let Some(msg) = msg {
message_bus.send(msg).await;
}
let Some(PortSubscription::Steal(_)) = self.get_subscription(id) else {
continue;
};
let msg = match msg {
MessageOut::Raw(bytes) => {
ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData {
connection_id: id.0,
bytes,
}))
},
MessageOut::Http(HttpResponseFallback::Fallback(res)) => {
ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse(res))
},
MessageOut::Http(HttpResponseFallback::Framed(res)) => {
ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(res))
},
MessageOut::Http(HttpResponseFallback::Streamed(mut res)) => {
let mut body = vec![];
let key = (res.connection_id, res.request_id);

match res.internal_response.body.next_frames(false).await {
Ok(frames) => {
frames.frames.into_iter().map(From::from).for_each(|frame| body.push(frame));
},
Err(error) => {
debug!(%error, "Error while receving streamed response frames");
let res = ChunkedResponse::Error(ChunkedHttpError { connection_id: key.0, request_id: key.1 });
message_bus.send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked(res))).await;
continue;
},
}

self.response_body_rxs.insert(key, StreamNotifyClose::new(res.internal_response.body));

let internal_response = InternalHttpResponse {
status: res.internal_response.status, version: res.internal_response.version, headers: res.internal_response.headers, body
};
let res = ChunkedResponse::Start(HttpResponse {
port: res.port , connection_id: res.connection_id, request_id: res.request_id, internal_response
});
ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked(res))
},
};
message_bus.send(msg).await;
},
},
}
Expand Down
Loading

0 comments on commit 6c8249a

Please sign in to comment.