Skip to content
Merged
32 changes: 18 additions & 14 deletions pre-compute/src/compute/app_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,35 +42,33 @@ pub fn start_with_app<A: PreComputeAppTrait>(
pre_compute_app: &mut A,
chain_task_id: &str,
) -> ExitMode {
let exit_cause = match pre_compute_app.run() {
let exit_causes = match pre_compute_app.run() {
Ok(_) => {
info!("TEE pre-compute completed");
return ExitMode::Success;
}
Err(exit_cause) => {
error!("TEE pre-compute failed with known exit cause [{exit_cause:?}]");
exit_cause
Err(exit_causes) => {
error!("TEE pre-compute failed with known exit cause [{exit_causes:?}]");
exit_causes
}
};

let authorization = match get_challenge(chain_task_id) {
Ok(auth) => auth,
Err(_) => {
error!("Failed to sign exitCause message [{exit_cause:?}]");
error!("Failed to sign exitCause message [{exit_causes:?}]");
return ExitMode::UnreportedFailure;
}
};

let exit_causes = vec![exit_cause.clone()];

match WorkerApiClient::from_env().send_exit_causes_for_pre_compute_stage(
&authorization,
chain_task_id,
&exit_causes,
) {
Ok(_) => ExitMode::ReportedFailure,
Err(_) => {
error!("Failed to report exitCause [{exit_cause:?}]");
error!("Failed to report exitCause [{exit_causes:?}]");
ExitMode::UnreportedFailure
}
}
Expand Down Expand Up @@ -150,7 +148,7 @@ mod pre_compute_start_with_app_tests {

let mut mock = MockPreComputeAppTrait::new();
mock.expect_run()
.returning(|| Err(ReplicateStatusCause::PreComputeWorkerAddressMissing));
.returning(|| Err(vec![ReplicateStatusCause::PreComputeWorkerAddressMissing]));

temp_env::with_vars(env_vars_to_set, || {
temp_env::with_vars_unset(env_vars_to_unset, || {
Expand All @@ -172,8 +170,11 @@ mod pre_compute_start_with_app_tests {
let env_vars_to_unset = vec![ENV_SIGN_TEE_CHALLENGE_PRIVATE_KEY];

let mut mock = MockPreComputeAppTrait::new();
mock.expect_run()
.returning(|| Err(ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing));
mock.expect_run().returning(|| {
Err(vec![
ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing,
])
});

temp_env::with_vars(env_vars_to_set, || {
temp_env::with_vars_unset(env_vars_to_unset, || {
Expand All @@ -199,8 +200,11 @@ mod pre_compute_start_with_app_tests {
let mock_server_addr_string = mock_server.address().to_string();

let mut mock = MockPreComputeAppTrait::new();
mock.expect_run()
.returning(|| Err(ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing));
mock.expect_run().returning(|| {
Err(vec![
ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing,
])
});

let result_code = tokio::task::spawn_blocking(move || {
let env_vars = vec![
Expand Down Expand Up @@ -248,7 +252,7 @@ mod pre_compute_start_with_app_tests {
let mut mock = MockPreComputeAppTrait::new();
mock.expect_run()
.times(1)
.returning(|| Err(ReplicateStatusCause::PreComputeOutputFolderNotFound));
.returning(|| Err(vec![ReplicateStatusCause::PreComputeOutputFolderNotFound]));

// Move the blocking operations into spawn_blocking
let result_code = tokio::task::spawn_blocking(move || {
Expand Down
4 changes: 2 additions & 2 deletions pre-compute/src/compute/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ pub enum ReplicateStatusCause {
PreComputeInvalidTeeSignature,
#[error("IS_DATASET_REQUIRED environment variable is missing")]
PreComputeIsDatasetRequiredMissing,
#[error("Input files download failed")]
PreComputeInputFileDownloadFailed,
#[error("Input file download failed for input {0}")]
PreComputeInputFileDownloadFailed(usize),
#[error("Input files number related environment variable is missing")]
PreComputeInputFilesNumberMissing,
#[error("Invalid dataset checksum for dataset {0}")]
Expand Down
66 changes: 45 additions & 21 deletions pre-compute/src/compute/pre_compute_app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ use std::path::{Path, PathBuf};

#[cfg_attr(test, automock)]
pub trait PreComputeAppTrait {
fn run(&mut self) -> Result<(), ReplicateStatusCause>;
fn run(&mut self) -> Result<(), Vec<ReplicateStatusCause>>;
fn check_output_folder(&self) -> Result<(), ReplicateStatusCause>;
fn download_input_files(&self) -> Result<(), ReplicateStatusCause>;
fn download_input_files(&self) -> Result<(), Vec<ReplicateStatusCause>>;
fn save_plain_dataset_file(
&self,
plain_content: &[u8],
Expand Down Expand Up @@ -55,17 +55,33 @@ impl PreComputeAppTrait for PreComputeApp {
/// let mut app = PreComputeApp::new("task_id".to_string());
/// app.run();
/// ```
fn run(&mut self) -> Result<(), ReplicateStatusCause> {
// TODO: Collect all errors instead of propagating immediately, and return the list of errors
self.pre_compute_args = PreComputeArgs::read_args()?;
self.check_output_folder()?;
fn run(&mut self) -> Result<(), Vec<ReplicateStatusCause>> {
let (args, mut exit_causes) = PreComputeArgs::read_args();
self.pre_compute_args = args;

if let Err(exit_cause) = self.check_output_folder() {
return Err(vec![exit_cause]);
}

for dataset in self.pre_compute_args.datasets.iter() {
let encrypted_content = dataset.download_encrypted_dataset(&self.chain_task_id)?;
let plain_content = dataset.decrypt_dataset(&encrypted_content)?;
self.save_plain_dataset_file(&plain_content, &dataset.filename)?;
if let Err(exit_cause) = dataset
.download_encrypted_dataset(&self.chain_task_id)
.and_then(|encrypted_content| dataset.decrypt_dataset(&encrypted_content))
.and_then(|plain_content| {
self.save_plain_dataset_file(&plain_content, &dataset.filename)
})
{
exit_causes.push(exit_cause);
};
}
if let Err(exit_cause) = self.download_input_files() {
exit_causes.extend(exit_cause);
};
if !exit_causes.is_empty() {
Err(exit_causes)
} else {
Ok(())
}
self.download_input_files()?;
Ok(())
}

/// Checks whether the output folder specified in `pre_compute_args` exists.
Expand Down Expand Up @@ -105,19 +121,27 @@ impl PreComputeAppTrait for PreComputeApp {
/// This function panics if:
/// - `pre_compute_args` is `None`.
/// - `chain_task_id` is `None`.
fn download_input_files(&self) -> Result<(), ReplicateStatusCause> {
fn download_input_files(&self) -> Result<(), Vec<ReplicateStatusCause>> {
let mut exit_causes: Vec<ReplicateStatusCause> = vec![];
let args = &self.pre_compute_args;
let chain_task_id: &str = &self.chain_task_id;

for url in &args.input_files {
for (index, url) in args.input_files.iter().enumerate() {
info!("Downloading input file [chainTaskId:{chain_task_id}, url:{url}]");

let filename = sha256(url.to_string());
if download_file(url, &args.output_dir, &filename).is_none() {
return Err(ReplicateStatusCause::PreComputeInputFileDownloadFailed);
exit_causes.push(ReplicateStatusCause::PreComputeInputFileDownloadFailed(
index,
));
}
}
Ok(())

if !exit_causes.is_empty() {
Err(exit_causes)
} else {
Ok(())
}
}

/// Saves the decrypted (plain) dataset to disk in the configured output directory.
Expand Down Expand Up @@ -293,12 +317,12 @@ mod tests {
let result = app.download_input_files();
assert_eq!(
result.unwrap_err(),
ReplicateStatusCause::PreComputeInputFileDownloadFailed
vec![ReplicateStatusCause::PreComputeInputFileDownloadFailed(0)]
);
}

#[test]
fn test_partial_failure_stops_on_first_error() {
fn test_partial_failure_dont_stops_on_first_error() {
let (_container, json_url, xml_url) = start_container();

let temp_dir = TempDir::new().unwrap();
Expand All @@ -307,24 +331,24 @@ mod tests {
vec![
&json_url, // This should succeed
"https://invalid-url-that-should-fail.com/file.txt", // This should fail
&xml_url, // This shouldn't be reached
&xml_url, // This should succeed
],
temp_dir.path().to_str().unwrap(),
);

let result = app.download_input_files();
assert_eq!(
result.unwrap_err(),
ReplicateStatusCause::PreComputeInputFileDownloadFailed
vec![ReplicateStatusCause::PreComputeInputFileDownloadFailed(1)]
);

// First file should be downloaded with SHA256 filename
let json_hash = sha256(json_url);
assert!(temp_dir.path().join(json_hash).exists());

// Third file should NOT be downloaded (stopped on second failure)
// Third file should be downloaded (not stopped on second failure)
let xml_hash = sha256(xml_url);
assert!(!temp_dir.path().join(xml_hash).exists());
assert!(temp_dir.path().join(xml_hash).exists());
}
// endregion

Expand Down
Loading
Loading