Skip to content
This repository was archived by the owner on Oct 6, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion commands/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func newUpCommand(desktopClient *desktop.Client) *cobra.Command {
return err
}

_, err := desktopClient.Pull(model, func(s string) {
_, _, err := desktopClient.Pull(model, func(s string) {
sendInfo(s)
})
if err != nil {
Expand Down
29 changes: 19 additions & 10 deletions commands/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,28 @@ func newPullCmd(desktopClient *desktop.Client) *cobra.Command {
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
model := args[0]
response, err := desktopClient.Pull(model, TUIProgress)
if err != nil {
err = handleClientError(err, "Failed to pull model")
return handleNotRunningError(err)
}
cmd.Println(response)
return nil
return pullModel(cmd, desktopClient, args[0])
},
}
return c
}

func TUIProgress(line string) {
fmt.Print("\r\033[K", line)
func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error {
response, progressShown, err := desktopClient.Pull(model, TUIProgress)

// Add a newline before any output (success or error) if progress was shown.
if progressShown {
cmd.Println()
}

if err != nil {
return handleNotRunningError(handleClientError(err, "Failed to pull model"))
}

cmd.Println(response)
return nil
}

func TUIProgress(message string) {
fmt.Print("\r\033[K", message)
}
6 changes: 2 additions & 4 deletions commands/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@ func newRunCmd(desktopClient *desktop.Client) *cobra.Command {
return handleNotRunningError(handleClientError(err, "Failed to list models"))
}
cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.")
response, err := desktopClient.Pull(model, TUIProgress)
if err != nil {
return handleNotRunningError(handleClientError(err, "Failed to pull model"))
if err := pullModel(cmd, desktopClient, model); err != nil {
return err
}
cmd.Println(response)
}

if prompt != "" {
Expand Down
6 changes: 6 additions & 0 deletions desktop/api.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package desktop

// ProgressMessage represents a message sent during model pull operations
type ProgressMessage struct {
Type string `json:"type"` // "progress", "success", or "error"
Message string `json:"message"` // Human-readable message
}

type OpenAIChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Expand Down
39 changes: 30 additions & 9 deletions desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"html"
"io"
"net/http"
"os"
Expand Down Expand Up @@ -79,10 +80,10 @@ func (c *Client) Status() Status {
}
}

func (c *Client) Pull(model string, progress func(string)) (string, error) {
func (c *Client) Pull(model string, progress func(string)) (string, bool, error) {
jsonData, err := json.Marshal(models.ModelCreateRequest{From: model})
if err != nil {
return "", fmt.Errorf("error marshaling request: %w", err)
return "", false, fmt.Errorf("error marshaling request: %w", err)
}

createPath := inference.ModelsPrefix + "/create"
Expand All @@ -92,26 +93,46 @@ func (c *Client) Pull(model string, progress func(string)) (string, error) {
bytes.NewReader(jsonData),
)
if err != nil {
return "", c.handleQueryError(err, createPath)
return "", false, c.handleQueryError(err, createPath)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body))
return "", false, fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body))
}

progressShown := false

scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
progressLine := scanner.Text()
if progressLine != "" {
progress(progressLine)
if progressLine == "" {
continue
}
}

fmt.Println()
// Parse the progress message
var progressMsg ProgressMessage
if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil {
return "", progressShown, fmt.Errorf("error parsing progress message: %w", err)
}

// Handle different message types
switch progressMsg.Type {
case "progress":
progress(progressMsg.Message)
progressShown = true
case "error":
return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message)
case "success":
return progressMsg.Message, progressShown, nil
default:
return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type)
}
}

return fmt.Sprintf("Model %s pulled successfully", model), nil
// If we get here, something went wrong
return "", progressShown, fmt.Errorf("unexpected end of stream while pulling model %s", model)
}

func (c *Client) List(jsonFormat, openai bool, model string) (string, error) {
Expand Down