diff --git a/crates/mofa-foundation/src/agent/tools/builtin.rs b/crates/mofa-foundation/src/agent/tools/builtin.rs index 028cc1796..076697341 100644 --- a/crates/mofa-foundation/src/agent/tools/builtin.rs +++ b/crates/mofa-foundation/src/agent/tools/builtin.rs @@ -22,7 +22,7 @@ //! let mut registry = SimpleToolRegistry::new(); //! registry.register(as_tool(DateTimeTool)).unwrap(); //! registry.register(as_tool(FileReadTool)).unwrap(); -//! registry.register(as_tool(HttpTool)).unwrap(); +//! registry.register(as_tool(HttpTool::new())).unwrap(); //! ``` use async_trait::async_trait; @@ -48,8 +48,29 @@ use crate::agent::components::tool::{SimpleTool, ToolCategory}; /// | `method` | string | no | `"GET"` (default) or `"POST"` etc. | /// | `body` | string | no | Request body for POST/PUT/PATCH | /// | `headers` | object | no | Additional headers as key→value pairs | -#[derive(Debug)] -pub struct HttpTool; +/// `HttpTool` holds a shared [`reqwest::Client`] so that the underlying +/// connection pool and TLS session cache are reused across all invocations. +/// Construct once and register with the tool registry; do not create a new +/// instance per request. +#[derive(Debug, Clone)] +pub struct HttpTool { + client: reqwest::Client, +} + +impl HttpTool { + /// Create a new `HttpTool` with a freshly-constructed, shared `reqwest::Client`. + pub fn new() -> Self { + Self { + client: reqwest::Client::new(), + } + } +} + +impl Default for HttpTool { + fn default() -> Self { + Self::new() + } +} #[async_trait] impl SimpleTool for HttpTool { @@ -96,14 +117,13 @@ impl SimpleTool for HttpTool { }; let method = input.get_str("method").unwrap_or("GET").to_uppercase(); - let client = reqwest::Client::new(); let mut builder = match method.as_str() { - "GET" => client.get(&url), - "POST" => client.post(&url), - "PUT" => client.put(&url), - "DELETE" => client.delete(&url), - "PATCH" => client.patch(&url), + "GET" => self.client.get(&url), + "POST" => self.client.post(&url), + "PUT" => self.client.put(&url), + "DELETE" => self.client.delete(&url), + "PATCH" => self.client.patch(&url), other => return ToolResult::failure(format!("unsupported HTTP method: {other}")), }; @@ -901,7 +921,10 @@ mod tests { .execute(ToolInput::from_json(json!({"path": path}))) .await; assert!(read.success); - content = read.output["content"].as_str().unwrap_or_default().to_string(); + content = read.output["content"] + .as_str() + .unwrap_or_default() + .to_string(); if content.contains("line2") { ok = true; break; diff --git a/crates/mofa-runtime/src/agent/execution.rs b/crates/mofa-runtime/src/agent/execution.rs index d63ac894e..6330f1491 100644 --- a/crates/mofa-runtime/src/agent/execution.rs +++ b/crates/mofa-runtime/src/agent/execution.rs @@ -16,7 +16,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use tokio::sync::RwLock; +use tokio::sync::{RwLock, Semaphore}; use tokio::time::timeout; use tracing::Instrument; @@ -271,6 +271,9 @@ impl ExecutionResult { /// info!("Output: {:?}", result.output); /// } /// ``` +/// Default concurrency limit for parallel execution. +const DEFAULT_MAX_CONCURRENT_TASKS: usize = 10; + #[derive(Clone)] pub struct ExecutionEngine { /// Agent 注册中心 @@ -281,6 +284,8 @@ pub struct ExecutionEngine { plugin_executor: PluginExecutor, /// Fallback strategy invoked after all retries are exhausted. fallback: Arc, + /// Maximum number of concurrent tasks for parallel execution. + max_concurrent_tasks: usize, } impl ExecutionEngine { @@ -291,6 +296,7 @@ impl ExecutionEngine { registry, plugin_executor: PluginExecutor::new(Arc::new(SimplePluginRegistry::new())), fallback: Arc::new(NoFallback), + max_concurrent_tasks: DEFAULT_MAX_CONCURRENT_TASKS, } } @@ -300,6 +306,15 @@ impl ExecutionEngine { self } + /// Set the maximum number of concurrent tasks for parallel execution. + /// + /// This limit is enforced by `execute_parallel` via a `tokio::sync::Semaphore`. + /// The default is `10`, matching `RuntimeConfig::max_concurrent_tasks`. + pub fn with_max_concurrent_tasks(mut self, max: usize) -> Self { + self.max_concurrent_tasks = if max == 0 { 1 } else { max }; + self + } + /// 创建带有自定义插件注册中心的执行引擎 /// Create execution engine with custom plugin registry pub fn with_plugin_registry( @@ -310,6 +325,7 @@ impl ExecutionEngine { registry, plugin_executor: PluginExecutor::new(plugin_registry), fallback: Arc::new(NoFallback), + max_concurrent_tasks: DEFAULT_MAX_CONCURRENT_TASKS, } } @@ -566,20 +582,31 @@ impl ExecutionEngine { /// 并行执行多个 Agent /// Execute multiple agents in parallel + /// + /// Concurrency is bounded by `max_concurrent_tasks` (default: 10). + /// Use [`with_max_concurrent_tasks`] to override. pub async fn execute_parallel( &self, executions: Vec<(String, AgentInput)>, options: ExecutionOptions, ) -> Vec> { + let semaphore = Arc::new(Semaphore::new(self.max_concurrent_tasks)); let mut handles = Vec::new(); for (agent_id, input) in executions { let engine = self.clone(); let opts = options.clone(); + let sem = semaphore.clone(); let span = tracing::info_span!("agent.parallel", agent_id = %agent_id); + let permit = sem.acquire_owned().await.expect("concurrency semaphore closed"); let handle = tokio::spawn( - async move { engine.execute(&agent_id, input, opts).await }.instrument(span), + async move { + let result = engine.execute(&agent_id, input, opts).await; + drop(permit); + result + } + .instrument(span), ); handles.push(handle); @@ -797,6 +824,7 @@ mod tests { use crate::agent::context::AgentContext; use crate::agent::core::MoFAAgent; use crate::agent::types::AgentState; + use std::sync::atomic::{AtomicUsize, Ordering}; // 测试用 Agent (内联实现,不依赖 BaseAgent) // Agent for testing (inline implementation, no BaseAgent dependency) @@ -927,4 +955,136 @@ mod tests { options_from_serde.retry_delay_ms ); } + + // ======================================================================== + // Bounded concurrency tests + // ======================================================================== + + /// Agent that tracks the peak number of concurrent executions. + struct ConcurrencyTrackingAgent { + id: String, + capabilities: AgentCapabilities, + state: AgentState, + active: Arc, + peak: Arc, + } + + impl ConcurrencyTrackingAgent { + fn new(id: &str, active: Arc, peak: Arc) -> Self { + Self { + id: id.to_string(), + capabilities: AgentCapabilities::default(), + state: AgentState::Created, + active, + peak, + } + } + } + + #[async_trait::async_trait] + impl MoFAAgent for ConcurrencyTrackingAgent { + fn id(&self) -> &str { + &self.id + } + + fn name(&self) -> &str { + &self.id + } + + fn capabilities(&self) -> &AgentCapabilities { + &self.capabilities + } + + fn state(&self) -> AgentState { + self.state.clone() + } + + async fn initialize(&mut self, _ctx: &AgentContext) -> AgentResult<()> { + self.state = AgentState::Ready; + Ok(()) + } + + async fn execute( + &mut self, + _input: AgentInput, + _ctx: &AgentContext, + ) -> AgentResult { + // Increment active count and update peak. + let prev = self.active.fetch_add(1, Ordering::SeqCst); + let current = prev + 1; + self.peak.fetch_max(current, Ordering::SeqCst); + + // Hold the slot long enough for other tasks to pile up. + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + self.active.fetch_sub(1, Ordering::SeqCst); + Ok(AgentOutput::text("done")) + } + + async fn shutdown(&mut self) -> AgentResult<()> { + self.state = AgentState::Shutdown; + Ok(()) + } + } + + #[tokio::test] + async fn test_execute_parallel_respects_concurrency_limit() { + let active = Arc::new(AtomicUsize::new(0)); + let peak = Arc::new(AtomicUsize::new(0)); + + let registry = Arc::new(AgentRegistry::new()); + let agent = Arc::new(RwLock::new(ConcurrencyTrackingAgent::new( + "conc-agent", + active.clone(), + peak.clone(), + ))); + registry.register(agent).await.unwrap(); + + let concurrency_limit = 3; + let engine = ExecutionEngine::new(registry).with_max_concurrent_tasks(concurrency_limit); + + // Spawn more tasks than the concurrency limit. + let inputs: Vec<(String, AgentInput)> = (0..12) + .map(|i| ("conc-agent".to_string(), AgentInput::text(format!("t{i}")))) + .collect(); + + let results = engine + .execute_parallel(inputs, ExecutionOptions::default()) + .await; + + assert_eq!(results.len(), 12); + for r in &results { + assert!(r.is_ok(), "expected Ok, got {:?}", r); + } + + let observed_peak = peak.load(Ordering::SeqCst); + assert!( + observed_peak <= concurrency_limit, + "peak concurrency {} exceeded limit {}", + observed_peak, + concurrency_limit, + ); + } + + #[test] + fn test_execute_parallel_default_concurrency() { + let registry = Arc::new(AgentRegistry::new()); + let engine = ExecutionEngine::new(registry); + assert_eq!(engine.max_concurrent_tasks, DEFAULT_MAX_CONCURRENT_TASKS); + } + + #[test] + fn test_with_max_concurrent_tasks_sets_limit() { + let registry = Arc::new(AgentRegistry::new()); + let engine = ExecutionEngine::new(registry).with_max_concurrent_tasks(5); + assert_eq!(engine.max_concurrent_tasks, 5); + } + + #[test] + fn test_with_max_concurrent_tasks_rejects_zero() { + let registry = Arc::new(AgentRegistry::new()); + let engine = ExecutionEngine::new(registry).with_max_concurrent_tasks(0); + // Zero should be clamped to 1, not allowed to create a dead semaphore. + assert_eq!(engine.max_concurrent_tasks, 1); + } }