diff --git a/commands/completion/functions.go b/commands/completion/functions.go index 91a44bb9..3c69375b 100644 --- a/commands/completion/functions.go +++ b/commands/completion/functions.go @@ -1,6 +1,8 @@ package completion import ( + "strings" + "github.com/docker/model-cli/desktop" "github.com/spf13/cobra" ) @@ -31,3 +33,56 @@ func ModelNames(desktopClient func() *desktop.Client, limit int) cobra.Completio return names, cobra.ShellCompDirectiveNoFileComp } } + +// ModelNamesAndTags offers completion that matches the base model name along with its tags. +// If the model has multiple tags, match both the base model name and each tag. +func ModelNamesAndTags(desktopClient func() *desktop.Client, limit int) cobra.CompletionFunc { + return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + // HACK: Invoke rootCmd's PersistentPreRunE, which is needed for context + // detection and client initialization. This function isn't invoked + // automatically on autocompletion paths. + cmd.Parent().PersistentPreRunE(cmd, args) + + if limit > 0 && len(args) >= limit { + return nil, cobra.ShellCompDirectiveNoFileComp + } + + models, err := desktopClient().List() + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + var names []string + + modelNames := make(map[string]bool) + modelTags := make(map[string][]string) + + for _, m := range models { + for _, tag := range m.Tags { + // Extract model name (everything before the first colon or the full tag if no colon). + modelName, _, _ := strings.Cut(tag, ":") + modelNames[modelName] = true + modelTags[modelName] = append(modelTags[modelName], tag) + } + } + + for name := range modelNames { + // If model has multiple tags, suggest the base model name and all specific tags. + if len(modelTags[name]) > 1 { + names = append(names, name) + for _, tag := range modelTags[name] { + names = append(names, tag) + // If this model doesn't have a tag, also add the :latest variant. + if tag == name { + names = append(names, tag+":latest") + } + } + } else { + // If only one tag, just suggest that tag to avoid duplication. + names = append(names, modelTags[name][0]) + } + } + + return names, cobra.ShellCompDirectiveNoSpace + } +} diff --git a/commands/list.go b/commands/list.go index fbaaf58c..db873a85 100644 --- a/commands/list.go +++ b/commands/list.go @@ -4,6 +4,8 @@ import ( "bytes" "fmt" "os" + "slices" + "strings" "time" "github.com/docker/go-units" @@ -50,14 +52,18 @@ func newListCmd() *cobra.Command { if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), standaloneInstallPrinter); err != nil { return fmt.Errorf("unable to initialize standalone model runner: %w", err) } - models, err := listModels(openai, backend, desktopClient, quiet, jsonFormat, apiKey) + var modelFilter string + if len(args) > 0 { + modelFilter = args[0] + } + models, err := listModels(openai, backend, desktopClient, quiet, jsonFormat, apiKey, modelFilter) if err != nil { return err } cmd.Print(models) return nil }, - ValidArgsFunction: completion.NoComplete, + ValidArgsFunction: completion.ModelNamesAndTags(getDesktopClient, 1), } c.Flags().BoolVar(&jsonFormat, "json", false, "List models in a JSON format") c.Flags().BoolVar(&openai, "openai", false, "List models in an OpenAI format") @@ -67,7 +73,7 @@ func newListCmd() *cobra.Command { return c } -func listModels(openai bool, backend string, desktopClient *desktop.Client, quiet bool, jsonFormat bool, apiKey string) (string, error) { +func listModels(openai bool, backend string, desktopClient *desktop.Client, quiet bool, jsonFormat bool, apiKey string, modelFilter string) (string, error) { if openai || backend == "openai" { models, err := desktopClient.ListOpenAI(backend, apiKey) if err != nil { @@ -81,6 +87,25 @@ func listModels(openai bool, backend string, desktopClient *desktop.Client, quie err = handleClientError(err, "Failed to list models") return "", handleNotRunningError(err) } + + if modelFilter != "" { + var filteredModels []dmrm.Model + for _, m := range models { + hasMatchingTag := false + for _, tag := range m.Tags { + modelName, _, _ := strings.Cut(tag, ":") + if slices.Contains([]string{modelName, tag + ":latest", tag}, modelFilter) { + hasMatchingTag = true + break + } + } + if hasMatchingTag { + filteredModels = append(filteredModels, m) + } + } + models = filteredModels + } + if jsonFormat { return formatter.ToStandardJSON(models) }