diff --git a/commands/inspect.go b/commands/inspect.go index 474981c1..090b676c 100644 --- a/commands/inspect.go +++ b/commands/inspect.go @@ -30,7 +30,8 @@ func newInspectCmd() *cobra.Command { } model, err = client.List(false, openai, model) if err != nil { - return fmt.Errorf("Failed to list models: %v\n", err) + err = handleClientError(err, "Failed to list models") + return handleNotRunningError(err) } cmd.Println(model) return nil diff --git a/commands/list.go b/commands/list.go index 12aed24f..00ca9c72 100644 --- a/commands/list.go +++ b/commands/list.go @@ -20,7 +20,8 @@ func newListCmd() *cobra.Command { } models, err := client.List(jsonFormat, openai, "") if err != nil { - return fmt.Errorf("Failed to list models: %v\n", err) + err = handleClientError(err, "Failed to list models") + return handleNotRunningError(err) } cmd.Println(models) return nil diff --git a/commands/pull.go b/commands/pull.go index 43b48fe7..1e1725a9 100644 --- a/commands/pull.go +++ b/commands/pull.go @@ -29,7 +29,8 @@ func newPullCmd() *cobra.Command { } response, err := client.Pull(model) if err != nil { - return fmt.Errorf("Failed to pull model: %v\n", err) + err = handleClientError(err, "Failed to pull model") + return handleNotRunningError(err) } cmd.Println(response) return nil diff --git a/commands/rm.go b/commands/rm.go index bb71ca46..3393e87b 100644 --- a/commands/rm.go +++ b/commands/rm.go @@ -29,7 +29,8 @@ func newRemoveCmd() *cobra.Command { } response, err := client.Remove(model) if err != nil { - return fmt.Errorf("Failed to remove model: %v\n", err) + err = handleClientError(err, "Failed to remove model") + return handleNotRunningError(err) } cmd.Println(response) return nil diff --git a/commands/run.go b/commands/run.go index 001733f3..8b31859d 100644 --- a/commands/run.go +++ b/commands/run.go @@ -39,19 +39,19 @@ func newRunCmd() *cobra.Command { if _, err := client.List(false, false, model); err != nil { if !errors.Is(err, desktop.ErrNotFound) { - return fmt.Errorf("Failed to list model: %v\n", err) + return handleNotRunningError(handleClientError(err, "Failed to list models")) } cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.") response, err := client.Pull(model) if err != nil { - return fmt.Errorf("Failed to pull model: %v\n", err) + return handleNotRunningError(handleClientError(err, "Failed to pull model")) } cmd.Println(response) } if prompt != "" { if err := client.Chat(model, prompt); err != nil { - return fmt.Errorf("Failed to generate a response: %v\n", err) + return handleClientError(err, "Failed to generate a response") } cmd.Println() return nil @@ -75,7 +75,7 @@ func newRunCmd() *cobra.Command { } if err := client.Chat(model, userInput); err != nil { - cmd.PrintErrf("Failed to generate a response: %v\n", err) + cmd.PrintErr(handleClientError(err, "Failed to generate a response")) cmd.Print("> ") continue } diff --git a/commands/status.go b/commands/status.go index cb6fc39a..3a4a6258 100644 --- a/commands/status.go +++ b/commands/status.go @@ -2,9 +2,11 @@ package commands import ( "fmt" + "os" + + "github.com/docker/cli/cli-plugins/hooks" "github.com/docker/model-cli/desktop" "github.com/spf13/cobra" - "os" ) func newStatusCmd() *cobra.Command { @@ -24,6 +26,7 @@ func newStatusCmd() *cobra.Command { cmd.Println("Docker Model Runner is running") } else { cmd.Println("Docker Model Runner is not running") + hooks.PrintNextSteps(os.Stdout, []string{enableViaCLI, enableViaGUI}) os.Exit(1) } diff --git a/commands/utils.go b/commands/utils.go new file mode 100644 index 00000000..6ddb33ba --- /dev/null +++ b/commands/utils.go @@ -0,0 +1,34 @@ +package commands + +import ( + "bytes" + "fmt" + "strings" + + "github.com/docker/cli/cli-plugins/hooks" + "github.com/docker/model-cli/desktop" + "github.com/pkg/errors" +) + +const ( + enableViaCLI = "Enable Docker Model Runner via the CLI → docker desktop enable model-runner" + enableViaGUI = "Enable Docker Model Runner via the GUI → Go to Settings->Features in development->Enable Docker Model Runner" +) + +var notRunningErr = fmt.Errorf("Docker Model Runner is not running. Please start it and try again.\n") + +func handleClientError(err error, message string) error { + if errors.Is(err, desktop.ErrServiceUnavailable) { + return notRunningErr + } + return errors.Wrap(err, message) +} + +func handleNotRunningError(err error) error { + if errors.Is(err, notRunningErr) { + var buf bytes.Buffer + hooks.PrintNextSteps(&buf, []string{enableViaCLI, enableViaGUI}) + return fmt.Errorf("%w\n%s", err, strings.TrimRight(buf.String(), "\n")) + } + return err +} diff --git a/desktop/desktop.go b/desktop/desktop.go index d0344fdd..683653fc 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -22,7 +22,10 @@ import ( "go.opentelemetry.io/otel" ) -var ErrNotFound = errors.New("model not found") +var ( + ErrNotFound = errors.New("model not found") + ErrServiceUnavailable = errors.New("service unavailable") +) type otelErrorSilencer struct{} @@ -79,13 +82,14 @@ func (c *Client) Pull(model string) (string, error) { return "", fmt.Errorf("error marshaling request: %w", err) } - resp, err := c.dockerClient.HTTPClient().Post( - url(inference.ModelsPrefix+"/create"), - "application/json", + createPath := inference.ModelsPrefix + "/create" + resp, err := c.doRequest( + http.MethodPost, + createPath, bytes.NewReader(jsonData), ) if err != nil { - return "", fmt.Errorf("error querying %s: %w", inference.ModelsPrefix+"/create", err) + return "", c.handleQueryError(err, createPath) } defer resp.Body.Close() @@ -118,17 +122,20 @@ func (c *Client) List(jsonFormat, openai bool, model string) (string, error) { } modelsRoute += "/" + model } - resp, err := c.dockerClient.HTTPClient().Get(url(modelsRoute)) + + resp, err := c.doRequest(http.MethodGet, modelsRoute, nil) if err != nil { - return "", err + return "", c.handleQueryError(err, modelsRoute) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { if model != "" && resp.StatusCode == http.StatusNotFound { return "", errors.Wrap(ErrNotFound, model) } return "", fmt.Errorf("failed to list models: %s", resp.Status) } + body, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("failed to read response body: %w", err) @@ -188,13 +195,14 @@ func (c *Client) Chat(model, prompt string) error { return fmt.Errorf("error marshaling request: %w", err) } - resp, err := c.dockerClient.HTTPClient().Post( - url(inference.InferencePrefix+"/v1/chat/completions"), - "application/json", + chatCompletionsPath := inference.InferencePrefix + "/v1/chat/completions" + resp, err := c.doRequest( + http.MethodPost, + chatCompletionsPath, bytes.NewReader(jsonData), ) if err != nil { - return fmt.Errorf("error querying %s: %w", inference.InferencePrefix+"/v1/chat/completions", err) + return c.handleQueryError(err, chatCompletionsPath) } defer resp.Body.Close() @@ -239,18 +247,14 @@ func (c *Client) Chat(model, prompt string) error { } func (c *Client) Remove(model string) (string, error) { - req, err := http.NewRequest(http.MethodDelete, url(inference.ModelsPrefix+"/"+model), nil) + removePath := inference.ModelsPrefix + "/" + model + resp, err := c.doRequest(http.MethodDelete, removePath, nil) if err != nil { - return "", fmt.Errorf("error creating request: %w", err) - } - - resp, err := c.dockerClient.HTTPClient().Do(req) - if err != nil { - return "", fmt.Errorf("error querying %s: %w", inference.ModelsPrefix+"/"+model, err) + return "", c.handleQueryError(err, removePath) } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { // from common/pkg/inference/models/manager.go + if resp.StatusCode != http.StatusOK { var bodyStr string body, err := io.ReadAll(resp.Body) if err != nil { @@ -268,6 +272,36 @@ func url(path string) string { return fmt.Sprintf("http://localhost" + inference.ExperimentalEndpointsPrefix + path) } +// doRequest is a helper function that performs HTTP requests and handles 503 responses +func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequest(method, url(path), body) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := c.dockerClient.HTTPClient().Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode == http.StatusServiceUnavailable { + resp.Body.Close() + return nil, ErrServiceUnavailable + } + + return resp, nil +} + +func (c *Client) handleQueryError(err error, path string) error { + if errors.Is(err, ErrServiceUnavailable) { + return ErrServiceUnavailable + } + return fmt.Errorf("error querying %s: %w", path, err) +} + func prettyPrintModels(models []Model) string { var buf bytes.Buffer table := tablewriter.NewWriter(&buf)