Skip to content

feat: align function id with tool call response #3111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,10 @@
"type": "string",
"nullable": true
},
"id": {
"type": "string",
"nullable": true
},
"name": {
"type": "string"
}
Expand Down Expand Up @@ -1883,6 +1887,11 @@
"role": {
"type": "string",
"example": "user"
},
"tool_call_id": {
"type": "string",
"example": "10",
"nullable": true
}
}
}
Expand Down
1 change: 1 addition & 0 deletions router/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub(crate) fn parse_output(generated_text: &str) -> Result<ChatChoice, InferErro
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionDefinition {
id: None,
description: None,
name: name.to_string(),
arguments: serde_json::to_value(call.function.arguments).map_err(|err| {
Expand Down
27 changes: 20 additions & 7 deletions router/src/infer/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,21 @@ impl ChatTemplate {

let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
let final_message = messages.last().cloned();
let template_inputs = ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools,
};

// NOTE: initalizing `template_inputs` is helpful when JSON dumping the
// `ChatTemplateInputs` struct for debugging
// let template_inputs_as_json = serde_json::to_string(&template_inputs).unwrap();

let mut rendered_template = self
.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools,
})
.render(template_inputs)
.map_err(InferError::TemplateError)?;

// if the last message is from the assistant, continue the generation prompt
Expand Down Expand Up @@ -1175,6 +1181,7 @@ TOOL CALL ID: 0
"I'd like to show off how chat templating works!".to_string(),
),
},
tool_call_id: None,
},
Message {
name: None,
Expand All @@ -1184,13 +1191,15 @@ TOOL CALL ID: 0
"Great! How can I help you today?".to_string(),
),
},
tool_call_id: None,
},
Message {
name: None,
role: "user".to_string(),
body: MessageBody::Content {
content: MessageContent::SingleText("Just testing".to_string()),
},
tool_call_id: None,
},
];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
Expand Down Expand Up @@ -1220,6 +1229,7 @@ TOOL CALL ID: 0
.to_string(),
),
},
tool_call_id: None,
},
Message {
name: None,
Expand All @@ -1229,6 +1239,7 @@ TOOL CALL ID: 0
"What is the weather like in Brooklyn, New York?".to_string(),
),
},
tool_call_id: None,
},
];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
Expand Down Expand Up @@ -1299,6 +1310,7 @@ TOOL CALL ID: 0
text: "You are a helpful assistant.".to_string(),
}]),
},
tool_call_id: None,
},
Message {
name: None,
Expand Down Expand Up @@ -1326,6 +1338,7 @@ TOOL CALL ID: 0
},
]),
},
tool_call_id: None,
},
];

Expand Down
1 change: 1 addition & 0 deletions router/src/infer/tool_grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl ToolGrammar {
.chain(std::iter::once(Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
id: None,
name: "no_tool".to_string(),
description: Some(
"Open ended response with no specific tool selected".to_string(),
Expand Down
42 changes: 33 additions & 9 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError};
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokenizers::Encoding;
use tracing::warn;
use utoipa::ToSchema;
Expand Down Expand Up @@ -912,7 +913,10 @@ pub(crate) struct ChatRequest {
}

impl ChatRequest {
fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> {
fn try_into_generate(
self,
infer: &Infer,
) -> Result<(GenerateRequest, Option<HashMap<String, String>>), InferError> {
let ChatRequest {
model,
max_tokens,
Expand Down Expand Up @@ -952,7 +956,7 @@ impl ChatRequest {
let (inputs, grammar, using_tools) = match response_format {
Some(format) => {
let inputs = infer.apply_chat_template(messages, None)?;
(inputs, Some(format), false)
(inputs, Some(format), None)
}
None => {
if let Some(tools) = tools {
Expand All @@ -961,20 +965,31 @@ impl ChatRequest {
let grammar = GrammarType::Json(serde_json::json!(tool_schema));
let inputs: String = infer.apply_chat_template(
messages,
Some((updated_tools, tool_prompt)),
Some((updated_tools.clone(), tool_prompt)),
)?;
(inputs, Some(grammar), true)
let tool_name_to_id: HashMap<String, String> = updated_tools
.into_iter()
.map(|tool| {
(
tool.function.name,
tool.function
.id
.map_or_else(|| "0".to_string(), |id| id.to_string()),
)
})
.collect();
(inputs, Some(grammar), Some(tool_name_to_id))
}
None => {
// same as if no response_format or tools are set
let inputs = infer.apply_chat_template(messages, None)?;
(inputs, None, false)
(inputs, None, None)
}
}
} else {
// if no response_format or tools are set simply apply the chat template to generate inputs
let inputs = infer.apply_chat_template(messages, None)?;
(inputs, None, false)
(inputs, None, None)
}
}
};
Expand Down Expand Up @@ -1154,6 +1169,8 @@ pub struct FunctionDefinition {
#[serde(default)]
pub description: Option<String>,
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(alias = "parameters", serialize_with = "serialize_as_string")]
pub arguments: serde_json::Value,
}
Expand All @@ -1175,7 +1192,7 @@ pub(crate) struct Tool {
pub function: FunctionDefinition,
}

#[derive(Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<TextMessage>,
bos_token: Option<&'a str>,
Expand Down Expand Up @@ -1208,6 +1225,9 @@ pub enum MessageChunk {
pub struct Message {
#[schema(example = "user")]
pub role: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "10")]
pub tool_call_id: Option<String>,
#[serde(flatten)]
#[schema(example = "My name is David and I")]
pub body: MessageBody,
Expand Down Expand Up @@ -1287,7 +1307,7 @@ impl From<Message> for TextMessage {
.collect::<Vec<_>>()
.join(""),
},
..Default::default()
tool_call_id: value.tool_call_id,
}
}
}
Expand Down Expand Up @@ -1624,6 +1644,7 @@ mod tests {
body: MessageBody::Content {
content: MessageContent::SingleText("What is Deep Learning?".to_string())
},
tool_call_id: None,
}
);
}
Expand Down Expand Up @@ -1683,6 +1704,7 @@ mod tests {
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }},
]),
},
tool_call_id: None,
}
);
}
Expand All @@ -1697,7 +1719,8 @@ mod tests {
MessageChunk::Text { text: "Whats in this image?".to_string() },
MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } }
]),
}
},
tool_call_id: None
};
let textmsg: TextMessage = message.into();
assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)");
Expand Down Expand Up @@ -1758,6 +1781,7 @@ mod tests {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionDefinition {
id: None,
description: None,
name: "myfn".to_string(),
arguments: json!({
Expand Down
29 changes: 18 additions & 11 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1165,8 +1165,7 @@ pub(crate) async fn chat_completions(

tracing::debug!("Got chat_template {:?}", infer.chat_template);
let id = chat.next_tool_call_id();
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer)?;
let (generate_request, using_tools) = chat.clone().try_into_generate(&infer)?;
span.record("parameters", format!("{:?}", generate_request.parameters));
let logprobs = logprobs.unwrap_or_default();

Expand All @@ -1188,7 +1187,7 @@ pub(crate) async fn chat_completions(

let response_stream = async_stream::stream! {
let mut response_stream = Box::pin(response_stream);
let mut state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
let mut state = ChatState::new(using_tools.is_some(), stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
while let Some(result) = response_stream.next().await {
match result{
Ok(stream_token) => {
Expand All @@ -1197,12 +1196,12 @@ pub(crate) async fn chat_completions(
ChatEvent::NoTool => {
chat.tools = None;
chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) =
let (generate_request, using_tools) =
chat.clone().try_into_generate(&infer).unwrap();
assert!(!using_tools);
assert!(using_tools.is_none());
let (_headers, response_stream2) =
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
state = ChatState::new(using_tools.is_some(), stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
response_stream = Box::pin(response_stream2);
}
ChatEvent::Events(events) => {
Expand Down Expand Up @@ -1237,14 +1236,13 @@ pub(crate) async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs();

let (tool_calls, output) = if using_tools {
let (tool_calls, output) = if using_tools.is_some() {
match crate::chat::parse_output(&generation.generated_text)? {
ChatChoice::NoTool => {
chat.tools = None;
chat.response_format = None;
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.clone().try_into_generate(&infer)?;
assert!(!using_tools);
let (generate_request, using_tools) = chat.clone().try_into_generate(&infer)?;
assert!(using_tools.is_none());
let (headers_final, input_length_final, Json(generation)) = generate_internal(
Extension(infer),
compute_type,
Expand All @@ -1256,7 +1254,16 @@ pub(crate) async fn chat_completions(
input_length = input_length_final;
(None, Some(generation.generated_text))
}
ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None),
ChatChoice::ToolCalls(mut tool_calls) => {
// assign the tool ids based on the tool names
tool_calls.iter_mut().for_each(|tool_call| {
tool_call.id = using_tools
.as_ref()
.and_then(|tools| tools.get(&tool_call.function.name))
.map_or("0".to_string(), |id| id.clone());
});
(Some(tool_calls), None)
}
}
} else {
(None, Some(generation.generated_text))
Expand Down
4 changes: 2 additions & 2 deletions router/src/vertex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ pub(crate) async fn vertex_compatibility(
},
},
VertexInstance::Chat(instance) => {
let (generate_request, _using_tools): (GenerateRequest, bool) =
instance.try_into_generate(&infer)?;
let (generate_request, _using_tools) = instance.try_into_generate(&infer)?;
generate_request
}
};
Expand Down Expand Up @@ -176,6 +175,7 @@ mod tests {
"What's Deep Learning?".to_string()
)
},
tool_call_id: None,
},],
max_tokens: Some(128),
top_p: Some(0.95),
Expand Down
Loading