diff --git a/commands/completion/functions.go b/commands/completion/functions.go index 91a44bb9..a85b4218 100644 --- a/commands/completion/functions.go +++ b/commands/completion/functions.go @@ -1,6 +1,8 @@ package completion import ( + "strings" + "github.com/docker/model-cli/desktop" "github.com/spf13/cobra" ) @@ -31,3 +33,12 @@ func ModelNames(desktopClient func() *desktop.Client, limit int) cobra.Completio return names, cobra.ShellCompDirectiveNoFileComp } } + +// ensures the model string contains a slash, and if not, prepends "ai/". +func AddDefaultNamespace(model string) string { + if strings.Contains(model, "/") { + return model + } + + return "ai/" + model +} diff --git a/commands/inspect.go b/commands/inspect.go index 8f1e81a6..8c4d9d83 100644 --- a/commands/inspect.go +++ b/commands/inspect.go @@ -23,6 +23,8 @@ func newInspectCmd() *cobra.Command { "See 'docker model inspect --help' for more information", ) } + + args[0] = completion.AddDefaultNamespace(args[0]) return nil }, RunE: func(cmd *cobra.Command, args []string) error { @@ -47,7 +49,7 @@ func newInspectCmd() *cobra.Command { } func inspectModel(args []string, openai bool, remote bool, desktopClient *desktop.Client) (string, error) { - modelName := args[0] + modelName := completion.AddDefaultNamespace(args[0]) if openai { model, err := desktopClient.InspectOpenAI(modelName) if err != nil { diff --git a/commands/list.go b/commands/list.go index fbaaf58c..4e51daea 100644 --- a/commands/list.go +++ b/commands/list.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "os" + "strings" "time" "github.com/docker/go-units" @@ -127,6 +128,7 @@ func prettyPrintModels(models []dmrm.Model) string { continue } for _, tag := range m.Tags { + tag = strings.TrimPrefix(tag, "ai/") appendRow(table, tag, m) } } diff --git a/commands/package.go b/commands/package.go index 3e3cec02..580454f0 100644 --- a/commands/package.go +++ b/commands/package.go @@ -63,7 +63,7 @@ func newPackagedCmd() *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { - opts.tag = args[0] + opts.tag = completion.AddDefaultNamespace(args[0]) if err := packageModel(cmd, opts); err != nil { cmd.PrintErrln("Failed to package model") return fmt.Errorf("package model: %w", err) diff --git a/commands/pull.go b/commands/pull.go index a85f2024..7680b82d 100644 --- a/commands/pull.go +++ b/commands/pull.go @@ -41,6 +41,7 @@ func newPullCmd() *cobra.Command { } func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string, ignoreRuntimeMemoryCheck bool) error { + model = completion.AddDefaultNamespace(model) var progress func(string) if isatty.IsTerminal(os.Stdout.Fd()) { progress = TUIProgress diff --git a/commands/push.go b/commands/push.go index ed94f1a6..41592b4c 100644 --- a/commands/push.go +++ b/commands/push.go @@ -34,6 +34,7 @@ func newPushCmd() *cobra.Command { } func pushModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error { + model = completion.AddDefaultNamespace(model) response, progressShown, err := desktopClient.Push(model, TUIProgress) // Add a newline before any output (success or error) if progress was shown. diff --git a/commands/rm.go b/commands/rm.go index 95159fb2..81053bee 100644 --- a/commands/rm.go +++ b/commands/rm.go @@ -21,6 +21,8 @@ func newRemoveCmd() *cobra.Command { "See 'docker model rm --help' for more information", ) } + + args[0] = completion.AddDefaultNamespace(args[0]) return nil }, RunE: func(cmd *cobra.Command, args []string) error { diff --git a/commands/run.go b/commands/run.go index 24c73f6e..4557f0c0 100644 --- a/commands/run.go +++ b/commands/run.go @@ -101,7 +101,7 @@ func newRunCmd() *cobra.Command { return err } - model := args[0] + model := completion.AddDefaultNamespace(args[0]) prompt := "" args_len := len(args) if args_len > 1 { diff --git a/commands/unload.go b/commands/unload.go index 97f3faa8..8207394f 100644 --- a/commands/unload.go +++ b/commands/unload.go @@ -54,6 +54,8 @@ func newUnloadCmd() *cobra.Command { "See 'docker model unload --help' for more information.", ) } + + args[0] = completion.AddDefaultNamespace(args[0]) return nil } c.Flags().BoolVar(&all, "all", false, "Unload all running models")