diff --git a/commands/compose.go b/commands/compose.go new file mode 100644 index 00000000..b2aee891 --- /dev/null +++ b/commands/compose.go @@ -0,0 +1,109 @@ +package commands + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/docker/model-cli/desktop" + "github.com/spf13/cobra" +) + +func newComposeCmd(desktopClient *desktop.Client) *cobra.Command { + + c := &cobra.Command{ + Use: "compose EVENT", + } + c.AddCommand(newUpCommand(desktopClient)) + c.AddCommand(newDownCommand()) + c.Hidden = true + c.Flags().String("project-name", "", "compose project name") // unused by model + + return c +} + +func newUpCommand(desktopClient *desktop.Client) *cobra.Command { + var model string + c := &cobra.Command{ + Use: "up", + RunE: func(cmd *cobra.Command, args []string) error { + if model == "" { + err := errors.New("options.model is required") + sendError(err.Error()) + return err + } + + _, err := desktopClient.Pull(model, func(s string) { + sendInfo(s) + }) + if err != nil { + sendErrorf("Failed to pull model", err) + return fmt.Errorf("Failed to pull model: %v\n", err) + } + + // FIXME get actual URL from Docker Desktop + setenv("URL", "http://model-runner.docker.internal/engines/v1/") + setenv("MODEL", model) + + return nil + }, + } + c.Flags().StringVar(&model, "model", "", "model to use") + return c +} + +func newDownCommand() *cobra.Command { + c := &cobra.Command{ + Use: "down", + RunE: func(cmd *cobra.Command, args []string) error { + // No required cleanup on down + return nil + }, + } + return c +} + +type jsonMessage struct { + Type string `json:"type"` + Message string `json:"message"` +} + +func setenv(k, v string) error { + marshal, err := json.Marshal(jsonMessage{ + Type: "setenv", + Message: fmt.Sprintf("%v=%v", k, v), + }) + if err != nil { + return err + } + _, err = fmt.Println(string(marshal)) + return err +} + +func sendErrorf(message string, args ...any) error { + return sendError(fmt.Sprintf(message, args...)) +} + +func sendError(message string) error { + marshal, err := json.Marshal(jsonMessage{ + Type: "error", + Message: message, + }) + if err != nil { + return err + } + _, err = fmt.Println(string(marshal)) + return err +} + +func sendInfo(s string) error { + marshal, err := json.Marshal(jsonMessage{ + Type: "info", + Message: s, + }) + if err != nil { + return err + } + _, err = fmt.Println(string(marshal)) + return err +} diff --git a/commands/pull.go b/commands/pull.go index f96c732a..8f37e386 100644 --- a/commands/pull.go +++ b/commands/pull.go @@ -23,7 +23,7 @@ func newPullCmd(desktopClient *desktop.Client) *cobra.Command { }, RunE: func(cmd *cobra.Command, args []string) error { model := args[0] - response, err := desktopClient.Pull(model) + response, err := desktopClient.Pull(model, TUIProgress) if err != nil { err = handleClientError(err, "Failed to pull model") return handleNotRunningError(err) @@ -34,3 +34,7 @@ func newPullCmd(desktopClient *desktop.Client) *cobra.Command { } return c } + +func TUIProgress(line string) { + fmt.Print("\r\033[K", line) +} diff --git a/commands/root.go b/commands/root.go index e546ae90..294e7c53 100644 --- a/commands/root.go +++ b/commands/root.go @@ -33,6 +33,7 @@ func NewRootCmd() *cobra.Command { newRunCmd(desktopClient), newRemoveCmd(desktopClient), newInspectCmd(desktopClient), + newComposeCmd(desktopClient), ) return rootCmd } diff --git a/commands/run.go b/commands/run.go index f426a84c..2e0d03d4 100644 --- a/commands/run.go +++ b/commands/run.go @@ -37,7 +37,7 @@ func newRunCmd(desktopClient *desktop.Client) *cobra.Command { return handleNotRunningError(handleClientError(err, "Failed to list models")) } cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.") - response, err := desktopClient.Pull(model) + response, err := desktopClient.Pull(model, TUIProgress) if err != nil { return handleNotRunningError(handleClientError(err, "Failed to pull model")) } diff --git a/desktop/desktop.go b/desktop/desktop.go index 7c986d96..44d02956 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -79,7 +79,7 @@ func (c *Client) Status() Status { } } -func (c *Client) Pull(model string) (string, error) { +func (c *Client) Pull(model string, progress func(string)) (string, error) { jsonData, err := json.Marshal(models.ModelCreateRequest{From: model}) if err != nil { return "", fmt.Errorf("error marshaling request: %w", err) @@ -105,7 +105,7 @@ func (c *Client) Pull(model string) (string, error) { for scanner.Scan() { progressLine := scanner.Text() if progressLine != "" { - fmt.Print("\r\033[K", progressLine) + progress(progressLine) } }