diff --git a/desktop/desktop.go b/desktop/desktop.go index 3fff3bfd..ba5ed137 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -9,9 +9,13 @@ import ( "html" "io" "net/http" + "os" + "runtime" "strconv" "strings" + "syscall" "time" + "unsafe" "github.com/docker/go-units" "github.com/docker/model-distribution/distribution" @@ -106,6 +110,147 @@ func (c *Client) Status() Status { } } +func humanReadableSize(size float64) string { + return units.CustomSize("%.2f%s", float64(size), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}) +} + +func humanReadableSizePad(size float64, width int) string { + return fmt.Sprintf("%*s", width, humanReadableSize(size)) +} + +func humanReadableTimePad(seconds int64, width int) string { + var s string + if seconds < 60 { + s = fmt.Sprintf("%ds", seconds) + } else if seconds < 3600 { + s = fmt.Sprintf("%dm %02ds", seconds/60, seconds%60) + } else { + s = fmt.Sprintf("%dh %02dm %02ds", seconds/3600, (seconds%3600)/60, seconds%60) + } + return fmt.Sprintf("%*s", width, s) +} + +// ProgressBarState tracks the running totals and timing for speed/ETA +type ProgressBarState struct { + LastTime time.Time + StartTime time.Time + UpdateInterval time.Duration // New: interval between updates + LastPrint time.Time // New: last time the progress bar was printed +} + +// fmtBar calculates the bar width and filled bar string. +func (pbs *ProgressBarState) fmtBar(percent float64, termWidth int, prefix, suffix string) string { + barWidth := termWidth - len(prefix) - len(suffix) - 4 + if barWidth < 10 { + barWidth = 10 + } + + filled := int(percent / 100 * float64(barWidth)) + if filled > barWidth { + filled = barWidth + } + + bar := strings.Repeat("█", filled) + strings.Repeat(" ", barWidth-filled) + + return bar +} + +// calcSpeed calculates the current download speed. +func (pbs *ProgressBarState) calcSpeed(current uint64, now time.Time) float64 { + elapsed := now.Sub(pbs.StartTime).Seconds() + if elapsed <= 0 { + return 0 + } + + speed := float64(current) / elapsed + pbs.LastTime = now + + return speed +} + +// fmtSuffix returns the suffix string showing human readable sizes, speed, and ETA. +func (pbs *ProgressBarState) fmtSuffix(current, total uint64, speed float64, eta int64) string { + return fmt.Sprintf("%s/%s %s/s %s", + humanReadableSizePad(float64(current), 10), + humanReadableSize(float64(total)), + humanReadableSizePad(speed, 10), + humanReadableTimePad(eta, 16), + ) +} + +// calcETA calculates the estimated time remaining. +func (pbs *ProgressBarState) calcETA(current, total uint64, speed float64) int64 { + if speed <= 0 { + return 0 + } + + return int64(float64(total-current) / speed) +} + +// fmtProgressBar returns a progress bar update string +func (pbs *ProgressBarState) fmtProgressBar(current, total uint64) string { + if pbs.StartTime.IsZero() { + pbs.StartTime = time.Now() + pbs.LastTime = pbs.StartTime + pbs.LastPrint = pbs.StartTime + } + + now := time.Now() + + // Update display if enough time passed, or always if interval=0 + if pbs.UpdateInterval > 0 && now.Sub(pbs.LastPrint) < pbs.UpdateInterval && current != total { + return "" + } + + pbs.LastPrint = now + termWidth := getTerminalWidth() + percent := float64(current) / float64(total) * 100 + prefix := fmt.Sprintf("%3.0f%% |", percent) + speed := pbs.calcSpeed(current, now) + eta := pbs.calcETA(current, total, speed) + suffix := pbs.fmtSuffix(current, total, speed, eta) + bar := pbs.fmtBar(percent, termWidth, prefix, suffix) + return fmt.Sprintf("%s%s| %s", prefix, bar, suffix) +} + +func getTerminalWidthUnix() (int, error) { + type winsize struct { + Row uint16 + Col uint16 + Xpixel uint16 + Ypixel uint16 + } + ws := &winsize{} + retCode, _, errno := syscall.Syscall6( + syscall.SYS_IOCTL, + uintptr(os.Stdout.Fd()), + uintptr(syscall.TIOCGWINSZ), + uintptr(unsafe.Pointer(ws)), + 0, 0, 0, + ) + if int(retCode) == -1 { + return 0, errno + } + return int(ws.Col), nil +} + +// getTerminalWidth tries to get the terminal width (default 80 if fails) +func getTerminalWidth() int { + var width int + var err error + default_width := 80 + if runtime.GOOS == "windows" { // to be implemented + return default_width + } + + width, err = getTerminalWidthUnix() + if width == 0 || err != nil { + return default_width + } + + return width +} + func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) { model = normalizeHuggingFaceModelName(model) jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck}) @@ -134,6 +279,9 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func layerProgress := make(map[string]uint64) // Track progress per layer ID scanner := bufio.NewScanner(resp.Body) + pbs := &ProgressBarState{ + UpdateInterval: time.Millisecond * 100, + } for scanner.Scan() { progressLine := scanner.Text() if progressLine == "" { @@ -159,8 +307,12 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func current += layerCurrent } - progress(fmt.Sprintf("Downloaded %s of %s", units.CustomSize("%.2f%s", float64(current), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}), units.CustomSize("%.2f%s", float64(progressMsg.Total), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}))) - progressShown = true + progressBar := pbs.fmtProgressBar(current, progressMsg.Total) + if progressBar != "" { + progress(progressBar) + progressShown = true + } + case "error": return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message) case "success":