11import os
2- from pathlib import Path
32import sys
4- from typing import Iterator , List , Optional
5- from llama_cpp import Llama
3+ from typing import Iterator , List , Optional , TypedDict , cast
4+ from llama_cpp import Completion , CompletionChunk , Llama
65
76from gptcli .completion import CompletionProvider , Message
87
9- LLAMA_MODELS : Optional [dict [str , str ]] = None
108
9+ class LLaMAModelConfig (TypedDict ):
10+ path : str
11+ human_prompt : str
12+ assistant_prompt : str
1113
12- def init_llama_models (model_paths : dict [str , str ]):
13- for name , path in model_paths .items ():
14- if not os .path .isfile (path ):
15- print (f"LLaMA model { name } not found at { path } ." )
14+
15+ LLAMA_MODELS : Optional [dict [str , LLaMAModelConfig ]] = None
16+
17+
18+ def init_llama_models (models : dict [str , LLaMAModelConfig ]):
19+ for name , model_config in models .items ():
20+ if not os .path .isfile (model_config ["path" ]):
21+ print (f"LLaMA model { name } not found at { model_config ['path' ]} ." )
1622 sys .exit (1 )
1723 if not name .startswith ("llama" ):
1824 print (f"LLaMA model names must start with `llama`, but got `{ name } `." )
1925 sys .exit (1 )
2026
2127 global LLAMA_MODELS
22- LLAMA_MODELS = model_paths
28+ LLAMA_MODELS = models
2329
2430
25- def role_to_name (role : str ) -> str :
31+ def role_to_name (role : str , model_config : LLaMAModelConfig ) -> str :
2632 if role == "system" or role == "user" :
27- return "### Human: "
33+ return model_config [ "human_prompt" ]
2834 elif role == "assistant" :
29- return "### Assistant: "
35+ return model_config [ "assistant_prompt" ]
3036 else :
3137 raise ValueError (f"Unknown role: { role } " )
3238
3339
34- def make_prompt (messages : List [Message ]) -> str :
40+ def make_prompt (messages : List [Message ], model_config : LLaMAModelConfig ) -> str :
3541 prompt = "\n " .join (
36- [f"{ role_to_name (message ['role' ])} { message ['content' ]} " for message in messages ]
42+ [
43+ f"{ role_to_name (message ['role' ], model_config )} { message ['content' ]} "
44+ for message in messages
45+ ]
3746 )
38- prompt += "### Assistant: "
47+ prompt += f" \n { model_config [ 'assistant_prompt' ] } "
3948 return prompt
4049
4150
42- END_SEQ = "### Human:"
43-
44-
4551class LLaMACompletionProvider (CompletionProvider ):
4652 def complete (
4753 self , messages : List [Message ], args : dict , stream : bool = False
4854 ) -> Iterator [str ]:
4955 assert LLAMA_MODELS , "LLaMA models not initialized"
5056
57+ model_config = LLAMA_MODELS [args ["model" ]]
58+
5159 with suppress_stderr ():
5260 llm = Llama (
53- model_path = LLAMA_MODELS [ args [ "model" ] ],
61+ model_path = model_config [ "path" ],
5462 n_ctx = 2048 ,
5563 verbose = False ,
5664 use_mlock = True ,
5765 )
58- prompt = make_prompt (messages )
66+ prompt = make_prompt (messages , model_config )
67+ print (prompt )
5968
6069 extra_args = {}
6170 if "temperature" in args :
@@ -66,16 +75,16 @@ def complete(
6675 gen = llm .create_completion (
6776 prompt ,
6877 max_tokens = 1024 ,
69- stop = END_SEQ ,
78+ stop = model_config [ "human_prompt" ] ,
7079 stream = stream ,
7180 echo = False ,
7281 ** extra_args ,
7382 )
7483 if stream :
75- for x in gen :
84+ for x in cast ( Iterator [ CompletionChunk ], gen ) :
7685 yield x ["choices" ][0 ]["text" ]
7786 else :
78- yield gen ["choices" ][0 ]["text" ]
87+ yield cast ( Completion , gen ) ["choices" ][0 ]["text" ]
7988
8089
8190# https://stackoverflow.com/a/50438156
0 commit comments