Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed HTTP filter stuck forever #3031

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/+http-filter-cleanup.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Agent now correctly clears incoming port subscriptions of disconnected clients.
46 changes: 44 additions & 2 deletions mirrord/agent/src/sniffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ pub(crate) struct TcpConnectionSniffer<T> {
sessions: TCPSessionMap,

client_txs: HashMap<ClientId, Sender<SniffedConnection>>,
clients_closed: FuturesUnordered<ChannelClosedFuture<SniffedConnection>>,
clients_closed: FuturesUnordered<ChannelClosedFuture>,
}

impl<T> Drop for TcpConnectionSniffer<T> {
Expand Down Expand Up @@ -432,14 +432,15 @@ mod test {
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
time::{Duration, Instant},
};

use api::TcpSnifferApi;
use mirrord_protocol::{
tcp::{DaemonTcp, LayerTcp, NewTcpConnection, TcpClose, TcpData},
ConnectionId, LogLevel,
};
use rstest::rstest;
use tcp_capture::test::TcpPacketsChannel;
use tokio::sync::mpsc;

Expand Down Expand Up @@ -856,4 +857,45 @@ mod test {
}),
);
}

/// Verifies that [`TcpConnectionSniffer`] reacts to [`TcpSnifferApi`] being dropped
/// and clears the packet filter.
#[rstest]
#[timeout(Duration::from_secs(5))]
#[tokio::test]
async fn cleanup_on_client_closed() {
let mut setup = TestSnifferSetup::new();

let mut api = setup.get_api().await;

api.handle_client_message(LayerTcp::PortSubscribe(80))
.await
.unwrap();
assert_eq!(
api.recv().await.unwrap(),
(DaemonTcp::SubscribeResult(Ok(80)), None),
);
assert_eq!(setup.times_filter_changed(), 1);

std::mem::drop(api);
let dropped_at = Instant::now();

loop {
match setup.times_filter_changed() {
1 => {
println!(
"filter still not changed {}ms after client closed",
dropped_at.elapsed().as_millis()
);
tokio::time::sleep(Duration::from_millis(20)).await;
}

2 => {
break;
}

other => panic!("unexpected times filter changed {other}"),
}
}
}
}
166 changes: 143 additions & 23 deletions mirrord/agent/src/steal/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ use tokio::{
use tokio_util::sync::CancellationToken;
use tracing::{warn, Level};

use super::http::HttpResponseFallback;
use super::{http::HttpResponseFallback, subscriptions::PortRedirector};
use crate::{
error::AgentResult,
error::{AgentError, AgentResult},
metrics::HTTP_REQUEST_IN_PROGRESS_COUNT,
steal::{
connections::{
Expand Down Expand Up @@ -292,9 +292,9 @@ struct TcpStealerConfig {
/// run in the same network namespace as the agent's target.
///
/// Enabled by the `steal` feature for incoming traffic.
pub(crate) struct TcpConnectionStealer {
pub(crate) struct TcpConnectionStealer<Redirector: PortRedirector = IpTablesRedirector> {
/// For managing active subscriptions and port redirections.
port_subscriptions: PortSubscriptions<IpTablesRedirector>,
port_subscriptions: PortSubscriptions<Redirector>,

/// For receiving commands.
/// The other end of this channel belongs to [`TcpStealerApi`](super::api::TcpStealerApi).
Expand All @@ -304,7 +304,7 @@ pub(crate) struct TcpConnectionStealer {
clients: HashMap<ClientId, Client>,

/// [`Future`](std::future::Future)s that resolve when stealer clients close.
clients_closed: FuturesUnordered<ChannelClosedFuture<StealerCommand>>,
clients_closed: FuturesUnordered<ChannelClosedFuture>,

/// Set of active connections stolen by [`Self::port_subscriptions`].
connections: StolenConnections,
Expand All @@ -313,7 +313,7 @@ pub(crate) struct TcpConnectionStealer {
support_ipv6: bool,
}

impl TcpConnectionStealer {
impl TcpConnectionStealer<IpTablesRedirector> {
pub const TASK_NAME: &'static str = "Stealer";

/// Initializes a new [`TcpConnectionStealer`], but doesn't start the actual work.
Expand All @@ -327,25 +327,39 @@ impl TcpConnectionStealer {
.from_env::<TcpStealerConfig>()
.unwrap_or_default();

let port_subscriptions = {
let redirector = IpTablesRedirector::new(
config.stealer_flush_connections,
config.pod_ips,
support_ipv6,
)
.await?;
let redirector = IpTablesRedirector::new(
config.stealer_flush_connections,
config.pod_ips,
support_ipv6,
)
.await?;

PortSubscriptions::new(redirector, 4)
};
Ok(Self::with_redirector(command_rx, support_ipv6, redirector))
}
}

Ok(Self {
port_subscriptions,
impl<Redirector> TcpConnectionStealer<Redirector>
where
Redirector: PortRedirector,
Redirector::Error: std::error::Error + Into<AgentError>,
AgentError: From<Redirector::Error>,
{
/// Creates a new stealer.
///
/// Given [`PortRedirector`] will be used to capture incoming connections.
pub(crate) fn with_redirector(
command_rx: Receiver<StealerCommand>,
support_ipv6: bool,
redirector: Redirector,
) -> Self {
Self {
port_subscriptions: PortSubscriptions::new(redirector, 4),
command_rx,
clients: HashMap::with_capacity(8),
clients_closed: Default::default(),
connections: StolenConnections::with_capacity(8),
support_ipv6,
})
}
}

/// Runs the tcp traffic stealer loop.
Expand Down Expand Up @@ -383,7 +397,7 @@ impl TcpConnectionStealer {
}
Err(error) => {
tracing::error!(?error, "Failed to accept a stolen connection");
break Err(error);
break Err(error.into());
}
},

Expand Down Expand Up @@ -644,6 +658,8 @@ impl TcpConnectionStealer {

match command {
Command::NewClient(daemon_tx, protocol_version) => {
self.clients_closed
.push(ChannelClosedFuture::new(daemon_tx.clone(), client_id));
self.clients.insert(
client_id,
Client {
Expand Down Expand Up @@ -708,7 +724,7 @@ impl TcpConnectionStealer {

#[cfg(test)]
mod test {
use std::net::SocketAddr;
use std::{net::SocketAddr, time::Duration};

use bytes::Bytes;
use futures::{future::BoxFuture, FutureExt};
Expand All @@ -719,18 +735,75 @@ mod test {
service::Service,
};
use hyper_util::rt::TokioIo;
use mirrord_protocol::tcp::{ChunkedRequest, DaemonTcp, InternalHttpBodyFrame};
use mirrord_protocol::{
tcp::{ChunkedRequest, DaemonTcp, Filter, HttpFilter, InternalHttpBodyFrame, StealType},
Port,
};
use rstest::rstest;
use tokio::{
net::{TcpListener, TcpStream},
sync::{
mpsc::{self, Receiver, Sender},
oneshot,
oneshot, watch,
},
};
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::CancellationToken;

use super::AgentError;
use crate::{
steal::{
connection::{Client, MatchedHttpRequest},
subscriptions::PortRedirector,
TcpConnectionStealer, TcpStealerApi,
},
watched_task::TaskStatus,
};

/// Notification about a requested redirection operation.
///
/// Produced by [`NotifyingRedirector`].
#[derive(Debug, PartialEq, Eq)]
enum RedirectNotification {
Razz4780 marked this conversation as resolved.
Show resolved Hide resolved
Added(Port),
Removed(Port),
Cleanup,
}

/// Test [`PortRedirector`] that never fails and notifies about requested operations using an
/// [`mpsc::channel`].
struct NotifyingRedirector(Sender<RedirectNotification>);

#[async_trait::async_trait]
impl PortRedirector for NotifyingRedirector {
type Error = AgentError;

async fn add_redirection(&mut self, port: Port) -> Result<(), Self::Error> {
self.0
.send(RedirectNotification::Added(port))
.await
.unwrap();
Ok(())
}

async fn remove_redirection(&mut self, port: Port) -> Result<(), Self::Error> {
self.0
.send(RedirectNotification::Removed(port))
.await
.unwrap();
Ok(())
}

async fn cleanup(&mut self) -> Result<(), Self::Error> {
self.0.send(RedirectNotification::Cleanup).await.unwrap();
Ok(())
}

async fn next_connection(&mut self) -> Result<(TcpStream, SocketAddr), Self::Error> {
std::future::pending().await
}
}

use crate::steal::connection::{Client, MatchedHttpRequest};
async fn prepare_dummy_service() -> (
SocketAddr,
Receiver<(Request<Incoming>, oneshot::Sender<Response<Empty<Bytes>>>)>,
Expand Down Expand Up @@ -907,4 +980,51 @@ mod test {

let _ = response_tx.send(Response::new(Empty::default()));
}

/// Verifies that [`TcpConnectionStealer`] removes client's port subscriptions
/// when client's [`TcpStealerApi`] is dropped.
#[rstest]
#[timeout(Duration::from_secs(5))]
#[tokio::test]
async fn cleanup_on_client_closed() {
let (command_tx, command_rx) = mpsc::channel(8);
let (redirect_tx, mut redirect_rx) = mpsc::channel(2);
let stealer = TcpConnectionStealer::with_redirector(
command_rx,
false,
NotifyingRedirector(redirect_tx),
);

tokio::spawn(stealer.start(CancellationToken::new()));

let (_dummy_tx, dummy_rx) = watch::channel(None);
let task_status = TaskStatus::dummy(TcpConnectionStealer::TASK_NAME, dummy_rx);
let mut api = TcpStealerApi::new(
0,
command_tx.clone(),
task_status,
8,
mirrord_protocol::VERSION.clone(),
)
.await
.unwrap();

api.port_subscribe(StealType::FilteredHttpEx(
80,
HttpFilter::Header(Filter::new("user: test".into()).unwrap()),
))
.await
.unwrap();

let response = api.recv().await.unwrap();
assert_eq!(response, DaemonTcp::SubscribeResult(Ok(80)));

let notification = redirect_rx.recv().await.unwrap();
assert_eq!(notification, RedirectNotification::Added(80));

std::mem::drop(api);

let notification = redirect_rx.recv().await.unwrap();
assert_eq!(notification, RedirectNotification::Removed(80));
}
}
Loading
Loading