From ecc7c88a21b0d0d2478e0af8498fb2624fed8957 Mon Sep 17 00:00:00 2001 From: DayOS <43686351+Day-OS@users.noreply.github.com> Date: Sun, 9 Jul 2023 17:52:47 -0300 Subject: [PATCH] Better Response time, Intervals fixed, Better Replies, Better BAIChat API --- .vscode/launch.json | 16 ++++ Cargo.toml | 1 + readme.md | 10 ++- src/ai.rs | 30 +++++--- src/baichat_rs.rs | 107 +++++++++++++++++++------- src/chat_logs.rs | 181 +++++++++++++++++++++----------------------- src/main.rs | 128 ++++++++++++++++++------------- src/memory_core.rs | 43 ++++++++--- src/utils.rs | 6 +- 9 files changed, 318 insertions(+), 204 deletions(-) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..10efcb2 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug", + "program": "${workspaceFolder}/", + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index fbc42f2..e8711a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,4 @@ simplelog = "0.12.1" stop-words = { version="0.7.2", features = ["nltk", "iso"] } rand = "0.8.5" async-recursion = "1.0.4" +json-gettext = "4.0.5" diff --git a/readme.md b/readme.md index 568ec24..d0791a9 100644 --- a/readme.md +++ b/readme.md @@ -1,4 +1,8 @@ -- [ ] APAGAR O QUE JÁ FOI LIDO -- [ ] INCLUDE DB - [ ] REORGANIZE -- [ ] DO SO THE BOT REPLIES WHEN SOMEONE IS REFERENCING HIM. \ No newline at end of file +- [ ] MARK AS ALREADY READ +- [x] TYPING STATE +- [X] DO SO THE BOT REPLIES WHEN SOMEONE IS REFERENCING HIM. +- [ ] ALTERNABLE STATES (COUNTDOWN) +-MAKE THE BOT ASNWER THE RIGHT PERSON + +ERROR BAICHAT:104:60 thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Error("expected value", line: 1, column: 1)', src/baichat_rs.rs:104:60 \ No newline at end of file diff --git a/src/ai.rs b/src/ai.rs index c7773b7..29e0f2a 100644 --- a/src/ai.rs +++ b/src/ai.rs @@ -2,6 +2,7 @@ use crate::{baichat_rs::{self, ThebAI, Delta}, ai}; use serde::{Serialize, Deserialize}; use std::env; +use lazy_static; #[derive(Debug)] pub enum Error{ @@ -33,16 +34,21 @@ pub async fn reply(mut logs: String, memories: Option) -> Result = match ai.ask(&prompt, Some(env::var("PARENT_MESSAGE_ID_GPT").unwrap())).await { - Ok(message) => message, - Err(_) => {return Err(Error::CouldntGenerateResponseFromAI)} - }; - let answer: String = answer[answer.len() - 1].text.clone(); - println!("PROMPT: {}", answer); - let inputmsg: ai::ResponseMessage = match serde_json::from_str(&answer) { - Ok(memory)=>{memory} - Err(_)=>{return Err(Error::CouldntConvertToJSON)} - }; - Ok(inputmsg) + let mut baichat: baichat_rs::ThebAI = get_ai(); + + for i in 0..3 { + log::info!("GENERATING MESSAGE - try #{}", i); + println!("GENERATING MESSAGE - try #{}", i); + let answer = match baichat.ask_single(&prompt, Some(env::var("PARENT_MESSAGE_ID_GPT").unwrap())).await { + Ok(message) => message, + Err(_) => continue + }; + //let answer: String = baichat_rs::delta_to_string(answer).await; + let inputmsg: ai::ResponseMessage = match serde_json::from_str(&answer.text) { + Ok(memory)=>{memory} + Err(_)=> continue + }; + return Ok(inputmsg) + } + return Err(Error::CouldntGenerateResponseFromAI) } \ No newline at end of file diff --git a/src/baichat_rs.rs b/src/baichat_rs.rs index aa7cfaa..d26bdaf 100644 --- a/src/baichat_rs.rs +++ b/src/baichat_rs.rs @@ -6,6 +6,17 @@ use ratmom::{prelude::*, Request, config::SslOption}; use serde::{Serialize, Deserialize}; use serde_json::json; +#[derive(Serialize, Deserialize, Debug)] +pub struct Input { + pub prompt: String, + pub options: Options, +} +#[derive(Serialize, Deserialize, Debug)] +pub struct Options { + #[serde(rename(serialize = "parentMessageId", deserialize = "parentMessageId"))] + pub parent_message_id: String, +} + #[derive(Serialize, Deserialize, Debug)] pub struct Delta { pub role: String, @@ -41,7 +52,6 @@ pub struct DeltaChoice { pub struct ThebAI { pub parent_message_id: Option, } - impl ThebAI { pub fn new(parent_message_id: Option<&str>) -> ThebAI { if let Some(parent_message_id) = parent_message_id { @@ -55,29 +65,62 @@ impl ThebAI { } } - pub async fn ask(&mut self, prompt: &str, parent_message_id: Option) -> Result, Box> { - - let mut body = String::new(); - body.push_str(r#"{ - "prompt": "#); - body.push_str(&json!(prompt).to_string()); - if let Some(parent_message_id) = parent_message_id { - body.push_str(r#", - "options": { - "parentMessageId": ""#); - body.push_str(parent_message_id.as_str()); - body.push_str(r#"" - } - }"#); + pub async fn ask_single(&mut self, prompt: &str, parent_message_id: Option) -> Result> { + let parent_message_id: String = if let Some(parent_message_id) = parent_message_id { + parent_message_id } else { - body.push_str(r#"", - "options": { - "parentMessageId": ""#); - body.push_str(self.parent_message_id.as_ref().unwrap().as_str()); - body.push_str(r#"" - } - }"#); + self.parent_message_id.clone().unwrap() + }; + let body: String = serde_json::to_string(&Input{ prompt:json!(prompt).to_string(), options: Options { parent_message_id: parent_message_id }}).unwrap(); + + + //println!("{}", body); + let mut request = Request::builder() + .method("POST") + .uri("https://chatbot.theb.ai/api/chat-process") + .header("User-Agent", "Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/112.0") + .header("Accept-Language", "en-US,en;q=0.5") + .header("Content-Type", "application/json") + .header("Referer", "https://chatbot.theb.ai") + .header("Origin", "https://chatbot.theb.ai") + .ssl_options(SslOption::DANGER_ACCEPT_INVALID_CERTS | SslOption::DANGER_ACCEPT_INVALID_HOSTS | SslOption::DANGER_ACCEPT_REVOKED_CERTS) + .body(body)? + .send()?; + + let result = request.text()?; + + println!("BAICHAT RESULT: {}", result); + //println!("{:?}", result.lines()); + let mut target_line: &str = "".into(); + for line in result.lines() { + if line == "" { + continue; + } + target_line = line; + } + match serde_json::from_str(target_line) { + Ok (delta)=> { + return Ok(delta) + } + Err(err)=>{ + println!("BAICHAT ERRROR: {} \n {}", err, result); + return Err(err.into()) + } } + //println!("{:?}", deltas); + + ; + } + + pub async fn ask(&mut self, prompt: &str, parent_message_id: Option) -> Result, Box> { + let parent_message_id: String = if let Some(parent_message_id) = parent_message_id { + parent_message_id + } else { + self.parent_message_id.clone().unwrap() + }; + let body: String = serde_json::to_string(&Input{ prompt:json!(prompt).to_string(), options: Options { parent_message_id: parent_message_id }}).unwrap(); + + //println!("{}", body); let mut request = Request::builder() .method("POST") @@ -85,7 +128,6 @@ impl ThebAI { .header("User-Agent", "Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/112.0") .header("Accept-Language", "en-US,en;q=0.5") .header("Content-Type", "application/json") - //.header("Host", "chatbot.theb.ai") .header("Referer", "https://chatbot.theb.ai") .header("Origin", "https://chatbot.theb.ai") .ssl_options(SslOption::DANGER_ACCEPT_INVALID_CERTS | SslOption::DANGER_ACCEPT_INVALID_HOSTS | SslOption::DANGER_ACCEPT_REVOKED_CERTS) @@ -95,22 +137,31 @@ impl ThebAI { let result = request.text()?; let mut deltas: Vec = Vec::new(); + println!("BAICHAT RESULT: {}", result); //println!("{:?}", result.lines()); for line in result.lines() { if line == "" { continue; } - let delta: Delta = serde_json::from_str(line).unwrap(); - deltas.push(delta); + match serde_json::from_str(line) { + Ok (delta)=> { + deltas.push(delta) + } + Err(err)=>{println!("BAICHAT ERRROR: {} | {} ", err, line)} + } + } + match deltas.last() { + Some(delta)=>{self.parent_message_id = Some(delta.id.clone())} + None => {return Err("Error::WrongFormat(result)".into())} } - - self.parent_message_id = Some(deltas.last().unwrap().id.clone()); //println!("{:?}", deltas); return Ok(deltas); } } - +pub async fn deltas_to_string(delta : Vec) -> String{ + delta[delta.len() - 1].text.clone() +} diff --git a/src/chat_logs.rs b/src/chat_logs.rs index add99dc..0c59fc4 100644 --- a/src/chat_logs.rs +++ b/src/chat_logs.rs @@ -1,7 +1,7 @@ use async_recursion::async_recursion; -use futures::FutureExt; +use meilisearch_sdk::documents::DocumentQuery; use serde::{Serialize, Deserialize}; -use serenity::{http::Http, model::prelude::{Message, Channel}}; +use serenity::{http::Http, model::prelude::{Message, ChannelId}}; use crate::utils; @@ -15,12 +15,13 @@ pub struct RawMessage { pub timestamp: i64, pub user_image: String, pub message: String, - pub reference_id: Option + pub reference_id: Option, + pub read: bool, } #[derive(Debug)] -pub struct SavedMessagesFromChannel{pub channel_name:String,pub quantity:usize} +pub struct SavedMessagesFromChannel{pub channel_id:u64,pub quantity:usize} #[derive(Debug)] pub struct SavedMessage(pub String); @@ -29,9 +30,8 @@ pub struct SavedMessage(pub String); #[derive(Serialize, Deserialize, Clone, Debug)] pub struct ChatLogs(pub Vec); impl ChatLogs { - pub async fn to_string(&self) -> String{ + pub async fn build(&self, target_message_id: Option) -> String{ let bot_id = crate::get_bot_id().await; - println!("BOOOOOOOOOOOOOOOT ID!!!!!!!!!!!!!!!1 {}", bot_id); let mut messages: String = "".into(); for raw_message in self.0.clone() { if raw_message.user == bot_id { @@ -41,21 +41,55 @@ impl ChatLogs { ); } else{ - messages += &format!("ID DA MENSAGEM: {} | TEMPO: {} - NOME DE USUÁRIO:'{}' disse = {}\n", - raw_message.id, - serenity::model::Timestamp::from_unix_timestamp(raw_message.timestamp).unwrap().to_string(), - raw_message.user_name, - raw_message.message - ); + if target_message_id.is_some(){ + if raw_message.id == target_message_id.unwrap() { + messages += &format!("ID DA MENSAGEM QUE PRECISA SER RESPONDIDO: {} | TEMPO: {} - NOME DE USUÁRIO:'{}' disse = {}\n", + raw_message.id, + serenity::model::Timestamp::from_unix_timestamp(raw_message.timestamp).unwrap().to_string(), + raw_message.user_name, + raw_message.message); + } + else{ + messages += &format!("MENSAGEM PARA CONTEXTO | TEMPO: {} - NOME DE USUÁRIO:'{}' disse = {}\n", + serenity::model::Timestamp::from_unix_timestamp(raw_message.timestamp).unwrap().to_string(), + raw_message.user_name, + raw_message.message); + } + } + else{ + messages += &format!("ID DA MENSAGEM: {} | TEMPO: {} - NOME DE USUÁRIO:'{}' disse = {}\n", + raw_message.id, + serenity::model::Timestamp::from_unix_timestamp(raw_message.timestamp).unwrap().to_string(), + raw_message.user_name, + raw_message.message); + } } } messages } + pub fn filter_read(&mut self){ + self.0 = self.0.clone().into_iter().filter(|msg| !msg.read).collect(); + } pub fn get_ids(&self) -> Vec{ self.0.clone().into_iter().map(|log|{log.id}).collect() } } + +pub async fn set_read(logs: ChatLogs) -> Result<(), utils::Error>{ + let logs: Vec = logs.0.into_iter().map(|mut message| {message.read = true; message}).collect(); + match utils::get_database_client() + .index(utils::DBIndexes::RawMessage.as_str()) + .add_or_replace(&logs, None) + .await{ + Ok(_)=>{ + log::info!("The following messages were set as read in database: {:?}", logs) + } + Err(e)=>{return Err(utils::Error::MeiliSearchError(e))} + }; + Ok(()) +} + pub async fn delete_logs(logs: ChatLogs) -> Result<(), utils::Error>{ match utils::get_database_client() .index(utils::DBIndexes::RawMessage.as_str()) @@ -67,24 +101,26 @@ pub async fn delete_logs(logs: ChatLogs) -> Result<(), utils::Error>{ Err(e)=>{return Err(utils::Error::MeiliSearchError(e))} }; Ok(()) - } +} + + +pub async fn get_specific_message(target_message:u64) -> Result{ + let raw_message = DocumentQuery::new( &utils::get_database_client().index(utils::DBIndexes::RawMessage.as_str())) + .execute::(&target_message.to_string()) + .await; + match raw_message{ + Ok(raw_message) => {return Ok(raw_message)} + Err(e)=>{ return Err(utils::Error::NoReferencedMessageFound(e.to_string()))} + } +} pub async fn get_conversation_with_user(target_message:u64) -> Result{ - async fn search_message_from_id(target_message: u64) -> Result{ - let raw_messages_searched = utils::get_database_client() - .index(utils::DBIndexes::RawMessage.as_str()) - .search() - .with_query(&target_message.to_string()) - .with_sort(&["timestamp:desc", "username:desc", "message:desc"]) - .with_limit(1) - .execute::() - .await.unwrap(); - Ok(raw_messages_searched.hits.first().unwrap().result.clone()) - } #[async_recursion] async fn get_referenced_message(target_message: u64, mut total_messages: &mut Vec) -> (){ - //println!("{:?}", total_messages); - let new_target: RawMessage = search_message_from_id(target_message).await.unwrap(); + println!("{:?}", total_messages); + let new_target_result = get_specific_message(target_message).await; + if new_target_result.is_err(){return} + let new_target = new_target_result.unwrap(); total_messages.push(new_target.clone()); if new_target.reference_id.is_some() { get_referenced_message(new_target.reference_id.unwrap(), &mut total_messages).await; @@ -99,17 +135,6 @@ pub async fn get_conversation_with_user(target_message:u64) -> Result Result{ - let raw_messages_searched = utils::get_database_client() - .index(utils::DBIndexes::RawMessage.as_str()) - .search() - .with_query(&target_message.to_string()) - .with_sort(&["timestamp:desc", "username:desc", "message:desc"]) - .with_limit(1) - .execute::() - .await.unwrap(); - Ok(raw_messages_searched.hits.first().unwrap().result.clone()) -} pub async fn get_last_n_messages(n: usize) -> Result{ @@ -117,78 +142,43 @@ pub async fn get_last_n_messages(n: usize) -> Result{ .index(utils::DBIndexes::RawMessage.as_str()) .search() .with_filter("message IS NOT EMPTY") - .with_sort(&["timestamp:desc", "username:desc", "message:desc"]) + .with_sort(&["timestamp:desc"]) .with_limit(n) .execute::() .await{ Ok(msg)=>{msg} Err(e)=>{return Err(utils::Error::MeiliSearchError(e))} }; - raw_messages_searched.hits.sort_by(|a,b| {a.result.timestamp.cmp(&b.result.timestamp)}); - let mut logs: ChatLogs = ChatLogs(vec!()); - for raw_message in raw_messages_searched.hits {logs.0.push(raw_message.result);} - Ok(logs) + Ok(ChatLogs( raw_messages_searched.hits.into_iter().map(|msg| msg.result).collect()) ) } -pub async fn save_last_n_messages(http: &Http, chat_id: u64, n: u64) -> Result{ - let channel: Channel = http.get_channel(chat_id).await.unwrap(); - fn add_raw_message(raw_messages: &mut Vec, message: &Message){ - let referenced_message: Option = match message.referenced_message.clone(){ - Some(m)=>{Some(m.id.0)} - None=>{None} - }; - - - let raw_message = RawMessage{ - user_name:message.author.clone().name, - id: message.id.0, - timestamp:message.timestamp.unix_timestamp(), - user: message.author.id.0, - user_image: message.author.clone().avatar_url().unwrap_or("NO AVATAR".into()), - message: message.clone().content, - reference_id: referenced_message - }; - raw_messages.push(raw_message); - } - let channel_name; - let discord_messages: Vec = match channel { - Channel::Guild(channel)=>{ - if channel.nsfw {return Err(utils::Error::ChannelIsNSFW)}; - channel_name = channel.name().into(); - channel.messages(&http, |retriever| retriever.limit(n)).await.unwrap() - } - Channel::Private(channel)=>{ - //PREVENTS SKETCHY DM WEIRDOS FROM... DATING THE AI - let owner: u64 = env!("DISCORD_BOT_OWNER").parse::().unwrap(); - if channel.recipient.id != owner{return Err(utils::Error::PrivateChannelUserIsNotOwner)} - - channel_name = channel.clone().recipient.name + "'s DM"; - channel.messages(&http, |retriever| retriever.limit(n)).await.unwrap() +pub async fn save_last_n_messages(http: &Http, channel_id: u64, n: u64) -> Result{ + let messages: Vec = ChannelId(channel_id).messages(&http, |retriever| retriever.limit(n)).await.unwrap(); + let raw_messages : Vec = messages.iter().map(|message| { + let referenced_message: Option = match message.referenced_message.clone(){ + Some(m)=>{Some(m.id.0)} + None=>{None} + }; + RawMessage{user_name:message.author.clone().name, + id: message.id.0, + timestamp:message.timestamp.unix_timestamp(), + user: message.author.id.0, + user_image: message.author.clone().avatar_url().unwrap_or("NO AVATAR".into()), + message: message.clone().content, + reference_id: referenced_message, + read: false + } } - _=>{return Err(utils::Error::Generic)} - }; - + ).collect(); let db = utils::get_database_client(); let db_messages = db.index(utils::DBIndexes::RawMessage.as_str()); + db_messages.add_documents(&raw_messages, None).await.unwrap(); - let mut raw_messages : Vec = vec!(); - for message in discord_messages{ - let idstr = message.id.0.to_string(); - let id = idstr.as_str(); - match db_messages.search().with_query(id).execute::().await{ - Ok(pages)=>{if pages.hits.len() == 0 {add_raw_message(&mut raw_messages, &message)}}, - Err(_)=>{add_raw_message(&mut raw_messages, &message)} - } - } - - db.index(utils::DBIndexes::RawMessage.as_str()) - .add_documents(&raw_messages, None) - .await.unwrap(); - - Ok(SavedMessagesFromChannel{ channel_name: channel_name, quantity: raw_messages.len() }) + Ok(SavedMessagesFromChannel{ channel_id: channel_id, quantity: raw_messages.len() }) + } pub async fn save_message(message: Message, pre_defined_reference: Option) -> Result{ @@ -211,7 +201,8 @@ pub async fn save_message(message: Message, pre_defined_reference: Option) user: message.author.id.0, user_image: message.author.clone().avatar_url().unwrap_or("NO AVATAR".into()), message: message.clone().content, - reference_id: referenced_message + reference_id: referenced_message, + read:false }; db.index(utils::DBIndexes::RawMessage.as_str()) diff --git a/src/main.rs b/src/main.rs index 45e115e..19f4373 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,8 +2,9 @@ use std::env; mod baichat_rs; use crate::chat_logs::ChatLogs; +use serenity::model::Timestamp; use serenity::{async_trait}; -use serenity::model::prelude::{Ready}; +use serenity::model::prelude::{Ready, interaction}; use serenity::prelude::*; use serenity::model::channel::Message; use serenity::framework::standard::macros::{command, group}; @@ -20,7 +21,7 @@ mod chat_logs; #[group] -#[commands(populate, talk)] +#[commands(memories)] struct General; @@ -41,7 +42,7 @@ impl BotState { fn get_cap(&self)->u64{ match self { BotState::Standby =>{200} - BotState::Active => {15} + BotState::Active => {100} _=>{panic!()} } } @@ -53,6 +54,7 @@ lazy_static::lazy_static!{ static ref MESSAGES_COUNTER : Mutex = Mutex::new(0); static ref BOT_STATE: Mutex = Mutex::new(BotState::Active); //Mutex::new(BotState::WaitingAnswer(WaitingAnswerArgs { user_id: "236575283984072704".into(), message: 1125966367254913126 })); static ref BOT_ID: Mutex = Mutex::new(0); + static ref LAST_ACTIVITY_TIME: Mutex = Mutex::new(0); } pub async fn get_bot_id() -> u64{ BOT_ID.lock().await.clone() @@ -62,32 +64,46 @@ struct Handler; #[async_trait] impl EventHandler for Handler { async fn ready(&self, ctx: Context, ready: Ready) { + *LAST_ACTIVITY_TIME.lock().await = Timestamp::now().unix_timestamp(); *BOT_ID.lock().await = ctx.http.application_id().unwrap(); + let whitelisted_channel: u64 = env::var("WHITELISTED_CHANNEL").unwrap().parse::().unwrap(); + + match chat_logs::save_last_n_messages(&ctx.http, whitelisted_channel, 25).await { + Ok(saved_message)=>{ + println!("AUTOMATIC SAVING SUCCEED: {} {}", saved_message.channel_id ,saved_message.quantity); + log::info!("AUTOMATIC SAVING SUCCEED: {} {}", saved_message.channel_id ,saved_message.quantity); + } + Err(err)=>{ + println!("MESSAGE NOT SAVED | {:?}", err); + log::error!("MESSAGE NOT SAVED | {:?}", err); + } + }; println!("{} is connected!", ready.user.name); } async fn message(&self, ctx: Context, message: Message) { - let owner: u64 = env!("DISCORD_BOT_OWNER").parse::().unwrap(); //---------------------------------- CHECKS ---------------------------------- + //TEMPORARIO PARA TESTES VVVVVVVVVV - + //let owner: u64 = env!("DISCORD_BOT_OWNER").parse::().unwrap(); //if message.is_private() && (message.author.bot || message.author.id.0 == owner){}else {return}; + let whitelisted_channel: u64 = env::var("WHITELISTED_CHANNEL").unwrap().parse::().unwrap(); if message.channel_id.0 != whitelisted_channel {return}; //---------------------------------- MANAGING STATES ---------------------------------- - let mut state = BOT_STATE.lock().await; - println!("state: {:?}", state); let mut message_counter = MESSAGES_COUNTER.lock().await; let mut message_cap = MESSAGES_CAP.lock().await; let bot_id: u64 = ctx.http.application_id().unwrap(); match message.clone().referenced_message { - Some(reply_message)=>{if reply_message.author.id == bot_id {*state = BotState::WaitingAnswer(WaitingAnswerArgs { user_id: message.author.id.0.to_string(), message: message.id.0})} } + Some(reply_message)=>{if reply_message.author.id == bot_id { + *state = BotState::WaitingAnswer(WaitingAnswerArgs { user_id: message.author.id.0.to_string(), message: message.id.0}) + }} None=>{} }; @@ -96,18 +112,28 @@ impl EventHandler for Handler { None=>{None} }; - // VVVVV - U64 = MAXIMUM CAP + let mut target_message_id: Option = None; + match chat_logs::save_message(message.clone(), referenced_message_id).await { + Ok(saved_message)=>{ + println!("MESSAGE SAVED: {}", saved_message.0); + log::info!("MESSAGE SAVED: {}", saved_message.0); + } + Err(err)=>{ + println!("MESSAGE NOT SAVED | {:?}", err); + log::error!("MESSAGE NOT SAVED | {:?}", err); + } + } + + tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + let (logs, maximum_cap): (ChatLogs, u64) = match state.clone() { BotState::Standby| BotState::Active=>{ - let _ = chat_logs::save_message(message.clone(), None).await; - if (message.content.contains(&format!("<@{}>", bot_id)) || message.content.contains(&format!("<@!{}>", bot_id)) || message.content.to_lowercase().contains("chell") || *message_counter > *message_cap ) && message.author.id.0 != bot_id - { - let logs = chat_logs::get_last_n_messages(*message_cap as usize).await.unwrap(); - (logs, state.clone().get_cap()) + { + (chat_logs::get_last_n_messages(10).await.unwrap(), state.clone().get_cap()) } else{ *message_counter += 1; @@ -115,42 +141,33 @@ impl EventHandler for Handler { } } BotState::WaitingAnswer(answer_args)=>{ - let _ = chat_logs::save_message(message.clone(), referenced_message_id).await; + if (LAST_ACTIVITY_TIME.lock().await.clone() - Timestamp::now().unix_timestamp()) > 60 { + *state = BotState::Active; + return + } if message.author.id.0 != answer_args.user_id.parse::().unwrap(){return} - *state = BotState::Active; + target_message_id = Some(message.id.0); (chat_logs::get_conversation_with_user(message.id.0).await.unwrap(), 2) } }; - println!("{:?}", logs.to_string().await); + + //logs.filter_read(); + println!("CHECK FILTER: {:?}", logs.build(target_message_id).await); + //---------------------------------- GENERATING RESPONSE ---------------------------------- + + *LAST_ACTIVITY_TIME.lock().await = Timestamp::now().unix_timestamp(); let typing = message.channel_id.start_typing(&ctx.http).unwrap(); let topics: topics::Topics = topics::Topics::from_logs(&logs).await.unwrap(); - let memories: Option = memory_core::load_memory_from_db(topics).await; - - let mut response: Option = None; - for _ in 0..5 { - response = match ai::reply(logs.to_string().await, memories.clone()).await { - Ok(response)=>{ - println!("{:?}", response); - Some(response) - } - Err(e)=>{ - match e { - ai::Error::CouldntGenerateResponseFromAI | ai::Error::CouldntConvertToJSON =>{ - log::error!("MESSAGE NOT GENERATED, RETRYING! err:{:?}", e); - println!("MESSAGE NOT GENERATED, RETRYING! err:{:?}", e); - continue - - } - } - } - }; - break; - } - let mut response: ai::ResponseMessage = response.unwrap(); + let memories = memory_core::load_memory(topics).await; + let memories: Option = if memories.is_ok(){Some(memories.unwrap().to_string().await)} else{None}; + let mut response = match ai::reply(logs.build(target_message_id).await, memories.clone()).await{ + Ok(res)=>{res}, + Err(_)=>{return} + }; //---------------------------------- SENDING RESPONSE ---------------------------------- @@ -163,8 +180,8 @@ impl EventHandler for Handler { None => None, }; - if response.question { response.message += " | ⤵️" } - let reply_message : Message = match reply_target { + if response.question { response.message += &format!( " | ⤵️ ({})", message.author.name)} + _ = match reply_target { Some(target)=>{ log::info!("Answered {}'s message. | Response: {}", target.author.name, response.message.clone()); target.reply(&ctx.http, response.message.clone()).await.unwrap() @@ -178,17 +195,17 @@ impl EventHandler for Handler { } }; + //chat_logs::set_read(logs).await; + //STATE SET HERE IS USED IN THE NEXT TIME THE EVENT IS FIRED if response.question { *state = BotState::WaitingAnswer(WaitingAnswerArgs { user_id: message.author.id.0.to_string(), message: message.id.0}) } else{ *state = BotState::Active } - _ = memory_core::save_memory_to_db(&response).await; + _ = memory_core::save_memory(&response).await; _ = typing.stop(); - //_ = chat_logs::delete_logs(logs).await; - println!("AAAAAAAAAAAAAA"); *message_counter = 0; *message_cap = rand::thread_rng().gen_range(1..maximum_cap); - + } @@ -200,13 +217,13 @@ async fn main() { CombinedLogger::init( vec![ TermLogger::new(LevelFilter::Warn, Config::default(), TerminalMode::Mixed, ColorChoice::Auto), - WriteLogger::new(LevelFilter::Info, Config::default(), File::create("Chell.log").unwrap()), + WriteLogger::new(LevelFilter::Info, ConfigBuilder::new().add_filter_allow("chell".into()).build(), File::create("Chell.log").unwrap()), ] ).unwrap(); let db = utils::get_database_client().index(utils::DBIndexes::RawMessage.as_str()); - db.set_sortable_attributes(["timestamp", "username", "message"]).await.unwrap(); - db.set_filterable_attributes(["message"]).await.unwrap(); + db.set_sortable_attributes(["timestamp"]).await.unwrap(); + db.set_filterable_attributes(["message", "reference_id"]).await.unwrap(); let framework = StandardFramework::new() .configure(|c| c.prefix("~")) @@ -228,7 +245,7 @@ async fn main() { } - +/* #[command] async fn populate(ctx: &Context, arg_msg: &Message, mut args: Args) -> CommandResult { @@ -246,16 +263,23 @@ async fn populate(ctx: &Context, arg_msg: &Message, mut args: Args) -> CommandRe match arg_msg.reply(&ctx.http, format!("💾 Queried {} messages to DB from channel {}.\n ~Im getting smarter :3", - result.quantity,result.channel_name)).await{ + result.quantity,result.channel_id)).await{ Ok(_)=>{println!("AAAAA")}, Err(e)=>{println!("{}", e)} }; Ok(()) } + */ #[command] -async fn talk(ctx: &Context, arg_msg: &Message) -> CommandResult { +async fn memories(ctx: &Context, arg_msg: &Message) -> CommandResult { + let mems = memory_core::load_last_n_memories(25, None).await.unwrap(); + _ = arg_msg.channel_id.send_message(&ctx.http, |msg| + msg + .embed(|embed| + embed.description("hmmm")) + ).await; //reply_message(ctx, arg_msg.channel_id, chat_logs::get_last_n_messages(28).await.unwrap()).await; Ok(()) } \ No newline at end of file diff --git a/src/memory_core.rs b/src/memory_core.rs index 61da4fe..d315ec8 100644 --- a/src/memory_core.rs +++ b/src/memory_core.rs @@ -1,6 +1,6 @@ use serde::{Serialize, Deserialize}; use crate::{utils::{self, DBIndexes}, topics::Topics, ai}; - +use crate::Timestamp; #[derive(Serialize, Deserialize, Debug, Clone,)] pub struct Memory{ @@ -10,28 +10,47 @@ pub struct Memory{ } #[derive(Debug)] pub struct SavedMemories(pub String); +#[derive(Debug)] +pub struct LoadedMemories(pub Vec); +impl LoadedMemories { + pub async fn to_string(&self) -> String{ + let mut memories :String = "".into(); + for r in self.0.clone() { + memories += &format!("{}\n", r.content); + } + memories + } +} + +pub async fn load_last_n_memories(n: usize, max_timestamp: Option) -> Result{ + let hits = match utils::get_database_client() + .index(DBIndexes::InputMemory.as_str()) + .search() + .with_filter(&format!("timestamp < {}", max_timestamp.unwrap_or(Timestamp::now().unix_timestamp() as u64))) + .with_limit(n) + .execute::().await{ + Ok(pages)=>{pages}, + Err(_)=>{return Err(utils::Error::NoMemoriesFound)} + }.hits; -pub async fn load_memory_from_db(topics: Topics) -> Option{ - let mut memories :String = "".into(); + Ok(LoadedMemories(hits.into_iter().map(|result| result.result).collect())) +} + +pub async fn load_memory(topics: Topics) -> Result{ + let memories :String = "".into(); let hits = match utils::get_database_client() .index(DBIndexes::InputMemory.as_str()) .search().with_query(&topics.to_query()) .with_limit(1) .execute::().await{ Ok(pages)=>{pages}, - Err(_)=>{return None} + Err(_)=>{return Err(utils::Error::NoMemoriesFound)} }.hits; - for r in hits { - memories += &format!("{}\n", r.result.content); - } - if !memories.is_empty() { - return Some(memories) - } - None + Ok(LoadedMemories(hits.into_iter().map(|result| result.result).collect())) } -pub async fn save_memory_to_db(input_memory: &ai::ResponseMessage) -> Result{ +pub async fn save_memory(input_memory: &ai::ResponseMessage) -> Result{ let learned: String = match &input_memory.learned { Some(learned)=>{if learned.to_lowercase() == "null"{return Err(utils::Error::NoMemoriesToBeSaved)} learned.to_string()} None=>{return Err(utils::Error::NoMemoriesToBeSaved)} diff --git a/src/utils.rs b/src/utils.rs index eca65d9..443d272 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -5,12 +5,14 @@ pub enum Error{ PrivateChannelUserIsNotOwner, Generic, NoTopicsFound, - NoMemoriesToBeSaved + NoMemoriesToBeSaved, + NoMemoriesFound, + NoMessagesFound, + NoReferencedMessageFound(String) } - pub enum DBIndexes{ RawMessage, InputMemory