diff --git a/src/actions/callToolAction.ts b/src/actions/callToolAction.ts index 09a8cdc..5ee01ea 100644 --- a/src/actions/callToolAction.ts +++ b/src/actions/callToolAction.ts @@ -1,20 +1,22 @@ -import { - type Action, - type HandlerCallback, - type IAgentRuntime, - type Memory, - type State, - logger, +import type { + Action, + HandlerCallback, + IAgentRuntime, + Memory, + State, } from "@elizaos/core"; -import type { McpService } from "../service"; -import { MCP_SERVICE_NAME } from "../types"; -import { handleMcpError } from "../utils/error"; -import { handleToolResponse, processToolResult } from "../utils/processing"; -import { createToolSelectionArgument, createToolSelectionName } from "../utils/selection"; -import { handleNoToolAvailable } from "../utils/handler"; +import { mcpLogger } from "@/utils/mcp-logger"; +import { handleMcpError } from "@/utils/error"; +import { handleToolResponse, processToolResult } from "@/utils/processing"; +import { createToolSelectionArgument, createToolSelectionName } from "@/utils/selection"; +import { useActionHandler } from "@/utils/use-action"; +import { validateAction } from "@/utils/validation"; +import { handleNoToolAvailable } from "@/utils/handlers"; + +const ACTION_NAME = "CALL_TOOL"; export const callToolAction: Action = { - name: "CALL_TOOL", + name: ACTION_NAME, similes: [ "CALL_MCP_TOOL", "USE_TOOL", @@ -28,101 +30,61 @@ export const callToolAction: Action = { ], description: "Calls a tool from an MCP server to perform a specific task", - validate: async (runtime: IAgentRuntime, _message: Memory, _state?: State): Promise => { - const mcpService = runtime.getService(MCP_SERVICE_NAME); - if (!mcpService) return false; - - const servers = mcpService.getServers(); - return ( - servers.length > 0 && - servers.some( - (server) => server.status === "connected" && server.tools && server.tools.length > 0 - ) - ); + validate: async (runtime: IAgentRuntime, message: Memory, state?: State): Promise => { + return await validateAction(ACTION_NAME, runtime, message, state); }, handler: async ( runtime: IAgentRuntime, message: Memory, - _state?: State, - _options?: { [key: string]: unknown }, + state?: State, + options?: { [key: string]: unknown }, callback?: HandlerCallback ): Promise => { - const composedState = await runtime.composeState(message, ["RECENT_MESSAGES", "MCP"]); - const mcpService = runtime.getService(MCP_SERVICE_NAME); - if (!mcpService) { - throw new Error("MCP service not available"); - } - const mcpProvider = mcpService.getProviderData(); + const context = await useActionHandler({ actionName: ACTION_NAME, runtime, message, state, options, callback }); try { // Select the tool with this servername and toolname - const toolSelectionName = await createToolSelectionName({ - runtime, - state: composedState, - message, - callback, - mcpProvider, - }); + const toolSelectionName = await createToolSelectionName({...context}); if (!toolSelectionName || toolSelectionName.noToolAvailable) { - logger.warn("[NO_TOOL_AVAILABLE] No appropriate tool available for the request"); + mcpLogger.warn("[NO_TOOL_AVAILABLE] No appropriate tool available for the request"); return handleNoToolAvailable(callback, toolSelectionName); } const { serverName, toolName, reasoning } = toolSelectionName; - logger.info( - `[CALLING] Calling tool "${serverName}/${toolName}" on server with reasoning: "${reasoning}"` - ); + mcpLogger.info(`[CALLING] Calling tool "${serverName}/${toolName}" on server with reasoning: "${reasoning}"`); - // Create the tool selection "argument" based on the selected tool name - const toolSelectionArgument = await createToolSelectionArgument({ - runtime, - state: composedState, - message, - callback, - mcpProvider, - toolSelectionName, - }); + const toolSelectionArgument = await createToolSelectionArgument({ ...context, toolSelectionName }); if (!toolSelectionArgument) { - logger.warn( - "[NO_TOOL_SELECTION_ARGUMENT] No appropriate tool selection argument available" - ); + mcpLogger.warn("[NO_TOOL_SELECTION_ARGUMENT] No appropriate tool selection argument available"); return handleNoToolAvailable(callback, toolSelectionName); } - logger.info( - `[SELECTED] Tool Selection result:\n${JSON.stringify(toolSelectionArgument, null, 2)}` - ); + mcpLogger.info(`[SELECTED] Tool Selection result:\n${JSON.stringify(toolSelectionArgument, null, 2)}`); - const result = await mcpService.callTool( - serverName, - toolName, - toolSelectionArgument.toolArguments - ); + const result = await context.mcpService.callTool(serverName, toolName, toolSelectionArgument.toolArguments); + mcpLogger.info(`[CALLED] Tool "${serverName}/${toolName}" result:\n"${JSON.stringify(result, null, 2)}"`); - const { toolOutput, hasAttachments, attachments } = processToolResult( + const { toolOutput, hasAttachments, attachments } = processToolResult({ + ...context, result, serverName, toolName, - runtime, - message.entityId - ); + messageEntityId: context.message.entityId, + }); - await handleToolResponse( - runtime, - message, + mcpLogger.info('[HANDLE] Handling tool response...'); + await handleToolResponse({ + ...context, serverName, toolName, - toolSelectionArgument.toolArguments, + toolArguments: toolSelectionArgument.toolArguments, toolOutput, hasAttachments, attachments, - composedState, - mcpProvider, - callback - ); + }); return true; } catch (error) { - return handleMcpError(composedState, mcpProvider, error, runtime, message, "tool", callback); + return await handleMcpError({ ...context, type: 'tool', error }); } }, diff --git a/src/actions/readResourceAction.ts b/src/actions/readResourceAction.ts index e367a5c..5688a79 100644 --- a/src/actions/readResourceAction.ts +++ b/src/actions/readResourceAction.ts @@ -1,67 +1,26 @@ -import { - type Action, - type HandlerCallback, - type IAgentRuntime, - type Memory, - ModelType, - type State, - composePromptFromState, - logger, +import type { + Action, + HandlerCallback, + IAgentRuntime, + Memory, + State, } from "@elizaos/core"; -import type { McpService } from "../service"; -import { resourceSelectionTemplate } from "../templates/resourceSelectionTemplate"; -import { MCP_SERVICE_NAME } from "../types"; -import { handleMcpError } from "../utils/error"; +import { mcpLogger } from "@/utils/mcp-logger"; +import { handleMcpError } from "@/utils/error"; import { - handleResourceAnalysis, processResourceResult, sendInitialResponse, -} from "../utils/processing"; -import { - createResourceSelectionFeedbackPrompt, - validateResourceSelection, -} from "../utils/validation"; -import type { ResourceSelection } from "../utils/validation"; -import { withModelRetry } from "../utils/wrapper"; +} from "@/utils/processing"; +import { useActionHandler } from "@/utils/use-action"; +import { createResourceSelection } from "@/utils/selection"; +import { handleNoResourceAvailable, handleResourceAnalysis } from "@/utils/handlers"; +import { validateAction } from "@/utils/validation"; -function createResourceSelectionPrompt(composedState: State, userMessage: string): string { - const mcpData = composedState.values.mcp || {}; - const serverNames = Object.keys(mcpData); - let resourcesDescription = ""; - for (const serverName of serverNames) { - const server = mcpData[serverName]; - if (server.status !== "connected") continue; - - const resourceUris = Object.keys(server.resources || {}); - for (const uri of resourceUris) { - const resource = server.resources[uri]; - resourcesDescription += `Resource: ${uri} (Server: ${serverName})\n`; - resourcesDescription += `Name: ${resource.name || "No name available"}\n`; - resourcesDescription += `Description: ${ - resource.description || "No description available" - }\n`; - resourcesDescription += `MIME Type: ${resource.mimeType || "Not specified"}\n\n`; - } - } - - const enhancedState: State = { - ...composedState, - values: { - ...composedState.values, - resourcesDescription, - userMessage, - }, - }; - - return composePromptFromState({ - state: enhancedState, - template: resourceSelectionTemplate, - }); -} +const ACTION_NAME = 'READ_RESOURCE'; export const readResourceAction: Action = { - name: "READ_RESOURCE", + name: ACTION_NAME, similes: [ "READ_MCP_RESOURCE", "GET_RESOURCE", @@ -73,107 +32,45 @@ export const readResourceAction: Action = { ], description: "Reads a resource from an MCP server", - validate: async (runtime: IAgentRuntime, _message: Memory, _state?: State): Promise => { - const mcpService = runtime.getService(MCP_SERVICE_NAME); - if (!mcpService) return false; - - const servers = mcpService.getServers(); - return ( - servers.length > 0 && - servers.some( - (server) => server.status === "connected" && server.resources && server.resources.length > 0 - ) - ); + validate: async (runtime: IAgentRuntime, message: Memory, state?: State): Promise => { + return await validateAction(ACTION_NAME, runtime, message, state); }, handler: async ( runtime: IAgentRuntime, message: Memory, - _state?: State, - _options?: { [key: string]: unknown }, + state?: State, + options?: { [key: string]: unknown }, callback?: HandlerCallback ): Promise => { - const composedState = await runtime.composeState(message, ["RECENT_MESSAGES", "MCP"]); - - const mcpService = runtime.getService(MCP_SERVICE_NAME); - if (!mcpService) { - throw new Error("MCP service not available"); - } - - const mcpProvider = mcpService.getProviderData(); + const context = await useActionHandler({ actionName: ACTION_NAME, runtime, message, state, options, callback }); try { + mcpLogger.info('[INITIAL_RESPONSE] Sending initial response...'); await sendInitialResponse(callback); - const resourceSelectionPrompt = createResourceSelectionPrompt( - composedState, - message.content.text || "" - ); - - const resourceSelection = await runtime.useModel(ModelType.TEXT_SMALL, { - prompt: resourceSelectionPrompt, - }); + const resourceSelection = await createResourceSelection({ ...context }); + mcpLogger.info(`[SELECTED] Resource Selection response:\n${JSON.stringify(resourceSelection, null, 2)}`); - const parsedSelection = await withModelRetry({ - runtime, - state: composedState, - message, - callback, - input: resourceSelection, - validationFn: (data) => validateResourceSelection(data), - createFeedbackPromptFn: (originalResponse, errorMessage, state, userMessage) => - createResourceSelectionFeedbackPrompt( - originalResponse as string, - errorMessage, - state, - userMessage - ), - failureMsg: `I'm having trouble finding the resource you're looking for. Could you provide more details about what you need?`, - retryCount: 0, - }); - - if (!parsedSelection || parsedSelection.noResourceAvailable) { - if (callback && parsedSelection?.noResourceAvailable) { - await callback({ - text: "I don't have a specific resource that contains the information you're looking for. Let me try to assist you directly instead.", - thought: - "No appropriate MCP resource available for this request. Falling back to direct assistance.", - actions: ["REPLY"], - }); - } - return true; + if (!resourceSelection || resourceSelection.noResourceAvailable) { + mcpLogger.info('[NO_RESOURCE_AVAILABLE] No appropriate resource available for the request'); + return handleNoResourceAvailable(callback); } - const { serverName, uri, reasoning } = parsedSelection; - - logger.debug(`Selected resource "${uri}" on server "${serverName}" because: ${reasoning}`); - - const result = await mcpService.readResource(serverName, uri); - logger.debug(`Read resource ${uri} from server ${serverName}`); + const { serverName, uri, reasoning } = resourceSelection; + mcpLogger.info(`[FETCHING] Fetching resource "${serverName}/${uri}" with reasoning: "${reasoning}"`); + const result = await context.mcpService.readResource(serverName, uri); + mcpLogger.info(`[FETCHED] Resource "${serverName}/${uri}" result: \n"${JSON.stringify(result, null, 2)}"`); + const { resourceContent, resourceMeta } = processResourceResult(result, uri); - await handleResourceAnalysis( - runtime, - message, - uri, - serverName, - resourceContent, - resourceMeta, - callback - ); + mcpLogger.info('[HANDLE] Handling resource response...'); + await handleResourceAnalysis({ ...context, serverName, uri, resourceContent, resourceMeta }); return true; } catch (error) { - return handleMcpError( - composedState, - mcpProvider, - error, - runtime, - message, - "resource", - callback - ); + return await handleMcpError({ ...context, error, type: 'resource' }); } }, diff --git a/src/index.ts b/src/index.ts index 95622d9..9787229 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,4 +1,5 @@ -import { type IAgentRuntime, type Plugin, logger } from "@elizaos/core"; +import type { IAgentRuntime, Plugin } from "@elizaos/core"; +import { mcpLogger } from "@/utils/mcp-logger"; import { callToolAction } from "./actions/callToolAction"; import { readResourceAction } from "./actions/readResourceAction"; import { provider } from "./provider"; @@ -9,7 +10,7 @@ const mcpPlugin: Plugin = { description: "Plugin for connecting to MCP (Model Context Protocol) servers", init: async (_config: Record, _runtime: IAgentRuntime) => { - logger.info("Initializing MCP plugin..."); + mcpLogger.info("Initializing MCP plugin..."); }, services: [McpService], diff --git a/src/service.ts b/src/service.ts index 6c92fbe..f270af8 100644 --- a/src/service.ts +++ b/src/service.ts @@ -1,4 +1,5 @@ -import { type IAgentRuntime, Service, logger } from "@elizaos/core"; +import { type IAgentRuntime, Service, } from "@elizaos/core"; +import { mcpLogger } from "@/utils/mcp-logger"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; @@ -73,14 +74,14 @@ export class McpService extends Service { try { const mcpSettings = this.getMcpSettings(); if (!mcpSettings || !mcpSettings.servers) { - logger.info("No MCP servers configured."); + mcpLogger.info("No MCP servers configured."); return; } await this.updateServerConnections(mcpSettings.servers); const servers = this.getServers(); this.mcpProvider = buildMcpProviderData(servers); } catch (error) { - logger.error( + mcpLogger.error( "Failed to initialize MCP servers:", error instanceof Error ? error.message : String(error) ); @@ -100,7 +101,7 @@ export class McpService extends Service { for (const name of currentNames) { if (!newNames.has(name)) { await this.deleteConnection(name); - logger.info(`Deleted MCP server: ${name}`); + mcpLogger.info(`Deleted MCP server: ${name}`); } } @@ -110,7 +111,7 @@ export class McpService extends Service { try { await this.initializeConnection(name, config); } catch (error) { - logger.error( + mcpLogger.error( `Failed to connect to new MCP server ${name}:`, error instanceof Error ? error.message : String(error) ); @@ -119,9 +120,9 @@ export class McpService extends Service { try { await this.deleteConnection(name); await this.initializeConnection(name, config); - logger.info(`Reconnected MCP server with updated config: ${name}`); + mcpLogger.info(`Reconnected MCP server with updated config: ${name}`); } catch (error) { - logger.error( + mcpLogger.error( `Failed to reconnect MCP server ${name}:`, error instanceof Error ? error.message : String(error) ); @@ -170,7 +171,7 @@ export class McpService extends Service { state.reconnectAttempts = 0; state.consecutivePingFailures = 0; this.startPingMonitoring(name); - logger.info(`Successfully connected to MCP server: ${name}`); + mcpLogger.info(`Successfully connected to MCP server: ${name}`); } catch (error) { state.status = "disconnected"; state.lastError = error instanceof Error ? error : new Error(String(error)); @@ -181,7 +182,7 @@ export class McpService extends Service { private setupTransportHandlers(name: string, connection: McpConnection, state: ConnectionState) { connection.transport.onerror = async (error) => { - logger.error(`Transport error for "${name}":`, error); + mcpLogger.error(`Transport error for "${name}":`, error); connection.server.status = "disconnected"; this.appendErrorMessage(connection, error.message); this.handleDisconnection(name, error); @@ -198,7 +199,7 @@ export class McpService extends Service { if (state.pingInterval) clearInterval(state.pingInterval); state.pingInterval = setInterval(() => { this.sendPing(name).catch((err) => { - logger.warn(`Ping failed for ${name}:`, err instanceof Error ? err.message : String(err)); + mcpLogger.warn(`Ping failed for ${name}:`, err instanceof Error ? err.message : String(err)); this.handlePingFailure(name, err); }); }, this.pingConfig.intervalMs); @@ -224,7 +225,7 @@ export class McpService extends Service { if (!state) return; state.consecutivePingFailures++; if (state.consecutivePingFailures >= this.pingConfig.failuresBeforeDisconnect) { - logger.warn(`Ping failures exceeded for ${name}, disconnecting and attempting reconnect.`); + mcpLogger.warn(`Ping failures exceeded for ${name}, disconnecting and attempting reconnect.`); this.handleDisconnection(name, error); } } @@ -237,19 +238,19 @@ export class McpService extends Service { if (state.pingInterval) clearInterval(state.pingInterval); if (state.reconnectTimeout) clearTimeout(state.reconnectTimeout); if (state.reconnectAttempts >= MAX_RECONNECT_ATTEMPTS) { - logger.error(`Max reconnect attempts reached for ${name}. Giving up.`); + mcpLogger.error(`Max reconnect attempts reached for ${name}. Giving up.`); return; } const delay = INITIAL_RETRY_DELAY * Math.pow(BACKOFF_MULTIPLIER, state.reconnectAttempts); state.reconnectTimeout = setTimeout(async () => { state.reconnectAttempts++; - logger.info(`Attempting to reconnect to ${name} (attempt ${state.reconnectAttempts})...`); + mcpLogger.info(`Attempting to reconnect to ${name} (attempt ${state.reconnectAttempts})...`); const config = this.connections.get(name)?.server.config; if (config) { try { await this.initializeConnection(name, JSON.parse(config)); } catch (err) { - logger.error( + mcpLogger.error( `Reconnect attempt failed for ${name}:`, err instanceof Error ? err.message : String(err) ); @@ -266,7 +267,7 @@ export class McpService extends Service { await connection.transport.close(); await connection.client.close(); } catch (error) { - logger.error( + mcpLogger.error( `Failed to close transport for ${name}:`, error instanceof Error ? error.message : String(error) ); @@ -309,7 +310,7 @@ export class McpService extends Service { // Add deprecation warning for legacy "sse" type if (config.type === "sse") { - logger.warn( + mcpLogger.warn( `Server "${name}": "sse" transport type is deprecated. Use "streamable-http" or "http" instead for the modern Streamable HTTP transport.` ); } @@ -345,9 +346,9 @@ export class McpService extends Service { // Apply compatibility transformations automatically processedTool.inputSchema = this.applyToolCompatibility(tool.inputSchema); - logger.debug(`Applied tool compatibility for: ${tool.name} on server: ${serverName}`); + mcpLogger.debug(`Applied tool compatibility for: ${tool.name} on server: ${serverName}`); } catch (error) { - logger.warn(`Tool compatibility failed for ${tool.name} on ${serverName}:`, error); + mcpLogger.warn(`Tool compatibility failed for ${tool.name} on ${serverName}:`, error); // Keep original schema if transformation fails } } @@ -355,14 +356,14 @@ export class McpService extends Service { return processedTool; }); - logger.info(`Fetched ${tools.length} tools for ${serverName}`); + mcpLogger.info(`Fetched ${tools.length} tools for ${serverName}`); for (const tool of tools) { - logger.info(`[${serverName}] ${tool.name}: ${tool.description}`); + mcpLogger.info(`[${serverName}] ${tool.name}: ${tool.description}`); } return tools; } catch (error) { - logger.error( + mcpLogger.error( `Failed to fetch tools for ${serverName}:`, error instanceof Error ? error.message : String(error) ); @@ -380,7 +381,7 @@ export class McpService extends Service { const response = await connection.client.listResources(); return response?.resources || []; } catch (error) { - logger.warn( + mcpLogger.warn( `No resources found for ${serverName}:`, error instanceof Error ? error.message : String(error) ); @@ -398,7 +399,7 @@ export class McpService extends Service { const response = await connection.client.listResourceTemplates(); return response?.resourceTemplates || []; } catch (error) { - logger.warn( + mcpLogger.warn( `No resource templates found for ${serverName}:`, error instanceof Error ? error.message : String(error) ); @@ -433,7 +434,7 @@ export class McpService extends Service { const config = JSON.parse(connection.server.config); timeout = config.timeoutInMillis || DEFAULT_MCP_TIMEOUT_SECONDS; } catch (error) { - logger.error( + mcpLogger.error( `Failed to parse timeout configuration for server ${serverName}:`, error instanceof Error ? error.message : String(error) ); @@ -464,15 +465,15 @@ export class McpService extends Service { const connection = this.connections.get(serverName); const config = connection?.server.config; if (config) { - logger.info(`Restarting ${serverName} MCP server...`); + mcpLogger.info(`Restarting ${serverName} MCP server...`); connection.server.status = "connecting"; connection.server.error = ""; try { await this.deleteConnection(serverName); await this.initializeConnection(serverName, JSON.parse(config)); - logger.info(`${serverName} MCP server connected`); + mcpLogger.info(`${serverName} MCP server connected`); } catch (error) { - logger.error( + mcpLogger.error( `Failed to restart connection for ${serverName}:`, error instanceof Error ? error.message : String(error) ); @@ -488,9 +489,9 @@ export class McpService extends Service { this.compatibilityInitialized = true; if (this.toolCompatibility) { - logger.info(`Tool compatibility enabled`); + mcpLogger.info(`Tool compatibility enabled`); } else { - logger.info(`No tool compatibility needed`); + mcpLogger.info(`No tool compatibility needed`); } } @@ -506,7 +507,7 @@ export class McpService extends Service { try { return this.toolCompatibility.transformToolSchema(toolSchema); } catch (error) { - logger.warn(`Tool compatibility transformation failed:`, error); + mcpLogger.warn(`Tool compatibility transformation failed:`, error); return toolSchema; // Fall back to original schema } } diff --git a/src/types.ts b/src/types.ts index e74a1f2..901bdd7 100644 --- a/src/types.ts +++ b/src/types.ts @@ -123,54 +123,6 @@ export interface McpProviderData { [serverName: string]: McpServerInfo; } -export const ToolSelectionSchema = { - type: "object", - required: ["serverName", "toolName", "arguments"], - properties: { - serverName: { - type: "string", - minLength: 1, - errorMessage: "serverName must not be empty", - }, - toolName: { - type: "string", - minLength: 1, - errorMessage: "toolName must not be empty", - }, - arguments: { - type: "object", - }, - reasoning: { - type: "string", - }, - noToolAvailable: { - type: "boolean", - }, - }, -}; - -export const ResourceSelectionSchema = { - type: "object", - required: ["serverName", "uri"], - properties: { - serverName: { - type: "string", - minLength: 1, - errorMessage: "serverName must not be empty", - }, - uri: { - type: "string", - minLength: 1, - errorMessage: "uri must not be empty", - }, - reasoning: { - type: "string", - }, - noResourceAvailable: { - type: "boolean", - }, - }, -}; export const DEFAULT_PING_CONFIG: PingConfig = { enabled: true, diff --git a/src/utils/error.ts b/src/utils/error.ts index 988ca50..54d15f1 100644 --- a/src/utils/error.ts +++ b/src/utils/error.ts @@ -4,24 +4,34 @@ import { type Memory, ModelType, composePromptFromState, - logger, } from "@elizaos/core"; import type { State } from "@elizaos/core"; -import { errorAnalysisPrompt } from "../templates/errorAnalysisPrompt"; -import type { McpProvider } from "../types"; +import { errorAnalysisPrompt } from "@/templates/errorAnalysisPrompt"; +import type { McpProvider } from "@/types"; +import { mcpLogger } from "./mcp-logger"; -export async function handleMcpError( - state: State, - mcpProvider: McpProvider, - error: unknown, - runtime: IAgentRuntime, - message: Memory, - type: "tool" | "resource", - callback?: HandlerCallback -): Promise { +interface HandleMcpErrorOptions { + state: State; + mcpProvider: McpProvider; + error: unknown; + runtime: IAgentRuntime; + message: Memory; + type: 'tool' | 'resource'; + callback?: HandlerCallback; +} + +export async function handleMcpError({ + state, + mcpProvider, + error, + runtime, + message, + type, + callback, +}: HandleMcpErrorOptions): Promise { const errorMessage = error instanceof Error ? error.message : String(error); - logger.error(`Error executing MCP ${type}: ${errorMessage}`, error); + mcpLogger.error(`Error executing MCP ${type}: ${errorMessage}`, error); if (callback) { const enhancedState: State = { @@ -50,7 +60,7 @@ export async function handleMcpError( actions: ["REPLY"], }); } catch (modelError) { - logger.error( + mcpLogger.error( "Failed to generate error response:", modelError instanceof Error ? modelError.message : String(modelError) ); diff --git a/src/utils/handler.ts b/src/utils/handler.ts deleted file mode 100644 index 225f582..0000000 --- a/src/utils/handler.ts +++ /dev/null @@ -1,18 +0,0 @@ -import type { HandlerCallback } from "@elizaos/core"; - -export function handleNoToolAvailable( - callback?: HandlerCallback, - // biome-ignore lint/suspicious/noExplicitAny: - toolSelection?: Record | null -): boolean { - if (callback && toolSelection?.noToolAvailable) { - callback({ - text: "I don't have a specific tool that can help with that request. Let me try to assist you directly instead.", - thought: - "No appropriate MCP tool available for this request. Falling back to direct assistance.", - actions: ["REPLY"], - }); - } - - return true; -} diff --git a/src/utils/handlers.ts b/src/utils/handlers.ts new file mode 100644 index 0000000..abb8c4d --- /dev/null +++ b/src/utils/handlers.ts @@ -0,0 +1,115 @@ +import { resourceAnalysisTemplate } from '@/templates/resourceAnalysisTemplate'; +import { createMcpMemory } from '@/utils/mcp'; +import { mcpLogger } from '@/utils/mcp-logger'; +import { + type HandlerCallback, + type IAgentRuntime, + type Memory, + ModelType, + type State, + composePromptFromState, +} from '@elizaos/core'; + +interface HandleResourceAnalysisParams { + runtime: IAgentRuntime; + message: Memory; + uri: string; + serverName: string; + resourceContent: string; + resourceMeta: string; + callback?: HandlerCallback; +} + +/** + * Handles the analysis of a resource fetched from an MCP server. + * @param runtime - The agent runtime instance. + * @param message - The message object containing the user's request. + * @param uri - The URI of the resource. + * @param serverName - The name of the MCP server. + * @param resourceContent - The content of the resource. + * @param resourceMeta - The metadata of the resource. + * @param callback - Optional callback function to handle the response. + * @returns A promise that resolves when the analysis is complete. + */ +export async function handleResourceAnalysis({ + runtime, + message, + uri, + serverName, + resourceContent, + resourceMeta, + callback, +}: HandleResourceAnalysisParams): Promise { + // Create a memory entry for the resource + mcpLogger.debug(`[HANDLER] Creating memory entry for resource: ${uri} on server: ${serverName}`); + await createMcpMemory(runtime, message, 'resource', serverName, resourceContent, { + uri, + isResourceAccess: true, + }); + + // Generate a thoughtful response based on the resource content + mcpLogger.debug(`[HANDLER] Generating response based on resource content: ${uri} on server: ${serverName}`); + const analysisPrompt = createAnalysisPrompt(uri, message.content.text || '', resourceContent, resourceMeta); + + mcpLogger.debug(`[HANDLER] Analysis prompt created: ${analysisPrompt}`); + const analyzedResponse = await runtime.useModel(ModelType.TEXT_LARGE, { + prompt: analysisPrompt, + }); + + if (callback) { + await callback({ + text: analyzedResponse, + thought: `I analyzed the content from the ${uri} resource on ${serverName} and crafted a thoughtful response that addresses the user's request while maintaining my conversational style.`, + actions: ['READ_MCP_RESOURCE'], + }); + } +} + +function createAnalysisPrompt(uri: string, userMessage: string, resourceContent: string, resourceMeta: string): string { + const enhancedState: State = { + data: {}, + text: '', + values: { + uri, + userMessage, + resourceContent, + resourceMeta, + }, + }; + + return composePromptFromState({ + state: enhancedState, + template: resourceAnalysisTemplate, + }); +} + +export function handleNoResourceAvailable(callback?: HandlerCallback): boolean { + if (callback) { + callback({ + text: "I don't have a specific resource that contains the information you're looking for. Let me try to assist you directly instead.", + thought: 'No appropriate MCP resource available for this request. Falling back to direct assistance.', + actions: ['REPLY'], + }); + } + + return true; +} + + +export function handleNoToolAvailable( + callback?: HandlerCallback, + // biome-ignore lint/suspicious/noExplicitAny: + toolSelection?: Record | null +): boolean { + if (callback && toolSelection?.noToolAvailable) { + callback({ + text: "I don't have a specific tool that can help with that request. Let me try to assist you directly instead.", + thought: + "No appropriate MCP tool available for this request. Falling back to direct assistance.", + actions: ["REPLY"], + }); + } + + return true; +} + diff --git a/src/utils/mcp-logger.ts b/src/utils/mcp-logger.ts new file mode 100644 index 0000000..80ea579 --- /dev/null +++ b/src/utils/mcp-logger.ts @@ -0,0 +1,9 @@ +import { logger } from '@elizaos/core'; + +export const mcpLogger = { + trace: (message: string, ...args: unknown[]) => logger.trace(`[MCP] ${message}`, ...args), + debug: (message: string, ...args: unknown[]) => logger.debug(`[MCP] ${message}`, ...args), + info: (message: string, ...args: unknown[]) => logger.info(`[MCP] ${message}`, ...args), + warn: (message: string, ...args: unknown[]) => logger.warn(`[MCP] ${message}`, ...args), + error: (message: string, ...args: unknown[]) => logger.error(`[MCP] ${message}`, ...args), +}; diff --git a/src/utils/mcp.ts b/src/utils/mcp.ts index 965c2b8..e21a51c 100644 --- a/src/utils/mcp.ts +++ b/src/utils/mcp.ts @@ -5,7 +5,7 @@ import type { McpResourceInfo, McpServer, McpToolInfo, -} from "../types"; +} from "@/types"; export async function createMcpMemory( runtime: IAgentRuntime, diff --git a/src/utils/processing.ts b/src/utils/processing.ts index f82c800..7eaffbe 100644 --- a/src/utils/processing.ts +++ b/src/utils/processing.ts @@ -7,12 +7,13 @@ import { type Memory, ModelType, createUniqueUuid, - logger, } from '@elizaos/core'; +import { mcpLogger } from "./mcp-logger"; import { type State, composePromptFromState } from '@elizaos/core'; -import { resourceAnalysisTemplate } from '../templates/resourceAnalysisTemplate'; -import { toolReasoningTemplate } from '../templates/toolReasoningTemplate'; +import { toolReasoningTemplate } from '@/templates/toolReasoningTemplate'; import { createMcpMemory } from './mcp'; +import type { CallToolResult } from '@modelcontextprotocol/sdk/types.js'; +import type { McpProvider } from '@/types'; function getMimeTypeToContentType(mimeType?: string): ContentType | undefined { if (!mimeType) return undefined; @@ -55,26 +56,27 @@ export function processResourceResult( return { resourceContent, resourceMeta }; } -export function processToolResult( - result: { - content: Array<{ - type: string; - text?: string; - mimeType?: string; - data?: string; - resource?: { - uri: string; - text?: string; - blob?: string; - }; - }>; - isError?: boolean; - }, - serverName: string, - toolName: string, - runtime: IAgentRuntime, - messageEntityId: string -): { toolOutput: string; hasAttachments: boolean; attachments: Media[] } { +interface ProcessToolResultOptions { + runtime: IAgentRuntime; + result: CallToolResult; + serverName: string; + toolName: string; + messageEntityId: string; +} + +interface ProcessedToolResult { + toolOutput: string; + hasAttachments: boolean; + attachments: Media[]; +} + +export function processToolResult({ + runtime, + result, + serverName, + toolName, + messageEntityId, +}: ProcessToolResultOptions): ProcessedToolResult { let toolOutput = ''; let hasAttachments = false; const attachments: Media[] = []; @@ -106,60 +108,36 @@ export function processToolResult( return { toolOutput, hasAttachments, attachments }; } -export async function handleResourceAnalysis( - runtime: IAgentRuntime, - message: Memory, - uri: string, - serverName: string, - resourceContent: string, - resourceMeta: string, - callback?: HandlerCallback -): Promise { - await createMcpMemory(runtime, message, 'resource', serverName, resourceContent, { - uri, - isResourceAccess: true, - }); - - const analysisPrompt = createAnalysisPrompt( - uri, - message.content.text || '', - resourceContent, - resourceMeta - ); - - const analyzedResponse = await runtime.useModel(ModelType.TEXT_SMALL, { - prompt: analysisPrompt, - }); - - if (callback) { - await callback({ - text: analyzedResponse, - thought: `I analyzed the content from the ${uri} resource on ${serverName} and crafted a thoughtful response that addresses the user's request while maintaining my conversational style.`, - actions: ['READ_MCP_RESOURCE'], - }); - } +interface HandleToolResponse { + runtime: IAgentRuntime; + state: State; + message: Memory; + serverName: string; + toolName: string; + toolArguments: Record; + toolOutput: string; + hasAttachments: boolean; + attachments: Media[]; + mcpProvider: McpProvider; + callback?: HandlerCallback; } -export async function handleToolResponse( - runtime: IAgentRuntime, - message: Memory, - serverName: string, - toolName: string, - toolArgs: Record, - toolOutput: string, - hasAttachments: boolean, - attachments: Media[], - state: State, - mcpProvider: { - values: { mcp: unknown }; - data: { mcp: unknown }; - text: string; - }, - callback?: HandlerCallback -): Promise { +export async function handleToolResponse({ + runtime, + state, + message, + serverName, + toolName, + toolArguments, + toolOutput, + hasAttachments, + attachments, + mcpProvider, + callback, +}: HandleToolResponse): Promise { await createMcpMemory(runtime, message, 'tool', serverName, toolOutput, { toolName, - arguments: toolArgs, + arguments: toolArguments, isToolCall: true, }); @@ -173,7 +151,7 @@ export async function handleToolResponse( hasAttachments ); - logger.info('reasoning prompt: ', reasoningPrompt); + mcpLogger.info('reasoning prompt: ', reasoningPrompt); const reasonedResponse = await runtime.useModel(ModelType.TEXT_SMALL, { prompt: reasoningPrompt, @@ -216,28 +194,6 @@ export async function sendInitialResponse(callback?: HandlerCallback): Promise({ runtime, @@ -89,14 +90,14 @@ export async function createToolSelectionArgument({ toolSelectionName, }: CreateToolSelectionOptions): Promise { if (!toolSelectionName) { - logger.warn( + mcpLogger.warn( "[SELECTION] Tool selection name is not provided. Cannot create tool selection argument." ); return null; } const { serverName, toolName } = toolSelectionName; const toolInputSchema = mcpProvider.data.mcp[serverName].tools[toolName].inputSchema; - logger.trace(`[SELECTION] Tool Input Schema:\n${JSON.stringify({ toolInputSchema }, null, 2)}`); + mcpLogger.trace(`[SELECTION] Tool Input Schema:\n${JSON.stringify({ toolInputSchema }, null, 2)}`); // Create a tool selection argument prompt const toolSelectionArgumentPrompt: string = composePromptFromState({ @@ -110,13 +111,13 @@ export async function createToolSelectionArgument({ }, template: toolSelectionArgumentTemplate, }); - logger.debug(`[SELECTION] Tool Selection Prompt:\n${toolSelectionArgumentPrompt}`); + mcpLogger.debug(`[SELECTION] Tool Selection Prompt:\n${toolSelectionArgumentPrompt}`); // Use the model to generate a tool selection argument stringified json response const toolSelectionArgument: string = await runtime.useModel(ModelType.TEXT_LARGE, { prompt: toolSelectionArgumentPrompt, }); - logger.debug(`[SELECTION] Tool Selection Argument Response:\n${toolSelectionArgument}`); + mcpLogger.debug(`[SELECTION] Tool Selection Argument Response:\n${toolSelectionArgument}`); return await withModelRetry({ runtime, @@ -158,7 +159,7 @@ function createToolSelectionFeedbackPrompt( toolsDescription, userMessage ); - logger.debug(`[SELECTION] Tool Selection Feedback Prompt:\n${feedbackPrompt}`); + mcpLogger.debug(`[SELECTION] Tool Selection Feedback Prompt:\n${feedbackPrompt}`); return feedbackPrompt; } @@ -180,3 +181,112 @@ function createFeedbackPrompt( User request: ${userMessage}`; } + + +interface CreateResourceSelection { + runtime: IAgentRuntime; + state: State; + message: Memory; + callback?: HandlerCallback; +} + +export async function createResourceSelection({ + runtime, + state, + message, + callback, +}: CreateResourceSelection): Promise { + // Select appropriate prompt + mcpLogger.info('[SELECTION] Selecting resource based on the current state...'); + const resourceSelectionPrompt = createResourceSelectionPrompt({ + state, + userMessage: message.content.text || '', + }); + mcpLogger.info(`[SELECTION] Resource Selection Prompt: ${resourceSelectionPrompt}`); + + // Call the model to get the resource selection + mcpLogger.info('[SELECTION] Calling model to get resource selection...'); + const resourceSelection = await runtime.useModel(ModelType.OBJECT_LARGE, { + prompt: resourceSelectionPrompt, + }); + mcpLogger.info(`[SELECTION] Resource Selection Response: ${resourceSelection}`); + + const parsedSelection = await withModelRetry({ + runtime, + state, + message, + callback, + input: resourceSelection, + validationFn: (data) => validateResourceSelection(data), + createFeedbackPromptFn: (originalResponse, errorMessage, state, userMessage) => + createResourceSelectionFeedbackPrompt(originalResponse, errorMessage, state, userMessage), + failureMsg: `I'm having trouble finding the resource you're looking for. Could you provide more details about what you need?`, + retryCount: 0, + }); + mcpLogger.info(`[SELECTION] Parsed Resource Selection: ${JSON.stringify(parsedSelection)}`); + + return parsedSelection; +} + +interface CreateResourceSelectionPromptOptions { + state: State; + userMessage: string; +} + +function createResourceSelectionPrompt({ state, userMessage }: CreateResourceSelectionPromptOptions): string { + const mcpData = state.values.mcp || {}; + const serverNames = Object.keys(mcpData); + + let resourcesDescription = ''; + for (const serverName of serverNames) { + const server = mcpData[serverName]; + if (server.status !== 'connected') continue; + + const resourceUris = Object.keys(server.resources || {}); + for (const uri of resourceUris) { + const resource = server.resources[uri]; + resourcesDescription += `Resource: ${uri} (Server: ${serverName})\n`; + resourcesDescription += `Name: ${resource.name || 'No name available'}\n`; + resourcesDescription += `Description: ${resource.description || 'No description available'}\n`; + resourcesDescription += `MIME Type: ${resource.mimeType || 'Not specified'}\n\n`; + } + } + + const enhancedState: State = { + ...state, + values: { + ...state.values, + resourcesDescription, + userMessage, + }, + }; + + return composePromptFromState({ + state: enhancedState, + template: resourceSelectionTemplate, + }); +} + +function createResourceSelectionFeedbackPrompt( + originalResponse: string | object, + errorMessage: string, + state: State, + userMessage: string, +): string { + let resourcesDescription = ''; + + for (const [serverName, server] of Object.entries(state.values.mcp || {}) as [string, McpProviderData[string]][]) { + if (server.status !== 'connected') continue; + + for (const [uri, resource] of Object.entries(server.resources || {}) as [ + string, + { description?: string; name?: string }, + ][]) { + resourcesDescription += `Resource: ${uri} (Server: ${serverName})\n`; + resourcesDescription += `Name: ${resource.name || 'No name available'}\n`; + resourcesDescription += `Description: ${resource.description || 'No description available'}\n\n`; + } + } + + return createFeedbackPrompt(originalResponse, errorMessage, 'resource', resourcesDescription, userMessage); +} diff --git a/src/utils/use-action.ts b/src/utils/use-action.ts new file mode 100644 index 0000000..c19b00d --- /dev/null +++ b/src/utils/use-action.ts @@ -0,0 +1,59 @@ +import { MCP_SERVICE_NAME, type McpProvider } from '@/types'; +import { mcpLogger } from './mcp-logger'; +import type { McpService } from '@/service'; +import type { HandlerCallback, IAgentRuntime, Memory, State } from '@elizaos/core'; + +export type HandlersOptions = { + [key: string]: unknown; +}; + +export interface UseActionHandlerOptions { + actionName: string; + runtime: IAgentRuntime; + message: Memory; + state?: State; + options?: HandlersOptions; + callback?: HandlerCallback; +} + +export interface ActionHandlerContext { + runtime: IAgentRuntime; + message: Memory; + state: State; + options?: HandlersOptions; + callback?: HandlerCallback; + mcpService: McpService; + mcpProvider: McpProvider; +} + +export async function useActionHandler({ + actionName, + runtime, + message, + state, + options, + callback, +}: UseActionHandlerOptions): Promise { + mcpLogger.info(`[USE-ACTION] [${actionName}] Starting handler with message: "${message.content.text}"`); + + const composedState = await runtime.composeState(message); + + const mcpService = runtime.getService(MCP_SERVICE_NAME); + if (!mcpService) { + throw new Error(`[USE-ACTION] [${actionName}] MCP service not available`); + } + + // Get MCP provider data + const mcpProvider = mcpService.getProviderData(); + mcpLogger.trace(`[USE-ACTION] [${actionName}] Provider Data: ${mcpProvider}`); + + return { + runtime, + message, + state: composedState, + options, + callback, + mcpService, + mcpProvider, + } satisfies ActionHandlerContext; +} diff --git a/src/utils/validation.ts b/src/utils/validation.ts index fe71c75..501349c 100644 --- a/src/utils/validation.ts +++ b/src/utils/validation.ts @@ -1,17 +1,20 @@ -import type { State } from "@elizaos/core"; +import type { IAgentRuntime, Memory, State } from "@elizaos/core"; import { + MCP_SERVICE_NAME, type McpProviderData, type McpServer, - ResourceSelectionSchema, type ValidationResult, -} from "../types"; +} from "@/types"; import { validateJsonSchema } from "./json"; import { + ResourceSelectionSchema, toolSelectionArgumentSchema, toolSelectionNameSchema, type ToolSelectionArgument, type ToolSelectionName, } from "./schemas"; +import type { McpService } from "@/service"; +import { mcpLogger } from "./mcp-logger"; export interface ToolSelection { serverName: string; @@ -182,3 +185,26 @@ ${itemsDescription} User request: ${userMessage}`; } + +export function validateAction(actionName: string, runtime: IAgentRuntime, _message: Memory, _state?: State): boolean { + try { + const mcpService = runtime.getService(MCP_SERVICE_NAME); + if (!mcpService) { + mcpLogger.warn('[VALIDATE] service not available'); + return false; + } + + const servers = mcpService.getServers(); + mcpLogger.info(`[VALIDATE] [${actionName}] Found ${servers.length} servers`); + servers.forEach((s) => mcpLogger.debug(`\t- ${s.name} | ${s.status}`)); + mcpLogger.trace(`[VALIDATE] [${actionName}] servers:\n${JSON.stringify(servers)}`); + + const hasConnectedServersWithTools = servers.some((s) => s.status === 'connected' && s.tools && s.tools.length > 0); + mcpLogger.info(`[VALIDATE] [${actionName}] Connected servers with tools: "${hasConnectedServersWithTools}"`); + + return servers.length > 0 && hasConnectedServersWithTools; + } catch (error) { + mcpLogger.error(`[VALIDATE] [${actionName}] Error in action validation:`, error); + return false; + } +} diff --git a/src/utils/wrapper.ts b/src/utils/wrapper.ts index 46485fb..bf56a07 100644 --- a/src/utils/wrapper.ts +++ b/src/utils/wrapper.ts @@ -5,10 +5,10 @@ import { type Memory, type IAgentRuntime, type State, - logger, ModelType, } from "@elizaos/core"; import { DEFAULT_MAX_RETRIES, type ValidationResult } from "../types"; +import { mcpLogger } from "./mcp-logger"; export type Input = string | object; @@ -54,11 +54,11 @@ export async function withModelRetry({ const maxRetries = getMaxRetries(runtime); try { - logger.info(`[WITH-MODEL-RETRY] Raw selection input:\n${input}`); + mcpLogger.info(`[WITH-MODEL-RETRY] Raw selection input:\n${input}`); // If it's a first retry, input is a string, so we need to parse it const parsedJson = typeof input === "string" ? parseJSON(input) : input; - logger.debug( + mcpLogger.debug( `[WITH-MODEL-RETRY] Parsed selection input:\n${JSON.stringify(parsedJson, null, 2)}` ); @@ -72,10 +72,10 @@ export async function withModelRetry({ } catch (parseError) { const errorMessage = parseError instanceof Error ? parseError.message : "Unknown parsing error"; - logger.error(`[WITH-MODEL-RETRY] Failed to parse response: ${errorMessage}`); + mcpLogger.error(`[WITH-MODEL-RETRY] Failed to parse response: ${errorMessage}`); if (retryCount < maxRetries) { - logger.debug(`[WITH-MODEL-RETRY] Retrying (attempt ${retryCount + 1}/${maxRetries})`); + mcpLogger.debug(`[WITH-MODEL-RETRY] Retrying (attempt ${retryCount + 1}/${maxRetries})`); const feedbackPrompt: string = createFeedbackPromptFn( input, @@ -124,12 +124,12 @@ function getMaxRetries(runtime: IAgentRuntime): number { if (settings && "maxRetries" in settings && settings.maxRetries !== undefined) { const configValue = Number(settings.maxRetries); if (!Number.isNaN(configValue) && configValue >= 0) { - logger.debug(`[WITH-MODEL-RETRY] Using configured selection retries: ${configValue}`); + mcpLogger.debug(`[WITH-MODEL-RETRY] Using configured selection retries: ${configValue}`); return configValue; } } } catch (error) { - logger.debug( + mcpLogger.debug( "[WITH-MODEL-RETRY] Error reading selection retries config:", error instanceof Error ? error.message : String(error) ); diff --git a/tsconfig.json b/tsconfig.json index 9091d6a..5a4f45d 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -23,6 +23,10 @@ "noPropertyAccessFromIndexSignature": false, "declaration": true, "outDir": "./dist", - "rootDir": "./src" + "rootDir": "./src", + "paths": { + "@/*": ["./src/*"], + "#/*": ["./"] + } } } \ No newline at end of file