diff --git a/internal/store/index.go b/internal/store/index.go index 3d20e8b..6cff7c9 100644 --- a/internal/store/index.go +++ b/internal/store/index.go @@ -39,10 +39,10 @@ func (i Index) Tag(reference string, tag string) (Index, error) { return result, nil } -func (i Index) UnTag(tag string) (name.Tag, Index) { +func (i Index) UnTag(tag string) (name.Tag, Index, error) { tagRef, err := name.NewTag(tag) if err != nil { - return name.Tag{}, Index{} + return name.Tag{}, Index{}, err } result := Index{ @@ -52,7 +52,7 @@ func (i Index) UnTag(tag string) (name.Tag, Index) { result.Models = append(result.Models, entry.UnTag(tagRef)) } - return tagRef, result + return tagRef, result, nil } func (i Index) Find(reference string) (IndexEntry, int, bool) { diff --git a/internal/store/index_test.go b/internal/store/index_test.go index b998ff6..99e1149 100644 --- a/internal/store/index_test.go +++ b/internal/store/index_test.go @@ -146,15 +146,26 @@ func TestUntag(t *testing.T) { }, }, } - tag, idx := idx.UnTag("other-tag") - if len(idx.Models) != 2 { - t.Fatalf("Expected 2 models, got %d", len(idx.Models)) - } - if len(idx.Models[0].Tags) != 1 { - t.Fatalf("Expected 1 tag, got %d", len(idx.Models[0].Tags)) - } - if tag.String() != "other-tag" { - t.Fatalf("Expected tag 'other-tag', got '%s'", tag) - } + t.Run("UnTagging existing tag", func(t *testing.T) { + tag, idx, err := idx.UnTag("other-tag") + if err != nil { + t.Fatalf("Error untagging entry: %v", err) + } + if len(idx.Models) != 2 { + t.Fatalf("Expected 2 models, got %d", len(idx.Models)) + } + if len(idx.Models[0].Tags) != 1 { + t.Fatalf("Expected 1 tag, got %d", len(idx.Models[0].Tags)) + } + if tag.String() != "other-tag" { + t.Fatalf("Expected tag 'other-tag', got '%s'", tag) + } + }) + t.Run("UnTagging invalid tag", func(t *testing.T) { + _, _, err := idx.UnTag("!@#$%^&*()") + if err == nil { + t.Fatal("Expected error when untagging non-existing tag, got nil") + } + }) }) } diff --git a/internal/store/store.go b/internal/store/store.go index d9fe797..233608b 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -7,7 +7,6 @@ import ( "path/filepath" "github.com/docker/model-distribution/internal/progress" - "github.com/google/go-containerregistry/pkg/name" v1 "github.com/google/go-containerregistry/pkg/v1" ) @@ -165,11 +164,18 @@ func (s *LocalStore) RemoveTags(tags []string) ([]string, error) { if err != nil { return nil, fmt.Errorf("reading modelss index: %w", err) } - var tagRef name.Tag var tagRefs []string for _, tag := range tags { - tagRef, index = index.UnTag(tag) + tagRef, newIndex, err := index.UnTag(tag) + if err != nil { + // Try to save progress before returning error. + if writeIndexErr := s.writeIndex(newIndex); writeIndexErr != nil { + return tagRefs, fmt.Errorf("untagging model: %w, also failed to save: %w", err, writeIndexErr) + } + return tagRefs, fmt.Errorf("untagging model: %w", err) + } tagRefs = append(tagRefs, tagRef.Name()) + index = newIndex } return tagRefs, s.writeIndex(index) }