Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions crates/mofa-foundation/src/agent/tools/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
//! let mut registry = SimpleToolRegistry::new();
//! registry.register(as_tool(DateTimeTool)).unwrap();
//! registry.register(as_tool(FileReadTool)).unwrap();
//! registry.register(as_tool(HttpTool)).unwrap();
//! registry.register(as_tool(HttpTool::new())).unwrap();
//! ```

use async_trait::async_trait;
Expand All @@ -48,8 +48,29 @@ use crate::agent::components::tool::{SimpleTool, ToolCategory};
/// | `method` | string | no | `"GET"` (default) or `"POST"` etc. |
/// | `body` | string | no | Request body for POST/PUT/PATCH |
/// | `headers` | object | no | Additional headers as key→value pairs |
#[derive(Debug)]
pub struct HttpTool;
/// `HttpTool` holds a shared [`reqwest::Client`] so that the underlying
/// connection pool and TLS session cache are reused across all invocations.
/// Construct once and register with the tool registry; do not create a new
/// instance per request.
#[derive(Debug, Clone)]
pub struct HttpTool {
client: reqwest::Client,
}

impl HttpTool {
/// Create a new `HttpTool` with a freshly-constructed, shared `reqwest::Client`.
pub fn new() -> Self {
Self {
client: reqwest::Client::new(),
}
}
}

impl Default for HttpTool {
fn default() -> Self {
Self::new()
}
}

#[async_trait]
impl SimpleTool for HttpTool {
Expand Down Expand Up @@ -96,14 +117,13 @@ impl SimpleTool for HttpTool {
};

let method = input.get_str("method").unwrap_or("GET").to_uppercase();
let client = reqwest::Client::new();

let mut builder = match method.as_str() {
"GET" => client.get(&url),
"POST" => client.post(&url),
"PUT" => client.put(&url),
"DELETE" => client.delete(&url),
"PATCH" => client.patch(&url),
"GET" => self.client.get(&url),
"POST" => self.client.post(&url),
"PUT" => self.client.put(&url),
"DELETE" => self.client.delete(&url),
"PATCH" => self.client.patch(&url),
other => return ToolResult::failure(format!("unsupported HTTP method: {other}")),
};

Expand Down Expand Up @@ -901,7 +921,10 @@ mod tests {
.execute(ToolInput::from_json(json!({"path": path})))
.await;
assert!(read.success);
content = read.output["content"].as_str().unwrap_or_default().to_string();
content = read.output["content"]
.as_str()
.unwrap_or_default()
.to_string();
if content.contains("line2") {
ok = true;
break;
Expand Down
164 changes: 162 additions & 2 deletions crates/mofa-runtime/src/agent/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::sync::{RwLock, Semaphore};
use tokio::time::timeout;
use tracing::Instrument;

Expand Down Expand Up @@ -271,6 +271,9 @@ impl ExecutionResult {
/// info!("Output: {:?}", result.output);
/// }
/// ```
/// Default concurrency limit for parallel execution.
const DEFAULT_MAX_CONCURRENT_TASKS: usize = 10;

#[derive(Clone)]
pub struct ExecutionEngine {
/// Agent 注册中心
Expand All @@ -281,6 +284,8 @@ pub struct ExecutionEngine {
plugin_executor: PluginExecutor,
/// Fallback strategy invoked after all retries are exhausted.
fallback: Arc<dyn FallbackStrategy>,
/// Maximum number of concurrent tasks for parallel execution.
max_concurrent_tasks: usize,
}

impl ExecutionEngine {
Expand All @@ -291,6 +296,7 @@ impl ExecutionEngine {
registry,
plugin_executor: PluginExecutor::new(Arc::new(SimplePluginRegistry::new())),
fallback: Arc::new(NoFallback),
max_concurrent_tasks: DEFAULT_MAX_CONCURRENT_TASKS,
}
}

Expand All @@ -300,6 +306,15 @@ impl ExecutionEngine {
self
}

/// Set the maximum number of concurrent tasks for parallel execution.
///
/// This limit is enforced by `execute_parallel` via a `tokio::sync::Semaphore`.
/// The default is `10`, matching `RuntimeConfig::max_concurrent_tasks`.
pub fn with_max_concurrent_tasks(mut self, max: usize) -> Self {
self.max_concurrent_tasks = if max == 0 { 1 } else { max };
self
}

/// 创建带有自定义插件注册中心的执行引擎
/// Create execution engine with custom plugin registry
pub fn with_plugin_registry(
Expand All @@ -310,6 +325,7 @@ impl ExecutionEngine {
registry,
plugin_executor: PluginExecutor::new(plugin_registry),
fallback: Arc::new(NoFallback),
max_concurrent_tasks: DEFAULT_MAX_CONCURRENT_TASKS,
}
}

Expand Down Expand Up @@ -566,20 +582,31 @@ impl ExecutionEngine {

/// 并行执行多个 Agent
/// Execute multiple agents in parallel
///
/// Concurrency is bounded by `max_concurrent_tasks` (default: 10).
/// Use [`with_max_concurrent_tasks`] to override.
pub async fn execute_parallel(
&self,
executions: Vec<(String, AgentInput)>,
options: ExecutionOptions,
) -> Vec<AgentResult<ExecutionResult>> {
let semaphore = Arc::new(Semaphore::new(self.max_concurrent_tasks));
let mut handles = Vec::new();

for (agent_id, input) in executions {
let engine = self.clone();
let opts = options.clone();
let sem = semaphore.clone();

let span = tracing::info_span!("agent.parallel", agent_id = %agent_id);
let permit = sem.acquire_owned().await.expect("concurrency semaphore closed");
let handle = tokio::spawn(
async move { engine.execute(&agent_id, input, opts).await }.instrument(span),
async move {
let result = engine.execute(&agent_id, input, opts).await;
drop(permit);
result
}
.instrument(span),
);

handles.push(handle);
Expand Down Expand Up @@ -797,6 +824,7 @@ mod tests {
use crate::agent::context::AgentContext;
use crate::agent::core::MoFAAgent;
use crate::agent::types::AgentState;
use std::sync::atomic::{AtomicUsize, Ordering};

// 测试用 Agent (内联实现,不依赖 BaseAgent)
// Agent for testing (inline implementation, no BaseAgent dependency)
Expand Down Expand Up @@ -927,4 +955,136 @@ mod tests {
options_from_serde.retry_delay_ms
);
}

// ========================================================================
// Bounded concurrency tests
// ========================================================================

/// Agent that tracks the peak number of concurrent executions.
struct ConcurrencyTrackingAgent {
id: String,
capabilities: AgentCapabilities,
state: AgentState,
active: Arc<AtomicUsize>,
peak: Arc<AtomicUsize>,
}

impl ConcurrencyTrackingAgent {
fn new(id: &str, active: Arc<AtomicUsize>, peak: Arc<AtomicUsize>) -> Self {
Self {
id: id.to_string(),
capabilities: AgentCapabilities::default(),
state: AgentState::Created,
active,
peak,
}
}
}

#[async_trait::async_trait]
impl MoFAAgent for ConcurrencyTrackingAgent {
fn id(&self) -> &str {
&self.id
}

fn name(&self) -> &str {
&self.id
}

fn capabilities(&self) -> &AgentCapabilities {
&self.capabilities
}

fn state(&self) -> AgentState {
self.state.clone()
}

async fn initialize(&mut self, _ctx: &AgentContext) -> AgentResult<()> {
self.state = AgentState::Ready;
Ok(())
}

async fn execute(
&mut self,
_input: AgentInput,
_ctx: &AgentContext,
) -> AgentResult<AgentOutput> {
// Increment active count and update peak.
let prev = self.active.fetch_add(1, Ordering::SeqCst);
let current = prev + 1;
self.peak.fetch_max(current, Ordering::SeqCst);

// Hold the slot long enough for other tasks to pile up.
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;

self.active.fetch_sub(1, Ordering::SeqCst);
Ok(AgentOutput::text("done"))
}

async fn shutdown(&mut self) -> AgentResult<()> {
self.state = AgentState::Shutdown;
Ok(())
}
}

#[tokio::test]
async fn test_execute_parallel_respects_concurrency_limit() {
let active = Arc::new(AtomicUsize::new(0));
let peak = Arc::new(AtomicUsize::new(0));

let registry = Arc::new(AgentRegistry::new());
let agent = Arc::new(RwLock::new(ConcurrencyTrackingAgent::new(
"conc-agent",
active.clone(),
peak.clone(),
)));
registry.register(agent).await.unwrap();

let concurrency_limit = 3;
let engine = ExecutionEngine::new(registry).with_max_concurrent_tasks(concurrency_limit);

// Spawn more tasks than the concurrency limit.
let inputs: Vec<(String, AgentInput)> = (0..12)
.map(|i| ("conc-agent".to_string(), AgentInput::text(format!("t{i}"))))
.collect();

let results = engine
.execute_parallel(inputs, ExecutionOptions::default())
.await;

assert_eq!(results.len(), 12);
for r in &results {
assert!(r.is_ok(), "expected Ok, got {:?}", r);
}

let observed_peak = peak.load(Ordering::SeqCst);
assert!(
observed_peak <= concurrency_limit,
"peak concurrency {} exceeded limit {}",
observed_peak,
concurrency_limit,
);
}

#[test]
fn test_execute_parallel_default_concurrency() {
let registry = Arc::new(AgentRegistry::new());
let engine = ExecutionEngine::new(registry);
assert_eq!(engine.max_concurrent_tasks, DEFAULT_MAX_CONCURRENT_TASKS);
}

#[test]
fn test_with_max_concurrent_tasks_sets_limit() {
let registry = Arc::new(AgentRegistry::new());
let engine = ExecutionEngine::new(registry).with_max_concurrent_tasks(5);
assert_eq!(engine.max_concurrent_tasks, 5);
}

#[test]
fn test_with_max_concurrent_tasks_rejects_zero() {
let registry = Arc::new(AgentRegistry::new());
let engine = ExecutionEngine::new(registry).with_max_concurrent_tasks(0);
// Zero should be clamped to 1, not allowed to create a dead semaphore.
assert_eq!(engine.max_concurrent_tasks, 1);
}
}
Loading