diff --git a/go/go.mod b/go/go.mod index bd4ffcf833..fb4d2d0b28 100644 --- a/go/go.mod +++ b/go/go.mod @@ -21,6 +21,7 @@ require ( github.com/jba/slog v0.2.0 github.com/lib/pq v1.10.9 github.com/pgvector/pgvector-go v0.2.0 + github.com/stretchr/testify v1.10.0 github.com/weaviate/weaviate v1.26.0-rc.1 github.com/weaviate/weaviate-go-client/v4 v4.15.0 github.com/xeipuuv/gojsonschema v1.2.0 @@ -37,6 +38,16 @@ require ( gopkg.in/yaml.v3 v3.0.1 ) +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + google.golang.org/protobuf v1.34.2 // indirect +) + require ( cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/auth v0.9.3 // indirect @@ -53,7 +64,7 @@ require ( github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect - github.com/blues/jsonata-go v1.5.4 // indirect + github.com/blues/jsonata-go v1.5.4 github.com/buger/jsonparser v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect @@ -78,6 +89,7 @@ require ( github.com/mailru/easyjson v0.9.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/oklog/ulid v1.3.1 // indirect + github.com/openai/openai-go v0.1.0-alpha.65 github.com/pkg/errors v0.9.1 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect @@ -98,5 +110,4 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/grpc v1.66.2 // indirect - google.golang.org/protobuf v1.34.2 // indirect ) diff --git a/go/go.sum b/go/go.sum index 00a99203a2..9b59899408 100644 --- a/go/go.sum +++ b/go/go.sum @@ -221,6 +221,8 @@ github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJ github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/openai/openai-go v0.1.0-alpha.65 h1:G12sA6OaL+cVMElMO3m5RVFwKhhg40kmGeGhaYZIoYw= +github.com/openai/openai-go v0.1.0-alpha.65/go.mod h1:3SdE6BffOX9HPEQv8IL/fi3LYZ5TUpRYaqGQZbyk11A= github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= github.com/pgvector/pgvector-go v0.2.0 h1:NZdW4NxUxdSCzaev3LVHb9ORf+LdX+uZOQVqQ6s2Zyg= github.com/pgvector/pgvector-go v0.2.0/go.mod h1:OQpvU5QZGQOPI9quIXAyHaRZ5yGk/RGUDbs9C3DPUNE= @@ -256,7 +258,17 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= github.com/uptrace/bun v1.1.12 h1:sOjDVHxNTuM6dNGaba0wUuz7KvDE1BmNu9Gqs2gJSXQ= diff --git a/go/plugins/compat_oai/README.md b/go/plugins/compat_oai/README.md new file mode 100644 index 0000000000..509145e122 --- /dev/null +++ b/go/plugins/compat_oai/README.md @@ -0,0 +1,77 @@ +# OpenAI-Compatible Plugin Package + +This directory contains a package for building plugins that are compatible with the OpenAI API specification, along with plugins built on top of this package. + +## Package Overview + +The `compat_oai` package provides a base implementation (`OpenAICompatible`) that handles: +- Model and embedder registration +- Message handling +- Tool support +- Configuration management + +## Usage Example + +Here's how to implement a new OpenAI-compatible plugin: + +```go +type MyPlugin struct { + compat_oai.OpenAICompatible + // define other plugin-specific fields +} + +var ( + supportedModels = map[string]ai.ModelInfo{ + // define supported models + } +) + +// Implement required methods +func (p *MyPlugin) Init(ctx context.Context, g *genkit.Genkit) error { + // initialize the plugin with the common compatible package + if err := p.OpenAICompatible.Init(ctx, g); err != nil { + return err + } + + // Define plugin-specific models + for model, info := range supportedModels { + if _, err := p.DefineModel(g, p.Provider, model, info); err != nil { + return err + } + } + + // Define embedders, if applicable + + return nil +} + +func (p *MyPlugin) Name() string { + return p.Provider +} +``` + +See the `openai` and `anthropic` directories for complete implementations. + +## Running Tests + +Set your API keys: +```bash +export OPENAI_API_KEY= +export ANTHROPIC_API_KEY= +``` + +Run all tests: +```bash +go test -v ./... +``` + +Run specific plugin tests: +```bash +# OpenAI tests +go test -v ./openai + +# Anthropic tests +go test -v ./anthropic +``` + +Note: Tests will be skipped if the required API keys are not set. \ No newline at end of file diff --git a/go/plugins/compat_oai/anthropic/README.md b/go/plugins/compat_oai/anthropic/README.md new file mode 100644 index 0000000000..70fb2ecc4e --- /dev/null +++ b/go/plugins/compat_oai/anthropic/README.md @@ -0,0 +1,53 @@ +# Anthropic Plugin + +This plugin provides a simple interface for using Anthropic's services. + +## Prerequisites + +- Go installed on your system +- An Anthropic API key + +## Running Tests + +First, set your Anthropic API key as an environment variable: + +```bash +export ANTHROPIC_API_KEY= +``` + +### Running All Tests +To run all tests in the directory: +```bash +go test -v . +``` + +### Running Tests from Specific Files +To run tests from a specific file: +```bash +# Run only generate_live_test.go tests +go test -run "^TestGenerator" + +# Run only anthropic_live_test.go tests +go test -run "^TestPlugin" +``` + +### Running Individual Tests +To run a specific test case: +```bash +# Run only the streaming test from anthropic_live_test.go +go test -run "TestPlugin/streaming" + +# Run only the Complete test from generate_live_test.go +go test -run "TestGenerator_Complete" + +# Run only the Stream test from generate_live_test.go +go test -run "TestGenerator_Stream" +``` + +### Test Output Verbosity +Add the `-v` flag for verbose output: +```bash +go test -v -run "TestPlugin/streaming" +``` + +Note: All live tests require the ANTHROPIC_API_KEY environment variable to be set. Tests will be skipped if the API key is not provided. diff --git a/go/plugins/compat_oai/anthropic/anthropic.go b/go/plugins/compat_oai/anthropic/anthropic.go new file mode 100644 index 0000000000..ed33a60fc3 --- /dev/null +++ b/go/plugins/compat_oai/anthropic/anthropic.go @@ -0,0 +1,123 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package anthropic + +import ( + "context" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/compat_oai" + "github.com/openai/openai-go/option" +) + +const ( + provider = "anthropic" + baseURL = "https://api.anthropic.com/v1" +) + +var ( + // Supported models: https://docs.anthropic.com/en/docs/about-claude/models/all-models + supportedModels = map[string]ai.ModelInfo{ + "claude-3-7-sonnet-20250219": { + Label: "Claude 3.7 Sonnet", + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: false, // NOTE: Anthropic supports tool use, but it's not compatible with the OpenAI API + SystemRole: true, + Media: true, + }, + Versions: []string{"claude-3-7-sonnet-latest", "claude-3-7-sonnet-20250219"}, + }, + "claude-3-5-haiku-20241022": { + Label: "Claude 3.5 Haiku", + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: false, // NOTE: Anthropic supports tool use, but it's not compatible with the OpenAI API + SystemRole: true, + Media: true, + }, + Versions: []string{"claude-3-5-haiku-latest", "claude-3-5-haiku-20241022"}, + }, + "claude-3-5-sonnet-20240620": { + Label: "Claude 3.5 Sonnet", + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: false, // NOTE: Anthropic supports tool use, but it's not compatible with the OpenAI API + SystemRole: false, // NOTE: This model does not support system role + Media: true, + }, + Versions: []string{"claude-3-5-sonnet-20240620"}, + }, + "claude-3-opus-20240229": { + Label: "Claude 3 Opus", + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: false, // NOTE: Anthropic supports tool use, but it's not compatible with the OpenAI API + SystemRole: false, // NOTE: This model does not support system role + Media: true, + }, + Versions: []string{"claude-3-opus-latest", "claude-3-opus-20240229"}, + }, + "claude-3-haiku-20240307": { + Label: "Claude 3 Haiku", + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: false, // NOTE: Anthropic supports tool use, but it's not compatible with the OpenAI API + SystemRole: false, // NOTE: This model does not support system role + Media: true, + }, + Versions: []string{"claude-3-haiku-20240307"}, + }, + } +) + +type Anthropic struct { + Opts []option.RequestOption + openAICompatible compat_oai.OpenAICompatible +} + +// Name implements genkit.Plugin. +func (a *Anthropic) Name() string { + return provider +} + +func (a *Anthropic) Init(ctx context.Context, g *genkit.Genkit) error { + // Set the base URL + a.Opts = append(a.Opts, option.WithBaseURL(baseURL)) + + // initialize OpenAICompatible + a.openAICompatible.Opts = a.Opts + if err := a.openAICompatible.Init(ctx, g); err != nil { + return err + } + + // define default models + for model, info := range supportedModels { + if _, err := a.DefineModel(g, model, info); err != nil { + return err + } + } + + return nil +} + +func (a *Anthropic) Model(g *genkit.Genkit, name string) ai.Model { + return a.openAICompatible.Model(g, name, provider) +} + +func (a *Anthropic) DefineModel(g *genkit.Genkit, name string, info ai.ModelInfo) (ai.Model, error) { + return a.openAICompatible.DefineModel(g, provider, name, info) +} diff --git a/go/plugins/compat_oai/anthropic/anthropic_live_test.go b/go/plugins/compat_oai/anthropic/anthropic_live_test.go new file mode 100644 index 0000000000..d89977bb6b --- /dev/null +++ b/go/plugins/compat_oai/anthropic/anthropic_live_test.go @@ -0,0 +1,124 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package anthropic_test + +import ( + "context" + "os" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/compat_oai/anthropic" + "github.com/openai/openai-go/option" +) + +func TestPlugin(t *testing.T) { + apiKey := os.Getenv("ANTHROPIC_API_KEY") + if apiKey == "" { + t.Skip("Skipping test: ANTHROPIC_API_KEY environment variable not set") + } + + ctx := context.Background() + + // Initialize genkit with claude-3-7-sonnet as default model + g, err := genkit.Init( + ctx, + genkit.WithDefaultModel("anthropic/claude-3-7-sonnet-20250219"), + genkit.WithPlugins(&anthropic.Anthropic{ + Opts: []option.RequestOption{ + option.WithAPIKey(apiKey), + }, + }), + ) + if err != nil { + t.Fatal(err) + } + t.Log("genkit initialized") + + t.Run("basic completion", func(t *testing.T) { + t.Log("generating basic completion response") + resp, err := genkit.Generate(ctx, g, + ai.WithPromptText("What is the capital of France?"), + ) + if err != nil { + t.Fatal("error generating basic completion response: ", err) + } + t.Logf("basic completion response: %+v", resp) + + out := resp.Message.Content[0].Text + if !strings.Contains(strings.ToLower(out), "paris") { + t.Errorf("got %q, expecting it to contain 'Paris'", out) + } + + // Verify usage statistics are present + if resp.Usage == nil || resp.Usage.TotalTokens == 0 { + t.Error("Expected non-zero usage statistics") + } + }) + + t.Run("streaming", func(t *testing.T) { + var streamedOutput string + chunks := 0 + + final, err := genkit.Generate(ctx, g, + ai.WithPromptText("Write a short paragraph about artificial intelligence."), + ai.WithStreaming(func(ctx context.Context, chunk *ai.ModelResponseChunk) error { + chunks++ + for _, content := range chunk.Content { + streamedOutput += content.Text + } + return nil + })) + if err != nil { + t.Fatal(err) + } + + // Verify streaming worked + if chunks <= 1 { + t.Error("Expected multiple chunks for streaming") + } + + // Verify final output matches streamed content + finalOutput := "" + for _, content := range final.Message.Content { + finalOutput += content.Text + } + if streamedOutput != finalOutput { + t.Errorf("Streaming output doesn't match final output\nStreamed: %s\nFinal: %s", + streamedOutput, finalOutput) + } + + t.Logf("streaming response: %+v", finalOutput) + }) + + t.Run("system message", func(t *testing.T) { + resp, err := genkit.Generate(ctx, g, + ai.WithPromptText("What are you?"), + ai.WithSystemText("You are a helpful math tutor who loves numbers."), + ) + if err != nil { + t.Fatal(err) + } + + out := resp.Message.Content[0].Text + if !strings.Contains(strings.ToLower(out), "math") { + t.Errorf("got %q, expecting response to mention being a math tutor", out) + } + + t.Logf("system message response: %+v", out) + }) +} diff --git a/go/plugins/compat_oai/compat_oai.go b/go/plugins/compat_oai/compat_oai.go new file mode 100644 index 0000000000..f13bb89afd --- /dev/null +++ b/go/plugins/compat_oai/compat_oai.go @@ -0,0 +1,186 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compat_oai + +import ( + "context" + "errors" + "strings" + "sync" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + openaiGo "github.com/openai/openai-go" + "github.com/openai/openai-go/option" +) + +var ( + // BasicText describes model capabilities for text-only GPT models. + BasicText = ai.ModelInfo{ + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: true, + SystemRole: true, + Media: false, + }, + } + + // Multimodal describes model capabilities for multimodal GPT models. + Multimodal = ai.ModelInfo{ + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: true, + SystemRole: true, + Media: true, + ToolChoice: true, + }, + } +) + +// OpenAICompatible is a plugin that provides compatibility with OpenAI's Compatible APIs. +// It allows defining models and embedders that can be used with Genkit. +type OpenAICompatible struct { + // mu protects concurrent access to the client and initialization state + mu sync.Mutex + + // initted tracks whether the plugin has been initialized + initted bool + + // client is the OpenAI client used for making API requests + // see https://github.com/openai/openai-go + client *openaiGo.Client + + // Opts contains request options for the OpenAI client. + // Required: Must include at least WithAPIKey for authentication. + // Optional: Can include other options like WithOrganization, WithBaseURL, etc. + Opts []option.RequestOption + + // Provider is a unique identifier for the plugin. + // This will be used as a prefix for model names (e.g., "myprovider/model-name"). + // Should be lowercase and match the plugin's Name() method. + Provider string +} + +// Init implements genkit.Plugin. +func (o *OpenAICompatible) Init(ctx context.Context, g *genkit.Genkit) error { + o.mu.Lock() + defer o.mu.Unlock() + if o.initted { + return errors.New("compat_oai.Init already called") + } + + // create client + client := openaiGo.NewClient(o.Opts...) + o.client = client + o.initted = true + + return nil +} + +// Name implements genkit.Plugin. +func (o *OpenAICompatible) Name() string { + return o.Provider +} + +// DefineModel defines a model in the registry +func (o *OpenAICompatible) DefineModel(g *genkit.Genkit, provider, name string, info ai.ModelInfo) (ai.Model, error) { + o.mu.Lock() + defer o.mu.Unlock() + if !o.initted { + return nil, errors.New("OpenAICompatible.Init not called") + } + + // Strip provider prefix if present to check against supportedModels + modelName := strings.TrimPrefix(name, provider+"/") + + return genkit.DefineModel(g, provider, name, &info, func( + ctx context.Context, + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, + ) (*ai.ModelResponse, error) { + + // Configure the response generator with input + generator := NewModelGenerator(o.client, modelName).WithMessages(input.Messages).WithConfig(input.Config).WithTools(input.Tools, input.ToolChoice) + + // Generate response + resp, err := generator.Generate(ctx, cb) + if err != nil { + return nil, err + } + + return resp, nil + }), nil +} + +// DefineEmbedder defines an embedder with a given name. +func (o *OpenAICompatible) DefineEmbedder(g *genkit.Genkit, provider, name string) (ai.Embedder, error) { + o.mu.Lock() + defer o.mu.Unlock() + if !o.initted { + return nil, errors.New("OpenAICompatible.Init not called") + } + + return genkit.DefineEmbedder(g, provider, name, func(ctx context.Context, input *ai.EmbedRequest) (*ai.EmbedResponse, error) { + var data openaiGo.EmbeddingNewParamsInputArrayOfStrings + for _, doc := range input.Documents { + for _, p := range doc.Content { + data = append(data, p.Text) + } + } + + params := openaiGo.EmbeddingNewParams{ + Input: openaiGo.F[openaiGo.EmbeddingNewParamsInputUnion](data), + Model: openaiGo.F(name), + EncodingFormat: openaiGo.F(openaiGo.EmbeddingNewParamsEncodingFormatFloat), + } + + embeddingResp, err := o.client.Embeddings.New(ctx, params) + if err != nil { + return nil, err + } + + resp := &ai.EmbedResponse{} + for _, emb := range embeddingResp.Data { + embedding := make([]float32, len(emb.Embedding)) + for i, val := range emb.Embedding { + embedding[i] = float32(val) + } + resp.Embeddings = append(resp.Embeddings, &ai.DocumentEmbedding{Embedding: embedding}) + } + return resp, nil + }), nil +} + +// IsDefinedEmbedder reports whether the named [Embedder] is defined by this plugin. +func (o *OpenAICompatible) IsDefinedEmbedder(g *genkit.Genkit, name string, provider string) bool { + return genkit.LookupEmbedder(g, provider, name) != nil +} + +// Embedder returns the [ai.Embedder] with the given name. +// It returns nil if the embedder was not defined. +func (o *OpenAICompatible) Embedder(g *genkit.Genkit, name string, provider string) ai.Embedder { + return genkit.LookupEmbedder(g, provider, name) +} + +// Model returns the [ai.Model] with the given name. +// It returns nil if the model was not defined. +func (o *OpenAICompatible) Model(g *genkit.Genkit, name string, provider string) ai.Model { + return genkit.LookupModel(g, provider, name) +} + +// IsDefinedModel reports whether the named [Model] is defined by this plugin. +func (o *OpenAICompatible) IsDefinedModel(g *genkit.Genkit, name string, provider string) bool { + return genkit.LookupModel(g, provider, name) != nil +} diff --git a/go/plugins/compat_oai/generate.go b/go/plugins/compat_oai/generate.go new file mode 100644 index 0000000000..e1a1149f45 --- /dev/null +++ b/go/plugins/compat_oai/generate.go @@ -0,0 +1,432 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compat_oai + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/firebase/genkit/go/ai" + "github.com/openai/openai-go" + "github.com/openai/openai-go/shared" +) + +// mapToStruct unmarshals a map[string]any to the expected config type. +func mapToStruct(m map[string]any, v any) error { + jsonData, err := json.Marshal(m) + if err != nil { + return err + } + return json.Unmarshal(jsonData, v) +} + +// ModelGenerator handles OpenAI generation requests +type ModelGenerator struct { + client *openai.Client + modelName string + request *openai.ChatCompletionNewParams + // Store any errors that occur during building + err error +} + +func (g *ModelGenerator) GetRequest() *openai.ChatCompletionNewParams { + return g.request +} + +// NewModelGenerator creates a new ModelGenerator instance +func NewModelGenerator(client *openai.Client, modelName string) *ModelGenerator { + return &ModelGenerator{ + client: client, + modelName: modelName, + request: &openai.ChatCompletionNewParams{ + Model: openai.F(modelName), + }, + } +} + +// WithMessages adds messages to the request +func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator { + // Return early if we already have an error + if g.err != nil { + return g + } + + if messages == nil { + return g + } + + oaiMessages := make([]openai.ChatCompletionMessageParamUnion, 0, len(messages)) + for _, msg := range messages { + content := g.concatenateContent(msg.Content) + switch msg.Role { + case ai.RoleSystem: + oaiMessages = append(oaiMessages, openai.SystemMessage(content)) + case ai.RoleModel: + oaiMessages = append(oaiMessages, openai.AssistantMessage(content)) + + am := openai.ChatCompletionAssistantMessageParam{ + Role: openai.F(openai.ChatCompletionAssistantMessageParamRoleAssistant), + } + if msg.Content[0].Text != "" { + am.Content = openai.F([]openai.ChatCompletionAssistantMessageParamContentUnion{ + openai.TextPart(msg.Content[0].Text), + }) + } + toolCalls := convertToolCalls(msg.Content) + if len(toolCalls) > 0 { + am.ToolCalls = openai.F(toolCalls) + } + oaiMessages = append(oaiMessages, am) + case ai.RoleTool: + for _, p := range msg.Content { + if !p.IsToolResponse() { + continue + } + tm := openai.ToolMessage( + // NOTE: Temporarily set its name instead of its ref (i.e. call_xxxxx) since it's not defined in the ai.ToolResponse struct. + p.ToolResponse.Name, + anyToJSONString(p.ToolResponse.Output), + ) + oaiMessages = append(oaiMessages, tm) + } + default: + oaiMessages = append(oaiMessages, openai.UserMessage(content)) + } + } + g.request.Messages = openai.F(oaiMessages) + return g +} + +// OpenaiConfig mirrors the OpenAI API configuration fields +type OpenAIConfig struct { + // Maximum number of tokens to generate + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + // Temperature for sampling + Temperature float64 `json:"temperature,omitempty"` + // Top-p value for nucleus sampling + TopP float64 `json:"top_p,omitempty"` + // List of sequences where the model will stop generating + StopSequences []string `json:"stop_sequences,omitempty"` +} + +// WithConfig adds configuration parameters from the model request +// see https://platform.openai.com/docs/api-reference/responses/create +// for more details on openai's request fields +func (g *ModelGenerator) WithConfig(config any) *ModelGenerator { + // Return early if we already have an error + if g.err != nil { + return g + } + + if config == nil { + return g + } + + var openaiConfig OpenAIConfig + switch cfg := config.(type) { + case OpenAIConfig: + openaiConfig = cfg + case *OpenAIConfig: + openaiConfig = *cfg + case map[string]any: + if err := mapToStruct(cfg, &openaiConfig); err != nil { + g.err = fmt.Errorf("failed to convert config to OpenAIConfig: %w", err) + return g + } + default: + g.err = fmt.Errorf("unexpected config type: %T", config) + return g + } + + // Map fields from OpenaiConfig to OpenAI request + if openaiConfig.MaxOutputTokens != 0 { + g.request.MaxCompletionTokens = openai.F(int64(openaiConfig.MaxOutputTokens)) + } + if len(openaiConfig.StopSequences) > 0 { + g.request.Stop = openai.F[openai.ChatCompletionNewParamsStopUnion]( + openai.ChatCompletionNewParamsStopArray(openaiConfig.StopSequences)) + } + if openaiConfig.Temperature != 0 { + g.request.Temperature = openai.F(openaiConfig.Temperature) + } + if openaiConfig.TopP != 0 { + g.request.TopP = openai.F(openaiConfig.TopP) + } + + return g +} + +// WithTools adds tools to the request +func (g *ModelGenerator) WithTools(tools []*ai.ToolDefinition, choice ai.ToolChoice) *ModelGenerator { + if g.err != nil { + return g + } + + if tools == nil { + return g + } + + toolParams := make([]openai.ChatCompletionToolParam, 0, len(tools)) + for _, tool := range tools { + if tool == nil || tool.Name == "" { + continue + } + + toolParams = append(toolParams, openai.ChatCompletionToolParam{ + Type: openai.F(openai.ChatCompletionToolTypeFunction), + Function: openai.F(shared.FunctionDefinitionParam{ + Name: openai.F(tool.Name), + Description: openai.F(tool.Description), + Parameters: openai.F(openai.FunctionParameters(tool.InputSchema)), + Strict: openai.F(false), // TODO: implement strict mode + }), + }) + } + + // Set the tools in the request + // If no tools are provided, set it to nil + // This is important to avoid sending an empty array in the request + // which is not supported by some vendor APIs + if len(toolParams) > 0 { + g.request.Tools = openai.F(toolParams) + } + + switch choice { + case ai.ToolChoiceAuto: + g.request.ToolChoice = openai.F[openai.ChatCompletionToolChoiceOptionUnionParam](openai.ChatCompletionToolChoiceOptionAutoAuto) + case ai.ToolChoiceRequired: + g.request.ToolChoice = openai.F[openai.ChatCompletionToolChoiceOptionUnionParam](openai.ChatCompletionToolChoiceOptionAutoRequired) + case ai.ToolChoiceNone: + g.request.ToolChoice = openai.F[openai.ChatCompletionToolChoiceOptionUnionParam](openai.ChatCompletionToolChoiceOptionAutoNone) + } + + return g +} + +// Generate executes the generation request +func (g *ModelGenerator) Generate(ctx context.Context, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { + // Check for any errors that occurred during building + if g.err != nil { + return nil, g.err + } + + // Ensure messages are set + if len(g.request.Messages.Value) == 0 { + return nil, fmt.Errorf("no messages provided") + } + + if handleChunk != nil { + return g.generateStream(ctx, handleChunk) + } + return g.generateComplete(ctx) +} + +// concatenateContent concatenates text content into a single string +func (g *ModelGenerator) concatenateContent(parts []*ai.Part) string { + content := "" + for _, part := range parts { + content += part.Text + } + return content +} + +// generateStream generates a streaming model response +func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { + stream := g.client.Chat.Completions.NewStreaming(ctx, *g.request) + defer stream.Close() + + var fullResponse ai.ModelResponse + fullResponse.Message = &ai.Message{ + Role: ai.RoleModel, + Content: make([]*ai.Part, 0), + } + + // Initialize request and usage + fullResponse.Request = &ai.ModelRequest{} + fullResponse.Usage = &ai.GenerationUsage{ + InputTokens: 0, + OutputTokens: 0, + TotalTokens: 0, + } + + var currentToolCall *ai.ToolRequest + var currentArguments string + + for stream.Next() { + chunk := stream.Current() + if len(chunk.Choices) > 0 { + choice := chunk.Choices[0] + + switch choice.FinishReason { + case openai.ChatCompletionChunkChoicesFinishReasonStop, openai.ChatCompletionChunkChoicesFinishReasonToolCalls: + fullResponse.FinishReason = ai.FinishReasonStop + case openai.ChatCompletionChunkChoicesFinishReasonLength: + fullResponse.FinishReason = ai.FinishReasonLength + case openai.ChatCompletionChunkChoicesFinishReasonContentFilter: + fullResponse.FinishReason = ai.FinishReasonBlocked + case openai.ChatCompletionChunkChoicesFinishReasonFunctionCall: + fullResponse.FinishReason = ai.FinishReasonOther + default: + fullResponse.FinishReason = ai.FinishReasonUnknown + } + + // handle tool calls + for _, toolCall := range choice.Delta.ToolCalls { + // first tool call (= current tool call is nil) contains the tool call name + if currentToolCall == nil { + currentToolCall = &ai.ToolRequest{ + Name: toolCall.Function.Name, + } + } + + if toolCall.Function.Arguments != "" { + currentArguments += toolCall.Function.Arguments + } + } + + // when tool call is complete + if choice.FinishReason == openai.ChatCompletionChunkChoicesFinishReasonToolCalls && currentToolCall != nil { + // parse accumulated arguments string + if currentArguments != "" { + currentToolCall.Input = jsonStringToMap(currentArguments) + } + + fullResponse.Message.Content = []*ai.Part{ai.NewToolRequestPart(currentToolCall)} + return &fullResponse, nil + } + + content := chunk.Choices[0].Delta.Content + modelChunk := &ai.ModelResponseChunk{ + Content: []*ai.Part{ai.NewTextPart(content)}, + } + + if err := handleChunk(ctx, modelChunk); err != nil { + return nil, fmt.Errorf("callback error: %w", err) + } + + fullResponse.Message.Content = append(fullResponse.Message.Content, modelChunk.Content...) + + // Update Usage + fullResponse.Usage.InputTokens += int(chunk.Usage.PromptTokens) + fullResponse.Usage.OutputTokens += int(chunk.Usage.CompletionTokens) + fullResponse.Usage.TotalTokens += int(chunk.Usage.TotalTokens) + } + } + + if err := stream.Err(); err != nil { + return nil, fmt.Errorf("stream error: %w", err) + } + + return &fullResponse, nil +} + +// generateComplete generates a complete model response +func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelResponse, error) { + completion, err := g.client.Chat.Completions.New(ctx, *g.request) + if err != nil { + return nil, fmt.Errorf("failed to create completion: %w", err) + } + + resp := &ai.ModelResponse{ + Request: &ai.ModelRequest{}, + Usage: &ai.GenerationUsage{ + InputTokens: int(completion.Usage.PromptTokens), + OutputTokens: int(completion.Usage.CompletionTokens), + TotalTokens: int(completion.Usage.TotalTokens), + }, + Message: &ai.Message{ + Role: ai.RoleModel, + }, + } + + choice := completion.Choices[0] + + switch choice.FinishReason { + case openai.ChatCompletionChoicesFinishReasonStop, openai.ChatCompletionChoicesFinishReasonToolCalls: + resp.FinishReason = ai.FinishReasonStop + case openai.ChatCompletionChoicesFinishReasonLength: + resp.FinishReason = ai.FinishReasonLength + case openai.ChatCompletionChoicesFinishReasonContentFilter: + resp.FinishReason = ai.FinishReasonBlocked + case openai.ChatCompletionChoicesFinishReasonFunctionCall: + resp.FinishReason = ai.FinishReasonOther + default: + resp.FinishReason = ai.FinishReasonUnknown + } + + // handle tool calls + var toolRequestParts []*ai.Part + for _, toolCall := range choice.Message.ToolCalls { + toolRequestParts = append(toolRequestParts, ai.NewToolRequestPart(&ai.ToolRequest{ + Name: toolCall.Function.Name, + Input: jsonStringToMap(toolCall.Function.Arguments), + })) + } + if len(toolRequestParts) > 0 { + resp.Message.Content = toolRequestParts + return resp, nil + } + + resp.Message.Content = []*ai.Part{ + ai.NewTextPart(completion.Choices[0].Message.Content), + } + return resp, nil +} + +func convertToolCalls(content []*ai.Part) []openai.ChatCompletionMessageToolCallParam { + var toolCalls []openai.ChatCompletionMessageToolCallParam + for _, p := range content { + if !p.IsToolRequest() { + continue + } + toolCall := convertToolCall(p) + toolCalls = append(toolCalls, toolCall) + } + return toolCalls +} + +func convertToolCall(part *ai.Part) openai.ChatCompletionMessageToolCallParam { + param := openai.ChatCompletionMessageToolCallParam{ + // NOTE: Temporarily set its name instead of its ref (i.e. call_xxxxx) since it's not defined in the ai.ToolRequest struct. + ID: openai.F(part.ToolRequest.Name), + Type: openai.F(openai.ChatCompletionMessageToolCallTypeFunction), + Function: openai.F(openai.ChatCompletionMessageToolCallFunctionParam{ + Name: openai.F(part.ToolRequest.Name), + }), + } + + if part.ToolRequest.Input != nil { + param.Function.Value.Arguments = openai.F(anyToJSONString(part.ToolRequest.Input)) + } + + return param +} + +func jsonStringToMap(jsonString string) map[string]any { + var result map[string]any + if err := json.Unmarshal([]byte(jsonString), &result); err != nil { + panic(fmt.Errorf("unmarshal failed to parse json string %s: %w", jsonString, err)) + } + return result +} + +func anyToJSONString(data any) string { + jsonBytes, err := json.Marshal(data) + if err != nil { + panic(fmt.Errorf("failed to marshal any to JSON string: data, %#v %w", data, err)) + } + return string(jsonBytes) +} diff --git a/go/plugins/compat_oai/generate_live_test.go b/go/plugins/compat_oai/generate_live_test.go new file mode 100644 index 0000000000..b6c0f1fcca --- /dev/null +++ b/go/plugins/compat_oai/generate_live_test.go @@ -0,0 +1,230 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compat_oai_test + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/plugins/compat_oai" + openaiClient "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/stretchr/testify/assert" +) + +const defaultModel = "gpt-4o-mini" + +func setupTestClient(t *testing.T) *compat_oai.ModelGenerator { + t.Helper() + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + t.Skip("Skipping test: OPENAI_API_KEY environment variable not set") + } + + client := openaiClient.NewClient(option.WithAPIKey(apiKey)) + return compat_oai.NewModelGenerator(client, defaultModel) +} + +func TestGenerator_Complete(t *testing.T) { + g := setupTestClient(t) + + // define case with user and model messages + messages := []*ai.Message{ + { + Role: ai.RoleUser, + Content: []*ai.Part{ + ai.NewTextPart("Tell me a joke"), + }, + }, + { + Role: ai.RoleModel, + Content: []*ai.Part{ + ai.NewTextPart("Why did the scarecrow win an award?"), + }, + }, + { + Role: ai.RoleUser, + Content: []*ai.Part{ + ai.NewTextPart("Why?"), + }, + }, + } + + resp, err := g.WithMessages(messages).Generate(context.Background(), nil) + assert.NoError(t, err) + assert.NotEmpty(t, resp.Message.Content) + assert.Equal(t, ai.RoleModel, resp.Message.Role) + + t.Log("\n=== Simple Completion Response ===") + for _, part := range resp.Message.Content { + t.Logf("Content: %s", part.Text) + } +} + +func TestGenerator_Stream(t *testing.T) { + g := setupTestClient(t) + + messages := []*ai.Message{ + { + Role: ai.RoleUser, + Content: []*ai.Part{ + ai.NewTextPart("Count from 1 to 3"), + }, + }, + } + + var chunks []string + handleChunk := func(ctx context.Context, chunk *ai.ModelResponseChunk) error { + for _, part := range chunk.Content { + chunks = append(chunks, part.Text) + } + return nil + } + + _, err := g.WithMessages(messages).Generate(context.Background(), handleChunk) + assert.NoError(t, err) + assert.NotEmpty(t, chunks) + + // Verify we got the full response + fullText := strings.Join(chunks, "") + assert.Contains(t, fullText, "1") + assert.Contains(t, fullText, "2") + assert.Contains(t, fullText, "3") + + t.Log("\n=== Full Streaming Response ===") + t.Log(strings.Join(chunks, "")) +} + +func TestWithConfig(t *testing.T) { + tests := []struct { + name string + config any + err error + validate func(*testing.T, *openaiClient.ChatCompletionNewParams) + }{ + { + name: "nil config", + config: nil, + validate: func(t *testing.T, request *openaiClient.ChatCompletionNewParams) { + // For nil config, we expect all fields to be unset (not nil, but with Present=false) + assert.False(t, request.Temperature.Present) + assert.False(t, request.MaxCompletionTokens.Present) + assert.False(t, request.TopP.Present) + assert.False(t, request.Stop.Present) + }, + }, + { + name: "empty openai config", + config: compat_oai.OpenAIConfig{}, + validate: func(t *testing.T, request *openaiClient.ChatCompletionNewParams) { + // For empty config, we expect all fields to be unset + assert.False(t, request.Temperature.Present) + assert.False(t, request.MaxCompletionTokens.Present) + assert.False(t, request.TopP.Present) + assert.False(t, request.Stop.Present) + }, + }, + { + name: "valid config with all supported fields", + config: compat_oai.OpenAIConfig{ + Temperature: 0.7, + MaxOutputTokens: 100, + TopP: 0.9, + StopSequences: []string{"stop1", "stop2"}, + }, + validate: func(t *testing.T, request *openaiClient.ChatCompletionNewParams) { + // Check that fields are present and have correct values + assert.True(t, request.Temperature.Present) + assert.Equal(t, float64(0.7), request.Temperature.Value) + + assert.True(t, request.MaxCompletionTokens.Present) + assert.Equal(t, int64(100), request.MaxCompletionTokens.Value) + + assert.True(t, request.TopP.Present) + assert.Equal(t, float64(0.9), request.TopP.Value) + + assert.True(t, request.Stop.Present) + stopArray, ok := request.Stop.Value.(openaiClient.ChatCompletionNewParamsStopArray) + assert.True(t, ok) + assert.Equal(t, openaiClient.ChatCompletionNewParamsStopArray{"stop1", "stop2"}, stopArray) + }, + }, + { + name: "valid config as map", + config: map[string]any{ + "temperature": 0.7, + "max_output_tokens": 100, + "top_p": 0.9, + "stop_sequences": []string{"stop1", "stop2"}, + }, + validate: func(t *testing.T, request *openaiClient.ChatCompletionNewParams) { + assert.True(t, request.Temperature.Present) + assert.Equal(t, float64(0.7), request.Temperature.Value) + + assert.True(t, request.MaxCompletionTokens.Present) + assert.Equal(t, int64(100), request.MaxCompletionTokens.Value) + + assert.True(t, request.TopP.Present) + assert.Equal(t, float64(0.9), request.TopP.Value) + + assert.True(t, request.Stop.Present) + stopArray, ok := request.Stop.Value.(openaiClient.ChatCompletionNewParamsStopArray) + assert.True(t, ok) + assert.Equal(t, openaiClient.ChatCompletionNewParamsStopArray{"stop1", "stop2"}, stopArray) + }, + }, + { + name: "invalid config type", + config: "not a config", + err: fmt.Errorf("unexpected config type: string"), + }, + } + + // define simple messages for testing + messages := []*ai.Message{ + { + Role: ai.RoleUser, + Content: []*ai.Part{ + ai.NewTextPart("Tell me a joke"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + generator := setupTestClient(t) + result, err := generator.WithMessages(messages).WithConfig(tt.config).Generate(context.Background(), nil) + + if tt.err != nil { + assert.Error(t, err) + assert.Equal(t, tt.err.Error(), err.Error()) + return + } + + // validate that the response was successful + assert.NoError(t, err) + assert.NotNil(t, result) + + // validate the input request was transformed correctly + if tt.validate != nil { + tt.validate(t, generator.GetRequest()) + } + }) + } +} diff --git a/go/plugins/compat_oai/openai/README.md b/go/plugins/compat_oai/openai/README.md new file mode 100644 index 0000000000..9b227d25fc --- /dev/null +++ b/go/plugins/compat_oai/openai/README.md @@ -0,0 +1,81 @@ +# OpenAI Plugin + +This plugin provides a simple interface for using OpenAI's services. + +## Prerequisites + +- Go installed on your system +- An OpenAI API key + +## Usage + +Here's a simple example of how to use the OpenAI plugin: + +```go +// import "github.com/firebase/genkit/go/plugins/compat_oai/openai" +// Initialize the OpenAI plugin with your API key +oai := openai.NewPlugin(apiKey) + +// Initialize Genkit with the OpenAI plugin +g, err := genkit.Init(ctx, + genkit.WithDefaultModel("openai/gpt-4o-mini"), + genkit.WithPlugins(oai), +) +if err != nil { + // handle errors +} + +config := &ai.GenerationCommonConfig{ + // define optional config fields +} + +resp, err = genkit.Generate(ctx, g, + ai.WithPromptText("Write a short sentence about artificial intelligence."), + ai.WithConfig(config), +) +``` + +## Running Tests + +First, set your OpenAI API key as an environment variable: + +```bash +export OPENAI_API_KEY= +``` + +### Running All Tests +To run all tests in the directory: +```bash +go test -v . +``` + +### Running Tests from Specific Files +To run tests from a specific file: +```bash +# Run only generate_live_test.go tests +go test -run "^TestGenerator" + +# Run only openai_live_test.go tests +go test -run "^TestPlugin" +``` + +### Running Individual Tests +To run a specific test case: +```bash +# Run only the streaming test from openai_live_test.go +go test -run "TestPlugin/streaming" + +# Run only the Complete test from generate_live_test.go +go test -run "TestGenerator_Complete" + +# Run only the Stream test from generate_live_test.go +go test -run "TestGenerator_Stream" +``` + +### Test Output Verbosity +Add the `-v` flag for verbose output: +```bash +go test -v -run "TestPlugin/streaming" +``` + +Note: All live tests require the OPENAI_API_KEY environment variable to be set. Tests will be skipped if the API key is not provided. diff --git a/go/plugins/compat_oai/openai/openai.go b/go/plugins/compat_oai/openai/openai.go new file mode 100644 index 0000000000..ebd85502a8 --- /dev/null +++ b/go/plugins/compat_oai/openai/openai.go @@ -0,0 +1,205 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package openai + +import ( + "context" + "fmt" + "os" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/compat_oai" + openaiGo "github.com/openai/openai-go" + "github.com/openai/openai-go/option" +) + +const provider = "openai" + +var ( + // Supported models: https://platform.openai.com/docs/models + supportedModels = map[string]ai.ModelInfo{ + "gpt-4.1": { + Label: "GPT-4.1", + Supports: compat_oai.Multimodal.Supports, + Versions: []string{"gpt-4.1", "gpt-4.1-2025-04-14"}, + }, + "gpt-4.1-mini": { + Label: "GPT-4.1-mini", + Supports: compat_oai.Multimodal.Supports, + Versions: []string{"gpt-4.1-mini", "gpt-4.1-mini-2025-04-14"}, + }, + "gpt-4.1-nano": { + Label: "GPT-4.1-nano", + Supports: compat_oai.Multimodal.Supports, + Versions: []string{"gpt-4.1-nano", "gpt-4.1-nano-2025-04-14"}, + }, + openaiGo.ChatModelO3Mini: { + Label: "o3-mini", + Supports: compat_oai.BasicText.Supports, + Versions: []string{"o3-mini", "o3-mini-2025-01-31"}, + }, + openaiGo.ChatModelO1: { + Label: "o1", + Supports: compat_oai.BasicText.Supports, + Versions: []string{"o1", "o1-2024-12-17"}, + }, + openaiGo.ChatModelO1Preview: { + Label: "o1-preview", + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: false, + SystemRole: false, + Media: false, + }, + Versions: []string{"o1-preview", "o1-preview-2024-09-12"}, + }, + openaiGo.ChatModelO1Mini: { + Label: "o1-mini", + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: false, + SystemRole: false, + Media: false, + }, + Versions: []string{"o1-mini", "o1-mini-2024-09-12"}, + }, + openaiGo.ChatModelGPT4_5Preview: { + Label: "GPT-4.5-preview", + Supports: compat_oai.Multimodal.Supports, + Versions: []string{"gpt-4.5-preview", "gpt-4.5-preview-2025-02-27"}, + }, + openaiGo.ChatModelGPT4o: { + Label: "GPT-4o", + Supports: compat_oai.Multimodal.Supports, + Versions: []string{"gpt-4o", "gpt-4o-2024-11-20", "gpt-4o-2024-08-06", "gpt-4o-2024-05-13"}, + }, + openaiGo.ChatModelGPT4oMini: { + Label: "GPT-4o-mini", + Supports: compat_oai.Multimodal.Supports, + Versions: []string{"gpt-4o-mini", "gpt-4o-mini-2024-07-18"}, + }, + openaiGo.ChatModelGPT4Turbo: { + Label: "GPT-4-turbo", + Supports: compat_oai.Multimodal.Supports, + Versions: []string{"gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "gpt-4-0125-preview"}, + }, + openaiGo.ChatModelGPT4: { + Label: "GPT-4", + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: false, + SystemRole: true, + Media: false, + }, + Versions: []string{"gpt-4", "gpt-4-0613", "gpt-4-0314"}, + }, + openaiGo.ChatModelGPT3_5Turbo: { + Label: "GPT-3.5-turbo", + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: false, + SystemRole: true, + Media: false, + }, + Versions: []string{"gpt-3.5-turbo", "gpt-3.5-turbo-0125", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-instruct"}, + }, + } + + // Known embedders: https://platform.openai.com/docs/guides/embeddings + knownEmbedders = []string{ + openaiGo.EmbeddingModelTextEmbedding3Small, + openaiGo.EmbeddingModelTextEmbedding3Large, + openaiGo.EmbeddingModelTextEmbeddingAda002, + } +) + +type OpenAI struct { + // APIKey is the API key for the OpenAI API. If empty, the values of the environment variable "OPENAI_API_KEY" will be consulted. + // Request a key at https://platform.openai.com/api-keys + APIKey string + // Optional: Opts are additional options for the OpenAI client. + // Can include other options like WithOrganization, WithBaseURL, etc. + Opts []option.RequestOption + + openAICompatible *compat_oai.OpenAICompatible +} + +// Name implements genkit.Plugin. +func (o *OpenAI) Name() string { + return provider +} + +// Init implements genkit.Plugin. +func (o *OpenAI) Init(ctx context.Context, g *genkit.Genkit) error { + apiKey := o.APIKey + + // if api key is not set, get it from environment variable + if apiKey == "" { + apiKey = os.Getenv("OPENAI_API_KEY") + } + + if apiKey == "" { + return fmt.Errorf("openai plugin initialization failed: apiKey is required") + } + + if o.openAICompatible == nil { + o.openAICompatible = &compat_oai.OpenAICompatible{} + } + + // set the options + o.openAICompatible.Opts = []option.RequestOption{ + option.WithAPIKey(apiKey), + } + if len(o.Opts) > 0 { + o.openAICompatible.Opts = append(o.openAICompatible.Opts, o.Opts...) + } + + if err := o.openAICompatible.Init(ctx, g); err != nil { + return err + } + + // define default models + for model, info := range supportedModels { + if _, err := o.DefineModel(g, model, info); err != nil { + return err + } + } + + // define default embedders + for _, embedder := range knownEmbedders { + if _, err := o.DefineEmbedder(g, embedder); err != nil { + return err + } + } + + return nil +} + +func (o *OpenAI) Model(g *genkit.Genkit, name string) ai.Model { + return o.openAICompatible.Model(g, name, provider) +} + +func (o *OpenAI) DefineModel(g *genkit.Genkit, name string, info ai.ModelInfo) (ai.Model, error) { + return o.openAICompatible.DefineModel(g, provider, name, info) +} + +func (o *OpenAI) DefineEmbedder(g *genkit.Genkit, name string) (ai.Embedder, error) { + return o.openAICompatible.DefineEmbedder(g, provider, name) +} + +func (o *OpenAI) Embedder(g *genkit.Genkit, name string) ai.Embedder { + return o.openAICompatible.Embedder(g, name, provider) +} diff --git a/go/plugins/compat_oai/openai/openai_live_test.go b/go/plugins/compat_oai/openai/openai_live_test.go new file mode 100644 index 0000000000..52c1fa7bf2 --- /dev/null +++ b/go/plugins/compat_oai/openai/openai_live_test.go @@ -0,0 +1,274 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package openai_test + +import ( + "context" + "math" + "os" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/compat_oai/openai" +) + +func TestPlugin(t *testing.T) { + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + t.Skip("Skipping test: OPENAI_API_KEY environment variable not set") + } + + ctx := context.Background() + + // Initialize the OpenAI plugin + oai := &openai.OpenAI{ + APIKey: apiKey, + } + g, err := genkit.Init(context.Background(), + genkit.WithDefaultModel("openai/gpt-4o-mini"), + genkit.WithPlugins(oai), + ) + if err != nil { + t.Fatal(err) + } + t.Log("genkit initialized") + + // Define a tool for calculating gablorkens + gablorkenTool := genkit.DefineTool(g, "gablorken", "use when need to calculate a gablorken", + func(ctx *ai.ToolContext, input struct { + Value float64 + Over float64 + }, + ) (float64, error) { + return math.Pow(input.Value, input.Over), nil + }, + ) + + t.Log("openai plugin initialized") + + t.Run("embedder", func(t *testing.T) { + + // define embedder + embedder := oai.Embedder(g, "text-embedding-3-small") + res, err := ai.Embed(ctx, embedder, ai.WithEmbedText("yellow banana")) + if err != nil { + t.Fatal(err) + } + out := res.Embeddings[0].Embedding + // There's not a whole lot we can test about the result. + // Just do a few sanity checks. + if len(out) < 100 { + t.Errorf("embedding vector looks too short: len(out)=%d", len(out)) + } + var normSquared float32 + for _, x := range out { + normSquared += x * x + } + if normSquared < 0.9 || normSquared > 1.1 { + t.Errorf("embedding vector not unit length: %f", normSquared) + } + }) + + t.Run("basic completion", func(t *testing.T) { + t.Log("generating basic completion response") + resp, err := genkit.Generate(ctx, g, + ai.WithPromptText("What is the capital of France?"), + ) + if err != nil { + t.Fatal("error generating basic completion response: ", err) + } + t.Logf("basic completion response: %+v", resp) + + out := resp.Message.Content[0].Text + if !strings.Contains(strings.ToLower(out), "paris") { + t.Errorf("got %q, expecting it to contain 'Paris'", out) + } + + // Verify usage statistics are present + if resp.Usage == nil || resp.Usage.TotalTokens == 0 { + t.Error("Expected non-zero usage statistics") + } + }) + + t.Run("streaming", func(t *testing.T) { + var streamedOutput string + chunks := 0 + + final, err := genkit.Generate(ctx, g, + ai.WithPromptText("Write a short paragraph about artificial intelligence."), + ai.WithStreaming(func(ctx context.Context, chunk *ai.ModelResponseChunk) error { + chunks++ + for _, content := range chunk.Content { + streamedOutput += content.Text + } + return nil + })) + if err != nil { + t.Fatal(err) + } + + // Verify streaming worked + if chunks <= 1 { + t.Error("Expected multiple chunks for streaming") + } + + // Verify final output matches streamed content + finalOutput := "" + for _, content := range final.Message.Content { + finalOutput += content.Text + } + if streamedOutput != finalOutput { + t.Errorf("Streaming output doesn't match final output\nStreamed: %s\nFinal: %s", + streamedOutput, finalOutput) + } + + t.Logf("streaming response: %+v", finalOutput) + }) + + t.Run("tool usage with basic completion", func(t *testing.T) { + resp, err := genkit.Generate(ctx, g, + ai.WithPromptText("what is a gablorken of 2 over 3.5?"), + ai.WithTools(gablorkenTool)) + if err != nil { + t.Fatal(err) + } + + out := resp.Message.Content[0].Text + const want = "12.25" + if !strings.Contains(out, want) { + t.Errorf("got %q, expecting it to contain %q", out, want) + } + + t.Logf("tool usage with basic completion response: %+v", out) + }) + + t.Run("tool usage with streaming", func(t *testing.T) { + var streamedOutput string + chunks := 0 + + final, err := genkit.Generate(ctx, g, + ai.WithPromptText("what is a gablorken of 2 over 3.5?"), + ai.WithTools(gablorkenTool), + ai.WithStreaming(func(ctx context.Context, chunk *ai.ModelResponseChunk) error { + chunks++ + for _, content := range chunk.Content { + streamedOutput += content.Text + } + return nil + })) + if err != nil { + t.Fatal(err) + } + + // Verify streaming worked + if chunks <= 1 { + t.Error("Expected multiple chunks for streaming") + } + + // Verify final output matches streamed content + finalOutput := "" + for _, content := range final.Message.Content { + finalOutput += content.Text + } + if streamedOutput != finalOutput { + t.Errorf("Streaming output doesn't match final output\nStreamed: %s\nFinal: %s", + streamedOutput, finalOutput) + } + + const want = "12.25" + if !strings.Contains(finalOutput, want) { + t.Errorf("got %q, expecting it to contain %q", finalOutput, want) + } + + t.Logf("tool usage with streaming response: %+v", finalOutput) + }) + + t.Run("system message", func(t *testing.T) { + resp, err := genkit.Generate(ctx, g, + ai.WithPromptText("What are you?"), + ai.WithSystemText("You are a helpful math tutor who loves numbers."), + ) + if err != nil { + t.Fatal(err) + } + + out := resp.Message.Content[0].Text + if !strings.Contains(strings.ToLower(out), "math") { + t.Errorf("got %q, expecting response to mention being a math tutor", out) + } + + t.Logf("system message response: %+v", out) + }) + + t.Run("generation config", func(t *testing.T) { + // Create a config with specific parameters + config := &ai.GenerationCommonConfig{ + Temperature: 0.2, + MaxOutputTokens: 50, + TopP: 0.5, + StopSequences: []string{".", "!", "?"}, + } + + resp, err := genkit.Generate(ctx, g, + ai.WithPromptText("Write a short sentence about artificial intelligence."), + ai.WithConfig(config), + ) + if err != nil { + t.Fatal(err) + } + out := resp.Message.Content[0].Text + t.Logf("generation config response: %+v", out) + }) + + t.Run("unsupported config field", func(t *testing.T) { + // Create a config with an unsupported TopK parameter + config := &ai.GenerationCommonConfig{ + Temperature: 0.2, + MaxOutputTokens: 50, + TopK: 10, // TopK is not supported in OpenAI's chat completion API + } + + _, err := genkit.Generate(ctx, g, + ai.WithPromptText("Write a short sentence about artificial intelligence."), + ai.WithConfig(config), + ) + if err == nil { + t.Fatal("expected error for unsupported TopK parameter") + } + if !strings.Contains(err.Error(), "TopK is not supported in OpenAI's chat completion API") { + t.Errorf("got error %q, want error containing 'TopK is not supported in OpenAI's chat completion API'", err.Error()) + } + t.Logf("unsupported config error: %v", err) + }) + + t.Run("invalid config type", func(t *testing.T) { + // Try to use a string as config instead of *ai.GenerationCommonConfig + config := "not a config" + + _, err := genkit.Generate(ctx, g, + ai.WithPromptText("Write a short sentence about artificial intelligence."), + ai.WithConfig(config), + ) + if err == nil { + t.Fatal("expected error for invalid config type") + } + if !strings.Contains(err.Error(), "config must be of type *ai.GenerationCommonConfig") { + t.Errorf("got error %q, want error containing 'config must be of type *ai.GenerationCommonConfig'", err.Error()) + } + t.Logf("invalid config type error: %v", err) + }) +} diff --git a/go/samples/compat_oai/anthropic/main.go b/go/samples/compat_oai/anthropic/main.go new file mode 100644 index 0000000000..a853aae86d --- /dev/null +++ b/go/samples/compat_oai/anthropic/main.go @@ -0,0 +1,54 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +// This program can be manually tested like so: +// Start the server listening on port 3100: +// +// genkit start -o -- go run . + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + oai "github.com/firebase/genkit/go/plugins/compat_oai/anthropic" + "github.com/firebase/genkit/go/plugins/server" + "github.com/openai/openai-go/option" +) + +func main() { + ctx := context.Background() + + oai := oai.Anthropic{ + Opts: []option.RequestOption{ + option.WithAPIKey(os.Getenv("ANTHROPIC_API_KEY")), + }, + } + g, err := genkit.Init(ctx, genkit.WithPlugins(&oai)) + if err != nil { + log.Fatalf("failed to initialize OpenAICompatible: %v", err) + } + + genkit.DefineFlow(g, "anthropic", func(ctx context.Context, subject string) (string, error) { + sonnet37 := oai.Model(g, "claude-3-7-sonnet-20250219") + + prompt := fmt.Sprintf("tell me a joke about %s", subject) + foo, err := genkit.Generate(ctx, g, ai.WithModel(sonnet37), ai.WithPromptText(prompt)) + if err != nil { + return "", err + } + return fmt.Sprintf("foo: %s", foo.Text()), nil + }) + + mux := http.NewServeMux() + for _, a := range genkit.ListFlows(g) { + mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) + } + log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) +} diff --git a/go/samples/compat_oai/openai/main.go b/go/samples/compat_oai/openai/main.go new file mode 100644 index 0000000000..5b6b0e7924 --- /dev/null +++ b/go/samples/compat_oai/openai/main.go @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 + +// This program can be manually tested like so: +// Start the server listening on port 3100: +// +// genkit start -o -- go run . + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + oai "github.com/firebase/genkit/go/plugins/compat_oai/openai" + "github.com/firebase/genkit/go/plugins/server" + "github.com/openai/openai-go" +) + +func main() { + ctx := context.Background() + + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + log.Fatalf("no OPENAI_API_KEY environment variable set") + } + oai := &oai.OpenAI{ + APIKey: apiKey, + } + g, err := genkit.Init(ctx, genkit.WithPlugins(oai)) + if err != nil { + log.Fatalf("failed to create Genkit: %v", err) + } + + genkit.DefineFlow(g, "basic", func(ctx context.Context, subject string) (string, error) { + gpt4o := oai.Model(g, "gpt-4o") + + prompt := fmt.Sprintf("tell me a joke about %s", subject) + config := &openai.ChatCompletionNewParams{Temperature: openai.F(0.5), MaxTokens: openai.F(int64(100))} + foo, err := genkit.Generate(ctx, g, ai.WithModel(gpt4o), ai.WithPromptText(prompt), ai.WithConfig(config)) + if err != nil { + return "", err + } + return fmt.Sprintf("foo: %s", foo.Text()), nil + }) + + genkit.DefineFlow(g, "defined-model", func(ctx context.Context, subject string) (string, error) { + gpt4oMini := oai.Model(g, "gpt-4o-mini") + if err != nil { + return "", err + } + prompt := fmt.Sprintf("tell me a joke about %s", subject) + config := &ai.GenerationCommonConfig{Temperature: 0.5} + foo, err := genkit.Generate(ctx, g, ai.WithModel(gpt4oMini), ai.WithPromptText(prompt), ai.WithConfig(config)) + if err != nil { + return "", err + } + return fmt.Sprintf("foo: %s", foo.Text()), nil + }) + + mux := http.NewServeMux() + for _, a := range genkit.ListFlows(g) { + mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) + } + log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) +}