diff --git a/commands/compose.go b/commands/compose.go index 4c5e45cf..efa78032 100644 --- a/commands/compose.go +++ b/commands/compose.go @@ -33,7 +33,7 @@ func newUpCommand(desktopClient *desktop.Client) *cobra.Command { return err } - _, err := desktopClient.Pull(model, func(s string) { + _, _, err := desktopClient.Pull(model, func(s string) { sendInfo(s) }) if err != nil { diff --git a/commands/pull.go b/commands/pull.go index 8f37e386..9f1da718 100644 --- a/commands/pull.go +++ b/commands/pull.go @@ -22,19 +22,28 @@ func newPullCmd(desktopClient *desktop.Client) *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { - model := args[0] - response, err := desktopClient.Pull(model, TUIProgress) - if err != nil { - err = handleClientError(err, "Failed to pull model") - return handleNotRunningError(err) - } - cmd.Println(response) - return nil + return pullModel(cmd, desktopClient, args[0]) }, } return c } -func TUIProgress(line string) { - fmt.Print("\r\033[K", line) +func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error { + response, progressShown, err := desktopClient.Pull(model, TUIProgress) + + // Add a newline before any output (success or error) if progress was shown. + if progressShown { + cmd.Println() + } + + if err != nil { + return handleNotRunningError(handleClientError(err, "Failed to pull model")) + } + + cmd.Println(response) + return nil +} + +func TUIProgress(message string) { + fmt.Print("\r\033[K", message) } diff --git a/commands/run.go b/commands/run.go index 2e0d03d4..ebe46099 100644 --- a/commands/run.go +++ b/commands/run.go @@ -37,11 +37,9 @@ func newRunCmd(desktopClient *desktop.Client) *cobra.Command { return handleNotRunningError(handleClientError(err, "Failed to list models")) } cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.") - response, err := desktopClient.Pull(model, TUIProgress) - if err != nil { - return handleNotRunningError(handleClientError(err, "Failed to pull model")) + if err := pullModel(cmd, desktopClient, model); err != nil { + return err } - cmd.Println(response) } if prompt != "" { diff --git a/desktop/api.go b/desktop/api.go index b7064167..46f8de12 100644 --- a/desktop/api.go +++ b/desktop/api.go @@ -1,5 +1,11 @@ package desktop +// ProgressMessage represents a message sent during model pull operations +type ProgressMessage struct { + Type string `json:"type"` // "progress", "success", or "error" + Message string `json:"message"` // Human-readable message +} + type OpenAIChatMessage struct { Role string `json:"role"` Content string `json:"content"` diff --git a/desktop/desktop.go b/desktop/desktop.go index 148f7ba2..1477e9e1 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -5,6 +5,7 @@ import ( "bytes" "encoding/json" "fmt" + "html" "io" "net/http" "os" @@ -79,10 +80,10 @@ func (c *Client) Status() Status { } } -func (c *Client) Pull(model string, progress func(string)) (string, error) { +func (c *Client) Pull(model string, progress func(string)) (string, bool, error) { jsonData, err := json.Marshal(models.ModelCreateRequest{From: model}) if err != nil { - return "", fmt.Errorf("error marshaling request: %w", err) + return "", false, fmt.Errorf("error marshaling request: %w", err) } createPath := inference.ModelsPrefix + "/create" @@ -92,26 +93,46 @@ func (c *Client) Pull(model string, progress func(string)) (string, error) { bytes.NewReader(jsonData), ) if err != nil { - return "", c.handleQueryError(err, createPath) + return "", false, c.handleQueryError(err, createPath) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body)) + return "", false, fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body)) } + progressShown := false + scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { progressLine := scanner.Text() - if progressLine != "" { - progress(progressLine) + if progressLine == "" { + continue } - } - fmt.Println() + // Parse the progress message + var progressMsg ProgressMessage + if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { + return "", progressShown, fmt.Errorf("error parsing progress message: %w", err) + } + + // Handle different message types + switch progressMsg.Type { + case "progress": + progress(progressMsg.Message) + progressShown = true + case "error": + return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message) + case "success": + return progressMsg.Message, progressShown, nil + default: + return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type) + } + } - return fmt.Sprintf("Model %s pulled successfully", model), nil + // If we get here, something went wrong + return "", progressShown, fmt.Errorf("unexpected end of stream while pulling model %s", model) } func (c *Client) List(jsonFormat, openai bool, model string) (string, error) {