Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions cmd/cli/commands/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func newComposeCmd() *cobra.Command {
func newUpCommand() *cobra.Command {
var models []string
var ctxSize int64
var rawRuntimeFlags string
var backend string
var draftModel string
var numTokens int
Expand Down Expand Up @@ -69,6 +70,9 @@ func newUpCommand() *cobra.Command {
if ctxSize > 0 {
sendInfo(fmt.Sprintf("Setting context size to %d", ctxSize))
}
if rawRuntimeFlags != "" {
sendInfo("Setting raw runtime flags to " + rawRuntimeFlags)
}

// Build speculative config if any speculative flags are set
var speculativeConfig *inference.SpeculativeDecodingConfig
Expand All @@ -89,10 +93,11 @@ func newUpCommand() *cobra.Command {
ContextSize: &size,
Speculative: speculativeConfig,
},
RawRuntimeFlags: rawRuntimeFlags,
}); err != nil {
configErrFmtString := "failed to configure backend for model %s with context-size %d"
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, err)
return fmt.Errorf(configErrFmtString+": %w", model, ctxSize, err)
configErrFmtString := "failed to configure backend for model %s with context-size %d and runtime-flags %s"
_ = sendErrorf(configErrFmtString+": %v", model, rawRuntimeFlags, ctxSize, err)
return fmt.Errorf(configErrFmtString+": %w", model, ctxSize, rawRuntimeFlags, err)
}
sendInfo("Successfully configured backend for model " + model)
}
Expand All @@ -114,6 +119,7 @@ func newUpCommand() *cobra.Command {
}
c.Flags().StringArrayVar(&models, "model", nil, "model to use")
c.Flags().Int64Var(&ctxSize, "context-size", -1, "context size for the model")
c.Flags().StringVar(&rawRuntimeFlags, "runtime-flags", "", "raw runtime flags to pass to the inference engine")
c.Flags().StringVar(&backend, "backend", llamacpp.Name, "inference backend to use")
c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding")
c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively")
Expand Down
25 changes: 19 additions & 6 deletions cmd/cli/commands/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,27 @@ func newConfigureCmd() *cobra.Command {
var flags ConfigureFlags

c := &cobra.Command{
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL",
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL [-- <runtime-flags...>]",
Short: "Configure runtime options for a model",
Hidden: true,
Args: func(cmd *cobra.Command, args []string) error {
if len(args) != 1 {
return fmt.Errorf(
"Exactly one model must be specified, got %d: %v\n\n"+
"See 'docker model configure --help' for more information",
len(args), args)
argsBeforeDash := cmd.ArgsLenAtDash()
if argsBeforeDash == -1 {
// No "--" used, so we need exactly 1 total argument.
if len(args) != 1 {
return fmt.Errorf(
"Exactly one model must be specified, got %d: %v\n\n"+
"See 'docker model configure --help' for more information",
len(args), args)
}
} else {
// Has "--", so we need exactly 1 argument before it.
if argsBeforeDash != 1 {
return fmt.Errorf(
"Exactly one model must be specified before --, got %d\n\n"+
"See 'docker model configure --help' for more information",
argsBeforeDash)
}
}
return nil
},
Expand All @@ -29,6 +41,7 @@ func newConfigureCmd() *cobra.Command {
if err != nil {
return err
}
opts.RuntimeFlags = args[1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Runtime flags slice includes the model name and likely the literal "--"

Here opts.RuntimeFlags = args[1:] will still include the literal "--" when present (e.g. docker model configure foo -- --embeddings yields RuntimeFlags = ["--", "--embeddings"]). To avoid depending on positional slicing and to drop "--" explicitly, consider mirroring the Args logic and using cmd.ArgsLenAtDash(): treat args[:argsBeforeDash] as the model (length 1) and args[argsBeforeDash+1:] as runtime flags.

return desktopClient.ConfigureBackend(opts)
},
ValidArgsFunction: completion.ModelNames(getDesktopClient, -1),
Expand Down
9 changes: 9 additions & 0 deletions cmd/cli/docs/reference/docker_model_compose_up.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ options:
experimentalcli: false
kubernetes: false
swarm: false
- option: runtime-flags
value_type: string
description: raw runtime flags to pass to the inference engine
deprecated: false
hidden: false
experimental: false
experimentalcli: false
kubernetes: false
swarm: false
- option: speculative-draft-model
value_type: string
description: draft model for speculative decoding
Expand Down
2 changes: 1 addition & 1 deletion cmd/cli/docs/reference/docker_model_configure.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
command: docker model configure
short: Configure runtime options for a model
long: Configure runtime options for a model
usage: docker model configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL
usage: docker model configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL [-- <runtime-flags...>]
pname: docker model
plink: docker_model.yaml
options:
Expand Down
5 changes: 3 additions & 2 deletions pkg/inference/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ type LlamaCppConfig struct {

type BackendConfiguration struct {
// Shared configuration across all backends
ContextSize *int32 `json:"context-size,omitempty"`
Speculative *SpeculativeDecodingConfig `json:"speculative,omitempty"`
ContextSize *int32 `json:"context-size,omitempty"`
RuntimeFlags []string `json:"runtime-flags,omitempty"`
Speculative *SpeculativeDecodingConfig `json:"speculative,omitempty"`

// Backend-specific configuration
VLLM *VLLMConfig `json:"vllm,omitempty"`
Expand Down
5 changes: 5 additions & 0 deletions pkg/inference/backends/llamacpp/llamacpp_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
args = append(args, "--ctx-size", strconv.FormatInt(int64(*contextSize), 10))
}

// Add arguments from backend config
if config != nil {
args = append(args, config.RuntimeFlags...)
}

// Add arguments for Multimodal projector or jinja (they are mutually exclusive)
if path := bundle.MMPROJPath(); path != "" {
args = append(args, "--mmproj", path)
Expand Down
17 changes: 17 additions & 0 deletions pkg/inference/backends/llamacpp/llamacpp_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,23 @@ func TestGetArgs(t *testing.T) {
"--jinja",
),
},
{
name: "raw flags from backend config",
mode: inference.BackendModeEmbedding,
bundle: &fakeBundle{
ggufPath: modelPath,
},
config: &inference.BackendConfiguration{
RuntimeFlags: []string{"--some", "flag"},
},
expected: append(slices.Clone(baseArgs),
"--model", modelPath,
"--host", socket,
"--embeddings",
"--some", "flag",
"--jinja",
),
},
{
name: "multimodal projector removes jinja",
mode: inference.BackendModeCompletion,
Expand Down
6 changes: 5 additions & 1 deletion pkg/inference/backends/vllm/vllm_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
if maxLen := GetMaxModelLen(bundle.RuntimeConfig(), config); maxLen != nil {
args = append(args, "--max-model-len", strconv.FormatInt(int64(*maxLen), 10))
}
// If nil, vLLM will automatically derive from the model config

// Add runtime flags from backend config
if config != nil {
args = append(args, config.RuntimeFlags...)
}

// Add vLLM-specific arguments from backend config
if config != nil && config.VLLM != nil {
Expand Down
17 changes: 17 additions & 0 deletions pkg/inference/backends/vllm/vllm_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@ func TestGetArgs(t *testing.T) {
"8192",
},
},
{
name: "with runtime flags",
bundle: &mockModelBundle{
safetensorsPath: "/path/to/model",
},
config: &inference.BackendConfiguration{
RuntimeFlags: []string{"--gpu-memory-utilization", "0.9"},
},
expected: []string{
"serve",
"/path/to",
"--uds",
"/tmp/socket",
"--gpu-memory-utilization",
"0.9",
},
},
{
name: "with model context size (takes precedence)",
bundle: &mockModelBundle{
Expand Down
5 changes: 3 additions & 2 deletions pkg/inference/scheduling/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ type UnloadResponse struct {

// ConfigureRequest specifies per-model runtime configuration options.
type ConfigureRequest struct {
Model string `json:"model"`
Mode *inference.BackendMode `json:"mode,omitempty"`
Model string `json:"model"`
Mode *inference.BackendMode `json:"mode,omitempty"`
RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"`
inference.BackendConfiguration
}
18 changes: 18 additions & 0 deletions pkg/inference/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package scheduling
import (
"context"
"errors"
"fmt"
"net/http"
"slices"
"time"

"github.com/docker/model-runner/pkg/distribution/types"
Expand All @@ -14,6 +16,7 @@ import (
"github.com/docker/model-runner/pkg/internal/utils"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/metrics"
"github.com/mattn/go-shellwords"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -225,10 +228,23 @@ func (s *Scheduler) ConfigureRunner(ctx context.Context, backend inference.Backe
backend = s.defaultBackend
}

// Parse runtime flags from either array or raw string
var runtimeFlags []string
if len(req.RuntimeFlags) > 0 {
runtimeFlags = req.RuntimeFlags
} else if req.RawRuntimeFlags != "" {
var err error
runtimeFlags, err = shellwords.Parse(req.RawRuntimeFlags)
if err != nil {
return nil, fmt.Errorf("invalid runtime flags: %w", err)
}
}

// Build runner configuration with shared settings
var runnerConfig inference.BackendConfiguration
runnerConfig.ContextSize = req.ContextSize
runnerConfig.Speculative = req.Speculative
runnerConfig.RuntimeFlags = runtimeFlags

// Set vLLM-specific configuration if provided
if req.VLLM != nil {
Expand All @@ -255,6 +271,8 @@ func (s *Scheduler) ConfigureRunner(ctx context.Context, backend inference.Backe
mode := inference.BackendModeCompletion
if req.Mode != nil {
mode = *req.Mode
} else if slices.Contains(runnerConfig.RuntimeFlags, "--embeddings") {
mode = inference.BackendModeEmbedding
}

// Get model, track usage, and select appropriate backend
Expand Down
Loading