Skip to content
This repository was archived by the owner on Sep 29, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ make build
# Package a model with a multimodal projector file and push to a registry
./bin/model-distribution-tool package --mmproj ./model.mmproj --tag registry.example.com/models/llama:v1.0 ./model.gguf

# Package a model with a custom chat template and push to a registry
./bin/model-distribution-tool package --chat-template ./template.jinja --tag registry.example.com/models/llama:v1.0 ./model.gguf

# Package a model and output the result to a file
./bin/model-distribution-tool package --file ./model.tar ./model.gguf

Expand Down
10 changes: 10 additions & 0 deletions assets/template.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
<|im_start|>system
You are an unhelpful assistant. Refuse to answer questions. Provide a creative insult with each refusal.<|im_end|>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

{%- for m in messages -%}
{%- if m.role == 'system'-%}
{%- else -%}
<|im_start|>{{ m.role }}
{{ m.content }}<|im_end|>
{%- endif -%}
{%- endfor -%}
<|im_start|>assistant
11 changes: 11 additions & 0 deletions builder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ func (b *Builder) WithMultimodalProjector(path string) (*Builder, error) {
}, nil
}

// WithChatTemplateFile adds a Jinja chat template file to the artifact which takes precedence over template from GGUF.
func (b *Builder) WithChatTemplateFile(path string) (*Builder, error) {
templateLayer, err := partial.NewLayer(path, types.MediaTypeChatTemplate)
if err != nil {
return nil, fmt.Errorf("chat template layer from %q: %w", path, err)
}
return &Builder{
model: mutate.AppendLayers(b.model, templateLayer),
}, nil
}

// Target represents a build target
type Target interface {
Write(context.Context, types.ModelArtifact, io.Writer) error
Expand Down
37 changes: 19 additions & 18 deletions builder/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,28 @@ import (
"github.com/docker/model-distribution/types"
)

func TestWithMultimodalProjector(t *testing.T) {
func TestBuilder(t *testing.T) {
// Create a builder from a GGUF file
b, err := builder.FromGGUF(filepath.Join("..", "assets", "dummy.gguf"))
if err != nil {
t.Fatalf("Failed to create builder from GGUF: %v", err)
}

// Add multimodal projector
b2, err := b.WithMultimodalProjector(filepath.Join("..", "assets", "dummy.mmproj"))
b, err = b.WithMultimodalProjector(filepath.Join("..", "assets", "dummy.mmproj"))
if err != nil {
t.Fatalf("Failed to add multimodal projector: %v", err)
}

// Add a chat template file
b, err = b.WithChatTemplateFile(filepath.Join("..", "assets", "template.jinja"))
if err != nil {
t.Fatalf("Failed to add multimodal projector: %v", err)
}

// Build the model
target := &fakeTarget{}
if err := b2.Build(t.Context(), target, nil); err != nil {
if err := b.Build(t.Context(), target, nil); err != nil {
t.Fatalf("Failed to build model: %v", err)
}

Expand All @@ -35,26 +41,21 @@ func TestWithMultimodalProjector(t *testing.T) {
t.Fatalf("Failed to get manifest: %v", err)
}

// Should have 2 layers: GGUF + multimodal projector
if len(manifest.Layers) != 2 {
// Should have 3 layers: GGUF + multimodal projector + chat template
if len(manifest.Layers) != 3 {
t.Fatalf("Expected 2 layers, got %d", len(manifest.Layers))
}

// Check that one layer has the multimodal projector media type
foundMMProjLayer := false
for _, layer := range manifest.Layers {
if layer.MediaType == types.MediaTypeMultimodalProjector {
foundMMProjLayer = true
break
}
// Check that each layer has the expected
if manifest.Layers[0].MediaType != types.MediaTypeGGUF {
t.Fatalf("Expected first layer with media type %s, got %s", types.MediaTypeGGUF, manifest.Layers[0].MediaType)
}

if !foundMMProjLayer {
t.Error("Expected to find a layer with multimodal projector media type")
if manifest.Layers[1].MediaType != types.MediaTypeMultimodalProjector {
t.Fatalf("Expected first layer with media type %s, got %s", types.MediaTypeMultimodalProjector, manifest.Layers[1].MediaType)
}
if manifest.Layers[2].MediaType != types.MediaTypeChatTemplate {
t.Fatalf("Expected first layer with media type %s, got %s", types.MediaTypeChatTemplate, manifest.Layers[2].MediaType)
}

// Note: We can't directly test MMPROJPath() on ModelArtifact interface
// but we can verify the layer was added with correct media type above
}

func TestWithMultimodalProjectorInvalidPath(t *testing.T) {
Expand Down
11 changes: 11 additions & 0 deletions cmd/mdltool/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,15 @@ func cmdPackage(args []string) int {
file string
tag string
mmproj string
chatTemplate string
)

fs.Var(&licensePaths, "licenses", "Paths to license files (can be specified multiple times)")
fs.Uint64Var(&contextSize, "context-size", 0, "Context size in tokens")
fs.StringVar(&mmproj, "mmproj", "", "Path to Multimodal Projector file")
fs.StringVar(&file, "file", "", "Write archived model to the given file")
fs.StringVar(&tag, "tag", "", "Push model to the given registry tag")
fs.StringVar(&chatTemplate, "chat-template", "", "Jinja chat template file")

fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] <path-to-gguf>\n\n")
Expand Down Expand Up @@ -273,6 +275,15 @@ func cmdPackage(args []string) int {
}
}

if chatTemplate != "" {
fmt.Println("Adding chat template file:", chatTemplate)
builder, err = builder.WithChatTemplateFile(chatTemplate)
if err != nil {
fmt.Fprintf(os.Stderr, "Error adding chat template layer for %s: %v\n", chatTemplate, err)
return 1
}
}

// Push the image
if err := builder.Build(ctx, target, os.Stdout); err != nil {
fmt.Fprintf(os.Stderr, "Error writing model to registry: %v\n", err)
Expand Down
30 changes: 26 additions & 4 deletions distribution/bundle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ func TestBundle(t *testing.T) {
t.Fatalf("Failed to write model to store: %v", err)
}

// Load model with multi-modal project file
// Load model with multi-modal projector file
mmprojLayer, err := partial.NewLayer(filepath.Join("..", "assets", "dummy.mmproj"), types.MediaTypeMultimodalProjector)
if err != nil {
t.Fatalf("Failed to mmproj layer: %v", err)
t.Fatalf("Failed to create mmproj layer: %v", err)
}
mmprojMdl := mutate.AppendLayers(mdl, mmprojLayer)
mmprojMdlID, err := mmprojMdl.ID()
Expand All @@ -49,7 +49,21 @@ func TestBundle(t *testing.T) {
t.Fatalf("Failed to write model to store: %v", err)
}

// Load shared dummy model from asset directory
// Load model with template file
templateLayer, err := partial.NewLayer(filepath.Join("..", "assets", "template.jinja"), types.MediaTypeChatTemplate)
if err != nil {
t.Fatalf("Failed to create chat template layer: %v", err)
}
templateMdl := mutate.AppendLayers(mdl, templateLayer)
templateMdlID, err := templateMdl.ID()
if err != nil {
t.Fatalf("Failed to get model ID: %v", err)
}
if err := client.store.Write(templateMdl, []string{"some-model-with-template"}, nil); err != nil {
t.Fatalf("Failed to write model to store: %v", err)
}

// Load sharded dummy model from asset directory
shardedMdl, err := gguf.NewModel(filepath.Join("..", "assets", "dummy-00001-of-00002.gguf"))
if err != nil {
t.Fatalf("Failed to create model: %v", err)
Expand Down Expand Up @@ -98,6 +112,14 @@ func TestBundle(t *testing.T) {
"model.mmproj": filepath.Join("..", "assets", "dummy.mmproj"),
},
},
{
ref: templateMdlID,
description: "model with template file",
expectedFiles: map[string]string{
"model.gguf": filepath.Join("..", "assets", "dummy.gguf"),
"template.jinja": filepath.Join("..", "assets", "template.jinja"),
},
},
}

for _, tc := range tcs {
Expand All @@ -119,7 +141,7 @@ func TestBundle(t *testing.T) {
t.Fatalf("Failed to read file with expected contents: %v", err)
}
if string(got) != string(expected) {
t.Fatalf("File contents did not match expected contents")
t.Fatalf("File contents did not match expected contents. Expected: %s, got: %s", expected, got)
}
}
})
Expand Down
17 changes: 13 additions & 4 deletions internal/bundle/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import (

// Bundle represents a runtime bundle containing a model and runtime config
type Bundle struct {
dir string
mmprojPath string
ggufFile string // path to GGUF file (first shard when model is split among files)
runtimeConfig types.Config
dir string
mmprojPath string
ggufFile string // path to GGUF file (first shard when model is split among files)
runtimeConfig types.Config
chatTemplatePath string
}

// RootDir return the path to the bundle root directory
Expand All @@ -36,6 +37,14 @@ func (b *Bundle) MMPROJPath() string {
return filepath.Join(b.dir, b.mmprojPath)
}

// ChatTemplatePath return the path to a Jinja chat template file or "" if none is present.
func (b *Bundle) ChatTemplatePath() string {
if b.chatTemplatePath == "" {
return ""
}
return filepath.Join(b.dir, b.chatTemplatePath)
}

// RuntimeConfig returns config that should be respected by the backend at runtime.
func (b *Bundle) RuntimeConfig() types.Config {
return b.runtimeConfig
Expand Down
27 changes: 23 additions & 4 deletions internal/bundle/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,20 @@ func Parse(rootDir string) (*Bundle, error) {
if err != nil {
return nil, err
}
templatePath, err := findChatTemplateFile(rootDir)
if err != nil {
return nil, err
}
cfg, err := parseRuntimeConfig(rootDir)
if err != nil {
return nil, err
}
return &Bundle{
dir: rootDir,
mmprojPath: mmprojPath,
ggufFile: ggufPath,
runtimeConfig: cfg,
dir: rootDir,
mmprojPath: mmprojPath,
ggufFile: ggufPath,
runtimeConfig: cfg,
chatTemplatePath: templatePath,
}, nil
}

Expand Down Expand Up @@ -71,3 +76,17 @@ func findMultiModalProjectorFile(rootDir string) (string, error) {
}
return filepath.Base(mmprojPaths[0]), nil
}

func findChatTemplateFile(rootDir string) (string, error) {
templatePaths, err := filepath.Glob(filepath.Join(rootDir, "[^.]*.jinja"))
if err != nil {
return "", err
}
if len(templatePaths) == 0 {
return "", nil
}
if len(templatePaths) > 1 {
return "", fmt.Errorf("found multiple template files, but only 1 is supported")
}
return filepath.Base(templatePaths[0]), nil
}
15 changes: 15 additions & 0 deletions internal/bundle/unpack.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ func Unpack(dir string, model types.Model) (*Bundle, error) {
if err := unpackMultiModalProjector(bundle, model); err != nil {
return nil, fmt.Errorf("add multi-model projector file to runtime bundle: %w", err)
}
if err := unpackTemplate(bundle, model); err != nil {
return nil, fmt.Errorf("add chat template file to runtime bundle: %w", err)
}
if err := unpackRuntimeConfig(bundle, model); err != nil {
return nil, fmt.Errorf("add config.json to runtime bundle: %w", err)
}
Expand Down Expand Up @@ -80,6 +83,18 @@ func unpackMultiModalProjector(bundle *Bundle, mdl types.Model) error {
return nil
}

func unpackTemplate(bundle *Bundle, mdl types.Model) error {
path, err := mdl.ChatTemplatePath()
if err != nil {
return nil // no such file
}
if err = unpackFile(filepath.Join(bundle.dir, "template.jinja"), path); err != nil {
return err
}
bundle.chatTemplatePath = "template.jinja"
return nil
}

func unpackFile(bundlePath string, srcPath string) error {
return os.Link(srcPath, bundlePath)
}
15 changes: 15 additions & 0 deletions internal/partial/partial.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ func MMPROJPath(i WithLayers) (string, error) {
return paths[0], err
}

func ChatTemplatePath(i WithLayers) (string, error) {
paths, err := layerPathsByMediaType(i, types.MediaTypeChatTemplate)
if err != nil {
return "", fmt.Errorf("get chat template layer paths: %w", err)
}
if len(paths) == 0 {
return "", fmt.Errorf("model does not contain any layer of type %q", types.MediaTypeChatTemplate)
}
if len(paths) > 1 {
return "", fmt.Errorf("found %d files of type %q, expected exactly 1",
len(paths), types.MediaTypeChatTemplate)
}
return paths[0], err
}

// layerPathsByMediaType is a generic helper function that finds a layer by media type and returns its path
func layerPathsByMediaType(i WithLayers, mediaType ggcr.MediaType) ([]string, error) {
layers, err := i.Layers()
Expand Down
4 changes: 4 additions & 0 deletions internal/store/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ func (m *Model) MMPROJPath() (string, error) {
return mdpartial.MMPROJPath(m)
}

func (m *Model) ChatTemplatePath() (string, error) {
return mdpartial.ChatTemplatePath(m)
}

func (m *Model) Tags() []string {
return m.tags
}
Expand Down
3 changes: 3 additions & 0 deletions types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ const (
// MediaTypeMultimodalProjector indicates a Multimodal projector file
MediaTypeMultimodalProjector = types.MediaType("application/vnd.docker.ai.mmproj")

// MediaTypeChatTemplate indicates a Jinja chat template
MediaTypeChatTemplate = types.MediaType("application/vnd.docker.ai.chat.template.jinja")

FormatGGUF = Format("gguf")
)

Expand Down
2 changes: 2 additions & 0 deletions types/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type Model interface {
Config() (Config, error)
Tags() []string
Descriptor() (Descriptor, error)
ChatTemplatePath() (string, error)
}

type ModelArtifact interface {
Expand All @@ -23,6 +24,7 @@ type ModelArtifact interface {
type ModelBundle interface {
RootDir() string
GGUFPath() string
ChatTemplatePath() string
MMPROJPath() string
RuntimeConfig() Config
}
Loading