Skip to content

Commit

Permalink
Merge pull request #2807 from rockwotj/openai
Browse files Browse the repository at this point in the history
Support JSON Schemas for OpenAI
  • Loading branch information
rockwotj authored Sep 3, 2024
2 parents 4526abd + 09bd331 commit 1c664da
Show file tree
Hide file tree
Showing 23 changed files with 851 additions and 198 deletions.
431 changes: 431 additions & 0 deletions docs/modules/components/pages/processors/openai_chat_completion.adoc

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ require (
cloud.google.com/go/pubsub v1.40.0
cloud.google.com/go/storage v1.42.0
cloud.google.com/go/vertexai v0.12.0
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.0.3
Expand Down Expand Up @@ -99,6 +98,7 @@ require (
github.com/redpanda-data/benthos/v4 v4.36.0
github.com/redpanda-data/connect/public/bundle/free/v4 v4.31.0
github.com/rs/xid v1.5.0
github.com/sashabaranov/go-openai v1.28.3
github.com/sijms/go-ora/v2 v2.8.19
github.com/smira/go-statsd v1.3.3
github.com/snowflakedb/gosnowflake v1.11.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ github.com/AthenZ/athenz v1.10.43/go.mod h1:pEm4lLLcpwxS33OdM8JNCS7GnWBoY/12QD7i
github.com/Azure/azure-pipeline-go v0.2.3/go.mod h1:x841ezTBIMG6O3lAcl8ATHnsOPVl2bqk7S3ta6S6u4k=
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU=
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0 h1:FQOmDxJj1If0D0khZR00MDa2Eb+k9BBsSaK7cEbLwkk=
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0/go.mod h1:X0+PSrHOZdTjkiEhgv53HS5gplbzVVl2jd6hQRYSS3c=
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0/go.mod h1:h6H6c8enJmmocHUbLiiGY6sx7f9i+X3m1CHdd5c6Rdw=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0=
Expand Down Expand Up @@ -1052,6 +1050,8 @@ github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThC
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w=
github.com/sashabaranov/go-openai v1.28.3 h1:9ZjKWwFOO8RRgHarUC8rTPSLBZgkNzjyf18O9/8+jto=
github.com/sashabaranov/go-openai v1.28.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
Expand Down
5 changes: 3 additions & 2 deletions internal/impl/confluent/processor_schema_registry_decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/Jeffail/shutdown"

"github.com/redpanda-data/benthos/v4/public/service"
"github.com/redpanda-data/connect/v4/internal/impl/confluent/sr"
)

func schemaRegistryDecoderConfig() *service.ConfigSpec {
Expand Down Expand Up @@ -87,7 +88,7 @@ func init() {

type schemaRegistryDecoder struct {
avroRawJSON bool
client *schemaRegistryClient
client *sr.Client

schemas map[int]*cachedSchemaDecoder
cacheMut sync.RWMutex
Expand Down Expand Up @@ -133,7 +134,7 @@ func newSchemaRegistryDecoder(
mgr: mgr,
}
var err error
if s.client, err = newSchemaRegistryClient(urlStr, reqSigner, tlsConf, mgr); err != nil {
if s.client, err = sr.NewClient(urlStr, reqSigner, tlsConf, mgr); err != nil {
return nil, err
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ basic_auth:

e, err := newSchemaRegistryDecoderFromConfig(conf, service.MockResources())
if e != nil {
assert.Equal(t, test.expectedBaseURL, e.client.schemaRegistryBaseURL.String())
assert.Equal(t, test.expectedBaseURL, e.client.SchemaRegistryBaseURL.String())
}

if err == nil {
Expand Down
5 changes: 3 additions & 2 deletions internal/impl/confluent/processor_schema_registry_encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/Jeffail/shutdown"

"github.com/redpanda-data/benthos/v4/public/service"
"github.com/redpanda-data/connect/v4/internal/impl/confluent/sr"
)

func schemaRegistryEncoderConfig() *service.ConfigSpec {
Expand Down Expand Up @@ -107,7 +108,7 @@ func init() {
//------------------------------------------------------------------------------

type schemaRegistryEncoder struct {
client *schemaRegistryClient
client *sr.Client
subject *service.InterpolatedString
avroRawJSON bool
schemaRefreshAfter time.Duration
Expand Down Expand Up @@ -178,7 +179,7 @@ func newSchemaRegistryEncoder(
nowFn: time.Now,
}
var err error
if s.client, err = newSchemaRegistryClient(urlStr, reqSigner, tlsConf, mgr); err != nil {
if s.client, err = sr.NewClient(urlStr, reqSigner, tlsConf, mgr); err != nil {
return nil, err
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ subject: foo

e, err := newSchemaRegistryEncoderFromConfig(conf, service.MockResources())
if e != nil {
assert.Equal(t, test.expectedBaseURL, e.client.schemaRegistryBaseURL.String())
assert.Equal(t, test.expectedBaseURL, e.client.SchemaRegistryBaseURL.String())
}

if err == nil {
Expand Down
9 changes: 5 additions & 4 deletions internal/impl/confluent/serde_avro.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ import (
"github.com/linkedin/goavro/v2"

"github.com/redpanda-data/benthos/v4/public/service"
"github.com/redpanda-data/connect/v4/internal/impl/confluent/sr"
)

func resolveAvroReferences(ctx context.Context, client *schemaRegistryClient, info schemaInfo) (string, error) {
func resolveAvroReferences(ctx context.Context, client *sr.Client, info sr.SchemaInfo) (string, error) {
if len(info.References) == 0 {
return info.Schema, nil
}

refsMap := map[string]string{}
if err := client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, info schemaInfo) error {
if err := client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, info sr.SchemaInfo) error {
refsMap[name] = info.Schema
return nil
}); err != nil {
Expand Down Expand Up @@ -59,7 +60,7 @@ func resolveAvroReferences(ctx context.Context, client *schemaRegistryClient, in
return string(schemaHydratedBytes), nil
}

func (s *schemaRegistryEncoder) getAvroEncoder(ctx context.Context, info schemaInfo) (schemaEncoder, error) {
func (s *schemaRegistryEncoder) getAvroEncoder(ctx context.Context, info sr.SchemaInfo) (schemaEncoder, error) {
schema, err := resolveAvroReferences(ctx, s.client, info)
if err != nil {
return nil, err
Expand Down Expand Up @@ -97,7 +98,7 @@ func (s *schemaRegistryEncoder) getAvroEncoder(ctx context.Context, info schemaI
}, nil
}

func (s *schemaRegistryDecoder) getAvroDecoder(ctx context.Context, info schemaInfo) (schemaDecoder, error) {
func (s *schemaRegistryDecoder) getAvroDecoder(ctx context.Context, info sr.SchemaInfo) (schemaDecoder, error) {
schema, err := resolveAvroReferences(ctx, s.client, info)
if err != nil {
return nil, err
Expand Down
11 changes: 6 additions & 5 deletions internal/impl/confluent/serde_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import (
"github.com/xeipuuv/gojsonschema"

"github.com/redpanda-data/benthos/v4/public/service"
"github.com/redpanda-data/connect/v4/internal/impl/confluent/sr"
)

func resolveJSONSchema(ctx context.Context, client *schemaRegistryClient, info schemaInfo) (*gojsonschema.Schema, error) {
func resolveJSONSchema(ctx context.Context, client *sr.Client, info sr.SchemaInfo) (*gojsonschema.Schema, error) {
sl := gojsonschema.NewSchemaLoader()

if len(info.References) == 0 {
Expand All @@ -34,7 +35,7 @@ func resolveJSONSchema(ctx context.Context, client *schemaRegistryClient, info s
return sl.Compile(gojsonschema.NewStringLoader(info.Schema))
}

if err := client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, info schemaInfo) error {
if err := client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, info sr.SchemaInfo) error {
return sl.AddSchemas(gojsonschema.NewStringLoader(info.Schema))
}); err != nil {
return nil, err
Expand All @@ -43,15 +44,15 @@ func resolveJSONSchema(ctx context.Context, client *schemaRegistryClient, info s
return sl.Compile(gojsonschema.NewStringLoader(info.Schema))
}

func (s *schemaRegistryEncoder) getJSONEncoder(ctx context.Context, info schemaInfo) (schemaEncoder, error) {
func (s *schemaRegistryEncoder) getJSONEncoder(ctx context.Context, info sr.SchemaInfo) (schemaEncoder, error) {
return getJSONTranscoder(ctx, s.client, info)
}

func (s *schemaRegistryDecoder) getJSONDecoder(ctx context.Context, info schemaInfo) (schemaDecoder, error) {
func (s *schemaRegistryDecoder) getJSONDecoder(ctx context.Context, info sr.SchemaInfo) (schemaDecoder, error) {
return getJSONTranscoder(ctx, s.client, info)
}

func getJSONTranscoder(ctx context.Context, cl *schemaRegistryClient, info schemaInfo) (func(m *service.Message) error, error) {
func getJSONTranscoder(ctx context.Context, cl *sr.Client, info sr.SchemaInfo) (func(m *service.Message) error, error) {
sch, err := resolveJSONSchema(ctx, cl, info)
if err != nil {
return nil, err
Expand Down
9 changes: 5 additions & 4 deletions internal/impl/confluent/serde_protobuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ import (

"github.com/redpanda-data/benthos/v4/public/service"

"github.com/redpanda-data/connect/v4/internal/impl/confluent/sr"
"github.com/redpanda-data/connect/v4/internal/impl/protobuf"
)

func (s *schemaRegistryDecoder) getProtobufDecoder(ctx context.Context, info schemaInfo) (schemaDecoder, error) {
func (s *schemaRegistryDecoder) getProtobufDecoder(ctx context.Context, info sr.SchemaInfo) (schemaDecoder, error) {
regMap := map[string]string{
".": info.Schema,
}
if err := s.client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, si schemaInfo) error {
if err := s.client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, si sr.SchemaInfo) error {
regMap[name] = si.Schema
return nil
}); err != nil {
Expand Down Expand Up @@ -96,11 +97,11 @@ func (s *schemaRegistryDecoder) getProtobufDecoder(ctx context.Context, info sch
}, nil
}

func (s *schemaRegistryEncoder) getProtobufEncoder(ctx context.Context, info schemaInfo) (schemaEncoder, error) {
func (s *schemaRegistryEncoder) getProtobufEncoder(ctx context.Context, info sr.SchemaInfo) (schemaEncoder, error) {
regMap := map[string]string{
".": info.Schema,
}
if err := s.client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, si schemaInfo) error {
if err := s.client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, si sr.SchemaInfo) error {
regMap[name] = si.Schema
return nil
}); err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package confluent
package sr

import (
"bytes"
Expand All @@ -29,19 +29,21 @@ import (
"github.com/redpanda-data/benthos/v4/public/service"
)

type schemaRegistryClient struct {
// Client is used to make requests to a schema registry.
type Client struct {
SchemaRegistryBaseURL *url.URL
client *http.Client
schemaRegistryBaseURL *url.URL
requestSigner func(f fs.FS, req *http.Request) error
mgr *service.Resources
}

func newSchemaRegistryClient(
// NewClient creates a new schema registry client.
func NewClient(
urlStr string,
reqSigner func(f fs.FS, req *http.Request) error,
tlsConf *tls.Config,
mgr *service.Resources,
) (*schemaRegistryClient, error) {
) (*Client, error) {
u, err := url.Parse(urlStr)
if err != nil {
return nil, fmt.Errorf("failed to parse url: %w", err)
Expand All @@ -61,30 +63,33 @@ func newSchemaRegistryClient(
}
}

return &schemaRegistryClient{
return &Client{
client: hClient,
schemaRegistryBaseURL: u,
SchemaRegistryBaseURL: u,
requestSigner: reqSigner,
mgr: mgr,
}, nil
}

type schemaInfo struct {
// SchemaInfo is the information about a schema stored in the registry.
type SchemaInfo struct {
ID int `json:"id"`
Type string `json:"schemaType"`
Schema string `json:"schema"`
References []schemaReference `json:"references"`
References []SchemaReference `json:"references"`
}

// TODO: Further reading:
// https://www.confluent.io/blog/multiple-event-types-in-the-same-kafka-topic/
type schemaReference struct {
// SchemaReference is a reference to another schema within the registry.
//
// TODO: further reading https://www.confluent.io/blog/multiple-event-types-in-the-same-kafka-topic/
type SchemaReference struct {
Name string `json:"name"`
Subject string `json:"subject"`
Version int `json:"version"`
}

func (c *schemaRegistryClient) GetSchemaByID(ctx context.Context, id int) (resPayload schemaInfo, err error) {
// GetSchemaByID gets a schema by it's global identifier.
func (c *Client) GetSchemaByID(ctx context.Context, id int) (resPayload SchemaInfo, err error) {
var resCode int
var resBody []byte
if resCode, resBody, err = c.doRequest(ctx, "GET", fmt.Sprintf("/schemas/ids/%v", id)); err != nil {
Expand Down Expand Up @@ -112,7 +117,8 @@ func (c *schemaRegistryClient) GetSchemaByID(ctx context.Context, id int) (resPa
return
}

func (c *schemaRegistryClient) GetSchemaBySubjectAndVersion(ctx context.Context, subject string, version *int) (resPayload schemaInfo, err error) {
// GetSchemaBySubjectAndVersion returns the schema by it's subject and optional version. A `nil` version returns the latest schema.
func (c *Client) GetSchemaBySubjectAndVersion(ctx context.Context, subject string, version *int) (resPayload SchemaInfo, err error) {
var path string
if version != nil {
path = fmt.Sprintf("/subjects/%s/versions/%v", url.PathEscape(subject), *version)
Expand Down Expand Up @@ -147,19 +153,19 @@ func (c *schemaRegistryClient) GetSchemaBySubjectAndVersion(ctx context.Context,
return
}

type refWalkFn func(ctx context.Context, name string, info schemaInfo) error
type refWalkFn func(ctx context.Context, name string, info SchemaInfo) error

// For each reference provided the schema info is obtained and the provided
// closure is called recursively, which means each reference obtained will also
// be walked.
// WalkReferences goes through the provided schema info and for each reference
// the provided closure is called recursively, which means each reference obtained
// will also be walked.
//
// If a reference of a given subject but differing version is detected an error
// is returned as this would put us in an invalid state.
func (c *schemaRegistryClient) WalkReferences(ctx context.Context, refs []schemaReference, fn refWalkFn) error {
func (c *Client) WalkReferences(ctx context.Context, refs []SchemaReference, fn refWalkFn) error {
return c.walkReferencesTracked(ctx, map[string]int{}, refs, fn)
}

func (c *schemaRegistryClient) walkReferencesTracked(ctx context.Context, seen map[string]int, refs []schemaReference, fn refWalkFn) error {
func (c *Client) walkReferencesTracked(ctx context.Context, seen map[string]int, refs []SchemaReference, fn refWalkFn) error {
for _, ref := range refs {
if i, exists := seen[ref.Name]; exists {
if i != ref.Version {
Expand All @@ -182,8 +188,8 @@ func (c *schemaRegistryClient) walkReferencesTracked(ctx context.Context, seen m
return nil
}

func (c *schemaRegistryClient) doRequest(ctx context.Context, verb, reqPath string) (resCode int, resBody []byte, err error) {
reqURL := *c.schemaRegistryBaseURL
func (c *Client) doRequest(ctx context.Context, verb, reqPath string) (resCode int, resBody []byte, err error) {
reqURL := *c.SchemaRegistryBaseURL
if reqURL.Path, err = url.JoinPath(reqURL.Path, reqPath); err != nil {
return
}
Expand Down
11 changes: 4 additions & 7 deletions internal/impl/openai/base_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ package openai
import (
"context"

"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/redpanda-data/benthos/v4/public/service"
oai "github.com/sashabaranov/go-openai"
)

const (
Expand Down Expand Up @@ -54,11 +53,9 @@ func newBaseProcessor(conf *service.ParsedConfig) (*baseProcessor, error) {
if err != nil {
return nil, err
}
kc := azcore.NewKeyCredential(k)
c, err := azopenai.NewClientForOpenAI(sa, kc, nil)
if err != nil {
return nil, err
}
cfg := oai.DefaultConfig(k)
cfg.BaseURL = sa
c := oai.NewClientWithConfig(cfg)
m, err := conf.FieldString(opFieldModel)
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit 1c664da

Please sign in to comment.