diff --git a/mirrord/intproxy/src/background_tasks.rs b/mirrord/intproxy/src/background_tasks.rs index 988ca56f332..a425bbf19a2 100644 --- a/mirrord/intproxy/src/background_tasks.rs +++ b/mirrord/intproxy/src/background_tasks.rs @@ -39,7 +39,7 @@ impl MessageBus { } /// Cast `&mut MessageBus` as `&mut MessageBus` only if they share the same message types - pub fn cast(&mut self) -> &mut MessageBus + pub(crate) fn cast(&mut self) -> &mut MessageBus where R: BackgroundTask, { @@ -267,17 +267,10 @@ where self.register(RestartableBackgroundTaskWrapper { task }, id, channel_size) } - pub fn tasks_ids(&self) -> impl Iterator { - self.handles.keys() - } - - pub async fn kill_task(&mut self, id: Id) { - self.streams.remove(&id); - let Some(task) = self.handles.remove(&id) else { - return; - }; - - task.abort(); + pub fn clear(&mut self) { + for (id, _) in self.handles.drain() { + self.streams.remove(&id); + } } /// Returns the next update from one of registered tasks. diff --git a/mirrord/intproxy/src/lib.rs b/mirrord/intproxy/src/lib.rs index 145be8752db..b904434fc46 100644 --- a/mirrord/intproxy/src/lib.rs +++ b/mirrord/intproxy/src/lib.rs @@ -18,6 +18,7 @@ use proxies::{ outgoing::{OutgoingProxy, OutgoingProxyMessage}, simple::{SimpleProxy, SimpleProxyMessage}, }; +use semver::Version; use tokio::{net::TcpListener, time}; use tracing::Level; @@ -60,6 +61,10 @@ pub struct IntProxy { any_connection_accepted: bool, background_tasks: BackgroundTasks, task_txs: TaskTxs, + /// [`mirrord_protocol`] version negotiated with the agent. + /// Determines whether we can use some messages, like [`FileRequest::ReadDirBatch`] or + /// [`FileRequest::ReadLink`]. + protocol_version: Option, } impl IntProxy { @@ -128,6 +133,7 @@ impl IntProxy { ping_pong, files, }, + protocol_version: None, } } @@ -318,6 +324,8 @@ impl IntProxy { .await } DaemonMessage::SwitchProtocolVersionResponse(protocol_version) => { + let _ = self.protocol_version.insert(protocol_version.clone()); + if CLIENT_READY_FOR_LOGS.matches(&protocol_version) { self.task_txs.agent.send(ClientMessage::ReadyForLogs).await; } @@ -413,8 +421,18 @@ impl IntProxy { #[tracing::instrument(level = Level::TRACE, skip(self), err)] async fn handle_connection_refresh(&self) -> Result<(), IntProxyError> { self.task_txs - .simple - .send(SimpleProxyMessage::ConnectionRefresh) + .agent + .send(ClientMessage::SwitchProtocolVersion( + self.protocol_version + .as_ref() + .unwrap_or(&mirrord_protocol::VERSION) + .clone(), + )) + .await; + + self.task_txs + .files + .send(FilesProxyMessage::ConnectionRefresh) .await; self.task_txs diff --git a/mirrord/intproxy/src/proxies/files.rs b/mirrord/intproxy/src/proxies/files.rs index 0d24d9d41c5..60da47e3eca 100644 --- a/mirrord/intproxy/src/proxies/files.rs +++ b/mirrord/intproxy/src/proxies/files.rs @@ -36,6 +36,8 @@ pub enum FilesProxyMessage { LayerForked(LayerForked), /// Layer instance closed. LayerClosed(LayerClosed), + /// Agent connection was refreshed + ConnectionRefresh, } /// Error that can occur in [`FilesProxy`]. @@ -761,6 +763,20 @@ impl FilesProxy { Ok(()) } + + async fn handle_reconnect(&mut self, _message_bus: &mut MessageBus) { + for (_, fds) in self.remote_files.drain() { + for fd in fds { + self.buffered_files.remove(&fd); + } + } + + for (_, fds) in self.remote_dirs.drain() { + for fd in fds { + self.buffered_dirs.remove(&fd); + } + } + } } impl BackgroundTask for FilesProxy { @@ -785,6 +801,7 @@ impl BackgroundTask for FilesProxy { } FilesProxyMessage::LayerForked(forked) => self.layer_forked(forked), FilesProxyMessage::ProtocolVersion(version) => self.protocol_version(version), + FilesProxyMessage::ConnectionRefresh => self.handle_reconnect(message_bus).await, } } diff --git a/mirrord/intproxy/src/proxies/incoming.rs b/mirrord/intproxy/src/proxies/incoming.rs index 3c57b630b19..f16112e1b0f 100644 --- a/mirrord/intproxy/src/proxies/incoming.rs +++ b/mirrord/intproxy/src/proxies/incoming.rs @@ -536,11 +536,9 @@ impl IncomingProxy { } IncomingProxyMessage::ConnectionRefresh => { - let running_task_ids = self.tasks.tasks_ids().cloned().collect::>(); - - for task in running_task_ids { - self.tasks.kill_task(task).await; - } + self.mirror_tcp_proxies.clear(); + self.steal_tcp_proxies.clear(); + self.tasks.clear(); for subscription in self.subscriptions.iter_mut() { tracing::debug!(?subscription, "resubscribing"); diff --git a/mirrord/intproxy/src/proxies/simple.rs b/mirrord/intproxy/src/proxies/simple.rs index 2141d549417..6efa5416865 100644 --- a/mirrord/intproxy/src/proxies/simple.rs +++ b/mirrord/intproxy/src/proxies/simple.rs @@ -27,8 +27,6 @@ pub enum SimpleProxyMessage { GetEnvRes(RemoteResult>), /// Protocol version was negotiated with the agent. ProtocolVersion(Version), - /// Agent connection was refreshed need to negotiate version - ConnectionRefresh, } #[derive(Error, Debug)] @@ -125,13 +123,6 @@ impl BackgroundTask for SimpleProxy { .await } SimpleProxyMessage::ProtocolVersion(version) => self.set_protocol_version(version), - SimpleProxyMessage::ConnectionRefresh => { - if let Some(version) = &self.protocol_version { - message_bus - .send(ClientMessage::SwitchProtocolVersion(version.clone())) - .await - } - } } } diff --git a/mirrord/intproxy/src/remote_resources.rs b/mirrord/intproxy/src/remote_resources.rs index 055455b7818..645ca3fece2 100644 --- a/mirrord/intproxy/src/remote_resources.rs +++ b/mirrord/intproxy/src/remote_resources.rs @@ -123,4 +123,16 @@ where *self.counts.entry(resource).or_default() += 1; } } + + /// Removes all resources held all layers instances. + /// Returns an [`Iterator`] of layers and remote files/folders that were removed. + /// + /// Should be used for when the remote is lost and there is a need to restart. + #[tracing::instrument(level = Level::TRACE, skip(self))] + pub(crate) fn drain(&mut self) -> impl '_ + Iterator)> { + let ids: Vec<_> = self.by_layer.keys().cloned().collect(); + + ids.into_iter() + .map(|id| (id, self.remove_all(id).collect())) + } }