Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions crates/mofa-foundation/src/agent/components/context_compressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,36 @@ impl ContextCompressor for SlidingWindowCompressor {
}

// Split system messages from the rest.
let (system_msgs, mut conversation): (Vec<_>, Vec<_>) =
let (system_msgs, conversation): (Vec<_>, Vec<_>) =
messages.into_iter().partition(|m| m.role == "system");

// Keep only the most-recent window_size messages.
if conversation.len() > self.window_size {
let keep_from = conversation.len() - self.window_size;
conversation = conversation.split_off(keep_from);
// First apply window size limit to non-system messages
let mut limited_conversation = conversation;
if limited_conversation.len() > self.window_size {
let keep_from = limited_conversation.len() - self.window_size;
limited_conversation = limited_conversation.split_off(keep_from);
}

// Further refine to fit into max_tokens budget (always keep system messages)
let system_tokens = self.count_tokens(&system_msgs);
let mut budget = max_tokens.saturating_sub(system_tokens);

let mut final_conversation = Vec::new();
// Work backwards to keep the most recent messages first
for msg in limited_conversation.into_iter().rev() {
let tokens = self.count_tokens(std::slice::from_ref(&msg));
if tokens <= budget {
budget = budget.saturating_sub(tokens);
final_conversation.push(msg);
} else {
// If this message doesn't fit, we stop keeping older ones
break;
}
}
final_conversation.reverse();

let mut result = system_msgs;
result.extend(conversation);
result.extend(final_conversation);
Ok(result)
}

Expand Down
289 changes: 273 additions & 16 deletions crates/mofa-foundation/src/react/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::llm::{LLMAgent, LLMError, LLMResult, Tool};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::Instrument;

Expand Down Expand Up @@ -261,6 +262,16 @@ pub struct ReActConfig {
/// 每步最大 token 数
/// Max tokens per step
pub max_tokens_per_step: Option<u32>,
/// Per-tool-call timeout. Default: 30 seconds.
pub tool_timeout: Duration,
/// Max context tokens. If the conversation exceeds this, it will be compressed.
pub max_context_tokens: Option<usize>,
/// Number of recent steps to keep when compressing.
pub keep_recent_steps: usize,
/// Maximum number of retries for tool execution.
pub tool_max_retries: usize,
/// Whether to enable self-correction/reflection.
pub enable_reflection: bool,
}

impl Default for ReActConfig {
Expand All @@ -272,6 +283,11 @@ impl Default for ReActConfig {
system_prompt: None,
verbose: true,
max_tokens_per_step: Some(2048),
tool_timeout: Duration::from_secs(30),
max_context_tokens: Some(4096),
keep_recent_steps: 4,
tool_max_retries: 2,
enable_reflection: true,
}
}
}
Expand Down Expand Up @@ -352,6 +368,67 @@ impl ReActAgent {
tools.values().cloned().collect()
}

/// 计算 token 数 (简单启发式)
/// Count tokens (simple heuristic)
fn count_tokens(&self, text: &str) -> usize {
text.len() / 4 + 1
}

/// 压缩对话历史
/// Compress conversation history
async fn compress_conversation(&self, conversation: &mut Vec<String>) {
let max_tokens = match self.config.max_context_tokens {
Some(m) => m,
None => return,
};

let current_tokens: usize = conversation.iter().map(|s| self.count_tokens(s)).sum();

if current_tokens > max_tokens {
if self.config.verbose {
tracing::warn!(
"Conversation too long ({} tokens > {}), compressing...",
current_tokens,
max_tokens
);
}

// 保留任务描述 (第一条消息) 和尽可能多的最近步骤,直到符合 token 限制
// Always keep task (first message) and then fill the remaining budget with the most recent messages.
let task = conversation[0].clone();
let task_tokens = self.count_tokens(&task);

// 预留一些 token 用于占位符和安全余量
// Reserve some tokens for the placeholder and safety margin.
let mut budget = max_tokens.saturating_sub(task_tokens + 20);

let mut recent_messages = Vec::new();
let mut dropped_any = false;

// 从后往前遍历,尽量保留最近的上下文
// Iterate backwards from the end of the conversation.
for i in (1..conversation.len()).rev() {
let msg = &conversation[i];
let tokens = self.count_tokens(msg);
if tokens <= budget {
recent_messages.push(msg.clone());
budget = budget.saturating_sub(tokens);
} else {
dropped_any = true;
break;
}
}
recent_messages.reverse();

let mut new_conversation = vec![task];
if dropped_any {
new_conversation.push("... [Older context compressed to save space] ...".to_string());
}
new_conversation.extend(recent_messages);
*conversation = new_conversation;
}
}

/// 执行任务
/// Execute task
pub async fn run(&self, task: impl Into<String>) -> LLMResult<ReActResult> {
Expand All @@ -373,6 +450,10 @@ impl ReActAgent {
for iteration in 0..self.config.max_iterations {
step_number += 1;

// 压缩对话历史 (如果需要)
// Compress conversation history (if needed)
self.compress_conversation(&mut conversation).await;

// 获取 LLM 响应
// Get LLM response
let prompt = self.build_prompt(&system_prompt, &conversation).await;
Expand Down Expand Up @@ -451,6 +532,109 @@ impl ReActAgent {
))
}

/// 执行任务并实时流式返回步骤
/// Execute task and stream steps in real-time
pub async fn run_streaming(
&self,
task: impl Into<String>,
step_tx: tokio::sync::mpsc::Sender<ReActStep>,
) -> LLMResult<ReActResult> {
let task = task.into();
let task_id = uuid::Uuid::now_v7().to_string();
let start_time = std::time::Instant::now();

let mut steps = Vec::new();
let mut step_number = 0;

let system_prompt = self.build_system_prompt().await;
let mut conversation = vec![format!("Task: {}", task)];

for iteration in 0..self.config.max_iterations {
step_number += 1;

// 压缩对话历史 (如果需要)
// Compress conversation history (if needed)
self.compress_conversation(&mut conversation).await;

let prompt = self.build_prompt(&system_prompt, &conversation).await;
let response = self.llm.ask(&prompt).await?;
let parsed = self.parse_response(&response);

match parsed {
ParsedResponse::Thought(thought) => {
let step = ReActStep::thought(&thought, step_number);
// Stream step immediately as it is produced
let _ = step_tx.send(step.clone()).await;
steps.push(step);
conversation.push(format!("Thought: {}", thought));

if self.config.verbose {
tracing::info!("Thought: {}", thought);
}
}
ParsedResponse::Action { tool, input } => {
let step = ReActStep::action(&tool, &input, step_number);
let _ = step_tx.send(step.clone()).await;
steps.push(step);
conversation.push(format!("Action: {}[{}]", tool, input));

if self.config.verbose {
tracing::info!("Action: {}[{}]", tool, input);
}

// Execute tool
step_number += 1;
let observation = self.execute_tool(&tool, &input).await;
let obs_step = ReActStep::observation(&observation, step_number);
let _ = step_tx.send(obs_step.clone()).await;
steps.push(obs_step);
conversation.push(format!("Observation: {}", observation));

if self.config.verbose {
tracing::info!("Observation: {}", observation);
}
}
ParsedResponse::FinalAnswer(answer) => {
let step = ReActStep::final_answer(&answer, step_number);
let _ = step_tx.send(step.clone()).await;
steps.push(step);

if self.config.verbose {
tracing::info!("Final Answer: {}", answer);
}

return Ok(ReActResult::success(
task_id,
&task,
answer,
steps,
iteration + 1,
start_time.elapsed().as_millis() as u64,
));
}
ParsedResponse::Error(err) => {
return Ok(ReActResult::failed(
task_id,
&task,
err,
steps,
iteration + 1,
start_time.elapsed().as_millis() as u64,
));
}
}
}

Ok(ReActResult::failed(
task_id,
&task,
format!("Max iterations ({}) exceeded", self.config.max_iterations),
steps,
self.config.max_iterations,
start_time.elapsed().as_millis() as u64,
))
}

/// 构建系统提示词
/// Build system prompt
async fn build_system_prompt(&self) -> String {
Expand Down Expand Up @@ -486,15 +670,40 @@ Rules:
- Use tools when you need external information
- Be concise and focused
- Provide a Final Answer when you have enough information
- If a tool returns an error, think about alternatives"#,
- If a tool returns an error, think about alternatives
- If you find yourself repeating the same action, reflect on why and try a different approach"#,
tool_descriptions.join("\n")
)
}

/// 检测重复动作
/// Detect repeated action
fn detect_repeated_action(&self, conversation: &[String]) -> bool {
let actions: Vec<_> = conversation
.iter()
.filter(|s| s.starts_with("Action:"))
.collect();
if actions.len() >= 2 {
return actions[actions.len() - 1] == actions[actions.len() - 2];
}
false
}

/// 构建完整提示词
/// Build complete prompt
async fn build_prompt(&self, system_prompt: &str, conversation: &[String]) -> String {
format!("{}\n\n{}", system_prompt, conversation.join("\n"))
let mut final_conversation = conversation.to_vec();

// 注入反思警告 (如果检测到重复)
// Inject reflection warning (if repeated action detected)
if self.config.enable_reflection && self.detect_repeated_action(conversation) {
final_conversation.push(
"Observation: [Warning: You are repeating the same action. Please reflect on why it might be failing and try a different approach.]"
.to_string(),
);
}

format!("{}\n\n{}", system_prompt, final_conversation.join("\n"))
}

/// 解析 LLM 响应
Expand Down Expand Up @@ -566,25 +775,73 @@ Rules:
ParsedResponse::Thought(response.to_string())
}

/// 执行工具
/// Execute tool
/// 执行工具(带超时保护和重试机制)
/// Execute tool (with timeout protection and retry mechanism)
async fn execute_tool(&self, tool_name: &str, input: &str) -> String {
let span = tracing::info_span!("react.tool_call", tool = %tool_name);
async {
let timeout_dur = self.config.tool_timeout;
let max_retries = self.config.tool_max_retries;

// Clone the Arc while holding the read lock
let maybe_tool: Option<Arc<dyn ReActTool>> = {
let tools = self.tools.read().await;
tools.get(tool_name).cloned()
};

let mut last_error = String::new();

for attempt in 0..=max_retries {
if attempt > 0 {
if self.config.verbose {
tracing::warn!(
"Retrying tool '{}' (attempt {}/{})",
tool_name,
attempt,
max_retries
);
}
// 指数退避 (Exponential backoff)
let delay = Duration::from_millis(500 * (2u64.pow(attempt as u32 - 1)));
tokio::time::sleep(delay).await;
}

match tools.get(tool_name) {
Some(tool) => match tool.execute(input).await {
Ok(result) => result,
Err(e) => format!("Tool error: {}", e),
},
None => format!(
"Tool '{}' not found. Available tools: {:?}",
tool_name,
tools.keys().collect::<Vec<_>>()
),
let result = async {
match maybe_tool.as_ref() {
Some(tool) => {
match tokio::time::timeout(timeout_dur, tool.execute(input)).await {
Ok(Ok(result)) => Ok(result),
Ok(Err(e)) => Err(format!("Tool error: {}", e)),
Err(_) => {
Err(format!("Tool '{}' timed out after {:?}", tool_name, timeout_dur))
}
}
}
None => {
let tools = self.tools.read().await;
Err(format!(
"Tool '{}' not found. Available tools: {:?}",
tool_name,
tools.keys().collect::<Vec<_>>()
))
}
}
}
}.instrument(span).await
.instrument(span.clone())
.await;

match result {
Ok(res) => return res,
Err(e) => {
last_error = e;
// 如果工具未找到,直接终止重试
if last_error.contains("not found") {
break;
}
}
}
}

last_error
}
}

Expand Down
Loading