From fc8178c99b1e208d36b2dc653939f8f2957362ec Mon Sep 17 00:00:00 2001 From: Guillaume Lours <705411+glours@users.noreply.github.com> Date: Mon, 7 Apr 2025 17:23:36 +0200 Subject: [PATCH 1/2] add support of model ID to inspect and rm commands Signed-off-by: Guillaume Lours <705411+glours@users.noreply.github.com> --- desktop/desktop.go | 74 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 17 deletions(-) diff --git a/desktop/desktop.go b/desktop/desktop.go index 148f7ba2..d6494b93 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -121,28 +121,18 @@ func (c *Client) List(jsonFormat, openai bool, model string) (string, error) { } if model != "" { if !strings.Contains(strings.Trim(model, "/"), "/") { - // We assume a model name is invalid if it does not contain a "/". - return "", fmt.Errorf("invalid model name: %s", model) + // Do an extra API call to check if the model parameter isn't a model ID. + var err error + if model, err = c.modelNameFromID(model); err != nil { + return "", fmt.Errorf("invalid model name: %s", model) + } } modelsRoute += "/" + model } - resp, err := c.doRequest(http.MethodGet, modelsRoute, nil) - if err != nil { - return "", c.handleQueryError(err, modelsRoute) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - if model != "" && resp.StatusCode == http.StatusNotFound { - return "", errors.Wrap(ErrNotFound, model) - } - return "", fmt.Errorf("failed to list models: %s", resp.Status) - } - - body, err := io.ReadAll(resp.Body) + body, err := c.listRaw(modelsRoute, model) if err != nil { - return "", fmt.Errorf("failed to read response body: %w", err) + return "", err } if openai { @@ -182,6 +172,48 @@ func (c *Client) List(jsonFormat, openai bool, model string) (string, error) { return prettyPrintModels(modelsJson), nil } +func (c *Client) listRaw(route string, model string) ([]byte, error) { + resp, err := c.doRequest(http.MethodGet, route, nil) + if err != nil { + return nil, c.handleQueryError(err, route) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if model != "" && resp.StatusCode == http.StatusNotFound { + return nil, errors.Wrap(ErrNotFound, model) + } + return nil, fmt.Errorf("failed to list models: %s", resp.Status) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return body, nil + +} + +func (c *Client) modelNameFromID(id string) (string, error) { + bodyResponse, err := c.listRaw(inference.ModelsPrefix, "") + if err != nil { + return "", err + } + + var modelsJson []Model + if err := json.Unmarshal(bodyResponse, &modelsJson); err != nil { + return "", fmt.Errorf("failed to unmarshal response body: %w", err) + } + + for _, m := range modelsJson { + if m.ID[7:19] == id || strings.TrimPrefix(m.ID, "sha256:") == id || m.ID == id { + return m.Tags[0], nil + } + } + + return "", fmt.Errorf("model with ID %s not found", id) +} + func (c *Client) Chat(model, prompt string) error { reqBody := OpenAIChatRequest{ Model: model, @@ -251,6 +283,14 @@ func (c *Client) Chat(model, prompt string) error { } func (c *Client) Remove(model string) (string, error) { + // Check if not a model ID passed as parameter. + if !strings.Contains(model, "/") { + var err error + if model, err = c.modelNameFromID(model); err != nil { + return "", fmt.Errorf("invalid model name: %s", model) + } + } + removePath := inference.ModelsPrefix + "/" + model resp, err := c.doRequest(http.MethodDelete, removePath, nil) if err != nil { From eb08c54ef8690c4b78defd7cd661cac89085c956 Mon Sep 17 00:00:00 2001 From: Guillaume Lours <705411+glours@users.noreply.github.com> Date: Mon, 7 Apr 2025 17:52:06 +0200 Subject: [PATCH 2/2] allow to remove more than 1 model at the time Signed-off-by: Guillaume Lours <705411+glours@users.noreply.github.com> --- commands/rm.go | 17 ++++++++-------- desktop/desktop.go | 50 +++++++++++++++++++++++++--------------------- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/commands/rm.go b/commands/rm.go index 1a095736..ff9630e3 100644 --- a/commands/rm.go +++ b/commands/rm.go @@ -9,26 +9,27 @@ import ( func newRemoveCmd(desktopClient *desktop.Client) *cobra.Command { c := &cobra.Command{ - Use: "rm MODEL", - Short: "Remove a model downloaded from Docker Hub", + Use: "rm [MODEL...]", + Short: "Remove models downloaded from Docker Hub", Args: func(cmd *cobra.Command, args []string) error { - if len(args) != 1 { + if len(args) < 1 { return fmt.Errorf( - "'docker model rm' requires 1 argument.\n\n" + - "Usage: docker model rm MODEL\n\n" + + "'docker model rm' requires at least 1 argument.\n\n" + + "Usage: docker model rm [MODEL...]\n\n" + "See 'docker model rm --help' for more information", ) } return nil }, RunE: func(cmd *cobra.Command, args []string) error { - model := args[0] - response, err := desktopClient.Remove(model) + response, err := desktopClient.Remove(args) + if response != "" { + cmd.Println(response) + } if err != nil { err = handleClientError(err, "Failed to remove model") return handleNotRunningError(err) } - cmd.Println(response) return nil }, } diff --git a/desktop/desktop.go b/desktop/desktop.go index d6494b93..77d6205f 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -282,34 +282,38 @@ func (c *Client) Chat(model, prompt string) error { return nil } -func (c *Client) Remove(model string) (string, error) { - // Check if not a model ID passed as parameter. - if !strings.Contains(model, "/") { - var err error - if model, err = c.modelNameFromID(model); err != nil { - return "", fmt.Errorf("invalid model name: %s", model) +func (c *Client) Remove(models []string) (string, error) { + modelRemoved := "" + for _, model := range models { + // Check if not a model ID passed as parameter. + if !strings.Contains(model, "/") { + var err error + modelID := model + if model, err = c.modelNameFromID(model); err != nil { + return modelRemoved, fmt.Errorf("invalid model name: %s", modelID) + } } - } - - removePath := inference.ModelsPrefix + "/" + model - resp, err := c.doRequest(http.MethodDelete, removePath, nil) - if err != nil { - return "", c.handleQueryError(err, removePath) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - var bodyStr string - body, err := io.ReadAll(resp.Body) + removePath := inference.ModelsPrefix + "/" + model + resp, err := c.doRequest(http.MethodDelete, removePath, nil) if err != nil { - bodyStr = fmt.Sprintf("(failed to read response body: %v)", err) - } else { - bodyStr = string(body) + return modelRemoved, c.handleQueryError(err, removePath) } - return "", fmt.Errorf("removing %s failed with status %s: %s", model, resp.Status, bodyStr) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + var bodyStr string + body, err := io.ReadAll(resp.Body) + if err != nil { + bodyStr = fmt.Sprintf("(failed to read response body: %v)", err) + } else { + bodyStr = string(body) + } + return modelRemoved, fmt.Errorf("removing %s failed with status %s: %s", model, resp.Status, bodyStr) + } + modelRemoved += fmt.Sprintf("Model %s removed successfully\n", model) } - - return fmt.Sprintf("Model %s removed successfully", model), nil + return modelRemoved, nil } func URL(path string) string {