Skip to content

Add reasoning content for deepseek R1 API and tackle problems of tool_calls #324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

228 changes: 212 additions & 16 deletions rig-core/src/providers/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
//!
//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
//! ```
use std::{future::Future, time::Duration};

use crate::{
completion::{self, CompletionError, CompletionModel, CompletionRequest},
extractor::ExtractorBuilder,
json_utils, message, OneOrMany,
completion::{self, CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder}, extractor::ExtractorBuilder, json_utils, message, streaming::{StreamingChoice, StreamingCompletion, StreamingCompletionModel, StreamingResult}, OneOrMany
};
use reqwest::Client as HttpClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use futures::StreamExt;

// ================================================================
// Main DeepSeek Client
Expand Down Expand Up @@ -123,7 +124,7 @@ pub struct Choice {
pub finish_reason: String,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message {
System {
Expand All @@ -139,8 +140,10 @@ pub enum Message {
Assistant {
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(default, deserialize_with = "json_utils::null_or_vec")]
#[serde(default, deserialize_with = "json_utils::null_or_vec", skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>,
},
#[serde(rename = "Tool")]
Expand Down Expand Up @@ -244,6 +247,7 @@ impl TryFrom<message::Message> for Vec<Message> {
if !tool_calls.is_empty() {
messages.push(Message::Assistant {
content: "".to_string(),
reasoning_content: None,
name: None,
tool_calls,
});
Expand All @@ -255,6 +259,7 @@ impl TryFrom<message::Message> for Vec<Message> {
.filter_map(|content| match content {
message::AssistantContent::Text(text) => Some(Message::Assistant {
content: text.text,
reasoning_content: None,
name: None,
tool_calls: vec![],
}),
Expand Down Expand Up @@ -318,16 +323,22 @@ impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionRe
let content = match &choice.message {
Message::Assistant {
content,
reasoning_content,
tool_calls,
..
} => {
let mut content = if content.trim().is_empty() {
vec![]
} else {
vec![completion::AssistantContent::text(content)]
};

content.extend(
let mut content_parts = vec![];
if !content.trim().is_empty() {
if let Some(reasoning) = reasoning_content {
content_parts.push(completion::AssistantContent::text(
format!("<think>\n{}\n</think>\n{}", reasoning, content)
));
} else {
content_parts.push(completion::AssistantContent::text(content));
}
}

content_parts.extend(
tool_calls
.iter()
.map(|call| {
Expand All @@ -339,7 +350,7 @@ impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionRe
})
.collect::<Vec<_>>(),
);
Ok(content)
Ok(content_parts)
}
_ => Err(CompletionError::ResponseError(
"Response did not contain a valid message or tool call".into(),
Expand Down Expand Up @@ -426,13 +437,13 @@ impl CompletionModel for DeepSeekCompletionModel {
request
},
)
.timeout(Duration::from_secs(120))
.send()
.await?;

if response.status().is_success() {
let t = response.text().await?;
tracing::debug!(target: "rig", "OpenAI completion error: {}", t);

match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
ApiResponse::Ok(response) => response.try_into(),
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
Expand All @@ -443,6 +454,190 @@ impl CompletionModel for DeepSeekCompletionModel {
}
}

#[derive(Debug, Deserialize)]
struct DeepSeekStreamChunk {
choices: Vec<StreamChoice>,
}

#[derive(Debug, Deserialize)]
struct StreamChoice {
delta: StreamDelta,
}

#[derive(Debug, Deserialize)]
struct StreamDelta {
content: Option<String>,
reasoning_content: Option<String>,
_role: Option<String>,
#[serde(default)]
tool_calls: Vec<ToolCall>,
}

impl StreamingCompletion<DeepSeekCompletionModel> for DeepSeekCompletionModel {
fn stream_completion(
&self,
prompt: &str,
_chat_history: Vec<crate::message::Message>,
) -> impl Future<Output = Result<CompletionRequestBuilder<DeepSeekCompletionModel>, CompletionError>> {
let this = self.clone();
async move {
Ok(CompletionRequestBuilder::new(this, prompt))
}
}
}

impl StreamingCompletionModel for DeepSeekCompletionModel {
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
// 添加前言至聊天历史(如果有)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message::system(preamble)],
None => vec![],
};

// 转换提示为用户消息
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;

// 转换现有聊天历史
let chat_history: Vec<Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Vec<Message>>, _>>()?
.into_iter()
.flatten()
.collect();

// 合并所有消息到单一历史
full_history.extend(chat_history);
full_history.extend(prompt);

let request = if completion_request.tools.is_empty() {
json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
"stream": true
})
} else {
json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
"tool_choice": "auto",
"stream": true
})
};

let response = self
.client
.post("/chat/completions")
.json(
&if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
},
)
.send()
.await?;

if !response.status().is_success() {
return Err(CompletionError::ProviderError(response.text().await?));
}

Ok(Box::pin(async_stream::stream! {
let mut stream = response.bytes_stream();
let mut reasoning_started = false;
let mut reasoning_ended = false;

while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(CompletionError::from(e));
break;
}
};

let text = match String::from_utf8(chunk.to_vec()) {
Ok(t) => t,
Err(e) => {
yield Err(CompletionError::ResponseError(e.to_string()));
break;
}
};
for line in text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data.trim() == "[DONE]" {
break;
}

match serde_json::from_str::<DeepSeekStreamChunk>(data) {
Ok(chunk) => {
for choice in chunk.choices {
if let Some(reasoning) = &choice.delta.reasoning_content {
if !reasoning.is_empty() {
if !reasoning_started {
yield Ok(StreamingChoice::Reasoning("<think>\n".to_string()));
reasoning_started = true;
}
yield Ok(StreamingChoice::Reasoning(reasoning.clone()));
}
}

if let Some(content) = &choice.delta.content {
if reasoning_started && !reasoning_ended && !content.is_empty() {
yield Ok(StreamingChoice::Reasoning("\n</think>\n".to_string()));
reasoning_ended = true;
}

if !content.is_empty() {
yield Ok(StreamingChoice::Message(content.clone()));
}
}

if !choice.delta.tool_calls.is_empty() {
if reasoning_started && !reasoning_ended {
yield Ok(StreamingChoice::Message("\n</think>\n".to_string()));
reasoning_ended = true;
}

let tool_call = &choice.delta.tool_calls[0];

match serde_json::from_str(&tool_call.function.arguments.to_string()) {
Ok(json_value) => {
yield Ok(StreamingChoice::ToolCall(
tool_call.function.name.clone(),
tool_call.id.clone(),
json_value,
));
}
Err(e) => {
yield Err(CompletionError::from(e));
}
}
}
}
}
Err(e) => {
if !data.trim().is_empty() && data != "[DONE]" {
yield Err(CompletionError::ResponseError(format!("Failed to parse stream chunk: {}", e)));
}
}
}
} else if line.starts_with(": keep-alive") {
continue;
}
}
}
}))
}
}

// ================================================================
// DeepSeek Completion API
// ================================================================
Expand Down Expand Up @@ -510,7 +705,7 @@ mod tests {
"index": 0,
"message": {
"role": "assistant",
"content": "Why dont skeletons fight each other? \nBecause they dont have the guts! 😄"
"content": "Why don't skeletons fight each other? \nBecause they don't have the guts! 😄"
},
"logprobs": null,
"finish_reason": "stop"
Expand All @@ -536,7 +731,7 @@ mod tests {
Ok(response) => match &response.choices.first().unwrap().message {
Message::Assistant { content, .. } => assert_eq!(
content,
"Why dont skeletons fight each other? \nBecause they dont have the guts! 😄"
"Why don't skeletons fight each other? \nBecause they don't have the guts! 😄"
),
_ => panic!("Expected assistant message"),
},
Expand Down Expand Up @@ -579,6 +774,7 @@ mod tests {
logprobs: None,
message: Message::Assistant {
content: "".to_string(),
reasoning_content: None,
name: None,
tool_calls: vec![ToolCall {
id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
Expand Down
Loading