diff --git a/docs/modules/components/pages/processors/openai_chat_completion.adoc b/docs/modules/components/pages/processors/openai_chat_completion.adoc index aa9983c986..cd28e97e96 100644 --- a/docs/modules/components/pages/processors/openai_chat_completion.adoc +++ b/docs/modules/components/pages/processors/openai_chat_completion.adoc @@ -46,6 +46,10 @@ openai_chat_completion: max_tokens: 0 # No default (optional) temperature: 0 # No default (optional) user: "" # No default (optional) + response_format: text + json_schema: + name: "" # No default (required) + schema: "" # No default (required) ``` -- @@ -65,6 +69,38 @@ openai_chat_completion: max_tokens: 0 # No default (optional) temperature: 0 # No default (optional) user: "" # No default (optional) + response_format: text + json_schema: + name: "" # No default (required) + description: "" # No default (optional) + schema: "" # No default (required) + schema_registry: + url: "" # No default (required) + name_prefix: schema_registry_id_ + subject: "" # No default (required) + refresh_interval: "" # No default (optional) + tls: + skip_cert_verify: false + enable_renegotiation: false + root_cas: "" + root_cas_file: "" + client_certs: [] + oauth: + enabled: false + consumer_key: "" + consumer_secret: "" + access_token: "" + access_token_secret: "" + basic_auth: + enabled: false + username: "" + password: "" + jwt: + enabled: false + private_key_file: "" + signing_method: "" + claims: {} + headers: {} top_p: 0 # No default (optional) frequency_penalty: 0 # No default (optional) presence_penalty: 0 # No default (optional) @@ -167,6 +203,401 @@ This field supports xref:configuration:interpolation.adoc#bloblang-queries[inter *Type*: `string` +=== `response_format` + +Specify the model's output format. If `json_schema` is specified, then additionally a `json_schema` or `schema_registry` must be configured. + + +*Type*: `string` + +*Default*: `"text"` + +Options: +`text` +, `json` +, `json_schema` +. + +=== `json_schema` + +The JSON schema to use when responding in `json_schema` format. To learn more about what JSON schema is supported see the https://platform.openai.com/docs/guides/structured-outputs/supported-schemas[OpenAI documentation^]. + + +*Type*: `object` + + +=== `json_schema.name` + +The name of the schema. + + +*Type*: `string` + + +=== `json_schema.description` + +Additional description of the schema for the LLM. + + +*Type*: `string` + + +=== `json_schema.schema` + +The JSON schema for the LLM to use when generating the output. + + +*Type*: `string` + + +=== `schema_registry` + +The schema registry to dynamically load schemas from when responding in `json_schema` format. Schemas themselves must be in JSON format. To learn more about what JSON schema is supported see the https://platform.openai.com/docs/guides/structured-outputs/supported-schemas[OpenAI documentation^]. + + +*Type*: `object` + + +=== `schema_registry.url` + +The base URL of the schema registry service. + + +*Type*: `string` + + +=== `schema_registry.name_prefix` + +The prefix of the name for this schema, the schema ID is used as a suffix. + + +*Type*: `string` + +*Default*: `"schema_registry_id_"` + +=== `schema_registry.subject` + +The subject name to fetch the schema for. + + +*Type*: `string` + + +=== `schema_registry.refresh_interval` + +The refresh rate for getting the latest schema. If not specified the schema does not refresh. + + +*Type*: `string` + + +=== `schema_registry.tls` + +Custom TLS settings can be used to override system defaults. + + +*Type*: `object` + + +=== `schema_registry.tls.skip_cert_verify` + +Whether to skip server side certificate verification. + + +*Type*: `bool` + +*Default*: `false` + +=== `schema_registry.tls.enable_renegotiation` + +Whether to allow the remote server to repeatedly request renegotiation. Enable this option if you're seeing the error message `local error: tls: no renegotiation`. + + +*Type*: `bool` + +*Default*: `false` +Requires version 3.45.0 or newer + +=== `schema_registry.tls.root_cas` + +An optional root certificate authority to use. This is a string, representing a certificate chain from the parent trusted root certificate, to possible intermediate signing certificates, to the host certificate. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +root_cas: |- + -----BEGIN CERTIFICATE----- + ... + -----END CERTIFICATE----- +``` + +=== `schema_registry.tls.root_cas_file` + +An optional path of a root certificate authority file to use. This is a file, often with a .pem extension, containing a certificate chain from the parent trusted root certificate, to possible intermediate signing certificates, to the host certificate. + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +root_cas_file: ./root_cas.pem +``` + +=== `schema_registry.tls.client_certs` + +A list of client certificates to use. For each certificate either the fields `cert` and `key`, or `cert_file` and `key_file` should be specified, but not both. + + +*Type*: `array` + +*Default*: `[]` + +```yml +# Examples + +client_certs: + - cert: foo + key: bar + +client_certs: + - cert_file: ./example.pem + key_file: ./example.key +``` + +=== `schema_registry.tls.client_certs[].cert` + +A plain text certificate to use. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.tls.client_certs[].key` + +A plain text certificate key to use. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.tls.client_certs[].cert_file` + +The path of a certificate to use. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.tls.client_certs[].key_file` + +The path of a certificate key to use. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.tls.client_certs[].password` + +A plain text password for when the private key is password encrypted in PKCS#1 or PKCS#8 format. The obsolete `pbeWithMD5AndDES-CBC` algorithm is not supported for the PKCS#8 format. + +Because the obsolete pbeWithMD5AndDES-CBC algorithm does not authenticate the ciphertext, it is vulnerable to padding oracle attacks that can let an attacker recover the plaintext. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +password: foo + +password: ${KEY_PASSWORD} +``` + +=== `schema_registry.oauth` + +Allows you to specify open authentication via OAuth version 1. + + +*Type*: `object` + + +=== `schema_registry.oauth.enabled` + +Whether to use OAuth version 1 in requests. + + +*Type*: `bool` + +*Default*: `false` + +=== `schema_registry.oauth.consumer_key` + +A value used to identify the client to the service provider. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.oauth.consumer_secret` + +A secret used to establish ownership of the consumer key. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.oauth.access_token` + +A value used to gain access to the protected resources on behalf of the user. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.oauth.access_token_secret` + +A secret provided in order to establish ownership of a given access token. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.basic_auth` + +Allows you to specify basic authentication. + + +*Type*: `object` + + +=== `schema_registry.basic_auth.enabled` + +Whether to use basic authentication in requests. + + +*Type*: `bool` + +*Default*: `false` + +=== `schema_registry.basic_auth.username` + +A username to authenticate as. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.basic_auth.password` + +A password to authenticate with. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.jwt` + +BETA: Allows you to specify JWT authentication. + + +*Type*: `object` + + +=== `schema_registry.jwt.enabled` + +Whether to use JWT authentication in requests. + + +*Type*: `bool` + +*Default*: `false` + +=== `schema_registry.jwt.private_key_file` + +A file with the PEM encoded via PKCS1 or PKCS8 as private key. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.jwt.signing_method` + +A method used to sign the token such as RS256, RS384, RS512 or EdDSA. + + +*Type*: `string` + +*Default*: `""` + +=== `schema_registry.jwt.claims` + +A value used to identify the claims that issued the JWT. + + +*Type*: `object` + +*Default*: `{}` + +=== `schema_registry.jwt.headers` + +Add optional key/value headers to the JWT. + + +*Type*: `object` + +*Default*: `{}` + === `top_p` An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. diff --git a/internal/impl/openai/chat_processor.go b/internal/impl/openai/chat_processor.go index 9e734aed6d..7a6469d9a1 100644 --- a/internal/impl/openai/chat_processor.go +++ b/internal/impl/openai/chat_processor.go @@ -11,9 +11,13 @@ package openai import ( "context" "fmt" + "math" + "slices" + "time" "github.com/redpanda-data/benthos/v4/public/bloblang" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" oai "github.com/sashabaranov/go-openai" ) @@ -28,6 +32,19 @@ const ( ocpFieldStop = "stop" ocpFieldPresencePenalty = "presence_penalty" ocpFieldFrequencyPenalty = "frequency_penalty" + ocpFieldResponseFormat = "response_format" + // JSON schema fields + ocpFieldJSONSchema = "json_schema" + ocpFieldJSONSchemaName = "name" + ocpFieldJSONSchemaDesc = "description" + ocpFieldJSONSchemaSchema = "schema" + // Schema registry fields + ocpFieldSchemaRegistry = "schema_registry" + ocpFieldSchemaRegistrySubject = "subject" + ocpFieldSchemaRegistryRefreshInterval = "refresh_interval" + ocpFieldSchemaRegistryNamePrefix = "name_prefix" + ocpFieldSchemaRegistryURL = "url" + ocpFieldSchemaRegistryTLS = "tls" ) func init() { @@ -77,6 +94,37 @@ We generally recommend altering this or top_p but not both.`). service.NewInterpolatedStringField(ocpFieldUser). Optional(). Description("A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse."), + service.NewStringEnumField(ocpFieldResponseFormat, "text", "json", "json_schema"). + Default("text"). + Description("Specify the model's output format. If `json_schema` is specified, then additionally a `json_schema` or `schema_registry` must be configured."), + service.NewObjectField(ocpFieldJSONSchema, + service.NewStringField(ocpFieldJSONSchemaName).Description("The name of the schema."), + service.NewStringField(ocpFieldJSONSchemaDesc).Optional().Advanced().Description("Additional description of the schema for the LLM."), + service.NewStringField(ocpFieldJSONSchemaSchema).Description("The JSON schema for the LLM to use when generating the output."), + ). + Optional(). + Description("The JSON schema to use when responding in `json_schema` format. To learn more about what JSON schema is supported see the https://platform.openai.com/docs/guides/structured-outputs/supported-schemas[OpenAI documentation^]."), + service.NewObjectField( + ocpFieldSchemaRegistry, + slices.Concat( + []*service.ConfigField{ + service.NewURLField(ocpFieldSchemaRegistryURL).Description("The base URL of the schema registry service."), + service.NewStringField(ocpFieldSchemaRegistryNamePrefix). + Default("schema_registry_id_"). + Description("The prefix of the name for this schema, the schema ID is used as a suffix."), + service.NewStringField(ocpFieldSchemaRegistrySubject). + Description("The subject name to fetch the schema for."), + service.NewDurationField(ocpFieldSchemaRegistryRefreshInterval). + Optional(). + Description("The refresh rate for getting the latest schema. If not specified the schema does not refresh."), + service.NewTLSField(ocpFieldSchemaRegistryTLS), + }, + service.NewHTTPRequestAuthSignerFields(), + )..., + ). + Description("The schema registry to dynamically load schemas from when responding in `json_schema` format. Schemas themselves must be in JSON format. To learn more about what JSON schema is supported see the https://platform.openai.com/docs/guides/structured-outputs/supported-schemas[OpenAI documentation^]."). + Optional(). + Advanced(), service.NewFloatField(ocpFieldTopP). Optional(). Advanced(). @@ -102,7 +150,12 @@ We generally recommend altering this or temperature but not both.`). Optional(). Advanced(). Description("Up to 4 sequences where the API will stop generating further tokens."), - ) + ).LintRule(` + root = match { + this.exists("` + ocpFieldJSONSchema + `") && this.exists("` + ocpFieldSchemaRegistry + `") => ["cannot set both ` + "`" + ocpFieldJSONSchema + "`" + ` and ` + "`" + ocpFieldSchemaRegistry + "`" + `"] + this.response_format == "json_schema" && !this.exists("` + ocpFieldJSONSchema + `") && !this.exists("` + ocpFieldSchemaRegistry + `") => ["schema must be specified using either ` + "`" + ocpFieldJSONSchema + "`" + ` or ` + "`" + ocpFieldSchemaRegistry + "`" + `"] + } + `) } func makeChatProcessor(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) { @@ -190,7 +243,92 @@ func makeChatProcessor(conf *service.ParsedConfig, mgr *service.Resources) (serv return nil, err } } - return &chatProcessor{b, up, sp, maxTokens, temp, user, topP, frequencyPenalty, presencePenalty, seed, stop}, nil + v, err := conf.FieldString(ocpFieldResponseFormat) + if err != nil { + return nil, err + } + var responseFormat oai.ChatCompletionResponseFormatType + var schemaProvider jsonSchemaProvider + switch v { + case "json": + fallthrough + case "json_object": + responseFormat = oai.ChatCompletionResponseFormatTypeJSONObject + case "json_schema": + responseFormat = oai.ChatCompletionResponseFormatTypeJSONSchema + if conf.Contains(ocpFieldJSONSchema) { + schemaProvider, err = newFixedSchemaProvider(conf.Namespace(ocpFieldJSONSchema)) + if err != nil { + return nil, err + } + } else if conf.Contains(ocpFieldSchemaRegistry) { + schemaProvider, err = newDynamicSchemaProvider(conf.Namespace(ocpFieldSchemaRegistry), mgr) + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("using %s %q, but did not specify %s or %s", ocpFieldResponseFormat, v, ocpFieldJSONSchema, ocpFieldSchemaRegistry) + } + case "text": + responseFormat = oai.ChatCompletionResponseFormatTypeText + default: + return nil, fmt.Errorf("unknown %s: %q", ocpFieldResponseFormat, v) + } + return &chatProcessor{b, up, sp, maxTokens, temp, user, topP, frequencyPenalty, presencePenalty, seed, stop, responseFormat, schemaProvider}, nil +} + +func newFixedSchemaProvider(conf *service.ParsedConfig) (jsonSchemaProvider, error) { + name, err := conf.FieldString(ocpFieldJSONSchemaName) + if err != nil { + return nil, err + } + description := "" + if conf.Contains(ocpFieldJSONSchemaDesc) { + description, err = conf.FieldString(ocpFieldJSONSchemaDesc) + if err != nil { + return nil, err + } + } + schema, err := conf.FieldString(ocpFieldJSONSchemaSchema) + if err != nil { + return nil, err + } + return newFixedSchema(name, description, schema) +} + +func newDynamicSchemaProvider(conf *service.ParsedConfig, mgr *service.Resources) (jsonSchemaProvider, error) { + url, err := conf.FieldString(ocpFieldSchemaRegistryURL) + if err != nil { + return nil, err + } + reqSigner, err := conf.HTTPRequestAuthSignerFromParsed() + if err != nil { + return nil, err + } + tlsConfig, err := conf.FieldTLS(ocpFieldSchemaRegistryTLS) + if err != nil { + return nil, err + } + client, err := sr.NewClient(url, reqSigner, tlsConfig, mgr) + if err != nil { + return nil, fmt.Errorf("unable to create schema registry client: %w", err) + } + subject, err := conf.FieldString(ocpFieldSchemaRegistrySubject) + if err != nil { + return nil, err + } + var refreshInterval time.Duration = math.MaxInt64 + if conf.Contains(ocpFieldSchemaRegistryRefreshInterval) { + refreshInterval, err = conf.FieldDuration(ocpFieldSchemaRegistryRefreshInterval) + if err != nil { + return nil, err + } + } + namePrefix, err := conf.FieldString(ocpFieldSchemaRegistryNamePrefix) + if err != nil { + return nil, err + } + return newDynamicSchema(client, subject, namePrefix, refreshInterval), nil } type chatProcessor struct { @@ -206,6 +344,8 @@ type chatProcessor struct { presencePenalty *float32 seed *int stop []string + responseFormat oai.ChatCompletionResponseFormatType + schemaProvider jsonSchemaProvider } func (p *chatProcessor) Process(ctx context.Context, msg *service.Message) (service.MessageBatch, error) { @@ -227,6 +367,16 @@ func (p *chatProcessor) Process(ctx context.Context, msg *service.Message) (serv if p.presencePenalty != nil { body.PresencePenalty = *p.presencePenalty } + if p.responseFormat != oai.ChatCompletionResponseFormatTypeText { + body.ResponseFormat = &oai.ChatCompletionResponseFormat{Type: p.responseFormat} + if p.schemaProvider != nil { + s, err := p.schemaProvider.GetJSONSchema(ctx) + if err != nil { + return nil, err + } + body.ResponseFormat.JSONSchema = s + } + } body.Stop = p.stop if p.user != nil { u, err := p.user.TryString(msg) diff --git a/internal/impl/openai/json_schema_provider.go b/internal/impl/openai/json_schema_provider.go new file mode 100644 index 0000000000..7b6f45d21d --- /dev/null +++ b/internal/impl/openai/json_schema_provider.go @@ -0,0 +1,94 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/blob/main/licenses/rcl.md + +package openai + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" + oai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +type jsonSchemaProvider interface { + GetJSONSchema(context.Context) (*oai.ChatCompletionResponseFormatJSONSchema, error) +} + +type fixedSchemaProvider struct { + oai.ChatCompletionResponseFormatJSONSchema +} + +func (s *fixedSchemaProvider) GetJSONSchema(context.Context) (*oai.ChatCompletionResponseFormatJSONSchema, error) { + return &s.ChatCompletionResponseFormatJSONSchema, nil +} + +func newFixedSchema(name, description, raw string) (jsonSchemaProvider, error) { + p := &fixedSchemaProvider{} + p.Name = name + p.Description = description + if err := json.Unmarshal([]byte(raw), &p.Schema); err != nil { + return nil, fmt.Errorf("invalid JSON schema: %w", err) + } + p.Strict = true + return p, nil +} + +type dynamicSchemaProvider struct { + cached *oai.ChatCompletionResponseFormatJSONSchema + nextRefreshTime time.Time + refreshInterval time.Duration + mu sync.Mutex + + client *sr.Client + subject string + namePrefix string +} + +func (p *dynamicSchemaProvider) GetJSONSchema(ctx context.Context) (*oai.ChatCompletionResponseFormatJSONSchema, error) { + if time.Now().Before(p.nextRefreshTime) { + return p.cached, nil + } + p.mu.Lock() + defer p.mu.Unlock() + // Double check since we now have the lock that we didn't race with other requests + if time.Now().Before(p.nextRefreshTime) { + return p.cached, nil + } + info, err := p.client.GetSchemaBySubjectAndVersion(ctx, p.subject, nil) + if err != nil { + return nil, fmt.Errorf("unable to load latest schema for subject %q: %w", p.subject, err) + } + var schema jsonschema.Definition + if err := json.Unmarshal([]byte(info.Schema), &schema); err != nil { + return nil, fmt.Errorf("unable to parse json schema from schema with ID=%d", info.ID) + } + name := fmt.Sprintf("%s%d", p.namePrefix, info.ID) + p.cached = &oai.ChatCompletionResponseFormatJSONSchema{ + Name: name, + Schema: schema, + Strict: true, + } + p.nextRefreshTime = time.Now().Add(p.refreshInterval) + return p.cached, nil +} + +func newDynamicSchema(client *sr.Client, subject, namePrefix string, refreshInterval time.Duration) jsonSchemaProvider { + return &dynamicSchemaProvider{ + cached: nil, + nextRefreshTime: time.UnixMilli(0), + refreshInterval: refreshInterval, + client: client, + subject: subject, + namePrefix: namePrefix, + } +}