diff --git a/mirrord/intproxy/src/agent_conn.rs b/mirrord/intproxy/src/agent_conn.rs index 7e448f8e5ff..99621a23a20 100644 --- a/mirrord/intproxy/src/agent_conn.rs +++ b/mirrord/intproxy/src/agent_conn.rs @@ -38,7 +38,7 @@ use tracing::Level; use crate::{ background_tasks::{BackgroundTask, MessageBus, RestartableBackgroundTask}, - ProxyMessage, + main_tasks::{ConnectionRefresh, ProxyMessage}, }; mod portforward; @@ -283,6 +283,10 @@ impl RestartableBackgroundTask for AgentConnection { config, connect_info, } => { + message_bus + .send(ProxyMessage::ConnectionRefresh(ConnectionRefresh::Start)) + .await; + let retry_strategy = ExponentialBackoff::from_millis(50).map(jitter).take(10); let connection = Retry::spawn(retry_strategy, || async move { @@ -301,7 +305,9 @@ impl RestartableBackgroundTask for AgentConnection { match connection { Ok(connection) => { *self = connection; - message_bus.send(ProxyMessage::ConnectionRefresh).await; + message_bus + .send(ProxyMessage::ConnectionRefresh(ConnectionRefresh::End)) + .await; ControlFlow::Continue(()) } diff --git a/mirrord/intproxy/src/lib.rs b/mirrord/intproxy/src/lib.rs index b639e0aeacb..1162db0ccbb 100644 --- a/mirrord/intproxy/src/lib.rs +++ b/mirrord/intproxy/src/lib.rs @@ -2,7 +2,10 @@ #![warn(clippy::indexing_slicing)] #![deny(unused_crate_dependencies)] -use std::{collections::HashMap, time::Duration}; +use std::{ + collections::{HashMap, VecDeque}, + time::Duration, +}; use background_tasks::{BackgroundTasks, TaskSender, TaskUpdate}; use error::UnexpectedAgentMessage; @@ -26,7 +29,7 @@ use crate::{ agent_conn::AgentConnection, background_tasks::{RestartableBackgroundTaskWrapper, TaskError}, error::IntProxyError, - main_tasks::LayerClosed, + main_tasks::{ConnectionRefresh, LayerClosed}, }; pub mod agent_conn; @@ -64,6 +67,10 @@ pub struct IntProxy { /// [`mirrord_protocol`] version negotiated with the agent. protocol_version: Option, + + /// Temporary message queue for any [`ProxyMessage`] from layer or to agent that are sent + /// during reconnection state. + reconnect_task_queue: Option>, } impl IntProxy { @@ -133,6 +140,7 @@ impl IntProxy { files, }, protocol_version: None, + reconnect_task_queue: Default::default(), } } @@ -186,6 +194,15 @@ impl IntProxy { /// [`ProxyMessage::NewLayer`] is handled here, as an exception. async fn handle(&mut self, msg: ProxyMessage) -> Result<(), IntProxyError> { match msg { + ProxyMessage::NewLayer(_) | ProxyMessage::FromLayer(_) | ProxyMessage::ToAgent(_) + if self.reconnect_task_queue.is_some() => + { + // We are in reconnect state so should queue this message. + self.reconnect_task_queue + .as_mut() + .expect("reconnect_task_queue should contain value when in reconnect state") + .push_back(msg); + } ProxyMessage::NewLayer(new_layer) => { self.any_connection_accepted = true; @@ -230,7 +247,7 @@ impl IntProxy { .await; } } - ProxyMessage::ConnectionRefresh => self.handle_connection_refresh().await?, + ProxyMessage::ConnectionRefresh(kind) => self.handle_connection_refresh(kind).await?, } Ok(()) @@ -418,26 +435,55 @@ impl IntProxy { } #[tracing::instrument(level = Level::TRACE, skip(self), err)] - async fn handle_connection_refresh(&self) -> Result<(), IntProxyError> { - self.task_txs - .agent - .send(ClientMessage::SwitchProtocolVersion( - self.protocol_version - .as_ref() - .unwrap_or(&mirrord_protocol::VERSION) - .clone(), - )) - .await; + async fn handle_connection_refresh( + &mut self, + kind: ConnectionRefresh, + ) -> Result<(), IntProxyError> { + match kind { + ConnectionRefresh::Start => { + // Initialise default reconnect message queue + self.reconnect_task_queue.get_or_insert_default(); + } + ConnectionRefresh::End => { + let Some(task_queue) = self.reconnect_task_queue.take() else { + return Err(IntProxyError::AgentFailed( + "unexpected state: agent reconnected finished without correctly initialzing a reconnect" + .into(), + )); + }; - self.task_txs - .files - .send(FilesProxyMessage::ConnectionRefresh) - .await; + self.task_txs + .agent + .send(ClientMessage::SwitchProtocolVersion( + self.protocol_version + .as_ref() + .unwrap_or(&mirrord_protocol::VERSION) + .clone(), + )) + .await; - self.task_txs - .incoming - .send(IncomingProxyMessage::ConnectionRefresh) - .await; + self.task_txs + .files + .send(FilesProxyMessage::ConnectionRefresh) + .await; + + self.task_txs + .incoming + .send(IncomingProxyMessage::ConnectionRefresh) + .await; + + Box::pin(async { + for msg in task_queue { + tracing::debug!(?msg, "dequeueing message for reconnect"); + + self.handle(msg).await? + } + + Ok::<(), IntProxyError>(()) + }) + .await? + } + } Ok(()) } diff --git a/mirrord/intproxy/src/main_tasks.rs b/mirrord/intproxy/src/main_tasks.rs index 7b7810c3a9c..bb61c8ec90f 100644 --- a/mirrord/intproxy/src/main_tasks.rs +++ b/mirrord/intproxy/src/main_tasks.rs @@ -20,7 +20,7 @@ pub enum ProxyMessage { /// New layer instance to serve. NewLayer(NewLayer), /// Connection to agent was dropped and needs reload. - ConnectionRefresh, + ConnectionRefresh(ConnectionRefresh), } #[cfg(test)] @@ -138,3 +138,11 @@ pub struct LayerForked { pub struct LayerClosed { pub id: LayerId, } + +/// Notification about start and end of reconnection to agent. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub enum ConnectionRefresh { + Start, + End, +}