diff --git a/Cargo.toml b/Cargo.toml index 175f9c7..36de1c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ license = "MIT OR GPL-2.0" async-trait = "0.1.80" bytes = "1.6.0" clap = { version = "4.5.9", features = ["derive"] } +clap_complete = "4.5.47" futures-core = "0.3.30" futures-util = "0.3.30" lazy_static = "1.4.0" diff --git a/src/cli/chat.rs b/src/cli/chat.rs index 6d078d1..834a43b 100644 --- a/src/cli/chat.rs +++ b/src/cli/chat.rs @@ -18,11 +18,10 @@ use crate::config; use crate::providers::{ChatProvider, ContextManagement, MessageDelta}; use crate::registry::populate::resolve_once; use crate::registry::registry::{self, ModelSpec, Registry}; -use crate::ChatArgs; +use crate::ChatOpts; use prompt::{model_prompt, user_prompt}; use tokio::{select, signal}; - pub(crate) enum Severity { Error, Warn, @@ -155,7 +154,7 @@ pub(crate) async fn chat_cmd( keybindings: config::Keybindings, default_model: Option, registry: Registry, - args: &ChatArgs, + args: &ChatOpts, ) { let in_terminal = io::stdin().is_terminal(); let out_terminal = io::stdout().is_terminal(); @@ -283,7 +282,7 @@ async fn chat<'p>( msg_buf.add_message(Message::user(prompt)); } - + let completion = provider .stream_completion(&model_id, &msg_buf.chat_messages()) .await; @@ -331,7 +330,7 @@ async fn chat<'p>( print!("{}", delta.content); flush_or_die(); } - + msg_builder.add(&delta); } Err(err) => panic!("failed to decode streaming response: {}", err), @@ -340,7 +339,7 @@ async fn chat<'p>( _ = signal::ctrl_c() => { skip_response = true; break; - } + } } } @@ -362,7 +361,7 @@ async fn chat<'p>( if !interactive { break; } - + pending_init_prompt = false; } } diff --git a/src/main.rs b/src/main.rs index 8ca503c..93f7a49 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ mod version; use std::path::PathBuf; -use clap::{Parser, Subcommand, ValueEnum}; +use clap::{CommandFactory, Parser, Subcommand, ValueEnum}; use cli::{chat::chat_cmd, list::list_cmd, ColorMode}; use config::read_config; use providers::providers::ProviderIdentifier; @@ -33,25 +33,31 @@ pub(crate) enum RequestedColorMode { author = "Alex ", version = version::VERSION )] -struct Cli { - #[arg(long, default_value_t = RequestedColorMode::default())] +struct Opts { + #[arg(help="Use ANSI color", long, default_value_t = RequestedColorMode::default())] color: RequestedColorMode, - #[arg(long)] + #[arg(help = "Specify alternative config path", long)] config: Option, #[command(subcommand)] command: Option, + #[arg( + help = "generate shell completion and exit", + long = "generate", + exclusive = true + )] + generator: Option, } #[derive(Subcommand)] enum Commands { /// Start a chat - Chat(ChatArgs), + Chat(ChatOpts), /// List available models List(ListArgs), } #[derive(Parser, Default)] -pub(crate) struct ChatArgs { +pub(crate) struct ChatOpts { /// Specifies the model to be used during the chat #[arg(short, long)] model: Option, @@ -121,7 +127,7 @@ fn hook_panics_with_reporting() { async fn main() { hook_panics_with_reporting(); - let cli = Cli::parse(); + let cli = Opts::parse(); let color = ColorMode::resolve_auto(cli.color); @@ -133,6 +139,14 @@ async fn main() { let editor: Option = config.editor.map(|s| s.into()); + if let Some(generator) = cli.generator { + let out_dir = "target/completions"; + let _ = std::fs::create_dir_all(out_dir); + let _ = clap_complete::generate_to(generator, &mut Opts::command(), version::NAME, out_dir); + println!("Generated completions for {} in {}", generator, out_dir); + return; + } + match &cli.command { Some(Commands::Chat(args)) => { chat_cmd( @@ -151,7 +165,7 @@ async fn main() { config.keybindings, config.default_model, registry, - &ChatArgs::default(), + &ChatOpts::default(), ) .await } diff --git a/src/version.rs b/src/version.rs index 95fc0d7..f428de0 100644 --- a/src/version.rs +++ b/src/version.rs @@ -1,2 +1,2 @@ -pub(crate) const VERSION: &'static str = "0.0.1-alpha.3"; -pub(crate) const NAME: &'static str = "xtalk"; \ No newline at end of file +pub(crate) const VERSION: &'static str = "0.0.1-alpha.3"; +pub(crate) const NAME: &'static str = "xtalk";