1- import logging
21import os
3- import signal
4- import subprocess
52from pathlib import Path
63import sys
74from typing import Iterator , List , Optional
5+ from llama_cpp import Llama
86
97from gptcli .completion import CompletionProvider , Message
108
11- LLAMA_DIR : Optional [Path ] = None
12- LLAMA_MODELS : Optional [dict [str , Path ]] = None
9+ LLAMA_MODELS : Optional [dict [str , str ]] = None
1310
1411
15- def init_llama_models (llama_cpp_dir : str , model_paths : dict [str , str ]):
12+ def init_llama_models (model_paths : dict [str , str ]):
1613 for name , path in model_paths .items ():
1714 if not os .path .isfile (path ):
1815 print (f"LLaMA model { name } not found at { path } ." )
@@ -21,9 +18,8 @@ def init_llama_models(llama_cpp_dir: str, model_paths: dict[str, str]):
2118 print (f"LLaMA model names must start with `llama`, but got `{ name } `." )
2219 sys .exit (1 )
2320
24- global LLAMA_DIR , LLAMA_MODELS
25- LLAMA_DIR = Path (llama_cpp_dir )
26- LLAMA_MODELS = {name : Path (path ) for name , path in model_paths .items ()}
21+ global LLAMA_MODELS
22+ LLAMA_MODELS = model_paths
2723
2824
2925def role_to_name (role : str ) -> str :
@@ -50,75 +46,51 @@ class LLaMACompletionProvider(CompletionProvider):
5046 def complete (
5147 self , messages : List [Message ], args : dict , stream : bool = False
5248 ) -> Iterator [str ]:
53- assert LLAMA_DIR , "LLaMA models not initialized"
5449 assert LLAMA_MODELS , "LLaMA models not initialized"
5550
51+ with suppress_stderr ():
52+ llm = Llama (
53+ model_path = LLAMA_MODELS [args ["model" ]],
54+ n_ctx = 2048 ,
55+ verbose = False ,
56+ use_mlock = True ,
57+ )
5658 prompt = make_prompt (messages )
5759
58- extra_args = []
60+ extra_args = {}
5961 if "temperature" in args :
60- extra_args += [ "--temp" , str ( args ["temperature" ]) ]
62+ extra_args [ "temperature" ] = args ["temperature" ]
6163 if "top_p" in args :
62- extra_args += ["--top_p" , str (args ["top_p" ])]
63-
64- process = subprocess .Popen (
65- [
66- LLAMA_DIR / "main" ,
67- "--model" ,
68- LLAMA_MODELS [args ["model" ]],
69- "-n" ,
70- "4096" ,
71- "-r" ,
72- "### Human:" ,
73- "-p" ,
74- prompt ,
75- * extra_args ,
76- ],
77- stdin = subprocess .PIPE ,
78- stdout = subprocess .PIPE ,
79- stderr = subprocess .PIPE ,
80- shell = False ,
81- text = True ,
64+ extra_args ["top_p" ] = args ["top_p" ]
65+
66+ gen = llm .create_completion (
67+ prompt ,
68+ max_tokens = 1024 ,
69+ stop = END_SEQ ,
70+ stream = stream ,
71+ echo = False ,
72+ ** extra_args ,
8273 )
83-
8474 if stream :
85- return self ._read_stream (process , prompt )
75+ for x in gen :
76+ yield x ["choices" ][0 ]["text" ]
8677 else :
87- return self ._read (process , prompt )
88-
89- def _read_stream (self , process : subprocess .Popen , prompt : str ) -> Iterator [str ]:
90- assert process .stdout , "LLaMA stdout not set"
91- assert process .stderr , "LLaMA stderr not set"
92-
93- buffer = ""
94- num_read = 0
95- char = process .stdout .read (1 )
96-
97- try :
98- while char := process .stdout .read (1 ):
99- num_read += len (char )
100- if num_read <= len (prompt ):
101- continue
102-
103- buffer += char
104- if not buffer .startswith ("#" ) or (buffer != END_SEQ [: len (buffer )]):
105- yield buffer
106- buffer = ""
107- elif buffer .endswith (END_SEQ ):
108- yield buffer [: - len (END_SEQ )]
109- buffer = ""
110- process .terminate ()
111- break
112- except KeyboardInterrupt :
113- os .kill (process .pid , signal .SIGINT )
114- raise
115- finally :
116- process .wait ()
117- stderr = "" .join (process .stderr .readlines ())
118- logging .debug (f"LLaMA stderr: { stderr } " )
119-
120- def _read (self , process : subprocess .Popen , prompt : str ) -> Iterator [str ]:
121- result = ""
122- for token in self ._read_stream (process , prompt ):
123- result += token
124- yield result
78+ yield gen ["choices" ][0 ]["text" ]
79+
80+
81+ # https://stackoverflow.com/a/50438156
82+ class suppress_stderr (object ):
83+ def __enter__ (self ):
84+ self .errnull_file = open (os .devnull , "w" )
85+ self .old_stderr_fileno_undup = sys .stderr .fileno ()
86+ self .old_stderr_fileno = os .dup (sys .stderr .fileno ())
87+ self .old_stderr = sys .stderr
88+ os .dup2 (self .errnull_file .fileno (), self .old_stderr_fileno_undup )
89+ sys .stderr = self .errnull_file
90+ return self
91+
92+ def __exit__ (self , * _ ):
93+ sys .stderr = self .old_stderr
94+ os .dup2 (self .old_stderr_fileno , self .old_stderr_fileno_undup )
95+ os .close (self .old_stderr_fileno )
96+ self .errnull_file .close ()
0 commit comments