diff --git a/Cargo.lock b/Cargo.lock index a371188366a..459f8880fb2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4181,7 +4181,7 @@ dependencies = [ [[package]] name = "mirrord-protocol" -version = "1.7.0" +version = "1.8.0" dependencies = [ "actix-codec", "bincode", diff --git a/changelog.d/2557.added.md b/changelog.d/2557.added.md new file mode 100644 index 00000000000..e98ea1eca97 --- /dev/null +++ b/changelog.d/2557.added.md @@ -0,0 +1 @@ +Added support for streaming HTTP responses. \ No newline at end of file diff --git a/mirrord/agent/src/steal/api.rs b/mirrord/agent/src/steal/api.rs index 170e1442a37..44e9d10b564 100644 --- a/mirrord/agent/src/steal/api.rs +++ b/mirrord/agent/src/steal/api.rs @@ -1,5 +1,16 @@ -use mirrord_protocol::tcp::{DaemonTcp, HttpResponseFallback, LayerTcpSteal, TcpData}; +use std::collections::HashMap; + +use bytes::Bytes; +use hyper::body::Frame; +use mirrord_protocol::{ + tcp::{ + ChunkedResponse, DaemonTcp, HttpResponse, HttpResponseFallback, InternalHttpResponse, + LayerTcpSteal, ReceiverStreamBody, TcpData, + }, + RequestId, +}; use tokio::sync::mpsc::{self, OwnedPermit, Receiver, Sender}; +use tokio_stream::wrappers::ReceiverStream; use super::*; use crate::{ @@ -31,6 +42,8 @@ pub(crate) struct TcpStealerApi { /// View on the stealer task's status. task_status: TaskStatus, + + response_body_txs: HashMap<(ConnectionId, RequestId), Sender>>>, } impl TcpStealerApi { @@ -65,6 +78,7 @@ impl TcpStealerApi { close_permit: Some(close_permit), daemon_rx, task_status, + response_body_txs: HashMap::new(), }) } @@ -89,7 +103,13 @@ impl TcpStealerApi { #[tracing::instrument(level = "trace", skip(self))] pub(crate) async fn recv(&mut self) -> Result { match self.daemon_rx.recv().await { - Some(msg) => Ok(msg), + Some(msg) => { + if let DaemonTcp::Close(close) = &msg { + self.response_body_txs + .retain(|(key_id, _), _| *key_id != close.connection_id); + } + Ok(msg) + } None => Err(self.task_status.unwrap_err().await), } } @@ -153,6 +173,8 @@ impl TcpStealerApi { match message { LayerTcpSteal::PortSubscribe(port_steal) => self.port_subscribe(port_steal).await, LayerTcpSteal::ConnectionUnsubscribe(connection_id) => { + self.response_body_txs + .retain(|(key_id, _), _| *key_id != connection_id); self.connection_unsubscribe(connection_id).await } LayerTcpSteal::PortUnsubscribe(port) => self.port_unsubscribe(port).await, @@ -165,6 +187,63 @@ impl TcpStealerApi { self.http_response(HttpResponseFallback::Framed(response)) .await } + LayerTcpSteal::HttpResponseChunked(inner) => match inner { + ChunkedResponse::Start(response) => { + let (tx, rx) = mpsc::channel(12); + let body = ReceiverStreamBody::new(ReceiverStream::from(rx)); + let http_response: HttpResponse = HttpResponse { + port: response.port, + connection_id: response.connection_id, + request_id: response.request_id, + internal_response: InternalHttpResponse { + status: response.internal_response.status, + version: response.internal_response.version, + headers: response.internal_response.headers, + body, + }, + }; + + let key = (response.connection_id, response.request_id); + self.response_body_txs.insert(key, tx.clone()); + + self.http_response(HttpResponseFallback::Streamed(http_response)) + .await?; + + for frame in response.internal_response.body { + if let Err(err) = tx.send(Ok(frame.into())).await { + self.response_body_txs.remove(&key); + tracing::trace!(?err, "error while sending streaming response frame"); + } + } + Ok(()) + } + ChunkedResponse::Body(body) => { + let key = &(body.connection_id, body.request_id); + let mut send_err = false; + if let Some(tx) = self.response_body_txs.get(key) { + for frame in body.frames { + if let Err(err) = tx.send(Ok(frame.into())).await { + send_err = true; + tracing::trace!( + ?err, + "error while sending streaming response body" + ); + break; + } + } + } + if send_err || body.is_last { + self.response_body_txs.remove(key); + }; + Ok(()) + } + ChunkedResponse::Error(err) => { + self.response_body_txs + .remove(&(err.connection_id, err.request_id)); + tracing::trace!(?err, "ChunkedResponse error received"); + Ok(()) + } + }, } } } diff --git a/mirrord/agent/src/steal/connection.rs b/mirrord/agent/src/steal/connection.rs index 4eb754be4c0..393f6c0e7a6 100644 --- a/mirrord/agent/src/steal/connection.rs +++ b/mirrord/agent/src/steal/connection.rs @@ -11,8 +11,9 @@ use hyper::{ http::{header::UPGRADE, request::Parts}, }; use mirrord_protocol::{ + body_chunks::{BodyExt as _, Frames}, tcp::{ - ChunkedRequest, ChunkedRequestBody, ChunkedRequestError, DaemonTcp, HttpRequest, + ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, DaemonTcp, HttpRequest, HttpResponseFallback, InternalHttpBody, InternalHttpBodyFrame, InternalHttpRequest, StealType, TcpClose, TcpData, HTTP_CHUNKED_VERSION, HTTP_FILTERED_UPGRADE_VERSION, HTTP_FRAMED_VERSION, @@ -33,7 +34,7 @@ use crate::{ connections::{ ConnectionMessageIn, ConnectionMessageOut, StolenConnection, StolenConnections, }, - http::{Frames, HttpFilter, IncomingExt}, + http::HttpFilter, orig_dst, subscriptions::{IpTablesRedirector, PortSubscriptions}, Command, StealerCommand, @@ -204,7 +205,7 @@ impl Client { .filter_map(Result::ok) .collect(); let message = DaemonTcp::HttpRequestChunked(ChunkedRequest::Body( - ChunkedRequestBody { + ChunkedHttpBody { frames, is_last, connection_id, @@ -218,7 +219,7 @@ impl Client { Err(_) => { let _ = tx .send(DaemonTcp::HttpRequestChunked(ChunkedRequest::Error( - ChunkedRequestError { + ChunkedHttpError { connection_id, request_id, }, diff --git a/mirrord/agent/src/steal/http.rs b/mirrord/agent/src/steal/http.rs index 53da34db228..159d9c9aac8 100644 --- a/mirrord/agent/src/steal/http.rs +++ b/mirrord/agent/src/steal/http.rs @@ -2,16 +2,12 @@ use crate::http::HttpVersion; -mod body_chunks; mod filter; mod reversible_stream; pub use filter::HttpFilter; -pub(crate) use self::{ - body_chunks::{Frames, IncomingExt}, - reversible_stream::ReversibleStream, -}; +pub(crate) use self::reversible_stream::ReversibleStream; /// Handy alias due to [`ReversibleStream`] being generic, avoiding value mismatches. pub(crate) type DefaultReversibleStream = ReversibleStream<{ HttpVersion::MINIMAL_HEADER_SIZE }>; diff --git a/mirrord/intproxy/Cargo.toml b/mirrord/intproxy/Cargo.toml index b5deb0bf71b..9264b2b04bd 100644 --- a/mirrord/intproxy/Cargo.toml +++ b/mirrord/intproxy/Cargo.toml @@ -35,8 +35,6 @@ hyper = { workspace = true, features = ["client", "http1", "http2"] } hyper-util.workspace = true http-body-util.workspace = true bytes.workspace = true +futures.workspace = true rand = "0.8" - -[dev-dependencies] -futures.workspace = true diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 6a52a3ba8e8..53933e7307d 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -6,22 +6,28 @@ use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, }; +use futures::StreamExt; use mirrord_intproxy_protocol::{ ConnMetadataRequest, ConnMetadataResponse, IncomingRequest, IncomingResponse, LayerId, MessageId, PortSubscribe, PortSubscription, PortUnsubscribe, ProxyToLayerMessage, }; use mirrord_protocol::{ + body_chunks::BodyExt as _, tcp::{ - ChunkedRequest, DaemonTcp, HttpRequest, HttpRequestFallback, InternalHttpBodyFrame, - InternalHttpRequest, NewTcpConnection, StreamingBody, + ChunkedHttpBody, ChunkedHttpError, ChunkedRequest, ChunkedResponse, DaemonTcp, HttpRequest, + HttpRequestFallback, HttpResponse, HttpResponseFallback, InternalHttpBodyFrame, + InternalHttpRequest, InternalHttpResponse, LayerTcpSteal, NewTcpConnection, + ReceiverStreamBody, StreamingBody, TcpData, }, - ConnectionId, RequestId, ResponseError, + ClientMessage, ConnectionId, RequestId, ResponseError, }; use thiserror::Error; use tokio::{ net::TcpSocket, sync::mpsc::{self, Sender}, }; +use tokio_stream::{StreamMap, StreamNotifyClose}; +use tracing::debug; use self::{ interceptor::{Interceptor, InterceptorError, MessageOut}, @@ -159,6 +165,8 @@ pub struct IncomingProxy { metadata_store: MetadataStore, /// For managing streamed [`DaemonTcp::HttpRequestChunked`] request channels. request_body_txs: HashMap<(ConnectionId, RequestId), Sender>, + /// For managing streamed [`LayerTcpSteal::HttpResponseChunked`] response streams. + response_body_rxs: StreamMap<(ConnectionId, RequestId), StreamNotifyClose>, } impl IncomingProxy { @@ -253,7 +261,16 @@ impl IncomingProxy { self.interceptors .remove(&InterceptorId(close.connection_id)); self.request_body_txs - .retain(|(connection_id, _), _| *connection_id != close.connection_id) + .retain(|(connection_id, _), _| *connection_id != close.connection_id); + let keys: Vec<(ConnectionId, RequestId)> = self + .response_body_rxs + .keys() + .filter(|key| key.0 == close.connection_id) + .cloned() + .collect(); + for key in keys.iter() { + self.response_body_rxs.remove(key); + } } DaemonTcp::Data(data) => { if let Some(interceptor) = self.interceptors.get(&InterceptorId(data.connection_id)) @@ -418,6 +435,47 @@ impl BackgroundTask for IncomingProxy { async fn run(mut self, message_bus: &mut MessageBus) -> Result<(), Self::Error> { loop { tokio::select! { + Some(((connection_id, request_id), stream_item)) = self.response_body_rxs.next() => match stream_item { + Some(Ok(frame)) => { + let int_frame = InternalHttpBodyFrame::from(frame); + let res = ChunkedResponse::Body(ChunkedHttpBody { + frames: vec![int_frame], + is_last: false, + connection_id, + request_id, + }); + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( + res, + ))) + .await; + }, + Some(Err(error)) => { + debug!(%error, "Error while reading streamed response body"); + let res = ChunkedResponse::Error(ChunkedHttpError {connection_id, request_id}); + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( + res, + ))) + .await; + self.response_body_rxs.remove(&(connection_id, request_id)); + }, + None => { + let res = ChunkedResponse::Body(ChunkedHttpBody { + frames: vec![], + is_last: true, + connection_id, + request_id, + }); + message_bus + .send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked( + res, + ))) + .await; + self.response_body_rxs.remove(&(connection_id, request_id)); + } + }, + msg = message_bus.recv() => match msg { None => { tracing::trace!("message bus closed, exiting"); @@ -456,10 +514,50 @@ impl BackgroundTask for IncomingProxy { }, (id, TaskUpdate::Message(msg)) => { - let msg = self.get_subscription(id).and_then(|s| s.wrap_response(msg, id.0)); - if let Some(msg) = msg { - message_bus.send(msg).await; - } + let Some(PortSubscription::Steal(_)) = self.get_subscription(id) else { + continue; + }; + let msg = match msg { + MessageOut::Raw(bytes) => { + ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { + connection_id: id.0, + bytes, + })) + }, + MessageOut::Http(HttpResponseFallback::Fallback(res)) => { + ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse(res)) + }, + MessageOut::Http(HttpResponseFallback::Framed(res)) => { + ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(res)) + }, + MessageOut::Http(HttpResponseFallback::Streamed(mut res)) => { + let mut body = vec![]; + let key = (res.connection_id, res.request_id); + + match res.internal_response.body.next_frames(false).await { + Ok(frames) => { + frames.frames.into_iter().map(From::from).for_each(|frame| body.push(frame)); + }, + Err(error) => { + debug!(%error, "Error while receving streamed response frames"); + let res = ChunkedResponse::Error(ChunkedHttpError { connection_id: key.0, request_id: key.1 }); + message_bus.send(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked(res))).await; + continue; + }, + } + + self.response_body_rxs.insert(key, StreamNotifyClose::new(res.internal_response.body)); + + let internal_response = InternalHttpResponse { + status: res.internal_response.status, version: res.internal_response.version, headers: res.internal_response.headers, body + }; + let res = ChunkedResponse::Start(HttpResponse { + port: res.port , connection_id: res.connection_id, request_id: res.request_id, internal_response + }); + ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseChunked(res)) + }, + }; + message_bus.send(msg).await; }, }, } diff --git a/mirrord/intproxy/src/proxies/incoming/interceptor.rs b/mirrord/intproxy/src/proxies/incoming/interceptor.rs index 6f30645e282..316745ab1a1 100644 --- a/mirrord/intproxy/src/proxies/incoming/interceptor.rs +++ b/mirrord/intproxy/src/proxies/incoming/interceptor.rs @@ -11,7 +11,7 @@ use bytes::BytesMut; use hyper::{upgrade::OnUpgrade, StatusCode, Version}; use hyper_util::rt::TokioIo; use mirrord_protocol::tcp::{ - HttpRequestFallback, HttpResponse, HttpResponseFallback, InternalHttpBody, + HttpRequestFallback, HttpResponse, HttpResponseFallback, InternalHttpBody, ReceiverStreamBody, }; use thiserror::Error; use tokio::{ @@ -258,16 +258,14 @@ impl HttpConnection { .map(HttpResponseFallback::Fallback) } HttpRequestFallback::Streamed(..) => { - // Returning `HttpResponseFallback::Framed` variant is safe - streaming - // requests require a strictly higher mirrord-protocol version - HttpResponse::::from_hyper_response( + HttpResponse::::from_hyper_response( res, self.peer.port(), request.connection_id(), request.request_id(), ) .await - .map(HttpResponseFallback::Framed) + .map(HttpResponseFallback::Streamed) } }; @@ -437,10 +435,7 @@ impl RawConnection { #[cfg(test)] mod test { - use std::{ - convert::Infallible, - sync::{Arc, Mutex}, - }; + use std::sync::{Arc, Mutex}; use bytes::Bytes; use futures::future::FutureExt; @@ -594,7 +589,7 @@ mod test { match update { TaskUpdate::Message(MessageOut::Http(res)) => { let res = res - .into_hyper::() + .into_hyper::() .expect("failed to convert into hyper response"); assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS); println!("{:?}", res.headers()); diff --git a/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs b/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs index 4c9e4cedb65..e928be69ace 100644 --- a/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs +++ b/mirrord/intproxy/src/proxies/incoming/port_subscription_ext.rs @@ -2,12 +2,10 @@ use mirrord_intproxy_protocol::PortSubscription; use mirrord_protocol::{ - tcp::{HttpResponseFallback, LayerTcp, LayerTcpSteal, StealType, TcpData}, + tcp::{LayerTcp, LayerTcpSteal, StealType}, ClientMessage, ConnectionId, Port, }; -use super::interceptor::MessageOut; - /// Retrieves subscribed port from the given [`StealType`]. fn get_port(steal_type: &StealType) -> Port { match steal_type { @@ -31,10 +29,6 @@ pub trait PortSubscriptionExt { /// Returns an unsubscribe connection request to be sent to the agent. fn wrap_agent_unsubscribe_connection(&self, connection_id: ConnectionId) -> ClientMessage; - - /// Returns a message to be sent to the agent in response to data coming from an interceptor. - /// [`None`] means that the data should be discarded. - fn wrap_response(&self, res: MessageOut, connection_id: ConnectionId) -> Option; } impl PortSubscriptionExt for PortSubscription { @@ -74,26 +68,4 @@ impl PortSubscriptionExt for PortSubscription { } } } - - /// Always [`None`] for the `mirror` mode - data coming from the layer is discarded. - /// Corrent [`LayerTcpSteal`] variant for the `steal` mode. - fn wrap_response(&self, res: MessageOut, connection_id: ConnectionId) -> Option { - match self { - Self::Mirror(..) => None, - Self::Steal(..) => match res { - MessageOut::Raw(bytes) => { - Some(ClientMessage::TcpSteal(LayerTcpSteal::Data(TcpData { - connection_id, - bytes, - }))) - } - MessageOut::Http(HttpResponseFallback::Fallback(res)) => { - Some(ClientMessage::TcpSteal(LayerTcpSteal::HttpResponse(res))) - } - MessageOut::Http(HttpResponseFallback::Framed(res)) => Some( - ClientMessage::TcpSteal(LayerTcpSteal::HttpResponseFramed(res)), - ), - }, - } - } } diff --git a/mirrord/protocol/Cargo.toml b/mirrord/protocol/Cargo.toml index 43c44b00006..b191dea08f5 100644 --- a/mirrord/protocol/Cargo.toml +++ b/mirrord/protocol/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mirrord-protocol" -version = "1.7.0" +version = "1.8.0" authors.workspace = true description.workspace = true documentation.workspace = true diff --git a/mirrord/agent/src/steal/http/body_chunks.rs b/mirrord/protocol/src/body_chunks.rs similarity index 77% rename from mirrord/agent/src/steal/http/body_chunks.rs rename to mirrord/protocol/src/body_chunks.rs index 83dfe3a1758..e9e2a6cc073 100644 --- a/mirrord/agent/src/steal/http/body_chunks.rs +++ b/mirrord/protocol/src/body_chunks.rs @@ -5,14 +5,17 @@ use std::{ }; use bytes::Bytes; -use hyper::body::{Body, Frame, Incoming}; +use hyper::body::{Body, Frame}; -pub trait IncomingExt { - fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_>; +pub trait BodyExt { + fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_, B>; } -impl IncomingExt for Incoming { - fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_> { +impl BodyExt for B +where + B: Body, +{ + fn next_frames(&mut self, no_wait: bool) -> FramesFut<'_, B> { FramesFut { body: self, no_wait, @@ -20,12 +23,15 @@ impl IncomingExt for Incoming { } } -pub struct FramesFut<'a> { - body: &'a mut Incoming, +pub struct FramesFut<'a, B> { + body: &'a mut B, no_wait: bool, } -impl<'a> Future for FramesFut<'a> { +impl<'a, B> Future for FramesFut<'a, B> +where + B: Body + Unpin, +{ type Output = hyper::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { diff --git a/mirrord/protocol/src/lib.rs b/mirrord/protocol/src/lib.rs index c01e88ad2bb..afe503d5d2b 100644 --- a/mirrord/protocol/src/lib.rs +++ b/mirrord/protocol/src/lib.rs @@ -3,6 +3,7 @@ #![feature(lazy_cell)] #![warn(clippy::indexing_slicing)] +pub mod body_chunks; pub mod codec; pub mod dns; pub mod error; diff --git a/mirrord/protocol/src/tcp.rs b/mirrord/protocol/src/tcp.rs index 170d3d2dbf2..c8d811b089d 100644 --- a/mirrord/protocol/src/tcp.rs +++ b/mirrord/protocol/src/tcp.rs @@ -11,7 +11,7 @@ use std::{ use bincode::{Decode, Encode}; use bytes::Bytes; -use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use http_body_util::{combinators::BoxBody, BodyExt, Full, StreamBody}; use hyper::{ body::{Body, Frame, Incoming}, http, @@ -22,9 +22,10 @@ use mirrord_macros::protocol_break; use semver::VersionReq; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc::Receiver; +use tokio_stream::wrappers::ReceiverStream; use tracing::error; -use crate::{ConnectionId, Port, RemoteResult, RequestId}; +use crate::{body_chunks::BodyExt as _, ConnectionId, Port, RemoteResult, RequestId}; #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub struct NewTcpConnection { @@ -81,13 +82,13 @@ pub enum DaemonTcp { #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] pub enum ChunkedRequest { Start(HttpRequest>), - Body(ChunkedRequestBody), - Error(ChunkedRequestError), + Body(ChunkedHttpBody), + Error(ChunkedHttpError), } /// Contents of a chunked message body frame from server. #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] -pub struct ChunkedRequestBody { +pub struct ChunkedHttpBody { #[bincode(with_serde)] pub frames: Vec, pub is_last: bool, @@ -106,7 +107,7 @@ impl From for Frame { /// An error occurred while processing chunked data from server. #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] -pub struct ChunkedRequestError { +pub struct ChunkedHttpError { pub connection_id: ConnectionId, pub request_id: RequestId, } @@ -191,6 +192,14 @@ pub enum LayerTcpSteal { Data(TcpData), HttpResponse(HttpResponse>), HttpResponseFramed(HttpResponse), + HttpResponseChunked(ChunkedResponse), +} + +#[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] +pub enum ChunkedResponse { + Start(HttpResponse>), + Body(ChunkedHttpBody), + Error(ChunkedHttpError), } /// (De-)Serializable HTTP request. @@ -429,15 +438,15 @@ impl HttpRequest { #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] pub struct InternalHttpResponse { #[serde(with = "http_serde::status_code")] - status: StatusCode, + pub status: StatusCode, #[serde(with = "http_serde::version")] - version: Version, + pub version: Version, #[serde(with = "http_serde::header_map")] - headers: HeaderMap, + pub headers: HeaderMap, - body: Body, + pub body: Body, } impl InternalHttpResponse { @@ -536,10 +545,13 @@ impl fmt::Debug for InternalHttpBodyFrame { } } +pub type ReceiverStreamBody = StreamBody>>>; + #[derive(Debug)] pub enum HttpResponseFallback { Framed(HttpResponse), Fallback(HttpResponse>), + Streamed(HttpResponse), } impl HttpResponseFallback { @@ -547,6 +559,7 @@ impl HttpResponseFallback { match self { HttpResponseFallback::Framed(req) => req.connection_id, HttpResponseFallback::Fallback(req) => req.connection_id, + HttpResponseFallback::Streamed(req) => req.connection_id, } } @@ -554,13 +567,18 @@ impl HttpResponseFallback { match self { HttpResponseFallback::Framed(req) => req.request_id, HttpResponseFallback::Fallback(req) => req.request_id, + HttpResponseFallback::Streamed(req) => req.request_id, } } - pub fn into_hyper(self) -> Result>, http::Error> { + pub fn into_hyper(self) -> Result>, http::Error> + where + E: From, + { match self { HttpResponseFallback::Framed(req) => req.internal_response.try_into(), HttpResponseFallback::Fallback(req) => req.internal_response.try_into(), + HttpResponseFallback::Streamed(req) => req.internal_response.try_into(), } } @@ -576,8 +594,8 @@ impl HttpResponseFallback { HttpRequestFallback::Fallback(request) => HttpResponseFallback::Fallback( HttpResponse::>::response_from_request(request, status, message), ), - HttpRequestFallback::Streamed(request) => HttpResponseFallback::Framed( - HttpResponse::::response_from_request(request, status, message), + HttpRequestFallback::Streamed(request) => HttpResponseFallback::Streamed( + HttpResponse::::response_from_request(request, status, message), ), } } @@ -585,10 +603,7 @@ impl HttpResponseFallback { #[derive(Encode, Decode, Debug, PartialEq, Eq, Clone)] #[bincode(bounds = "for<'de> Body: Serialize + Deserialize<'de>")] -pub struct HttpResponse -where - for<'de> Body: Serialize + Deserialize<'de>, -{ +pub struct HttpResponse { /// This is used to make sure the response is sent in its turn, after responses to all earlier /// requests were already sent. pub port: Port, @@ -789,6 +804,88 @@ impl HttpResponse> { } } +impl HttpResponse { + pub async fn from_hyper_response( + response: Response, + port: Port, + connection_id: ConnectionId, + request_id: RequestId, + ) -> Result, hyper::Error> { + let ( + Parts { + status, + version, + headers, + .. + }, + mut body, + ) = response.into_parts(); + + let frames = body.next_frames(true).await?; + let (tx, rx) = tokio::sync::mpsc::channel(frames.frames.len().max(12)); + for frame in frames.frames { + tx.try_send(Ok(frame)) + .expect("Channel is open, capacity sufficient") + } + if !frames.is_last { + tokio::spawn(async move { + while let Some(frame) = body.frame().await { + if tx.send(frame).await.is_err() { + return; + } + } + }); + }; + + let body = StreamBody::new(ReceiverStream::from(rx)); + + let internal_response = InternalHttpResponse { + status, + headers, + version, + body, + }; + + Ok(HttpResponse { + request_id, + port, + connection_id, + internal_response, + }) + } + + pub fn response_from_request( + request: HttpRequest, + status: StatusCode, + message: &str, + ) -> Self { + let HttpRequest { + internal_request: InternalHttpRequest { version, .. }, + connection_id, + request_id, + port, + } = request; + + let (tx, rx) = tokio::sync::mpsc::channel(1); + let frame = Frame::data(Bytes::copy_from_slice(message.as_bytes())); + tx.try_send(Ok(frame)) + .expect("channel is open, capacity is sufficient"); + let body = StreamBody::new(ReceiverStream::new(rx)); + + Self { + port, + connection_id, + request_id, + internal_response: InternalHttpResponse { + status, + version, + headers: Default::default(), + body, + }, + } + } +} + impl TryFrom> for Response> { type Error = http::Error; @@ -830,3 +927,26 @@ impl TryFrom>> for Response> { )) } } + +impl TryFrom> for Response> +where + E: From, +{ + type Error = http::Error; + + fn try_from(value: InternalHttpResponse) -> Result { + let InternalHttpResponse { + status, + version, + headers, + body, + } = value; + + let mut builder = Response::builder().status(status).version(version); + if let Some(h) = builder.headers_mut() { + *h = headers; + } + + builder.body(BoxBody::new(body.map_err(|e| e.into()))) + } +}