Skip to content

Commit

Permalink
Fixed HTTP filter stuck forever (#3031)
Browse files Browse the repository at this point in the history
* Fixed TcpConnectionStealer and ChannelClosedFuture, added unit tests

* Changelog

* use rstest timeout on cleanup_on_client_closed stealer test
  • Loading branch information
Razz4780 authored Jan 23, 2025
1 parent 0aa4f40 commit afccbc8
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 40 deletions.
1 change: 1 addition & 0 deletions changelog.d/+http-filter-cleanup.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Agent now correctly clears incoming port subscriptions of disconnected clients.
46 changes: 44 additions & 2 deletions mirrord/agent/src/sniffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ pub(crate) struct TcpConnectionSniffer<T> {
sessions: TCPSessionMap,

client_txs: HashMap<ClientId, Sender<SniffedConnection>>,
clients_closed: FuturesUnordered<ChannelClosedFuture<SniffedConnection>>,
clients_closed: FuturesUnordered<ChannelClosedFuture>,
}

impl<T> Drop for TcpConnectionSniffer<T> {
Expand Down Expand Up @@ -432,14 +432,15 @@ mod test {
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
time::{Duration, Instant},
};

use api::TcpSnifferApi;
use mirrord_protocol::{
tcp::{DaemonTcp, LayerTcp, NewTcpConnection, TcpClose, TcpData},
ConnectionId, LogLevel,
};
use rstest::rstest;
use tcp_capture::test::TcpPacketsChannel;
use tokio::sync::mpsc;

Expand Down Expand Up @@ -856,4 +857,45 @@ mod test {
}),
);
}

/// Verifies that [`TcpConnectionSniffer`] reacts to [`TcpSnifferApi`] being dropped
/// and clears the packet filter.
#[rstest]
#[timeout(Duration::from_secs(5))]
#[tokio::test]
async fn cleanup_on_client_closed() {
let mut setup = TestSnifferSetup::new();

let mut api = setup.get_api().await;

api.handle_client_message(LayerTcp::PortSubscribe(80))
.await
.unwrap();
assert_eq!(
api.recv().await.unwrap(),
(DaemonTcp::SubscribeResult(Ok(80)), None),
);
assert_eq!(setup.times_filter_changed(), 1);

std::mem::drop(api);
let dropped_at = Instant::now();

loop {
match setup.times_filter_changed() {
1 => {
println!(
"filter still not changed {}ms after client closed",
dropped_at.elapsed().as_millis()
);
tokio::time::sleep(Duration::from_millis(20)).await;
}

2 => {
break;
}

other => panic!("unexpected times filter changed {other}"),
}
}
}
}
166 changes: 143 additions & 23 deletions mirrord/agent/src/steal/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ use tokio::{
use tokio_util::sync::CancellationToken;
use tracing::{warn, Level};

use super::http::HttpResponseFallback;
use super::{http::HttpResponseFallback, subscriptions::PortRedirector};
use crate::{
error::AgentResult,
error::{AgentError, AgentResult},
metrics::HTTP_REQUEST_IN_PROGRESS_COUNT,
steal::{
connections::{
Expand Down Expand Up @@ -292,9 +292,9 @@ struct TcpStealerConfig {
/// run in the same network namespace as the agent's target.
///
/// Enabled by the `steal` feature for incoming traffic.
pub(crate) struct TcpConnectionStealer {
pub(crate) struct TcpConnectionStealer<Redirector: PortRedirector = IpTablesRedirector> {
/// For managing active subscriptions and port redirections.
port_subscriptions: PortSubscriptions<IpTablesRedirector>,
port_subscriptions: PortSubscriptions<Redirector>,

/// For receiving commands.
/// The other end of this channel belongs to [`TcpStealerApi`](super::api::TcpStealerApi).
Expand All @@ -304,7 +304,7 @@ pub(crate) struct TcpConnectionStealer {
clients: HashMap<ClientId, Client>,

/// [`Future`](std::future::Future)s that resolve when stealer clients close.
clients_closed: FuturesUnordered<ChannelClosedFuture<StealerCommand>>,
clients_closed: FuturesUnordered<ChannelClosedFuture>,

/// Set of active connections stolen by [`Self::port_subscriptions`].
connections: StolenConnections,
Expand All @@ -313,7 +313,7 @@ pub(crate) struct TcpConnectionStealer {
support_ipv6: bool,
}

impl TcpConnectionStealer {
impl TcpConnectionStealer<IpTablesRedirector> {
pub const TASK_NAME: &'static str = "Stealer";

/// Initializes a new [`TcpConnectionStealer`], but doesn't start the actual work.
Expand All @@ -327,25 +327,39 @@ impl TcpConnectionStealer {
.from_env::<TcpStealerConfig>()
.unwrap_or_default();

let port_subscriptions = {
let redirector = IpTablesRedirector::new(
config.stealer_flush_connections,
config.pod_ips,
support_ipv6,
)
.await?;
let redirector = IpTablesRedirector::new(
config.stealer_flush_connections,
config.pod_ips,
support_ipv6,
)
.await?;

PortSubscriptions::new(redirector, 4)
};
Ok(Self::with_redirector(command_rx, support_ipv6, redirector))
}
}

Ok(Self {
port_subscriptions,
impl<Redirector> TcpConnectionStealer<Redirector>
where
Redirector: PortRedirector,
Redirector::Error: std::error::Error + Into<AgentError>,
AgentError: From<Redirector::Error>,
{
/// Creates a new stealer.
///
/// Given [`PortRedirector`] will be used to capture incoming connections.
pub(crate) fn with_redirector(
command_rx: Receiver<StealerCommand>,
support_ipv6: bool,
redirector: Redirector,
) -> Self {
Self {
port_subscriptions: PortSubscriptions::new(redirector, 4),
command_rx,
clients: HashMap::with_capacity(8),
clients_closed: Default::default(),
connections: StolenConnections::with_capacity(8),
support_ipv6,
})
}
}

/// Runs the tcp traffic stealer loop.
Expand Down Expand Up @@ -383,7 +397,7 @@ impl TcpConnectionStealer {
}
Err(error) => {
tracing::error!(?error, "Failed to accept a stolen connection");
break Err(error);
break Err(error.into());
}
},

Expand Down Expand Up @@ -644,6 +658,8 @@ impl TcpConnectionStealer {

match command {
Command::NewClient(daemon_tx, protocol_version) => {
self.clients_closed
.push(ChannelClosedFuture::new(daemon_tx.clone(), client_id));
self.clients.insert(
client_id,
Client {
Expand Down Expand Up @@ -708,7 +724,7 @@ impl TcpConnectionStealer {

#[cfg(test)]
mod test {
use std::net::SocketAddr;
use std::{net::SocketAddr, time::Duration};

use bytes::Bytes;
use futures::{future::BoxFuture, FutureExt};
Expand All @@ -719,18 +735,75 @@ mod test {
service::Service,
};
use hyper_util::rt::TokioIo;
use mirrord_protocol::tcp::{ChunkedRequest, DaemonTcp, InternalHttpBodyFrame};
use mirrord_protocol::{
tcp::{ChunkedRequest, DaemonTcp, Filter, HttpFilter, InternalHttpBodyFrame, StealType},
Port,
};
use rstest::rstest;
use tokio::{
net::{TcpListener, TcpStream},
sync::{
mpsc::{self, Receiver, Sender},
oneshot,
oneshot, watch,
},
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::CancellationToken;

use super::AgentError;
use crate::{
steal::{
connection::{Client, MatchedHttpRequest},
subscriptions::PortRedirector,
TcpConnectionStealer, TcpStealerApi,
},
watched_task::TaskStatus,
};

/// Notification about a requested redirection operation.
///
/// Produced by [`NotifyingRedirector`].
#[derive(Debug, PartialEq, Eq)]
enum RedirectNotification {
Added(Port),
Removed(Port),
Cleanup,
}

/// Test [`PortRedirector`] that never fails and notifies about requested operations using an
/// [`mpsc::channel`].
struct NotifyingRedirector(Sender<RedirectNotification>);

#[async_trait::async_trait]
impl PortRedirector for NotifyingRedirector {
type Error = AgentError;

async fn add_redirection(&mut self, port: Port) -> Result<(), Self::Error> {
self.0
.send(RedirectNotification::Added(port))
.await
.unwrap();
Ok(())
}

async fn remove_redirection(&mut self, port: Port) -> Result<(), Self::Error> {
self.0
.send(RedirectNotification::Removed(port))
.await
.unwrap();
Ok(())
}

async fn cleanup(&mut self) -> Result<(), Self::Error> {
self.0.send(RedirectNotification::Cleanup).await.unwrap();
Ok(())
}

async fn next_connection(&mut self) -> Result<(TcpStream, SocketAddr), Self::Error> {
std::future::pending().await
}
}

use crate::steal::connection::{Client, MatchedHttpRequest};
async fn prepare_dummy_service() -> (
SocketAddr,
Receiver<(Request<Incoming>, oneshot::Sender<Response<Empty<Bytes>>>)>,
Expand Down Expand Up @@ -907,4 +980,51 @@ mod test {

let _ = response_tx.send(Response::new(Empty::default()));
}

/// Verifies that [`TcpConnectionStealer`] removes client's port subscriptions
/// when client's [`TcpStealerApi`] is dropped.
#[rstest]
#[timeout(Duration::from_secs(5))]
#[tokio::test]
async fn cleanup_on_client_closed() {
let (command_tx, command_rx) = mpsc::channel(8);
let (redirect_tx, mut redirect_rx) = mpsc::channel(2);
let stealer = TcpConnectionStealer::with_redirector(
command_rx,
false,
NotifyingRedirector(redirect_tx),
);

tokio::spawn(stealer.start(CancellationToken::new()));

let (_dummy_tx, dummy_rx) = watch::channel(None);
let task_status = TaskStatus::dummy(TcpConnectionStealer::TASK_NAME, dummy_rx);
let mut api = TcpStealerApi::new(
0,
command_tx.clone(),
task_status,
8,
mirrord_protocol::VERSION.clone(),
)
.await
.unwrap();

api.port_subscribe(StealType::FilteredHttpEx(
80,
HttpFilter::Header(Filter::new("user: test".into()).unwrap()),
))
.await
.unwrap();

let response = api.recv().await.unwrap();
assert_eq!(response, DaemonTcp::SubscribeResult(Ok(80)));

let notification = redirect_rx.recv().await.unwrap();
assert_eq!(notification, RedirectNotification::Added(80));

std::mem::drop(api);

let notification = redirect_rx.recv().await.unwrap();
assert_eq!(notification, RedirectNotification::Removed(80));
}
}
Loading

0 comments on commit afccbc8

Please sign in to comment.