diff --git a/nativelink-worker/src/local_worker.rs b/nativelink-worker/src/local_worker.rs index 2bc8d2bad..ec475b1c8 100644 --- a/nativelink-worker/src/local_worker.rs +++ b/nativelink-worker/src/local_worker.rs @@ -84,6 +84,35 @@ struct LocalWorkerImpl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsM metrics: Arc, } +#[derive(Debug)] +struct ActionsInTransitGuard { + counter: Arc, + active: bool, +} + +impl ActionsInTransitGuard { + fn new(counter: Arc) -> Self { + counter.fetch_add(1, Ordering::Release); + Self { + counter, + active: true, + } + } + + fn done(&mut self) { + if self.active { + self.counter.fetch_sub(1, Ordering::Release); + self.active = false; + } + } +} + +impl Drop for ActionsInTransitGuard { + fn drop(&mut self) { + self.done(); + } +} + async fn preconditions_met(precondition_script: Option) -> Result<(), Error> { let Some(precondition_script) = &precondition_script else { // No script means we are always ok to proceed. @@ -254,7 +283,8 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke let start_action_fut = { let precondition_script_cfg = self.config.experimental_precondition_script.clone(); - let actions_in_transit = self.actions_in_transit.clone(); + let mut actions_in_transit_guard = + ActionsInTransitGuard::new(self.actions_in_transit.clone()); let worker_id = self.worker_id.clone(); let running_actions_manager = self.running_actions_manager.clone(); let mut grpc_client = self.grpc_client.clone(); @@ -265,9 +295,7 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke metrics.preconditions.wrap(preconditions_met(precondition_script_cfg)) .and_then(|()| running_actions_manager.create_and_add_action(worker_id, start_execute)) .map(move |r| { - // Now that we either failed or registered our action, we can - // consider the action to no longer be in transit. - actions_in_transit.fetch_sub(1, Ordering::Release); + actions_in_transit_guard.done(); r }) .and_then(|action| { @@ -339,8 +367,6 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke } }; - self.actions_in_transit.fetch_add(1, Ordering::Release); - let add_future_channel = add_future_channel.clone(); info_span!( @@ -683,24 +709,28 @@ impl LocalWorker::Err(err).merge(remove_dir_result) - } else if let Err(err) = remove_dir_result { - error!(%operation_id, ?err, "Error removing working directory"); - Err(err) - } else { - Ok(()) + let remove_dir_result = match fs::remove_dir_all(action_directory).await { + Ok(()) => Ok(()), + Err(err) if err.code == Code::NotFound => Ok(()), + Err(err) => { + Err(err).err_tip(|| format!("Could not remove working directory {action_directory}")) + } + }; + + let cleanup_result = running_actions_manager.cleanup_action(operation_id); + match (cleanup_result, remove_dir_result) { + (Err(err), remove_dir_result) => { + error!(%operation_id, ?err, "Error cleaning up action"); + Err::<(), Error>(err).merge(remove_dir_result) + } + (Ok(()), Err(err)) => { + error!(%operation_id, ?err, "Error removing working directory"); + Err(err) + } + _ => Ok(()), } } @@ -2053,14 +2060,25 @@ impl RunningActionsManagerImpl { }) } + #[allow( + clippy::unnecessary_wraps, + reason = "We keep a Result here to preserve the existing API and future-proof error handling." + )] fn cleanup_action(&self, operation_id: &OperationId) -> Result<(), Error> { let mut running_actions = self.running_actions.lock(); - let result = running_actions.remove(operation_id).err_tip(|| { - format!("Expected operation id '{operation_id}' to exist in RunningActionsManagerImpl") - }); + if running_actions.remove(operation_id).is_none() { + warn!( + %operation_id, + "Cleanup requested for operation that was not tracked" + ); + self.metrics.cleanup_missing_action.inc(); + // No need to copy anything, we just are telling the receivers an event happened. + self.action_done_tx.send_modify(|()| {}); + return Ok(()); + } // No need to copy anything, we just are telling the receivers an event happened. self.action_done_tx.send_modify(|()| {}); - result.map(|_| ()) + Ok(()) } // Note: We do not capture metrics on this call, only `.kill_all()`. @@ -2260,6 +2278,8 @@ pub struct Metrics { get_finished_result: AsyncCounterWrapper, #[metric(help = "Number of times an action waited for cleanup to complete.")] cleanup_waits: CounterWithTime, + #[metric(help = "Number of cleanup calls where the action was already missing.")] + cleanup_missing_action: CounterWithTime, #[metric(help = "Number of stale directories removed during action retries.")] stale_removals: CounterWithTime, #[metric(help = "Number of timeouts while waiting for cleanup to complete.")] diff --git a/nativelink-worker/tests/local_worker_test.rs b/nativelink-worker/tests/local_worker_test.rs index 123cdd9e7..65a66fbc7 100644 --- a/nativelink-worker/tests/local_worker_test.rs +++ b/nativelink-worker/tests/local_worker_test.rs @@ -57,6 +57,7 @@ use pretty_assertions::assert_eq; use prost::Message; use rand::Rng; use tokio::io::AsyncWriteExt; +use tokio::time::timeout; use utils::local_worker_test_utils::{ setup_grpc_stream, setup_local_worker, setup_local_worker_with_config, }; @@ -290,6 +291,81 @@ async fn blake3_digest_function_registered_properly() -> Result<(), Error> { Ok(()) } +#[nativelink_test] +async fn disconnect_with_action_in_transit_does_not_panic() -> Result<(), Error> { + let mut test_context = setup_local_worker(HashMap::new()).await; + let streaming_response = test_context.maybe_streaming_response.take().unwrap(); + + { + let props = test_context + .client + .expect_connect_worker(Ok(streaming_response)) + .await; + assert_eq!(props, ConnectWorkerRequest::default()); + } + + let expected_worker_id = "foobar".to_string(); + let tx_stream = test_context.maybe_tx_stream.take().unwrap(); + { + tx_stream + .send(Frame::data( + encode_stream_proto(&UpdateForWorker { + update: Some(Update::ConnectionResult(ConnectionResult { + worker_id: expected_worker_id.clone(), + })), + }) + .unwrap(), + )) + .await + .map_err(|e| make_input_err!("Could not send : {:?}", e))?; + } + + { + tx_stream + .send(Frame::data( + encode_stream_proto(&UpdateForWorker { + update: Some(Update::StartAction(StartExecute { + execute_request: Some(nativelink_proto::build::bazel::remote::execution::v2::ExecuteRequest { + action_digest: None, + digest_function: nativelink_proto::build::bazel::remote::execution::v2::digest_function::Value::Sha256 as i32, + ..Default::default() + }), + operation_id: "pending-op".to_string(), + queued_timestamp: None, + platform: None, + worker_id: expected_worker_id, + })), + }) + .unwrap(), + )) + .await + .map_err(|e| make_input_err!("Could not send : {:?}", e))?; + } + + // Ensure the start action is in-flight but do not respond so it stays pending. + let (_worker_id, pending_start_execute) = test_context + .actions_manager + .wait_for_create_and_add_action_call() + .await; + assert_eq!(pending_start_execute.operation_id, "pending-op"); + + drop(tx_stream); + + timeout( + Duration::from_secs(2), + test_context.actions_manager.expect_kill_all(), + ) + .await + .expect("kill_all should be called when disconnecting with pending actions"); + + // Unblock any pending create_and_add_action future so the worker can settle. + test_context + .actions_manager + .send_create_and_add_action_response(Err(make_input_err!("Disconnected"))); + + Ok(()) +} + #[nativelink_test] async fn simple_worker_start_action_test() -> Result<(), Error> { let mut test_context = setup_local_worker(HashMap::new()).await; diff --git a/nativelink-worker/tests/running_actions_manager_test.rs b/nativelink-worker/tests/running_actions_manager_test.rs index 64ac8c0f7..a9bb81ef2 100644 --- a/nativelink-worker/tests/running_actions_manager_test.rs +++ b/nativelink-worker/tests/running_actions_manager_test.rs @@ -1387,6 +1387,97 @@ mod tests { Ok(()) } + #[nativelink_test] + async fn cleanup_is_idempotent_after_first_call() -> Result<(), Box> { + const WORKER_ID: &str = "foo_worker_id"; + + fn test_monotonic_clock() -> SystemTime { + static CLOCK: AtomicU64 = AtomicU64::new(0); + monotonic_clock(&CLOCK) + } + + let (_, _, cas_store, ac_store) = setup_stores().await?; + let root_action_directory = make_temp_path("root_action_directory"); + fs::create_dir_all(&root_action_directory).await?; + + let running_actions_manager = Arc::new(RunningActionsManagerImpl::new_with_callbacks( + RunningActionsManagerArgs { + root_action_directory: root_action_directory.clone(), + execution_configuration: ExecutionConfiguration::default(), + cas_store: cas_store.clone(), + ac_store: Some(Store::new(ac_store.clone())), + historical_store: Store::new(cas_store.clone()), + upload_action_result_config: + &nativelink_config::cas_server::UploadActionResultConfig { + upload_ac_results_strategy: + nativelink_config::cas_server::UploadCacheResultsStrategy::Never, + ..Default::default() + }, + max_action_timeout: Duration::MAX, + timeout_handled_externally: false, + directory_cache: None, + }, + Callbacks { + now_fn: test_monotonic_clock, + sleep_fn: |_duration| Box::pin(future::pending()), + }, + )?); + let command = Command { + arguments: vec!["echo".to_string(), "hello".to_string()], + output_paths: vec![], + ..Default::default() + }; + let command_digest = serialize_and_upload_message( + &command, + cas_store.as_pin(), + &mut DigestHasherFunc::Sha256.hasher(), + ) + .await?; + let input_root_digest = serialize_and_upload_message( + &Directory::default(), + cas_store.as_pin(), + &mut DigestHasherFunc::Sha256.hasher(), + ) + .await?; + let action = Action { + command_digest: Some(command_digest.into()), + input_root_digest: Some(input_root_digest.into()), + ..Default::default() + }; + let action_digest = serialize_and_upload_message( + &action, + cas_store.as_pin(), + &mut DigestHasherFunc::Sha256.hasher(), + ) + .await?; + + let queued_timestamp = make_system_time(1000); + let operation_id = OperationId::default().to_string(); + + let running_action = running_actions_manager + .create_and_add_action( + WORKER_ID.to_string(), + StartExecute { + execute_request: Some(ExecuteRequest { + action_digest: Some(action_digest.into()), + digest_function: ProtoDigestFunction::Sha256 as i32, + ..Default::default() + }), + operation_id, + queued_timestamp: Some(queued_timestamp.into()), + platform: action.platform.clone(), + worker_id: WORKER_ID.to_string(), + }, + ) + .await?; + + running_action.clone().cleanup().await?; + running_action.clone().cleanup().await?; + + running_actions_manager.kill_all().await; + Ok(()) + } + #[nativelink_test] async fn kill_ends_action() -> Result<(), Box> { const WORKER_ID: &str = "foo_worker_id"; diff --git a/nativelink-worker/tests/utils/mock_running_actions_manager.rs b/nativelink-worker/tests/utils/mock_running_actions_manager.rs index 4efe50132..d1914d6b4 100644 --- a/nativelink-worker/tests/utils/mock_running_actions_manager.rs +++ b/nativelink-worker/tests/utils/mock_running_actions_manager.rs @@ -94,6 +94,27 @@ impl MockRunningActionsManager { req } + pub(crate) async fn wait_for_create_and_add_action_call(&self) -> (String, StartExecute) { + let mut rx_call_lock = self.rx_call.lock().await; + let RunningActionManagerCalls::CreateAndAddAction(req) = rx_call_lock + .recv() + .await + .expect("Could not receive msg in mpsc") + else { + panic!("Got incorrect call waiting for create_and_add_action") + }; + req + } + + pub(crate) fn send_create_and_add_action_response( + &self, + result: Result, Error>, + ) { + self.tx_resp + .send(RunningActionManagerReturns::CreateAndAddAction(result)) + .expect("Could not send request to mpsc"); + } + pub(crate) async fn expect_cache_action_result( &self, ) -> (DigestInfo, ActionResult, DigestHasherFunc) {