Skip to content

Commit 9838d17

Browse files
fix(go/ai): fixed bad stream message format parsing (#3573)
1 parent dbdc5fc commit 9838d17

File tree

5 files changed

+509
-64
lines changed

5 files changed

+509
-64
lines changed

go/ai/format_array.go

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"encoding/json"
1919
"errors"
2020
"fmt"
21+
"strings"
2122

2223
"github.com/firebase/genkit/go/internal/base"
2324
)
@@ -78,27 +79,33 @@ func (a arrayHandler) ParseMessage(m *Message) (*Message, error) {
7879
return nil, errors.New("message has no content")
7980
}
8081

81-
var newParts []*Part
82+
var nonTextParts []*Part
83+
accumulatedText := strings.Builder{}
84+
8285
for _, part := range m.Content {
8386
if !part.IsText() {
84-
newParts = append(newParts, part)
87+
nonTextParts = append(nonTextParts, part)
8588
} else {
86-
lines := base.GetJsonObjectLines(part.Text)
87-
for _, line := range lines {
88-
var schemaBytes []byte
89-
schemaBytes, err := json.Marshal(a.config.Schema["items"])
90-
if err != nil {
91-
return nil, fmt.Errorf("expected schema is not valid: %w", err)
92-
}
93-
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
94-
return nil, err
95-
}
96-
97-
newParts = append(newParts, NewJSONPart(line))
98-
}
89+
accumulatedText.WriteString(part.Text)
9990
}
10091
}
101-
m.Content = newParts
92+
93+
var newParts []*Part
94+
lines := base.GetJsonObjectLines(accumulatedText.String())
95+
for _, line := range lines {
96+
var schemaBytes []byte
97+
schemaBytes, err := json.Marshal(a.config.Schema["items"])
98+
if err != nil {
99+
return nil, fmt.Errorf("expected schema is not valid: %w", err)
100+
}
101+
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
102+
return nil, err
103+
}
104+
105+
newParts = append(newParts, NewJSONPart(line))
106+
}
107+
108+
m.Content = append(newParts, nonTextParts...)
102109
}
103110

104111
return m, nil

go/ai/format_enum.go

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,24 +77,31 @@ func (e enumHandler) ParseMessage(m *Message) (*Message, error) {
7777
return nil, errors.New("message has no content")
7878
}
7979

80-
for i, part := range m.Content {
80+
var nonTextParts []*Part
81+
accumulatedText := strings.Builder{}
82+
for _, part := range m.Content {
8183
if !part.IsText() {
82-
continue
84+
nonTextParts = append(nonTextParts, part)
85+
} else {
86+
accumulatedText.WriteString(part.Text)
8387
}
88+
}
8489

85-
// replace single and double quotes
86-
re := regexp.MustCompile(`['"]`)
87-
clean := re.ReplaceAllString(part.Text, "")
88-
89-
// trim whitespace
90-
trimmed := strings.TrimSpace(clean)
90+
// replace single and double quotes
91+
re := regexp.MustCompile(`['"]`)
92+
clean := re.ReplaceAllString(accumulatedText.String(), "")
9193

92-
if !slices.Contains(e.enums, trimmed) {
93-
return nil, fmt.Errorf("message %s not in list of valid enums: %s", trimmed, strings.Join(e.enums, ", "))
94-
}
94+
// trim whitespace
95+
trimmed := strings.TrimSpace(clean)
9596

96-
m.Content[i] = NewTextPart(trimmed)
97+
if !slices.Contains(e.enums, trimmed) {
98+
return nil, fmt.Errorf("message %s not in list of valid enums: %s", trimmed, strings.Join(e.enums, ", "))
9799
}
100+
101+
newParts := []*Part{NewTextPart(trimmed)}
102+
newParts = append(newParts, nonTextParts...)
103+
104+
m.Content = newParts
98105
}
99106

100107
return m, nil

go/ai/format_json.go

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"encoding/json"
1919
"errors"
2020
"fmt"
21+
"strings"
2122

2223
"github.com/firebase/genkit/go/internal/base"
2324
)
@@ -79,30 +80,38 @@ func (j jsonHandler) ParseMessage(m *Message) (*Message, error) {
7980
return nil, errors.New("message has no content")
8081
}
8182

82-
for i, part := range m.Content {
83-
if !part.IsText() {
84-
continue
85-
}
83+
var nonTextParts []*Part
84+
accumulatedText := strings.Builder{}
8685

87-
text := base.ExtractJSONFromMarkdown(part.Text)
88-
89-
if j.config.Schema != nil {
90-
var schemaBytes []byte
91-
schemaBytes, err := json.Marshal(j.config.Schema)
92-
if err != nil {
93-
return nil, fmt.Errorf("expected schema is not valid: %w", err)
94-
}
95-
if err = base.ValidateRaw([]byte(text), schemaBytes); err != nil {
96-
return nil, err
97-
}
86+
for _, part := range m.Content {
87+
if !part.IsText() {
88+
nonTextParts = append(nonTextParts, part)
9889
} else {
99-
if !base.ValidJSON(text) {
100-
return nil, errors.New("message is not a valid JSON")
101-
}
90+
accumulatedText.WriteString(part.Text)
10291
}
92+
}
93+
94+
text := base.ExtractJSONFromMarkdown(accumulatedText.String())
10395

104-
m.Content[i] = NewJSONPart(text)
96+
if j.config.Schema != nil {
97+
var schemaBytes []byte
98+
schemaBytes, err := json.Marshal(j.config.Schema)
99+
if err != nil {
100+
return nil, fmt.Errorf("expected schema is not valid: %w", err)
101+
}
102+
if err = base.ValidateRaw([]byte(text), schemaBytes); err != nil {
103+
return nil, err
104+
}
105+
} else {
106+
if !base.ValidJSON(text) {
107+
return nil, errors.New("message is not a valid JSON")
108+
}
105109
}
110+
111+
newParts := []*Part{NewJSONPart(text)}
112+
newParts = append(newParts, nonTextParts...)
113+
114+
m.Content = newParts
106115
}
107116

108117
return m, nil

go/ai/format_jsonl.go

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"encoding/json"
1919
"errors"
2020
"fmt"
21+
"strings"
2122

2223
"github.com/firebase/genkit/go/internal/base"
2324
)
@@ -79,29 +80,35 @@ func (j jsonlHandler) ParseMessage(m *Message) (*Message, error) {
7980
return nil, errors.New("message has no content")
8081
}
8182

82-
var newParts []*Part
83+
var nonTextParts []*Part
84+
accumulatedText := strings.Builder{}
85+
8386
for _, part := range m.Content {
8487
if !part.IsText() {
85-
newParts = append(newParts, part)
88+
nonTextParts = append(nonTextParts, part)
8689
} else {
87-
lines := base.GetJsonObjectLines(part.Text)
88-
for _, line := range lines {
89-
if j.config.Schema != nil {
90-
var schemaBytes []byte
91-
schemaBytes, err := json.Marshal(j.config.Schema["items"])
92-
if err != nil {
93-
return nil, fmt.Errorf("expected schema is not valid: %w", err)
94-
}
95-
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
96-
return nil, err
97-
}
98-
}
99-
100-
newParts = append(newParts, NewJSONPart(line))
90+
accumulatedText.WriteString(part.Text)
91+
}
92+
}
93+
94+
var newParts []*Part
95+
lines := base.GetJsonObjectLines(accumulatedText.String())
96+
for _, line := range lines {
97+
if j.config.Schema != nil {
98+
var schemaBytes []byte
99+
schemaBytes, err := json.Marshal(j.config.Schema["items"])
100+
if err != nil {
101+
return nil, fmt.Errorf("expected schema is not valid: %w", err)
102+
}
103+
if err = base.ValidateRaw([]byte(line), schemaBytes); err != nil {
104+
return nil, err
101105
}
102106
}
107+
108+
newParts = append(newParts, NewJSONPart(line))
103109
}
104-
m.Content = newParts
110+
111+
m.Content = append(newParts, nonTextParts...)
105112
}
106113

107114
return m, nil

0 commit comments

Comments
 (0)