diff --git a/cmd/cli/commands/integration_test.go b/cmd/cli/commands/integration_test.go index 437ad33e7..85d37efc9 100644 --- a/cmd/cli/commands/integration_test.go +++ b/cmd/cli/commands/integration_test.go @@ -264,7 +264,7 @@ func createAndPushTestModel(t *testing.T, registryURL, modelRef string, contextS // Create a builder from the GGUF file t.Logf("Creating test model %s from %s", modelRef, absPath) - pkg, err := builder.FromGGUF(absPath) + pkg, err := builder.FromPath(absPath) require.NoError(t, err) // Set context size if specified diff --git a/cmd/cli/commands/package.go b/cmd/cli/commands/package.go index 83352aa88..1a8d23ad7 100644 --- a/cmd/cli/commands/package.go +++ b/cmd/cli/commands/package.go @@ -241,7 +241,7 @@ func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitRes } } else if opts.ggufPath != "" { cmd.PrintErrf("Adding GGUF file from %q\n", opts.ggufPath) - pkg, err := builder.FromGGUF(opts.ggufPath) + pkg, err := builder.FromPath(opts.ggufPath) if err != nil { return nil, fmt.Errorf("add gguf file: %w", err) } @@ -262,7 +262,7 @@ func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitRes } cmd.PrintErrf("Found %d safetensors file(s)\n", len(safetensorsPaths)) - pkg, err := builder.FromSafetensors(safetensorsPaths) + pkg, err := builder.FromPaths(safetensorsPaths) if err != nil { return nil, fmt.Errorf("create safetensors model: %w", err) } diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index 787b9c49c..3c2fc2eb1 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -286,7 +286,7 @@ func cmdPackage(args []string) int { var b *builder.Builder if isSafetensors { fmt.Println("Creating safetensors model") - b, err = builder.FromSafetensors(safetensorsPaths) + b, err = builder.FromPaths(safetensorsPaths) if err != nil { fmt.Fprintf(os.Stderr, "Error creating model from safetensors: %v\n", err) return 1 @@ -302,7 +302,7 @@ func cmdPackage(args []string) int { } } } else { - b, err = builder.FromGGUF(source) + b, err = builder.FromPath(source) if err != nil { fmt.Fprintf(os.Stderr, "Error creating model from gguf: %v\n", err) return 1 diff --git a/pkg/distribution/builder/builder.go b/pkg/distribution/builder/builder.go index 04ed6465d..1b566c611 100644 --- a/pkg/distribution/builder/builder.go +++ b/pkg/distribution/builder/builder.go @@ -4,11 +4,11 @@ import ( "context" "fmt" "io" + "time" - "github.com/docker/model-runner/pkg/distribution/internal/gguf" + "github.com/docker/model-runner/pkg/distribution/format" "github.com/docker/model-runner/pkg/distribution/internal/mutate" "github.com/docker/model-runner/pkg/distribution/internal/partial" - "github.com/docker/model-runner/pkg/distribution/internal/safetensors" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/types" ) @@ -19,23 +19,86 @@ type Builder struct { originalLayers []oci.Layer // Snapshot of layers when created from existing model } -// FromGGUF returns a *Builder that builds a model artifacts from a GGUF file -func FromGGUF(path string) (*Builder, error) { - mdl, err := gguf.NewModel(path) +// FromPath returns a *Builder that builds model artifacts from a file path. +// It auto-detects the model format (GGUF or Safetensors) and discovers any shards. +// This is the preferred entry point for creating models from local files. +func FromPath(path string) (*Builder, error) { + // Auto-detect format from file extension + f, err := format.DetectFromPath(path) if err != nil { - return nil, err + return nil, fmt.Errorf("detect format: %w", err) } - return &Builder{ - model: mdl, - }, nil + + // Discover all shards if this is a sharded model + paths, err := f.DiscoverShards(path) + if err != nil { + return nil, fmt.Errorf("discover shards: %w", err) + } + + // Create model using the format abstraction + return fromFormat(f, paths) +} + +// FromPaths returns a *Builder that builds model artifacts from multiple file paths. +// All paths must be of the same format. Use this when you already have the list of files. +func FromPaths(paths []string) (*Builder, error) { + if len(paths) == 0 { + return nil, fmt.Errorf("at least one path is required") + } + + // Detect and verify format from all paths + f, err := format.DetectFromPaths(paths) + if err != nil { + return nil, fmt.Errorf("detect format: %w", err) + } + + // Create model using the format abstraction + return fromFormat(f, paths) } -// FromSafetensors returns a *Builder that builds model artifacts from safetensors files -func FromSafetensors(safetensorsPaths []string) (*Builder, error) { - mdl, err := safetensors.NewModel(safetensorsPaths) +// fromFormat creates a Builder using the unified format abstraction. +// This is the internal implementation that creates layers and config. +func fromFormat(f format.Format, paths []string) (*Builder, error) { + // Create layers from paths + layers := make([]oci.Layer, len(paths)) + diffIDs := make([]oci.Hash, len(paths)) + + mediaType := f.MediaType() + for i, path := range paths { + layer, err := partial.NewLayer(path, mediaType) + if err != nil { + return nil, fmt.Errorf("create layer from %q: %w", path, err) + } + diffID, err := layer.DiffID() + if err != nil { + return nil, fmt.Errorf("get diffID for %q: %w", path, err) + } + layers[i] = layer + diffIDs[i] = diffID + } + + // Extract config metadata using format-specific logic + config, err := f.ExtractConfig(paths) if err != nil { - return nil, err + return nil, fmt.Errorf("extract config: %w", err) } + + // Build the model + created := time.Now() + mdl := &partial.BaseModel{ + ModelConfigFile: types.ConfigFile{ + Config: config, + Descriptor: types.Descriptor{ + Created: &created, + }, + RootFS: oci.RootFS{ + Type: "rootfs", + DiffIDs: diffIDs, + }, + }, + LayerList: layers, + } + return &Builder{ model: mdl, }, nil diff --git a/pkg/distribution/builder/builder_test.go b/pkg/distribution/builder/builder_test.go index 7b0d90e94..eac0a0804 100644 --- a/pkg/distribution/builder/builder_test.go +++ b/pkg/distribution/builder/builder_test.go @@ -15,7 +15,7 @@ import ( func TestBuilder(t *testing.T) { // Create a builder from a GGUF file - b, err := builder.FromGGUF(filepath.Join("..", "assets", "dummy.gguf")) + b, err := builder.FromPath(filepath.Join("..", "assets", "dummy.gguf")) if err != nil { t.Fatalf("Failed to create builder from GGUF: %v", err) } @@ -63,7 +63,7 @@ func TestBuilder(t *testing.T) { func TestWithMultimodalProjectorInvalidPath(t *testing.T) { // Create a builder from a GGUF file - b, err := builder.FromGGUF(filepath.Join("..", "assets", "dummy.gguf")) + b, err := builder.FromPath(filepath.Join("..", "assets", "dummy.gguf")) if err != nil { t.Fatalf("Failed to create builder from GGUF: %v", err) } @@ -77,7 +77,7 @@ func TestWithMultimodalProjectorInvalidPath(t *testing.T) { func TestWithMultimodalProjectorChaining(t *testing.T) { // Create a builder from a GGUF file - b, err := builder.FromGGUF(filepath.Join("..", "assets", "dummy.gguf")) + b, err := builder.FromPath(filepath.Join("..", "assets", "dummy.gguf")) if err != nil { t.Fatalf("Failed to create builder from GGUF: %v", err) } @@ -147,7 +147,7 @@ func TestWithMultimodalProjectorChaining(t *testing.T) { func TestFromModel(t *testing.T) { // Step 1: Create an initial model from GGUF with context size 2048 - initialBuilder, err := builder.FromGGUF(filepath.Join("..", "assets", "dummy.gguf")) + initialBuilder, err := builder.FromPath(filepath.Join("..", "assets", "dummy.gguf")) if err != nil { t.Fatalf("Failed to create initial builder from GGUF: %v", err) } @@ -230,7 +230,7 @@ func TestFromModel(t *testing.T) { func TestFromModelWithAdditionalLayers(t *testing.T) { // Create an initial model from GGUF - initialBuilder, err := builder.FromGGUF(filepath.Join("..", "assets", "dummy.gguf")) + initialBuilder, err := builder.FromPath(filepath.Join("..", "assets", "dummy.gguf")) if err != nil { t.Fatalf("Failed to create initial builder from GGUF: %v", err) } diff --git a/pkg/distribution/distribution/bundle_test.go b/pkg/distribution/distribution/bundle_test.go index b7c51fe0f..cecb3eeed 100644 --- a/pkg/distribution/distribution/bundle_test.go +++ b/pkg/distribution/distribution/bundle_test.go @@ -6,7 +6,7 @@ import ( "path/filepath" "testing" - "github.com/docker/model-runner/pkg/distribution/internal/gguf" + "github.com/docker/model-runner/pkg/distribution/builder" "github.com/docker/model-runner/pkg/distribution/internal/mutate" "github.com/docker/model-runner/pkg/distribution/internal/partial" "github.com/docker/model-runner/pkg/distribution/types" @@ -23,10 +23,11 @@ func TestBundle(t *testing.T) { } // Load dummy model from assets directory - mdl, err := gguf.NewModel(filepath.Join("..", "assets", "dummy.gguf")) + b, err := builder.FromPath(filepath.Join("..", "assets", "dummy.gguf")) if err != nil { t.Fatalf("Failed to create model: %v", err) } + mdl := b.Model() singleGGUFID, err := mdl.ID() if err != nil { t.Fatalf("Failed to get model ID: %v", err) @@ -64,10 +65,11 @@ func TestBundle(t *testing.T) { } // Load sharded dummy model from asset directory - shardedMdl, err := gguf.NewModel(filepath.Join("..", "assets", "dummy-00001-of-00002.gguf")) + shardedB, err := builder.FromPath(filepath.Join("..", "assets", "dummy-00001-of-00002.gguf")) if err != nil { t.Fatalf("Failed to create model: %v", err) } + shardedMdl := shardedB.Model() shardedGGUFID, err := shardedMdl.ID() if err != nil { t.Fatalf("Failed to get model ID: %v", err) diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index 6fed2157e..ae8aea2a0 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -149,7 +149,7 @@ func NewClient(opts ...Option) (*Client, error) { } // normalizeModelName adds the default organization prefix (ai/) and tag (:latest) if missing. -// It also converts Hugging Face model names to lowercase and resolves IDs to full IDs. +// It also resolves IDs to full IDs. // This is a private method used internally by the Client. func (c *Client) normalizeModelName(model string) string { const ( @@ -158,8 +158,6 @@ func (c *Client) normalizeModelName(model string) string { ) model = strings.TrimSpace(model) - - // If the model is empty, return as-is if model == "" { return model } @@ -169,44 +167,37 @@ func (c *Client) normalizeModelName(model string) string { if fullID := c.resolveID(model); fullID != "" { return fullID } - // If not found, return as-is return model } - // Normalize HuggingFace model names - if strings.HasPrefix(model, "hf.co/") { - // Replace hf.co with huggingface.co to avoid losing the Authorization header on redirect. - // Lowercase for OCI compatibility (repository names must be lowercase) - model = "huggingface.co" + strings.ToLower(strings.TrimPrefix(model, "hf.co")) - } + // Split name vs tag, where ':' is a tag separator only if it's after the last '/' + lastSlash := strings.LastIndex(model, "/") + lastColon := strings.LastIndex(model, ":") + + name := model + tag := defaultTag + hasTag := lastColon > lastSlash - // Check if model contains a registry (domain with dot before first slash) - firstSlash := strings.Index(model, "/") - if firstSlash > 0 && strings.Contains(model[:firstSlash], ".") { - // Has a registry, just ensure tag - // Check for tag separator after the last "/" (to avoid matching port like :5000) - lastSlash := strings.LastIndex(model, "/") - afterLastSlash := model[lastSlash+1:] - if !strings.Contains(afterLastSlash, ":") { - return model + ":" + defaultTag + if hasTag { + name = model[:lastColon] + // Preserve tag as-is; if empty, fall back to defaultTag + if t := model[lastColon+1:]; t != "" { + tag = t } - return model } - // Split by colon to check for tag - parts := strings.SplitN(model, ":", 2) - nameWithOrg := parts[0] - tag := defaultTag - if len(parts) == 2 && parts[1] != "" { - tag = parts[1] - } + // If name has no registry (domain with dot before first slash), apply default org if missing slash + firstSlash := strings.Index(name, "/") + hasRegistry := firstSlash > 0 && strings.Contains(name[:firstSlash], ".") - // If name doesn't contain a slash, add the default org - if !strings.Contains(nameWithOrg, "/") { - nameWithOrg = defaultOrg + "/" + nameWithOrg + if !hasRegistry && !strings.Contains(name, "/") { + name = defaultOrg + "/" + name } - return nameWithOrg + ":" + tag + // Lowercase ONLY the name part (registry/org/repo). Tag stays unchanged. + name = strings.ToLower(name) + + return name + ":" + tag } // looksLikeID returns true for short & long hex IDs (12 or 64 chars) @@ -282,25 +273,30 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter reference = c.normalizeModelName(reference) c.log.Infoln("Starting model pull:", utils.SanitizeForLog(reference)) - // Use the client's registry, or create a temporary one if bearer token is provided - registryClient := c.registry + // Handle bearer token for registry authentication var token string if len(bearerToken) > 0 && bearerToken[0] != "" { token = bearerToken[0] + } + + // HuggingFace references always use native pull (download raw files from HF Hub) + if isHuggingFaceReference(originalReference) { + c.log.Infoln("Using native HuggingFace pull for:", utils.SanitizeForLog(reference)) + // Pass original reference to preserve case-sensitivity for HuggingFace API + return c.pullNativeHuggingFace(ctx, originalReference, progressWriter, token) + } + + // For non-HF references, use OCI registry + registryClient := c.registry + if token != "" { // Create a temporary registry client with bearer token authentication auth := authn.NewBearer(token) registryClient = registry.FromClient(c.registry, registry.WithAuth(auth)) } - // First, fetch the remote model to get the manifest + // Fetch the remote model to get the manifest remoteModel, err := registryClient.Model(ctx, reference) if err != nil { - // Check if this is a HuggingFace reference and the error indicates no OCI manifest - if isHuggingFaceReference(reference) && isNotOCIError(err) { - c.log.Infoln("No OCI manifest found, attempting native HuggingFace pull") - // Pass original reference to preserve case-sensitivity for HuggingFace API - return c.pullNativeHuggingFace(ctx, originalReference, progressWriter, token) - } // Check if the error should be converted to registry.ErrModelNotFound for API compatibility // If the error already matches ErrModelNotFound, return it directly to preserve errors.Is compatibility if errors.Is(err, registry.ErrModelNotFound) { @@ -679,68 +675,16 @@ func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string, // isHuggingFaceReference checks if a reference is a HuggingFace model reference func isHuggingFaceReference(reference string) bool { - return strings.HasPrefix(reference, "huggingface.co/") + return strings.HasPrefix(reference, "huggingface.co/") || + strings.HasPrefix(reference, "hf.co/") } -// isNotOCIError checks if the error indicates the model is not OCI-formatted -// This happens when the HuggingFace repository doesn't have an OCI manifest -func isNotOCIError(err error) bool { - if err == nil { - return false - } - - // Check for registry errors indicating no manifest - var regErr *registry.Error - if errors.As(err, ®Err) { - if regErr.Code == "MANIFEST_UNKNOWN" || regErr.Code == "NAME_UNKNOWN" { - return true - } - } - - // Note: We intentionally don't treat ErrInvalidReference as "not OCI" - that's a format error - // that should be reported to the user, not interpreted as a native HF model. - // The model name is lowercased during normalization to ensure OCI compatibility. - - // Also check error message for common patterns - errStr := err.Error() - errStrLower := strings.ToLower(errStr) - return strings.Contains(errStr, "MANIFEST_UNKNOWN") || - strings.Contains(errStr, "NAME_UNKNOWN") || - strings.Contains(errStrLower, "manifest unknown") || - strings.Contains(errStrLower, "name unknown") || - // HuggingFace returns this error for non-GGUF repositories - strings.Contains(errStr, "Repository is not GGUF") || - strings.Contains(errStr, "not compatible with llama.cpp") || - // Additional patterns that might indicate non-OCI format from registry - strings.Contains(errStrLower, "blob unknown") || - strings.Contains(errStrLower, "tag unknown") || - // Containerd resolver specific error patterns - strings.Contains(errStrLower, "not found") || - strings.Contains(errStrLower, "status 404") || - strings.Contains(errStrLower, "status code 404") || - strings.Contains(errStrLower, "response status code") || - strings.Contains(errStrLower, "no such host") || - strings.Contains(errStrLower, "connection refused") || - // Additional OCI-related patterns - strings.Contains(errStrLower, "no manifest found") || - strings.Contains(errStrLower, "no image found") || - strings.Contains(errStrLower, "image not found") || - strings.Contains(errStrLower, "artifact not found") || - // Additional HuggingFace-specific error patterns - strings.Contains(errStrLower, "repository not found") || - strings.Contains(errStrLower, "resource not found") || - strings.Contains(errStrLower, "endpoint not found") || - strings.Contains(errStrLower, "model not found") || - // More specific HuggingFace error patterns - strings.Contains(errStr, "401") || - strings.Contains(errStr, "403") -} - -// parseHFReference extracts repo and revision from a HF reference -// e.g., "huggingface.co/org/model:revision" -> ("org/model", "revision") -// e.g., "hf.co/org/model:latest" -> ("org/model", "main") -// Note: This preserves the original case of the repo name for HuggingFace API compatibility -func parseHFReference(reference string) (repo, revision string) { +// parseHFReference extracts repo, revision, and tag from a HF reference +// e.g., "huggingface.co/org/model:revision" -> ("org/model", "main", "revision") +// e.g., "hf.co/org/model:latest" -> ("org/model", "main", "latest") +// e.g., "hf.co/org/model:Q4_K_M" -> ("org/model", "main", "Q4_K_M") +// The tag is used for GGUF quantization selection, while revision is always "main" for HuggingFace +func parseHFReference(reference string) (repo, revision, tag string) { // Remove registry prefix (handle both hf.co and huggingface.co) ref := strings.TrimPrefix(reference, "huggingface.co/") ref = strings.TrimPrefix(ref, "hf.co/") @@ -749,19 +693,24 @@ func parseHFReference(reference string) (repo, revision string) { parts := strings.SplitN(ref, ":", 2) repo = parts[0] - revision = "main" - if len(parts) == 2 && parts[1] != "" && parts[1] != "latest" { - revision = parts[1] + // Default tag is "latest" + tag = "latest" + if len(parts) == 2 && parts[1] != "" { + tag = parts[1] } - return repo, revision + // Revision is always "main" for HuggingFace repos + // (the tag is used for quantization selection, not git revision) + revision = "main" + + return repo, revision, tag } // pullNativeHuggingFace pulls a native HuggingFace repository (non-OCI format) // This is used when the model is stored as raw files (safetensors) on HuggingFace Hub func (c *Client) pullNativeHuggingFace(ctx context.Context, reference string, progressWriter io.Writer, token string) error { - repo, revision := parseHFReference(reference) - c.log.Infof("Pulling native HuggingFace model: repo=%s, revision=%s", utils.SanitizeForLog(repo), utils.SanitizeForLog(revision)) + repo, revision, tag := parseHFReference(reference) + c.log.Infof("Pulling native HuggingFace model: repo=%s, revision=%s, tag=%s", utils.SanitizeForLog(repo), utils.SanitizeForLog(revision), utils.SanitizeForLog(tag)) // Create HuggingFace client hfOpts := []huggingface.ClientOption{ @@ -780,7 +729,8 @@ func (c *Client) pullNativeHuggingFace(ctx context.Context, reference string, pr defer os.RemoveAll(tempDir) // Build model from HuggingFace repository - model, err := huggingface.BuildModel(ctx, hfClient, repo, revision, tempDir, progressWriter) + // The tag is used for GGUF quantization selection (e.g., "Q4_K_M", "Q8_0") + model, err := huggingface.BuildModel(ctx, hfClient, repo, revision, tag, tempDir, progressWriter) if err != nil { // Convert HuggingFace errors to registry errors for consistent handling var authErr *huggingface.AuthError @@ -797,9 +747,8 @@ func (c *Client) pullNativeHuggingFace(ctx context.Context, reference string, pr return fmt.Errorf("build model from HuggingFace: %w", err) } - // Write model to store - // Lowercase the reference for storage since OCI tags don't allow uppercase - storageTag := strings.ToLower(reference) + // Write model to store with normalized tag + storageTag := c.normalizeModelName(reference) c.log.Infof("Writing model to store with tag: %s", utils.SanitizeForLog(storageTag)) if err := c.store.Write(model, []string{storageTag}, progressWriter); err != nil { if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil { diff --git a/pkg/distribution/distribution/delete_test.go b/pkg/distribution/distribution/delete_test.go index 5878e6dc0..1585409b4 100644 --- a/pkg/distribution/distribution/delete_test.go +++ b/pkg/distribution/distribution/delete_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "github.com/docker/model-runner/pkg/distribution/internal/gguf" + "github.com/docker/model-runner/pkg/distribution/builder" ) func TestDeleteModel(t *testing.T) { @@ -23,10 +23,11 @@ func TestDeleteModel(t *testing.T) { } // Use the dummy.gguf file from assets directory - mdl, err := gguf.NewModel(testGGUFFile) + b, err := builder.FromPath(testGGUFFile) if err != nil { t.Fatalf("Failed to create model: %v", err) } + mdl := b.Model() id, err := mdl.ID() if err != nil { t.Fatalf("Failed to get model ID: %v", err) diff --git a/pkg/distribution/distribution/ecr_test.go b/pkg/distribution/distribution/ecr_test.go index afb72a4fd..30d686392 100644 --- a/pkg/distribution/distribution/ecr_test.go +++ b/pkg/distribution/distribution/ecr_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "github.com/docker/model-runner/pkg/distribution/internal/gguf" + "github.com/docker/model-runner/pkg/distribution/builder" ) func TestECRIntegration(t *testing.T) { @@ -41,10 +41,11 @@ func TestECRIntegration(t *testing.T) { } t.Run("Push", func(t *testing.T) { - mdl, err := gguf.NewModel(testGGUFFile) + b, err := builder.FromPath(testGGUFFile) if err != nil { t.Fatalf("Failed to create model: %v", err) } + mdl := b.Model() if err := client.store.Write(mdl, []string{ecrTag}, nil); err != nil { t.Fatalf("Failed to write model to store: %v", err) } diff --git a/pkg/distribution/distribution/gar_test.go b/pkg/distribution/distribution/gar_test.go index 2f3be0482..05f36482d 100644 --- a/pkg/distribution/distribution/gar_test.go +++ b/pkg/distribution/distribution/gar_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "github.com/docker/model-runner/pkg/distribution/internal/gguf" + "github.com/docker/model-runner/pkg/distribution/builder" ) func TestGARIntegration(t *testing.T) { @@ -42,10 +42,11 @@ func TestGARIntegration(t *testing.T) { // Test push to GAR t.Run("Push", func(t *testing.T) { - mdl, err := gguf.NewModel(testGGUFFile) + b, err := builder.FromPath(testGGUFFile) if err != nil { t.Fatalf("Failed to create model: %v", err) } + mdl := b.Model() if err := client.store.Write(mdl, []string{garTag}, nil); err != nil { t.Fatalf("Failed to write model to store: %v", err) } diff --git a/pkg/distribution/distribution/load_test.go b/pkg/distribution/distribution/load_test.go index 4d804bb56..6c22fb927 100644 --- a/pkg/distribution/distribution/load_test.go +++ b/pkg/distribution/distribution/load_test.go @@ -36,7 +36,7 @@ func TestLoadModel(t *testing.T) { id, err = client.LoadModel(pr, nil) done <- err }() - bldr, err := builder.FromGGUF(testGGUFFile) + bldr, err := builder.FromPath(testGGUFFile) if err != nil { t.Fatalf("Failed to create builder: %v", err) } diff --git a/pkg/distribution/distribution/normalize_test.go b/pkg/distribution/distribution/normalize_test.go index a149c151c..11754708d 100644 --- a/pkg/distribution/distribution/normalize_test.go +++ b/pkg/distribution/distribution/normalize_test.go @@ -2,14 +2,12 @@ package distribution import ( "context" - "errors" "io" "path/filepath" "strings" "testing" "github.com/docker/model-runner/pkg/distribution/builder" - "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/tarball" "github.com/sirupsen/logrus" ) @@ -68,28 +66,6 @@ func TestNormalizeModelName(t *testing.T) { expected: "registry.example.com/myorg/model:v1", }, - // HuggingFace cases (lowercased for OCI reference compatibility) - { - name: "huggingface short form lowercase", - input: "hf.co/model", - expected: "huggingface.co/model:latest", - }, - { - name: "huggingface short form uppercase", - input: "hf.co/Model", - expected: "huggingface.co/model:latest", // lowercased for OCI compatibility - }, - { - name: "huggingface short form with org", - input: "hf.co/MyOrg/MyModel", - expected: "huggingface.co/myorg/mymodel:latest", // lowercased for OCI compatibility - }, - { - name: "huggingface with tag", - input: "hf.co/model:v1", - expected: "huggingface.co/model:v1", - }, - // ID cases - without store lookup (IDs not in store) { name: "short ID (12 hex chars) not in store", @@ -146,7 +122,7 @@ func TestNormalizeModelName(t *testing.T) { { name: "name with uppercase (not huggingface)", input: "MyModel", - expected: "ai/MyModel:latest", + expected: "ai/mymodel:latest", }, } @@ -367,7 +343,8 @@ func TestIsHuggingFaceReference(t *testing.T) { {"huggingface.co without tag", "huggingface.co/org/model", true}, {"not huggingface", "registry.example.com/model:latest", false}, {"docker hub", "ai/gemma3:latest", false}, - {"hf.co prefix (not normalized)", "hf.co/org/model", false}, // This is the un-normalized form + {"hf.co prefix (short form)", "hf.co/org/model", true}, // Short form is also recognized + {"hf.co with quantization", "hf.co/bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", true}, {"empty", "", false}, } @@ -387,81 +364,63 @@ func TestParseHFReference(t *testing.T) { input string expectedRepo string expectedRev string + expectedTag string }{ { name: "basic with latest tag", input: "huggingface.co/org/model:latest", expectedRepo: "org/model", - expectedRev: "main", // latest maps to main + expectedRev: "main", // revision is always main + expectedTag: "latest", }, { - name: "with explicit revision", - input: "huggingface.co/org/model:v1.0", + name: "with quantization tag", + input: "huggingface.co/org/model:Q4_K_M", expectedRepo: "org/model", - expectedRev: "v1.0", + expectedRev: "main", + expectedTag: "Q4_K_M", }, { name: "without tag", input: "huggingface.co/org/model", expectedRepo: "org/model", expectedRev: "main", + expectedTag: "latest", }, { name: "with commit hash as tag", input: "huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct:abc123", expectedRepo: "HuggingFaceTB/SmolLM2-135M-Instruct", - expectedRev: "abc123", + expectedRev: "main", + expectedTag: "abc123", }, { name: "single name (no org)", input: "huggingface.co/model:latest", expectedRepo: "model", expectedRev: "main", + expectedTag: "latest", + }, + { + name: "hf.co prefix with quantization", + input: "hf.co/bartowski/Llama-3.2-1B-Instruct-GGUF:Q8_0", + expectedRepo: "bartowski/Llama-3.2-1B-Instruct-GGUF", + expectedRev: "main", + expectedTag: "Q8_0", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - repo, rev := parseHFReference(tt.input) + repo, rev, tag := parseHFReference(tt.input) if repo != tt.expectedRepo { t.Errorf("parseHFReference(%q) repo = %q, want %q", tt.input, repo, tt.expectedRepo) } if rev != tt.expectedRev { t.Errorf("parseHFReference(%q) rev = %q, want %q", tt.input, rev, tt.expectedRev) } - }) - } -} - -func TestIsNotOCIError(t *testing.T) { - tests := []struct { - name string - err error - expected bool - }{ - {"nil error", nil, false}, - {"generic error", errors.New("some error"), false}, - {"manifest unknown in message", errors.New("MANIFEST_UNKNOWN: manifest not found"), true}, - {"name unknown in message", errors.New("NAME_UNKNOWN: repository not found"), true}, - {"manifest unknown lowercase", errors.New("manifest unknown"), true}, - {"unrelated error", errors.New("network timeout"), false}, - {"HuggingFace not GGUF error", errors.New("Repository is not GGUF or is not compatible with llama.cpp"), true}, - {"HuggingFace llama.cpp incompatible", errors.New("not compatible with llama.cpp"), true}, - // registry.Error typed error cases - {"registry error MANIFEST_UNKNOWN", ®istry.Error{Code: "MANIFEST_UNKNOWN"}, true}, - {"registry error NAME_UNKNOWN", ®istry.Error{Code: "NAME_UNKNOWN"}, true}, - {"registry error other code", ®istry.Error{Code: "UNAUTHORIZED"}, false}, - // ErrInvalidReference is NOT treated as "not OCI" - it's a format error - // that should be reported to the user. Model names are lowercased during - // normalization to ensure OCI compatibility. - {"invalid reference error", registry.ErrInvalidReference, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := isNotOCIError(tt.err) - if result != tt.expected { - t.Errorf("isNotOCIError(%v) = %v, want %v", tt.err, result, tt.expected) + if tag != tt.expectedTag { + t.Errorf("parseHFReference(%q) tag = %q, want %q", tt.input, tag, tt.expectedTag) } }) } @@ -486,7 +445,7 @@ func loadTestModel(t *testing.T, client *Client, ggufPath string) string { done <- err }() - bldr, err := builder.FromGGUF(ggufPath) + bldr, err := builder.FromPath(ggufPath) if err != nil { t.Fatalf("Failed to create builder from GGUF: %v", err) } diff --git a/pkg/distribution/files/classify.go b/pkg/distribution/files/classify.go new file mode 100644 index 000000000..9e4b24b66 --- /dev/null +++ b/pkg/distribution/files/classify.go @@ -0,0 +1,111 @@ +// Package files provides utilities for classifying and working with model files. +// This package consolidates file classification logic used across the distribution system. +package files + +import ( + "path/filepath" + "strings" +) + +// FileType represents the type of file for model packaging +type FileType int + +const ( + // FileTypeUnknown is an unrecognized file type + FileTypeUnknown FileType = iota + // FileTypeGGUF is a GGUF model weight file + FileTypeGGUF + // FileTypeSafetensors is a safetensors model weight file + FileTypeSafetensors + // FileTypeConfig is a configuration file (json, txt, etc.) + FileTypeConfig + // FileTypeLicense is a license file + FileTypeLicense + // FileTypeChatTemplate is a Jinja chat template file + FileTypeChatTemplate +) + +// String returns a string representation of the file type +func (ft FileType) String() string { + switch ft { + case FileTypeGGUF: + return "gguf" + case FileTypeSafetensors: + return "safetensors" + case FileTypeConfig: + return "config" + case FileTypeLicense: + return "license" + case FileTypeChatTemplate: + return "chat_template" + case FileTypeUnknown: + return "unknown" + } + return "unknown" +} + +var ( + // ConfigExtensions defines the file extensions that should be treated as config files + ConfigExtensions = []string{".md", ".txt", ".json", ".vocab"} + + // SpecialConfigFiles are specific filenames treated as config files + SpecialConfigFiles = []string{"tokenizer.model"} + + // ChatTemplateExtensions defines extensions for chat template files + ChatTemplateExtensions = []string{".jinja"} + + // LicensePatterns defines patterns for license files (case-insensitive) + LicensePatterns = []string{"license", "licence", "copying", "notice"} +) + +// Classify determines the file type based on the filename. +// It examines the file extension and name patterns to classify the file. +func Classify(path string) FileType { + filename := filepath.Base(path) + lower := strings.ToLower(filename) + + // Check for GGUF files first (highest priority for model files) + if strings.HasSuffix(lower, ".gguf") { + return FileTypeGGUF + } + + // Check for safetensors files + if strings.HasSuffix(lower, ".safetensors") { + return FileTypeSafetensors + } + + // Check for chat template files (before generic config check) + for _, ext := range ChatTemplateExtensions { + if strings.HasSuffix(lower, ext) { + return FileTypeChatTemplate + } + } + + // Also check for files containing "chat_template" in the name + if strings.Contains(lower, "chat_template") { + return FileTypeChatTemplate + } + + // Check for license files + for _, pattern := range LicensePatterns { + if strings.Contains(lower, pattern) { + return FileTypeLicense + } + } + + // Check for config file extensions + for _, ext := range ConfigExtensions { + if strings.HasSuffix(lower, ext) { + return FileTypeConfig + } + } + + // Check for special config files + for _, special := range SpecialConfigFiles { + if strings.EqualFold(lower, special) { + return FileTypeConfig + } + } + + return FileTypeUnknown +} diff --git a/pkg/distribution/files/classify_test.go b/pkg/distribution/files/classify_test.go new file mode 100644 index 000000000..0c955560c --- /dev/null +++ b/pkg/distribution/files/classify_test.go @@ -0,0 +1,86 @@ +package files + +import ( + "testing" +) + +func TestClassify(t *testing.T) { + tests := []struct { + name string + filename string + want FileType + }{ + // GGUF files + {"gguf file", "model.gguf", FileTypeGGUF}, + {"gguf uppercase", "MODEL.GGUF", FileTypeGGUF}, + {"gguf with path", "/path/to/model.gguf", FileTypeGGUF}, + {"gguf shard", "model-00001-of-00015.gguf", FileTypeGGUF}, + + // Safetensors files + {"safetensors file", "model.safetensors", FileTypeSafetensors}, + {"safetensors uppercase", "MODEL.SAFETENSORS", FileTypeSafetensors}, + {"safetensors with path", "/path/to/model.safetensors", FileTypeSafetensors}, + {"safetensors shard", "model-00001-of-00003.safetensors", FileTypeSafetensors}, + + // Chat template files + {"jinja template", "template.jinja", FileTypeChatTemplate}, + {"jinja uppercase", "TEMPLATE.JINJA", FileTypeChatTemplate}, + {"chat_template file", "chat_template.txt", FileTypeChatTemplate}, + {"chat_template json", "chat_template.json", FileTypeChatTemplate}, + + // Config files + {"json config", "config.json", FileTypeConfig}, + {"txt config", "readme.txt", FileTypeConfig}, + {"md config", "README.md", FileTypeConfig}, + {"vocab file", "vocab.vocab", FileTypeConfig}, + {"tokenizer model", "tokenizer.model", FileTypeConfig}, + {"tokenizer model uppercase", "TOKENIZER.MODEL", FileTypeConfig}, + {"generation config", "generation_config.json", FileTypeConfig}, + {"tokenizer config", "tokenizer_config.json", FileTypeConfig}, + + // License files + {"license file", "LICENSE", FileTypeLicense}, + {"license md", "LICENSE.md", FileTypeLicense}, + {"license txt", "license.txt", FileTypeLicense}, + {"licence uk", "LICENCE", FileTypeLicense}, + {"copying", "COPYING", FileTypeLicense}, + {"notice", "NOTICE", FileTypeLicense}, + + // Unknown files + {"unknown bin", "model.bin", FileTypeUnknown}, + {"unknown py", "script.py", FileTypeUnknown}, + {"unknown empty", "", FileTypeUnknown}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := Classify(tt.filename) + if got != tt.want { + t.Errorf("Classify(%q) = %v, want %v", tt.filename, got, tt.want) + } + }) + } +} + +func TestFileTypeString(t *testing.T) { + tests := []struct { + ft FileType + want string + }{ + {FileTypeGGUF, "gguf"}, + {FileTypeSafetensors, "safetensors"}, + {FileTypeConfig, "config"}, + {FileTypeLicense, "license"}, + {FileTypeChatTemplate, "chat_template"}, + {FileTypeUnknown, "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := tt.ft.String() + if got != tt.want { + t.Errorf("FileType.String() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/pkg/distribution/format/format.go b/pkg/distribution/format/format.go new file mode 100644 index 000000000..6b4b5c40e --- /dev/null +++ b/pkg/distribution/format/format.go @@ -0,0 +1,95 @@ +// Package format provides a unified interface for handling different model formats. +// It uses the Strategy pattern to encapsulate format-specific behavior while providing +// a common interface for model creation and metadata extraction. +package format + +import ( + "fmt" + + "github.com/docker/model-runner/pkg/distribution/files" + "github.com/docker/model-runner/pkg/distribution/oci" + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/internal/utils" +) + +// Format defines the interface for model format-specific operations. +// Implementations handle the differences between GGUF and Safetensors formats. +type Format interface { + // Name returns the format identifier (e.g., "gguf" or "safetensors") + Name() types.Format + + // MediaType returns the OCI media type for weight layers of this format + MediaType() oci.MediaType + + // DiscoverShards finds all shard files for a sharded model given a starting path. + // For single-file models, it returns a slice containing only the input path. + // Returns an error if shards are incomplete or cannot be found. + DiscoverShards(path string) ([]string, error) + + // ExtractConfig parses the weight files and extracts model configuration metadata. + // This includes parameters count, quantization type, architecture, etc. + ExtractConfig(paths []string) (types.Config, error) +} + +// registry holds all registered format implementations +var registry = make(map[types.Format]Format) + +// Register adds a format implementation to the global registry. +// This should be called in init() functions by format implementations. +func Register(f Format) { + registry[f.Name()] = f +} + +// Get returns the format implementation for the given format type. +// Returns an error if the format is not registered. +func Get(name types.Format) (Format, error) { + f, ok := registry[name] + if !ok { + return nil, fmt.Errorf("unknown format: %s", name) + } + return f, nil +} + +// DetectFromPath determines the model format based on file extension. +// Returns the appropriate Format implementation or an error if unrecognized. +func DetectFromPath(path string) (Format, error) { + ft := files.Classify(path) + + switch ft { + case files.FileTypeGGUF: + return Get(types.FormatGGUF) + case files.FileTypeSafetensors: + return Get(types.FormatSafetensors) + case files.FileTypeUnknown, files.FileTypeConfig, files.FileTypeLicense, files.FileTypeChatTemplate: + return nil, fmt.Errorf("unable to detect format from path: %s (file type: %s)", utils.SanitizeForLog(path), ft) + } + return nil, fmt.Errorf("unable to detect format from path: %s", utils.SanitizeForLog(path)) +} + +// DetectFromPaths determines the model format based on a list of file paths. +// All paths must be of the same format. Returns an error if formats are mixed. +func DetectFromPaths(paths []string) (Format, error) { + if len(paths) == 0 { + return nil, fmt.Errorf("no paths provided") + } + + // Detect format from first path + format, err := DetectFromPath(paths[0]) + if err != nil { + return nil, err + } + + // Verify all paths are the same format + expectedName := format.Name() + for _, p := range paths[1:] { + f, err := DetectFromPath(p) + if err != nil { + return nil, err + } + if f.Name() != expectedName { + return nil, fmt.Errorf("mixed formats detected: %s and %s", expectedName, f.Name()) + } + } + + return format, nil +} diff --git a/pkg/distribution/format/format_test.go b/pkg/distribution/format/format_test.go new file mode 100644 index 000000000..c248f840b --- /dev/null +++ b/pkg/distribution/format/format_test.go @@ -0,0 +1,240 @@ +package format + +import ( + "testing" + + "github.com/docker/model-runner/pkg/distribution/types" +) + +func TestGetFormat(t *testing.T) { + tests := []struct { + name string + format types.Format + wantError bool + }{ + {"get gguf", types.FormatGGUF, false}, + {"get safetensors", types.FormatSafetensors, false}, + {"get unknown", types.Format("unknown"), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, err := Get(tt.format) + if tt.wantError { + if err == nil { + t.Error("Expected error, got nil") + } + return + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + if f.Name() != tt.format { + t.Errorf("Got format %s, want %s", f.Name(), tt.format) + } + }) + } +} + +func TestDetectFromPath(t *testing.T) { + tests := []struct { + name string + path string + wantFormat types.Format + wantError bool + }{ + {"gguf file", "model.gguf", types.FormatGGUF, false}, + {"gguf uppercase", "MODEL.GGUF", types.FormatGGUF, false}, + {"gguf with path", "/path/to/model.gguf", types.FormatGGUF, false}, + {"gguf shard", "model-00001-of-00015.gguf", types.FormatGGUF, false}, + + {"safetensors file", "model.safetensors", types.FormatSafetensors, false}, + {"safetensors uppercase", "MODEL.SAFETENSORS", types.FormatSafetensors, false}, + {"safetensors with path", "/path/to/model.safetensors", types.FormatSafetensors, false}, + {"safetensors shard", "model-00001-of-00003.safetensors", types.FormatSafetensors, false}, + + {"unknown extension", "model.bin", types.Format(""), true}, + {"config file", "config.json", types.Format(""), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, err := DetectFromPath(tt.path) + if tt.wantError { + if err == nil { + t.Error("Expected error, got nil") + } + return + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + if f.Name() != tt.wantFormat { + t.Errorf("Got format %s, want %s", f.Name(), tt.wantFormat) + } + }) + } +} + +func TestDetectFromPaths(t *testing.T) { + tests := []struct { + name string + paths []string + wantFormat types.Format + wantError bool + }{ + { + name: "single gguf", + paths: []string{"model.gguf"}, + wantFormat: types.FormatGGUF, + wantError: false, + }, + { + name: "multiple gguf", + paths: []string{"model-00001.gguf", "model-00002.gguf"}, + wantFormat: types.FormatGGUF, + wantError: false, + }, + { + name: "single safetensors", + paths: []string{"model.safetensors"}, + wantFormat: types.FormatSafetensors, + wantError: false, + }, + { + name: "multiple safetensors", + paths: []string{"model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"}, + wantFormat: types.FormatSafetensors, + wantError: false, + }, + { + name: "mixed formats", + paths: []string{"model.gguf", "model.safetensors"}, + wantError: true, + }, + { + name: "empty paths", + paths: []string{}, + wantError: true, + }, + { + name: "unknown file", + paths: []string{"config.json"}, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, err := DetectFromPaths(tt.paths) + if tt.wantError { + if err == nil { + t.Error("Expected error, got nil") + } + return + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + if f.Name() != tt.wantFormat { + t.Errorf("Got format %s, want %s", f.Name(), tt.wantFormat) + } + }) + } +} + +func TestGGUFFormat_Name(t *testing.T) { + f := &GGUFFormat{} + if f.Name() != types.FormatGGUF { + t.Errorf("Expected %s, got %s", types.FormatGGUF, f.Name()) + } +} + +func TestGGUFFormat_MediaType(t *testing.T) { + f := &GGUFFormat{} + if f.MediaType() != types.MediaTypeGGUF { + t.Errorf("Expected %s, got %s", types.MediaTypeGGUF, f.MediaType()) + } +} + +func TestSafetensorsFormat_Name(t *testing.T) { + f := &SafetensorsFormat{} + if f.Name() != types.FormatSafetensors { + t.Errorf("Expected %s, got %s", types.FormatSafetensors, f.Name()) + } +} + +func TestSafetensorsFormat_MediaType(t *testing.T) { + f := &SafetensorsFormat{} + if f.MediaType() != types.MediaTypeSafetensors { + t.Errorf("Expected %s, got %s", types.MediaTypeSafetensors, f.MediaType()) + } +} + +func TestNormalizeUnitString(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"16.78 M", "16.78M"}, + {"256.35 MiB", "256.35MiB"}, + {"409M", "409M"}, + {"1.5 B", "1.5B"}, + {" 100 KB ", "100KB"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := normalizeUnitString(tt.input) + if got != tt.want { + t.Errorf("normalizeUnitString(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestFormatParameters(t *testing.T) { + tests := []struct { + params int64 + want string + }{ + {1000, "1.00K"}, + {1000000, "1.00M"}, + {1000000000, "1.00B"}, + {7000000000, "7.00B"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := formatParameters(tt.params) + if got != tt.want { + t.Errorf("formatParameters(%d) = %q, want %q", tt.params, got, tt.want) + } + }) + } +} + +func TestFormatSize(t *testing.T) { + tests := []struct { + bytes int64 + want string + }{ + {1000, "1.00kB"}, + {1000000, "1.00MB"}, + {1000000000, "1.00GB"}, + {5000000000, "5.00GB"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := formatSize(tt.bytes) + if got != tt.want { + t.Errorf("formatSize(%d) = %q, want %q", tt.bytes, got, tt.want) + } + }) + } +} diff --git a/pkg/distribution/internal/gguf/metadata.go b/pkg/distribution/format/gguf.go similarity index 53% rename from pkg/distribution/internal/gguf/metadata.go rename to pkg/distribution/format/gguf.go index c9958a751..2d8260bb2 100644 --- a/pkg/distribution/internal/gguf/metadata.go +++ b/pkg/distribution/format/gguf.go @@ -1,12 +1,85 @@ -package gguf +package format import ( "fmt" + "regexp" "strings" + "github.com/docker/model-runner/pkg/distribution/oci" + "github.com/docker/model-runner/pkg/distribution/types" parser "github.com/gpustack/gguf-parser-go" ) +// GGUFFormat implements the Format interface for GGUF model files. +type GGUFFormat struct{} + +// init registers the GGUF format implementation. +func init() { + Register(&GGUFFormat{}) +} + +// Name returns the format identifier for GGUF. +func (g *GGUFFormat) Name() types.Format { + return types.FormatGGUF +} + +// MediaType returns the OCI media type for GGUF layers. +func (g *GGUFFormat) MediaType() oci.MediaType { + return types.MediaTypeGGUF +} + +// DiscoverShards finds all GGUF shard files for a sharded model. +// GGUF shards follow the pattern: -00001-of-00015.gguf +// For single-file models, returns a slice containing only the input path. +func (g *GGUFFormat) DiscoverShards(path string) ([]string, error) { + // Use the external GGUF parser's shard discovery + shards := parser.CompleteShardGGUFFilename(path) + if len(shards) == 0 { + // Single file, not sharded + return []string{path}, nil + } + return shards, nil +} + +// ExtractConfig parses GGUF file(s) and extracts model configuration metadata. +func (g *GGUFFormat) ExtractConfig(paths []string) (types.Config, error) { + if len(paths) == 0 { + return types.Config{Format: types.FormatGGUF}, nil + } + + // Parse the first shard/file to get metadata + gguf, err := parser.ParseGGUFFile(paths[0]) + if err != nil { + // Return empty config if parsing fails, continue without metadata + return types.Config{Format: types.FormatGGUF}, nil + } + + return types.Config{ + Format: types.FormatGGUF, + Parameters: normalizeUnitString(gguf.Metadata().Parameters.String()), + Architecture: strings.TrimSpace(gguf.Metadata().Architecture), + Quantization: strings.TrimSpace(gguf.Metadata().FileType.String()), + Size: normalizeUnitString(gguf.Metadata().Size.String()), + GGUF: extractGGUFMetadata(&gguf.Header), + }, nil +} + +var ( + // spaceBeforeUnitRegex matches one or more spaces between a valid number and a letter (unit) + // Used to remove spaces between numbers and units (e.g., "16.78 M" -> "16.78M") + spaceBeforeUnitRegex = regexp.MustCompile(`([0-9]+(?:\.[0-9]+)?)\s+([A-Za-z]+)`) +) + +// normalizeUnitString removes spaces between numbers and units for consistent formatting +// Examples: "16.78 M" -> "16.78M", "256.35 MiB" -> "256.35MiB", "409M" -> "409M" +func normalizeUnitString(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return s + } + return spaceBeforeUnitRegex.ReplaceAllString(s, "$1$2") +} + const maxArraySize = 50 // extractGGUFMetadata converts the GGUF header metadata into a string map. @@ -47,7 +120,7 @@ func extractGGUFMetadata(header *parser.GGUFHeader) map[string]string { case parser.GGUFMetadataValueTypeString: value = kv.ValueString() case parser.GGUFMetadataValueTypeArray: - value = handleArray(kv.ValueArray()) + value = handleGGUFArray(kv.ValueArray()) default: value = fmt.Sprintf("[unknown type %d]", kv.ValueType) } @@ -57,8 +130,8 @@ func extractGGUFMetadata(header *parser.GGUFHeader) map[string]string { return metadata } -// handleArray processes an array value and returns its string representation -func handleArray(arrayValue parser.GGUFMetadataKVArrayValue) string { +// handleGGUFArray processes an array value and returns its string representation. +func handleGGUFArray(arrayValue parser.GGUFMetadataKVArrayValue) string { var values []string for _, v := range arrayValue.Array { switch arrayValue.Type { diff --git a/pkg/distribution/format/safetensors.go b/pkg/distribution/format/safetensors.go new file mode 100644 index 000000000..231625770 --- /dev/null +++ b/pkg/distribution/format/safetensors.go @@ -0,0 +1,305 @@ +package format + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "sort" + "strconv" + + "github.com/docker/go-units" + "github.com/docker/model-runner/pkg/distribution/oci" + "github.com/docker/model-runner/pkg/distribution/types" +) + +// SafetensorsFormat implements the Format interface for Safetensors model files. +type SafetensorsFormat struct{} + +// init registers the Safetensors format implementation. +func init() { + Register(&SafetensorsFormat{}) +} + +// Name returns the format identifier for Safetensors. +func (s *SafetensorsFormat) Name() types.Format { + return types.FormatSafetensors +} + +// MediaType returns the OCI media type for Safetensors layers. +func (s *SafetensorsFormat) MediaType() oci.MediaType { + return types.MediaTypeSafetensors +} + +var ( + // shardPattern matches safetensors shard filenames like "model-00001-of-00003.safetensors" + // This pattern assumes 5-digit zero-padded numbering (e.g., 00001-of-00003), which is + // the most common format used by popular model repositories. + shardPattern = regexp.MustCompile(`^(.+)-(\d{5})-of-(\d{5})\.safetensors$`) +) + +// DiscoverShards finds all Safetensors shard files for a sharded model. +// Safetensors shards follow the pattern: -00001-of-00003.safetensors +// For single-file models, returns a slice containing only the input path. +func (s *SafetensorsFormat) DiscoverShards(path string) ([]string, error) { + baseName := filepath.Base(path) + matches := shardPattern.FindStringSubmatch(baseName) + + if len(matches) != 4 { + // Not a sharded file, return single path + return []string{path}, nil + } + + prefix := matches[1] + totalShards, err := strconv.Atoi(matches[3]) + if err != nil { + return nil, fmt.Errorf("parse shard count: %w", err) + } + + dir := filepath.Dir(path) + var shards []string + + // Look for all shards in the same directory + for i := 1; i <= totalShards; i++ { + shardName := fmt.Sprintf("%s-%05d-of-%05d.safetensors", prefix, i, totalShards) + shardPath := filepath.Join(dir, shardName) + + // Check if the file exists + if _, err := os.Stat(shardPath); err == nil { + shards = append(shards, shardPath) + } + } + + // Return error if we didn't find all expected shards + if len(shards) != totalShards { + return nil, fmt.Errorf("incomplete shard set: found %d of %d shards for %s", len(shards), totalShards, baseName) + } + + // Sort to ensure consistent ordering + sort.Strings(shards) + + return shards, nil +} + +// ExtractConfig parses Safetensors file(s) and extracts model configuration metadata. +func (s *SafetensorsFormat) ExtractConfig(paths []string) (types.Config, error) { + if len(paths) == 0 { + return types.Config{Format: types.FormatSafetensors}, nil + } + + // Parse the first safetensors file to extract metadata + header, err := parseSafetensorsHeader(paths[0]) + if err != nil { + // Continue without metadata if parsing fails + return types.Config{Format: types.FormatSafetensors}, nil + } + + // Calculate total size across all files + var totalSize int64 + for _, path := range paths { + info, err := os.Stat(path) + if err != nil { + return types.Config{}, fmt.Errorf("failed to stat file %s: %w", path, err) + } + totalSize += info.Size() + } + + // Calculate parameters + params := header.calculateParameters() + + // Extract architecture from metadata if available + architecture := "" + if arch, ok := header.Metadata["architecture"]; ok { + architecture = fmt.Sprintf("%v", arch) + } + + return types.Config{ + Format: types.FormatSafetensors, + Parameters: formatParameters(params), + Quantization: header.getQuantization(), + Size: formatSize(totalSize), + Architecture: architecture, + Safetensors: header.extractMetadata(), + }, nil +} + +const ( + quantizationUnknown = "unknown" + quantizationMixed = "mixed" +) + +// safetensorsHeader represents the JSON header in a safetensors file +type safetensorsHeader struct { + Metadata map[string]interface{} + Tensors map[string]tensorInfo +} + +// tensorInfo contains information about a tensor +type tensorInfo struct { + Dtype string + Shape []int64 + DataOffsets [2]int64 +} + +// parseSafetensorsHeader reads only the header from a safetensors file without loading the entire file. +func parseSafetensorsHeader(path string) (*safetensorsHeader, error) { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open file: %w", err) + } + defer file.Close() + + // Read the first 8 bytes to get the header length + var headerLen uint64 + if err := binary.Read(file, binary.LittleEndian, &headerLen); err != nil { + return nil, fmt.Errorf("read header length: %w", err) + } + + // Sanity check: header shouldn't be larger than 100MB + if headerLen > 100*1024*1024 { + return nil, fmt.Errorf("header length too large: %d bytes", headerLen) + } + + // Read only the header JSON (not the entire file!) + headerBytes := make([]byte, headerLen) + if _, err := io.ReadFull(file, headerBytes); err != nil { + return nil, fmt.Errorf("read header: %w", err) + } + + // Parse the JSON header + var rawHeader map[string]interface{} + if err := json.Unmarshal(headerBytes, &rawHeader); err != nil { + return nil, fmt.Errorf("parse JSON header: %w", err) + } + + // Extract metadata (stored under "__metadata__" key) + var metadata map[string]interface{} + if rawMetadata, ok := rawHeader["__metadata__"].(map[string]interface{}); ok { + metadata = rawMetadata + delete(rawHeader, "__metadata__") + } + + // Parse tensor info from remaining keys + tensors := make(map[string]tensorInfo) + for name, value := range rawHeader { + tensorMap, ok := value.(map[string]interface{}) + if !ok { + continue + } + + // Parse dtype + dtype, _ := tensorMap["dtype"].(string) + + // Parse shape + var shape []int64 + if shapeArray, ok := tensorMap["shape"].([]interface{}); ok { + for index, v := range shapeArray { + floatVal, ok := v.(float64) + if !ok { + return nil, fmt.Errorf("invalid shape value for tensor %q at index %d: expected number, got %T", name, index, v) + } + shape = append(shape, int64(floatVal)) + } + } + + // Parse data_offsets + var dataOffsets [2]int64 + if offsetsArray, ok := tensorMap["data_offsets"].([]interface{}); ok { + if len(offsetsArray) != 2 { + return nil, fmt.Errorf("invalid data_offsets for tensor %q: expected 2 elements, got %d", name, len(offsetsArray)) + } + for index, offset := range offsetsArray { + floatVal, ok := offset.(float64) + if !ok { + return nil, fmt.Errorf("invalid data_offsets value for tensor %q at index %d: expected number, got %T", name, index, offset) + } + dataOffsets[index] = int64(floatVal) + } + } + + tensors[name] = tensorInfo{ + Dtype: dtype, + Shape: shape, + DataOffsets: dataOffsets, + } + } + + return &safetensorsHeader{ + Metadata: metadata, + Tensors: tensors, + }, nil +} + +// calculateParameters sums up all tensor parameters +func (h *safetensorsHeader) calculateParameters() int64 { + var total int64 + for _, tensor := range h.Tensors { + params := int64(1) + for _, dim := range tensor.Shape { + params *= dim + } + total += params + } + return total +} + +// getQuantization determines the quantization type from tensor dtypes +func (h *safetensorsHeader) getQuantization() string { + if len(h.Tensors) == 0 { + return quantizationUnknown + } + + // Count dtype occurrences (skip empty dtypes) + dtypeCounts := make(map[string]int) + for _, tensor := range h.Tensors { + if tensor.Dtype != "" { + dtypeCounts[tensor.Dtype]++ + } + } + + // No valid dtypes found + if len(dtypeCounts) == 0 { + return quantizationUnknown + } + + // If all tensors have the same dtype, return it + if len(dtypeCounts) == 1 { + for dtype := range dtypeCounts { + return dtype + } + } + + return quantizationMixed +} + +// extractMetadata converts header to string map (similar to GGUF) +func (h *safetensorsHeader) extractMetadata() map[string]string { + metadata := make(map[string]string) + + // Add metadata from __metadata__ section + if h.Metadata != nil { + for k, v := range h.Metadata { + metadata[k] = fmt.Sprintf("%v", v) + } + } + + // Add tensor count + metadata["tensor_count"] = fmt.Sprintf("%d", len(h.Tensors)) + + return metadata +} + +// formatParameters converts parameter count to human-readable format +// Returns format like "361.82M" or "1.5B" (no space before unit, base 1000, where B = Billion) +func formatParameters(params int64) string { + return units.CustomSize("%.2f%s", float64(params), 1000.0, []string{"", "K", "M", "B", "T"}) +} + +// formatSize converts bytes to human-readable format matching Docker's style +// Returns format like "256MB" (decimal units, no space, matching `docker images`) +func formatSize(bytes int64) string { + return units.CustomSize("%.2f%s", float64(bytes), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB"}) +} diff --git a/pkg/distribution/huggingface/model.go b/pkg/distribution/huggingface/model.go index a20601fc9..37835230e 100644 --- a/pkg/distribution/huggingface/model.go +++ b/pkg/distribution/huggingface/model.go @@ -17,8 +17,9 @@ import ( // BuildModel downloads files from a HuggingFace repository and constructs an OCI model artifact // This is the main entry point for pulling native HuggingFace models -func BuildModel(ctx context.Context, client *Client, repo, revision string, tempDir string, progressWriter io.Writer) (types.ModelArtifact, error) { - // Step 1: List files in the repository +// The tag parameter is used for GGUF repos to select the requested quantization (e.g., "Q4_K_M") +func BuildModel(ctx context.Context, client *Client, repo, revision, tag string, tempDir string, progressWriter io.Writer) (types.ModelArtifact, error) { + // List files in the repository if progressWriter != nil { _ = progress.WriteProgress(progressWriter, "Fetching file list...", 0, 0, 0, "") } @@ -28,15 +29,36 @@ func BuildModel(ctx context.Context, client *Client, repo, revision string, temp return nil, fmt.Errorf("list files: %w", err) } - // Step 2: Filter to model files (safetensors + configs) - safetensorsFiles, configFiles := FilterModelFiles(files) + // Filter to model files (weights + configs) + weightFiles, configFiles := FilterModelFiles(files) - if len(safetensorsFiles) == 0 { - return nil, fmt.Errorf("no safetensors files found in repository %s", repo) + if len(weightFiles) == 0 { + return nil, fmt.Errorf("no model weight files (GGUF or SafeTensors) found in repository %s", repo) + } + + // For GGUF repos with multiple quantizations, select the appropriate files + var mmprojFile *RepoFile + if isGGUFModel(weightFiles) && len(weightFiles) > 1 { + // Use the tag as quantization hint (e.g., "Q4_K_M", "Q8_0", or "latest") + weightFiles, mmprojFile = SelectGGUFFiles(weightFiles, tag) + if len(weightFiles) == 0 { + return nil, fmt.Errorf("no GGUF files found matching quantization %q in repository %s", tag, repo) + } + + if progressWriter != nil { + if tag == "" || tag == "latest" || tag == "main" { + _ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization (default)", DefaultGGUFQuantization), 0, 0, 0, "") + } else { + _ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization", tag), 0, 0, 0, "") + } + } } // Combine all files to download - allFiles := append(safetensorsFiles, configFiles...) + allFiles := append(weightFiles, configFiles...) + if mmprojFile != nil { + allFiles = append(allFiles, *mmprojFile) + } if progressWriter != nil { totalSize := TotalSize(allFiles) @@ -57,7 +79,7 @@ func BuildModel(ctx context.Context, client *Client, repo, revision string, temp _ = progress.WriteProgress(progressWriter, "Building model artifact...", 0, 0, 0, "") } - model, err := buildModelFromFiles(result.LocalPaths, safetensorsFiles, configFiles, tempDir) + model, err := buildModelFromFiles(result.LocalPaths, weightFiles, configFiles, tempDir) if err != nil { return nil, fmt.Errorf("build model: %w", err) } @@ -66,20 +88,20 @@ func BuildModel(ctx context.Context, client *Client, repo, revision string, temp } // buildModelFromFiles constructs an OCI model artifact from downloaded files -func buildModelFromFiles(localPaths map[string]string, safetensorsFiles, configFiles []RepoFile, tempDir string) (types.ModelArtifact, error) { - // Collect safetensors paths (sorted for reproducibility) - var safetensorsPaths []string - for _, f := range safetensorsFiles { +func buildModelFromFiles(localPaths map[string]string, weightFiles, configFiles []RepoFile, tempDir string) (types.ModelArtifact, error) { + // Collect weight file paths (sorted for reproducibility) + var weightPaths []string + for _, f := range weightFiles { localPath, ok := localPaths[f.Path] if !ok { return nil, fmt.Errorf("missing local path for %s", f.Path) } - safetensorsPaths = append(safetensorsPaths, localPath) + weightPaths = append(weightPaths, localPath) } - sort.Strings(safetensorsPaths) + sort.Strings(weightPaths) - // Create builder from safetensors files - b, err := builder.FromSafetensors(safetensorsPaths) + // Create builder from weight files - auto-detects format (GGUF or SafeTensors) + b, err := builder.FromPaths(weightPaths) if err != nil { return nil, fmt.Errorf("create builder: %w", err) } diff --git a/pkg/distribution/huggingface/repository.go b/pkg/distribution/huggingface/repository.go index 79a6890e7..c0d69a959 100644 --- a/pkg/distribution/huggingface/repository.go +++ b/pkg/distribution/huggingface/repository.go @@ -2,9 +2,15 @@ package huggingface import ( "path" + "sort" "strings" - "github.com/docker/model-runner/pkg/distribution/packaging" + "github.com/docker/model-runner/pkg/distribution/files" +) + +const ( + // DefaultGGUFQuantization is the preferred quantization when "latest" is requested + DefaultGGUFQuantization = "Q4_K_M" ) // RepoFile represents a file in a HuggingFace repository @@ -36,79 +42,276 @@ func (f *RepoFile) Filename() string { return path.Base(f.Path) } -// fileType represents the type of file for model packaging -type fileType int +// FilterModelFiles filters repository files to only include files needed for model-runner +// Returns weight files (GGUF or SafeTensors) and config files separately +func FilterModelFiles(repoFiles []RepoFile) (weights []RepoFile, configs []RepoFile) { + for _, f := range repoFiles { + if f.Type != "file" { + continue + } -const ( - // fileTypeUnknown is an unrecognized file type - fileTypeUnknown fileType = iota - // fileTypeSafetensors is a safetensors model weight file - fileTypeSafetensors - // fileTypeConfig is a configuration file (json, txt, etc.) - fileTypeConfig -) + switch ft := files.Classify(f.Filename()); ft { + case files.FileTypeSafetensors, files.FileTypeGGUF: + weights = append(weights, f) + case files.FileTypeConfig, files.FileTypeChatTemplate: + configs = append(configs, f) + case files.FileTypeUnknown, files.FileTypeLicense: + // Skip these file types + } + } + return weights, configs +} -// classifyFile determines the file type based on filename -func classifyFile(filename string) fileType { - lower := strings.ToLower(filename) +// TotalSize calculates the total size of files +func TotalSize(repoFiles []RepoFile) int64 { + var total int64 + for _, f := range repoFiles { + total += f.ActualSize() + } + return total +} - // Check for safetensors files - if strings.HasSuffix(lower, ".safetensors") { - return fileTypeSafetensors +// isSafetensorsModel checks if the files contain at least one safetensors file +func isSafetensorsModel(repoFiles []RepoFile) bool { + for _, f := range repoFiles { + if f.Type == "file" && files.Classify(f.Filename()) == files.FileTypeSafetensors { + return true + } } + return false +} - // Check for config file extensions - for _, ext := range packaging.ConfigExtensions { - if strings.HasSuffix(lower, ext) { - return fileTypeConfig +// isGGUFModel checks if the files contain at least one GGUF file +func isGGUFModel(repoFiles []RepoFile) bool { + for _, f := range repoFiles { + if f.Type == "file" && files.Classify(f.Filename()) == files.FileTypeGGUF { + return true } } + return false +} + +// SelectGGUFFiles selects GGUF files based on the requested quantization. +// For GGUF repos with multiple quantization variants: +// - If requestedQuant matches a known quantization (e.g., "Q4_K_M"), select files with that quantization +// - If requestedQuant is empty, "latest", or "main", prefer Q4_K_M, then fall back to first GGUF +// - Handles sharded GGUF files (selects all shards of the chosen quantization) +// - Also selects mmproj files for multimodal models (prefers F16) +func SelectGGUFFiles(ggufFiles []RepoFile, requestedQuant string) (selected []RepoFile, mmproj *RepoFile) { + if len(ggufFiles) == 0 { + return nil, nil + } - // Check for special config files - for _, special := range packaging.SpecialConfigFiles { - if strings.EqualFold(filename, special) { - return fileTypeConfig + // Separate mmproj files from model files + var modelFiles []RepoFile + var mmprojFiles []RepoFile + + for _, f := range ggufFiles { + filename := f.Filename() + if isMMProjFile(filename) { + mmprojFiles = append(mmprojFiles, f) + } else { + modelFiles = append(modelFiles, f) } } - return fileTypeUnknown + // Select mmproj file (prefer F16) + mmproj = selectMMProj(mmprojFiles) + + // If only one model file, return it + if len(modelFiles) == 1 { + return modelFiles, mmproj + } + + // Normalize requested quantization + quant := normalizeQuantization(requestedQuant) + + // Try to find files matching the requested quantization + if quant != "" { + matching := filterByQuantization(modelFiles, quant) + if len(matching) > 0 { + return matching, mmproj + } + } + + // Fall back to default quantization (Q4_K_M) + defaultMatching := filterByQuantization(modelFiles, DefaultGGUFQuantization) + if len(defaultMatching) > 0 { + return defaultMatching, mmproj + } + + // No specific quantization found - return the first file (or sharded set) + // Sort by filename to ensure consistent selection + first := selectFirstGGUF(modelFiles) + return first, mmproj } -// FilterModelFiles filters repository files to only include files needed for model-runner -// Returns safetensors files and config files separately -func FilterModelFiles(files []RepoFile) (safetensors []RepoFile, configs []RepoFile) { - for _, f := range files { - if f.Type != "file" { - continue +// normalizeQuantization normalizes the quantization string +// Returns empty string for "latest" or "main" (meaning use default) +func normalizeQuantization(quant string) string { + if quant == "" || quant == "latest" || quant == "main" { + return "" + } + return quant +} + +// filterByQuantization filters GGUF files by quantization type +// Handles both single files and sharded files +func filterByQuantization(modelFiles []RepoFile, quant string) []RepoFile { + var matching []RepoFile + + for _, f := range modelFiles { + filename := f.Filename() + if containsQuantization(filename, quant) { + matching = append(matching, f) } + } - switch classifyFile(f.Filename()) { - case fileTypeSafetensors: - safetensors = append(safetensors, f) - case fileTypeConfig: - configs = append(configs, f) - case fileTypeUnknown: - // Skip unknown file types + return matching +} + +// containsQuantization checks if a filename contains the specified quantization +// Matches patterns like "model-Q4_K_M.gguf" or "model-Q4_K_M-00001-of-00003.gguf" +func containsQuantization(filename, quant string) bool { + // Case-insensitive comparison + filenameLower := strings.ToLower(filename) + quantLower := strings.ToLower(quant) + + // Remove .gguf extension for cleaner matching + if hasSuffix(filenameLower, ".gguf") { + filenameLower = filenameLower[:len(filenameLower)-5] + } + + // Common patterns: + // - "model-Q4_K_M" -> ends with "-Q4_K_M" or "-Q4_K_M-00001-of-00003" + // - "model.Q4_K_M" -> ends with ".Q4_K_M" + // - "Llama-3.2-1B-Instruct-Q4_K_M" -> ends with "-Q4_K_M" + + // Check if the quantization appears after a separator (-, ., _) and is followed by + // either end of string or another separator + separators := []string{"-", ".", "_"} + for _, sep := range separators { + pattern := sep + quantLower + idx := strings.Index(filenameLower, pattern) + if idx >= 0 { + // Check what comes after the quantization + afterIdx := idx + len(pattern) + if afterIdx == len(filenameLower) { + // Quantization is at end of filename (after removing .gguf) + return true + } + // Check if followed by a separator (e.g., "-00001-of-00003") + if afterIdx < len(filenameLower) { + nextChar := filenameLower[afterIdx] + if nextChar == '-' || nextChar == '.' || nextChar == '_' { + return true + } + } } } - return safetensors, configs + + return false } -// TotalSize calculates the total size of files -func TotalSize(files []RepoFile) int64 { - var total int64 - for _, f := range files { - total += f.ActualSize() +func hasSuffix(s, suffix string) bool { + return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix +} + +// selectFirstGGUF selects the first GGUF file (handling sharded files) +func selectFirstGGUF(modelFiles []RepoFile) []RepoFile { + if len(modelFiles) == 0 { + return nil } - return total + + // Sort by filename for consistent ordering + sorted := make([]RepoFile, len(modelFiles)) + copy(sorted, modelFiles) + sort.Slice(sorted, func(i, j int) bool { return sorted[i].Filename() < sorted[j].Filename() }) + + // Get the first file + first := sorted[0] + + // Check if it's a sharded file + if isShardedFile(first.Filename()) { + // Find all shards with the same prefix + return findAllShards(sorted, first.Filename()) + } + + return []RepoFile{first} } -// isSafetensorsModel checks if the files contain at least one safetensors file -func isSafetensorsModel(files []RepoFile) bool { +// isShardedFile checks if a filename follows the sharded pattern +// e.g., "model-00001-of-00003.gguf" +func isShardedFile(filename string) bool { + // Delegate to indexOfShardPattern so shard detection is precise and consistent + return indexOfShardPattern(filename) >= 0 +} + +// findAllShards finds all shards that belong to the same model +func findAllShards(files []RepoFile, firstShard string) []RepoFile { + // Extract the base prefix (everything before the shard number) + // e.g., "model-00001-of-00003.gguf" -> "model" + prefix := extractShardPrefix(firstShard) + + var shards []RepoFile for _, f := range files { - if f.Type == "file" && classifyFile(f.Filename()) == fileTypeSafetensors { - return true + if strings.HasPrefix(f.Filename(), prefix) && isShardedFile(f.Filename()) { + shards = append(shards, f) } } - return false + + return shards +} + +// extractShardPrefix extracts the model name prefix from a sharded filename +func extractShardPrefix(filename string) string { + // Find "-00001-of-" or similar pattern and return everything before it + idx := indexOfShardPattern(filename) + if idx > 0 { + return filename[:idx] + } + return filename +} + +// isMMProjFile checks if a file is a multimodal projector file +func isMMProjFile(filename string) bool { + lower := strings.ToLower(filename) + return strings.Contains(lower, "mmproj") +} + +// selectMMProj selects the best mmproj file, preferring F16 +func selectMMProj(mmprojFiles []RepoFile) *RepoFile { + if len(mmprojFiles) == 0 { + return nil + } + + // Prefer F16 over other formats + for i := range mmprojFiles { + filename := strings.ToLower(mmprojFiles[i].Filename()) + if strings.Contains(filename, "f16") { + return &mmprojFiles[i] + } + } + + // Fall back to first mmproj file + return &mmprojFiles[0] +} + +func indexOfShardPattern(filename string) int { + // Look for pattern like "-00001-of-" or "-00002-of-" + for i := 0; i < len(filename)-10; i++ { + if filename[i] == '-' && + filename[i+1] >= '0' && filename[i+1] <= '9' && + filename[i+2] >= '0' && filename[i+2] <= '9' && + filename[i+3] >= '0' && filename[i+3] <= '9' && + filename[i+4] >= '0' && filename[i+4] <= '9' && + filename[i+5] >= '0' && filename[i+5] <= '9' && + filename[i+6] == '-' && + filename[i+7] == 'o' && + filename[i+8] == 'f' && + filename[i+9] == '-' { + return i + } + } + return -1 } diff --git a/pkg/distribution/huggingface/repository_gguf_test.go b/pkg/distribution/huggingface/repository_gguf_test.go new file mode 100644 index 000000000..4e3761571 --- /dev/null +++ b/pkg/distribution/huggingface/repository_gguf_test.go @@ -0,0 +1,215 @@ +package huggingface + +import ( + "testing" +) + +func TestSelectGGUFFiles(t *testing.T) { + tests := []struct { + name string + files []RepoFile + requestedQuant string + expectedFiles []string // filenames + expectedMMProj string // mmproj filename (or "") + }{ + { + name: "single file, no selection needed", + files: []RepoFile{ + {Type: "file", Path: "model.gguf"}, + }, + requestedQuant: "latest", + expectedFiles: []string{"model.gguf"}, + expectedMMProj: "", + }, + { + name: "multiple quantizations, select Q4_K_M explicitly", + files: []RepoFile{ + {Type: "file", Path: "model-Q2_K.gguf"}, + {Type: "file", Path: "model-Q4_K_M.gguf"}, + {Type: "file", Path: "model-Q8_0.gguf"}, + }, + requestedQuant: "Q4_K_M", + expectedFiles: []string{"model-Q4_K_M.gguf"}, + expectedMMProj: "", + }, + { + name: "multiple quantizations, select Q8_0", + files: []RepoFile{ + {Type: "file", Path: "model-Q2_K.gguf"}, + {Type: "file", Path: "model-Q4_K_M.gguf"}, + {Type: "file", Path: "model-Q8_0.gguf"}, + }, + requestedQuant: "Q8_0", + expectedFiles: []string{"model-Q8_0.gguf"}, + expectedMMProj: "", + }, + { + name: "latest tag defaults to Q4_K_M", + files: []RepoFile{ + {Type: "file", Path: "model-Q2_K.gguf"}, + {Type: "file", Path: "model-Q4_K_M.gguf"}, + {Type: "file", Path: "model-Q8_0.gguf"}, + }, + requestedQuant: "latest", + expectedFiles: []string{"model-Q4_K_M.gguf"}, + expectedMMProj: "", + }, + { + name: "empty tag defaults to Q4_K_M", + files: []RepoFile{ + {Type: "file", Path: "model-Q2_K.gguf"}, + {Type: "file", Path: "model-Q4_K_M.gguf"}, + {Type: "file", Path: "model-Q8_0.gguf"}, + }, + requestedQuant: "", + expectedFiles: []string{"model-Q4_K_M.gguf"}, + expectedMMProj: "", + }, + { + name: "no Q4_K_M, fallback to first file", + files: []RepoFile{ + {Type: "file", Path: "model-Q2_K.gguf"}, + {Type: "file", Path: "model-Q8_0.gguf"}, + }, + requestedQuant: "latest", + expectedFiles: []string{"model-Q2_K.gguf"}, + expectedMMProj: "", + }, + { + name: "case insensitive matching", + files: []RepoFile{ + {Type: "file", Path: "model-q4_k_m.gguf"}, + {Type: "file", Path: "model-q8_0.gguf"}, + }, + requestedQuant: "Q4_K_M", + expectedFiles: []string{"model-q4_k_m.gguf"}, + expectedMMProj: "", + }, + { + name: "multimodal with mmproj, prefers F16", + files: []RepoFile{ + {Type: "file", Path: "model-Q4_K_M.gguf"}, + {Type: "file", Path: "mmproj-model-f32.gguf"}, + {Type: "file", Path: "mmproj-model-f16.gguf"}, + }, + requestedQuant: "Q4_K_M", + expectedFiles: []string{"model-Q4_K_M.gguf"}, + expectedMMProj: "mmproj-model-f16.gguf", + }, + { + name: "multimodal with only f32 mmproj", + files: []RepoFile{ + {Type: "file", Path: "model-Q4_K_M.gguf"}, + {Type: "file", Path: "mmproj-model-f32.gguf"}, + }, + requestedQuant: "Q4_K_M", + expectedFiles: []string{"model-Q4_K_M.gguf"}, + expectedMMProj: "mmproj-model-f32.gguf", + }, + { + name: "bartowski style naming", + files: []RepoFile{ + {Type: "file", Path: "Llama-3.2-1B-Instruct-Q2_K.gguf"}, + {Type: "file", Path: "Llama-3.2-1B-Instruct-Q4_K_M.gguf"}, + {Type: "file", Path: "Llama-3.2-1B-Instruct-Q5_K_M.gguf"}, + {Type: "file", Path: "Llama-3.2-1B-Instruct-Q6_K.gguf"}, + {Type: "file", Path: "Llama-3.2-1B-Instruct-Q8_0.gguf"}, + {Type: "file", Path: "Llama-3.2-1B-Instruct-IQ4_XS.gguf"}, + }, + requestedQuant: "Q5_K_M", + expectedFiles: []string{"Llama-3.2-1B-Instruct-Q5_K_M.gguf"}, + expectedMMProj: "", + }, + { + name: "IQ quantization", + files: []RepoFile{ + {Type: "file", Path: "model-IQ2_XXS.gguf"}, + {Type: "file", Path: "model-IQ4_XS.gguf"}, + {Type: "file", Path: "model-Q4_K_M.gguf"}, + }, + requestedQuant: "IQ4_XS", + expectedFiles: []string{"model-IQ4_XS.gguf"}, + expectedMMProj: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selected, mmproj := SelectGGUFFiles(tt.files, tt.requestedQuant) + + // Check selected files + if len(selected) != len(tt.expectedFiles) { + t.Errorf("SelectGGUFFiles() returned %d files, want %d", len(selected), len(tt.expectedFiles)) + return + } + + for i, f := range selected { + if f.Filename() != tt.expectedFiles[i] { + t.Errorf("SelectGGUFFiles() file[%d] = %q, want %q", i, f.Filename(), tt.expectedFiles[i]) + } + } + + // Check mmproj + if tt.expectedMMProj == "" { + if mmproj != nil { + t.Errorf("SelectGGUFFiles() mmproj = %q, want nil", mmproj.Filename()) + } + } else { + if mmproj == nil { + t.Errorf("SelectGGUFFiles() mmproj = nil, want %q", tt.expectedMMProj) + } else if mmproj.Filename() != tt.expectedMMProj { + t.Errorf("SelectGGUFFiles() mmproj = %q, want %q", mmproj.Filename(), tt.expectedMMProj) + } + } + }) + } +} + +func TestContainsQuantization(t *testing.T) { + tests := []struct { + filename string + quant string + expected bool + }{ + {"model-Q4_K_M.gguf", "Q4_K_M", true}, + {"model-Q4_K_M.gguf", "q4_k_m", true}, // case insensitive + {"model-Q4_K_M.gguf", "Q8_0", false}, + {"model.Q4_K_M.gguf", "Q4_K_M", true}, // dot separator + {"model_Q4_K_M.gguf", "Q4_K_M", true}, // underscore separator + {"Llama-3.2-1B-Instruct-Q4_K_M.gguf", "Q4_K_M", true}, + {"model-Q4_K_M-00001-of-00003.gguf", "Q4_K_M", true}, // sharded + {"model-IQ4_XS.gguf", "IQ4_XS", true}, + {"model.gguf", "Q4_K_M", false}, + } + + for _, tt := range tests { + t.Run(tt.filename+"_"+tt.quant, func(t *testing.T) { + result := containsQuantization(tt.filename, tt.quant) + if result != tt.expected { + t.Errorf("containsQuantization(%q, %q) = %v, want %v", tt.filename, tt.quant, result, tt.expected) + } + }) + } +} + +func TestIsMMProjFile(t *testing.T) { + tests := []struct { + filename string + expected bool + }{ + {"mmproj-model-f16.gguf", true}, + {"mmproj-model-f32.gguf", true}, + {"MMPROJ-model.gguf", true}, // case insensitive + {"model-Q4_K_M.gguf", false}, + {"model.gguf", false}, + } + + for _, tt := range tests { + t.Run(tt.filename, func(t *testing.T) { + result := isMMProjFile(tt.filename) + if result != tt.expected { + t.Errorf("isMMProjFile(%q) = %v, want %v", tt.filename, result, tt.expected) + } + }) + } +} diff --git a/pkg/distribution/huggingface/repository_test.go b/pkg/distribution/huggingface/repository_test.go index b2326e076..16e3a159b 100644 --- a/pkg/distribution/huggingface/repository_test.go +++ b/pkg/distribution/huggingface/repository_test.go @@ -4,42 +4,8 @@ import ( "testing" ) -func TestClassifyFile(t *testing.T) { - tests := []struct { - name string - filename string - want fileType - }{ - {"safetensors file", "model.safetensors", fileTypeSafetensors}, - {"safetensors uppercase", "model.SAFETENSORS", fileTypeSafetensors}, - {"safetensors mixed case", "Model.SafeTensors", fileTypeSafetensors}, - {"sharded safetensors", "model-00001-of-00003.safetensors", fileTypeSafetensors}, - - {"json config", "config.json", fileTypeConfig}, - {"tokenizer json", "tokenizer.json", fileTypeConfig}, - {"tokenizer config", "tokenizer_config.json", fileTypeConfig}, - {"txt file", "README.txt", fileTypeConfig}, - {"markdown file", "README.md", fileTypeConfig}, - {"vocab file", "vocab.vocab", fileTypeConfig}, - {"jinja template", "chat_template.jinja", fileTypeConfig}, - {"tokenizer model", "tokenizer.model", fileTypeConfig}, - - {"unknown extension", "model.bin", fileTypeUnknown}, - {"python file", "model.py", fileTypeUnknown}, - {"pytorch model", "pytorch_model.bin", fileTypeUnknown}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := classifyFile(tt.filename); got != tt.want { - t.Errorf("classifyFile(%q) = %v, want %v", tt.filename, got, tt.want) - } - }) - } -} - func TestFilterModelFiles(t *testing.T) { - files := []RepoFile{ + repoFiles := []RepoFile{ {Type: "file", Path: "model.safetensors", Size: 1000}, {Type: "file", Path: "config.json", Size: 100}, {Type: "file", Path: "tokenizer.json", Size: 200}, @@ -50,7 +16,7 @@ func TestFilterModelFiles(t *testing.T) { {Type: "file", Path: "model-00002-of-00002.safetensors", Size: 2000}, } - safetensors, configs := FilterModelFiles(files) + safetensors, configs := FilterModelFiles(repoFiles) if len(safetensors) != 3 { t.Errorf("Expected 3 safetensors files, got %d", len(safetensors)) @@ -61,12 +27,12 @@ func TestFilterModelFiles(t *testing.T) { } func TestTotalSize(t *testing.T) { - files := []RepoFile{ + repoFiles := []RepoFile{ {Type: "file", Path: "a.safetensors", Size: 1000}, {Type: "file", Path: "b.safetensors", Size: 2000, LFS: &LFSInfo{Size: 5000}}, } - total := TotalSize(files) + total := TotalSize(repoFiles) if total != 6000 { // 1000 + 5000 (LFS size takes precedence) t.Errorf("TotalSize() = %d, want 6000", total) } diff --git a/pkg/distribution/internal/gguf/create.go b/pkg/distribution/internal/gguf/create.go index acdd1ba4e..f368ca0d1 100644 --- a/pkg/distribution/internal/gguf/create.go +++ b/pkg/distribution/internal/gguf/create.go @@ -2,84 +2,27 @@ package gguf import ( "fmt" - "regexp" - "strings" - "time" + "github.com/docker/model-runner/pkg/distribution/builder" "github.com/docker/model-runner/pkg/distribution/internal/partial" - "github.com/docker/model-runner/pkg/distribution/oci" - "github.com/docker/model-runner/pkg/distribution/types" - parser "github.com/gpustack/gguf-parser-go" ) +// NewModel creates a new GGUF model from a file path. +// It delegates to the unified builder package for model creation. func NewModel(path string) (*Model, error) { - shards := parser.CompleteShardGGUFFilename(path) - if len(shards) == 0 { - shards = []string{path} // single file + // Delegate to builder which handles format detection, shard discovery, and config extraction + b, err := builder.FromPath(path) + if err != nil { + return nil, fmt.Errorf("create model from path: %w", err) } - layers := make([]oci.Layer, len(shards)) - diffIDs := make([]oci.Hash, len(shards)) - for i, shard := range shards { - layer, err := partial.NewLayer(shard, types.MediaTypeGGUF) - if err != nil { - return nil, fmt.Errorf("create gguf layer: %w", err) - } - diffID, err := layer.DiffID() - if err != nil { - return nil, fmt.Errorf("get gguf layer diffID: %w", err) - } - layers[i] = layer - diffIDs[i] = diffID + + // Get the underlying model and wrap it in our type + baseModel, ok := b.Model().(*partial.BaseModel) + if !ok { + return nil, fmt.Errorf("unexpected model type: %T", b.Model()) } - created := time.Now() return &Model{ - BaseModel: partial.BaseModel{ - ModelConfigFile: types.ConfigFile{ - Config: configFromFile(path), - Descriptor: types.Descriptor{ - Created: &created, - }, - RootFS: oci.RootFS{ - Type: "rootfs", - DiffIDs: diffIDs, - }, - }, - LayerList: layers, - }, + BaseModel: *baseModel, }, nil } - -func configFromFile(path string) types.Config { - gguf, err := parser.ParseGGUFFile(path) - if err != nil { - return types.Config{} // continue without metadata - } - return types.Config{ - Format: types.FormatGGUF, - Parameters: normalizeUnitString(gguf.Metadata().Parameters.String()), - Architecture: strings.TrimSpace(gguf.Metadata().Architecture), - Quantization: strings.TrimSpace(gguf.Metadata().FileType.String()), - Size: normalizeUnitString(gguf.Metadata().Size.String()), - GGUF: extractGGUFMetadata(&gguf.Header), - } -} - -var ( - // spaceBeforeUnitRegex matches one or more spaces between a valid number and a letter (unit) - // Used to remove spaces between numbers and units (e.g., "16.78 M" -> "16.78M") - // Pattern: integer or decimal number, then whitespace, then letters (unit) - spaceBeforeUnitRegex = regexp.MustCompile(`([0-9]+(?:\.[0-9]+)?)\s+([A-Za-z]+)`) -) - -// normalizeUnitString removes spaces between numbers and units for consistent formatting -// Examples: "16.78 M" -> "16.78M", "256.35 MiB" -> "256.35MiB", "409M" -> "409M" -func normalizeUnitString(s string) string { - s = strings.TrimSpace(s) - if s == "" { - return s - } - // Remove space(s) between numbers/decimals and unit letters using regex - // Pattern matches: number(s) or decimal, then whitespace, then letters (unit) - return spaceBeforeUnitRegex.ReplaceAllString(s, "$1$2") -} diff --git a/pkg/distribution/internal/safetensors/create.go b/pkg/distribution/internal/safetensors/create.go index 500e407cc..e37ecff3f 100644 --- a/pkg/distribution/internal/safetensors/create.go +++ b/pkg/distribution/internal/safetensors/create.go @@ -2,161 +2,40 @@ package safetensors import ( "fmt" - "os" - "path/filepath" - "regexp" - "strconv" - "time" + "github.com/docker/model-runner/pkg/distribution/builder" "github.com/docker/model-runner/pkg/distribution/internal/partial" - "github.com/docker/model-runner/pkg/distribution/oci" - "github.com/docker/model-runner/pkg/distribution/types" ) -var ( - // shardPattern matches safetensors shard filenames like "model-00001-of-00003.safetensors" - // This pattern assumes 5-digit zero-padded numbering (e.g., 00001-of-00003), which is - // the most common format used by popular model repositories. - // The pattern enforces consistent padding width for both the shard number and total count. - shardPattern = regexp.MustCompile(`^(.+)-(\d{5})-of-(\d{5})\.safetensors$`) -) - -// NewModel creates a new safetensors model from one or more safetensors files -// If a sharded model pattern is detected (e.g., model-00001-of-00002.safetensors), -// it will auto-discover all related shards +// NewModel creates a new safetensors model from one or more safetensors files. +// It delegates to the unified builder package for model creation. func NewModel(paths []string) (*Model, error) { if len(paths) == 0 { return nil, fmt.Errorf("at least one safetensors file is required") } - // Auto-discover shards if the first path matches the shard pattern - allPaths, err := discoverSafetensorsShards(paths[0]) - if err != nil { - return nil, fmt.Errorf("discover safetensors shards: %w", err) - } - if len(allPaths) == 0 { - // Not a sharded file, use provided paths as-is - allPaths = paths - } + // Delegate to builder which handles format detection, shard discovery, and config extraction + // Use FromPath for single path (will auto-discover shards) + // Use FromPaths for multiple explicit paths + var b *builder.Builder + var err error - layers := make([]oci.Layer, len(allPaths)) - diffIDs := make([]oci.Hash, len(allPaths)) - - for i, path := range allPaths { - layer, layerErr := partial.NewLayer(path, types.MediaTypeSafetensors) - if layerErr != nil { - return nil, fmt.Errorf("create safetensors layer from %q: %w", path, layerErr) - } - diffID, diffIDErr := layer.DiffID() - if diffIDErr != nil { - return nil, fmt.Errorf("get safetensors layer diffID: %w", diffIDErr) - } - layers[i] = layer - diffIDs[i] = diffID + if len(paths) == 1 { + b, err = builder.FromPath(paths[0]) + } else { + b, err = builder.FromPaths(paths) } - - config, err := configFromFiles(allPaths) if err != nil { - return nil, fmt.Errorf("create config from files: %w", err) - } - - created := time.Now() - return &Model{ - BaseModel: partial.BaseModel{ - ModelConfigFile: types.ConfigFile{ - Config: config, - Descriptor: types.Descriptor{ - Created: &created, - }, - RootFS: oci.RootFS{ - Type: "rootfs", - DiffIDs: diffIDs, - }, - }, - LayerList: layers, - }, - }, nil -} - -// discoverSafetensorsShards attempts to auto-discover all shards for a given safetensors file -// It looks for the pattern: -XXXXX-of-YYYYY.safetensors -// Returns (nil, nil) for single-file models, (paths, nil) for complete shard sets, -// or (nil, error) for incomplete shard sets -func discoverSafetensorsShards(path string) ([]string, error) { - baseName := filepath.Base(path) - matches := shardPattern.FindStringSubmatch(baseName) - - if len(matches) != 4 { - // Not a sharded file, return empty slice with no error - return nil, nil + return nil, fmt.Errorf("create model from paths: %w", err) } - prefix := matches[1] - totalShards, err := strconv.Atoi(matches[3]) - if err != nil { - return nil, fmt.Errorf("parse shard count: %w", err) + // Get the underlying model and wrap it in our type + baseModel, ok := b.Model().(*partial.BaseModel) + if !ok { + return nil, fmt.Errorf("unexpected model type: %T", b.Model()) } - dir := filepath.Dir(path) - var shards []string - - // Look for all shards in the same directory - for i := 1; i <= totalShards; i++ { - shardName := fmt.Sprintf("%s-%05d-of-%05d.safetensors", prefix, i, totalShards) - shardPath := filepath.Join(dir, shardName) - - // Check if the file exists - if _, err := os.Stat(shardPath); err == nil { - shards = append(shards, shardPath) - } - } - - // Return error if we didn't find all expected shards - if len(shards) != totalShards { - return nil, fmt.Errorf("incomplete shard set: found %d of %d shards for %s", len(shards), totalShards, baseName) - } - - // Shards are already in order due to sequential loop - return shards, nil -} - -func configFromFiles(paths []string) (types.Config, error) { - // Parse the first safetensors file to extract metadata - if len(paths) == 0 { - return types.Config{Format: types.FormatSafetensors}, nil - } - - header, err := ParseSafetensorsHeader(paths[0]) - if err != nil { - // Continue without metadata if parsing fails - return types.Config{Format: types.FormatSafetensors}, nil - } - - // Calculate total size across all files - var totalSize int64 - for _, path := range paths { - info, err := os.Stat(path) - if err != nil { - return types.Config{}, fmt.Errorf("failed to stat file %s: %w", path, err) - } - totalSize += info.Size() - } - - // Calculate parameters - params := header.CalculateParameters() - - // Extract architecture from metadata if available - architecture := "" - if arch, ok := header.Metadata["architecture"]; ok { - architecture = fmt.Sprintf("%v", arch) - } - - return types.Config{ - Format: types.FormatSafetensors, - Parameters: formatParameters(params), - Quantization: header.GetQuantization(), - Size: formatSize(totalSize), - Architecture: architecture, - Safetensors: header.ExtractMetadata(), + return &Model{ + BaseModel: *baseModel, }, nil } diff --git a/pkg/distribution/internal/safetensors/metadata.go b/pkg/distribution/internal/safetensors/metadata.go index 9fea9e585..e6d2de104 100644 --- a/pkg/distribution/internal/safetensors/metadata.go +++ b/pkg/distribution/internal/safetensors/metadata.go @@ -6,8 +6,6 @@ import ( "fmt" "io" "os" - - "github.com/docker/go-units" ) const ( @@ -182,15 +180,3 @@ func (h *Header) ExtractMetadata() map[string]string { return metadata } - -// formatParameters converts parameter count to human-readable format -// Returns format like "361.82M" or "1.5B" (no space before unit, base 1000, where B = Billion) -func formatParameters(params int64) string { - return units.CustomSize("%.2f%s", float64(params), 1000.0, []string{"", "K", "M", "B", "T"}) -} - -// formatSize converts bytes to human-readable format matching Docker's style -// Returns format like "256MB" (decimal units, no space, matching `docker images`) -func formatSize(bytes int64) string { - return units.CustomSize("%.2f%s", float64(bytes), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB"}) -} diff --git a/pkg/distribution/oci/remote/remote.go b/pkg/distribution/oci/remote/remote.go index 4122f4da4..9b58d9910 100644 --- a/pkg/distribution/oci/remote/remote.go +++ b/pkg/distribution/oci/remote/remote.go @@ -299,111 +299,6 @@ type remoteImage struct { store content.Store ctx context.Context mu sync.Mutex - httpClient *http.Client - authorizer docker.Authorizer - plainHTTP bool -} - -// manifestFetcher wraps a fetcher to handle manifest fetches specially. -// Some registries (like HuggingFace) don't serve manifests via /blobs/ endpoint, -// only via /manifests/ endpoint. This fetcher detects manifest media types and -// fetches them from the correct endpoint. -type manifestFetcher struct { - underlying remotes.Fetcher - ref reference.Reference - httpClient *http.Client - authorizer docker.Authorizer - plainHTTP bool -} - -// isManifestMediaType returns true if the media type indicates a manifest. -func isManifestMediaType(mediaType string) bool { - switch mediaType { - case "application/vnd.oci.image.manifest.v1+json", - "application/vnd.oci.image.index.v1+json", - "application/vnd.docker.distribution.manifest.v2+json", - "application/vnd.docker.distribution.manifest.list.v2+json", - "application/vnd.docker.distribution.manifest.v1+json", - "application/vnd.docker.distribution.manifest.v1+prettyjws": - return true - } - return false -} - -// isHuggingFaceRegistry returns true if the host is a HuggingFace registry. -// HuggingFace doesn't serve manifests via /blobs/ endpoint, only via /manifests/. -func isHuggingFaceRegistry(host string) bool { - return strings.Contains(host, "huggingface.co") || strings.Contains(host, "hf.co") -} - -// Fetch fetches content by descriptor. For manifests, it uses /manifests/ endpoint -// to support registries like HuggingFace that don't serve manifests via /blobs/. -// For HuggingFace, we try /manifests/ first for ALL content types since they don't -// serve any manifest-like content via /blobs/. -func (f *manifestFetcher) Fetch(ctx context.Context, desc v1.Descriptor) (io.ReadCloser, error) { - registry := f.ref.Context().Registry - isHF := isHuggingFaceRegistry(registry.RegistryStr()) - - // For HuggingFace, try /manifests/ first for any JSON-like content - // since they don't serve manifests via /blobs/ at all - shouldUseManifestEndpoint := isHF && isManifestMediaType(desc.MediaType) - - // For non-manifest content on non-HF registries, use the underlying fetcher - if !shouldUseManifestEndpoint { - return f.underlying.Fetch(ctx, desc) - } - - // For manifests, fetch via /manifests/ endpoint to support HuggingFace - // Build the manifest URL: /v2//manifests/ - repo := f.ref.Context().RepositoryStr() - - // Determine scheme based on plainHTTP flag or registry's default scheme - scheme := registry.Scheme() - if f.plainHTTP { - scheme = "http" - } - - // For HuggingFace, use tag instead of digest because HF doesn't support - // fetching manifests by digest, only by tag - manifestRef := f.ref.Identifier() - if manifestRef == "" { - manifestRef = "latest" - } - - url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", - scheme, - registry.RegistryStr(), - repo, - manifestRef) - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) - if err != nil { - return nil, fmt.Errorf("creating manifest request: %w", err) - } - - // Set Accept header for the manifest media type - req.Header.Set("Accept", desc.MediaType) - - // Add authorization if available - if f.authorizer != nil { - if err := f.authorizer.Authorize(ctx, req); err != nil { - return nil, fmt.Errorf("authorizing manifest request: %w", err) - } - } - - resp, err := f.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("fetching manifest: %w", err) - } - - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - // If manifest endpoint fails, fall back to underlying fetcher (which uses /blobs/) - // This handles registries that do serve manifests via /blobs/ - return f.underlying.Fetch(ctx, desc) - } - - return resp.Body, nil } // resolverComponents holds the components created for a resolver. @@ -487,14 +382,11 @@ func Image(ref reference.Reference, opts ...Option) (oci.Image, error) { } return &remoteImage{ - ref: ref, - resolver: components.resolver, - desc: desc, - store: store, - ctx: o.ctx, - httpClient: components.httpClient, - authorizer: components.authorizer, - plainHTTP: components.plainHTTP, + ref: ref, + resolver: components.resolver, + desc: desc, + store: store, + ctx: o.ctx, }, nil } @@ -507,21 +399,11 @@ func (i *remoteImage) fetchManifest() error { return nil } - underlyingFetcher, err := i.resolver.Fetcher(i.ctx, i.ref.String()) + fetcher, err := i.resolver.Fetcher(i.ctx, i.ref.String()) if err != nil { return fmt.Errorf("getting fetcher: %w", err) } - // Wrap with manifest-aware fetcher to handle registries like HuggingFace - // that don't serve manifests via /blobs/ endpoint - fetcher := &manifestFetcher{ - underlying: underlyingFetcher, - ref: i.ref, - httpClient: i.httpClient, - authorizer: i.authorizer, - plainHTTP: i.plainHTTP, - } - // Fetch manifest rc, err := fetcher.Fetch(i.ctx, i.desc) if err != nil { diff --git a/pkg/distribution/packaging/safetensors.go b/pkg/distribution/packaging/safetensors.go index 40cac9843..2a5be421d 100644 --- a/pkg/distribution/packaging/safetensors.go +++ b/pkg/distribution/packaging/safetensors.go @@ -7,14 +7,9 @@ import ( "os" "path/filepath" "sort" - "strings" -) - -// ConfigExtensions defines the file extensions that should be treated as config files -var ConfigExtensions = []string{".md", ".txt", ".json", ".vocab", ".jinja"} -// SpecialConfigFiles are specific filenames treated as config files -var SpecialConfigFiles = []string{"tokenizer.model"} + "github.com/docker/model-runner/pkg/distribution/files" +) // PackageFromDirectory scans a directory for safetensors files and config files, // creating a temporary tar archive of the config files. @@ -37,15 +32,16 @@ func PackageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfig name := entry.Name() fullPath := filepath.Join(dirPath, name) - // Collect safetensors files - lower := strings.ToLower(name) - if strings.HasSuffix(lower, ".safetensors") { - safetensorsPaths = append(safetensorsPaths, fullPath) - } + // Classify file using centralized files package + fileType := files.Classify(name) - // Collect config files - if isConfigFile(name) { + switch fileType { + case files.FileTypeSafetensors: + safetensorsPaths = append(safetensorsPaths, fullPath) + case files.FileTypeConfig, files.FileTypeChatTemplate: configFiles = append(configFiles, fullPath) + case files.FileTypeUnknown, files.FileTypeGGUF, files.FileTypeLicense: + // Skip these file types } } @@ -159,22 +155,3 @@ func addFileToTar(tw *tar.Writer, filePath string) error { return nil } - -// isConfigFile checks if a file should be included as a config file based on its name. -// It checks for extensions listed in ConfigExtensions and the special case of the tokenizer.model file. -func isConfigFile(name string) bool { - lower := strings.ToLower(name) - for _, ext := range ConfigExtensions { - if strings.HasSuffix(lower, ext) { - return true - } - } - - for _, special := range SpecialConfigFiles { - if strings.EqualFold(name, special) { - return true - } - } - - return false -} diff --git a/pkg/distribution/tarball/target_test.go b/pkg/distribution/tarball/target_test.go index 0474d478f..449a2e885 100644 --- a/pkg/distribution/tarball/target_test.go +++ b/pkg/distribution/tarball/target_test.go @@ -8,7 +8,7 @@ import ( "path/filepath" "testing" - "github.com/docker/model-runner/pkg/distribution/internal/gguf" + "github.com/docker/model-runner/pkg/distribution/builder" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/tarball" ) @@ -27,10 +27,11 @@ func TestTarget(t *testing.T) { t.Fatalf("Failed to create tar target: %v", err) } - mdl, err := gguf.NewModel(filepath.Join("..", "assets", "dummy.gguf")) + b, err := builder.FromPath(filepath.Join("..", "assets", "dummy.gguf")) if err != nil { t.Fatalf("Failed to create model: %v", err) } + mdl := b.Model() blobContents, err := os.ReadFile(filepath.Join("..", "assets", "dummy.gguf")) if err != nil { diff --git a/pkg/inference/models/handler_test.go b/pkg/inference/models/handler_test.go index 1cf9e2f0e..0f520a6f9 100644 --- a/pkg/inference/models/handler_test.go +++ b/pkg/inference/models/handler_test.go @@ -61,7 +61,7 @@ func TestPullModel(t *testing.T) { // Prepare the OCI model artifact projectRoot := getProjectRoot(t) - model, err := builder.FromGGUF(filepath.Join(projectRoot, "assets", "dummy.gguf")) + model, err := builder.FromPath(filepath.Join(projectRoot, "assets", "dummy.gguf")) if err != nil { t.Fatalf("Failed to create model builder: %v", err) } @@ -159,7 +159,7 @@ func TestHandleGetModel(t *testing.T) { // Prepare the OCI model artifact projectRoot := getProjectRoot(t) - model, err := builder.FromGGUF(filepath.Join(projectRoot, "assets", "dummy.gguf")) + model, err := builder.FromPath(filepath.Join(projectRoot, "assets", "dummy.gguf")) if err != nil { t.Fatalf("Failed to create model builder: %v", err) } diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index 529a93e2e..1eb6c9ff6 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -81,7 +81,6 @@ func (m *Manager) GetLocal(ref string) (types.Model, error) { return nil, fmt.Errorf("model distribution service unavailable") } - // Query the model - first try without normalization (as ID), then with normalization model, err := m.distributionClient.GetModel(ref) if err != nil { return nil, fmt.Errorf("error while getting model: %w", err)