diff --git a/changelog.d/+http-filter-cleanup.fixed.md b/changelog.d/+http-filter-cleanup.fixed.md new file mode 100644 index 00000000000..92adcb9d93c --- /dev/null +++ b/changelog.d/+http-filter-cleanup.fixed.md @@ -0,0 +1 @@ +Agent now correctly clears incoming port subscriptions of disconnected clients. diff --git a/mirrord/agent/src/sniffer.rs b/mirrord/agent/src/sniffer.rs index 0d5cccfb584..94c40e0ce67 100644 --- a/mirrord/agent/src/sniffer.rs +++ b/mirrord/agent/src/sniffer.rs @@ -139,7 +139,7 @@ pub(crate) struct TcpConnectionSniffer { sessions: TCPSessionMap, client_txs: HashMap>, - clients_closed: FuturesUnordered>, + clients_closed: FuturesUnordered, } impl Drop for TcpConnectionSniffer { @@ -432,7 +432,7 @@ mod test { atomic::{AtomicUsize, Ordering}, Arc, }, - time::Duration, + time::{Duration, Instant}, }; use api::TcpSnifferApi; @@ -440,6 +440,7 @@ mod test { tcp::{DaemonTcp, LayerTcp, NewTcpConnection, TcpClose, TcpData}, ConnectionId, LogLevel, }; + use rstest::rstest; use tcp_capture::test::TcpPacketsChannel; use tokio::sync::mpsc; @@ -856,4 +857,45 @@ mod test { }), ); } + + /// Verifies that [`TcpConnectionSniffer`] reacts to [`TcpSnifferApi`] being dropped + /// and clears the packet filter. + #[rstest] + #[timeout(Duration::from_secs(5))] + #[tokio::test] + async fn cleanup_on_client_closed() { + let mut setup = TestSnifferSetup::new(); + + let mut api = setup.get_api().await; + + api.handle_client_message(LayerTcp::PortSubscribe(80)) + .await + .unwrap(); + assert_eq!( + api.recv().await.unwrap(), + (DaemonTcp::SubscribeResult(Ok(80)), None), + ); + assert_eq!(setup.times_filter_changed(), 1); + + std::mem::drop(api); + let dropped_at = Instant::now(); + + loop { + match setup.times_filter_changed() { + 1 => { + println!( + "filter still not changed {}ms after client closed", + dropped_at.elapsed().as_millis() + ); + tokio::time::sleep(Duration::from_millis(20)).await; + } + + 2 => { + break; + } + + other => panic!("unexpected times filter changed {other}"), + } + } + } } diff --git a/mirrord/agent/src/steal/connection.rs b/mirrord/agent/src/steal/connection.rs index 6515f2ecfc9..f6b4a9f2b7b 100644 --- a/mirrord/agent/src/steal/connection.rs +++ b/mirrord/agent/src/steal/connection.rs @@ -30,9 +30,9 @@ use tokio::{ use tokio_util::sync::CancellationToken; use tracing::{warn, Level}; -use super::http::HttpResponseFallback; +use super::{http::HttpResponseFallback, subscriptions::PortRedirector}; use crate::{ - error::AgentResult, + error::{AgentError, AgentResult}, metrics::HTTP_REQUEST_IN_PROGRESS_COUNT, steal::{ connections::{ @@ -292,9 +292,9 @@ struct TcpStealerConfig { /// run in the same network namespace as the agent's target. /// /// Enabled by the `steal` feature for incoming traffic. -pub(crate) struct TcpConnectionStealer { +pub(crate) struct TcpConnectionStealer { /// For managing active subscriptions and port redirections. - port_subscriptions: PortSubscriptions, + port_subscriptions: PortSubscriptions, /// For receiving commands. /// The other end of this channel belongs to [`TcpStealerApi`](super::api::TcpStealerApi). @@ -304,7 +304,7 @@ pub(crate) struct TcpConnectionStealer { clients: HashMap, /// [`Future`](std::future::Future)s that resolve when stealer clients close. - clients_closed: FuturesUnordered>, + clients_closed: FuturesUnordered, /// Set of active connections stolen by [`Self::port_subscriptions`]. connections: StolenConnections, @@ -313,7 +313,7 @@ pub(crate) struct TcpConnectionStealer { support_ipv6: bool, } -impl TcpConnectionStealer { +impl TcpConnectionStealer { pub const TASK_NAME: &'static str = "Stealer"; /// Initializes a new [`TcpConnectionStealer`], but doesn't start the actual work. @@ -327,25 +327,39 @@ impl TcpConnectionStealer { .from_env::() .unwrap_or_default(); - let port_subscriptions = { - let redirector = IpTablesRedirector::new( - config.stealer_flush_connections, - config.pod_ips, - support_ipv6, - ) - .await?; + let redirector = IpTablesRedirector::new( + config.stealer_flush_connections, + config.pod_ips, + support_ipv6, + ) + .await?; - PortSubscriptions::new(redirector, 4) - }; + Ok(Self::with_redirector(command_rx, support_ipv6, redirector)) + } +} - Ok(Self { - port_subscriptions, +impl TcpConnectionStealer +where + Redirector: PortRedirector, + Redirector::Error: std::error::Error + Into, + AgentError: From, +{ + /// Creates a new stealer. + /// + /// Given [`PortRedirector`] will be used to capture incoming connections. + pub(crate) fn with_redirector( + command_rx: Receiver, + support_ipv6: bool, + redirector: Redirector, + ) -> Self { + Self { + port_subscriptions: PortSubscriptions::new(redirector, 4), command_rx, clients: HashMap::with_capacity(8), clients_closed: Default::default(), connections: StolenConnections::with_capacity(8), support_ipv6, - }) + } } /// Runs the tcp traffic stealer loop. @@ -383,7 +397,7 @@ impl TcpConnectionStealer { } Err(error) => { tracing::error!(?error, "Failed to accept a stolen connection"); - break Err(error); + break Err(error.into()); } }, @@ -644,6 +658,8 @@ impl TcpConnectionStealer { match command { Command::NewClient(daemon_tx, protocol_version) => { + self.clients_closed + .push(ChannelClosedFuture::new(daemon_tx.clone(), client_id)); self.clients.insert( client_id, Client { @@ -708,7 +724,7 @@ impl TcpConnectionStealer { #[cfg(test)] mod test { - use std::net::SocketAddr; + use std::{net::SocketAddr, time::Duration}; use bytes::Bytes; use futures::{future::BoxFuture, FutureExt}; @@ -719,18 +735,75 @@ mod test { service::Service, }; use hyper_util::rt::TokioIo; - use mirrord_protocol::tcp::{ChunkedRequest, DaemonTcp, InternalHttpBodyFrame}; + use mirrord_protocol::{ + tcp::{ChunkedRequest, DaemonTcp, Filter, HttpFilter, InternalHttpBodyFrame, StealType}, + Port, + }; use rstest::rstest; use tokio::{ net::{TcpListener, TcpStream}, sync::{ mpsc::{self, Receiver, Sender}, - oneshot, + oneshot, watch, }, }; use tokio_stream::wrappers::ReceiverStream; + use tokio_util::sync::CancellationToken; + + use super::AgentError; + use crate::{ + steal::{ + connection::{Client, MatchedHttpRequest}, + subscriptions::PortRedirector, + TcpConnectionStealer, TcpStealerApi, + }, + watched_task::TaskStatus, + }; + + /// Notification about a requested redirection operation. + /// + /// Produced by [`NotifyingRedirector`]. + #[derive(Debug, PartialEq, Eq)] + enum RedirectNotification { + Added(Port), + Removed(Port), + Cleanup, + } + + /// Test [`PortRedirector`] that never fails and notifies about requested operations using an + /// [`mpsc::channel`]. + struct NotifyingRedirector(Sender); + + #[async_trait::async_trait] + impl PortRedirector for NotifyingRedirector { + type Error = AgentError; + + async fn add_redirection(&mut self, port: Port) -> Result<(), Self::Error> { + self.0 + .send(RedirectNotification::Added(port)) + .await + .unwrap(); + Ok(()) + } + + async fn remove_redirection(&mut self, port: Port) -> Result<(), Self::Error> { + self.0 + .send(RedirectNotification::Removed(port)) + .await + .unwrap(); + Ok(()) + } + + async fn cleanup(&mut self) -> Result<(), Self::Error> { + self.0.send(RedirectNotification::Cleanup).await.unwrap(); + Ok(()) + } + + async fn next_connection(&mut self) -> Result<(TcpStream, SocketAddr), Self::Error> { + std::future::pending().await + } + } - use crate::steal::connection::{Client, MatchedHttpRequest}; async fn prepare_dummy_service() -> ( SocketAddr, Receiver<(Request, oneshot::Sender>>)>, @@ -907,4 +980,51 @@ mod test { let _ = response_tx.send(Response::new(Empty::default())); } + + /// Verifies that [`TcpConnectionStealer`] removes client's port subscriptions + /// when client's [`TcpStealerApi`] is dropped. + #[rstest] + #[timeout(Duration::from_secs(5))] + #[tokio::test] + async fn cleanup_on_client_closed() { + let (command_tx, command_rx) = mpsc::channel(8); + let (redirect_tx, mut redirect_rx) = mpsc::channel(2); + let stealer = TcpConnectionStealer::with_redirector( + command_rx, + false, + NotifyingRedirector(redirect_tx), + ); + + tokio::spawn(stealer.start(CancellationToken::new())); + + let (_dummy_tx, dummy_rx) = watch::channel(None); + let task_status = TaskStatus::dummy(TcpConnectionStealer::TASK_NAME, dummy_rx); + let mut api = TcpStealerApi::new( + 0, + command_tx.clone(), + task_status, + 8, + mirrord_protocol::VERSION.clone(), + ) + .await + .unwrap(); + + api.port_subscribe(StealType::FilteredHttpEx( + 80, + HttpFilter::Header(Filter::new("user: test".into()).unwrap()), + )) + .await + .unwrap(); + + let response = api.recv().await.unwrap(); + assert_eq!(response, DaemonTcp::SubscribeResult(Ok(80))); + + let notification = redirect_rx.recv().await.unwrap(); + assert_eq!(notification, RedirectNotification::Added(80)); + + std::mem::drop(api); + + let notification = redirect_rx.recv().await.unwrap(); + assert_eq!(notification, RedirectNotification::Removed(80)); + } } diff --git a/mirrord/agent/src/util.rs b/mirrord/agent/src/util.rs index 0c72cddd82f..c5a002979e9 100644 --- a/mirrord/agent/src/util.rs +++ b/mirrord/agent/src/util.rs @@ -8,6 +8,7 @@ use std::{ thread::JoinHandle, }; +use futures::{future::BoxFuture, FutureExt}; use tokio::sync::mpsc; use tracing::error; @@ -162,27 +163,25 @@ pub(crate) fn enter_namespace(pid: Option, namespace: &str) -> AgentResult< } /// [`Future`] that resolves to [`ClientId`] when the client drops their [`mpsc::Receiver`]. -pub(crate) struct ChannelClosedFuture { - tx: mpsc::Sender, - client_id: ClientId, -} +pub(crate) struct ChannelClosedFuture(BoxFuture<'static, ClientId>); + +impl ChannelClosedFuture { + pub(crate) fn new(tx: mpsc::Sender, client_id: ClientId) -> Self { + let future = async move { + tx.closed().await; + client_id + } + .boxed(); -impl ChannelClosedFuture { - pub(crate) fn new(tx: mpsc::Sender, client_id: ClientId) -> Self { - Self { tx, client_id } + Self(future) } } -impl Future for ChannelClosedFuture { +impl Future for ChannelClosedFuture { type Output = ClientId; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let client_id = self.client_id; - - let future = std::pin::pin!(self.get_mut().tx.closed()); - std::task::ready!(future.poll(cx)); - - Poll::Ready(client_id) + self.get_mut().0.as_mut().poll(cx) } } @@ -264,3 +263,52 @@ mod subscription_tests { assert_eq!(subscriptions.get_subscribed_topics(), Vec::::new()); } } + +#[cfg(test)] +mod channel_closed_tests { + use std::time::Duration; + + use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; + use rstest::rstest; + + use super::*; + + /// Verifies that [`ChannelClosedFuture`] resolves when the related [`mpsc::Receiver`] is + /// dropped. + #[rstest] + #[timeout(Duration::from_secs(5))] + #[tokio::test] + async fn channel_closed_resolves() { + let (tx, rx) = mpsc::channel::<()>(1); + let future = ChannelClosedFuture::new(tx, 0); + std::mem::drop(rx); + assert_eq!(future.await, 0); + } + + /// Verifies that [`ChannelClosedFuture`] works fine when used in [`FuturesUnordered`]. + /// + /// The future used to hold the [`mpsc::Sender`] and call poll [`mpsc::Sender::closed`] in it's + /// [`Future::poll`] implementation. This worked fine when the future was used in a simple way + /// ([`channel_closed_resolves`] test was passing). + /// + /// However, [`FuturesUnordered::next`] was hanging forever due to [`mpsc::Sender::closed`] + /// implementation details. + /// + /// New implementation of [`ChannelClosedFuture`] uses a [`BoxFuture`] internally, which works + /// fine. + #[rstest] + #[timeout(Duration::from_secs(5))] + #[tokio::test] + async fn channel_closed_works_in_futures_unordered() { + let mut unordered: FuturesUnordered = FuturesUnordered::new(); + + let (tx, rx) = mpsc::channel::<()>(1); + let future = ChannelClosedFuture::new(tx, 0); + + unordered.push(future); + + assert!(unordered.next().now_or_never().is_none()); + std::mem::drop(rx); + assert_eq!(unordered.next().await.unwrap(), 0); + } +} diff --git a/mirrord/agent/src/watched_task.rs b/mirrord/agent/src/watched_task.rs index ad06bb238ee..2e7370b262c 100644 --- a/mirrord/agent/src/watched_task.rs +++ b/mirrord/agent/src/watched_task.rs @@ -94,9 +94,21 @@ where } #[cfg(test)] -mod test { +pub(crate) mod test { use super::*; + impl TaskStatus { + pub fn dummy( + task_name: &'static str, + result_rx: Receiver>>, + ) -> Self { + Self { + task_name, + result_rx, + } + } + } + #[tokio::test] async fn simple_successful() { let task = WatchedTask::new("task", async move { Ok(()) });