diff --git a/commands/tag.go b/commands/tag.go index a5529024..be3366ca 100644 --- a/commands/tag.go +++ b/commands/tag.go @@ -2,6 +2,7 @@ package commands import ( "fmt" + "strings" "github.com/docker/model-cli/desktop" "github.com/google/go-containerregistry/pkg/name" @@ -33,20 +34,21 @@ func newTagCmd() *cobra.Command { } func tagModel(cmd *cobra.Command, desktopClient *desktop.Client, source, target string) error { - // Parse the target to extract repo and tag + // Ensure tag is valid tag, err := name.NewTag(target) if err != nil { return fmt.Errorf("invalid tag: %w", err) } - targetRepo := tag.Repository.String() - targetTag := tag.TagStr() - - // Make the POST request - resp, err := desktopClient.Tag(source, targetRepo, targetTag) - if err != nil { + // Make tag request with model runner client + if err := desktopClient.Tag(source, parseRepo(tag), tag.TagStr()); err != nil { return fmt.Errorf("failed to tag model: %w", err) } - - cmd.Println(resp) + cmd.Printf("Model %q tagged successfully with %q\n", source, target) return nil } + +// parseRepo returns the repo portion of the original target string. It does not include implicit +// index.docker.io when the registry is omitted. +func parseRepo(tag name.Tag) string { + return strings.TrimSuffix(tag.String(), ":"+tag.TagStr()) +} diff --git a/desktop/desktop.go b/desktop/desktop.go index 6506d308..7a240509 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -605,7 +605,7 @@ func (c *Client) handleQueryError(err error, path string) error { return fmt.Errorf("error querying %s: %w", path, err) } -func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) { +func (c *Client) Tag(source, targetRepo, targetTag string) error { source = normalizeHuggingFaceModelName(source) // Check if the source is a model ID, and expand it if necessary if !strings.Contains(strings.Trim(source, "/"), "/") { @@ -625,19 +625,17 @@ func (c *Client) Tag(source, targetRepo, targetTag string) (string, error) { resp, err := c.doRequest(http.MethodPost, tagPath, nil) if err != nil { - return "", c.handleQueryError(err, tagPath) + return c.handleQueryError(err, tagPath) } defer resp.Body.Close() - - if resp.StatusCode != http.StatusCreated { - body, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("tagging failed with status %s: %s", resp.Status, string(body)) - } - body, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("failed to read response body: %w", err) + return fmt.Errorf("failed to read response body: %w", err) } - return string(body), nil + if resp.StatusCode != http.StatusCreated { + return fmt.Errorf("tagging failed with status %s: %s", resp.Status, string(body)) + } + + return nil } diff --git a/desktop/desktop_test.go b/desktop/desktop_test.go index 71ea9b9f..78c32893 100644 --- a/desktop/desktop_test.go +++ b/desktop/desktop_test.go @@ -193,8 +193,7 @@ func TestTagHuggingFaceModel(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString("Tag created successfully")), }, nil) - _, err := client.Tag(sourceModel, targetRepo, targetTag) - assert.NoError(t, err) + assert.NoError(t, client.Tag(sourceModel, targetRepo, targetTag)) } func TestInspectOpenAIHuggingFaceModel(t *testing.T) {