@@ -2,55 +2,56 @@ package gpuscheduler
22
33import (
44 "encoding/json"
5- "errors"
6- "io/ioutil"
7- "net/http"
85 "strconv"
6+ "strings"
97 "sync"
108
11- "github.com/mayooot/gpu-docker-api/internal/config"
9+ "github.com/commander-cli/cmd"
10+ "github.com/pkg/errors"
11+
1212 "github.com/mayooot/gpu-docker-api/internal/etcd"
13- "github.com/mayooot/gpu-docker-api/internal/model"
1413 "github.com/mayooot/gpu-docker-api/internal/xerrors"
1514)
1615
1716const (
18- // 默认的可用GPU 数量
19- defaultAvailableGpuNums = 8
17+ // 执行命令获取 gpu 的 index 和 uuid
18+ allGpuUUIDCommand = "nvidia-smi --query-gpu=index,uuid --format=csv,noheader,nounits"
2019
2120 // gpuScheduler 存储在 etcd 中的 key
2221 gpuStatusMapKey = "gpuStatusMapKey"
2322)
2423
2524var Scheduler * scheduler
2625
26+ type gpu struct {
27+ Index int `json:"index"`
28+ UUID * string `json:"uuid"`
29+ }
30+
2731type scheduler struct {
2832 sync.RWMutex
2933
3034 AvailableGpuNums int
3135 GpuStatusMap map [string ]byte
3236}
3337
34- func Init (cfg * config. Config ) error {
38+ func Init () error {
3539 var err error
3640 Scheduler , err = initFormEtcd ()
3741 if err != nil {
38- return err
42+ return errors . Wrap ( err , "initFormEtcd failed" )
3943 }
4044
4145 if Scheduler .AvailableGpuNums == 0 || len (Scheduler .GpuStatusMap ) == 0 {
4246 // 如果没有初始化过
43- Scheduler .AvailableGpuNums = defaultAvailableGpuNums
44- if cfg .AvailableGpuNums >= 0 {
45- Scheduler .AvailableGpuNums = cfg .AvailableGpuNums
46- }
47-
48- gpus , err := getDetectGpus (cfg .DetectGPUAddr )
47+ gpus , err := getAllGpuUUID ()
4948 if err != nil {
50- return err
49+ return errors . Wrap ( err , "getAllGpuUUID failed" )
5150 }
51+
52+ Scheduler .AvailableGpuNums = len (gpus )
5253 for i := 0 ; i < len (gpus ); i ++ {
53- Scheduler .GpuStatusMap [gpus [i ].UUID ] = 0
54+ Scheduler .GpuStatusMap [* gpus [i ].UUID ] = 0
5455 }
5556 }
5657 return nil
@@ -139,20 +140,40 @@ func initFormEtcd() (s *scheduler, err error) {
139140 return s , err
140141}
141142
142- func getDetectGpus (addr string ) (gpus []model.GpuInfo , err error ) {
143- resp , err := http .Get (addr )
143+ func getAllGpuUUID () ([]* gpu , error ) {
144+ c := cmd .NewCommand (allGpuUUIDCommand )
145+ err := c .Execute ()
144146 if err != nil {
145- return gpus , err
147+ return nil , errors . Wrap ( err , "cmd.Execute failed" )
146148 }
147- defer resp .Body .Close ()
148149
149- body , err := ioutil . ReadAll ( resp . Body )
150+ gpuList , err := parseOutput ( c . Stdout () )
150151 if err != nil {
151- return gpus , err
152+ return nil , errors . Wrap ( err , "parseOutput failed" )
152153 }
154+ return gpuList , nil
155+ }
156+
157+ func parseOutput (output string ) (gpuList []* gpu , err error ) {
158+ lines := strings .Split (output , "\n " )
159+ gpuList = make ([]* gpu , 0 , len (lines ))
160+ for _ , line := range lines {
161+ if line == "" {
162+ continue
163+ }
153164
154- if err = json .Unmarshal (body , & gpus ); err != nil {
155- return gpus , err
165+ fields := strings .Split (line , ", " )
166+ if len (fields ) == 2 {
167+ index , err := strconv .Atoi (fields [0 ])
168+ if err != nil {
169+ return gpuList , errors .Wrapf (err , "strconv.Atoi failed, index: %s" , fields [0 ])
170+ }
171+ uuid := fields [1 ]
172+ gpuList = append (gpuList , & gpu {
173+ Index : index ,
174+ UUID : & uuid ,
175+ })
176+ }
156177 }
157- return gpus , err
178+ return
158179}
0 commit comments