Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 89 additions & 18 deletions src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
path::PathBuf,
process::{Child, Command, Stdio},
sync::Mutex,
time::{Duration, Instant},
time::{Duration, Instant, SystemTime, UNIX_EPOCH},
};
use tauri::{AppHandle, Manager};
use thiserror::Error;
Expand Down Expand Up @@ -400,14 +400,15 @@ async fn connect_runtime(
"No local GPU runtime was available and Vast API key is empty.".to_string(),
)
})?;
let status = connect_vast(
let (status, auth_token) = connect_vast(
&state.client,
api_key,
&configured_image,
&config,
local_gpu,
)
.await?;
store_pitch_server_auth_token(&state, auth_token)?;
replace_status(&app, &state, status.clone())?;
Ok(status)
})
Expand Down Expand Up @@ -515,7 +516,7 @@ async fn score_pitch(
.post(format!("{service_url}/score"))
.json(&request)
.timeout(Duration::from_secs(TRIBE_CLIENT_SCORE_TIMEOUT_SECONDS));
if status.mode == "pitchserver" {
if status.mode == "pitchserver" || status.mode == "vast" {
let token = pitch_server_auth_token(&state)?;
request_builder = request_builder.bearer_auth(token);
}
Expand Down Expand Up @@ -1942,7 +1943,7 @@ async fn connect_vast(
image: &str,
config: &RuntimeConfig,
local_gpu: LocalGpuInfo,
) -> RuntimeResult<RuntimeStatus> {
) -> RuntimeResult<(RuntimeStatus, String)> {
let mut offers = search_vast_offers(client, api_key, config).await?;
if let Some(max_price) = config.max_hourly_price {
offers.retain(|offer| offer.dph_total.unwrap_or(f64::MAX) <= max_price);
Expand Down Expand Up @@ -1970,16 +1971,36 @@ async fn connect_vast(
let vast_image = select_vast_image(client, image).await;
let use_bootstrap = vast_image != image;
let mut last_error: Option<String> = None;
let (auth_username, auth_password) = vast_auth_credentials(config);
for offer in offers.iter().take(5) {
match create_vast_instance(client, api_key, offer, &vast_image, config, use_bootstrap).await
match create_vast_instance(
client,
api_key,
offer,
&vast_image,
config,
use_bootstrap,
&auth_username,
&auth_password,
)
.await
{
Ok(instance_id) => {
if let Err(error) = start_vast_instance(client, api_key, instance_id).await {
let _ = destroy_vast_instance(client, api_key, instance_id).await;
last_error = Some(error.to_string());
continue;
}
match wait_for_vast_runtime(client, api_key, instance_id, &vast_image, offer).await
match wait_for_vast_runtime(
client,
api_key,
instance_id,
&vast_image,
offer,
&auth_username,
&auth_password,
)
.await
{
Ok(mut status) => {
status.local_gpu = local_gpu;
Expand Down Expand Up @@ -2197,6 +2218,8 @@ async fn create_vast_instance(
image: &str,
config: &RuntimeConfig,
use_bootstrap: bool,
auth_username: &str,
auth_password: &str,
) -> RuntimeResult<u64> {
let mut env = Map::new();
env.insert("TRIBE_DEVICE".to_string(), json!("cuda"));
Expand Down Expand Up @@ -2243,6 +2266,23 @@ async fn create_vast_instance(
"OPEN_BUTTON_PORT".to_string(),
json!(TRIBE_PORT.to_string()),
);
env.insert("PITCHSERVER_AUTH_REQUIRED".to_string(), json!("1"));
env.insert(
"PITCHSERVER_AUTH_FILE".to_string(),
json!("/tmp/pitchcheck_auth.json"),
);
env.insert(
"PITCHSERVER_AUTH_SEED_USERNAME".to_string(),
json!(auth_username),
);
env.insert(
"PITCHSERVER_AUTH_SEED_PASSWORD".to_string(),
json!(auth_password),
);
env.insert(
"PITCHSERVER_SESSION_TTL_SECONDS".to_string(),
json!("86400"),
);
env.insert(format!("-p {TRIBE_PORT}:{TRIBE_PORT}"), json!("1"));

let onstart = if use_bootstrap {
Expand Down Expand Up @@ -2299,7 +2339,9 @@ async fn wait_for_vast_runtime(
instance_id: u64,
image: &str,
offer: &OfferSummary,
) -> RuntimeResult<RuntimeStatus> {
auth_username: &str,
auth_password: &str,
) -> RuntimeResult<(RuntimeStatus, String)> {
let deadline = Instant::now() + Duration::from_secs(900);

loop {
Expand All @@ -2308,17 +2350,23 @@ async fn wait_for_vast_runtime(
if let Some(service_url) = service_url_from_instance(&instance, TRIBE_PORT) {
match wait_for_health(client, &service_url, Duration::from_secs(15)).await {
Ok(()) => {
return Ok(RuntimeStatus {
mode: "vast".to_string(),
connected: true,
service_url: Some(service_url),
local_gpu: LocalGpuInfo::default(),
container_id: None,
vast_instance_id: Some(instance_id),
offer: Some(offer.clone()),
image: image.to_string(),
last_error: None,
});
let token =
login_pitch_server(client, &service_url, auth_username, auth_password)
.await?;
return Ok((
RuntimeStatus {
mode: "vast".to_string(),
connected: true,
service_url: Some(service_url),
local_gpu: LocalGpuInfo::default(),
container_id: None,
vast_instance_id: Some(instance_id),
offer: Some(offer.clone()),
image: image.to_string(),
last_error: None,
},
token,
));
}
Err(error) => error.to_string(),
}
Expand All @@ -2335,6 +2383,29 @@ async fn wait_for_vast_runtime(
}
}

fn vast_auth_credentials(config: &RuntimeConfig) -> (String, String) {
let username = config
.pitch_server_username
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or("pitchcheck")
.to_string();
let password = config
.pitch_server_password
.as_deref()
.filter(|value| !value.trim().is_empty())
.map(str::to_string)
.unwrap_or_else(|| {
let nonce = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_nanos())
.unwrap_or(0);
format!("vast-{nonce:032x}")
});
(username, password)
}

async fn fetch_vast_instance(
client: &Client,
api_key: &str,
Expand Down
Loading