diff --git a/docs/openapi.json b/docs/openapi.json index 85ca3f97792..8454a6a33a3 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1520,6 +1520,10 @@ "type": "string", "nullable": true }, + "id": { + "type": "string", + "nullable": true + }, "name": { "type": "string" } @@ -1883,6 +1887,11 @@ "role": { "type": "string", "example": "user" + }, + "tool_call_id": { + "type": "string", + "example": "10", + "nullable": true } } } diff --git a/router/src/chat.rs b/router/src/chat.rs index d5824fea014..0aeb868d3dc 100644 --- a/router/src/chat.rs +++ b/router/src/chat.rs @@ -49,6 +49,7 @@ pub(crate) fn parse_output(generated_text: &str) -> Result = messages.into_iter().map(|c| c.into()).collect(); let final_message = messages.last().cloned(); + let template_inputs = ChatTemplateInputs { + messages, + bos_token: self.bos_token.as_deref(), + eos_token: self.eos_token.as_deref(), + add_generation_prompt: true, + tools, + }; + + // NOTE: initalizing `template_inputs` is helpful when JSON dumping the + // `ChatTemplateInputs` struct for debugging + // let template_inputs_as_json = serde_json::to_string(&template_inputs).unwrap(); + let mut rendered_template = self .template - .render(ChatTemplateInputs { - messages, - bos_token: self.bos_token.as_deref(), - eos_token: self.eos_token.as_deref(), - add_generation_prompt: true, - tools, - }) + .render(template_inputs) .map_err(InferError::TemplateError)?; // if the last message is from the assistant, continue the generation prompt @@ -1175,6 +1181,7 @@ TOOL CALL ID: 0 "I'd like to show off how chat templating works!".to_string(), ), }, + tool_call_id: None, }, Message { name: None, @@ -1184,6 +1191,7 @@ TOOL CALL ID: 0 "Great! How can I help you today?".to_string(), ), }, + tool_call_id: None, }, Message { name: None, @@ -1191,6 +1199,7 @@ TOOL CALL ID: 0 body: MessageBody::Content { content: MessageContent::SingleText("Just testing".to_string()), }, + tool_call_id: None, }, ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); @@ -1220,6 +1229,7 @@ TOOL CALL ID: 0 .to_string(), ), }, + tool_call_id: None, }, Message { name: None, @@ -1229,6 +1239,7 @@ TOOL CALL ID: 0 "What is the weather like in Brooklyn, New York?".to_string(), ), }, + tool_call_id: None, }, ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); @@ -1299,6 +1310,7 @@ TOOL CALL ID: 0 text: "You are a helpful assistant.".to_string(), }]), }, + tool_call_id: None, }, Message { name: None, @@ -1326,6 +1338,7 @@ TOOL CALL ID: 0 }, ]), }, + tool_call_id: None, }, ]; diff --git a/router/src/infer/tool_grammar.rs b/router/src/infer/tool_grammar.rs index e4e2085983f..168df02cc37 100644 --- a/router/src/infer/tool_grammar.rs +++ b/router/src/infer/tool_grammar.rs @@ -34,6 +34,7 @@ impl ToolGrammar { .chain(std::iter::once(Tool { r#type: "function".to_string(), function: FunctionDefinition { + id: None, name: "no_tool".to_string(), description: Some( "Open ended response with no specific tool selected".to_string(), diff --git a/router/src/lib.rs b/router/src/lib.rs index e8b8f663240..77b159c3a6d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError}; use pyo3::prelude::*; use pyo3::types::IntoPyDict; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use tokenizers::Encoding; use tracing::warn; use utoipa::ToSchema; @@ -912,7 +913,10 @@ pub(crate) struct ChatRequest { } impl ChatRequest { - fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> { + fn try_into_generate( + self, + infer: &Infer, + ) -> Result<(GenerateRequest, Option>), InferError> { let ChatRequest { model, max_tokens, @@ -952,7 +956,7 @@ impl ChatRequest { let (inputs, grammar, using_tools) = match response_format { Some(format) => { let inputs = infer.apply_chat_template(messages, None)?; - (inputs, Some(format), false) + (inputs, Some(format), None) } None => { if let Some(tools) = tools { @@ -961,20 +965,31 @@ impl ChatRequest { let grammar = GrammarType::Json(serde_json::json!(tool_schema)); let inputs: String = infer.apply_chat_template( messages, - Some((updated_tools, tool_prompt)), + Some((updated_tools.clone(), tool_prompt)), )?; - (inputs, Some(grammar), true) + let tool_name_to_id: HashMap = updated_tools + .into_iter() + .map(|tool| { + ( + tool.function.name, + tool.function + .id + .map_or_else(|| "0".to_string(), |id| id.to_string()), + ) + }) + .collect(); + (inputs, Some(grammar), Some(tool_name_to_id)) } None => { // same as if no response_format or tools are set let inputs = infer.apply_chat_template(messages, None)?; - (inputs, None, false) + (inputs, None, None) } } } else { // if no response_format or tools are set simply apply the chat template to generate inputs let inputs = infer.apply_chat_template(messages, None)?; - (inputs, None, false) + (inputs, None, None) } } }; @@ -1154,6 +1169,8 @@ pub struct FunctionDefinition { #[serde(default)] pub description: Option, pub name: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub id: Option, #[serde(alias = "parameters", serialize_with = "serialize_as_string")] pub arguments: serde_json::Value, } @@ -1175,7 +1192,7 @@ pub(crate) struct Tool { pub function: FunctionDefinition, } -#[derive(Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub(crate) struct ChatTemplateInputs<'a> { messages: Vec, bos_token: Option<&'a str>, @@ -1208,6 +1225,9 @@ pub enum MessageChunk { pub struct Message { #[schema(example = "user")] pub role: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[schema(example = "10")] + pub tool_call_id: Option, #[serde(flatten)] #[schema(example = "My name is David and I")] pub body: MessageBody, @@ -1287,7 +1307,7 @@ impl From for TextMessage { .collect::>() .join(""), }, - ..Default::default() + tool_call_id: value.tool_call_id, } } } @@ -1624,6 +1644,7 @@ mod tests { body: MessageBody::Content { content: MessageContent::SingleText("What is Deep Learning?".to_string()) }, + tool_call_id: None, } ); } @@ -1683,6 +1704,7 @@ mod tests { MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() }}, ]), }, + tool_call_id: None, } ); } @@ -1697,7 +1719,8 @@ mod tests { MessageChunk::Text { text: "Whats in this image?".to_string() }, MessageChunk::ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } } ]), - } + }, + tool_call_id: None }; let textmsg: TextMessage = message.into(); assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)"); @@ -1758,6 +1781,7 @@ mod tests { id: "0".to_string(), r#type: "function".to_string(), function: FunctionDefinition { + id: None, description: None, name: "myfn".to_string(), arguments: json!({ diff --git a/router/src/server.rs b/router/src/server.rs index 45d2b9f3c42..e04450b6b97 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1165,8 +1165,7 @@ pub(crate) async fn chat_completions( tracing::debug!("Got chat_template {:?}", infer.chat_template); let id = chat.next_tool_call_id(); - let (generate_request, using_tools): (GenerateRequest, bool) = - chat.clone().try_into_generate(&infer)?; + let (generate_request, using_tools) = chat.clone().try_into_generate(&infer)?; span.record("parameters", format!("{:?}", generate_request.parameters)); let logprobs = logprobs.unwrap_or_default(); @@ -1188,7 +1187,7 @@ pub(crate) async fn chat_completions( let response_stream = async_stream::stream! { let mut response_stream = Box::pin(response_stream); - let mut state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone()); + let mut state = ChatState::new(using_tools.is_some(), stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone()); while let Some(result) = response_stream.next().await { match result{ Ok(stream_token) => { @@ -1197,12 +1196,12 @@ pub(crate) async fn chat_completions( ChatEvent::NoTool => { chat.tools = None; chat.response_format = None; - let (generate_request, using_tools): (GenerateRequest, bool) = + let (generate_request, using_tools) = chat.clone().try_into_generate(&infer).unwrap(); - assert!(!using_tools); + assert!(using_tools.is_none()); let (_headers, response_stream2) = generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await; - state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone()); + state = ChatState::new(using_tools.is_some(), stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone()); response_stream = Box::pin(response_stream2); } ChatEvent::Events(events) => { @@ -1237,14 +1236,13 @@ pub(crate) async fn chat_completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - let (tool_calls, output) = if using_tools { + let (tool_calls, output) = if using_tools.is_some() { match crate::chat::parse_output(&generation.generated_text)? { ChatChoice::NoTool => { chat.tools = None; chat.response_format = None; - let (generate_request, using_tools): (GenerateRequest, bool) = - chat.clone().try_into_generate(&infer)?; - assert!(!using_tools); + let (generate_request, using_tools) = chat.clone().try_into_generate(&infer)?; + assert!(using_tools.is_none()); let (headers_final, input_length_final, Json(generation)) = generate_internal( Extension(infer), compute_type, @@ -1256,7 +1254,16 @@ pub(crate) async fn chat_completions( input_length = input_length_final; (None, Some(generation.generated_text)) } - ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None), + ChatChoice::ToolCalls(mut tool_calls) => { + // assign the tool ids based on the tool names + tool_calls.iter_mut().for_each(|tool_call| { + tool_call.id = using_tools + .as_ref() + .and_then(|tools| tools.get(&tool_call.function.name)) + .map_or("0".to_string(), |id| id.clone()); + }); + (Some(tool_calls), None) + } } } else { (None, Some(generation.generated_text)) diff --git a/router/src/vertex.rs b/router/src/vertex.rs index 38695532cc9..c0978d31439 100644 --- a/router/src/vertex.rs +++ b/router/src/vertex.rs @@ -104,8 +104,7 @@ pub(crate) async fn vertex_compatibility( }, }, VertexInstance::Chat(instance) => { - let (generate_request, _using_tools): (GenerateRequest, bool) = - instance.try_into_generate(&infer)?; + let (generate_request, _using_tools) = instance.try_into_generate(&infer)?; generate_request } }; @@ -176,6 +175,7 @@ mod tests { "What's Deep Learning?".to_string() ) }, + tool_call_id: None, },], max_tokens: Some(128), top_p: Some(0.95),