diff --git a/services/services.py b/services/services.py index aa54991..2057e6a 100644 --- a/services/services.py +++ b/services/services.py @@ -11,7 +11,7 @@ class BaseClient(ABC): """Base class for all clients""" api_type: str = None - system_prompt = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block." + system_prompt = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, do not include any explanation, comments or any other text that is not part of the command. Do not put completed command in a code block." @abstractmethod def get_completion(self, full_command: str) -> str: @@ -54,9 +54,14 @@ def get_completion(self, full_command: str) -> str: model=self.config["model"], messages=[ {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": "# list all files with all attributes in current folder"}, + {"role": "assistant", "content": "ls -alhi"}, + {"role": "user", "content": "# go one directory up"}, + {"role": "assistant", "content": "cd .."}, {"role": "user", "content": full_command}, ], - temperature=float(self.config.get("temperature", 1.0)), + + temperature=float(self.config.get("temperature", 0.0)), ) return response.choices[0].message.content @@ -87,9 +92,25 @@ def __init__(self, config: dict): self.model = genai.GenerativeModel(self.config["model"]) def get_completion(self, full_command: str) -> str: - chat = self.model.start_chat(history=[]) - prompt = f"{self.system_prompt}\n\n{full_command}" - response = chat.send_message(prompt) + chat = self.model.start_chat(history=[ + { + "role": "user", + "parts": [f"{self.system_prompt}\n\n# list all files with all attributes in current folder"] + }, + { + "role": "model", + "parts": ["ls -alhi"] + }, + { + "role": "user", + "parts": ["# go one directory up"] + }, + { + "role": "model", + "parts": ["cd .."] + } + ]) + response = chat.send_message(full_command) return response.text @@ -125,9 +146,13 @@ def get_completion(self, full_command: str) -> str: model=self.config["model"], messages=[ {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": "# list all files with all attributes in current folder"}, + {"role": "assistant", "content": "ls -alhi"}, + {"role": "user", "content": "# go one directory up"}, + {"role": "assistant", "content": "cd .."}, {"role": "user", "content": full_command}, ], - temperature=float(self.config.get("temperature", 1.0)), + temperature=float(self.config.get("temperature", 0.0)), ) return response.choices[0].message.content @@ -164,9 +189,13 @@ def get_completion(self, full_command: str) -> str: model=self.config["model"], messages=[ {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": "# list all files with all attributes in current folder"}, + {"role": "assistant", "content": "ls -alhi"}, + {"role": "user", "content": "# go one directory up"}, + {"role": "assistant", "content": "cd .."}, {"role": "user", "content": full_command}, ], - temperature=float(self.config.get("temperature", 1.0)), + temperature=float(self.config.get("temperature", 0.0)), ) return response.choices[0].message.content @@ -213,6 +242,10 @@ def get_completion(self, full_command: str) -> str: import json messages = [ + {"role": "user", "content": "# list all files with all attributes in current folder"}, + {"role": "assistant", "content": "ls -alhi"}, + {"role": "user", "content": "# go one directory up"}, + {"role": "assistant", "content": "cd .."}, {"role": "user", "content": full_command} ] @@ -223,7 +256,7 @@ def get_completion(self, full_command: str) -> str: "max_tokens": 1000, "system": self.system_prompt, "messages": messages, - "temperature": float(self.config.get("temperature", 1.0)) + "temperature": float(self.config.get("temperature", 0.0)) } else: raise ValueError(f"Unsupported model: {self.config['model']}") @@ -237,9 +270,60 @@ def get_completion(self, full_command: str) -> str: return response_body["content"][0]["text"] +class OllamaClient(BaseClient): + """ + config keys: + - api_type="ollama" + - base_url (optional): defaults to "http://localhost:11434" + - model (optional): defaults to "llama3.2" or environment variable OLLAMA_DEFAULT_MODEL + - temperature (optional): defaults to 1.0. + """ + + api_type = "ollama" + default_model = os.getenv("OLLAMA_DEFAULT_MODEL", "llama3.2") + + def __init__(self, config: dict): + try: + import ollama + except ImportError: + print( + "Ollama library is not installed. Please install it using 'pip install ollama'" + ) + sys.exit(1) + + self.config = config + self.config["model"] = self.config.get("model", self.default_model) + + # Create ollama client with custom host if specified + if "base_url" in self.config: + self.client = ollama.Client(host=self.config["base_url"]) + else: + self.client = ollama.Client() + + def get_completion(self, full_command: str) -> str: + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": "# list all files with all attributes in current folder"}, + {"role": "assistant", "content": "ls -alhi"}, + {"role": "user", "content": "# go one directory up"}, + {"role": "assistant", "content": "cd .."}, + {"role": "user", "content": full_command} + ] + + response = self.client.chat( + model=self.config["model"], + messages=messages, + options={ + "temperature": float(self.config.get("temperature", 0.0)) + }, + think=True + ) + + return response["message"]["content"] + class ClientFactory: - api_types = [OpenAIClient.api_type, GoogleGenAIClient.api_type, GroqClient.api_type, MistralClient.api_type, AmazonBedrock.api_type] + api_types = [OpenAIClient.api_type, GoogleGenAIClient.api_type, GroqClient.api_type, MistralClient.api_type, AmazonBedrock.api_type, OllamaClient.api_type] @classmethod def create(cls): @@ -263,6 +347,8 @@ def create(cls): return MistralClient(config) case AmazonBedrock.api_type: return AmazonBedrock(config) + case OllamaClient.api_type: + return OllamaClient(config) case _: raise KeyError( f"Specified API type {api_type} is not one of the supported services {cls.api_types}"