Skip to content

Commit 09da727

Browse files
Merge pull request #198 from ibuildthecloud/openai-remote
feat: Add gptscript --list-models [PROVIDER]
2 parents b66246f + 45a7520 commit 09da727

File tree

5 files changed

+28
-14
lines changed

5 files changed

+28
-14
lines changed

pkg/cli/gptscript.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) error {
146146
defer gptScript.Close()
147147

148148
if r.ListModels {
149-
models, err := gptScript.ListModels(cmd.Context())
149+
models, err := gptScript.ListModels(cmd.Context(), args...)
150150
if err != nil {
151151
return err
152152
}

pkg/gptscript/gptscript.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,6 @@ func (g *GPTScript) GetModel() engine.Model {
103103
return g.Registry
104104
}
105105

106-
func (g *GPTScript) ListModels(ctx context.Context) ([]string, error) {
107-
return g.Registry.ListModels(ctx)
106+
func (g *GPTScript) ListModels(ctx context.Context, providers ...string) ([]string, error) {
107+
return g.Registry.ListModels(ctx, providers...)
108108
}

pkg/llm/registry.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111

1212
type Client interface {
1313
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
14-
ListModels(ctx context.Context) (result []string, _ error)
14+
ListModels(ctx context.Context, providers ...string) (result []string, _ error)
1515
Supports(ctx context.Context, modelName string) (bool, error)
1616
}
1717

@@ -28,9 +28,9 @@ func (r *Registry) AddClient(client Client) error {
2828
return nil
2929
}
3030

31-
func (r *Registry) ListModels(ctx context.Context) (result []string, _ error) {
31+
func (r *Registry) ListModels(ctx context.Context, providers ...string) (result []string, _ error) {
3232
for _, v := range r.clients {
33-
models, err := v.ListModels(ctx)
33+
models, err := v.ListModels(ctx, providers...)
3434
if err != nil {
3535
return nil, err
3636
}

pkg/openai/client.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,12 @@ func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
127127
return slices.Contains(models, modelName), nil
128128
}
129129

130-
func (c *Client) ListModels(ctx context.Context) (result []string, _ error) {
130+
func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) {
131+
// Only serve if providers is empty or "" is in the list
132+
if len(providers) != 0 && !slices.Contains(providers, "") {
133+
return nil, nil
134+
}
135+
131136
models, err := c.c.ListModels(ctx)
132137
if err != nil {
133138
return nil, err

pkg/remote/remote.go

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
"github.com/gptscript-ai/gptscript/pkg/openai"
1616
"github.com/gptscript-ai/gptscript/pkg/runner"
1717
"github.com/gptscript-ai/gptscript/pkg/types"
18-
"golang.org/x/exp/maps"
1918
)
2019

2120
type Client struct {
@@ -49,13 +48,23 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
4948
return client.Call(ctx, messageRequest, status)
5049
}
5150

52-
func (c *Client) ListModels(_ context.Context) (result []string, _ error) {
53-
c.clientsLock.Lock()
54-
defer c.clientsLock.Unlock()
51+
func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) {
52+
for _, provider := range providers {
53+
client, err := c.load(ctx, provider)
54+
if err != nil {
55+
return nil, err
56+
}
57+
models, err := client.ListModels(ctx, "")
58+
if err != nil {
59+
return nil, err
60+
}
61+
for _, model := range models {
62+
result = append(result, model+" from "+provider)
63+
}
64+
}
5565

56-
keys := maps.Keys(c.models)
57-
sort.Strings(keys)
58-
return keys, nil
66+
sort.Strings(result)
67+
return
5968
}
6069

6170
func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {

0 commit comments

Comments
 (0)