Skip to content

Commit 586087d

Browse files
authored
fix(go/core): added normalization of input to RunJSONWithTelemetry (#3613)
1 parent 9838d17 commit 586087d

File tree

7 files changed

+218
-40
lines changed

7 files changed

+218
-40
lines changed

go/core/action.go

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,25 +258,31 @@ func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.Raw
258258

259259
// RunJSON runs the action with a JSON input, and returns a JSON result along with telemetry info.
260260
func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) {
261-
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
262-
if err := base.ValidateJSON(input, a.desc.InputSchema); err != nil {
263-
return nil, NewError(INVALID_ARGUMENT, err.Error())
264-
}
265-
266261
var i In
267262
if len(input) > 0 {
268-
if err := json.Unmarshal(input, &i); err != nil {
269-
return nil, NewError(INVALID_ARGUMENT, "invalid input: %v", err)
263+
// First unmarshal input into a generic value to handle unknown fields and null values which
264+
// would be discard and/or converted into zero values if unmarshaled directly into the In type.
265+
var rawData any
266+
if err := json.Unmarshal(input, &rawData); err != nil {
267+
return nil, NewError(INTERNAL, "failed to unmarshal input: %v", err)
270268
}
271269

272-
// Adhere to the input schema if the number type is ambiguous and the input type is an any.
273-
converted, err := base.ConvertJSONNumbers(i, a.desc.InputSchema)
270+
normalized, err := base.NormalizeInput(rawData, a.desc.InputSchema)
274271
if err != nil {
275272
return nil, NewError(INVALID_ARGUMENT, "invalid input: %v", err)
276273
}
277274

278-
if result, ok := converted.(In); ok {
279-
i = result
275+
if err := base.ValidateValue(normalized, a.desc.InputSchema); err != nil {
276+
return nil, NewError(INVALID_ARGUMENT, err.Error())
277+
}
278+
279+
normalizedBytes, err := json.Marshal(normalized)
280+
if err != nil {
281+
return nil, NewError(INTERNAL, "failed to marshal normalized input: %v", err)
282+
}
283+
284+
if err := json.Unmarshal(normalizedBytes, &i); err != nil {
285+
return nil, NewError(INTERNAL, "failed to unmarshal normalized input: %v", err)
280286
}
281287
}
282288

go/internal/base/json_type_converter.go

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,21 @@ import (
2020
"fmt"
2121
)
2222

23-
// ConvertJSONNumbers recursively traverses a data structure and a corresponding JSON schema.
24-
// It converts instances of float64 into int64 or float64 based on the schema's "type" property.
25-
func ConvertJSONNumbers(data any, schema map[string]any) (any, error) {
26-
if data == nil || schema == nil {
23+
// NormalizeInput recursively traverses a data structure and performs normalization:
24+
// 1. Removes any fields with null values
25+
// 2. Converts instances of float64 into int64 or float64 based on the schema's "type" property
26+
func NormalizeInput(data any, schema map[string]any) (any, error) {
27+
if data == nil {
2728
return data, nil
2829
}
2930

3031
switch d := data.(type) {
3132
case float64:
3233
return convertFloat64(d, schema)
3334
case map[string]any:
34-
return convertObjectNumbers(d, schema)
35+
return normalizeObjectInput(d, schema)
3536
case []any:
36-
return convertArrayNumbers(d, schema)
37+
return normalizeArrayInput(d, schema)
3738
default:
3839
return data, nil
3940
}
@@ -60,45 +61,78 @@ func convertFloat64(f float64, schema map[string]any) (any, error) {
6061
}
6162
}
6263

63-
// convertObjectNumbers converts any float64s in the map values to int64 or float64 based on the schema's "type" property.
64-
func convertObjectNumbers(obj map[string]any, schema map[string]any) (map[string]any, error) {
65-
props, ok := schema["properties"].(map[string]any)
66-
if !ok {
67-
return obj, nil // No properties to guide conversion
64+
// normalizeObjectInput normalizes map values by removing null fields and converting JSON numbers.
65+
func normalizeObjectInput(obj map[string]any, schema map[string]any) (map[string]any, error) {
66+
var props map[string]any
67+
if schema != nil {
68+
props, _ = schema["properties"].(map[string]any)
69+
}
70+
71+
// If no schema or no properties, just remove null fields and normalize recursively
72+
if schema == nil || props == nil {
73+
newObj := make(map[string]any)
74+
for k, v := range obj {
75+
if v != nil {
76+
normalized, err := NormalizeInput(v, nil)
77+
if err != nil {
78+
return nil, err
79+
}
80+
newObj[k] = normalized
81+
}
82+
}
83+
return newObj, nil
6884
}
6985

70-
newObj := make(map[string]any, len(obj))
86+
newObj := make(map[string]any)
7187
for k, v := range obj {
72-
newObj[k] = v // Copy original value
88+
// Skip null values - this removes the field entirely
89+
if v == nil {
90+
continue
91+
}
7392

7493
propSchema, ok := props[k].(map[string]any)
7594
if !ok {
76-
continue // No schema for this property
95+
// No schema for this property, just keep it if not null
96+
normalized, err := NormalizeInput(v, nil)
97+
if err != nil {
98+
return nil, err
99+
}
100+
newObj[k] = normalized
101+
continue
77102
}
78103

79-
converted, err := ConvertJSONNumbers(v, propSchema)
104+
normalized, err := NormalizeInput(v, propSchema)
80105
if err != nil {
81106
return nil, err
82107
}
83-
newObj[k] = converted
108+
newObj[k] = normalized
84109
}
85110
return newObj, nil
86111
}
87112

88-
// convertArrayNumbers converts any float64s in the array values to int64 or float64 based on the schema's "type" property.
89-
func convertArrayNumbers(arr []any, schema map[string]any) ([]any, error) {
113+
// normalizeArrayInput normalizes array values by converting JSON numbers and handling null elements.
114+
func normalizeArrayInput(arr []any, schema map[string]any) ([]any, error) {
90115
items, ok := schema["items"].(map[string]any)
91116
if !ok {
92-
return arr, nil // No items schema to guide conversion
117+
// No items schema, just normalize each element
118+
newArr := make([]any, len(arr))
119+
for i, v := range arr {
120+
normalized, err := NormalizeInput(v, nil)
121+
if err != nil {
122+
return nil, err
123+
}
124+
newArr[i] = normalized
125+
}
126+
return newArr, nil
93127
}
94128

95129
newArr := make([]any, len(arr))
96130
for i, v := range arr {
97-
converted, err := ConvertJSONNumbers(v, items)
131+
normalized, err := NormalizeInput(v, items)
98132
if err != nil {
99133
return nil, err
100134
}
101-
newArr[i] = converted
135+
newArr[i] = normalized
102136
}
103137
return newArr, nil
104138
}
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// SPDX-License-Identifier: Apache-2.0
16+
17+
package base
18+
19+
import (
20+
"reflect"
21+
"testing"
22+
)
23+
24+
func TestNormalizeInput(t *testing.T) {
25+
tests := []struct {
26+
name string
27+
data any
28+
schema map[string]any
29+
expected any
30+
}{
31+
{
32+
name: "removes null fields from object",
33+
data: map[string]any{
34+
"name": "test",
35+
"nullField": nil,
36+
"emptyString": "",
37+
"number": 42.0,
38+
},
39+
schema: map[string]any{
40+
"type": "object",
41+
"properties": map[string]any{
42+
"name": map[string]any{"type": "string"},
43+
"nullField": map[string]any{"type": "string"},
44+
"emptyString": map[string]any{"type": "string"},
45+
"number": map[string]any{"type": "integer"},
46+
},
47+
},
48+
expected: map[string]any{
49+
"name": "test",
50+
"emptyString": "",
51+
"number": int64(42),
52+
},
53+
},
54+
{
55+
name: "removes null fields without schema",
56+
data: map[string]any{
57+
"name": "test",
58+
"nullField": nil,
59+
"value": 123,
60+
},
61+
schema: nil,
62+
expected: map[string]any{
63+
"name": "test",
64+
"value": 123,
65+
},
66+
},
67+
{
68+
name: "handles nested objects with null fields",
69+
data: map[string]any{
70+
"outer": map[string]any{
71+
"inner": "value",
72+
"nullInner": nil,
73+
},
74+
"nullOuter": nil,
75+
},
76+
schema: map[string]any{
77+
"type": "object",
78+
"properties": map[string]any{
79+
"outer": map[string]any{
80+
"type": "object",
81+
"properties": map[string]any{
82+
"inner": map[string]any{"type": "string"},
83+
"nullInner": map[string]any{"type": "string"},
84+
},
85+
},
86+
"nullOuter": map[string]any{"type": "string"},
87+
},
88+
},
89+
expected: map[string]any{
90+
"outer": map[string]any{
91+
"inner": "value",
92+
},
93+
},
94+
},
95+
{
96+
name: "converts numbers correctly",
97+
data: map[string]any{
98+
"intField": 42.0,
99+
"floatField": 3.14,
100+
},
101+
schema: map[string]any{
102+
"type": "object",
103+
"properties": map[string]any{
104+
"intField": map[string]any{"type": "integer"},
105+
"floatField": map[string]any{"type": "number"},
106+
},
107+
},
108+
expected: map[string]any{
109+
"intField": int64(42),
110+
"floatField": 3.14,
111+
},
112+
},
113+
{
114+
name: "handles arrays with null elements",
115+
data: []any{"item1", nil, "item2"},
116+
schema: map[string]any{
117+
"type": "array",
118+
"items": map[string]any{
119+
"type": "string",
120+
},
121+
},
122+
expected: []any{"item1", nil, "item2"}, // Arrays preserve null elements
123+
},
124+
}
125+
126+
for _, test := range tests {
127+
t.Run(test.name, func(t *testing.T) {
128+
result, err := NormalizeInput(test.data, test.schema)
129+
if err != nil {
130+
t.Fatalf("NormalizeInput returned error: %v", err)
131+
}
132+
133+
if !reflect.DeepEqual(result, test.expected) {
134+
t.Errorf("NormalizeInput result mismatch:\nExpected: %+v\nGot: %+v", test.expected, result)
135+
}
136+
})
137+
}
138+
}

go/samples/prompts/main.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ import (
3434
func main() {
3535
ctx := context.Background()
3636
g := genkit.Init(ctx,
37-
genkit.WithDefaultModel("vertexai/gemini-2.0-flash"),
38-
genkit.WithPlugins(&googlegenai.VertexAI{}),
37+
genkit.WithDefaultModel("googleai/gemini-2.5-flash"),
38+
genkit.WithPlugins(&googlegenai.GoogleAI{}),
3939
genkit.WithPromptDir("prompts"),
4040
)
4141

@@ -62,7 +62,7 @@ func SimplePrompt(ctx context.Context, g *genkit.Genkit) {
6262
// Define prompt with default model and system text.
6363
helloPrompt := genkit.DefinePrompt(
6464
g, "SimplePrompt",
65-
ai.WithModelName("vertexai/gemini-2.5-pro"), // Override the default model.
65+
ai.WithModelName("googleai/gemini-2.5-pro"), // Override the default model.
6666
ai.WithSystem("You are a helpful AI assistant named Walt. Greet the user."),
6767
ai.WithPrompt("Hello, who are you?"),
6868
)
@@ -272,7 +272,7 @@ func PromptWithExecuteOverrides(ctx context.Context, g *genkit.Genkit) {
272272

273273
// Call the model and add additional messages from the user.
274274
resp, err := helloPrompt.Execute(ctx,
275-
ai.WithModel(googlegenai.VertexAIModel(g, "gemini-2.5-pro")),
275+
ai.WithModel(googlegenai.GoogleAIModel(g, "gemini-2.5-pro")),
276276
ai.WithMessages(ai.NewUserTextMessage("And I like turtles.")),
277277
)
278278
if err != nil {
@@ -319,7 +319,7 @@ func PromptWithMediaType(ctx context.Context, g *genkit.Genkit) {
319319
log.Fatal("empty prompt")
320320
}
321321
resp, err := prompt.Execute(ctx,
322-
ai.WithModelName("vertexai/gemini-2.0-flash"),
322+
ai.WithModelName("googleai/gemini-2.0-flash"),
323323
ai.WithInput(map[string]any{"imageUrl": "data:image/jpg;base64," + img}),
324324
)
325325
if err != nil {

go/samples/prompts/prompts/countries.prompt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
model: vertexai/gemini-2.0-flash
2+
model: googleai/gemini-2.0-flash
33
config:
44
temperature: 0.5
55
output:

go/samples/prompts/prompts/media.prompt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
model: vertexai/gemini-2.5-flash-preview-04-17
2+
model: googleai/gemini-2.5-flash-preview-04-17
33
config:
44
temperature: 0.1
55
input:

go/samples/prompts/prompts/multi-msg.prompt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
model: vertexai/gemini-2.0-flash
2+
model: googleai/gemini-2.0-flash
33
---
44
{{ role "system" }}
55

0 commit comments

Comments
 (0)