Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.23.7
require (
github.com/containerd/containerd/v2 v2.0.4
github.com/containerd/platforms v1.0.0-rc.1
github.com/docker/model-distribution v0.0.0-20250618082521-fb5c8332c857
github.com/docker/model-distribution v0.0.0-20250627163720-aff34abcf3e0
github.com/google/go-containerregistry v0.20.3
github.com/jaypipes/ghw v0.16.0
github.com/mattn/go-shellwords v1.0.12
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBi
github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
github.com/docker/docker-credential-helpers v0.8.2 h1:bX3YxiGzFP5sOXWc3bTPEXdEaZSeVMrFgOr3T+zrFAo=
github.com/docker/docker-credential-helpers v0.8.2/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M=
github.com/docker/model-distribution v0.0.0-20250618082521-fb5c8332c857 h1:2IvvpdPZvpNn06+RUh5DC5O64dnrKjdsBKCMrzR5QTk=
github.com/docker/model-distribution v0.0.0-20250618082521-fb5c8332c857/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
github.com/docker/model-distribution v0.0.0-20250627163720-aff34abcf3e0 h1:bve4JZI06Admw+NewtPfrpJXsvRnGKTQvBOEICNC1C0=
github.com/docker/model-distribution v0.0.0-20250627163720-aff34abcf3e0/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
Expand Down
16 changes: 5 additions & 11 deletions pkg/inference/backends/llamacpp/llamacpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"

"github.com/docker/model-runner/pkg/diskusage"
Expand Down Expand Up @@ -122,10 +121,9 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {

// Run implements inference.Backend.Run.
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
modelPath, err := l.modelManager.GetModelPath(model)
l.log.Infof("Model path: %s", modelPath)
mdl, err := l.modelManager.GetModel(model)
if err != nil {
return fmt.Errorf("failed to get model path: %w", err)
return fmt.Errorf("failed to get model: %w", err)
}

if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
Expand All @@ -138,13 +136,9 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
binPath = l.updatedServerStoragePath
}

args := l.config.GetArgs(modelPath, socket, mode)

if config != nil {
if config.ContextSize >= 0 {
args = append(args, "--ctx-size", strconv.Itoa(int(config.ContextSize)))
}
args = append(args, config.RuntimeFlags...)
args, err := l.config.GetArgs(mdl, socket, mode, config)
if err != nil {
return fmt.Errorf("failed to get args for llama.cpp: %w", err)
}

l.log.Infof("llamaCppArgs: %v", args)
Expand Down
29 changes: 27 additions & 2 deletions pkg/inference/backends/llamacpp/llamacpp_config.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package llamacpp

import (
"fmt"
"runtime"
"strconv"

"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/inference"
)

Expand Down Expand Up @@ -33,10 +35,20 @@ func NewDefaultLlamaCppConfig() *Config {
}

// GetArgs implements BackendConfig.GetArgs.
func (c *Config) GetArgs(modelPath, socket string, mode inference.BackendMode) []string {
func (c *Config) GetArgs(model types.Model, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) {
// Start with the arguments from LlamaCppConfig
args := append([]string{}, c.Args...)

modelPath, err := model.GGUFPath()
if err != nil {
return nil, fmt.Errorf("get gguf path: %w", err)
}

modelCfg, err := model.Config()
if err != nil {
return nil, fmt.Errorf("get model config: %w", err)
}

// Add model and socket arguments
args = append(args, "--model", modelPath, "--host", socket)

Expand All @@ -45,7 +57,20 @@ func (c *Config) GetArgs(modelPath, socket string, mode inference.BackendMode) [
args = append(args, "--embeddings")
}

return args
// Add arguments from model config
if modelCfg.ContextSize != nil {
args = append(args, "--ctx-size", strconv.FormatUint(*modelCfg.ContextSize, 10))
}

// Add arguments from backend config
if config != nil {
if config.ContextSize > 0 && !containsArg(args, "--ctx-size") {
args = append(args, "--ctx-size", strconv.FormatInt(config.ContextSize, 10))
}
args = append(args, config.RuntimeFlags...)
}

return args, nil
}

// containsArg checks if the given argument is already in the args slice.
Expand Down
105 changes: 104 additions & 1 deletion pkg/inference/backends/llamacpp/llamacpp_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strconv"
"testing"

"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/inference"
)

Expand Down Expand Up @@ -72,12 +73,17 @@ func TestGetArgs(t *testing.T) {

tests := []struct {
name string
model types.Model
mode inference.BackendMode
config *inference.BackendConfiguration
expected []string
}{
{
name: "completion mode",
mode: inference.BackendModeCompletion,
model: &fakeModel{
ggufPath: modelPath,
},
expected: []string{
"--jinja",
"-ngl", "100",
Expand All @@ -89,20 +95,86 @@ func TestGetArgs(t *testing.T) {
{
name: "embedding mode",
mode: inference.BackendModeEmbedding,
model: &fakeModel{
ggufPath: modelPath,
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--embeddings",
},
},
{
name: "context size from backend config",
mode: inference.BackendModeEmbedding,
model: &fakeModel{
ggufPath: modelPath,
},
config: &inference.BackendConfiguration{
ContextSize: 1234,
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--embeddings",
"--ctx-size", "1234", // should add this flag
},
},
{
name: "context size from model config",
mode: inference.BackendModeEmbedding,
model: &fakeModel{
ggufPath: modelPath,
config: types.Config{
ContextSize: uint64ptr(2096),
},
},
config: &inference.BackendConfiguration{
ContextSize: 1234,
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--embeddings",
"--ctx-size", "2096", // model config takes precedence
},
},
{
name: "raw flags from backend config",
mode: inference.BackendModeEmbedding,
model: &fakeModel{
ggufPath: modelPath,
},
config: &inference.BackendConfiguration{
RuntimeFlags: []string{"--some", "flag"},
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--embeddings",
"--some", "flag", // model config takes precedence
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
args := config.GetArgs(modelPath, socket, tt.mode)
args, err := config.GetArgs(tt.model, socket, tt.mode, tt.config)
if err != nil {
t.Errorf("GetArgs() error = %v", err)
}

// Check that all expected arguments are present and in the correct order
expectedIndex := 0
Expand Down Expand Up @@ -171,3 +243,34 @@ func TestContainsArg(t *testing.T) {
})
}
}

var _ types.Model = &fakeModel{}

type fakeModel struct {
ggufPath string
config types.Config
}

func (f *fakeModel) ID() (string, error) {
panic("shouldn't be called")
}

func (f *fakeModel) GGUFPath() (string, error) {
return f.ggufPath, nil
}

func (f *fakeModel) Config() (types.Config, error) {
return f.config, nil
}

func (f *fakeModel) Tags() []string {
panic("shouldn't be called")
}

func (f fakeModel) Descriptor() (types.Descriptor, error) {
panic("shouldn't be called")
}

func uint64ptr(n uint64) *uint64 {
return &n
}
3 changes: 2 additions & 1 deletion pkg/inference/config/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package config

import (
"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/inference"
)

Expand All @@ -11,5 +12,5 @@ type BackendConfig interface {
// GetArgs returns the command-line arguments for the backend.
// It takes the model path, socket, and mode as input and returns
// the appropriate arguments for the backend.
GetArgs(modelPath, socket string, mode inference.BackendMode) []string
GetArgs(model types.Model, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error)
}