@@ -9,9 +9,13 @@ import (
99 "html"
1010 "io"
1111 "net/http"
12+ "os"
13+ "runtime"
1214 "strconv"
1315 "strings"
16+ "syscall"
1417 "time"
18+ "unsafe"
1519
1620 "github.com/docker/go-units"
1721 "github.com/docker/model-distribution/distribution"
@@ -106,6 +110,147 @@ func (c *Client) Status() Status {
106110 }
107111}
108112
113+ func humanReadableSize (size float64 ) string {
114+ return units .CustomSize ("%.2f%s" , float64 (size ), 1000.0 , []string {"B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" })
115+ }
116+
117+ func humanReadableSizePad (size float64 , width int ) string {
118+ return fmt .Sprintf ("%*s" , width , humanReadableSize (size ))
119+ }
120+
121+ func humanReadableTimePad (seconds int64 , width int ) string {
122+ var s string
123+ if seconds < 60 {
124+ s = fmt .Sprintf ("%ds" , seconds )
125+ } else if seconds < 3600 {
126+ s = fmt .Sprintf ("%dm %02ds" , seconds / 60 , seconds % 60 )
127+ } else {
128+ s = fmt .Sprintf ("%dh %02dm %02ds" , seconds / 3600 , (seconds % 3600 )/ 60 , seconds % 60 )
129+ }
130+ return fmt .Sprintf ("%*s" , width , s )
131+ }
132+
133+ // ProgressBarState tracks the running totals and timing for speed/ETA
134+ type ProgressBarState struct {
135+ LastTime time.Time
136+ StartTime time.Time
137+ UpdateInterval time.Duration // New: interval between updates
138+ LastPrint time.Time // New: last time the progress bar was printed
139+ }
140+
141+ // fmtBar calculates the bar width and filled bar string.
142+ func (pbs * ProgressBarState ) fmtBar (percent float64 , termWidth int , prefix , suffix string ) string {
143+ barWidth := termWidth - len (prefix ) - len (suffix ) - 4
144+ if barWidth < 10 {
145+ barWidth = 10
146+ }
147+
148+ filled := int (percent / 100 * float64 (barWidth ))
149+ if filled > barWidth {
150+ filled = barWidth
151+ }
152+
153+ bar := strings .Repeat ("█" , filled ) + strings .Repeat (" " , barWidth - filled )
154+
155+ return bar
156+ }
157+
158+ // calcSpeed calculates the current download speed.
159+ func (pbs * ProgressBarState ) calcSpeed (current uint64 , now time.Time ) float64 {
160+ elapsed := now .Sub (pbs .StartTime ).Seconds ()
161+ if elapsed <= 0 {
162+ return 0
163+ }
164+
165+ speed := float64 (current ) / elapsed
166+ pbs .LastTime = now
167+
168+ return speed
169+ }
170+
171+ // fmtSuffix returns the suffix string showing human readable sizes, speed, and ETA.
172+ func (pbs * ProgressBarState ) fmtSuffix (current , total uint64 , speed float64 , eta int64 ) string {
173+ return fmt .Sprintf ("%s/%s %s/s %s" ,
174+ humanReadableSizePad (float64 (current ), 10 ),
175+ humanReadableSize (float64 (total )),
176+ humanReadableSizePad (speed , 10 ),
177+ humanReadableTimePad (eta , 16 ),
178+ )
179+ }
180+
181+ // calcETA calculates the estimated time remaining.
182+ func (pbs * ProgressBarState ) calcETA (current , total uint64 , speed float64 ) int64 {
183+ if speed <= 0 {
184+ return 0
185+ }
186+
187+ return int64 (float64 (total - current ) / speed )
188+ }
189+
190+ // fmtProgressBar returns a progress bar update string
191+ func (pbs * ProgressBarState ) fmtProgressBar (current , total uint64 ) string {
192+ if pbs .StartTime .IsZero () {
193+ pbs .StartTime = time .Now ()
194+ pbs .LastTime = pbs .StartTime
195+ pbs .LastPrint = pbs .StartTime
196+ }
197+
198+ now := time .Now ()
199+
200+ // Update display if enough time passed, or always if interval=0
201+ if pbs .UpdateInterval > 0 && now .Sub (pbs .LastPrint ) < pbs .UpdateInterval && current != total {
202+ return ""
203+ }
204+
205+ pbs .LastPrint = now
206+ termWidth := getTerminalWidth ()
207+ percent := float64 (current ) / float64 (total ) * 100
208+ prefix := fmt .Sprintf ("%3.0f%% |" , percent )
209+ speed := pbs .calcSpeed (current , now )
210+ eta := pbs .calcETA (current , total , speed )
211+ suffix := pbs .fmtSuffix (current , total , speed , eta )
212+ bar := pbs .fmtBar (percent , termWidth , prefix , suffix )
213+ return fmt .Sprintf ("%s%s| %s" , prefix , bar , suffix )
214+ }
215+
216+ func getTerminalWidthUnix () (int , error ) {
217+ type winsize struct {
218+ Row uint16
219+ Col uint16
220+ Xpixel uint16
221+ Ypixel uint16
222+ }
223+ ws := & winsize {}
224+ retCode , _ , errno := syscall .Syscall6 (
225+ syscall .SYS_IOCTL ,
226+ uintptr (os .Stdout .Fd ()),
227+ uintptr (syscall .TIOCGWINSZ ),
228+ uintptr (unsafe .Pointer (ws )),
229+ 0 , 0 , 0 ,
230+ )
231+ if int (retCode ) == - 1 {
232+ return 0 , errno
233+ }
234+ return int (ws .Col ), nil
235+ }
236+
237+ // getTerminalWidth tries to get the terminal width (default 80 if fails)
238+ func getTerminalWidth () int {
239+ var width int
240+ var err error
241+ default_width := 80
242+ if runtime .GOOS == "windows" { // to be implemented
243+ return default_width
244+ }
245+
246+ width , err = getTerminalWidthUnix ()
247+ if width == 0 || err != nil {
248+ return default_width
249+ }
250+
251+ return width
252+ }
253+
109254func (c * Client ) Pull (model string , ignoreRuntimeMemoryCheck bool , progress func (string )) (string , bool , error ) {
110255 model = normalizeHuggingFaceModelName (model )
111256 jsonData , err := json .Marshal (dmrm.ModelCreateRequest {From : model , IgnoreRuntimeMemoryCheck : ignoreRuntimeMemoryCheck })
@@ -134,6 +279,9 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
134279 layerProgress := make (map [string ]uint64 ) // Track progress per layer ID
135280
136281 scanner := bufio .NewScanner (resp .Body )
282+ pbs := & ProgressBarState {
283+ UpdateInterval : time .Millisecond * 100 ,
284+ }
137285 for scanner .Scan () {
138286 progressLine := scanner .Text ()
139287 if progressLine == "" {
@@ -159,8 +307,12 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
159307 current += layerCurrent
160308 }
161309
162- progress (fmt .Sprintf ("Downloaded %s of %s" , units .CustomSize ("%.2f%s" , float64 (current ), 1000.0 , []string {"B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" }), units .CustomSize ("%.2f%s" , float64 (progressMsg .Total ), 1000.0 , []string {"B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" })))
163- progressShown = true
310+ progressBar := pbs .fmtProgressBar (current , progressMsg .Total )
311+ if progressBar != "" {
312+ progress (progressBar )
313+ progressShown = true
314+ }
315+
164316 case "error" :
165317 return "" , progressShown , fmt .Errorf ("error pulling model: %s" , progressMsg .Message )
166318 case "success" :
0 commit comments