Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 129 additions & 3 deletions src-tauri/src/llm_client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
use crate::settings::PostProcessProvider;
use crate::settings::{PostProcessProvider, CUSTOM_LLM_BASE_URL_ENV};
use log::debug;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE, REFERER, USER_AGENT};
use serde::{Deserialize, Serialize};
use std::env;

/// Get the effective base URL for a provider.
/// For the "custom" provider, checks the environment variable first.
/// This is called fresh on each invocation to pick up runtime changes.
fn get_effective_base_url(provider: &PostProcessProvider) -> String {
if provider.id == "custom" {
// Check environment variable for custom provider override
if let Ok(env_url) = env::var(CUSTOM_LLM_BASE_URL_ENV) {
let trimmed = env_url.trim();
if !trimmed.is_empty() {
debug!(
"Using base URL from environment variable {}: {}",
CUSTOM_LLM_BASE_URL_ENV, trimmed
);
return trimmed.trim_end_matches('/').to_string();
}
}
}
provider.base_url.trim_end_matches('/').to_string()
}

#[derive(Debug, Serialize)]
struct ChatMessage {
Expand Down Expand Up @@ -85,7 +106,8 @@ pub async fn send_chat_completion(
model: &str,
prompt: String,
) -> Result<Option<String>, String> {
let base_url = provider.base_url.trim_end_matches('/');
// Get effective base URL (checks env var for custom provider on each call)
let base_url = get_effective_base_url(provider);
let url = format!("{}/chat/completions", base_url);

debug!("Sending chat completion request to: {}", url);
Expand Down Expand Up @@ -136,7 +158,8 @@ pub async fn fetch_models(
provider: &PostProcessProvider,
api_key: String,
) -> Result<Vec<String>, String> {
let base_url = provider.base_url.trim_end_matches('/');
// Get effective base URL (checks env var for custom provider on each call)
let base_url = get_effective_base_url(provider);
let url = format!("{}/models", base_url);

debug!("Fetching models from: {}", url);
Expand Down Expand Up @@ -189,3 +212,106 @@ pub async fn fetch_models(

Ok(models)
}

#[cfg(test)]
mod tests {
use super::*;
use crate::settings::PostProcessProvider;

#[test]
fn test_get_effective_base_url_custom_provider_with_env() {
// Set environment variable
std::env::set_var(CUSTOM_LLM_BASE_URL_ENV, "http://custom-server:8080/v1");

let provider = PostProcessProvider {
id: "custom".to_string(),
label: "Custom".to_string(),
base_url: "http://localhost:11434/v1".to_string(),
allow_base_url_edit: true,
models_endpoint: Some("/models".to_string()),
};

let result = get_effective_base_url(&provider);
assert_eq!(result, "http://custom-server:8080/v1");

// Clean up
std::env::remove_var(CUSTOM_LLM_BASE_URL_ENV);
}

#[test]
fn test_get_effective_base_url_custom_provider_without_env() {
// Ensure env var is not set
std::env::remove_var(CUSTOM_LLM_BASE_URL_ENV);

let provider = PostProcessProvider {
id: "custom".to_string(),
label: "Custom".to_string(),
base_url: "http://localhost:11434/v1".to_string(),
allow_base_url_edit: true,
models_endpoint: Some("/models".to_string()),
};

let result = get_effective_base_url(&provider);
assert_eq!(result, "http://localhost:11434/v1");
}

#[test]
fn test_get_effective_base_url_custom_provider_with_empty_env() {
// Set empty environment variable
std::env::set_var(CUSTOM_LLM_BASE_URL_ENV, " ");

let provider = PostProcessProvider {
id: "custom".to_string(),
label: "Custom".to_string(),
base_url: "http://localhost:11434/v1".to_string(),
allow_base_url_edit: true,
models_endpoint: Some("/models".to_string()),
};

let result = get_effective_base_url(&provider);
assert_eq!(result, "http://localhost:11434/v1");

// Clean up
std::env::remove_var(CUSTOM_LLM_BASE_URL_ENV);
}

#[test]
fn test_get_effective_base_url_non_custom_provider() {
// Set environment variable (should be ignored for non-custom provider)
std::env::set_var(CUSTOM_LLM_BASE_URL_ENV, "http://custom-server:8080/v1");

let provider = PostProcessProvider {
id: "openai".to_string(),
label: "OpenAI".to_string(),
base_url: "https://api.openai.com/v1".to_string(),
allow_base_url_edit: false,
models_endpoint: Some("/models".to_string()),
};

let result = get_effective_base_url(&provider);
assert_eq!(result, "https://api.openai.com/v1");

// Clean up
std::env::remove_var(CUSTOM_LLM_BASE_URL_ENV);
}

#[test]
fn test_get_effective_base_url_strips_trailing_slash() {
// Set environment variable with trailing slash
std::env::set_var(CUSTOM_LLM_BASE_URL_ENV, "http://custom-server:8080/v1/");

let provider = PostProcessProvider {
id: "custom".to_string(),
label: "Custom".to_string(),
base_url: "http://localhost:11434/v1/".to_string(),
allow_base_url_edit: true,
models_endpoint: Some("/models".to_string()),
};

let result = get_effective_base_url(&provider);
assert_eq!(result, "http://custom-server:8080/v1");

// Clean up
std::env::remove_var(CUSTOM_LLM_BASE_URL_ENV);
}
}
8 changes: 7 additions & 1 deletion src-tauri/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use tauri_plugin_store::StoreExt;

pub const APPLE_INTELLIGENCE_PROVIDER_ID: &str = "apple_intelligence";
pub const APPLE_INTELLIGENCE_DEFAULT_MODEL_ID: &str = "Apple Intelligence";
pub const CUSTOM_LLM_BASE_URL_ENV: &str = "HANDY_CUSTOM_LLM_BASE_URL";

#[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq, Type)]
#[serde(rename_all = "lowercase")]
Expand Down Expand Up @@ -449,10 +450,15 @@ fn default_post_process_providers() -> Vec<PostProcessProvider> {
}

// Custom provider always comes last
let custom_base_url = std::env::var(CUSTOM_LLM_BASE_URL_ENV)
.ok()
.filter(|s| !s.trim().is_empty())
.unwrap_or_else(|| "http://localhost:11434/v1".to_string());

providers.push(PostProcessProvider {
id: "custom".to_string(),
label: "Custom".to_string(),
base_url: "http://localhost:11434/v1".to_string(),
base_url: custom_base_url,
allow_base_url_edit: true,
models_endpoint: Some("/models".to_string()),
});
Expand Down