Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 53 additions & 51 deletions create_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import configparser
import argparse

# Conditionally import OpenAI and Google Generative AI
# Check for required libraries
try:
from openai import OpenAI
except ImportError:
Expand All @@ -16,121 +16,123 @@
except ImportError:
genai = None

# Get config dir from environment or default to ~/.config
try:
from groq import Groq
except ImportError:
Groq = None

CONFIG_DIR = os.getenv('XDG_CONFIG_HOME', os.path.expanduser('~/.config'))
OPENAI_API_KEYS_LOCATION = os.path.join(CONFIG_DIR, 'openaiapirc')
GEMINI_API_KEYS_LOCATION = os.path.join(CONFIG_DIR, 'geminiapirc')
GROQ_API_KEYS_LOCATION = os.path.join(CONFIG_DIR, 'groqapirc')

def create_template_ini_file(api_type):
"""
If the ini file does not exist create it and add the api_key placeholder
"""
if api_type == 'openai':
file_path = OPENAI_API_KEYS_LOCATION
content = '[openai]\nsecret_key=\n'
url = 'https://platform.openai.com/api-keys'
else: # gemini
file_path = GEMINI_API_KEYS_LOCATION
content = '[gemini]\napi_key=\n'
url = 'Google AI Studio'

file_info = {
'openai': (OPENAI_API_KEYS_LOCATION, '[openai]\nsecret_key=\n', 'https://platform.openai.com/api-keys'),
'gemini': (GEMINI_API_KEYS_LOCATION, '[gemini]\napi_key=\n', 'Google AI Studio'),
'groq': (GROQ_API_KEYS_LOCATION, '[groq]\napi_key=\n', 'Groq API dashboard')
}

file_path, content, url = file_info[api_type]

if not os.path.isfile(file_path):
with open(file_path, 'w') as f:
f.write(content)

print(f'{api_type.capitalize()} API config file created at {file_path}')
print('Please edit it and add your API key')
print(f'If you do not yet have an API key, you can get it from: {url}')
sys.exit(1)

def initialize_api(api_type):
"""
Initialize the specified API
"""
create_template_ini_file(api_type)
config = configparser.ConfigParser()
config.read(os.path.join(CONFIG_DIR, f'{api_type}apirc'))
api_config = {k: v.strip("\"'") for k, v in config[api_type].items()}

if api_type == 'openai':
config.read(OPENAI_API_KEYS_LOCATION)
api_config = {k: v.strip("\"'") for k, v in config["openai"].items()}
client = OpenAI(
api_key=api_config["secret_key"],
base_url=api_config.get("base_url", "https://api.openai.com/v1"),
organization=api_config.get("organization")
)
api_config.setdefault("model", "gpt-3.5-turbo-0613")
return client, api_config
else: # gemini
config.read(GEMINI_API_KEYS_LOCATION)
api_config = {k: v.strip("\"'") for k, v in config["gemini"].items()}
api_config["model"] = api_config.get("model", "gpt-3.5-turbo-0613")
elif api_type == 'gemini':
genai.configure(api_key=api_config["api_key"])
api_config.setdefault("model", "gemini-1.5-pro-latest")
return genai, api_config
client = genai
api_config["model"] = api_config.get("model", "gemini-1.5-pro-latest")
else: # groq
client = Groq(api_key=api_config["api_key"])
api_config["model"] = api_config.get("model", "llama3-8b-8192")

return client, api_config

def get_completion(api_type, client, config, full_command):
system_message = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block."

if api_type == 'openai':
response = client.chat.completions.create(
model=config["model"],
messages=[
{
"role": 'system',
"content": "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block.",
},
{
"role": 'user',
"content": full_command,
}
{"role": 'system', "content": system_message},
{"role": 'user', "content": full_command},
],
temperature=float(config.get("temperature", 1.0))
)
return response.choices[0].message.content
else: # gemini
elif api_type == 'gemini':
model = client.GenerativeModel(config["model"])
chat = model.start_chat(history=[])
prompt = "You are a zsh shell expert, please help me complete the following command. Only output the completed command, no need for any other explanation. Do not put the completed command in a code block.\n\n" + full_command
prompt = f"{system_message}\n\n{full_command}"
response = chat.send_message(prompt)
return response.text
else: # groq
response = client.chat.completions.create(
model=config["model"],
messages=[
{"role": 'system', "content": system_message},
{"role": 'user', "content": full_command},
],
temperature=float(config.get("temperature", 0.5)),
max_tokens=int(config.get("max_tokens", 1024)),
top_p=float(config.get("top_p", 0.65)),
stream=False,
)
return response.choices[0].message.content

def main():
parser = argparse.ArgumentParser(description="Generate command completions using AI.")
parser.add_argument('--api', choices=['openai', 'gemini'], default='openai', help="Choose the API to use (default: openai)")
parser.add_argument('--api', choices=['openai', 'gemini', 'groq'], default='openai', help="Choose the API to use (default: openai)")
parser.add_argument('cursor_position', type=int, help="Cursor position in the input buffer")
args = parser.parse_args()

if args.api == 'openai' and OpenAI is None:
print("OpenAI library is not installed. Please install it using 'pip install openai'")
sys.exit(1)
elif args.api == 'gemini' and genai is None:
print("Google Generative AI library is not installed. Please install it using 'pip install google-generativeai'")
api_libs = {'openai': OpenAI, 'gemini': genai, 'groq': Groq}
if api_libs[args.api] is None:
print(f"{args.api.capitalize()} library is not installed. Please install it using 'pip install {args.api}'")
sys.exit(1)

client, config = initialize_api(args.api)

# Read the input prompt from stdin.
buffer = sys.stdin.read()
zsh_prefix = '#!/bin/zsh\n\n'
buffer_prefix = buffer[:args.cursor_position]
buffer_suffix = buffer[args.cursor_position:]
full_command = zsh_prefix + buffer_prefix + buffer_suffix
full_command = f"{zsh_prefix}{buffer_prefix}{buffer_suffix}"

completion = get_completion(args.api, client, config, full_command)

if completion.startswith(zsh_prefix):
completion = completion[len(zsh_prefix):]

line_prefix = buffer_prefix.rsplit("\n", 1)[-1]
# Handle all the different ways the command can be returned
for prefix in [buffer_prefix, line_prefix]:
if completion.startswith(prefix):
completion = completion[len(prefix):]
break

if buffer_suffix and completion.endswith(buffer_suffix):
completion = completion[:-len(buffer_suffix)]

completion = completion.strip("\n")
completion = completion.rstrip(buffer_suffix).strip("\n")
if line_prefix.strip().startswith("#"):
completion = "\n" + completion
completion = f"\n{completion}"

sys.stdout.write(completion)

Expand Down
2 changes: 1 addition & 1 deletion zsh_codex.plugin.zsh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# This ZSH plugin reads the text from the current buffer
# and uses a Python script to complete the text.
api="openai"
api="${ZSH_CODEX_AI_SERVICE:-groq}" # Default to OpenAI if not set

create_completion() {
# Get the text typed until now.
Expand Down