diff --git a/Cargo.toml b/Cargo.toml index 02e32d7..b1f62f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ tokio = { version = "1.26.0", features = ["full"] } anyhow = "1.0.70" futures-util = "0.3.28" bytes = "1.4.0" +schemars = "0.8.22" [dev-dependencies] dotenvy = "0.15.7" diff --git a/src/chat.rs b/src/chat.rs index 9b366b6..19a3402 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -1,4 +1,5 @@ //! Given a chat conversation, the model will return a chat completion response. +pub mod structured_output; use super::{openai_post, ApiResponseOrError, Credentials, Usage}; use crate::openai_request_stream; @@ -6,9 +7,13 @@ use derive_builder::Builder; use futures_util::StreamExt; use reqwest::Method; use reqwest_eventsource::{CannotCloneRequestError, Event, EventSource}; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; +use structured_output::{ + ChatCompletionResponseFormatJsonSchema, JsonSchemaStyle, ToolCallFunctionDefinition, +}; use tokio::sync::mpsc::{channel, Receiver, Sender}; /// A full chat completion. @@ -60,6 +65,7 @@ pub struct ChatCompletionMessage { /// The function that ChatGPT called. This should be "None" usually, and is returned by ChatGPT and not provided by the developer /// /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) + #[deprecated(note = "Use `tool_calls` instead")] #[serde(skip_serializing_if = "Option::is_none")] pub function_call: Option, /// Tool call that this message is responding to. @@ -86,6 +92,7 @@ pub struct ChatCompletionMessageDelta { /// The function that ChatGPT called /// /// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call) + #[deprecated(note = "Use `tool_calls` instead")] #[serde(skip_serializing_if = "Option::is_none")] pub function_call: Option, /// Tool call that this message is responding to. @@ -96,7 +103,69 @@ pub struct ChatCompletionMessageDelta { /// Can only be populated if the role is `Assistant`, /// otherwise it should be empty. #[serde(skip_serializing_if = "is_none_or_empty_vec")] - pub tool_calls: Option>, + pub tool_calls: Option>, +} + +#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ChatCompletionTool { + Function { + function: ToolCallFunctionDefinition, + }, +} + +impl ChatCompletionTool { + pub fn new(strict: Option) -> Self { + let function = ToolCallFunctionDefinition::new::(strict); + ChatCompletionTool::Function { function } + } +} + +#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] +pub enum ToolChoiceMode { + None, + Auto, + Required, +} + +#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] +pub struct FunctionChoice { + /// The name of the function to call. + pub name: String, +} + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum FunctionType { + Function, +} + +#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] +#[serde(untagged)] +pub enum ToolChoice { + /// `none` means the model will not call any tool and instead generates a message. + /// `auto` means the model can pick between generating a message or calling one or more tools. + /// `required` means the model must call one or more tools. + Mode(ToolChoiceMode), + /// The model will call the function with the given name. + Function { + /// The type of the tool. Currently, only `function` is supported. + r#type: FunctionType, + /// The function that the model called. + function: FunctionChoice, + }, +} + +impl ToolChoice { + pub fn mode(mode: ToolChoiceMode) -> Self { + ToolChoice::Mode(mode) + } + pub fn function(name: String) -> Self { + ToolChoice::Function { + r#type: FunctionType::Function, + function: FunctionChoice { name }, + } + } } #[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] @@ -104,11 +173,22 @@ pub struct ToolCall { /// The ID of the tool call. pub id: String, /// The type of the tool. Currently, only `function` is supported. - pub r#type: String, + pub r#type: FunctionType, /// The function that the model called. pub function: ToolCallFunction, } +#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] +pub struct ToolCallDelta { + pub index: i64, + /// The ID of the tool call. + pub id: Option, + /// The type of the tool. Currently, only `function` is supported. + pub r#type: Option, + /// The function that the model called. + pub function: Option, +} + #[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] pub struct ToolCallFunction { /// The name of the function to call. @@ -165,6 +245,14 @@ pub enum ChatCompletionMessageRole { Developer, } +#[derive(Deserialize, Serialize, Debug, Clone, Copy, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ChatCompletionReasoningEffort { + Low, + Medium, + High, +} + #[derive(Serialize, Builder, Debug, Clone)] #[builder(derive(Clone, Debug, PartialEq))] #[builder(pattern = "owned")] @@ -176,6 +264,12 @@ pub struct ChatCompletionRequest { model: String, /// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction). messages: Vec, + /// Constrains effort on reasoning for (reasoning models)[https://platform.openai.com/docs/guides/reasoning]. + /// Currently supported values are low, medium, and high (Defaults to medium). + /// Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response. + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_effort: Option, /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. /// /// We generally recommend altering this or `top_p` but not both. @@ -204,6 +298,7 @@ pub struct ChatCompletionRequest { #[serde(skip_serializing_if = "Option::is_none")] seed: Option, /// The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will be (4096 - prompt tokens). + #[deprecated(note = "Use max_completion_tokens instead")] #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, @@ -234,12 +329,32 @@ pub struct ChatCompletionRequest { #[builder(default)] #[serde(skip_serializing_if = "String::is_empty")] user: String, + /// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported. + #[builder(default)] + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + /// Controls which (if any) tool is called by the model. + /// `none` means the model will not call any tool and instead generates a message. + /// `auto` means the model can pick between generating a message or calling one or more tools. + /// `required` means the model must call one or more tools. + /// Specifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool. + /// + /// `none` is the default when no tools are present. `auto` is the default if tools are present. + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, + /// Whether to enable parallel function calling during tool use. + /// Defaults to true. + #[builder(default)] + #[serde(skip_serializing_if = "Option::is_none")] + parallel_tool_calls: Option, /// Describe functions that ChatGPT can call /// The latest models of ChatGPT support function calling, which allows you to define functions that can be called from the prompt. /// For example, you can define a function called "get_weather" that returns the weather in a given city /// /// [Function calling API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions) /// [See more information about function calling in ChatGPT.](https://platform.openai.com/docs/guides/gpt/function-calling) + #[deprecated(note = "Use tools instead")] #[builder(default)] #[serde(skip_serializing_if = "Vec::is_empty")] functions: Vec, @@ -252,6 +367,7 @@ pub struct ChatCompletionRequest { /// - Specifying a particular function via {"name":\ "my_function"} forces the model to call that function. /// /// "none" is the default when no functions are present. "auto" is the default if functions are present. + #[deprecated(note = "Use tool_choice instead")] #[builder(default)] #[serde(skip_serializing_if = "Option::is_none")] function_call: Option, @@ -278,23 +394,25 @@ pub struct VeniceParameters { } #[derive(Serialize, Debug, Clone, Eq, PartialEq)] -pub struct ChatCompletionResponseFormat { - /// Must be one of text or json_object (defaults to text) - #[serde(rename = "type")] - typ: String, +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ChatCompletionResponseFormat { + Text, + JsonObject, + JsonSchema { + json_schema: ChatCompletionResponseFormatJsonSchema, + }, } impl ChatCompletionResponseFormat { + pub fn text() -> Self { + ChatCompletionResponseFormat::Text + } pub fn json_object() -> Self { - ChatCompletionResponseFormat { - typ: "json_object".to_string(), - } + ChatCompletionResponseFormat::JsonObject } - - pub fn text() -> Self { - ChatCompletionResponseFormat { - typ: "text".to_string(), - } + pub fn json_schema(strict: bool, json_style: JsonSchemaStyle) -> Self { + let json_schema = ChatCompletionResponseFormatJsonSchema::new::(strict, json_style); + ChatCompletionResponseFormat::JsonSchema { json_schema } } } @@ -398,6 +516,7 @@ impl ChatCompletionChoiceDelta { // merge function calls // function call names are concatenated // arguments are merged by concatenating them + #[allow(deprecated)] match self.delta.function_call.as_mut() { Some(function_call) => { match &other.delta.function_call { @@ -432,6 +551,7 @@ impl ChatCompletionChoiceDelta { impl From for ChatCompletion { fn from(delta: ChatCompletionDelta) -> Self { + #[allow(deprecated)] ChatCompletion { id: delta.id, object: delta.object, @@ -539,6 +659,7 @@ impl Default for ChatCompletionMessageRole { } #[cfg(test)] +#[allow(deprecated)] mod tests { use super::*; use dotenvy::dotenv; @@ -560,7 +681,7 @@ mod tests { }], ) .temperature(0.0) - .response_format(ChatCompletionResponseFormat::text()) + .response_format(ChatCompletionResponseFormat::Text) .credentials(credentials) .create() .await @@ -746,7 +867,7 @@ mod tests { ) .temperature(0.0) .seed(1337u64) - .response_format(ChatCompletionResponseFormat::json_object()) + .response_format(ChatCompletionResponseFormat::JsonObject) .credentials(credentials) .create() .await @@ -804,6 +925,90 @@ mod tests { merged.unwrap().into() } + #[derive(JsonSchema, Deserialize, Debug, Eq, PartialEq)] + enum Race { + Black, + White, + Asian, + Other(String), + } + #[derive(JsonSchema, Deserialize, Debug, PartialEq)] + enum Species { + Human(Race), + Orc { color: String, leader: String }, + } + #[derive(JsonSchema, Deserialize, Debug, PartialEq)] + struct Character { + pub name: String, + pub age: i64, + pub power: f64, + pub skills: Vec, + pub species: Species, + } + #[derive(JsonSchema, Deserialize, Debug, PartialEq)] + struct Skill { + pub name: String, + pub description: Option, + pub dont_use_this_property: Option, + } + + #[tokio::test] + async fn chat_structured_output_completion() { + dotenv().ok(); + let credentials = Credentials::from_env(); + + let format = + ChatCompletionResponseFormat::json_schema::(true, JsonSchemaStyle::OpenAI); + let chat_completion = ChatCompletion::builder( + "gpt-4o-mini", + [ChatCompletionMessage { + role: ChatCompletionMessageRole::User, + content: Some( + "Create a DND character, don't use the dont_use_this_property field" + .to_string(), + ), + ..Default::default() + }], + ) + .credentials(credentials) + .response_format(format) + .create() + .await + .unwrap(); + let character_str = chat_completion.choices[0].message.content.as_ref().unwrap(); + let _character: Character = serde_json::from_str(character_str).unwrap(); + } + + #[tokio::test] + async fn chat_tool_use_completion() { + dotenv().ok(); + let credentials = Credentials::from_env(); + let schema = ChatCompletionTool::new::(None); + let chat_completion = ChatCompletion::builder( + "gpt-4o-mini", + [ChatCompletionMessage { + role: ChatCompletionMessageRole::User, + content: Some("create a random DND character directly with tools".to_string()), + ..Default::default() + }], + ) + .credentials(credentials) + .tools(vec![schema]) + .tool_choice(ToolChoice::Function { + r#type: FunctionType::Function, + function: FunctionChoice { + name: "Character".to_string(), + }, + }) + .create() + .await + .unwrap(); + let msg = chat_completion.choices[0].message.clone(); + let tool_calls = msg.tool_calls.as_ref(); + let tool_call: &ToolCall = tool_calls.unwrap().first().unwrap(); + let _character: Character = serde_json::from_str(&tool_call.function.arguments).unwrap(); + } + #[tokio::test] async fn chat_tool_response_completion() { dotenv().ok(); @@ -835,7 +1040,7 @@ mod tests { tool_call_id: None, tool_calls: Some(vec![ToolCall { id: "the_tool_call".to_string(), - r#type: "function".to_string(), + r#type: FunctionType::Function, function: ToolCallFunction { name: "mul".to_string(), arguments: "not_required_to_be_valid_here".to_string(), @@ -847,7 +1052,7 @@ mod tests { content: Some("the result is 25903.061423199997".to_string()), name: None, function_call: None, - tool_call_id: Some("the_tool_call".to_owned()), + tool_call_id: Some("the_tool_call".to_string()), tool_calls: Some(Vec::new()), }, ], diff --git a/src/chat/structured_output.rs b/src/chat/structured_output.rs new file mode 100644 index 0000000..385adaf --- /dev/null +++ b/src/chat/structured_output.rs @@ -0,0 +1,147 @@ +use std::mem::take; + +use schemars::{ + schema::{Schema, SchemaObject}, + schema_for, + visit::{visit_schema_object, Visitor}, + JsonSchema, +}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum JsonSchemaStyle { + OpenAI, + Grok, +} + +#[derive(Serialize, Debug, Clone, Eq, PartialEq)] +pub struct ChatCompletionResponseFormatJsonSchema { + /// The name of the response format. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + pub name: String, + /// A description of what the response format is for, used by the model to determine how to respond in the format. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// The schema for the response format, described as a JSON Schema object. + #[serde(skip_serializing_if = "Option::is_none")] + pub schema: Option, + /// Whether to enable strict schema adherence when generating the output. + /// If set to true, the model will always follow the exact schema defined in the schema field. + /// Only a subset of JSON Schema is supported when strict is true. + /// To learn more, read the [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + /// + /// defaults to false + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +impl ChatCompletionResponseFormatJsonSchema { + pub fn new(strict: bool, json_style: JsonSchemaStyle) -> Self { + let (schema, description) = generate_json_schema::(json_style); + ChatCompletionResponseFormatJsonSchema { + name: T::schema_name(), + description, + schema: Some(schema), + strict: Some(strict), + } + } +} + +#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] +pub struct ToolCallFunctionDefinition { + /// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + pub name: String, + /// A description of what the function does, used by the model to choose when and how to call the function. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// The parameters the functions accepts, described as a JSON Schema object. + /// See the [guide](https://platform.openai.com/docs/guides/function-calling) for examples, + /// and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/reference) for documentation about the format. + /// Omitting `parameters` defines a function with an empty parameter list. + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, + /// Whether to enable strict schema adherence when generating the function call. + /// If set to true, the model will follow the exact schema defined in the `parameters` field. + /// Only a subset of JSON Schema is supported when `strict` is `true`. + /// Learn more about Structured Outputs in the [function calling guide](https://platform.openai.com/docs/api-reference/chat/docs/guides/function-calling). + /// + /// defaults to false + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +impl ToolCallFunctionDefinition { + /// Create a new ToolCallFunctionDefinition with the given strictness and JSON Schema style. + /// + /// Note: Grok tools does not support strict schema adherence, need to set `strict` to None. + pub fn new(strict: Option) -> Self { + let schema = schema_for!(T); + let description = if let Some(metadata) = &schema.schema.metadata { + metadata.description.clone() + } else { + None + }; + ToolCallFunctionDefinition { + description, + name: T::schema_name(), + parameters: Some(json!(schema)), + strict, + } + } +} + +/// Generate a JSON Schema with the given style. +/// +/// IMPORTANT: Both OpenAI and Grok do not support the `format` and `minimum` JSON Schema attributes. +/// As a result, numeric type constraints (like `u8`, `i32`, etc) cannot be enforced - all integers +/// will be treated as `i64` and all floating point numbers as `f64`. +pub fn generate_json_schema(json_style: JsonSchemaStyle) -> (Value, Option) { + let mut settings = schemars::r#gen::SchemaSettings::default(); + settings.option_nullable = false; + settings.inline_subschemas = true; + settings.option_add_null_type = match json_style { + JsonSchemaStyle::OpenAI => true, + JsonSchemaStyle::Grok => false, + }; + let mut generator = schemars::SchemaGenerator::new(settings); + let mut schema = T::json_schema(&mut generator).into_object(); + let description = schema.metadata().description.clone(); + let mut processor = SchemaPostProcessor { style: json_style }; + processor.visit_schema_object(&mut schema); + let schema = serde_json::to_value(schema).expect("unreachable"); + (schema, description) +} + +pub struct SchemaPostProcessor { + pub style: JsonSchemaStyle, +} + +impl Visitor for SchemaPostProcessor { + fn visit_schema_object(&mut self, schema: &mut SchemaObject) { + if let Some(sub) = &mut schema.subschemas { + sub.any_of = take(&mut sub.one_of); + } + schema.format = None; + if let Some(sub) = &mut schema.object { + if self.style == JsonSchemaStyle::OpenAI { + if sub.additional_properties.is_none() { + sub.additional_properties = Some(Box::new(Schema::Bool(false))); + } + sub.required = sub.properties.keys().map(|s| s.clone()).collect(); + } + } + if let Some(num) = &mut schema.number { + num.multiple_of = None; + num.exclusive_maximum = None; + num.exclusive_minimum = None; + num.maximum = None; + num.minimum = None; + } + if let Some(str) = &mut schema.string { + str.max_length = None; + str.min_length = None; + str.pattern = None; + } + visit_schema_object(self, schema); + } +}