diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 56d72aa..c483972 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -9,7 +9,7 @@ use std::{ path::PathBuf, process::{Child, Command, Stdio}, sync::Mutex, - time::{Duration, Instant}, + time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; use tauri::{AppHandle, Manager}; use thiserror::Error; @@ -400,7 +400,7 @@ async fn connect_runtime( "No local GPU runtime was available and Vast API key is empty.".to_string(), ) })?; - let status = connect_vast( + let (status, auth_token) = connect_vast( &state.client, api_key, &configured_image, @@ -408,6 +408,7 @@ async fn connect_runtime( local_gpu, ) .await?; + store_pitch_server_auth_token(&state, auth_token)?; replace_status(&app, &state, status.clone())?; Ok(status) }) @@ -515,7 +516,7 @@ async fn score_pitch( .post(format!("{service_url}/score")) .json(&request) .timeout(Duration::from_secs(TRIBE_CLIENT_SCORE_TIMEOUT_SECONDS)); - if status.mode == "pitchserver" { + if status.mode == "pitchserver" || status.mode == "vast" { let token = pitch_server_auth_token(&state)?; request_builder = request_builder.bearer_auth(token); } @@ -1942,7 +1943,7 @@ async fn connect_vast( image: &str, config: &RuntimeConfig, local_gpu: LocalGpuInfo, -) -> RuntimeResult { +) -> RuntimeResult<(RuntimeStatus, String)> { let mut offers = search_vast_offers(client, api_key, config).await?; if let Some(max_price) = config.max_hourly_price { offers.retain(|offer| offer.dph_total.unwrap_or(f64::MAX) <= max_price); @@ -1970,8 +1971,19 @@ async fn connect_vast( let vast_image = select_vast_image(client, image).await; let use_bootstrap = vast_image != image; let mut last_error: Option = None; + let (auth_username, auth_password) = vast_auth_credentials(config); for offer in offers.iter().take(5) { - match create_vast_instance(client, api_key, offer, &vast_image, config, use_bootstrap).await + match create_vast_instance( + client, + api_key, + offer, + &vast_image, + config, + use_bootstrap, + &auth_username, + &auth_password, + ) + .await { Ok(instance_id) => { if let Err(error) = start_vast_instance(client, api_key, instance_id).await { @@ -1979,7 +1991,16 @@ async fn connect_vast( last_error = Some(error.to_string()); continue; } - match wait_for_vast_runtime(client, api_key, instance_id, &vast_image, offer).await + match wait_for_vast_runtime( + client, + api_key, + instance_id, + &vast_image, + offer, + &auth_username, + &auth_password, + ) + .await { Ok(mut status) => { status.local_gpu = local_gpu; @@ -2197,6 +2218,8 @@ async fn create_vast_instance( image: &str, config: &RuntimeConfig, use_bootstrap: bool, + auth_username: &str, + auth_password: &str, ) -> RuntimeResult { let mut env = Map::new(); env.insert("TRIBE_DEVICE".to_string(), json!("cuda")); @@ -2243,6 +2266,23 @@ async fn create_vast_instance( "OPEN_BUTTON_PORT".to_string(), json!(TRIBE_PORT.to_string()), ); + env.insert("PITCHSERVER_AUTH_REQUIRED".to_string(), json!("1")); + env.insert( + "PITCHSERVER_AUTH_FILE".to_string(), + json!("/tmp/pitchcheck_auth.json"), + ); + env.insert( + "PITCHSERVER_AUTH_SEED_USERNAME".to_string(), + json!(auth_username), + ); + env.insert( + "PITCHSERVER_AUTH_SEED_PASSWORD".to_string(), + json!(auth_password), + ); + env.insert( + "PITCHSERVER_SESSION_TTL_SECONDS".to_string(), + json!("86400"), + ); env.insert(format!("-p {TRIBE_PORT}:{TRIBE_PORT}"), json!("1")); let onstart = if use_bootstrap { @@ -2299,7 +2339,9 @@ async fn wait_for_vast_runtime( instance_id: u64, image: &str, offer: &OfferSummary, -) -> RuntimeResult { + auth_username: &str, + auth_password: &str, +) -> RuntimeResult<(RuntimeStatus, String)> { let deadline = Instant::now() + Duration::from_secs(900); loop { @@ -2308,17 +2350,23 @@ async fn wait_for_vast_runtime( if let Some(service_url) = service_url_from_instance(&instance, TRIBE_PORT) { match wait_for_health(client, &service_url, Duration::from_secs(15)).await { Ok(()) => { - return Ok(RuntimeStatus { - mode: "vast".to_string(), - connected: true, - service_url: Some(service_url), - local_gpu: LocalGpuInfo::default(), - container_id: None, - vast_instance_id: Some(instance_id), - offer: Some(offer.clone()), - image: image.to_string(), - last_error: None, - }); + let token = + login_pitch_server(client, &service_url, auth_username, auth_password) + .await?; + return Ok(( + RuntimeStatus { + mode: "vast".to_string(), + connected: true, + service_url: Some(service_url), + local_gpu: LocalGpuInfo::default(), + container_id: None, + vast_instance_id: Some(instance_id), + offer: Some(offer.clone()), + image: image.to_string(), + last_error: None, + }, + token, + )); } Err(error) => error.to_string(), } @@ -2335,6 +2383,29 @@ async fn wait_for_vast_runtime( } } +fn vast_auth_credentials(config: &RuntimeConfig) -> (String, String) { + let username = config + .pitch_server_username + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .unwrap_or("pitchcheck") + .to_string(); + let password = config + .pitch_server_password + .as_deref() + .filter(|value| !value.trim().is_empty()) + .map(str::to_string) + .unwrap_or_else(|| { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_nanos()) + .unwrap_or(0); + format!("vast-{nonce:032x}") + }); + (username, password) +} + async fn fetch_vast_instance( client: &Client, api_key: &str,