diff --git a/src-tauri/src/llm_client.rs b/src-tauri/src/llm_client.rs index 57fe67e2b..e60b2010f 100644 --- a/src-tauri/src/llm_client.rs +++ b/src-tauri/src/llm_client.rs @@ -1,7 +1,28 @@ -use crate::settings::PostProcessProvider; +use crate::settings::{PostProcessProvider, CUSTOM_LLM_BASE_URL_ENV}; use log::debug; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE, REFERER, USER_AGENT}; use serde::{Deserialize, Serialize}; +use std::env; + +/// Get the effective base URL for a provider. +/// For the "custom" provider, checks the environment variable first. +/// This is called fresh on each invocation to pick up runtime changes. +fn get_effective_base_url(provider: &PostProcessProvider) -> String { + if provider.id == "custom" { + // Check environment variable for custom provider override + if let Ok(env_url) = env::var(CUSTOM_LLM_BASE_URL_ENV) { + let trimmed = env_url.trim(); + if !trimmed.is_empty() { + debug!( + "Using base URL from environment variable {}: {}", + CUSTOM_LLM_BASE_URL_ENV, trimmed + ); + return trimmed.trim_end_matches('/').to_string(); + } + } + } + provider.base_url.trim_end_matches('/').to_string() +} #[derive(Debug, Serialize)] struct ChatMessage { @@ -85,7 +106,8 @@ pub async fn send_chat_completion( model: &str, prompt: String, ) -> Result, String> { - let base_url = provider.base_url.trim_end_matches('/'); + // Get effective base URL (checks env var for custom provider on each call) + let base_url = get_effective_base_url(provider); let url = format!("{}/chat/completions", base_url); debug!("Sending chat completion request to: {}", url); @@ -136,7 +158,8 @@ pub async fn fetch_models( provider: &PostProcessProvider, api_key: String, ) -> Result, String> { - let base_url = provider.base_url.trim_end_matches('/'); + // Get effective base URL (checks env var for custom provider on each call) + let base_url = get_effective_base_url(provider); let url = format!("{}/models", base_url); debug!("Fetching models from: {}", url); @@ -189,3 +212,106 @@ pub async fn fetch_models( Ok(models) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::settings::PostProcessProvider; + + #[test] + fn test_get_effective_base_url_custom_provider_with_env() { + // Set environment variable + std::env::set_var(CUSTOM_LLM_BASE_URL_ENV, "http://custom-server:8080/v1"); + + let provider = PostProcessProvider { + id: "custom".to_string(), + label: "Custom".to_string(), + base_url: "http://localhost:11434/v1".to_string(), + allow_base_url_edit: true, + models_endpoint: Some("/models".to_string()), + }; + + let result = get_effective_base_url(&provider); + assert_eq!(result, "http://custom-server:8080/v1"); + + // Clean up + std::env::remove_var(CUSTOM_LLM_BASE_URL_ENV); + } + + #[test] + fn test_get_effective_base_url_custom_provider_without_env() { + // Ensure env var is not set + std::env::remove_var(CUSTOM_LLM_BASE_URL_ENV); + + let provider = PostProcessProvider { + id: "custom".to_string(), + label: "Custom".to_string(), + base_url: "http://localhost:11434/v1".to_string(), + allow_base_url_edit: true, + models_endpoint: Some("/models".to_string()), + }; + + let result = get_effective_base_url(&provider); + assert_eq!(result, "http://localhost:11434/v1"); + } + + #[test] + fn test_get_effective_base_url_custom_provider_with_empty_env() { + // Set empty environment variable + std::env::set_var(CUSTOM_LLM_BASE_URL_ENV, " "); + + let provider = PostProcessProvider { + id: "custom".to_string(), + label: "Custom".to_string(), + base_url: "http://localhost:11434/v1".to_string(), + allow_base_url_edit: true, + models_endpoint: Some("/models".to_string()), + }; + + let result = get_effective_base_url(&provider); + assert_eq!(result, "http://localhost:11434/v1"); + + // Clean up + std::env::remove_var(CUSTOM_LLM_BASE_URL_ENV); + } + + #[test] + fn test_get_effective_base_url_non_custom_provider() { + // Set environment variable (should be ignored for non-custom provider) + std::env::set_var(CUSTOM_LLM_BASE_URL_ENV, "http://custom-server:8080/v1"); + + let provider = PostProcessProvider { + id: "openai".to_string(), + label: "OpenAI".to_string(), + base_url: "https://api.openai.com/v1".to_string(), + allow_base_url_edit: false, + models_endpoint: Some("/models".to_string()), + }; + + let result = get_effective_base_url(&provider); + assert_eq!(result, "https://api.openai.com/v1"); + + // Clean up + std::env::remove_var(CUSTOM_LLM_BASE_URL_ENV); + } + + #[test] + fn test_get_effective_base_url_strips_trailing_slash() { + // Set environment variable with trailing slash + std::env::set_var(CUSTOM_LLM_BASE_URL_ENV, "http://custom-server:8080/v1/"); + + let provider = PostProcessProvider { + id: "custom".to_string(), + label: "Custom".to_string(), + base_url: "http://localhost:11434/v1/".to_string(), + allow_base_url_edit: true, + models_endpoint: Some("/models".to_string()), + }; + + let result = get_effective_base_url(&provider); + assert_eq!(result, "http://custom-server:8080/v1"); + + // Clean up + std::env::remove_var(CUSTOM_LLM_BASE_URL_ENV); + } +} diff --git a/src-tauri/src/settings.rs b/src-tauri/src/settings.rs index 8feb07f4d..375b1b9ed 100644 --- a/src-tauri/src/settings.rs +++ b/src-tauri/src/settings.rs @@ -8,6 +8,7 @@ use tauri_plugin_store::StoreExt; pub const APPLE_INTELLIGENCE_PROVIDER_ID: &str = "apple_intelligence"; pub const APPLE_INTELLIGENCE_DEFAULT_MODEL_ID: &str = "Apple Intelligence"; +pub const CUSTOM_LLM_BASE_URL_ENV: &str = "HANDY_CUSTOM_LLM_BASE_URL"; #[derive(Serialize, Debug, Clone, Copy, PartialEq, Eq, Type)] #[serde(rename_all = "lowercase")] @@ -449,10 +450,15 @@ fn default_post_process_providers() -> Vec { } // Custom provider always comes last + let custom_base_url = std::env::var(CUSTOM_LLM_BASE_URL_ENV) + .ok() + .filter(|s| !s.trim().is_empty()) + .unwrap_or_else(|| "http://localhost:11434/v1".to_string()); + providers.push(PostProcessProvider { id: "custom".to_string(), label: "Custom".to_string(), - base_url: "http://localhost:11434/v1".to_string(), + base_url: custom_base_url, allow_base_url_edit: true, models_endpoint: Some("/models".to_string()), });