diff --git a/desktop/api.go b/desktop/api.go index 17a06bc8..3833a112 100644 --- a/desktop/api.go +++ b/desktop/api.go @@ -1,9 +1,18 @@ package desktop -// ProgressMessage represents a message sent during model pull operations +// ProgressMessage represents a structured message for progress reporting type ProgressMessage struct { Type string `json:"type"` // "progress", "success", or "error" - Message string `json:"message"` // Human-readable message + Message string `json:"message"` // Deprecated: the message should be defined by clients based on Message.Total and Message.Layer + Total uint64 `json:"total"` + Pulled uint64 `json:"pulled"` // Deprecated: use Layer.Current + Layer Layer `json:"layer"` // Current layer information +} + +type Layer struct { + ID string // Layer ID + Size uint64 // Layer size + Current uint64 // Current bytes transferred } type OpenAIChatMessage struct { diff --git a/desktop/desktop.go b/desktop/desktop.go index 6506d308..2b43c329 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -12,6 +12,8 @@ import ( "strings" "time" + "github.com/docker/go-units" + "github.com/docker/model-runner/pkg/inference" dmrm "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/inference/scheduling" @@ -124,6 +126,8 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) } progressShown := false + current := uint64(0) // Track cumulative progress across all layers + layerProgress := make(map[string]uint64) // Track progress per layer ID scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { @@ -141,7 +145,17 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error) // Handle different message types switch progressMsg.Type { case "progress": - progress(progressMsg.Message) + // Update the current progress for this layer + layerID := progressMsg.Layer.ID + layerProgress[layerID] = progressMsg.Layer.Current + + // Sum all layer progress values + current = uint64(0) + for _, layerCurrent := range layerProgress { + current += layerCurrent + } + + progress(fmt.Sprintf("Downloaded %s of %s", units.HumanSize(float64(current)), units.HumanSize(float64(progressMsg.Total)))) progressShown = true case "error": return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message)