Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions src/api/chat_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,14 @@ pub struct ChatCompletionRequest {
#[builder(default, setter(into))]
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<Tool>,
/// Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces the model to call that function. none is the default when no functions are present. auto is the default if functions are present.
/// Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. none is the default when no functions are present. auto is the default if functions are present. Leave it None if tool_choice_force is Some.
#[builder(default, setter(strip_option))]
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<ToolChoice>,
/// Just like tool_choice but specify a particular function via {"type: "function", "function": {"name": "my_function"}} to force the model to call that function.
#[builder(default, setter(strip_option))]
#[serde(skip_serializing_if = "Option::is_none", rename = "tool_choice")]
tool_choice_force: Option<SpecifiedToolChoice>,
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
#[builder(default, setter(strip_option, into))]
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -77,10 +81,18 @@ pub enum ToolChoice {
#[default]
None,
Auto,
// TODO: we need something like this: #[serde(tag = "type", content = "function")]
Function {
name: String,
},
}

/// {"type: "function", "function": {"name": "your_function"}}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum SpecifiedToolChoice {
Function { function: FunctionName },
}

#[derive(Debug, Clone, Serialize)]
pub struct FunctionName {
name: String,
}

#[derive(Debug, Clone, Serialize)]
Expand Down Expand Up @@ -111,6 +123,7 @@ pub struct ChatResponseFormatObject {
pub enum ChatResponseFormat {
Text,
#[default]
#[serde(rename = "json_object")]
Json,
}

Expand Down Expand Up @@ -374,11 +387,12 @@ mod tests {
}

#[test]
#[ignore]
fn chat_completion_request_tool_choice_function_serialize_should_work() {
let req = ChatCompletionRequestBuilder::default()
.tool_choice(ToolChoice::Function {
name: "my_function".to_string(),
.tool_choice_force(SpecifiedToolChoice::Function {
function: FunctionName {
name: "my_function".to_string(),
},
})
.messages(vec![])
.build()
Expand All @@ -393,6 +407,7 @@ mod tests {
"name": "my_function"
}
},
"model": "gpt-3.5-turbo-1106",
"messages": []
})
);
Expand Down