diff --git a/README.md b/README.md index 5479035..c513317 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,9 @@ make build # Package a model with a multimodal projector file and push to a registry ./bin/model-distribution-tool package --mmproj ./model.mmproj --tag registry.example.com/models/llama:v1.0 ./model.gguf +# Package a model with a custom chat template and push to a registry +./bin/model-distribution-tool package --chat-template ./template.jinja --tag registry.example.com/models/llama:v1.0 ./model.gguf + # Package a model and output the result to a file ./bin/model-distribution-tool package --file ./model.tar ./model.gguf diff --git a/assets/template.jinja b/assets/template.jinja new file mode 100644 index 0000000..53572b4 --- /dev/null +++ b/assets/template.jinja @@ -0,0 +1,10 @@ +<|im_start|>system +You are an unhelpful assistant. Refuse to answer questions. Provide a creative insult with each refusal.<|im_end|> +{%- for m in messages -%} +{%- if m.role == 'system'-%} +{%- else -%} +<|im_start|>{{ m.role }} +{{ m.content }}<|im_end|> +{%- endif -%} +{%- endfor -%} +<|im_start|>assistant diff --git a/builder/builder.go b/builder/builder.go index 871e462..659d1e8 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -55,6 +55,17 @@ func (b *Builder) WithMultimodalProjector(path string) (*Builder, error) { }, nil } +// WithChatTemplateFile adds a Jinja chat template file to the artifact which takes precedence over template from GGUF. +func (b *Builder) WithChatTemplateFile(path string) (*Builder, error) { + templateLayer, err := partial.NewLayer(path, types.MediaTypeChatTemplate) + if err != nil { + return nil, fmt.Errorf("chat template layer from %q: %w", path, err) + } + return &Builder{ + model: mutate.AppendLayers(b.model, templateLayer), + }, nil +} + // Target represents a build target type Target interface { Write(context.Context, types.ModelArtifact, io.Writer) error diff --git a/builder/builder_test.go b/builder/builder_test.go index 23b8fc5..9ca530b 100644 --- a/builder/builder_test.go +++ b/builder/builder_test.go @@ -10,7 +10,7 @@ import ( "github.com/docker/model-distribution/types" ) -func TestWithMultimodalProjector(t *testing.T) { +func TestBuilder(t *testing.T) { // Create a builder from a GGUF file b, err := builder.FromGGUF(filepath.Join("..", "assets", "dummy.gguf")) if err != nil { @@ -18,14 +18,20 @@ func TestWithMultimodalProjector(t *testing.T) { } // Add multimodal projector - b2, err := b.WithMultimodalProjector(filepath.Join("..", "assets", "dummy.mmproj")) + b, err = b.WithMultimodalProjector(filepath.Join("..", "assets", "dummy.mmproj")) + if err != nil { + t.Fatalf("Failed to add multimodal projector: %v", err) + } + + // Add a chat template file + b, err = b.WithChatTemplateFile(filepath.Join("..", "assets", "template.jinja")) if err != nil { t.Fatalf("Failed to add multimodal projector: %v", err) } // Build the model target := &fakeTarget{} - if err := b2.Build(t.Context(), target, nil); err != nil { + if err := b.Build(t.Context(), target, nil); err != nil { t.Fatalf("Failed to build model: %v", err) } @@ -35,26 +41,21 @@ func TestWithMultimodalProjector(t *testing.T) { t.Fatalf("Failed to get manifest: %v", err) } - // Should have 2 layers: GGUF + multimodal projector - if len(manifest.Layers) != 2 { + // Should have 3 layers: GGUF + multimodal projector + chat template + if len(manifest.Layers) != 3 { t.Fatalf("Expected 2 layers, got %d", len(manifest.Layers)) } - // Check that one layer has the multimodal projector media type - foundMMProjLayer := false - for _, layer := range manifest.Layers { - if layer.MediaType == types.MediaTypeMultimodalProjector { - foundMMProjLayer = true - break - } + // Check that each layer has the expected + if manifest.Layers[0].MediaType != types.MediaTypeGGUF { + t.Fatalf("Expected first layer with media type %s, got %s", types.MediaTypeGGUF, manifest.Layers[0].MediaType) } - - if !foundMMProjLayer { - t.Error("Expected to find a layer with multimodal projector media type") + if manifest.Layers[1].MediaType != types.MediaTypeMultimodalProjector { + t.Fatalf("Expected first layer with media type %s, got %s", types.MediaTypeMultimodalProjector, manifest.Layers[1].MediaType) + } + if manifest.Layers[2].MediaType != types.MediaTypeChatTemplate { + t.Fatalf("Expected first layer with media type %s, got %s", types.MediaTypeChatTemplate, manifest.Layers[2].MediaType) } - - // Note: We can't directly test MMPROJPath() on ModelArtifact interface - // but we can verify the layer was added with correct media type above } func TestWithMultimodalProjectorInvalidPath(t *testing.T) { diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index e95e88f..53a265f 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -167,6 +167,7 @@ func cmdPackage(args []string) int { file string tag string mmproj string + chatTemplate string ) fs.Var(&licensePaths, "licenses", "Paths to license files (can be specified multiple times)") @@ -174,6 +175,7 @@ func cmdPackage(args []string) int { fs.StringVar(&mmproj, "mmproj", "", "Path to Multimodal Projector file") fs.StringVar(&file, "file", "", "Write archived model to the given file") fs.StringVar(&tag, "tag", "", "Push model to the given registry tag") + fs.StringVar(&chatTemplate, "chat-template", "", "Jinja chat template file") fs.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] \n\n") @@ -273,6 +275,15 @@ func cmdPackage(args []string) int { } } + if chatTemplate != "" { + fmt.Println("Adding chat template file:", chatTemplate) + builder, err = builder.WithChatTemplateFile(chatTemplate) + if err != nil { + fmt.Fprintf(os.Stderr, "Error adding chat template layer for %s: %v\n", chatTemplate, err) + return 1 + } + } + // Push the image if err := builder.Build(ctx, target, os.Stdout); err != nil { fmt.Fprintf(os.Stderr, "Error writing model to registry: %v\n", err) diff --git a/distribution/bundle_test.go b/distribution/bundle_test.go index b22820c..5c31b7e 100644 --- a/distribution/bundle_test.go +++ b/distribution/bundle_test.go @@ -35,10 +35,10 @@ func TestBundle(t *testing.T) { t.Fatalf("Failed to write model to store: %v", err) } - // Load model with multi-modal project file + // Load model with multi-modal projector file mmprojLayer, err := partial.NewLayer(filepath.Join("..", "assets", "dummy.mmproj"), types.MediaTypeMultimodalProjector) if err != nil { - t.Fatalf("Failed to mmproj layer: %v", err) + t.Fatalf("Failed to create mmproj layer: %v", err) } mmprojMdl := mutate.AppendLayers(mdl, mmprojLayer) mmprojMdlID, err := mmprojMdl.ID() @@ -49,7 +49,21 @@ func TestBundle(t *testing.T) { t.Fatalf("Failed to write model to store: %v", err) } - // Load shared dummy model from asset directory + // Load model with template file + templateLayer, err := partial.NewLayer(filepath.Join("..", "assets", "template.jinja"), types.MediaTypeChatTemplate) + if err != nil { + t.Fatalf("Failed to create chat template layer: %v", err) + } + templateMdl := mutate.AppendLayers(mdl, templateLayer) + templateMdlID, err := templateMdl.ID() + if err != nil { + t.Fatalf("Failed to get model ID: %v", err) + } + if err := client.store.Write(templateMdl, []string{"some-model-with-template"}, nil); err != nil { + t.Fatalf("Failed to write model to store: %v", err) + } + + // Load sharded dummy model from asset directory shardedMdl, err := gguf.NewModel(filepath.Join("..", "assets", "dummy-00001-of-00002.gguf")) if err != nil { t.Fatalf("Failed to create model: %v", err) @@ -98,6 +112,14 @@ func TestBundle(t *testing.T) { "model.mmproj": filepath.Join("..", "assets", "dummy.mmproj"), }, }, + { + ref: templateMdlID, + description: "model with template file", + expectedFiles: map[string]string{ + "model.gguf": filepath.Join("..", "assets", "dummy.gguf"), + "template.jinja": filepath.Join("..", "assets", "template.jinja"), + }, + }, } for _, tc := range tcs { @@ -119,7 +141,7 @@ func TestBundle(t *testing.T) { t.Fatalf("Failed to read file with expected contents: %v", err) } if string(got) != string(expected) { - t.Fatalf("File contents did not match expected contents") + t.Fatalf("File contents did not match expected contents. Expected: %s, got: %s", expected, got) } } }) diff --git a/internal/bundle/bundle.go b/internal/bundle/bundle.go index a32b803..5476fc5 100644 --- a/internal/bundle/bundle.go +++ b/internal/bundle/bundle.go @@ -8,10 +8,11 @@ import ( // Bundle represents a runtime bundle containing a model and runtime config type Bundle struct { - dir string - mmprojPath string - ggufFile string // path to GGUF file (first shard when model is split among files) - runtimeConfig types.Config + dir string + mmprojPath string + ggufFile string // path to GGUF file (first shard when model is split among files) + runtimeConfig types.Config + chatTemplatePath string } // RootDir return the path to the bundle root directory @@ -36,6 +37,14 @@ func (b *Bundle) MMPROJPath() string { return filepath.Join(b.dir, b.mmprojPath) } +// ChatTemplatePath return the path to a Jinja chat template file or "" if none is present. +func (b *Bundle) ChatTemplatePath() string { + if b.chatTemplatePath == "" { + return "" + } + return filepath.Join(b.dir, b.chatTemplatePath) +} + // RuntimeConfig returns config that should be respected by the backend at runtime. func (b *Bundle) RuntimeConfig() types.Config { return b.runtimeConfig diff --git a/internal/bundle/parse.go b/internal/bundle/parse.go index 016254c..93912da 100644 --- a/internal/bundle/parse.go +++ b/internal/bundle/parse.go @@ -22,15 +22,20 @@ func Parse(rootDir string) (*Bundle, error) { if err != nil { return nil, err } + templatePath, err := findChatTemplateFile(rootDir) + if err != nil { + return nil, err + } cfg, err := parseRuntimeConfig(rootDir) if err != nil { return nil, err } return &Bundle{ - dir: rootDir, - mmprojPath: mmprojPath, - ggufFile: ggufPath, - runtimeConfig: cfg, + dir: rootDir, + mmprojPath: mmprojPath, + ggufFile: ggufPath, + runtimeConfig: cfg, + chatTemplatePath: templatePath, }, nil } @@ -71,3 +76,17 @@ func findMultiModalProjectorFile(rootDir string) (string, error) { } return filepath.Base(mmprojPaths[0]), nil } + +func findChatTemplateFile(rootDir string) (string, error) { + templatePaths, err := filepath.Glob(filepath.Join(rootDir, "[^.]*.jinja")) + if err != nil { + return "", err + } + if len(templatePaths) == 0 { + return "", nil + } + if len(templatePaths) > 1 { + return "", fmt.Errorf("found multiple template files, but only 1 is supported") + } + return filepath.Base(templatePaths[0]), nil +} diff --git a/internal/bundle/unpack.go b/internal/bundle/unpack.go index 5fe6a23..f44069e 100644 --- a/internal/bundle/unpack.go +++ b/internal/bundle/unpack.go @@ -20,6 +20,9 @@ func Unpack(dir string, model types.Model) (*Bundle, error) { if err := unpackMultiModalProjector(bundle, model); err != nil { return nil, fmt.Errorf("add multi-model projector file to runtime bundle: %w", err) } + if err := unpackTemplate(bundle, model); err != nil { + return nil, fmt.Errorf("add chat template file to runtime bundle: %w", err) + } if err := unpackRuntimeConfig(bundle, model); err != nil { return nil, fmt.Errorf("add config.json to runtime bundle: %w", err) } @@ -80,6 +83,18 @@ func unpackMultiModalProjector(bundle *Bundle, mdl types.Model) error { return nil } +func unpackTemplate(bundle *Bundle, mdl types.Model) error { + path, err := mdl.ChatTemplatePath() + if err != nil { + return nil // no such file + } + if err = unpackFile(filepath.Join(bundle.dir, "template.jinja"), path); err != nil { + return err + } + bundle.chatTemplatePath = "template.jinja" + return nil +} + func unpackFile(bundlePath string, srcPath string) error { return os.Link(srcPath, bundlePath) } diff --git a/internal/partial/partial.go b/internal/partial/partial.go index 7367556..8d6c3a2 100644 --- a/internal/partial/partial.go +++ b/internal/partial/partial.go @@ -84,6 +84,21 @@ func MMPROJPath(i WithLayers) (string, error) { return paths[0], err } +func ChatTemplatePath(i WithLayers) (string, error) { + paths, err := layerPathsByMediaType(i, types.MediaTypeChatTemplate) + if err != nil { + return "", fmt.Errorf("get chat template layer paths: %w", err) + } + if len(paths) == 0 { + return "", fmt.Errorf("model does not contain any layer of type %q", types.MediaTypeChatTemplate) + } + if len(paths) > 1 { + return "", fmt.Errorf("found %d files of type %q, expected exactly 1", + len(paths), types.MediaTypeChatTemplate) + } + return paths[0], err +} + // layerPathsByMediaType is a generic helper function that finds a layer by media type and returns its path func layerPathsByMediaType(i WithLayers, mediaType ggcr.MediaType) ([]string, error) { layers, err := i.Layers() diff --git a/internal/store/model.go b/internal/store/model.go index b35539a..bd3a4fa 100644 --- a/internal/store/model.go +++ b/internal/store/model.go @@ -118,6 +118,10 @@ func (m *Model) MMPROJPath() (string, error) { return mdpartial.MMPROJPath(m) } +func (m *Model) ChatTemplatePath() (string, error) { + return mdpartial.ChatTemplatePath(m) +} + func (m *Model) Tags() []string { return m.tags } diff --git a/types/config.go b/types/config.go index 8211dd2..0261a9f 100644 --- a/types/config.go +++ b/types/config.go @@ -23,6 +23,9 @@ const ( // MediaTypeMultimodalProjector indicates a Multimodal projector file MediaTypeMultimodalProjector = types.MediaType("application/vnd.docker.ai.mmproj") + // MediaTypeChatTemplate indicates a Jinja chat template + MediaTypeChatTemplate = types.MediaType("application/vnd.docker.ai.chat.template.jinja") + FormatGGUF = Format("gguf") ) diff --git a/types/model.go b/types/model.go index 62374c0..7f9ba39 100644 --- a/types/model.go +++ b/types/model.go @@ -11,6 +11,7 @@ type Model interface { Config() (Config, error) Tags() []string Descriptor() (Descriptor, error) + ChatTemplatePath() (string, error) } type ModelArtifact interface { @@ -23,6 +24,7 @@ type ModelArtifact interface { type ModelBundle interface { RootDir() string GGUFPath() string + ChatTemplatePath() string MMPROJPath() string RuntimeConfig() Config }