From 0baad300e0e84942af7997b1aaaba7778c4e7ed5 Mon Sep 17 00:00:00 2001 From: Hugo Aguirre Date: Wed, 5 Feb 2025 16:37:19 -0600 Subject: [PATCH] feature(go): add support for multiple model versions (#1575) --- go/ai/generate.go | 77 +++++++++++++----- go/ai/generator_test.go | 80 ++++++++++++++----- go/genkit/genkit.go | 3 +- .../doc-snippets/modelplugin/modelplugin.go | 13 +-- go/internal/doc-snippets/ollama.go | 16 ++-- go/plugins/googleai/googleai.go | 59 +++++++++----- go/plugins/internal/gemini/gemini.go | 4 +- go/plugins/ollama/ollama.go | 47 ++++++----- go/plugins/vertexai/vertexai.go | 47 ++++++----- go/samples/basic-gemini/main.go | 69 ++++++++++++++++ js/testapps/basic-gemini/src/index.ts | 2 + 11 files changed, 303 insertions(+), 114 deletions(-) create mode 100644 go/samples/basic-gemini/main.go diff --git a/go/ai/generate.go b/go/ai/generate.go index a577751af..7356ceeef 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package ai import ( @@ -35,33 +34,21 @@ type modelAction = core.Action[*ModelRequest, *ModelResponse, *ModelResponseChun // ModelStreamingCallback is the type for the streaming callback of a model. type ModelStreamingCallback = func(context.Context, *ModelResponseChunk) error -// ModelCapabilities describes various capabilities of the model. -type ModelCapabilities struct { - Multiturn bool // the model can handle multiple request-response interactions - Media bool // the model supports media as well as text input - Tools bool // the model supports tools - SystemRole bool // the model supports a system prompt or role -} - -// ModelMetadata is the metadata of the model, specifying things like nice user-visible label, capabilities, etc. -type ModelMetadata struct { - Label string - Supports ModelCapabilities -} - // DefineModel registers the given generate function as an action, and returns a // [Model] that runs it. func DefineModel( r *registry.Registry, provider, name string, - metadata *ModelMetadata, + metadata *ModelInfo, generate func(context.Context, *ModelRequest, ModelStreamingCallback) (*ModelResponse, error), ) Model { metadataMap := map[string]any{} if metadata == nil { // Always make sure there's at least minimal metadata. - metadata = &ModelMetadata{ - Label: name, + metadata = &ModelInfo{ + Label: name, + Supports: &ModelInfoSupports{}, + Versions: []string{}, } } if metadata.Label != "" { @@ -74,6 +61,7 @@ func DefineModel( "tools": metadata.Supports.Tools, } metadataMap["supports"] = supports + metadataMap["versions"] = metadata.Versions return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{ "model": metadataMap, @@ -100,8 +88,8 @@ type generateParams struct { Request *ModelRequest Model Model Stream ModelStreamingCallback - History []*Message SystemPrompt *Message + History []*Message } // GenerateOption configures params of the Generate call. @@ -242,6 +230,19 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) if req.Model == nil { return nil, errors.New("model is required") } + + var modelVersion string + if config, ok := req.Request.Config.(*GenerationCommonConfig); ok { + modelVersion = config.Version + } + + if modelVersion != "" { + ok, err := validateModelVersion(r, modelVersion, req) + if !ok { + return nil, err + } + } + if req.History != nil { prev := req.Request.Messages req.Request.Messages = req.History @@ -256,6 +257,44 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) return req.Model.Generate(ctx, r, req.Request, req.Stream) } +// validateModelVersion checks in the registry the action of the +// given model version and determines whether its supported or not. +func validateModelVersion(r *registry.Registry, v string, req *generateParams) (bool, error) { + parts := strings.Split(req.Model.Name(), "/") + if len(parts) != 2 { + return false, errors.New("wrong model name") + } + + m := LookupModel(r, parts[0], parts[1]) + if m == nil { + return false, fmt.Errorf("model %s not found", v) + } + + // at the end, a Model is an action so type conversion is required + if a, ok := m.(*modelActionDef); ok { + if !(modelVersionSupported(v, (*modelAction)(a).Desc().Metadata)) { + return false, fmt.Errorf("version %s not supported", v) + } + } else { + return false, errors.New("unable to validate model version") + } + + return true, nil +} + +// modelVersionSupported iterates over model's metadata to find the requested +// supported model version +func modelVersionSupported(modelVersion string, modelMetadata map[string]any) bool { + if md, ok := modelMetadata["model"].(map[string]any); ok { + for _, v := range md["versions"].([]string) { + if modelVersion == v { + return true + } + } + } + return false +} + // GenerateText run generate request for this model. Returns generated text only. func GenerateText(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (string, error) { res, err := Generate(ctx, r, opts...) diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index b82bda62d..eb515e22b 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package ai import ( @@ -23,30 +22,46 @@ type GameCharacter struct { var r, _ = registry.New() -var echoModel = DefineModel(r, "test", "echo", nil, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { - if msc != nil { - msc(ctx, &ModelResponseChunk{ - Content: []*Part{NewTextPart("stream!")}, - }) +// echoModel attributes +var ( + modelName = "echo" + metadata = ModelInfo{ + Label: modelName, + Supports: &ModelInfoSupports{ + Multiturn: true, + Tools: true, + SystemRole: true, + Media: false, + }, + Versions: []string{"echo-001", "echo-002"}, } - textResponse := "" - for _, m := range gr.Messages { - if m.Role == RoleUser { - textResponse += m.Content[0].Text + + echoModel = DefineModel(r, "test", modelName, &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { + if msc != nil { + msc(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("stream!")}, + }) } - } - return &ModelResponse{ - Request: gr, - Message: NewUserTextMessage(textResponse), - }, nil -}) + textResponse := "" + for _, m := range gr.Messages { + if m.Role == RoleUser { + textResponse += m.Content[0].Text + } + } + return &ModelResponse{ + Request: gr, + Message: NewUserTextMessage(textResponse), + }, nil + }) +) // with tools var gablorkenTool = DefineTool(r, "gablorken", "use when need to calculate a gablorken", func(ctx context.Context, input struct { Value float64 Over float64 - }) (float64, error) { + }, + ) (float64, error) { return math.Pow(input.Value, input.Over), nil }, ) @@ -320,9 +335,36 @@ func TestGenerate(t *testing.T) { }) } +func TestModelVersion(t *testing.T) { + t.Run("valid version", func(t *testing.T) { + _, err := Generate(context.Background(), r, + WithModel(echoModel), + WithConfig(&GenerationCommonConfig{ + Temperature: 1, + Version: "echo-001", + }), + WithTextPrompt("tell a joke about batman")) + if err != nil { + t.Errorf("model version should be valid") + } + }) + t.Run("invalid version", func(t *testing.T) { + _, err := Generate(context.Background(), r, + WithModel(echoModel), + WithConfig(&GenerationCommonConfig{ + Temperature: 1, + Version: "echo-im-not-a-version", + }), + WithTextPrompt("tell a joke about batman")) + if err == nil { + t.Errorf("model version should be invalid: %v", err) + } + }) +} + func TestIsDefinedModel(t *testing.T) { t.Run("should return true", func(t *testing.T) { - if IsDefinedModel(r, "test", "echo") != true { + if IsDefinedModel(r, "test", modelName) != true { t.Errorf("IsDefinedModel did not return true") } }) @@ -335,7 +377,7 @@ func TestIsDefinedModel(t *testing.T) { func TestLookupModel(t *testing.T) { t.Run("should return model", func(t *testing.T) { - if LookupModel(r, "test", "echo") == nil { + if LookupModel(r, "test", modelName) == nil { t.Errorf("LookupModel did not return model") } }) diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 68c5147e7..dc0e8aac8 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - // Package genkit provides Genkit functionality for application developers. package genkit @@ -157,7 +156,7 @@ func (g *Genkit) Start(ctx context.Context, opts *StartOptions) error { func DefineModel( g *Genkit, provider, name string, - metadata *ai.ModelMetadata, + metadata *ai.ModelInfo, generate func(context.Context, *ai.ModelRequest, ai.ModelStreamingCallback) (*ai.ModelResponse, error), ) ai.Model { return ai.DefineModel(g.reg, provider, name, metadata, generate) diff --git a/go/internal/doc-snippets/modelplugin/modelplugin.go b/go/internal/doc-snippets/modelplugin/modelplugin.go index 657651674..669ef9f8b 100644 --- a/go/internal/doc-snippets/modelplugin/modelplugin.go +++ b/go/internal/doc-snippets/modelplugin/modelplugin.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package modelplugin import ( @@ -17,8 +16,8 @@ const providerID = "mymodels" // [START cfg] type MyModelConfig struct { ai.GenerationCommonConfig - CustomOption int AnotherCustomOption string + CustomOption int } // [END cfg] @@ -30,16 +29,18 @@ func Init() error { } // [START definemodel] + name := "my-model" genkit.DefineModel(g, - providerID, "my-model", - &ai.ModelMetadata{ - Label: "my-model", - Supports: ai.ModelCapabilities{ + providerID, name, + &ai.ModelInfo{ + Label: name, + Supports: &ai.ModelInfoSupports{ Multiturn: true, // Does the model support multi-turn chats? SystemRole: true, // Does the model support syatem messages? Media: false, // Can the model accept media input? Tools: false, // Does the model support function calling (tools)? }, + Versions: []string{}, }, func(ctx context.Context, genRequest *ai.ModelRequest, diff --git a/go/internal/doc-snippets/ollama.go b/go/internal/doc-snippets/ollama.go index 8665137a7..148fe8735 100644 --- a/go/internal/doc-snippets/ollama.go +++ b/go/internal/doc-snippets/ollama.go @@ -29,17 +29,21 @@ func ollamaEx(ctx context.Context) error { // [END init] // [START definemodel] + name := "gemma2" model := ollama.DefineModel( g, ollama.ModelDefinition{ - Name: "gemma2", + Name: name, Type: "chat", // "chat" or "generate" }, - &ai.ModelCapabilities{ - Multiturn: true, - SystemRole: true, - Tools: false, - Media: false, + &ai.ModelInfo{ + Label: name, + Supports: &ai.ModelInfoSupports{ + Multiturn: true, + SystemRole: true, + Tools: false, + Media: false, + }, }, ) // [END definemodel] diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 8a51ce56e..7076399ec 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - // Parts of this file are copied into vertexai, because the code is identical // except for the import path of the Gemini SDK. //go:generate go run ../../internal/cmd/copy -dest ../vertexai googleai.go @@ -31,17 +30,33 @@ const ( ) var state struct { - mu sync.Mutex - initted bool // These happen to be the same. gclient, pclient *genai.Client + mu sync.Mutex + initted bool } var ( - knownCaps = map[string]ai.ModelCapabilities{ - "gemini-1.0-pro": gemini.BasicText, - "gemini-1.5-pro": gemini.Multimodal, - "gemini-1.5-flash": gemini.Multimodal, + supportedModels = map[string]ai.ModelInfo{ + "gemini-1.0-pro": { + Versions: []string{"gemini-pro", "gemini-1.0-pro-latest", "gemini-1.0-pro-001"}, + Supports: &gemini.BasicText, + }, + + "gemini-1.5-flash": { + Versions: []string{"gemini-1.5-flash-latest", "gemini-1.5-flash-001", "gemini-1.5-flash-002"}, + Supports: &gemini.Multimodal, + }, + + "gemini-1.5-pro": { + Versions: []string{"gemini-1.5-pro-latest", "gemini-1.5-pro-001", "gemini-1.5-pro-002"}, + Supports: &gemini.Multimodal, + }, + + "gemini-1.5-flash-8b": { + Versions: []string{"gemini-1.5-flash-8b-latest", "gemini-1.5-flash-8b-001"}, + Supports: &gemini.Multimodal, + }, } knownEmbedders = []string{"text-embedding-004", "embedding-001"} @@ -88,7 +103,8 @@ func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) (err error) { opts := append([]option.ClientOption{ option.WithAPIKey(apiKey), - genai.WithClientInfo("genkit-go", internal.Version)}, + genai.WithClientInfo("genkit-go", internal.Version), + }, cfg.ClientOptions..., ) client, err := genai.NewClient(ctx, opts...) @@ -98,8 +114,8 @@ func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) (err error) { state.gclient = client state.pclient = client state.initted = true - for model, caps := range knownCaps { - defineModel(g, model, caps) + for model, details := range supportedModels { + defineModel(g, model, details) } for _, e := range knownEmbedders { defineEmbedder(g, e) @@ -113,30 +129,32 @@ func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) (err error) { // The second argument describes the capability of the model. // Use [IsDefinedModel] to determine if a model is already defined. // After [Init] is called, only the known models are defined. -func DefineModel(g *genkit.Genkit, name string, caps *ai.ModelCapabilities) (ai.Model, error) { +func DefineModel(g *genkit.Genkit, name string, info *ai.ModelInfo) (ai.Model, error) { state.mu.Lock() defer state.mu.Unlock() if !state.initted { panic(provider + ".Init not called") } - var mc ai.ModelCapabilities - if caps == nil { + var mi ai.ModelInfo + if info == nil { var ok bool - mc, ok = knownCaps[name] + mi, ok = supportedModels[name] if !ok { - return nil, fmt.Errorf("%s.DefineModel: called with unknown model %q and nil ModelCapabilities", provider, name) + return nil, fmt.Errorf("%s.DefineModel: called with unknown model %q and nil ModelInfo", provider, name) } } else { - mc = *caps + // TODO: unknown models could also specify versions? + mi = *info } - return defineModel(g, name, mc), nil + return defineModel(g, name, mi), nil } // requires state.mu -func defineModel(g *genkit.Genkit, name string, caps ai.ModelCapabilities) ai.Model { - meta := &ai.ModelMetadata{ +func defineModel(g *genkit.Genkit, name string, info ai.ModelInfo) ai.Model { + meta := &ai.ModelInfo{ Label: labelPrefix + " - " + name, - Supports: caps, + Supports: info.Supports, + Versions: info.Versions, } return genkit.DefineModel(g, provider, name, meta, func( ctx context.Context, @@ -317,7 +335,6 @@ func newModel(client *genai.Client, model string, input *ai.ModelRequest) (*gena systemParts, err := convertParts(m.Content) if err != nil { return nil, err - } // system prompts go into GenerativeModel.SystemInstruction field. if m.Role == ai.RoleSystem { diff --git a/go/plugins/internal/gemini/gemini.go b/go/plugins/internal/gemini/gemini.go index 578b41643..bfec6c671 100644 --- a/go/plugins/internal/gemini/gemini.go +++ b/go/plugins/internal/gemini/gemini.go @@ -10,7 +10,7 @@ import "github.com/firebase/genkit/go/ai" var ( // BasicText describes model capabilities for text-only Gemini models. - BasicText = ai.ModelCapabilities{ + BasicText = ai.ModelInfoSupports{ Multiturn: true, Tools: true, SystemRole: true, @@ -18,7 +18,7 @@ var ( } // Multimodal describes model capabilities for multimodal Gemini models. - Multimodal = ai.ModelCapabilities{ + Multimodal = ai.ModelInfoSupports{ Multiturn: true, Tools: true, SystemRole: true, diff --git a/go/plugins/ollama/ollama.go b/go/plugins/ollama/ollama.go index 3b3b51624..ed0f9b093 100644 --- a/go/plugins/ollama/ollama.go +++ b/go/plugins/ollama/ollama.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package ollama import ( @@ -26,41 +25,48 @@ import ( const provider = "ollama" -var mediaSupportedModels = []string{"llava"} -var roleMapping = map[ai.Role]string{ - ai.RoleUser: "user", - ai.RoleModel: "assistant", - ai.RoleSystem: "system", -} +var ( + mediaSupportedModels = []string{"llava"} + roleMapping = map[ai.Role]string{ + ai.RoleUser: "user", + ai.RoleModel: "assistant", + ai.RoleSystem: "system", + } +) + var state struct { - mu sync.Mutex - initted bool serverAddress string + initted bool + mu sync.Mutex } -func DefineModel(g *genkit.Genkit, model ModelDefinition, caps *ai.ModelCapabilities) ai.Model { +func DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model { state.mu.Lock() defer state.mu.Unlock() if !state.initted { panic("ollama.Init not called") } - var mc ai.ModelCapabilities - if caps != nil { - mc = *caps + var mi ai.ModelInfo + if info != nil { + mi = *info } else { - mc = ai.ModelCapabilities{ - Multiturn: true, - SystemRole: true, - Media: slices.Contains(mediaSupportedModels, model.Name), + mi = ai.ModelInfo{ + Label: model.Name, + Supports: &ai.ModelInfoSupports{ + Multiturn: true, + SystemRole: true, + Media: slices.Contains(mediaSupportedModels, model.Name), + }, + Versions: []string{}, } } - meta := &ai.ModelMetadata{ + meta := &ai.ModelInfo{ Label: "Ollama - " + model.Name, - Supports: mc, + Supports: mi.Supports, + Versions: []string{}, } gen := &generator{model: model, serverAddress: state.serverAddress} return genkit.DefineModel(g, provider, model.Name, meta, gen.generate) - } // IsDefinedModel reports whether a model is defined. @@ -160,7 +166,6 @@ func Init(ctx context.Context, cfg *Config) (err error) { // Generate makes a request to the Ollama API and processes the response. func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) { - stream := cb != nil var payload any isChatModel := g.model.Type == "chat" diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 49101f03d..762f70661 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package vertexai import ( @@ -29,10 +28,21 @@ const ( ) var ( - knownCaps = map[string]ai.ModelCapabilities{ - "gemini-1.0-pro": gemini.BasicText, - "gemini-1.5-pro": gemini.Multimodal, - "gemini-1.5-flash": gemini.Multimodal, + supportedModels = map[string]ai.ModelInfo{ + "gemini-1.0-pro": { + Versions: []string{"gemini-pro", "gemini-1.0-pro-latest", "gemini-1.0-pro-001"}, + Supports: &gemini.BasicText, + }, + + "gemini-1.5-flash": { + Versions: []string{"gemini-1.5-flash-latest", "gemini-1.5-flash-001", "gemini-1.5-flash-002"}, + Supports: &gemini.Multimodal, + }, + + "gemini-1.5-pro": { + Versions: []string{"gemini-1.5-pro-latest", "gemini-1.5-pro-001", "gemini-1.5-pro-002"}, + Supports: &gemini.Multimodal, + }, } knownEmbedders = []string{ @@ -114,8 +124,8 @@ func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) error { return err } state.initted = true - for model, caps := range knownCaps { - defineModel(g, model, caps) + for model, info := range supportedModels { + defineModel(g, model, info) } for _, e := range knownEmbedders { defineEmbedder(g, e) @@ -130,30 +140,32 @@ func Init(ctx context.Context, g *genkit.Genkit, cfg *Config) error { // The second argument describes the capability of the model. // Use [IsDefinedModel] to determine if a model is already defined. // After [Init] is called, only the known models are defined. -func DefineModel(g *genkit.Genkit, name string, caps *ai.ModelCapabilities) (ai.Model, error) { +func DefineModel(g *genkit.Genkit, name string, info *ai.ModelInfo) (ai.Model, error) { state.mu.Lock() defer state.mu.Unlock() if !state.initted { panic(provider + ".Init not called") } - var mc ai.ModelCapabilities - if caps == nil { + var mi ai.ModelInfo + if info == nil { var ok bool - mc, ok = knownCaps[name] + mi, ok = supportedModels[name] if !ok { - return nil, fmt.Errorf("%s.DefineModel: called with unknown model %q and nil ModelCapabilities", provider, name) + return nil, fmt.Errorf("%s.DefineModel: called with unknown model %q and nil ModelInfo", provider, name) } } else { - mc = *caps + // TODO: unknown models could also specify versions? + mi = *info } - return defineModel(g, name, mc), nil + return defineModel(g, name, mi), nil } // requires state.mu -func defineModel(g *genkit.Genkit, name string, caps ai.ModelCapabilities) ai.Model { - meta := &ai.ModelMetadata{ +func defineModel(g *genkit.Genkit, name string, info ai.ModelInfo) ai.Model { + meta := &ai.ModelInfo{ Label: labelPrefix + " - " + name, - Supports: caps, + Supports: info.Supports, + Versions: info.Versions, } return genkit.DefineModel(g, provider, name, meta, func( ctx context.Context, @@ -323,7 +335,6 @@ func newModel(client *genai.Client, model string, input *ai.ModelRequest) (*gena systemParts, err := convertParts(m.Content) if err != nil { return nil, err - } // system prompts go into GenerativeModel.SystemInstruction field. if m.Role == ai.RoleSystem { diff --git a/go/samples/basic-gemini/main.go b/go/samples/basic-gemini/main.go new file mode 100644 index 000000000..daf2a6708 --- /dev/null +++ b/go/samples/basic-gemini/main.go @@ -0,0 +1,69 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "errors" + "fmt" + "log" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googleai" +) + +func main() { + ctx := context.Background() + + g, err := genkit.New(nil) + if err != nil { + log.Fatal(err) + } + + // Initialize the Google AI plugin. When you pass nil for the + // Config parameter, the Google AI plugin will get the API key from the + // GOOGLE_GENAI_API_KEY environment variable, which is the recommended + // practice. + if err := googleai.Init(ctx, g, nil); err != nil { + log.Fatal(err) + } + + // Define a simple flow that generates jokes about a given topic + genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) { + m := googleai.Model(g, "gemini-1.5-flash") + if m == nil { + return "", errors.New("jokesFlow: failed to find model") + } + + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithConfig(&ai.GenerationCommonConfig{ + Temperature: 1, + Version: "gemini-1.5-flash-002", + }), + ai.WithTextPrompt(fmt.Sprintf(`Tell silly short jokes about %s`, input))) + if err != nil { + return "", err + } + + text := resp.Text() + return text, nil + }) + + if err := g.Start(ctx, nil); err != nil { + log.Fatal(err) + } +} diff --git a/js/testapps/basic-gemini/src/index.ts b/js/testapps/basic-gemini/src/index.ts index 0dfb8b9eb..90ae4da62 100644 --- a/js/testapps/basic-gemini/src/index.ts +++ b/js/testapps/basic-gemini/src/index.ts @@ -43,6 +43,8 @@ export const jokeFlow = ai.defineFlow( model: gemini15Flash, config: { temperature: 2, + // if desired, model versions can be explicitly set + version: 'gemini-1.5-flash-002', }, output: { schema: z.object({ jokeSubject: z.string() }),