diff --git a/charla/chat.py b/charla/chat.py index b4fb837..e644e2d 100644 --- a/charla/chat.py +++ b/charla/chat.py @@ -75,17 +75,14 @@ def run(argv: argparse.Namespace) -> None: # Prompt used to give directions to the model at the beginning of the chat. system_prompt = argv.system_prompt.read() if argv.system_prompt else '' - # Determine client class and import corresponding module. - client_cls: Any = None + # Determine which Client class to import. if argv.provider == 'ollama': - from charla.client.ollama import OllamaClient - client_cls = OllamaClient + from charla.client.ollama import OllamaClient as ApiClient elif argv.provider == 'github': - from charla.client.github import AzureClient - client_cls = AzureClient + from charla.client.github import AzureClient as ApiClient # Start model API client before chat REPL in case of model errors. - client = client_cls(argv.model, system=system_prompt, message_limit=argv.message_limit) + client = ApiClient(argv.model, system=system_prompt, message_limit=argv.message_limit) client.set_info() # Prompt history used for auto completion.