Skip to content
This repository was archived by the owner on Oct 6, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions desktop/api.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
package desktop

// ProgressMessage represents a message sent during model pull operations
// ProgressMessage represents a structured message for progress reporting
type ProgressMessage struct {
Type string `json:"type"` // "progress", "success", or "error"
Message string `json:"message"` // Human-readable message
Message string `json:"message"` // Deprecated: the message should be defined by clients based on Message.Total and Message.Layer
Total uint64 `json:"total"`
Pulled uint64 `json:"pulled"` // Deprecated: use Layer.Current
Layer Layer `json:"layer"` // Current layer information
}

type Layer struct {
ID string // Layer ID
Size uint64 // Layer size
Current uint64 // Current bytes transferred
}

type OpenAIChatMessage struct {
Expand Down
16 changes: 15 additions & 1 deletion desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"strings"
"time"

"github.com/docker/go-units"

"github.com/docker/model-runner/pkg/inference"
dmrm "github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/inference/scheduling"
Expand Down Expand Up @@ -124,6 +126,8 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error)
}

progressShown := false
current := uint64(0) // Track cumulative progress across all layers
layerProgress := make(map[string]uint64) // Track progress per layer ID

scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
Expand All @@ -141,7 +145,17 @@ func (c *Client) Pull(model string, progress func(string)) (string, bool, error)
// Handle different message types
switch progressMsg.Type {
case "progress":
progress(progressMsg.Message)
// Update the current progress for this layer
layerID := progressMsg.Layer.ID
layerProgress[layerID] = progressMsg.Layer.Current

// Sum all layer progress values
current = uint64(0)
for _, layerCurrent := range layerProgress {
current += layerCurrent
}

progress(fmt.Sprintf("Downloaded %s of %s", units.HumanSize(float64(current)), units.HumanSize(float64(progressMsg.Total))))
progressShown = true
case "error":
return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message)
Expand Down