Skip to content
This repository was archived by the owner on Oct 6, 2025. It is now read-only.

Commit ede328c

Browse files
ilopezlunaxenoscopicdoringeman
authored
Backend flag support (#126)
Co-authored-by: Jacob Howard <[email protected]> Co-authored-by: Dorin-Andrei Geman <[email protected]>
1 parent 72d2dd1 commit ede328c

File tree

9 files changed

+168
-36
lines changed

9 files changed

+168
-36
lines changed

commands/backend.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package commands
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"maps"
7+
"os"
8+
"slices"
9+
"strings"
10+
)
11+
12+
// ValidBackends is a map of valid backends
13+
var ValidBackends = map[string]bool{
14+
"llama.cpp": true,
15+
"openai": true,
16+
}
17+
18+
// validateBackend checks if the provided backend is valid
19+
func validateBackend(backend string) error {
20+
if !ValidBackends[backend] {
21+
return fmt.Errorf("invalid backend '%s'. Valid backends are: %s",
22+
backend, ValidBackendsKeys())
23+
}
24+
return nil
25+
}
26+
27+
// ensureAPIKey retrieves the API key if needed
28+
func ensureAPIKey(backend string) (string, error) {
29+
if backend == "openai" {
30+
apiKey := os.Getenv("OPENAI_API_KEY")
31+
if apiKey == "" {
32+
return "", errors.New("OPENAI_API_KEY environment variable is required when using --backend=openai")
33+
}
34+
return apiKey, nil
35+
}
36+
return "", nil
37+
}
38+
39+
func ValidBackendsKeys() string {
40+
return strings.Join(slices.Collect(maps.Keys(ValidBackends)), ", ")
41+
}

commands/list.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,39 @@ import (
1818

1919
func newListCmd() *cobra.Command {
2020
var jsonFormat, openai, quiet bool
21+
var backend string
2122
c := &cobra.Command{
2223
Use: "list [OPTIONS]",
2324
Aliases: []string{"ls"},
2425
Short: "List the models pulled to your local environment",
2526
RunE: func(cmd *cobra.Command, args []string) error {
26-
if openai && quiet {
27-
return fmt.Errorf("--quiet flag cannot be used with --openai flag")
27+
// Validate backend if specified
28+
if backend != "" {
29+
if err := validateBackend(backend); err != nil {
30+
return err
31+
}
2832
}
33+
34+
if (backend == "openai" || openai) && quiet {
35+
return fmt.Errorf("--quiet flag cannot be used with --openai flag or OpenAI backend")
36+
}
37+
38+
// Validate API key for OpenAI backend
39+
apiKey, err := ensureAPIKey(backend)
40+
if err != nil {
41+
return err
42+
}
43+
2944
// If we're doing an automatic install, only show the installation
3045
// status if it won't corrupt machine-readable output.
3146
var standaloneInstallPrinter standalone.StatusPrinter
32-
if !jsonFormat && !openai && !quiet {
47+
if !jsonFormat && !openai && !quiet && backend == "" {
3348
standaloneInstallPrinter = cmd
3449
}
3550
if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), standaloneInstallPrinter); err != nil {
3651
return fmt.Errorf("unable to initialize standalone model runner: %w", err)
3752
}
38-
models, err := listModels(openai, desktopClient, quiet, jsonFormat)
53+
models, err := listModels(openai, backend, desktopClient, quiet, jsonFormat, apiKey)
3954
if err != nil {
4055
return err
4156
}
@@ -47,12 +62,13 @@ func newListCmd() *cobra.Command {
4762
c.Flags().BoolVar(&jsonFormat, "json", false, "List models in a JSON format")
4863
c.Flags().BoolVar(&openai, "openai", false, "List models in an OpenAI format")
4964
c.Flags().BoolVarP(&quiet, "quiet", "q", false, "Only show model IDs")
65+
c.Flags().StringVar(&backend, "backend", "", fmt.Sprintf("Specify the backend to use (%s)", ValidBackendsKeys()))
5066
return c
5167
}
5268

53-
func listModels(openai bool, desktopClient *desktop.Client, quiet bool, jsonFormat bool) (string, error) {
54-
if openai {
55-
models, err := desktopClient.ListOpenAI()
69+
func listModels(openai bool, backend string, desktopClient *desktop.Client, quiet bool, jsonFormat bool, apiKey string) (string, error) {
70+
if openai || backend == "openai" {
71+
models, err := desktopClient.ListOpenAI(backend, apiKey)
5672
if err != nil {
5773
err = handleClientError(err, "Failed to list models")
5874
return "", handleNotRunningError(err)

commands/run.go

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,26 @@ func readMultilineInput(cmd *cobra.Command, scanner *bufio.Scanner) (string, err
7979

8080
func newRunCmd() *cobra.Command {
8181
var debug bool
82+
var backend string
8283

8384
const cmdArgs = "MODEL [PROMPT]"
8485
c := &cobra.Command{
8586
Use: "run " + cmdArgs,
8687
Short: "Run a model and interact with it using a submitted prompt or chat mode",
8788
RunE: func(cmd *cobra.Command, args []string) error {
89+
// Validate backend if specified
90+
if backend != "" {
91+
if err := validateBackend(backend); err != nil {
92+
return err
93+
}
94+
}
95+
96+
// Validate API key for OpenAI backend
97+
apiKey, err := ensureAPIKey(backend)
98+
if err != nil {
99+
return err
100+
}
101+
88102
model := args[0]
89103
prompt := ""
90104
if len(args) == 1 {
@@ -102,19 +116,22 @@ func newRunCmd() *cobra.Command {
102116
return fmt.Errorf("unable to initialize standalone model runner: %w", err)
103117
}
104118

105-
_, err := desktopClient.Inspect(model, false)
106-
if err != nil {
107-
if !errors.Is(err, desktop.ErrNotFound) {
108-
return handleNotRunningError(handleClientError(err, "Failed to inspect model"))
109-
}
110-
cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.")
111-
if err := pullModel(cmd, desktopClient, model); err != nil {
112-
return err
119+
// Do not validate the model in case of using OpenAI's backend, let OpenAI handle it
120+
if backend != "openai" {
121+
_, err := desktopClient.Inspect(model, false)
122+
if err != nil {
123+
if !errors.Is(err, desktop.ErrNotFound) {
124+
return handleNotRunningError(handleClientError(err, "Failed to inspect model"))
125+
}
126+
cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.")
127+
if err := pullModel(cmd, desktopClient, model); err != nil {
128+
return err
129+
}
113130
}
114131
}
115132

116133
if prompt != "" {
117-
if err := desktopClient.Chat(model, prompt); err != nil {
134+
if err := desktopClient.Chat(backend, model, prompt, apiKey); err != nil {
118135
return handleClientError(err, "Failed to generate a response")
119136
}
120137
cmd.Println()
@@ -143,7 +160,7 @@ func newRunCmd() *cobra.Command {
143160
continue
144161
}
145162

146-
if err := desktopClient.Chat(model, userInput); err != nil {
163+
if err := desktopClient.Chat(backend, model, userInput, apiKey); err != nil {
147164
cmd.PrintErr(handleClientError(err, "Failed to generate a response"))
148165
continue
149166
}
@@ -169,6 +186,7 @@ func newRunCmd() *cobra.Command {
169186
}
170187

171188
c.Flags().BoolVar(&debug, "debug", false, "Enable debug logging")
189+
c.Flags().StringVar(&backend, "backend", "", fmt.Sprintf("Specify the backend to use (%s)", ValidBackendsKeys()))
172190

173191
return c
174192
}

desktop/desktop.go

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import (
2121
"go.opentelemetry.io/otel"
2222
)
2323

24+
const DefaultBackend = "llama.cpp"
25+
2426
var (
2527
ErrNotFound = errors.New("model not found")
2628
ErrServiceUnavailable = errors.New("service unavailable")
@@ -236,14 +238,30 @@ func (c *Client) List() ([]dmrm.Model, error) {
236238
return modelsJson, nil
237239
}
238240

239-
func (c *Client) ListOpenAI() (dmrm.OpenAIModelList, error) {
240-
modelsRoute := inference.InferencePrefix + "/v1/models"
241-
rawResponse, err := c.listRaw(modelsRoute, "")
241+
func (c *Client) ListOpenAI(backend, apiKey string) (dmrm.OpenAIModelList, error) {
242+
if backend == "" {
243+
backend = DefaultBackend
244+
}
245+
modelsRoute := fmt.Sprintf("%s/%s/v1/models", inference.InferencePrefix, backend)
246+
247+
// Use doRequestWithAuth to support API key authentication
248+
resp, err := c.doRequestWithAuth(http.MethodGet, modelsRoute, nil, "openai", apiKey)
249+
if err != nil {
250+
return dmrm.OpenAIModelList{}, c.handleQueryError(err, modelsRoute)
251+
}
252+
defer resp.Body.Close()
253+
254+
if resp.StatusCode != http.StatusOK {
255+
return dmrm.OpenAIModelList{}, fmt.Errorf("failed to list models: %s", resp.Status)
256+
}
257+
258+
body, err := io.ReadAll(resp.Body)
242259
if err != nil {
243-
return dmrm.OpenAIModelList{}, err
260+
return dmrm.OpenAIModelList{}, fmt.Errorf("failed to read response body: %w", err)
244261
}
262+
245263
var modelsJson dmrm.OpenAIModelList
246-
if err := json.Unmarshal(rawResponse, &modelsJson); err != nil {
264+
if err := json.Unmarshal(body, &modelsJson); err != nil {
247265
return modelsJson, fmt.Errorf("failed to unmarshal response body: %w", err)
248266
}
249267
return modelsJson, nil
@@ -343,7 +361,7 @@ func (c *Client) fullModelID(id string) (string, error) {
343361
return "", fmt.Errorf("model with ID %s not found", id)
344362
}
345363

346-
func (c *Client) Chat(model, prompt string) error {
364+
func (c *Client) Chat(backend, model, prompt, apiKey string) error {
347365
model = normalizeHuggingFaceModelName(model)
348366
if !strings.Contains(strings.Trim(model, "/"), "/") {
349367
// Do an extra API call to check if the model parameter isn't a model ID.
@@ -368,14 +386,22 @@ func (c *Client) Chat(model, prompt string) error {
368386
return fmt.Errorf("error marshaling request: %w", err)
369387
}
370388

371-
chatCompletionsPath := inference.InferencePrefix + "/v1/chat/completions"
372-
resp, err := c.doRequest(
389+
var completionsPath string
390+
if backend != "" {
391+
completionsPath = inference.InferencePrefix + "/" + backend + "/v1/chat/completions"
392+
} else {
393+
completionsPath = inference.InferencePrefix + "/v1/chat/completions"
394+
}
395+
396+
resp, err := c.doRequestWithAuth(
373397
http.MethodPost,
374-
chatCompletionsPath,
398+
completionsPath,
375399
bytes.NewReader(jsonData),
400+
backend,
401+
apiKey,
376402
)
377403
if err != nil {
378-
return c.handleQueryError(err, chatCompletionsPath)
404+
return c.handleQueryError(err, completionsPath)
379405
}
380406
defer resp.Body.Close()
381407

@@ -604,6 +630,11 @@ func (c *Client) ConfigureBackend(request scheduling.ConfigureRequest) error {
604630

605631
// doRequest is a helper function that performs HTTP requests and handles 503 responses
606632
func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response, error) {
633+
return c.doRequestWithAuth(method, path, body, "", "")
634+
}
635+
636+
// doRequestWithAuth is a helper function that performs HTTP requests with optional authentication
637+
func (c *Client) doRequestWithAuth(method, path string, body io.Reader, backend, apiKey string) (*http.Response, error) {
607638
req, err := http.NewRequest(method, c.modelRunner.URL(path), body)
608639
if err != nil {
609640
return nil, fmt.Errorf("error creating request: %w", err)
@@ -613,6 +644,12 @@ func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response,
613644
}
614645

615646
req.Header.Set("User-Agent", "docker-model-cli/"+Version)
647+
648+
// Add Authorization header for OpenAI backend
649+
if apiKey != "" {
650+
req.Header.Set("Authorization", "Bearer "+apiKey)
651+
}
652+
616653
resp, err := c.modelRunner.Client().Do(req)
617654
if err != nil {
618655
return nil, err

desktop/desktop_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func TestChatHuggingFaceModel(t *testing.T) {
6363
Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")),
6464
}, nil)
6565

66-
err := client.Chat(modelName, prompt)
66+
err := client.Chat("", modelName, prompt, "")
6767
assert.NoError(t, err)
6868
}
6969

docs/reference/docker_model_list.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ usage: docker model list [OPTIONS]
66
pname: docker model
77
plink: docker_model.yaml
88
options:
9+
- option: backend
10+
value_type: string
11+
description: Specify the backend to use (llama.cpp, openai)
12+
deprecated: false
13+
hidden: false
14+
experimental: false
15+
experimentalcli: false
16+
kubernetes: false
17+
swarm: false
918
- option: json
1019
value_type: bool
1120
default_value: "false"

docs/reference/docker_model_run.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@ usage: docker model run MODEL [PROMPT]
1010
pname: docker model
1111
plink: docker_model.yaml
1212
options:
13+
- option: backend
14+
value_type: string
15+
description: Specify the backend to use (llama.cpp, openai)
16+
deprecated: false
17+
hidden: false
18+
experimental: false
19+
experimentalcli: false
20+
kubernetes: false
21+
swarm: false
1322
- option: debug
1423
value_type: bool
1524
default_value: "false"

docs/reference/model_list.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ List the models pulled to your local environment
99

1010
### Options
1111

12-
| Name | Type | Default | Description |
13-
|:----------------|:-------|:--------|:--------------------------------|
14-
| `--json` | `bool` | | List models in a JSON format |
15-
| `--openai` | `bool` | | List models in an OpenAI format |
16-
| `-q`, `--quiet` | `bool` | | Only show model IDs |
12+
| Name | Type | Default | Description |
13+
|:----------------|:---------|:--------|:-----------------------------------------------|
14+
| `--backend` | `string` | | Specify the backend to use (llama.cpp, openai) |
15+
| `--json` | `bool` | | List models in a JSON format |
16+
| `--openai` | `bool` | | List models in an OpenAI format |
17+
| `-q`, `--quiet` | `bool` | | Only show model IDs |
1718

1819

1920
<!---MARKER_GEN_END-->

docs/reference/model_run.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ Run a model and interact with it using a submitted prompt or chat mode
55

66
### Options
77

8-
| Name | Type | Default | Description |
9-
|:----------|:-------|:--------|:---------------------|
10-
| `--debug` | `bool` | | Enable debug logging |
8+
| Name | Type | Default | Description |
9+
|:------------|:---------|:--------|:-----------------------------------------------|
10+
| `--backend` | `string` | | Specify the backend to use (llama.cpp, openai) |
11+
| `--debug` | `bool` | | Enable debug logging |
1112

1213

1314
<!---MARKER_GEN_END-->

0 commit comments

Comments
 (0)