From 2b8dc05d6b9d113ee66b92a2e8325f10dcfb765d Mon Sep 17 00:00:00 2001 From: Xavier Lau Date: Tue, 9 Jul 2024 01:25:26 +0800 Subject: [PATCH] Refactor chat --- README.md | 6 ++- src/air.rs | 14 +++--- src/component.rs | 20 +------- src/component/keyboard.rs | 2 +- src/component/openai.rs | 22 +++------ src/component/setting.rs | 6 +-- src/service.rs | 36 ++++++++++++-- src/service/chat.rs | 100 ++++++++++++++++++++++++++++++++++++++ src/service/hotkey.rs | 74 ++++++---------------------- src/service/keyboard.rs | 17 +++---- src/service/quoter.rs | 2 +- src/state.rs | 8 --- src/ui.rs | 8 ++- src/ui/panel/chat.rs | 9 ++-- src/ui/panel/setting.rs | 14 ++++-- 15 files changed, 198 insertions(+), 140 deletions(-) create mode 100644 src/service/chat.rs diff --git a/README.md b/README.md index 4c8c688..902f84f 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,7 @@ Built upon [egui](https://github.com/emilk/egui), a fast and cross-platform GUI toolkit written in pure Rust. ### Components -These items either have their own `refresh` logic or do not require frequent refreshing. -They are not time-sensitive, and their `refresh` method will be called at specific intervals (e.g., every 15 seconds). +These items are static and they used to be called by other stuffs. ### OS Provides wrapped APIs to interact with the operating system. @@ -26,5 +25,8 @@ Provides wrapped APIs to interact with the operating system. These items are time-sensitive and require frequent checking or updating. They will be spawned as separate threads and run in the background. +### State +Mutable version of the components. Usually, they are `Arc>` in order to sync the state between service and UI. + ### UI The user interface components. diff --git a/src/air.rs b/src/air.rs index d02a00d..b59de53 100644 --- a/src/air.rs +++ b/src/air.rs @@ -16,17 +16,17 @@ struct AiR { uis: Uis, } impl AiR { - fn init(ctx: &Context) -> Result { + fn new(ctx: &Context) -> Result { Self::set_fonts(ctx); // To enable SVG. egui_extras::install_image_loaders(ctx); let once = Once::new(); - let components = Components::init()?; + let components = Components::new()?; let state = Default::default(); - let services = Services::init(ctx, &components, &state)?; - let uis = Uis::init(); + let services = Services::new(ctx, &components, &state)?; + let uis = Uis::new(); Ok(Self { once, components, state, services, uis }) } @@ -62,7 +62,7 @@ impl App for AiR { egui_ctx: ctx, components: &mut self.components, state: &self.state, - services: &self.services, + services: &mut self.services, }; self.uis.draw(air_ctx); @@ -108,7 +108,7 @@ pub struct AiRContext<'a> { pub egui_ctx: &'a Context, pub components: &'a mut Components, pub state: &'a State, - pub services: &'a Services, + pub services: &'a mut Services, } pub fn launch() -> Result<()> { @@ -125,7 +125,7 @@ pub fn launch() -> Result<()> { .with_transparent(true), ..Default::default() }, - Box::new(|c| Ok(Box::new(AiR::init(&c.egui_ctx).unwrap()))), + Box::new(|c| Ok(Box::new(AiR::new(&c.egui_ctx).unwrap()))), )?; Ok(()) diff --git a/src/component.rs b/src/component.rs index 02e2b1f..63f37cb 100644 --- a/src/component.rs +++ b/src/component.rs @@ -8,7 +8,6 @@ pub mod keyboard; pub mod net; pub mod openai; -use openai::OpenAi; pub mod quote; @@ -17,46 +16,29 @@ use setting::Setting; pub mod util; -// std -use std::sync::Arc; -// crates.io -use tokio::sync::Mutex; // self use crate::prelude::*; #[derive(Debug)] pub struct Components { pub setting: Setting, - // Keyboard didn't implement `Send`, can't use it between threads. - // pub keyboard: Arc>, - // TODO?: move the lock to somewhere else. - pub openai: Arc>, #[cfg(feature = "tokenizer")] pub tokenizer: Tokenizer, } impl Components { - pub fn init() -> Result { + pub fn new() -> Result { let setting = Setting::load()?; // TODO: https://github.com/emilk/egui/discussions/4670. debug_assert_eq!(setting.ai.temperature, setting.ai.temperature * 10. / 10.); - let openai = Arc::new(Mutex::new(OpenAi::new(setting.ai.clone()))); #[cfg(feature = "tokenizer")] let tokenizer = Tokenizer::new(setting.ai.model.as_str()); Ok(Self { setting, - openai, #[cfg(feature = "tokenizer")] tokenizer, }) } - - // TODO?: move to somewhere else. - pub fn reload_openai(&self) { - tracing::info!("reloading openai component"); - - self.openai.blocking_lock().reload(self.setting.ai.clone()); - } } diff --git a/src/component/keyboard.rs b/src/component/keyboard.rs index ca2b940..9d41719 100644 --- a/src/component/keyboard.rs +++ b/src/component/keyboard.rs @@ -7,7 +7,7 @@ use crate::prelude::*; #[derive(Debug)] pub struct Keyboard(pub Enigo); impl Keyboard { - pub fn init() -> Result { + pub fn new() -> Result { Ok(Self(Enigo::new(&Settings::default()).map_err(EnigoError::NewCon)?)) } diff --git a/src/component/openai.rs b/src/component/openai.rs index 72451ae..f5f52ea 100644 --- a/src/component/openai.rs +++ b/src/component/openai.rs @@ -16,22 +16,16 @@ use crate::prelude::*; #[derive(Debug)] pub struct OpenAi { pub client: Client, - pub setting: Ai, + pub model: Model, + pub temperature: f32, } impl OpenAi { pub fn new(setting: Ai) -> Self { - let client = Client::with_config( - OpenAIConfig::new().with_api_base(&setting.api_base).with_api_key(&setting.api_key), - ); + let Ai { api_base, api_key, model, temperature } = setting; + let client = + Client::with_config(OpenAIConfig::new().with_api_base(api_base).with_api_key(api_key)); - Self { client, setting } - } - - pub fn reload(&mut self, setting: Ai) { - self.client = Client::with_config( - OpenAIConfig::new().with_api_base(&setting.api_base).with_api_key(&setting.api_key), - ); - self.setting = setting; + Self { client, model, temperature } } pub async fn chat(&self, prompt: &str, content: &str) -> Result { @@ -40,8 +34,8 @@ impl OpenAi { ChatCompletionRequestUserMessageArgs::default().content(content).build()?.into(), ]; let req = CreateChatCompletionRequestArgs::default() - .model(self.setting.model.as_str()) - .temperature(self.setting.temperature) + .model(self.model.as_str()) + .temperature(self.temperature) .max_tokens(4_096_u16) .messages(&msg) .build()?; diff --git a/src/component/setting.rs b/src/component/setting.rs index a1eafc5..a64df21 100644 --- a/src/component/setting.rs +++ b/src/component/setting.rs @@ -102,7 +102,7 @@ impl Default for Rewrite { Self { prompt: "As language professor, assist me in refining this text. \ Amend any grammatical errors and enhance the language to sound more like a native speaker.\ - Just provide the refined text only, without any other things." + Just provide the refined text only, without any other things:" .into(), } } @@ -128,7 +128,7 @@ impl Default for Translation { fn default() -> Self { Self { prompt: "As a language professor, amend any grammatical errors and enhance the language to sound more like a native speaker. \ - Provide the translated text only, without any other things.".into(), + Provide the translated text only, without any other things:".into(), a: Language::ZhCn, b: Language::EnGb, } @@ -144,7 +144,7 @@ pub enum Language { EnGb, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] pub struct Hotkeys { pub rewrite: HotKey, diff --git a/src/service.rs b/src/service.rs index e1610bb..59eeac2 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,3 +1,6 @@ +mod chat; +use chat::Chat; + mod hotkey; use hotkey::Hotkey; @@ -7,6 +10,11 @@ use keyboard::Keyboard; mod quoter; use quoter::Quoter; +// std +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; // crates.io use eframe::egui::Context; use tokio::runtime::Runtime; @@ -18,21 +26,39 @@ pub struct Services { pub keyboard: Keyboard, pub rt: Option, pub quoter: Quoter, + pub is_chatting: Arc, + pub chat: Chat, pub hotkey: Hotkey, } impl Services { - pub fn init(ctx: &Context, components: &Components, state: &State) -> Result { - let keyboard = Keyboard::init(); + pub fn new(ctx: &Context, components: &Components, state: &State) -> Result { + let keyboard = Keyboard::new(); let rt = Runtime::new()?; - let quoter = Quoter::init(&rt, state.chat.quote.clone()); - let hotkey = Hotkey::init(ctx, keyboard.clone(), &rt, components, state)?; + let quoter = Quoter::new(&rt, state.chat.quote.clone()); + let is_chatting = Arc::new(AtomicBool::new(false)); + let chat = Chat::new( + keyboard.clone(), + &rt, + is_chatting.clone(), + components.setting.ai.clone(), + components.setting.chat.clone(), + state.chat.input.clone(), + state.chat.output.clone(), + ); + let hotkey = + Hotkey::new(ctx, keyboard.clone(), &rt, &components.setting.hotkeys, chat.tx.clone())?; + + Ok(Self { keyboard, rt: Some(rt), quoter, is_chatting, chat, hotkey }) + } - Ok(Self { keyboard, rt: Some(rt), quoter, hotkey }) + pub fn is_chatting(&self) -> bool { + self.is_chatting.load(Ordering::SeqCst) } pub fn abort(&mut self) { self.keyboard.abort(); self.quoter.abort(); + self.chat.abort(); self.hotkey.abort(); if let Some(rt) = self.rt.take() { diff --git a/src/service/chat.rs b/src/service/chat.rs new file mode 100644 index 0000000..e23f1ad --- /dev/null +++ b/src/service/chat.rs @@ -0,0 +1,100 @@ +// std +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + mpsc::{self, Sender}, + Arc, + }, + time::Duration, +}; +// crates.io +use futures::StreamExt; +use parking_lot::RwLock; +use tokio::{runtime::Runtime, task::AbortHandle, time}; +// self +use super::keyboard::Keyboard; +use crate::component::{ + function::Function, + openai::OpenAi, + setting::{Ai, Chat as ChatSetting}, +}; + +pub type ChatArgs = (Function, String, bool); + +#[derive(Debug)] +pub struct Chat { + pub tx: Sender, + abort_handle: AbortHandle, +} +impl Chat { + pub fn new( + keyboard: Keyboard, + rt: &Runtime, + is_chatting: Arc, + ai_setting: Ai, + chat_setting: ChatSetting, + input: Arc>, + output: Arc>, + ) -> Self { + let openai = OpenAi::new(ai_setting); + let (tx, rx) = mpsc::channel(); + // TODO: handle the error. + let abort_handle = rt + .spawn(async move { + loop { + let (func, content, type_in): ChatArgs = rx.recv().unwrap(); + + is_chatting.store(true, Ordering::SeqCst); + + tracing::info!("func: {func:?}"); + tracing::debug!("content: {content}"); + + input.write().clone_from(&content); + output.write().clear(); + + let mut stream = + openai.chat(&func.prompt(&chat_setting), &content).await.unwrap(); + + while let Some(r) = stream.next().await { + for s in r.unwrap().choices.into_iter().filter_map(|c| c.delta.content) { + output.write().push_str(&s); + + // TODO?: move to outside of the loop. + if type_in { + keyboard.text(s); + } + } + } + + // Allow the UI a moment to refresh the content. + time::sleep(Duration::from_millis(50)).await; + + is_chatting.store(false, Ordering::SeqCst); + } + }) + .abort_handle(); + + Self { abort_handle, tx } + } + + pub fn abort(&self) { + self.abort_handle.abort(); + } + + // TODO: fix clippy. + #[allow(clippy::too_many_arguments)] + pub fn renew( + &mut self, + keyboard: Keyboard, + rt: &Runtime, + is_chatting: Arc, + ai_setting: Ai, + chat_setting: ChatSetting, + input: Arc>, + output: Arc>, + ) { + self.abort(); + + *self = Self::new(keyboard, rt, is_chatting, ai_setting, chat_setting, input, output); + } +} diff --git a/src/service/hotkey.rs b/src/service/hotkey.rs index d4f7633..14af19d 100644 --- a/src/service/hotkey.rs +++ b/src/service/hotkey.rs @@ -1,49 +1,30 @@ // std -use std::{ - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - time::Duration, -}; +use std::{sync::mpsc::Sender, time::Duration}; // crates.io use arboard::Clipboard; use eframe::egui::{Context, ViewportCommand}; -use futures::StreamExt; use global_hotkey::{GlobalHotKeyEvent, GlobalHotKeyManager, HotKeyState}; use tokio::{runtime::Runtime, task::AbortHandle, time}; // self +use super::{chat::ChatArgs, keyboard::Keyboard}; use crate::{ - component::{function::Function, setting::Hotkeys, Components}, + component::{function::Function, setting::Hotkeys}, os::*, prelude::*, - service::keyboard::Keyboard, - state::State, }; #[derive(Debug)] -pub struct Hotkey { - abort_handle: AbortHandle, - is_running: Arc, -} +pub struct Hotkey(AbortHandle); impl Hotkey { - // TODO: optimize parameters. - pub fn init( + pub fn new( ctx: &Context, keyboard: Keyboard, rt: &Runtime, - components: &Components, - state: &State, + hotkeys: &Hotkeys, + tx: Sender, ) -> Result { let ctx = ctx.to_owned(); - // TODO: use `state.setting.hotkeys`. - let manager = Manager::init(&components.setting.hotkeys)?; - let openai = components.openai.clone(); - let chat_input = state.chat.input.clone(); - let chat_output = state.chat.output.clone(); - let chat_setting = state.setting.chat.clone(); - let is_running = Arc::new(AtomicBool::new(false)); - let is_running_ = is_running.clone(); + let manager = Manager::new(hotkeys)?; let receiver = GlobalHotKeyEvent::receiver(); let mut clipboard = Clipboard::new()?; // TODO: handle the error. @@ -53,8 +34,6 @@ impl Hotkey { let manager = manager; loop { - is_running_.store(false, Ordering::SeqCst); - // Block the thread until a hotkey event is received. let e = receiver.recv().unwrap(); @@ -63,8 +42,6 @@ impl Hotkey { // TODO: reset the hotkey state so that we don't need to wait for the user // to release the keys. - is_running_.store(true, Ordering::SeqCst); - let func = manager.match_func(e.id); let to_unhide = !func.is_directly(); @@ -86,49 +63,26 @@ impl Hotkey { _ => continue, }; + tx.send((func, content, !to_unhide)).unwrap(); + if to_unhide { // Generally, this needs some time to wait the window available // first, but the previous sleep in get selected text is enough. ctx.send_viewport_cmd(ViewportCommand::Focus); } - - chat_input.write().clone_from(&content); - chat_output.write().clear(); - - let chat_setting = chat_setting.read().to_owned(); - let mut stream = openai - .lock() - .await - .chat(&func.prompt(&chat_setting), &content) - .await - .unwrap(); - - while let Some(r) = stream.next().await { - for s in r.unwrap().choices.into_iter().filter_map(|c| c.delta.content) - { - chat_output.write().push_str(&s); - - // TODO: move to outside of the loop. - if !to_unhide { - keyboard.text(s); - } - } - } } } }) .abort_handle(); - Ok(Self { abort_handle, is_running }) + Ok(Self(abort_handle)) } pub fn abort(&self) { - self.abort_handle.abort(); + self.0.abort(); } - pub fn is_running(&self) -> bool { - self.is_running.load(Ordering::SeqCst) - } + // TODO: fn renew. } struct Manager { @@ -137,7 +91,7 @@ struct Manager { ids: [u32; 4], } impl Manager { - fn init(hotkeys: &Hotkeys) -> Result { + fn new(hotkeys: &Hotkeys) -> Result { let _inner = GlobalHotKeyManager::new()?; let hotkeys = [ hotkeys.rewrite, diff --git a/src/service/keyboard.rs b/src/service/keyboard.rs index 85bd38c..c507664 100644 --- a/src/service/keyboard.rs +++ b/src/service/keyboard.rs @@ -9,21 +9,18 @@ use crate::component::keyboard::Keyboard as Kb; #[derive(Clone, Debug)] pub struct Keyboard(Sender); impl Keyboard { - pub fn init() -> Self { + pub fn new() -> Self { let (tx, rx) = mpsc::channel::(); - // TODO: handle the error. + // [`enigo::Enigo`] can't be sent between threads safely. + // So, we spawn a new thread to handle the keyboard action here. thread::spawn(move || { - let mut kb = Kb::init().unwrap(); + let mut kb = Kb::new().expect("keyboard action must succeed"); loop { - let act = rx.recv().unwrap(); - - tracing::info!("receive action: {act:?}"); - - match act { - Action::Copy => kb.copy().unwrap(), - Action::Text(text) => kb.text(&text).unwrap(), + match rx.recv().expect("receive must succeed") { + Action::Copy => kb.copy().expect("keyboard action must succeed"), + Action::Text(text) => kb.text(&text).expect("keyboard action must succeed"), Action::Abort => return, } } diff --git a/src/service/quoter.rs b/src/service/quoter.rs index 62ebcc8..e226f44 100644 --- a/src/service/quoter.rs +++ b/src/service/quoter.rs @@ -9,7 +9,7 @@ use crate::component::quote::Quoter as QuoterC; #[derive(Debug)] pub struct Quoter(AbortHandle); impl Quoter { - pub fn init(rt: &Runtime, quote: Arc>) -> Self { + pub fn new(rt: &Runtime, quote: Arc>) -> Self { let quoter = QuoterC; let abort_handle = rt .spawn(async move { diff --git a/src/state.rs b/src/state.rs index 20bf81a..ade746d 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2,13 +2,10 @@ use std::sync::Arc; // crates.io use parking_lot::RwLock; -// self -use crate::component::setting::Chat as ChatSetting; #[derive(Debug, Default)] pub struct State { pub chat: Chat, - pub setting: Setting, } #[derive(Debug, Default)] @@ -17,8 +14,3 @@ pub struct Chat { pub input: Arc>, pub output: Arc>, } - -#[derive(Debug, Default)] -pub struct Setting { - pub chat: Arc>, -} diff --git a/src/ui.rs b/src/ui.rs index d26067d..fae082e 100644 --- a/src/ui.rs +++ b/src/ui.rs @@ -19,7 +19,7 @@ pub struct Uis { pub setting: Setting, } impl Uis { - pub fn init() -> Self { + pub fn new() -> Self { Default::default() } @@ -29,7 +29,11 @@ impl Uis { .frame(util::transparent_frame(ctx.egui_ctx)) .show(ctx.egui_ctx, |ui| { ui.horizontal(|ui| { - ui.selectable_value(&mut self.focused_panel, Panel::Chat, Panel::Chat.name()); + ui.selectable_value( + &mut self.focused_panel, + Panel::Chat, + Panel::Chat.name(), + ); ui.separator(); ui.selectable_value( &mut self.focused_panel, diff --git a/src/ui/panel/chat.rs b/src/ui/panel/chat.rs index 7624ef4..ef0a0ee 100644 --- a/src/ui/panel/chat.rs +++ b/src/ui/panel/chat.rs @@ -14,14 +14,14 @@ pub struct Chat { impl UiT for Chat { fn draw(&mut self, ui: &mut Ui, ctx: &mut AiRContext) { // TODO: other running cases. - let is_running = ctx.services.hotkey.is_running(); + let ic_chatting = ctx.services.is_chatting(); let size = ui.available_size(); ScrollArea::vertical().id_source("Input").max_height((size.y - 50.) / 2.).show(ui, |ui| { let input = ui.add_sized( (size.x, ui.available_height()), TextEdit::multiline({ - if is_running { + if ic_chatting { if let Some(i) = ctx.state.chat.input.try_read() { i.clone_into(&mut self.input); } @@ -79,7 +79,7 @@ impl UiT for Chat { // Shortcuts. ui.horizontal(|ui| { ui.with_layout(Layout::right_to_left(Align::Center), |ui| { - if is_running { + if ic_chatting { ui.spinner(); } else { // TODO: retry. @@ -97,8 +97,7 @@ impl UiT for Chat { ScrollArea::vertical().id_source("Output").show(ui, |ui| { ui.label({ - // FIXME: `is_running` is conflict with `try_read`. - if is_running { + if ic_chatting { if let Some(o) = ctx.state.chat.output.try_read() { o.clone_into(&mut self.output); } diff --git a/src/ui/panel/setting.rs b/src/ui/panel/setting.rs index e11afe4..63d50b9 100644 --- a/src/ui/panel/setting.rs +++ b/src/ui/panel/setting.rs @@ -82,7 +82,7 @@ impl UiT for Setting { }); ui.end_row(); - // TODO: we might not need to reload the client if only the model changed. + // TODO: we might not need to renew the client if only the model changed. ui.label("Model"); ComboBox::from_id_source("Model") .selected_text(&ctx.components.setting.ai.model) @@ -99,7 +99,7 @@ impl UiT for Setting { }); ui.end_row(); - // TODO: we might not need to reload the client if only the temperature changed. + // TODO: we might not need to renew the client if only the temperature changed. ui.label("Temperature"); ui.spacing_mut().slider_width = size.x; changed |= ui @@ -113,7 +113,15 @@ impl UiT for Setting { }); if changed { - ctx.components.reload_openai(); + ctx.services.chat.renew( + ctx.services.keyboard.clone(), + ctx.services.rt.as_ref().expect("runtime must exist"), + ctx.services.is_chatting.clone(), + ctx.components.setting.ai.clone(), + ctx.components.setting.chat.clone(), + ctx.state.chat.input.clone(), + ctx.state.chat.output.clone(), + ); } });