diff --git a/docs/modules/components/pages/processors/openai_chat_completion.adoc b/docs/modules/components/pages/processors/openai_chat_completion.adoc index aa9983c986..cd28e97e96 100644 --- a/docs/modules/components/pages/processors/openai_chat_completion.adoc +++ b/docs/modules/components/pages/processors/openai_chat_completion.adoc @@ -46,6 +46,10 @@ openai_chat_completion: max_tokens: 0 # No default (optional) temperature: 0 # No default (optional) user: "" # No default (optional) + response_format: text + json_schema: + name: "" # No default (required) + schema: "" # No default (required) ``` -- @@ -65,6 +69,38 @@ openai_chat_completion: max_tokens: 0 # No default (optional) temperature: 0 # No default (optional) user: "" # No default (optional) + response_format: text + json_schema: + name: "" # No default (required) + description: "" # No default (optional) + schema: "" # No default (required) + schema_registry: + url: "" # No default (required) + name_prefix: schema_registry_id_ + subject: "" # No default (required) + refresh_interval: "" # No default (optional) + tls: + skip_cert_verify: false + enable_renegotiation: false + root_cas: "" + root_cas_file: "" + client_certs: [] + oauth: + enabled: false + consumer_key: "" + consumer_secret: "" + access_token: "" + access_token_secret: "" + basic_auth: + enabled: false + username: "" + password: "" + jwt: + enabled: false + private_key_file: "" + signing_method: "" + claims: {} + headers: {} top_p: 0 # No default (optional) frequency_penalty: 0 # No default (optional) presence_penalty: 0 # No default (optional) @@ -167,6 +203,401 @@ This field supports xref:configuration:interpolation.adoc#bloblang-queries[inter *Type*: `string` +=== `response_format` + +Specify the model's output format. If `json_schema` is specified, then additionally a `json_schema` or `schema_registry` must be configured. + + +*Type*: `string` + +*Default*: `"text"` + +Options: +`text` +, `json` +, `json_schema` +. + +=== `json_schema` + +The JSON schema to use when responding in `json_schema` format. To learn more about what JSON schema is supported see the https://platform.openai.com/docs/guides/structured-outputs/supported-schemas[OpenAI documentation^]. + + +*Type*: `object` + + +=== `json_schema.name` + +The name of the schema. + + +*Type*: `string` + + +=== `json_schema.description` + +Additional description of the schema for the LLM. + + +*Type*: `string` + + +=== `json_schema.schema` + +The JSON schema for the LLM to use when generating the output. + + +*Type*: `string` + + +=== `schema_registry` + +The schema registry to dynamically load schemas from when responding in `json_schema` format. Schemas themselves must be in JSON format. To learn more about what JSON schema is supported see the https://platform.openai.com/docs/guides/structured-outputs/supported-schemas[OpenAI documentation^]. + + +*Type*: `object` + + +=== `schema_registry.url` + +The base URL of the schema registry service. + + +*Type*: `string` + + +=== `schema_registry.name_prefix` + +The prefix of the name for this schema, the schema ID is used as a suffix. + + +*Type*: `string` + +*Default*: `"schema_registry_id_"` + +=== `schema_registry.subject` + +The subject name to fetch the schema for. + + +*Type*: `string` + + +=== `schema_registry.refresh_interval` + +The refresh rate for getting the latest schema. If not specified the schema does not refresh. + + +*Type*: `string` + + +=== `schema_registry.tls` + +Custom TLS settings can be used to override system defaults. + + +*Type*: `object` + + +=== `schema_registry.tls.skip_cert_verify` + +Whether to skip server side certificate verification. + + +*Type*: `bool` + +*Default*: `false` + +=== `schema_registry.tls.enable_renegotiation` + +Whether to allow the remote server to repeatedly request renegotiation. Enable this option if you're seeing the error message `local error: tls: no renegotiation`. + + +*Type*: `bool` + +*Default*: `false` +Requires version 3.45.0 or newer + +=== `schema_registry.tls.root_cas` + +An optional root certificate authority to use. This is a string, representing a certificate chain from the parent trusted root certificate, to possible intermediate signing certificates, to the host certificate. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +root_cas: |- + -----BEGIN CERTIFICATE----- + ... + -----END CERTIFICATE----- +``` + +=== `schema_registry.tls.root_cas_file` + +An optional path of a root certificate authority file to use. This is a file, often with a .pem extension, containing a certificate chain from the parent trusted root certificate, to possible intermediate signing certificates, to the host certificate. + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +root_cas_file: ./root_cas.pem +``` + +=== `schema_registry.tls.client_certs` + +A list of client certificates to use. For each certificate either the fields `cert` and `key`, or `cert_file` and `key_file` should be specified, but not both. + + +*Type*: `array` + +*Default*: `[]` + +```yml +# Examples + +client_certs: + - cert: foo + key: bar + +client_certs: + - cert_file: ./example.pem + key_file: ./example.key +``` + +=== `schema_registry.tls.client_certs[].cert` + +A plain text certificate to use. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.tls.client_certs[].key` + +A plain text certificate key to use. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.tls.client_certs[].cert_file` + +The path of a certificate to use. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.tls.client_certs[].key_file` + +The path of a certificate key to use. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.tls.client_certs[].password` + +A plain text password for when the private key is password encrypted in PKCS#1 or PKCS#8 format. The obsolete `pbeWithMD5AndDES-CBC` algorithm is not supported for the PKCS#8 format. + +Because the obsolete pbeWithMD5AndDES-CBC algorithm does not authenticate the ciphertext, it is vulnerable to padding oracle attacks that can let an attacker recover the plaintext. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +password: foo + +password: ${KEY_PASSWORD} +``` + +=== `schema_registry.oauth` + +Allows you to specify open authentication via OAuth version 1. + + +*Type*: `object` + + +=== `schema_registry.oauth.enabled` + +Whether to use OAuth version 1 in requests. + + +*Type*: `bool` + +*Default*: `false` + +=== `schema_registry.oauth.consumer_key` + +A value used to identify the client to the service provider. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.oauth.consumer_secret` + +A secret used to establish ownership of the consumer key. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.oauth.access_token` + +A value used to gain access to the protected resources on behalf of the user. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.oauth.access_token_secret` + +A secret provided in order to establish ownership of a given access token. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.basic_auth` + +Allows you to specify basic authentication. + + +*Type*: `object` + + +=== `schema_registry.basic_auth.enabled` + +Whether to use basic authentication in requests. + + +*Type*: `bool` + +*Default*: `false` + +=== `schema_registry.basic_auth.username` + +A username to authenticate as. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.basic_auth.password` + +A password to authenticate with. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.jwt` + +BETA: Allows you to specify JWT authentication. + + +*Type*: `object` + + +=== `schema_registry.jwt.enabled` + +Whether to use JWT authentication in requests. + + +*Type*: `bool` + +*Default*: `false` + +=== `schema_registry.jwt.private_key_file` + +A file with the PEM encoded via PKCS1 or PKCS8 as private key. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.jwt.signing_method` + +A method used to sign the token such as RS256, RS384, RS512 or EdDSA. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.jwt.claims` + +A value used to identify the claims that issued the JWT. + + +*Type*: `object` + +*Default*: `{}` + +=== `schema_registry.jwt.headers` + +Add optional key/value headers to the JWT. + + +*Type*: `object` + +*Default*: `{}` + === `top_p` An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. diff --git a/go.mod b/go.mod index 97e6593eb6..ad2a6d9fd1 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( cloud.google.com/go/pubsub v1.40.0 cloud.google.com/go/storage v1.42.0 cloud.google.com/go/vertexai v0.12.0 - github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.0.3 @@ -99,6 +98,7 @@ require ( github.com/redpanda-data/benthos/v4 v4.36.0 github.com/redpanda-data/connect/public/bundle/free/v4 v4.31.0 github.com/rs/xid v1.5.0 + github.com/sashabaranov/go-openai v1.28.3 github.com/sijms/go-ora/v2 v2.8.19 github.com/smira/go-statsd v1.3.3 github.com/snowflakedb/gosnowflake v1.11.0 diff --git a/go.sum b/go.sum index 473ae57f20..237e129783 100644 --- a/go.sum +++ b/go.sum @@ -72,8 +72,6 @@ github.com/AthenZ/athenz v1.10.43/go.mod h1:pEm4lLLcpwxS33OdM8JNCS7GnWBoY/12QD7i github.com/Azure/azure-pipeline-go v0.2.3/go.mod h1:x841ezTBIMG6O3lAcl8ATHnsOPVl2bqk7S3ta6S6u4k= github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU= github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= -github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0 h1:FQOmDxJj1If0D0khZR00MDa2Eb+k9BBsSaK7cEbLwkk= -github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0/go.mod h1:X0+PSrHOZdTjkiEhgv53HS5gplbzVVl2jd6hQRYSS3c= github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0/go.mod h1:h6H6c8enJmmocHUbLiiGY6sx7f9i+X3m1CHdd5c6Rdw= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= @@ -1052,6 +1050,8 @@ github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThC github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w= +github.com/sashabaranov/go-openai v1.28.3 h1:9ZjKWwFOO8RRgHarUC8rTPSLBZgkNzjyf18O9/8+jto= +github.com/sashabaranov/go-openai v1.28.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= diff --git a/internal/impl/confluent/processor_schema_registry_decode.go b/internal/impl/confluent/processor_schema_registry_decode.go index 0884976e8a..c5f554db1a 100644 --- a/internal/impl/confluent/processor_schema_registry_decode.go +++ b/internal/impl/confluent/processor_schema_registry_decode.go @@ -29,6 +29,7 @@ import ( "github.com/Jeffail/shutdown" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" ) func schemaRegistryDecoderConfig() *service.ConfigSpec { @@ -87,7 +88,7 @@ func init() { type schemaRegistryDecoder struct { avroRawJSON bool - client *schemaRegistryClient + client *sr.Client schemas map[int]*cachedSchemaDecoder cacheMut sync.RWMutex @@ -133,7 +134,7 @@ func newSchemaRegistryDecoder( mgr: mgr, } var err error - if s.client, err = newSchemaRegistryClient(urlStr, reqSigner, tlsConf, mgr); err != nil { + if s.client, err = sr.NewClient(urlStr, reqSigner, tlsConf, mgr); err != nil { return nil, err } diff --git a/internal/impl/confluent/processor_schema_registry_decode_test.go b/internal/impl/confluent/processor_schema_registry_decode_test.go index 03280cd649..1f479c93e5 100644 --- a/internal/impl/confluent/processor_schema_registry_decode_test.go +++ b/internal/impl/confluent/processor_schema_registry_decode_test.go @@ -74,7 +74,7 @@ basic_auth: e, err := newSchemaRegistryDecoderFromConfig(conf, service.MockResources()) if e != nil { - assert.Equal(t, test.expectedBaseURL, e.client.schemaRegistryBaseURL.String()) + assert.Equal(t, test.expectedBaseURL, e.client.SchemaRegistryBaseURL.String()) } if err == nil { diff --git a/internal/impl/confluent/processor_schema_registry_encode.go b/internal/impl/confluent/processor_schema_registry_encode.go index 48b20d1be5..55a4116581 100644 --- a/internal/impl/confluent/processor_schema_registry_encode.go +++ b/internal/impl/confluent/processor_schema_registry_encode.go @@ -29,6 +29,7 @@ import ( "github.com/Jeffail/shutdown" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" ) func schemaRegistryEncoderConfig() *service.ConfigSpec { @@ -107,7 +108,7 @@ func init() { //------------------------------------------------------------------------------ type schemaRegistryEncoder struct { - client *schemaRegistryClient + client *sr.Client subject *service.InterpolatedString avroRawJSON bool schemaRefreshAfter time.Duration @@ -178,7 +179,7 @@ func newSchemaRegistryEncoder( nowFn: time.Now, } var err error - if s.client, err = newSchemaRegistryClient(urlStr, reqSigner, tlsConf, mgr); err != nil { + if s.client, err = sr.NewClient(urlStr, reqSigner, tlsConf, mgr); err != nil { return nil, err } diff --git a/internal/impl/confluent/processor_schema_registry_encode_test.go b/internal/impl/confluent/processor_schema_registry_encode_test.go index 23d7eb40d2..10cf0dc9e9 100644 --- a/internal/impl/confluent/processor_schema_registry_encode_test.go +++ b/internal/impl/confluent/processor_schema_registry_encode_test.go @@ -107,7 +107,7 @@ subject: foo e, err := newSchemaRegistryEncoderFromConfig(conf, service.MockResources()) if e != nil { - assert.Equal(t, test.expectedBaseURL, e.client.schemaRegistryBaseURL.String()) + assert.Equal(t, test.expectedBaseURL, e.client.SchemaRegistryBaseURL.String()) } if err == nil { diff --git a/internal/impl/confluent/serde_avro.go b/internal/impl/confluent/serde_avro.go index 62801af243..03b22598ac 100644 --- a/internal/impl/confluent/serde_avro.go +++ b/internal/impl/confluent/serde_avro.go @@ -22,15 +22,16 @@ import ( "github.com/linkedin/goavro/v2" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" ) -func resolveAvroReferences(ctx context.Context, client *schemaRegistryClient, info schemaInfo) (string, error) { +func resolveAvroReferences(ctx context.Context, client *sr.Client, info sr.SchemaInfo) (string, error) { if len(info.References) == 0 { return info.Schema, nil } refsMap := map[string]string{} - if err := client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, info schemaInfo) error { + if err := client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, info sr.SchemaInfo) error { refsMap[name] = info.Schema return nil }); err != nil { @@ -59,7 +60,7 @@ func resolveAvroReferences(ctx context.Context, client *schemaRegistryClient, in return string(schemaHydratedBytes), nil } -func (s *schemaRegistryEncoder) getAvroEncoder(ctx context.Context, info schemaInfo) (schemaEncoder, error) { +func (s *schemaRegistryEncoder) getAvroEncoder(ctx context.Context, info sr.SchemaInfo) (schemaEncoder, error) { schema, err := resolveAvroReferences(ctx, s.client, info) if err != nil { return nil, err @@ -97,7 +98,7 @@ func (s *schemaRegistryEncoder) getAvroEncoder(ctx context.Context, info schemaI }, nil } -func (s *schemaRegistryDecoder) getAvroDecoder(ctx context.Context, info schemaInfo) (schemaDecoder, error) { +func (s *schemaRegistryDecoder) getAvroDecoder(ctx context.Context, info sr.SchemaInfo) (schemaDecoder, error) { schema, err := resolveAvroReferences(ctx, s.client, info) if err != nil { return nil, err diff --git a/internal/impl/confluent/serde_json.go b/internal/impl/confluent/serde_json.go index 2149b81b3c..97949ef443 100644 --- a/internal/impl/confluent/serde_json.go +++ b/internal/impl/confluent/serde_json.go @@ -21,9 +21,10 @@ import ( "github.com/xeipuuv/gojsonschema" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" ) -func resolveJSONSchema(ctx context.Context, client *schemaRegistryClient, info schemaInfo) (*gojsonschema.Schema, error) { +func resolveJSONSchema(ctx context.Context, client *sr.Client, info sr.SchemaInfo) (*gojsonschema.Schema, error) { sl := gojsonschema.NewSchemaLoader() if len(info.References) == 0 { @@ -34,7 +35,7 @@ func resolveJSONSchema(ctx context.Context, client *schemaRegistryClient, info s return sl.Compile(gojsonschema.NewStringLoader(info.Schema)) } - if err := client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, info schemaInfo) error { + if err := client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, info sr.SchemaInfo) error { return sl.AddSchemas(gojsonschema.NewStringLoader(info.Schema)) }); err != nil { return nil, err @@ -43,15 +44,15 @@ func resolveJSONSchema(ctx context.Context, client *schemaRegistryClient, info s return sl.Compile(gojsonschema.NewStringLoader(info.Schema)) } -func (s *schemaRegistryEncoder) getJSONEncoder(ctx context.Context, info schemaInfo) (schemaEncoder, error) { +func (s *schemaRegistryEncoder) getJSONEncoder(ctx context.Context, info sr.SchemaInfo) (schemaEncoder, error) { return getJSONTranscoder(ctx, s.client, info) } -func (s *schemaRegistryDecoder) getJSONDecoder(ctx context.Context, info schemaInfo) (schemaDecoder, error) { +func (s *schemaRegistryDecoder) getJSONDecoder(ctx context.Context, info sr.SchemaInfo) (schemaDecoder, error) { return getJSONTranscoder(ctx, s.client, info) } -func getJSONTranscoder(ctx context.Context, cl *schemaRegistryClient, info schemaInfo) (func(m *service.Message) error, error) { +func getJSONTranscoder(ctx context.Context, cl *sr.Client, info sr.SchemaInfo) (func(m *service.Message) error, error) { sch, err := resolveJSONSchema(ctx, cl, info) if err != nil { return nil, err diff --git a/internal/impl/confluent/serde_protobuf.go b/internal/impl/confluent/serde_protobuf.go index 820e11a03c..fc7696a132 100644 --- a/internal/impl/confluent/serde_protobuf.go +++ b/internal/impl/confluent/serde_protobuf.go @@ -29,14 +29,15 @@ import ( "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" "github.com/redpanda-data/connect/v4/internal/impl/protobuf" ) -func (s *schemaRegistryDecoder) getProtobufDecoder(ctx context.Context, info schemaInfo) (schemaDecoder, error) { +func (s *schemaRegistryDecoder) getProtobufDecoder(ctx context.Context, info sr.SchemaInfo) (schemaDecoder, error) { regMap := map[string]string{ ".": info.Schema, } - if err := s.client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, si schemaInfo) error { + if err := s.client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, si sr.SchemaInfo) error { regMap[name] = si.Schema return nil }); err != nil { @@ -96,11 +97,11 @@ func (s *schemaRegistryDecoder) getProtobufDecoder(ctx context.Context, info sch }, nil } -func (s *schemaRegistryEncoder) getProtobufEncoder(ctx context.Context, info schemaInfo) (schemaEncoder, error) { +func (s *schemaRegistryEncoder) getProtobufEncoder(ctx context.Context, info sr.SchemaInfo) (schemaEncoder, error) { regMap := map[string]string{ ".": info.Schema, } - if err := s.client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, si schemaInfo) error { + if err := s.client.WalkReferences(ctx, info.References, func(ctx context.Context, name string, si sr.SchemaInfo) error { regMap[name] = si.Schema return nil }); err != nil { diff --git a/internal/impl/confluent/client.go b/internal/impl/confluent/sr/client.go similarity index 76% rename from internal/impl/confluent/client.go rename to internal/impl/confluent/sr/client.go index 84de132221..71ebe8af3d 100644 --- a/internal/impl/confluent/client.go +++ b/internal/impl/confluent/sr/client.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package confluent +package sr import ( "bytes" @@ -29,19 +29,21 @@ import ( "github.com/redpanda-data/benthos/v4/public/service" ) -type schemaRegistryClient struct { +// Client is used to make requests to a schema registry. +type Client struct { + SchemaRegistryBaseURL *url.URL client *http.Client - schemaRegistryBaseURL *url.URL requestSigner func(f fs.FS, req *http.Request) error mgr *service.Resources } -func newSchemaRegistryClient( +// NewClient creates a new schema registry client. +func NewClient( urlStr string, reqSigner func(f fs.FS, req *http.Request) error, tlsConf *tls.Config, mgr *service.Resources, -) (*schemaRegistryClient, error) { +) (*Client, error) { u, err := url.Parse(urlStr) if err != nil { return nil, fmt.Errorf("failed to parse url: %w", err) @@ -61,30 +63,33 @@ func newSchemaRegistryClient( } } - return &schemaRegistryClient{ + return &Client{ client: hClient, - schemaRegistryBaseURL: u, + SchemaRegistryBaseURL: u, requestSigner: reqSigner, mgr: mgr, }, nil } -type schemaInfo struct { +// SchemaInfo is the information about a schema stored in the registry. +type SchemaInfo struct { ID int `json:"id"` Type string `json:"schemaType"` Schema string `json:"schema"` - References []schemaReference `json:"references"` + References []SchemaReference `json:"references"` } -// TODO: Further reading: -// https://www.confluent.io/blog/multiple-event-types-in-the-same-kafka-topic/ -type schemaReference struct { +// SchemaReference is a reference to another schema within the registry. +// +// TODO: further reading https://www.confluent.io/blog/multiple-event-types-in-the-same-kafka-topic/ +type SchemaReference struct { Name string `json:"name"` Subject string `json:"subject"` Version int `json:"version"` } -func (c *schemaRegistryClient) GetSchemaByID(ctx context.Context, id int) (resPayload schemaInfo, err error) { +// GetSchemaByID gets a schema by it's global identifier. +func (c *Client) GetSchemaByID(ctx context.Context, id int) (resPayload SchemaInfo, err error) { var resCode int var resBody []byte if resCode, resBody, err = c.doRequest(ctx, "GET", fmt.Sprintf("/schemas/ids/%v", id)); err != nil { @@ -112,7 +117,8 @@ func (c *schemaRegistryClient) GetSchemaByID(ctx context.Context, id int) (resPa return } -func (c *schemaRegistryClient) GetSchemaBySubjectAndVersion(ctx context.Context, subject string, version *int) (resPayload schemaInfo, err error) { +// GetSchemaBySubjectAndVersion returns the schema by it's subject and optional version. A `nil` version returns the latest schema. +func (c *Client) GetSchemaBySubjectAndVersion(ctx context.Context, subject string, version *int) (resPayload SchemaInfo, err error) { var path string if version != nil { path = fmt.Sprintf("/subjects/%s/versions/%v", url.PathEscape(subject), *version) @@ -147,19 +153,19 @@ func (c *schemaRegistryClient) GetSchemaBySubjectAndVersion(ctx context.Context, return } -type refWalkFn func(ctx context.Context, name string, info schemaInfo) error +type refWalkFn func(ctx context.Context, name string, info SchemaInfo) error -// For each reference provided the schema info is obtained and the provided -// closure is called recursively, which means each reference obtained will also -// be walked. +// WalkReferences goes through the provided schema info and for each reference +// the provided closure is called recursively, which means each reference obtained +// will also be walked. // // If a reference of a given subject but differing version is detected an error // is returned as this would put us in an invalid state. -func (c *schemaRegistryClient) WalkReferences(ctx context.Context, refs []schemaReference, fn refWalkFn) error { +func (c *Client) WalkReferences(ctx context.Context, refs []SchemaReference, fn refWalkFn) error { return c.walkReferencesTracked(ctx, map[string]int{}, refs, fn) } -func (c *schemaRegistryClient) walkReferencesTracked(ctx context.Context, seen map[string]int, refs []schemaReference, fn refWalkFn) error { +func (c *Client) walkReferencesTracked(ctx context.Context, seen map[string]int, refs []SchemaReference, fn refWalkFn) error { for _, ref := range refs { if i, exists := seen[ref.Name]; exists { if i != ref.Version { @@ -182,8 +188,8 @@ func (c *schemaRegistryClient) walkReferencesTracked(ctx context.Context, seen m return nil } -func (c *schemaRegistryClient) doRequest(ctx context.Context, verb, reqPath string) (resCode int, resBody []byte, err error) { - reqURL := *c.schemaRegistryBaseURL +func (c *Client) doRequest(ctx context.Context, verb, reqPath string) (resCode int, resBody []byte, err error) { + reqURL := *c.SchemaRegistryBaseURL if reqURL.Path, err = url.JoinPath(reqURL.Path, reqPath); err != nil { return } diff --git a/internal/impl/openai/base_processor.go b/internal/impl/openai/base_processor.go index da0490479f..846472e6a4 100644 --- a/internal/impl/openai/base_processor.go +++ b/internal/impl/openai/base_processor.go @@ -11,9 +11,8 @@ package openai import ( "context" - "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" - "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/redpanda-data/benthos/v4/public/service" + oai "github.com/sashabaranov/go-openai" ) const ( @@ -54,11 +53,9 @@ func newBaseProcessor(conf *service.ParsedConfig) (*baseProcessor, error) { if err != nil { return nil, err } - kc := azcore.NewKeyCredential(k) - c, err := azopenai.NewClientForOpenAI(sa, kc, nil) - if err != nil { - return nil, err - } + cfg := oai.DefaultConfig(k) + cfg.BaseURL = sa + c := oai.NewClientWithConfig(cfg) m, err := conf.FieldString(opFieldModel) if err != nil { return nil, err diff --git a/internal/impl/openai/chat_processor.go b/internal/impl/openai/chat_processor.go index 9edd3e980b..7a6469d9a1 100644 --- a/internal/impl/openai/chat_processor.go +++ b/internal/impl/openai/chat_processor.go @@ -10,12 +10,15 @@ package openai import ( "context" - "errors" "fmt" + "math" + "slices" + "time" - oai "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/redpanda-data/benthos/v4/public/bloblang" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" + oai "github.com/sashabaranov/go-openai" ) const ( @@ -29,6 +32,19 @@ const ( ocpFieldStop = "stop" ocpFieldPresencePenalty = "presence_penalty" ocpFieldFrequencyPenalty = "frequency_penalty" + ocpFieldResponseFormat = "response_format" + // JSON schema fields + ocpFieldJSONSchema = "json_schema" + ocpFieldJSONSchemaName = "name" + ocpFieldJSONSchemaDesc = "description" + ocpFieldJSONSchemaSchema = "schema" + // Schema registry fields + ocpFieldSchemaRegistry = "schema_registry" + ocpFieldSchemaRegistrySubject = "subject" + ocpFieldSchemaRegistryRefreshInterval = "refresh_interval" + ocpFieldSchemaRegistryNamePrefix = "name_prefix" + ocpFieldSchemaRegistryURL = "url" + ocpFieldSchemaRegistryTLS = "tls" ) func init() { @@ -78,6 +94,37 @@ We generally recommend altering this or top_p but not both.`). service.NewInterpolatedStringField(ocpFieldUser). Optional(). Description("A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse."), + service.NewStringEnumField(ocpFieldResponseFormat, "text", "json", "json_schema"). + Default("text"). + Description("Specify the model's output format. If `json_schema` is specified, then additionally a `json_schema` or `schema_registry` must be configured."), + service.NewObjectField(ocpFieldJSONSchema, + service.NewStringField(ocpFieldJSONSchemaName).Description("The name of the schema."), + service.NewStringField(ocpFieldJSONSchemaDesc).Optional().Advanced().Description("Additional description of the schema for the LLM."), + service.NewStringField(ocpFieldJSONSchemaSchema).Description("The JSON schema for the LLM to use when generating the output."), + ). + Optional(). + Description("The JSON schema to use when responding in `json_schema` format. To learn more about what JSON schema is supported see the https://platform.openai.com/docs/guides/structured-outputs/supported-schemas[OpenAI documentation^]."), + service.NewObjectField( + ocpFieldSchemaRegistry, + slices.Concat( + []*service.ConfigField{ + service.NewURLField(ocpFieldSchemaRegistryURL).Description("The base URL of the schema registry service."), + service.NewStringField(ocpFieldSchemaRegistryNamePrefix). + Default("schema_registry_id_"). + Description("The prefix of the name for this schema, the schema ID is used as a suffix."), + service.NewStringField(ocpFieldSchemaRegistrySubject). + Description("The subject name to fetch the schema for."), + service.NewDurationField(ocpFieldSchemaRegistryRefreshInterval). + Optional(). + Description("The refresh rate for getting the latest schema. If not specified the schema does not refresh."), + service.NewTLSField(ocpFieldSchemaRegistryTLS), + }, + service.NewHTTPRequestAuthSignerFields(), + )..., + ). + Description("The schema registry to dynamically load schemas from when responding in `json_schema` format. Schemas themselves must be in JSON format. To learn more about what JSON schema is supported see the https://platform.openai.com/docs/guides/structured-outputs/supported-schemas[OpenAI documentation^]."). + Optional(). + Advanced(), service.NewFloatField(ocpFieldTopP). Optional(). Advanced(). @@ -103,7 +150,12 @@ We generally recommend altering this or temperature but not both.`). Optional(). Advanced(). Description("Up to 4 sequences where the API will stop generating further tokens."), - ) + ).LintRule(` + root = match { + this.exists("` + ocpFieldJSONSchema + `") && this.exists("` + ocpFieldSchemaRegistry + `") => ["cannot set both ` + "`" + ocpFieldJSONSchema + "`" + ` and ` + "`" + ocpFieldSchemaRegistry + "`" + `"] + this.response_format == "json_schema" && !this.exists("` + ocpFieldJSONSchema + `") && !this.exists("` + ocpFieldSchemaRegistry + `") => ["schema must be specified using either ` + "`" + ocpFieldJSONSchema + "`" + ` or ` + "`" + ocpFieldSchemaRegistry + "`" + `"] + } + `) } func makeChatProcessor(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) { @@ -125,14 +177,13 @@ func makeChatProcessor(conf *service.ParsedConfig, mgr *service.Resources) (serv return nil, err } } - var maxTokens *int32 + var maxTokens *int if conf.Contains(ocpFieldMaxTokens) { mt, err := conf.FieldInt(ocpFieldMaxTokens) if err != nil { return nil, err } - m := int32(mt) - maxTokens = &m + maxTokens = &mt } var temp *float32 if conf.Contains(ocpFieldTemp) { @@ -177,14 +228,13 @@ func makeChatProcessor(conf *service.ParsedConfig, mgr *service.Resources) (serv pp := float32(v) presencePenalty = &pp } - var seed *int64 + var seed *int if conf.Contains(ocpFieldSeed) { intSeed, err := conf.FieldInt(ocpFieldSeed) if err != nil { return nil, err } - s := int64(intSeed) - seed = &s + seed = &intSeed } var stop []string if conf.Contains(ocpFieldStop) { @@ -193,7 +243,92 @@ func makeChatProcessor(conf *service.ParsedConfig, mgr *service.Resources) (serv return nil, err } } - return &chatProcessor{b, up, sp, maxTokens, temp, user, topP, frequencyPenalty, presencePenalty, seed, stop}, nil + v, err := conf.FieldString(ocpFieldResponseFormat) + if err != nil { + return nil, err + } + var responseFormat oai.ChatCompletionResponseFormatType + var schemaProvider jsonSchemaProvider + switch v { + case "json": + fallthrough + case "json_object": + responseFormat = oai.ChatCompletionResponseFormatTypeJSONObject + case "json_schema": + responseFormat = oai.ChatCompletionResponseFormatTypeJSONSchema + if conf.Contains(ocpFieldJSONSchema) { + schemaProvider, err = newFixedSchemaProvider(conf.Namespace(ocpFieldJSONSchema)) + if err != nil { + return nil, err + } + } else if conf.Contains(ocpFieldSchemaRegistry) { + schemaProvider, err = newDynamicSchemaProvider(conf.Namespace(ocpFieldSchemaRegistry), mgr) + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("using %s %q, but did not specify %s or %s", ocpFieldResponseFormat, v, ocpFieldJSONSchema, ocpFieldSchemaRegistry) + } + case "text": + responseFormat = oai.ChatCompletionResponseFormatTypeText + default: + return nil, fmt.Errorf("unknown %s: %q", ocpFieldResponseFormat, v) + } + return &chatProcessor{b, up, sp, maxTokens, temp, user, topP, frequencyPenalty, presencePenalty, seed, stop, responseFormat, schemaProvider}, nil +} + +func newFixedSchemaProvider(conf *service.ParsedConfig) (jsonSchemaProvider, error) { + name, err := conf.FieldString(ocpFieldJSONSchemaName) + if err != nil { + return nil, err + } + description := "" + if conf.Contains(ocpFieldJSONSchemaDesc) { + description, err = conf.FieldString(ocpFieldJSONSchemaDesc) + if err != nil { + return nil, err + } + } + schema, err := conf.FieldString(ocpFieldJSONSchemaSchema) + if err != nil { + return nil, err + } + return newFixedSchema(name, description, schema) +} + +func newDynamicSchemaProvider(conf *service.ParsedConfig, mgr *service.Resources) (jsonSchemaProvider, error) { + url, err := conf.FieldString(ocpFieldSchemaRegistryURL) + if err != nil { + return nil, err + } + reqSigner, err := conf.HTTPRequestAuthSignerFromParsed() + if err != nil { + return nil, err + } + tlsConfig, err := conf.FieldTLS(ocpFieldSchemaRegistryTLS) + if err != nil { + return nil, err + } + client, err := sr.NewClient(url, reqSigner, tlsConfig, mgr) + if err != nil { + return nil, fmt.Errorf("unable to create schema registry client: %w", err) + } + subject, err := conf.FieldString(ocpFieldSchemaRegistrySubject) + if err != nil { + return nil, err + } + var refreshInterval time.Duration = math.MaxInt64 + if conf.Contains(ocpFieldSchemaRegistryRefreshInterval) { + refreshInterval, err = conf.FieldDuration(ocpFieldSchemaRegistryRefreshInterval) + if err != nil { + return nil, err + } + } + namePrefix, err := conf.FieldString(ocpFieldSchemaRegistryNamePrefix) + if err != nil { + return nil, err + } + return newDynamicSchema(client, subject, namePrefix, refreshInterval), nil } type chatProcessor struct { @@ -201,40 +336,63 @@ type chatProcessor struct { userPrompt *bloblang.Executor systemPrompt *service.InterpolatedString - maxTokens *int32 + maxTokens *int temperature *float32 user *service.InterpolatedString topP *float32 frequencyPenalty *float32 presencePenalty *float32 - seed *int64 + seed *int stop []string + responseFormat oai.ChatCompletionResponseFormatType + schemaProvider jsonSchemaProvider } func (p *chatProcessor) Process(ctx context.Context, msg *service.Message) (service.MessageBatch, error) { - var body oai.ChatCompletionsOptions - body.DeploymentName = &p.model - body.MaxTokens = p.maxTokens - body.Temperature = p.temperature - body.TopP = p.topP + var body oai.ChatCompletionRequest + body.Model = p.model + if p.maxTokens != nil { + body.MaxTokens = *p.maxTokens + } + if p.temperature != nil { + body.Temperature = *p.temperature + } + if p.topP != nil { + body.TopP = *p.topP + } body.Seed = p.seed - body.FrequencyPenalty = p.frequencyPenalty - body.PresencePenalty = p.presencePenalty + if p.frequencyPenalty != nil { + body.FrequencyPenalty = *p.frequencyPenalty + } + if p.presencePenalty != nil { + body.PresencePenalty = *p.presencePenalty + } + if p.responseFormat != oai.ChatCompletionResponseFormatTypeText { + body.ResponseFormat = &oai.ChatCompletionResponseFormat{Type: p.responseFormat} + if p.schemaProvider != nil { + s, err := p.schemaProvider.GetJSONSchema(ctx) + if err != nil { + return nil, err + } + body.ResponseFormat.JSONSchema = s + } + } body.Stop = p.stop if p.user != nil { u, err := p.user.TryString(msg) if err != nil { return nil, fmt.Errorf("%s interpolation error: %w", ocpFieldUser, err) } - body.User = &u + body.User = u } if p.systemPrompt != nil { s, err := p.systemPrompt.TryString(msg) if err != nil { return nil, fmt.Errorf("%s interpolation error: %w", ocpFieldSystemPrompt, err) } - body.Messages = append(body.Messages, &oai.ChatRequestSystemMessage{ - Content: &s, + body.Messages = append(body.Messages, oai.ChatCompletionMessage{ + Role: "system", + Content: s, }) } if p.userPrompt != nil { @@ -242,33 +400,28 @@ func (p *chatProcessor) Process(ctx context.Context, msg *service.Message) (serv if err != nil { return nil, fmt.Errorf("%s execution error: %w", ocpFieldUserPrompt, err) } - body.Messages = append(body.Messages, &oai.ChatRequestUserMessage{ - Content: oai.NewChatRequestUserMessageContent(bloblang.ValueToString(s)), + body.Messages = append(body.Messages, oai.ChatCompletionMessage{ + Role: "user", + Content: bloblang.ValueToString(s), }) } else { b, err := msg.AsBytes() if err != nil { return nil, err } - body.Messages = append(body.Messages, &oai.ChatRequestUserMessage{ - Content: oai.NewChatRequestUserMessageContent(string(b)), + body.Messages = append(body.Messages, oai.ChatCompletionMessage{ + Role: "user", + Content: string(b), }) } - var opts oai.GetChatCompletionsOptions - resp, err := p.client.GetChatCompletions(ctx, body, &opts) + resp, err := p.client.CreateChatCompletion(ctx, body) if err != nil { return nil, err } if len(resp.Choices) != 1 { return nil, fmt.Errorf("invalid number of choices in response: %d", len(resp.Choices)) } - if resp.Choices[0].Message == nil { - return nil, errors.New("invalid missing message in chat response") - } - if resp.Choices[0].Message.Content == nil { - return nil, errors.New("invalid missing message content in chat response") - } msg = msg.Copy() - msg.SetBytes([]byte(*resp.Choices[0].Message.Content)) + msg.SetBytes([]byte(resp.Choices[0].Message.Content)) return service.MessageBatch{msg}, nil } diff --git a/internal/impl/openai/chat_processor_test.go b/internal/impl/openai/chat_processor_test.go index 446a0dd815..d8f8d7a1cc 100644 --- a/internal/impl/openai/chat_processor_test.go +++ b/internal/impl/openai/chat_processor_test.go @@ -12,10 +12,10 @@ import ( "context" "testing" - oai "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/go-faker/faker/v4" "github.com/redpanda-data/benthos/v4/public/bloblang" "github.com/redpanda-data/benthos/v4/public/service" + oai "github.com/sashabaranov/go-openai" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -24,15 +24,14 @@ type mockChatClient struct { stubClient } -func (m *mockChatClient) GetChatCompletions(ctx context.Context, body oai.ChatCompletionsOptions, options *oai.GetChatCompletionsOptions) (resp oai.GetChatCompletionsResponse, err error) { - id := faker.UUIDHyphenated() - resp.ID = &id - resp.Model = body.DeploymentName - content := faker.Paragraph() - resp.Choices = []oai.ChatChoice{ +func (m *mockChatClient) CreateChatCompletion(ctx context.Context, body oai.ChatCompletionRequest) (resp oai.ChatCompletionResponse, err error) { + resp.ID = faker.UUIDHyphenated() + resp.Model = body.Model + resp.Choices = []oai.ChatCompletionChoice{ { - Message: &oai.ChatResponseMessage{ - Content: &content, + Message: oai.ChatCompletionMessage{ + Role: "assistant", + Content: faker.Paragraph(), }, }, } diff --git a/internal/impl/openai/client.go b/internal/impl/openai/client.go index 32e6e9ece6..8ea720066a 100644 --- a/internal/impl/openai/client.go +++ b/internal/impl/openai/client.go @@ -11,15 +11,15 @@ package openai import ( "context" - oai "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" + oai "github.com/sashabaranov/go-openai" ) // A mockable client for unit testing type client interface { - GetEmbeddings(ctx context.Context, body oai.EmbeddingsOptions, options *oai.GetEmbeddingsOptions) (oai.GetEmbeddingsResponse, error) - GetChatCompletions(ctx context.Context, body oai.ChatCompletionsOptions, options *oai.GetChatCompletionsOptions) (oai.GetChatCompletionsResponse, error) - GenerateSpeechFromText(ctx context.Context, body oai.SpeechGenerationOptions, options *oai.GenerateSpeechFromTextOptions) (oai.GenerateSpeechFromTextResponse, error) - GetAudioTranscription(ctx context.Context, body oai.AudioTranscriptionOptions, options *oai.GetAudioTranscriptionOptions) (oai.GetAudioTranscriptionResponse, error) - GetAudioTranslation(ctx context.Context, body oai.AudioTranslationOptions, options *oai.GetAudioTranslationOptions) (oai.GetAudioTranslationResponse, error) - GetImageGenerations(ctx context.Context, body oai.ImageGenerationOptions, options *oai.GetImageGenerationsOptions) (oai.GetImageGenerationsResponse, error) + CreateChatCompletion(ctx context.Context, body oai.ChatCompletionRequest) (oai.ChatCompletionResponse, error) + CreateEmbeddings(ctx context.Context, body oai.EmbeddingRequestConverter) (oai.EmbeddingResponse, error) + CreateSpeech(ctx context.Context, body oai.CreateSpeechRequest) (oai.RawResponse, error) + CreateTranscription(ctx context.Context, body oai.AudioRequest) (oai.AudioResponse, error) + CreateTranslation(ctx context.Context, body oai.AudioRequest) (oai.AudioResponse, error) + CreateImage(ctx context.Context, body oai.ImageRequest) (oai.ImageResponse, error) } diff --git a/internal/impl/openai/client_test.go b/internal/impl/openai/client_test.go index 77027de024..935793ee04 100644 --- a/internal/impl/openai/client_test.go +++ b/internal/impl/openai/client_test.go @@ -12,37 +12,37 @@ import ( "context" "errors" - oai "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" + oai "github.com/sashabaranov/go-openai" ) type stubClient struct{} -func (*stubClient) GetEmbeddings(ctx context.Context, body oai.EmbeddingsOptions, options *oai.GetEmbeddingsOptions) (r oai.GetEmbeddingsResponse, err error) { +func (*stubClient) CreateEmbeddings(ctx context.Context, body oai.EmbeddingRequestConverter) (r oai.EmbeddingResponse, err error) { err = errors.New("unimplemented") return } -func (*stubClient) GetChatCompletions(ctx context.Context, body oai.ChatCompletionsOptions, options *oai.GetChatCompletionsOptions) (r oai.GetChatCompletionsResponse, err error) { +func (*stubClient) CreateChatCompletion(ctx context.Context, body oai.ChatCompletionRequest) (r oai.ChatCompletionResponse, err error) { err = errors.New("unimplemented") return } -func (*stubClient) GenerateSpeechFromText(ctx context.Context, body oai.SpeechGenerationOptions, options *oai.GenerateSpeechFromTextOptions) (r oai.GenerateSpeechFromTextResponse, err error) { +func (*stubClient) CreateSpeech(ctx context.Context, body oai.CreateSpeechRequest) (r oai.RawResponse, err error) { err = errors.New("unimplemented") return } -func (*stubClient) GetAudioTranscription(ctx context.Context, body oai.AudioTranscriptionOptions, options *oai.GetAudioTranscriptionOptions) (r oai.GetAudioTranscriptionResponse, err error) { +func (*stubClient) CreateTranscription(ctx context.Context, body oai.AudioRequest) (r oai.AudioResponse, err error) { err = errors.New("unimplemented") return } -func (*stubClient) GetAudioTranslation(ctx context.Context, body oai.AudioTranslationOptions, options *oai.GetAudioTranslationOptions) (r oai.GetAudioTranslationResponse, err error) { +func (*stubClient) CreateTranslation(ctx context.Context, body oai.AudioRequest) (r oai.AudioResponse, err error) { err = errors.New("unimplemented") return } -func (*stubClient) GetImageGenerations(ctx context.Context, body oai.ImageGenerationOptions, options *oai.GetImageGenerationsOptions) (r oai.GetImageGenerationsResponse, err error) { +func (*stubClient) CreateImage(ctx context.Context, body oai.ImageRequest) (r oai.ImageResponse, err error) { err = errors.New("unimplemented") return } diff --git a/internal/impl/openai/embeddings_processor.go b/internal/impl/openai/embeddings_processor.go index 6e901249d6..0166f144aa 100644 --- a/internal/impl/openai/embeddings_processor.go +++ b/internal/impl/openai/embeddings_processor.go @@ -12,9 +12,9 @@ import ( "context" "fmt" - oai "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/redpanda-data/benthos/v4/public/bloblang" "github.com/redpanda-data/benthos/v4/public/service" + oai "github.com/sashabaranov/go-openai" ) const ( @@ -71,14 +71,13 @@ func makeEmbeddingsProcessor(conf *service.ParsedConfig, mgr *service.Resources) return nil, err } } - var dims *int32 + var dims *int if conf.Contains(oepFieldDims) { v, err := conf.FieldInt(oepFieldDims) if err != nil { return nil, err } - d := int32(v) - dims = &d + dims = &v } return &embeddingsProcessor{b, t, dims}, nil } @@ -87,13 +86,15 @@ type embeddingsProcessor struct { *baseProcessor text *bloblang.Executor - dimensions *int32 + dimensions *int } func (p *embeddingsProcessor) Process(ctx context.Context, msg *service.Message) (service.MessageBatch, error) { - var body oai.EmbeddingsOptions - body.DeploymentName = &p.model - body.Dimensions = p.dimensions + var body oai.EmbeddingRequestStrings + body.Model = oai.EmbeddingModel(p.model) + if p.dimensions != nil { + body.Dimensions = *p.dimensions + } if p.text != nil { s, err := msg.BloblangQueryValue(p.text) if err != nil { @@ -107,7 +108,7 @@ func (p *embeddingsProcessor) Process(ctx context.Context, msg *service.Message) } body.Input = append(body.Input, string(b)) } - resp, err := p.client.GetEmbeddings(ctx, body, nil) + resp, err := p.client.CreateEmbeddings(ctx, body) if err != nil { return nil, err } diff --git a/internal/impl/openai/embeddings_processor_test.go b/internal/impl/openai/embeddings_processor_test.go index 2f9aca1c1c..6cf8161a79 100644 --- a/internal/impl/openai/embeddings_processor_test.go +++ b/internal/impl/openai/embeddings_processor_test.go @@ -12,11 +12,11 @@ import ( "context" "testing" - oai "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/go-faker/faker/v4" "github.com/go-faker/faker/v4/pkg/options" "github.com/redpanda-data/benthos/v4/public/bloblang" "github.com/redpanda-data/benthos/v4/public/service" + oai "github.com/sashabaranov/go-openai" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -33,12 +33,12 @@ func mockEmbeddings(text string) []float32 { return embd } -func (m *mockEmbeddingsClient) GetEmbeddings(ctx context.Context, body oai.EmbeddingsOptions, options *oai.GetEmbeddingsOptions) (resp oai.GetEmbeddingsResponse, err error) { +func (m *mockEmbeddingsClient) CreateEmbeddings(ctx context.Context, genericBody oai.EmbeddingRequestConverter) (resp oai.EmbeddingResponse, err error) { + body := genericBody.(oai.EmbeddingRequestStrings) for i, text := range body.Input { - idx := int32(i) - resp.Data = append(resp.Data, oai.EmbeddingItem{ + resp.Data = append(resp.Data, oai.Embedding{ Embedding: mockEmbeddings(text), - Index: &idx, + Index: i, }) } return diff --git a/internal/impl/openai/image_processor.go b/internal/impl/openai/image_processor.go index cba9ace730..892d0798d5 100644 --- a/internal/impl/openai/image_processor.go +++ b/internal/impl/openai/image_processor.go @@ -13,11 +13,10 @@ import ( "encoding/base64" "errors" "fmt" - "slices" - oai "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/redpanda-data/benthos/v4/public/bloblang" "github.com/redpanda-data/benthos/v4/public/service" + oai "github.com/sashabaranov/go-openai" ) const ( @@ -121,70 +120,56 @@ type moderationProcessor struct { } func (p *moderationProcessor) Process(ctx context.Context, msg *service.Message) (service.MessageBatch, error) { - var body oai.ImageGenerationOptions - body.DeploymentName = &p.model - format := oai.ImageGenerationResponseFormatBase64 - body.ResponseFormat = &format + var body oai.ImageRequest + body.Model = p.model + body.ResponseFormat = "b64_json" if p.input != nil { v, err := msg.BloblangQueryValue(p.input) if err != nil { return nil, fmt.Errorf("%s execution error: %w", oipFieldPrompt, err) } s := bloblang.ValueToString(v) - body.Prompt = &s + body.Prompt = s } else { b, err := msg.AsBytes() if err != nil { return nil, err } s := string(b) - body.Prompt = &s + body.Prompt = s } if p.quality != nil { r, err := p.quality.TryString(msg) if err != nil { return nil, fmt.Errorf("%s interpolation error: %w", oipFieldQuality, err) } - q := oai.ImageGenerationQuality(r) - if !slices.Contains(oai.PossibleImageGenerationQualityValues(), q) { - return nil, fmt.Errorf("invalid image quality: %q", q) - } - body.Quality = &q + body.Quality = r } if p.style != nil { r, err := p.style.TryString(msg) if err != nil { return nil, fmt.Errorf("%s interpolation error: %w", oipFieldStyle, err) } - s := oai.ImageGenerationStyle(r) - if !slices.Contains(oai.PossibleImageGenerationStyleValues(), s) { - return nil, fmt.Errorf("invalid image style: %q", s) - } - body.Style = &s + body.Style = r } if p.size != nil { r, err := p.size.TryString(msg) if err != nil { return nil, fmt.Errorf("%s interpolation error: %w", oipFieldSize, err) } - s := oai.ImageSize(r) - if !slices.Contains(oai.PossibleImageSizeValues(), s) { - return nil, fmt.Errorf("invalid image style: %q", s) - } - body.Size = &s + body.Size = r } - var opts oai.GetImageGenerationsOptions - resp, err := p.client.GetImageGenerations(ctx, body, &opts) + resp, err := p.client.CreateImage(ctx, body) if err != nil { return nil, err } if len(resp.Data) != 1 { return nil, fmt.Errorf("expected single generated image in response, got: %d", len(resp.Data)) } - if resp.Data[0].Base64Data == nil { + if resp.Data[0].B64JSON == "" { return nil, errors.New("missing generated image data in response") } - b, err := base64.StdEncoding.DecodeString(*resp.Data[0].Base64Data) + b, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON) if err != nil { return nil, err } diff --git a/internal/impl/openai/json_schema_provider.go b/internal/impl/openai/json_schema_provider.go new file mode 100644 index 0000000000..7b6f45d21d --- /dev/null +++ b/internal/impl/openai/json_schema_provider.go @@ -0,0 +1,94 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/blob/main/licenses/rcl.md + +package openai + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" + oai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +type jsonSchemaProvider interface { + GetJSONSchema(context.Context) (*oai.ChatCompletionResponseFormatJSONSchema, error) +} + +type fixedSchemaProvider struct { + oai.ChatCompletionResponseFormatJSONSchema +} + +func (s *fixedSchemaProvider) GetJSONSchema(context.Context) (*oai.ChatCompletionResponseFormatJSONSchema, error) { + return &s.ChatCompletionResponseFormatJSONSchema, nil +} + +func newFixedSchema(name, description, raw string) (jsonSchemaProvider, error) { + p := &fixedSchemaProvider{} + p.Name = name + p.Description = description + if err := json.Unmarshal([]byte(raw), &p.Schema); err != nil { + return nil, fmt.Errorf("invalid JSON schema: %w", err) + } + p.Strict = true + return p, nil +} + +type dynamicSchemaProvider struct { + cached *oai.ChatCompletionResponseFormatJSONSchema + nextRefreshTime time.Time + refreshInterval time.Duration + mu sync.Mutex + + client *sr.Client + subject string + namePrefix string +} + +func (p *dynamicSchemaProvider) GetJSONSchema(ctx context.Context) (*oai.ChatCompletionResponseFormatJSONSchema, error) { + if time.Now().Before(p.nextRefreshTime) { + return p.cached, nil + } + p.mu.Lock() + defer p.mu.Unlock() + // Double check since we now have the lock that we didn't race with other requests + if time.Now().Before(p.nextRefreshTime) { + return p.cached, nil + } + info, err := p.client.GetSchemaBySubjectAndVersion(ctx, p.subject, nil) + if err != nil { + return nil, fmt.Errorf("unable to load latest schema for subject %q: %w", p.subject, err) + } + var schema jsonschema.Definition + if err := json.Unmarshal([]byte(info.Schema), &schema); err != nil { + return nil, fmt.Errorf("unable to parse json schema from schema with ID=%d", info.ID) + } + name := fmt.Sprintf("%s%d", p.namePrefix, info.ID) + p.cached = &oai.ChatCompletionResponseFormatJSONSchema{ + Name: name, + Schema: schema, + Strict: true, + } + p.nextRefreshTime = time.Now().Add(p.refreshInterval) + return p.cached, nil +} + +func newDynamicSchema(client *sr.Client, subject, namePrefix string, refreshInterval time.Duration) jsonSchemaProvider { + return &dynamicSchemaProvider{ + cached: nil, + nextRefreshTime: time.UnixMilli(0), + refreshInterval: refreshInterval, + client: client, + subject: subject, + namePrefix: namePrefix, + } +} diff --git a/internal/impl/openai/speech_processor.go b/internal/impl/openai/speech_processor.go index 4725dd92df..fa527ab7bd 100644 --- a/internal/impl/openai/speech_processor.go +++ b/internal/impl/openai/speech_processor.go @@ -12,11 +12,10 @@ import ( "context" "fmt" "io" - "slices" - oai "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/redpanda-data/benthos/v4/public/bloblang" "github.com/redpanda-data/benthos/v4/public/service" + oai "github.com/sashabaranov/go-openai" ) const ( @@ -101,50 +100,39 @@ type speechProcessor struct { } func (p *speechProcessor) Process(ctx context.Context, msg *service.Message) (service.MessageBatch, error) { - var body oai.SpeechGenerationOptions - body.DeploymentName = &p.model + var body oai.CreateSpeechRequest + body.Model = oai.SpeechModel(p.model) v, err := p.voice.TryString(msg) if err != nil { return nil, fmt.Errorf("%s interpolation error: %w", ospFieldVoice, err) } - voice := oai.SpeechVoice(v) - if !slices.Contains(oai.PossibleSpeechVoiceValues(), voice) { - return nil, fmt.Errorf("unknown speech voice value: %q", voice) - } - body.Voice = &voice + body.Voice = oai.SpeechVoice(v) if p.input != nil { v, err := msg.BloblangQueryValue(p.input) if err != nil { return nil, fmt.Errorf("%s execution error: %w", ospFieldInput, err) } - s := bloblang.ValueToString(v) - body.Input = &s + body.Input = bloblang.ValueToString(v) } else { b, err := msg.AsBytes() if err != nil { return nil, err } - s := string(b) - body.Input = &s + body.Input = string(b) } if p.responseFormat != nil { rf, err := p.responseFormat.TryString(msg) if err != nil { return nil, fmt.Errorf("%s interpolation error: %w", ospFieldResponseFormat, err) } - format := oai.SpeechGenerationResponseFormat(rf) - if !slices.Contains(oai.PossibleSpeechGenerationResponseFormatValues(), format) { - return nil, fmt.Errorf("unknown speech generation format value: %q", format) - } - body.ResponseFormat = &format + body.ResponseFormat = oai.SpeechResponseFormat(rf) } - var opts oai.GenerateSpeechFromTextOptions - resp, err := p.client.GenerateSpeechFromText(ctx, body, &opts) + resp, err := p.client.CreateSpeech(ctx, body) if err != nil { return nil, err } - defer resp.Body.Close() - b, err := io.ReadAll(resp.Body) + defer resp.Close() + b, err := io.ReadAll(resp) if err != nil { return nil, err } diff --git a/internal/impl/openai/transcription_processor.go b/internal/impl/openai/transcription_processor.go index 130a561bdd..725f2e8837 100644 --- a/internal/impl/openai/transcription_processor.go +++ b/internal/impl/openai/transcription_processor.go @@ -9,13 +9,13 @@ package openai import ( + "bytes" "context" - "errors" "fmt" - oai "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/redpanda-data/benthos/v4/public/bloblang" "github.com/redpanda-data/benthos/v4/public/service" + oai "github.com/sashabaranov/go-openai" ) const ( @@ -99,39 +99,36 @@ type transcriptionProcessor struct { } func (p *transcriptionProcessor) Process(ctx context.Context, msg *service.Message) (service.MessageBatch, error) { - var body oai.AudioTranscriptionOptions - body.DeploymentName = &p.model + var body oai.AudioRequest + body.Model = p.model f, err := msg.BloblangQueryValue(p.file) if err != nil { return nil, fmt.Errorf("%s execution error: %w", otspFieldFile, err) } - body.File, err = bloblang.ValueAsBytes(f) + b, err := bloblang.ValueAsBytes(f) if err != nil { return nil, err } + body.Reader = bytes.NewReader(b) if p.lang != nil { l, err := p.lang.TryString(msg) if err != nil { return nil, fmt.Errorf("%s interpolation error: %w", otspFieldLang, err) } - body.Language = &l + body.Language = l } if p.prompt != nil { pr, err := p.prompt.TryString(msg) if err != nil { return nil, fmt.Errorf("%s interpolation error: %w", otspFieldPrompt, err) } - body.Prompt = &pr + body.Prompt = pr } - var opts oai.GetAudioTranscriptionOptions - resp, err := p.client.GetAudioTranscription(ctx, body, &opts) + resp, err := p.client.CreateTranscription(ctx, body) if err != nil { return nil, err } - if resp.Text == nil { - return nil, errors.New("missing text in transcription response") - } msg = msg.Copy() - msg.SetBytes([]byte(*resp.Text)) + msg.SetBytes([]byte(resp.Text)) return service.MessageBatch{msg}, nil } diff --git a/internal/impl/openai/translation_processor.go b/internal/impl/openai/translation_processor.go index 78390d3674..e63ffe4571 100644 --- a/internal/impl/openai/translation_processor.go +++ b/internal/impl/openai/translation_processor.go @@ -9,13 +9,13 @@ package openai import ( + "bytes" "context" - "errors" "fmt" - oai "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/redpanda-data/benthos/v4/public/bloblang" "github.com/redpanda-data/benthos/v4/public/service" + oai "github.com/sashabaranov/go-openai" ) const ( @@ -89,40 +89,37 @@ type translationProcessor struct { } func (p *translationProcessor) Process(ctx context.Context, msg *service.Message) (service.MessageBatch, error) { - var body oai.AudioTranslationOptions - body.DeploymentName = &p.model + var body oai.AudioRequest + body.Model = p.model if p.file != nil { f, err := msg.BloblangQueryValue(p.file) if err != nil { return nil, fmt.Errorf("%s execution error: %w", otlpFieldFile, err) } - body.File, err = bloblang.ValueAsBytes(f) + b, err := bloblang.ValueAsBytes(f) if err != nil { return nil, fmt.Errorf("%s conversion error: %w", otlpFieldFile, err) } + body.Reader = bytes.NewReader(b) } else { f, err := msg.AsBytes() if err != nil { return nil, err } - body.File = f + body.Reader = bytes.NewReader(f) } if p.prompt != nil { pr, err := p.prompt.TryString(msg) if err != nil { return nil, fmt.Errorf("%s interpolation error: %w", otlpFieldPrompt, err) } - body.Prompt = &pr + body.Prompt = pr } - var opts oai.GetAudioTranslationOptions - resp, err := p.client.GetAudioTranslation(ctx, body, &opts) + resp, err := p.client.CreateTranslation(ctx, body) if err != nil { return nil, err } - if resp.Text == nil { - return nil, errors.New("missing text in translation response") - } msg = msg.Copy() - msg.SetBytes([]byte(*resp.Text)) + msg.SetBytes([]byte(resp.Text)) return service.MessageBatch{msg}, nil }