diff --git a/desktop/desktop.go b/desktop/desktop.go index 148f7ba2..d6494b93 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -121,28 +121,18 @@ func (c *Client) List(jsonFormat, openai bool, model string) (string, error) { } if model != "" { if !strings.Contains(strings.Trim(model, "/"), "/") { - // We assume a model name is invalid if it does not contain a "/". - return "", fmt.Errorf("invalid model name: %s", model) + // Do an extra API call to check if the model parameter isn't a model ID. + var err error + if model, err = c.modelNameFromID(model); err != nil { + return "", fmt.Errorf("invalid model name: %s", model) + } } modelsRoute += "/" + model } - resp, err := c.doRequest(http.MethodGet, modelsRoute, nil) - if err != nil { - return "", c.handleQueryError(err, modelsRoute) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - if model != "" && resp.StatusCode == http.StatusNotFound { - return "", errors.Wrap(ErrNotFound, model) - } - return "", fmt.Errorf("failed to list models: %s", resp.Status) - } - - body, err := io.ReadAll(resp.Body) + body, err := c.listRaw(modelsRoute, model) if err != nil { - return "", fmt.Errorf("failed to read response body: %w", err) + return "", err } if openai { @@ -182,6 +172,48 @@ func (c *Client) List(jsonFormat, openai bool, model string) (string, error) { return prettyPrintModels(modelsJson), nil } +func (c *Client) listRaw(route string, model string) ([]byte, error) { + resp, err := c.doRequest(http.MethodGet, route, nil) + if err != nil { + return nil, c.handleQueryError(err, route) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if model != "" && resp.StatusCode == http.StatusNotFound { + return nil, errors.Wrap(ErrNotFound, model) + } + return nil, fmt.Errorf("failed to list models: %s", resp.Status) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return body, nil + +} + +func (c *Client) modelNameFromID(id string) (string, error) { + bodyResponse, err := c.listRaw(inference.ModelsPrefix, "") + if err != nil { + return "", err + } + + var modelsJson []Model + if err := json.Unmarshal(bodyResponse, &modelsJson); err != nil { + return "", fmt.Errorf("failed to unmarshal response body: %w", err) + } + + for _, m := range modelsJson { + if m.ID[7:19] == id || strings.TrimPrefix(m.ID, "sha256:") == id || m.ID == id { + return m.Tags[0], nil + } + } + + return "", fmt.Errorf("model with ID %s not found", id) +} + func (c *Client) Chat(model, prompt string) error { reqBody := OpenAIChatRequest{ Model: model, @@ -251,6 +283,14 @@ func (c *Client) Chat(model, prompt string) error { } func (c *Client) Remove(model string) (string, error) { + // Check if not a model ID passed as parameter. + if !strings.Contains(model, "/") { + var err error + if model, err = c.modelNameFromID(model); err != nil { + return "", fmt.Errorf("invalid model name: %s", model) + } + } + removePath := inference.ModelsPrefix + "/" + model resp, err := c.doRequest(http.MethodDelete, removePath, nil) if err != nil {