Skip to content
Draft
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ tokio = { version = "1.26.0", features = ["full"] }
anyhow = "1.0.70"
futures-util = "0.3.28"
bytes = "1.4.0"
schemars = "0.8.22"

[dev-dependencies]
dotenvy = "0.15.7"
Expand Down
257 changes: 239 additions & 18 deletions src/chat.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
//! Given a chat conversation, the model will return a chat completion response.
pub mod structured_output;

use super::{openai_post, ApiResponseOrError, Credentials, Usage};
use crate::openai_request_stream;
use derive_builder::Builder;
use futures_util::StreamExt;
use reqwest::Method;
use reqwest_eventsource::{CannotCloneRequestError, Event, EventSource};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use structured_output::{
ChatCompletionResponseFormatJsonSchema, JsonSchemaStyle, ToolCallFunctionDefinition,
};
use tokio::sync::mpsc::{channel, Receiver, Sender};

/// A full chat completion.
Expand Down Expand Up @@ -60,6 +65,8 @@ pub struct ChatCompletionMessage {
/// The function that ChatGPT called. This should be "None" usually, and is returned by ChatGPT and not provided by the developer
///
/// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
///
/// Deprecated, use `tool_calls` instead

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can leverage Rust's #[deprecated] attribute.

#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<ChatCompletionFunctionCall>,
/// Tool call that this message is responding to.
Expand All @@ -86,6 +93,8 @@ pub struct ChatCompletionMessageDelta {
/// The function that ChatGPT called
///
/// [API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call)
///
/// Deprecated, use `tool_calls` instead

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can leverage Rust's #[deprecated] attribute.

#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<ChatCompletionFunctionCallDelta>,
/// Tool call that this message is responding to.
Expand All @@ -96,15 +105,105 @@ pub struct ChatCompletionMessageDelta {
/// Can only be populated if the role is `Assistant`,
/// otherwise it should be empty.
#[serde(skip_serializing_if = "is_none_or_empty_vec")]
pub tool_calls: Option<Vec<ToolCall>>,
pub tool_calls: Option<Vec<ToolCallDelta>>,
}

#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ChatCompletionTool {
Function {
function: ToolCallFunctionDefinition,
},
}

impl ChatCompletionTool {
pub fn new<T: JsonSchema>(strict: Option<bool>) -> Self {
let function = ToolCallFunctionDefinition::new::<T>(strict);
ChatCompletionTool::Function { function }
}
}

#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]

pub enum ToolChoiceMode {
None,
Auto,
Required,
}

#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
pub struct FunctionChoice {
/// The name of the function to call.
pub name: String,
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct FunctionLiteral;
impl Serialize for FunctionLiteral {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str("function")
}
}
impl<'de> Deserialize<'de> for FunctionLiteral {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: &str = serde::Deserialize::deserialize(deserializer)?;
if s != "function" {
return Err(serde::de::Error::custom("expected function"));
}
Ok(FunctionLiteral)
}
}

#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
#[serde(untagged)]
pub enum ToolChoice {
/// `none` means the model will not call any tool and instead generates a message.
/// `auto` means the model can pick between generating a message or calling one or more tools.
/// `required` means the model must call one or more tools.
Mode(ToolChoiceMode),
/// The model will call the function with the given name.
Function {
/// The type of the tool. Currently, only `function` is supported.
r#type: FunctionLiteral,
/// The function that the model called.
function: FunctionChoice,
},
}

impl ToolChoice {
pub fn mode(mode: ToolChoiceMode) -> Self {
ToolChoice::Mode(mode)
}
pub fn function(name: String) -> Self {
ToolChoice::Function {
r#type: FunctionLiteral,
function: FunctionChoice { name },
}
}
}

#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
pub struct ToolCall {
/// The ID of the tool call.
pub id: String,
/// The type of the tool. Currently, only `function` is supported.
pub r#type: String,
pub r#type: FunctionLiteral,
/// The function that the model called.
pub function: ToolCallFunction,
}

#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)]
pub struct ToolCallDelta {
pub index: i64,
/// The ID of the tool call.
pub id: String,
/// The type of the tool. Currently, only `function` is supported.
pub r#type: FunctionLiteral,
/// The function that the model called.
pub function: ToolCallFunction,
}
Expand Down Expand Up @@ -165,6 +264,14 @@ pub enum ChatCompletionMessageRole {
Developer,
}

#[derive(Deserialize, Serialize, Debug, Clone, Copy, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ChatCompletionReasoningEffort {
Low,
Medium,
High,
}

#[derive(Serialize, Builder, Debug, Clone)]
#[builder(derive(Clone, Debug, PartialEq))]
#[builder(pattern = "owned")]
Expand All @@ -176,6 +283,12 @@ pub struct ChatCompletionRequest {
model: String,
/// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction).
messages: Vec<ChatCompletionMessage>,
/// Constrains effort on reasoning for (reasoning models)[https://platform.openai.com/docs/guides/reasoning].
/// Currently supported values are low, medium, and high (Defaults to medium).
/// Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response.
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_effort: Option<ChatCompletionReasoningEffort>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
Expand Down Expand Up @@ -204,6 +317,7 @@ pub struct ChatCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")]
seed: Option<u64>,
/// The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will be (4096 - prompt tokens).
#[deprecated(note = "Use max_completion_tokens instead")]
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u64>,
Expand Down Expand Up @@ -234,12 +348,32 @@ pub struct ChatCompletionRequest {
#[builder(default)]
#[serde(skip_serializing_if = "String::is_empty")]
user: String,
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported.
#[builder(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<ChatCompletionTool>,
/// Controls which (if any) tool is called by the model.
/// `none` means the model will not call any tool and instead generates a message.
/// `auto` means the model can pick between generating a message or calling one or more tools.
/// `required` means the model must call one or more tools.
/// Specifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.
///
/// `none` is the default when no tools are present. `auto` is the default if tools are present.
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<ToolChoice>,
/// Whether to enable parallel function calling during tool use.
/// Defaults to true.
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
parallel_tool_calls: Option<bool>,
/// Describe functions that ChatGPT can call
/// The latest models of ChatGPT support function calling, which allows you to define functions that can be called from the prompt.
/// For example, you can define a function called "get_weather" that returns the weather in a given city
///
/// [Function calling API Reference](https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions)
/// [See more information about function calling in ChatGPT.](https://platform.openai.com/docs/guides/gpt/function-calling)
#[deprecated(note = "Use tools instead")]
#[builder(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
functions: Vec<ChatCompletionFunctionDefinition>,
Expand All @@ -252,6 +386,7 @@ pub struct ChatCompletionRequest {
/// - Specifying a particular function via {"name":\ "my_function"} forces the model to call that function.
///
/// "none" is the default when no functions are present. "auto" is the default if functions are present.
#[deprecated(note = "Use tool_choice instead")]
#[builder(default)]
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<Value>,
Expand All @@ -278,23 +413,25 @@ pub struct VeniceParameters {
}

#[derive(Serialize, Debug, Clone, Eq, PartialEq)]
pub struct ChatCompletionResponseFormat {
/// Must be one of text or json_object (defaults to text)
#[serde(rename = "type")]
typ: String,
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ChatCompletionResponseFormat {
Text,
JsonObject,
JsonSchema {
json_schema: ChatCompletionResponseFormatJsonSchema,
},
}

impl ChatCompletionResponseFormat {
pub fn text() -> Self {
ChatCompletionResponseFormat::Text
}
pub fn json_object() -> Self {
ChatCompletionResponseFormat {
typ: "json_object".to_string(),
}
ChatCompletionResponseFormat::JsonObject
}

pub fn text() -> Self {
ChatCompletionResponseFormat {
typ: "text".to_string(),
}
pub fn json_schema<T: JsonSchema>(strict: bool, json_style: JsonSchemaStyle) -> Self {
let json_schema = ChatCompletionResponseFormatJsonSchema::new::<T>(strict, json_style);
ChatCompletionResponseFormat::JsonSchema { json_schema }
}
}

Expand Down Expand Up @@ -560,7 +697,7 @@ mod tests {
}],
)
.temperature(0.0)
.response_format(ChatCompletionResponseFormat::text())
.response_format(ChatCompletionResponseFormat::Text)
.credentials(credentials)
.create()
.await
Expand Down Expand Up @@ -746,7 +883,7 @@ mod tests {
)
.temperature(0.0)
.seed(1337u64)
.response_format(ChatCompletionResponseFormat::json_object())
.response_format(ChatCompletionResponseFormat::JsonObject)
.credentials(credentials)
.create()
.await
Expand Down Expand Up @@ -804,6 +941,90 @@ mod tests {
merged.unwrap().into()
}

#[derive(JsonSchema, Deserialize, Debug, Eq, PartialEq)]
enum Race {
Black,
White,
Asian,
Other(String),
}
#[derive(JsonSchema, Deserialize, Debug, PartialEq)]
enum Species {
Human(Race),
Orc { color: String, leader: String },
}
#[derive(JsonSchema, Deserialize, Debug, PartialEq)]
struct Character {
pub name: String,
pub age: i64,
pub power: f64,
pub skills: Vec<Skill>,
pub species: Species,
}
#[derive(JsonSchema, Deserialize, Debug, PartialEq)]
struct Skill {
pub name: String,
pub description: Option<String>,
pub dont_use_this_property: Option<String>,
}

#[tokio::test]
async fn chat_structured_output_completion() {
dotenv().ok();
let credentials = Credentials::from_env();

let format =
ChatCompletionResponseFormat::json_schema::<Character>(true, JsonSchemaStyle::OpenAI);
let chat_completion = ChatCompletion::builder(
"gpt-4o-mini",
[ChatCompletionMessage {
role: ChatCompletionMessageRole::User,
content: Some(
"Create a DND character, don't use the dont_use_this_property field"
.to_string(),
),
..Default::default()
}],
)
.credentials(credentials)
.response_format(format)
.create()
.await
.unwrap();
let character_str = chat_completion.choices[0].message.content.as_ref().unwrap();
let _character: Character = serde_json::from_str(character_str).unwrap();
}

#[tokio::test]
async fn chat_tool_use_completion() {
dotenv().ok();
let credentials = Credentials::from_env();
let schema = ChatCompletionTool::new::<Character>(None);
let chat_completion = ChatCompletion::builder(
"gpt-4o-mini",
[ChatCompletionMessage {
role: ChatCompletionMessageRole::User,
content: Some("create a random DND character directly with tools".to_string()),
..Default::default()
}],
)
.credentials(credentials)
.tools(vec![schema])
.tool_choice(ToolChoice::Function {
r#type: FunctionLiteral,
function: FunctionChoice {
name: "Character".to_string(),
},
})
.create()
.await
.unwrap();
let msg = chat_completion.choices[0].message.clone();
let tool_calls = msg.tool_calls.as_ref();
let tool_call: &ToolCall = tool_calls.unwrap().first().unwrap();
let _character: Character = serde_json::from_str(&tool_call.function.arguments).unwrap();
}

#[tokio::test]
async fn chat_tool_response_completion() {
dotenv().ok();
Expand Down Expand Up @@ -835,7 +1056,7 @@ mod tests {
tool_call_id: None,
tool_calls: Some(vec![ToolCall {
id: "the_tool_call".to_string(),
r#type: "function".to_string(),
r#type: FunctionLiteral,
function: ToolCallFunction {
name: "mul".to_string(),
arguments: "not_required_to_be_valid_here".to_string(),
Expand All @@ -847,7 +1068,7 @@ mod tests {
content: Some("the result is 25903.061423199997".to_string()),
name: None,
function_call: None,
tool_call_id: Some("the_tool_call".to_owned()),
tool_call_id: Some("the_tool_call".to_string()),
tool_calls: Some(Vec::new()),
},
],
Expand Down
Loading