Skip to content

test(go/plugins/googlegenai): add UT to gemini utils functions #2593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/ai/model_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func validateSupport(model string, info *ModelInfo) ModelMiddleware {
case ModelStageDeprecated:
logger.FromContext(ctx).Warn("model is deprecated and may be removed in a future release", "model", model)
case ModelStageUnstable:
logger.FromContext(ctx).Warn("model is experimental or unstable", "model", model)
logger.FromContext(ctx).Info("model is experimental or unstable", "model", model)
}
}

Expand Down
8 changes: 4 additions & 4 deletions go/genkit/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,8 @@ func LookupTool(g *Genkit, name string) ai.Tool {
// log.Fatalf("GenerateWithRequest failed: %v", err)
// }
// var out1 GeoOutput
// if err = resp1.UnmarshalOutput(&out1); err != nil {
// log.Fatalf("UnmarshalOutput failed: %v", err)
// if err = resp1.Output(&out1); err != nil {
// log.Fatalf("Output failed: %v", err)
// }
// fmt.Printf("Capital of USA: %s\n", out1.Capital) // Output: Capital of USA: Washington D.C.
//
Expand All @@ -539,8 +539,8 @@ func LookupTool(g *Genkit, name string) ai.Tool {
// log.Fatalf("Execute failed: %v", err)
// }
// var out2 GeoOutput
// if err = resp2.UnmarshalOutput(&out2); err != nil {
// log.Fatalf("UnmarshalOutput failed: %v", err)
// if err = resp2.Output(&out2); err != nil {
// log.Fatalf("Output failed: %v", err)
// }
// fmt.Printf("Capital of France: %s\n", out2.Capital) // Output: Capital of France: Paris
func DefinePrompt(g *Genkit, name string, opts ...ai.PromptOption) (*ai.Prompt, error) {
Expand Down
2 changes: 1 addition & 1 deletion go/genkit/reflection.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ func handleRunAction(reg *registry.Registry) func(w http.ResponseWriter, r *http
}

errorJSON, _ := json.Marshal(genkitErr)
_, writeErr := fmt.Fprintf(w, "%s\n", errorJSON)
_, writeErr := fmt.Fprintf(w, "%s\n\n", errorJSON)
if writeErr != nil {
return writeErr
}
Expand Down
56 changes: 36 additions & 20 deletions go/plugins/googlegenai/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ func generate(
return nil, err
}

gc, err := convertRequest(client, input, cache)
gc, err := convertRequest(model, input, cache)
if err != nil {
return nil, err
}
Expand All @@ -337,14 +337,20 @@ func generate(
Role: string(m.Role),
})
}
if len(contents) == 0 {
return nil, fmt.Errorf("at least one message is required in generate request")
}

// Send out the actual request.
if cb == nil {
resp, err := client.Models.GenerateContent(ctx, model, contents, gc)
if err != nil {
return nil, err
}
r := translateResponse(resp)
r, err := translateResponse(resp)
if err != nil {
return nil, err
}
r.Request = input
if cache != nil {
r.Message.Metadata = setCacheMetadata(r.Message.Metadata, cache)
Expand All @@ -365,8 +371,11 @@ func generate(
return nil, err
}
for i, c := range chunk.Candidates {
tc := translateCandidate(c)
err := cb(ctx, &ai.ModelResponseChunk{
tc, err := translateCandidate(c)
if err != nil {
return nil, err
}
err = cb(ctx, &ai.ModelResponseChunk{
Content: tc.Message.Content,
})
if err != nil {
Expand All @@ -389,7 +398,10 @@ func generate(
},
}
resp.Candidates = merged
r = translateResponse(resp)
r, err = translateResponse(resp)
if err != nil {
return nil, err
}
if r == nil {
// No candidates were returned. Probably rare, but it might avoid a NPE
// to return an empty instead of nil result.
Expand All @@ -405,7 +417,7 @@ func generate(

// convertRequest translates from [*ai.ModelRequest] to
// *genai.GenerateContentParameters
func convertRequest(client *genai.Client, input *ai.ModelRequest, cache *genai.CachedContent) (*genai.GenerateContentConfig, error) {
func convertRequest(model string, input *ai.ModelRequest, cache *genai.CachedContent) (*genai.GenerateContentConfig, error) {
gc := genai.GenerateContentConfig{}
gc.CandidateCount = genai.Ptr[int32](1)
c, err := extractConfigFromInput(input)
Expand Down Expand Up @@ -451,10 +463,11 @@ func convertRequest(client *genai.Client, input *ai.ModelRequest, cache *genai.C
gc.Tools = tools

// Then set up the tool configuration based on ToolChoice
tc := convertToolChoice(input.ToolChoice, input.Tools)
if tc != nil {
gc.ToolConfig = tc
tc, err := convertToolChoice(input.ToolChoice, input.Tools)
if err != nil {
return nil, err
}
gc.ToolConfig = tc
}

var systemParts []*genai.Part
Expand Down Expand Up @@ -582,19 +595,19 @@ func castToStringArray(i []any) []string {
return r
}

func convertToolChoice(toolChoice ai.ToolChoice, tools []*ai.ToolDefinition) *genai.ToolConfig {
func convertToolChoice(toolChoice ai.ToolChoice, tools []*ai.ToolDefinition) (*genai.ToolConfig, error) {
var mode genai.FunctionCallingConfigMode
switch toolChoice {
case "":
return nil
return nil, nil
case ai.ToolChoiceAuto:
mode = genai.FunctionCallingConfigModeAuto
case ai.ToolChoiceRequired:
mode = genai.FunctionCallingConfigModeAny
case ai.ToolChoiceNone:
mode = genai.FunctionCallingConfigModeNone
default:
panic(fmt.Sprintf("tool choice mode %q not supported", toolChoice))
return nil, fmt.Errorf("tool choice mode %q not supported", toolChoice)
}

var toolNames []string
Expand All @@ -609,11 +622,11 @@ func convertToolChoice(toolChoice ai.ToolChoice, tools []*ai.ToolDefinition) *ge
Mode: mode,
AllowedFunctionNames: toolNames,
},
}
}, nil
}

// translateCandidate translates from a genai.GenerateContentResponse to an ai.ModelResponse.
func translateCandidate(cand *genai.Candidate) *ai.ModelResponse {
func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) {
m := &ai.ModelResponse{}
switch cand.FinishReason {
case genai.FinishReasonStop:
Expand Down Expand Up @@ -654,26 +667,29 @@ func translateCandidate(cand *genai.Candidate) *ai.ModelResponse {
})
}
if partFound > 1 {
panic(fmt.Sprintf("expected only 1 content part in response, got %d, part: %#v", partFound, part))
return nil, fmt.Errorf("expected only 1 content part in response, got %d, part: %#v", partFound, part)
}

msg.Content = append(msg.Content, p)
}
m.Message = msg
return m
return m, nil
}

// Translate from a genai.GenerateContentResponse to a ai.ModelResponse.
func translateResponse(resp *genai.GenerateContentResponse) *ai.ModelResponse {
r := translateCandidate(resp.Candidates[0])
func translateResponse(resp *genai.GenerateContentResponse) (*ai.ModelResponse, error) {
r, err := translateCandidate(resp.Candidates[0])
if err != nil {
return nil, err
}

r.Usage = &ai.GenerationUsage{}
if u := resp.UsageMetadata; u != nil {
r.Usage.InputTokens = int(*u.PromptTokenCount)
r.Usage.OutputTokens = int(*u.CandidatesTokenCount)
r.Usage.TotalTokens = int(u.TotalTokenCount)
}
return r
return r, nil
}

// convertParts converts a slice of *ai.Part to a slice of genai.Part.
Expand Down Expand Up @@ -728,6 +744,6 @@ func convertPart(p *ai.Part) (*genai.Part, error) {
fc := genai.NewPartFromFunctionCall(toolReq.Name, input)
return fc, nil
default:
panic("unknown part type in a request")
return nil, fmt.Errorf("unsupported part in the request: %#v", p)
}
}
Loading
Loading