Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 24 additions & 45 deletions src-tauri/src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ use crate::settings::{get_settings, AppSettings, APPLE_INTELLIGENCE_PROVIDER_ID}
use crate::shortcut;
use crate::tray::{change_tray_icon, TrayIconState};
use crate::utils::{self, show_recording_overlay, show_transcribing_overlay};
use async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
CreateChatCompletionRequestArgs,
};
use ferrous_opencc::{config::BuiltinConfig, OpenCC};
use log::{debug, error};
use once_cell::sync::Lazy;
Expand All @@ -30,6 +26,8 @@ pub trait ShortcutAction: Send + Sync {
// Transcribe Action
struct TranscribeAction;



async fn maybe_post_process_transcription(
settings: &AppSettings,
transcription: &str,
Expand Down Expand Up @@ -148,45 +146,20 @@ async fn maybe_post_process_transcription(
}
};

// Build the chat completion request
let message = match ChatCompletionRequestUserMessageArgs::default()
.content(processed_prompt)
.build()
{
Ok(msg) => ChatCompletionRequestMessage::User(msg),
Err(e) => {
error!("Failed to build chat message: {}", e);
return None;
}
};

let request = match CreateChatCompletionRequestArgs::default()
.model(&model)
.messages(vec![message])
.build()
{
Ok(req) => req,
Err(e) => {
error!("Failed to build chat completion request: {}", e);
return None;
}
};

// Send the request
match client.chat().create(request).await {
Ok(response) => {
if let Some(choice) = response.choices.first() {
if let Some(content) = &choice.message.content {
debug!(
"LLM post-processing succeeded for provider '{}'. Output length: {} chars",
provider.id,
content.len()
);
return Some(content.clone());
}
// Send the chat completion request using our custom client
match client.chat_completion(&model, &processed_prompt).await {
Ok(content) => {
if content.trim().is_empty() {
error!("LLM API response has empty content");
None
} else {
debug!(
"LLM post-processing succeeded for provider '{}'. Output length: {} chars",
provider.id,
content.len()
);
Some(content)
}
error!("LLM API response has no content");
None
}
Err(e) => {
error!(
Expand Down Expand Up @@ -351,17 +324,23 @@ impl ShortcutAction for TranscribeAction {
samples.len()
);

let settings = get_settings(&ah);

let transcription_time = Instant::now();
let samples_clone = samples.clone(); // Clone for history saving
match tm.transcribe(samples) {

// Use local transcription
debug!("Using local model for transcription");
let transcription_result: Result<String, String> = tm.transcribe(samples).map_err(|e| e.to_string());

match transcription_result {
Ok(transcription) => {
debug!(
"Transcription completed in {:?}: '{}'",
transcription_time.elapsed(),
transcription
);
if !transcription.is_empty() {
let settings = get_settings(&ah);
let mut final_text = transcription.clone();
let mut post_processed_text: Option<String> = None;
let mut post_process_prompt: Option<String> = None;
Expand Down Expand Up @@ -435,7 +414,7 @@ impl ShortcutAction for TranscribeAction {
}
}
Err(err) => {
debug!("Global Shortcut Transcription error: {}", err);
error!("Transcription error: {}", err);
utils::hide_recording_overlay(&ah);
change_tray_icon(&ah, TrayIconState::Idle);
}
Expand Down
1 change: 1 addition & 0 deletions src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod commands;
mod helpers;
mod input;
mod llm_client;
mod llm_types;
mod managers;
mod overlay;
mod settings;
Expand Down
110 changes: 89 additions & 21 deletions src-tauri/src/llm_client.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,101 @@
use crate::llm_types::ChatCompletionResponse;
use crate::settings::PostProcessProvider;
use async_openai::{config::OpenAIConfig, Client};
use reqwest::Client;
use serde::Serialize;

/// Create an OpenAI-compatible client configured for the given provider
#[derive(Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
}

#[derive(Serialize)]
struct ChatMessage {
role: String,
content: String,
}

/// LLM client for making chat completion requests to OpenAI-compatible APIs
pub struct LlmClient {
http_client: Client,
base_url: String,
api_key: String,
}

impl LlmClient {
/// Send a chat completion request and return the response content
pub async fn chat_completion(
&self,
model: &str,
user_message: &str,
) -> Result<String, String> {
let request = ChatCompletionRequest {
model: model.to_string(),
messages: vec![ChatMessage {
role: "user".to_string(),
content: user_message.to_string(),
}],
};

let url = format!("{}/chat/completions", self.base_url);

let response = self
.http_client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| format!("HTTP request failed: {}", e))?;

if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(format!("API request failed with status {}: {}", status, body));
}

let body = response
.text()
.await
.map_err(|e| format!("Failed to read response body: {}", e))?;

let parsed: ChatCompletionResponse = serde_json::from_str(&body)
.map_err(|e| format!("Failed to parse response: {} - body: {}", e, body))?;

parsed
.choices
.first()
.and_then(|c| c.message.content.clone())
.ok_or_else(|| "No content in response".to_string())
}
}

/// Create an LLM client configured for the given provider
pub fn create_client(
provider: &PostProcessProvider,
api_key: String,
) -> Result<Client<OpenAIConfig>, String> {
let base_url = provider.base_url.trim_end_matches('/');
let config = OpenAIConfig::new()
.with_api_base(base_url)
.with_api_key(api_key);

// Create client with Anthropic-specific header if needed
let client = if provider.id == "anthropic" {
let mut headers = reqwest::header::HeaderMap::new();
) -> Result<LlmClient, String> {
let base_url = provider.base_url.trim_end_matches('/').to_string();

let mut headers = reqwest::header::HeaderMap::new();

// Add provider-specific headers
if provider.id == "anthropic" {
headers.insert(
"anthropic-version",
reqwest::header::HeaderValue::from_static("2023-06-01"),
);
}

let http_client = reqwest::Client::builder()
.default_headers(headers)
.build()
.map_err(|e| format!("Failed to build HTTP client: {}", e))?;

Client::with_config(config).with_http_client(http_client)
} else {
Client::with_config(config)
};
let http_client = reqwest::Client::builder()
.default_headers(headers)
.build()
.map_err(|e| format!("Failed to build HTTP client: {}", e))?;

Ok(client)
Ok(LlmClient {
http_client,
base_url,
api_key,
})
}
32 changes: 32 additions & 0 deletions src-tauri/src/llm_types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use serde::Deserialize;

/// Custom response types for OpenAI-compatible APIs that may have
/// non-standard fields (like Groq's `service_tier: "on_demand"`)

#[derive(Debug, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
#[serde(skip)]
pub usage: Option<serde_json::Value>,
#[serde(skip)]
pub service_tier: Option<String>,
}

#[derive(Debug, Deserialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
#[serde(skip)]
pub logprobs: Option<serde_json::Value>,
}

#[derive(Debug, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: Option<String>,
}
22 changes: 19 additions & 3 deletions src-tauri/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ fn default_post_process_enabled() -> bool {
false
}


fn default_app_language() -> String {
"en".to_string()
}
Expand All @@ -371,6 +372,7 @@ fn default_post_process_provider_id() -> String {
}

fn default_post_process_providers() -> Vec<PostProcessProvider> {
#[allow(unused_mut)]
let mut providers = vec![
PostProcessProvider {
id: "openai".to_string(),
Expand All @@ -387,9 +389,23 @@ fn default_post_process_providers() -> Vec<PostProcessProvider> {
models_endpoint: Some("/models".to_string()),
},
PostProcessProvider {
id: "anthropic".to_string(),
label: "Anthropic".to_string(),
base_url: "https://api.anthropic.com/v1".to_string(),
id: "gemini".to_string(),
label: "Gemini".to_string(),
base_url: "https://generativelanguage.googleapis.com/v1beta/openai".to_string(),
allow_base_url_edit: false,
models_endpoint: Some("/models".to_string()),
},
PostProcessProvider {
id: "groq".to_string(),
label: "Groq".to_string(),
base_url: "https://api.groq.com/openai/v1".to_string(),
allow_base_url_edit: false,
models_endpoint: Some("/models".to_string()),
},
PostProcessProvider {
id: "cerebras".to_string(),
label: "Cerebras".to_string(),
base_url: "https://api.cerebras.ai/v1".to_string(),
allow_base_url_edit: false,
models_endpoint: Some("/models".to_string()),
},
Expand Down
5 changes: 5 additions & 0 deletions src-tauri/src/signal_handle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
#[cfg(unix)]
use crate::actions::ACTION_MAP;
#[cfg(unix)]
use crate::ManagedToggleState;
#[cfg(unix)]
use log::{debug, info, warn};
#[cfg(unix)]
use std::thread;
#[cfg(unix)]
use tauri::{AppHandle, Manager};

#[cfg(unix)]
Expand Down
10 changes: 10 additions & 0 deletions src/App.css
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@
background-color: var(--color-background);
}

/* Hide scrollbars while keeping scroll functionality */
* {
scrollbar-width: none; /* Firefox */
-ms-overflow-style: none; /* IE and Edge */
}

*::-webkit-scrollbar {
display: none; /* Chrome, Safari, Opera */
}

.container {
margin: 0;
padding-top: 10vh;
Expand Down
2 changes: 1 addition & 1 deletion src/bindings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -698,4 +698,4 @@ function __makeEvents__<T extends Record<string, any>>(
},
},
);
}
}
Loading