Skip to content

Commit

Permalink
feature(go): add support for multiple model versions (#1575)
Browse files Browse the repository at this point in the history
  • Loading branch information
hugoaguirre authored Feb 5, 2025
1 parent 698aedf commit 0baad30
Show file tree
Hide file tree
Showing 11 changed files with 303 additions and 114 deletions.
77 changes: 58 additions & 19 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package ai

import (
Expand Down Expand Up @@ -35,33 +34,21 @@ type modelAction = core.Action[*ModelRequest, *ModelResponse, *ModelResponseChun
// ModelStreamingCallback is the type for the streaming callback of a model.
type ModelStreamingCallback = func(context.Context, *ModelResponseChunk) error

// ModelCapabilities describes various capabilities of the model.
type ModelCapabilities struct {
Multiturn bool // the model can handle multiple request-response interactions
Media bool // the model supports media as well as text input
Tools bool // the model supports tools
SystemRole bool // the model supports a system prompt or role
}

// ModelMetadata is the metadata of the model, specifying things like nice user-visible label, capabilities, etc.
type ModelMetadata struct {
Label string
Supports ModelCapabilities
}

// DefineModel registers the given generate function as an action, and returns a
// [Model] that runs it.
func DefineModel(
r *registry.Registry,
provider, name string,
metadata *ModelMetadata,
metadata *ModelInfo,
generate func(context.Context, *ModelRequest, ModelStreamingCallback) (*ModelResponse, error),
) Model {
metadataMap := map[string]any{}
if metadata == nil {
// Always make sure there's at least minimal metadata.
metadata = &ModelMetadata{
Label: name,
metadata = &ModelInfo{
Label: name,
Supports: &ModelInfoSupports{},
Versions: []string{},
}
}
if metadata.Label != "" {
Expand All @@ -74,6 +61,7 @@ func DefineModel(
"tools": metadata.Supports.Tools,
}
metadataMap["supports"] = supports
metadataMap["versions"] = metadata.Versions

return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{
"model": metadataMap,
Expand All @@ -100,8 +88,8 @@ type generateParams struct {
Request *ModelRequest
Model Model
Stream ModelStreamingCallback
History []*Message
SystemPrompt *Message
History []*Message
}

// GenerateOption configures params of the Generate call.
Expand Down Expand Up @@ -242,6 +230,19 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
if req.Model == nil {
return nil, errors.New("model is required")
}

var modelVersion string
if config, ok := req.Request.Config.(*GenerationCommonConfig); ok {
modelVersion = config.Version
}

if modelVersion != "" {
ok, err := validateModelVersion(r, modelVersion, req)
if !ok {
return nil, err
}
}

if req.History != nil {
prev := req.Request.Messages
req.Request.Messages = req.History
Expand All @@ -256,6 +257,44 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
return req.Model.Generate(ctx, r, req.Request, req.Stream)
}

// validateModelVersion checks in the registry the action of the
// given model version and determines whether its supported or not.
func validateModelVersion(r *registry.Registry, v string, req *generateParams) (bool, error) {
parts := strings.Split(req.Model.Name(), "/")
if len(parts) != 2 {
return false, errors.New("wrong model name")
}

m := LookupModel(r, parts[0], parts[1])
if m == nil {
return false, fmt.Errorf("model %s not found", v)
}

// at the end, a Model is an action so type conversion is required
if a, ok := m.(*modelActionDef); ok {
if !(modelVersionSupported(v, (*modelAction)(a).Desc().Metadata)) {
return false, fmt.Errorf("version %s not supported", v)
}
} else {
return false, errors.New("unable to validate model version")
}

return true, nil
}

// modelVersionSupported iterates over model's metadata to find the requested
// supported model version
func modelVersionSupported(modelVersion string, modelMetadata map[string]any) bool {
if md, ok := modelMetadata["model"].(map[string]any); ok {
for _, v := range md["versions"].([]string) {
if modelVersion == v {
return true
}
}
}
return false
}

// GenerateText run generate request for this model. Returns generated text only.
func GenerateText(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (string, error) {
res, err := Generate(ctx, r, opts...)
Expand Down
80 changes: 61 additions & 19 deletions go/ai/generator_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package ai

import (
Expand All @@ -23,30 +22,46 @@ type GameCharacter struct {

var r, _ = registry.New()

var echoModel = DefineModel(r, "test", "echo", nil, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
if msc != nil {
msc(ctx, &ModelResponseChunk{
Content: []*Part{NewTextPart("stream!")},
})
// echoModel attributes
var (
modelName = "echo"
metadata = ModelInfo{
Label: modelName,
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
SystemRole: true,
Media: false,
},
Versions: []string{"echo-001", "echo-002"},
}
textResponse := ""
for _, m := range gr.Messages {
if m.Role == RoleUser {
textResponse += m.Content[0].Text

echoModel = DefineModel(r, "test", modelName, &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
if msc != nil {
msc(ctx, &ModelResponseChunk{
Content: []*Part{NewTextPart("stream!")},
})
}
}
return &ModelResponse{
Request: gr,
Message: NewUserTextMessage(textResponse),
}, nil
})
textResponse := ""
for _, m := range gr.Messages {
if m.Role == RoleUser {
textResponse += m.Content[0].Text
}
}
return &ModelResponse{
Request: gr,
Message: NewUserTextMessage(textResponse),
}, nil
})
)

// with tools
var gablorkenTool = DefineTool(r, "gablorken", "use when need to calculate a gablorken",
func(ctx context.Context, input struct {
Value float64
Over float64
}) (float64, error) {
},
) (float64, error) {
return math.Pow(input.Value, input.Over), nil
},
)
Expand Down Expand Up @@ -320,9 +335,36 @@ func TestGenerate(t *testing.T) {
})
}

func TestModelVersion(t *testing.T) {
t.Run("valid version", func(t *testing.T) {
_, err := Generate(context.Background(), r,
WithModel(echoModel),
WithConfig(&GenerationCommonConfig{
Temperature: 1,
Version: "echo-001",
}),
WithTextPrompt("tell a joke about batman"))
if err != nil {
t.Errorf("model version should be valid")
}
})
t.Run("invalid version", func(t *testing.T) {
_, err := Generate(context.Background(), r,
WithModel(echoModel),
WithConfig(&GenerationCommonConfig{
Temperature: 1,
Version: "echo-im-not-a-version",
}),
WithTextPrompt("tell a joke about batman"))
if err == nil {
t.Errorf("model version should be invalid: %v", err)
}
})
}

func TestIsDefinedModel(t *testing.T) {
t.Run("should return true", func(t *testing.T) {
if IsDefinedModel(r, "test", "echo") != true {
if IsDefinedModel(r, "test", modelName) != true {
t.Errorf("IsDefinedModel did not return true")
}
})
Expand All @@ -335,7 +377,7 @@ func TestIsDefinedModel(t *testing.T) {

func TestLookupModel(t *testing.T) {
t.Run("should return model", func(t *testing.T) {
if LookupModel(r, "test", "echo") == nil {
if LookupModel(r, "test", modelName) == nil {
t.Errorf("LookupModel did not return model")
}
})
Expand Down
3 changes: 1 addition & 2 deletions go/genkit/genkit.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


// Package genkit provides Genkit functionality for application developers.
package genkit

Expand Down Expand Up @@ -157,7 +156,7 @@ func (g *Genkit) Start(ctx context.Context, opts *StartOptions) error {
func DefineModel(
g *Genkit,
provider, name string,
metadata *ai.ModelMetadata,
metadata *ai.ModelInfo,
generate func(context.Context, *ai.ModelRequest, ai.ModelStreamingCallback) (*ai.ModelResponse, error),
) ai.Model {
return ai.DefineModel(g.reg, provider, name, metadata, generate)
Expand Down
13 changes: 7 additions & 6 deletions go/internal/doc-snippets/modelplugin/modelplugin.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package modelplugin

import (
Expand All @@ -17,8 +16,8 @@ const providerID = "mymodels"
// [START cfg]
type MyModelConfig struct {
ai.GenerationCommonConfig
CustomOption int
AnotherCustomOption string
CustomOption int
}

// [END cfg]
Expand All @@ -30,16 +29,18 @@ func Init() error {
}

// [START definemodel]
name := "my-model"
genkit.DefineModel(g,
providerID, "my-model",
&ai.ModelMetadata{
Label: "my-model",
Supports: ai.ModelCapabilities{
providerID, name,
&ai.ModelInfo{
Label: name,
Supports: &ai.ModelInfoSupports{
Multiturn: true, // Does the model support multi-turn chats?
SystemRole: true, // Does the model support syatem messages?
Media: false, // Can the model accept media input?
Tools: false, // Does the model support function calling (tools)?
},
Versions: []string{},
},
func(ctx context.Context,
genRequest *ai.ModelRequest,
Expand Down
16 changes: 10 additions & 6 deletions go/internal/doc-snippets/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,21 @@ func ollamaEx(ctx context.Context) error {
// [END init]

// [START definemodel]
name := "gemma2"
model := ollama.DefineModel(
g,
ollama.ModelDefinition{
Name: "gemma2",
Name: name,
Type: "chat", // "chat" or "generate"
},
&ai.ModelCapabilities{
Multiturn: true,
SystemRole: true,
Tools: false,
Media: false,
&ai.ModelInfo{
Label: name,
Supports: &ai.ModelInfoSupports{
Multiturn: true,
SystemRole: true,
Tools: false,
Media: false,
},
},
)
// [END definemodel]
Expand Down
Loading

0 comments on commit 0baad30

Please sign in to comment.