diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index 019e52983..0eae5d4cc 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -30,8 +30,8 @@ func (m BackendMode) String() string { } type BackendConfiguration struct { - ContextSize int64 `json:"context_size,omitempty"` - RawFlags []string `json:"flags,omitempty"` + ContextSize int64 `json:"context-size,omitempty"` + RuntimeFlags []string `json:"runtime-flags,omitempty"` } // Backend is the interface implemented by inference engine backends. Backend diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index 930535daa..00fa57b86 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -144,7 +144,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference if config.ContextSize >= 0 { args = append(args, "--ctx-size", strconv.Itoa(int(config.ContextSize))) } - args = append(args, config.RawFlags...) + args = append(args, config.RuntimeFlags...) } l.log.Infof("llamaCppArgs: %v", args) diff --git a/pkg/inference/scheduling/api.go b/pkg/inference/scheduling/api.go index 52c8ec56d..d718df4b4 100644 --- a/pkg/inference/scheduling/api.go +++ b/pkg/inference/scheduling/api.go @@ -86,7 +86,8 @@ type UnloadResponse struct { // ConfigureRequest specifies per-model runtime configuration options. type ConfigureRequest struct { - Model string `json:"model"` - ContextSize int64 `json:"context-size,omitempty"` - RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"` + Model string `json:"model"` + ContextSize int64 `json:"context-size,omitempty"` + RuntimeFlags []string `json:"runtime-flags,omitempty"` + RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"` } diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 5a9d0e5b1..fa6f794e9 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -388,23 +388,27 @@ func (s *Scheduler) Configure(w http.ResponseWriter, r *http.Request) { } configureRequest := ConfigureRequest{ - Model: "", - ContextSize: -1, - RawRuntimeFlags: "", + ContextSize: -1, } if err := json.Unmarshal(body, &configureRequest); err != nil { http.Error(w, "invalid request", http.StatusBadRequest) return } - rawFlags, err := shellwords.Parse(configureRequest.RawRuntimeFlags) - if err != nil { - http.Error(w, "invalid request", http.StatusBadRequest) - return + var runtimeFlags []string + if len(configureRequest.RuntimeFlags) > 0 { + runtimeFlags = configureRequest.RuntimeFlags + } else { + rawFlags, err := shellwords.Parse(configureRequest.RawRuntimeFlags) + if err != nil { + http.Error(w, "invalid request", http.StatusBadRequest) + return + } + runtimeFlags = rawFlags } var runnerConfig inference.BackendConfiguration runnerConfig.ContextSize = configureRequest.ContextSize - runnerConfig.RawFlags = rawFlags + runnerConfig.RuntimeFlags = runtimeFlags if err := s.loader.setRunnerConfig(r.Context(), backend.Name(), configureRequest.Model, inference.BackendModeCompletion, runnerConfig); err != nil { s.log.Warnf("Failed to configure %s runner for %s: %s", backend.Name(), configureRequest.Model, err)