Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.23.7
require (
github.com/containerd/containerd/v2 v2.0.4
github.com/containerd/platforms v1.0.0-rc.1
github.com/docker/model-distribution v0.0.0-20250618082521-fb5c8332c857
github.com/docker/model-distribution v0.0.0-20250627152504-c0c68acaabd5
github.com/google/go-containerregistry v0.20.3
github.com/jaypipes/ghw v0.16.0
github.com/mattn/go-shellwords v1.0.12
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ github.com/docker/docker-credential-helpers v0.8.2 h1:bX3YxiGzFP5sOXWc3bTPEXdEaZ
github.com/docker/docker-credential-helpers v0.8.2/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M=
github.com/docker/model-distribution v0.0.0-20250618082521-fb5c8332c857 h1:2IvvpdPZvpNn06+RUh5DC5O64dnrKjdsBKCMrzR5QTk=
github.com/docker/model-distribution v0.0.0-20250618082521-fb5c8332c857/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
github.com/docker/model-distribution v0.0.0-20250627152504-c0c68acaabd5 h1:4qs7OwrJJQE1de5kbHg83gm5rmCRQzAwAp8gG/3cLY8=
github.com/docker/model-distribution v0.0.0-20250627152504-c0c68acaabd5/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
Expand Down
16 changes: 5 additions & 11 deletions pkg/inference/backends/llamacpp/llamacpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"

"github.com/docker/model-runner/pkg/diskusage"
Expand Down Expand Up @@ -122,10 +121,9 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {

// Run implements inference.Backend.Run.
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
modelPath, err := l.modelManager.GetModelPath(model)
l.log.Infof("Model path: %s", modelPath)
mdl, err := l.modelManager.GetModel(model)
if err != nil {
return fmt.Errorf("failed to get model path: %w", err)
return fmt.Errorf("failed to get model: %w", err)
}

if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
Expand All @@ -138,13 +136,9 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
binPath = l.updatedServerStoragePath
}

args := l.config.GetArgs(modelPath, socket, mode)

if config != nil {
if config.ContextSize >= 0 {
args = append(args, "--ctx-size", strconv.Itoa(int(config.ContextSize)))
}
args = append(args, config.RuntimeFlags...)
args, err := l.config.GetArgs(mdl, socket, mode, config)
if err != nil {
return fmt.Errorf("failed to get args for llama.cpp: %w", err)
}

l.log.Infof("llamaCppArgs: %v", args)
Expand Down
29 changes: 27 additions & 2 deletions pkg/inference/backends/llamacpp/llamacpp_config.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package llamacpp

import (
"fmt"
"runtime"
"strconv"

"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/inference"
)

Expand Down Expand Up @@ -33,10 +35,20 @@ func NewDefaultLlamaCppConfig() *Config {
}

// GetArgs implements BackendConfig.GetArgs.
func (c *Config) GetArgs(modelPath, socket string, mode inference.BackendMode) []string {
func (c *Config) GetArgs(model types.Model, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) {
// Start with the arguments from LlamaCppConfig
args := append([]string{}, c.Args...)

modelPath, err := model.GGUFPath()
if err != nil {
return nil, fmt.Errorf("get gguf path: %v", err)
}

modelCfg, err := model.Config()
if err != nil {
return nil, fmt.Errorf("get model config: %v", err)
}

// Add model and socket arguments
args = append(args, "--model", modelPath, "--host", socket)

Expand All @@ -45,7 +57,20 @@ func (c *Config) GetArgs(modelPath, socket string, mode inference.BackendMode) [
args = append(args, "--embeddings")
}

return args
// Add arguments from model config
if modelCfg.ContextSize != nil {
args = append(args, "--ctx-size", strconv.FormatUint(*modelCfg.ContextSize, 10))
}

// Add arguments from backend config
if config != nil {
if config.ContextSize > 0 && !containsArg(args, "--ctx-size") {
args = append(args, "--ctx-size", fmt.Sprintf("%d", config.ContextSize))
}
args = append(args, config.RuntimeFlags...)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would unify the number-to-string conversion; strconv.Itoa is probably the most efficient (vs. strconv.FormatUint or fmt.Sprintf).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can use strconv.FormatInt and strconv.FormatUint for the backend config and artifact config respectively. strconv.Itoa accepts an int type rather than an int64. I was assuming we chose int64 in the backend config for a reason and wouldn't want to risk losing precision.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, roger that.


return args, nil
}

// containsArg checks if the given argument is already in the args slice.
Expand Down
105 changes: 104 additions & 1 deletion pkg/inference/backends/llamacpp/llamacpp_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strconv"
"testing"

"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/inference"
)

Expand Down Expand Up @@ -72,12 +73,17 @@ func TestGetArgs(t *testing.T) {

tests := []struct {
name string
model types.Model
mode inference.BackendMode
config *inference.BackendConfiguration
expected []string
}{
{
name: "completion mode",
mode: inference.BackendModeCompletion,
model: &fakeModel{
ggufPath: modelPath,
},
expected: []string{
"--jinja",
"-ngl", "100",
Expand All @@ -89,20 +95,86 @@ func TestGetArgs(t *testing.T) {
{
name: "embedding mode",
mode: inference.BackendModeEmbedding,
model: &fakeModel{
ggufPath: modelPath,
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--embeddings",
},
},
{
name: "context size from backend config",
mode: inference.BackendModeEmbedding,
model: &fakeModel{
ggufPath: modelPath,
},
config: &inference.BackendConfiguration{
ContextSize: 1234,
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--embeddings",
"--ctx-size", "1234", // should add this flag
},
},
{
name: "context size from model config",
mode: inference.BackendModeEmbedding,
model: &fakeModel{
ggufPath: modelPath,
config: types.Config{
ContextSize: uint64ptr(2096),
},
},
config: &inference.BackendConfiguration{
ContextSize: 1234,
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--embeddings",
"--ctx-size", "2096", // model config takes precedence
},
},
{
name: "raw flags from backend config",
mode: inference.BackendModeEmbedding,
model: &fakeModel{
ggufPath: modelPath,
},
config: &inference.BackendConfiguration{
RuntimeFlags: []string{"--some", "flag"},
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--embeddings",
"--some", "flag", // model config takes precedence
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
args := config.GetArgs(modelPath, socket, tt.mode)
args, err := config.GetArgs(tt.model, socket, tt.mode, tt.config)
if err != nil {
t.Errorf("GetArgs() error = %v", err)
}

// Check that all expected arguments are present and in the correct order
expectedIndex := 0
Expand Down Expand Up @@ -171,3 +243,34 @@ func TestContainsArg(t *testing.T) {
})
}
}

var _ types.Model = &fakeModel{}

type fakeModel struct {
ggufPath string
config types.Config
}

func (f *fakeModel) ID() (string, error) {
panic("shouldn't be called")
}

func (f *fakeModel) GGUFPath() (string, error) {
return f.ggufPath, nil
}

func (f *fakeModel) Config() (types.Config, error) {
return f.config, nil
}

func (f *fakeModel) Tags() []string {
panic("shouldn't be called")
}

func (f fakeModel) Descriptor() (types.Descriptor, error) {
panic("shouldn't be called")
}

func uint64ptr(n uint64) *uint64 {
return &n
}
3 changes: 2 additions & 1 deletion pkg/inference/config/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package config

import (
"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/inference"
)

Expand All @@ -11,5 +12,5 @@ type BackendConfig interface {
// GetArgs returns the command-line arguments for the backend.
// It takes the model path, socket, and mode as input and returns
// the appropriate arguments for the backend.
GetArgs(modelPath, socket string, mode inference.BackendMode) []string
GetArgs(model types.Model, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error)
}