Skip to content
Merged
32 changes: 18 additions & 14 deletions pre-compute/src/compute/app_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,35 +42,33 @@ pub fn start_with_app<A: PreComputeAppTrait>(
pre_compute_app: &mut A,
chain_task_id: &str,
) -> ExitMode {
let exit_cause = match pre_compute_app.run() {
let exit_causes = match pre_compute_app.run() {
Ok(_) => {
info!("TEE pre-compute completed");
return ExitMode::Success;
}
Err(exit_cause) => {
error!("TEE pre-compute failed with known exit cause [{exit_cause:?}]");
exit_cause
Err(exit_causes) => {
error!("TEE pre-compute failed with known exit cause [{exit_causes:?}]");
exit_causes
}
};

let authorization = match get_challenge(chain_task_id) {
Ok(auth) => auth,
Err(_) => {
error!("Failed to sign exitCause message [{exit_cause:?}]");
error!("Failed to sign exitCause message [{exit_causes:?}]");
return ExitMode::UnreportedFailure;
}
};

let exit_causes = vec![exit_cause.clone()];

match WorkerApiClient::from_env().send_exit_causes_for_pre_compute_stage(
&authorization,
chain_task_id,
&exit_causes,
) {
Ok(_) => ExitMode::ReportedFailure,
Err(_) => {
error!("Failed to report exitCause [{exit_cause:?}]");
error!("Failed to report exitCause [{exit_causes:?}]");
ExitMode::UnreportedFailure
}
}
Expand Down Expand Up @@ -150,7 +148,7 @@ mod pre_compute_start_with_app_tests {

let mut mock = MockPreComputeAppTrait::new();
mock.expect_run()
.returning(|| Err(ReplicateStatusCause::PreComputeWorkerAddressMissing));
.returning(|| Err(vec![ReplicateStatusCause::PreComputeWorkerAddressMissing]));

temp_env::with_vars(env_vars_to_set, || {
temp_env::with_vars_unset(env_vars_to_unset, || {
Expand All @@ -172,8 +170,11 @@ mod pre_compute_start_with_app_tests {
let env_vars_to_unset = vec![ENV_SIGN_TEE_CHALLENGE_PRIVATE_KEY];

let mut mock = MockPreComputeAppTrait::new();
mock.expect_run()
.returning(|| Err(ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing));
mock.expect_run().returning(|| {
Err(vec![
ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing,
])
});

temp_env::with_vars(env_vars_to_set, || {
temp_env::with_vars_unset(env_vars_to_unset, || {
Expand All @@ -199,8 +200,11 @@ mod pre_compute_start_with_app_tests {
let mock_server_addr_string = mock_server.address().to_string();

let mut mock = MockPreComputeAppTrait::new();
mock.expect_run()
.returning(|| Err(ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing));
mock.expect_run().returning(|| {
Err(vec![
ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing,
])
});

let result_code = tokio::task::spawn_blocking(move || {
let env_vars = vec![
Expand Down Expand Up @@ -248,7 +252,7 @@ mod pre_compute_start_with_app_tests {
let mut mock = MockPreComputeAppTrait::new();
mock.expect_run()
.times(1)
.returning(|| Err(ReplicateStatusCause::PreComputeOutputFolderNotFound));
.returning(|| Err(vec![ReplicateStatusCause::PreComputeOutputFolderNotFound]));

// Move the blocking operations into spawn_blocking
let result_code = tokio::task::spawn_blocking(move || {
Expand Down
4 changes: 2 additions & 2 deletions pre-compute/src/compute/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ pub enum ReplicateStatusCause {
PreComputeInvalidTeeSignature,
#[error("IS_DATASET_REQUIRED environment variable is missing")]
PreComputeIsDatasetRequiredMissing,
#[error("Input files download failed")]
PreComputeInputFileDownloadFailed,
#[error("Input file download failed for input {0}")]
PreComputeInputFileDownloadFailed(String),
#[error("Input files number related environment variable is missing")]
PreComputeInputFilesNumberMissing,
#[error("Invalid dataset checksum for dataset {0}")]
Expand Down
115 changes: 81 additions & 34 deletions pre-compute/src/compute/pre_compute_app.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::compute::errors::ReplicateStatusCause;
use crate::compute::pre_compute_args::PreComputeArgs;
use crate::compute::utils::env_utils::{TeeSessionEnvironmentVariable, get_env_var_or_error};
use crate::compute::utils::file_utils::{download_file, write_file};
use crate::compute::utils::hash_utils::sha256;
use log::{error, info};
Expand All @@ -9,9 +10,9 @@ use std::path::{Path, PathBuf};

#[cfg_attr(test, automock)]
pub trait PreComputeAppTrait {
fn run(&mut self) -> Result<(), ReplicateStatusCause>;
fn run(&mut self) -> Result<(), Vec<ReplicateStatusCause>>;
fn check_output_folder(&self) -> Result<(), ReplicateStatusCause>;
fn download_input_files(&self) -> Result<(), ReplicateStatusCause>;
fn download_input_files(&self) -> Result<(), Vec<ReplicateStatusCause>>;
fn save_plain_dataset_file(
&self,
plain_content: &[u8],
Expand All @@ -37,15 +38,19 @@ impl PreComputeAppTrait for PreComputeApp {
/// Runs the complete pre-compute pipeline.
///
/// This method orchestrates the entire pre-compute process:
/// 1. Reads configuration arguments
/// 2. Validates the output folder exists
/// 3. Downloads and decrypts the dataset (if required)
/// 4. Downloads all input files
/// 1. Reads the output directory from environment variable `IEXEC_PRE_COMPUTE_OUT`
/// 2. Reads and validates configuration arguments from environment variables
/// 3. Validates the output folder exists
/// 4. Downloads and decrypts all datasets (if required)
/// 5. Downloads all input files
///
/// The method collects all errors encountered during execution and returns them together,
/// allowing partial completion when possible (e.g., if one dataset fails, others are still processed).
///
/// # Returns
///
/// - `Ok(())` if all operations completed successfully
/// - `Err(ReplicateStatusCause)` if any step failed
/// - `Err(Vec<ReplicateStatusCause>)` containing all errors encountered during execution
///
/// # Example
///
Expand All @@ -55,17 +60,46 @@ impl PreComputeAppTrait for PreComputeApp {
/// let mut app = PreComputeApp::new("task_id".to_string());
/// app.run();
/// ```
fn run(&mut self) -> Result<(), ReplicateStatusCause> {
// TODO: Collect all errors instead of propagating immediately, and return the list of errors
self.pre_compute_args = PreComputeArgs::read_args()?;
self.check_output_folder()?;
fn run(&mut self) -> Result<(), Vec<ReplicateStatusCause>> {
let (mut args, mut exit_causes): (PreComputeArgs, Vec<ReplicateStatusCause>);
match get_env_var_or_error(
TeeSessionEnvironmentVariable::IexecPreComputeOut,
ReplicateStatusCause::PreComputeOutputPathMissing,
) {
Ok(output_dir) => {
(args, exit_causes) = PreComputeArgs::read_args();
args.output_dir = output_dir;
}
Err(e) => {
error!("Failed to read output directory: {e:?}");
return Err(vec![e]);
}
};
self.pre_compute_args = args;

if let Err(exit_cause) = self.check_output_folder() {
return Err(vec![exit_cause]);
}

for dataset in self.pre_compute_args.datasets.iter() {
let encrypted_content = dataset.download_encrypted_dataset(&self.chain_task_id)?;
let plain_content = dataset.decrypt_dataset(&encrypted_content)?;
self.save_plain_dataset_file(&plain_content, &dataset.filename)?;
if let Err(exit_cause) = dataset
.download_encrypted_dataset(&self.chain_task_id)
.and_then(|encrypted_content| dataset.decrypt_dataset(&encrypted_content))
.and_then(|plain_content| {
self.save_plain_dataset_file(&plain_content, &dataset.filename)
})
{
exit_causes.push(exit_cause);
};
}
if let Err(exit_cause) = self.download_input_files() {
exit_causes.extend(exit_cause);
};
if !exit_causes.is_empty() {
Err(exit_causes)
} else {
Ok(())
}
self.download_input_files()?;
Ok(())
}

/// Checks whether the output folder specified in `pre_compute_args` exists.
Expand Down Expand Up @@ -93,31 +127,40 @@ impl PreComputeAppTrait for PreComputeApp {
/// Downloads the input files listed in `pre_compute_args.input_files` to the specified `output_dir`.
///
/// Each URL is hashed (SHA-256) to generate a unique local filename.
/// If any download fails, the function returns an error.
/// The method continues downloading all files even if some downloads fail.
///
/// # Returns
/// # Behavior
///
/// - `Ok(())` if all files are downloaded successfully.
/// - `Err(ReplicateStatusCause::PreComputeInputFileDownloadFailed)` if any file fails to download.
/// - Downloads continue even when individual files fail
/// - Successfully downloaded files are saved with SHA-256 hashed filenames
/// - All download failures are collected and returned together
///
/// # Panics
/// # Returns
///
/// This function panics if:
/// - `pre_compute_args` is `None`.
/// - `chain_task_id` is `None`.
fn download_input_files(&self) -> Result<(), ReplicateStatusCause> {
/// - `Ok(())` if all files are downloaded successfully
/// - `Err(Vec<ReplicateStatusCause>)` containing a `PreComputeInputFileDownloadFailed` error
/// for each file that failed to download
fn download_input_files(&self) -> Result<(), Vec<ReplicateStatusCause>> {
let mut exit_causes: Vec<ReplicateStatusCause> = Vec::new();
let args = &self.pre_compute_args;
let chain_task_id: &str = &self.chain_task_id;

for url in &args.input_files {
for url in args.input_files.iter() {
info!("Downloading input file [chainTaskId:{chain_task_id}, url:{url}]");

let filename = sha256(url.to_string());
if download_file(url, &args.output_dir, &filename).is_none() {
return Err(ReplicateStatusCause::PreComputeInputFileDownloadFailed);
exit_causes.push(ReplicateStatusCause::PreComputeInputFileDownloadFailed(
url.to_string(),
));
}
}
Ok(())

if !exit_causes.is_empty() {
Err(exit_causes)
} else {
Ok(())
}
}

/// Saves the decrypted (plain) dataset to disk in the configured output directory.
Expand Down Expand Up @@ -293,12 +336,14 @@ mod tests {
let result = app.download_input_files();
assert_eq!(
result.unwrap_err(),
ReplicateStatusCause::PreComputeInputFileDownloadFailed
vec![ReplicateStatusCause::PreComputeInputFileDownloadFailed(
"https://invalid-url-that-should-fail.com/file.txt".to_string()
)]
);
}

#[test]
fn test_partial_failure_stops_on_first_error() {
fn test_partial_failure_dont_stops_on_first_error() {
let (_container, json_url, xml_url) = start_container();

let temp_dir = TempDir::new().unwrap();
Expand All @@ -307,24 +352,26 @@ mod tests {
vec![
&json_url, // This should succeed
"https://invalid-url-that-should-fail.com/file.txt", // This should fail
&xml_url, // This shouldn't be reached
&xml_url, // This should succeed
],
temp_dir.path().to_str().unwrap(),
);

let result = app.download_input_files();
assert_eq!(
result.unwrap_err(),
ReplicateStatusCause::PreComputeInputFileDownloadFailed
vec![ReplicateStatusCause::PreComputeInputFileDownloadFailed(
"https://invalid-url-that-should-fail.com/file.txt".to_string()
)]
);

// First file should be downloaded with SHA256 filename
let json_hash = sha256(json_url);
assert!(temp_dir.path().join(json_hash).exists());

// Third file should NOT be downloaded (stopped on second failure)
// Third file should be downloaded (not stopped on second failure)
let xml_hash = sha256(xml_url);
assert!(!temp_dir.path().join(xml_hash).exists());
assert!(temp_dir.path().join(xml_hash).exists());
}
// endregion

Expand Down
Loading