diff --git a/src/acp-agent.ts b/src/acp-agent.ts index b4590779..5585f064 100644 --- a/src/acp-agent.ts +++ b/src/acp-agent.ts @@ -99,11 +99,11 @@ export interface Logger { error: (...args: any[]) => void; } -type AccumulatedUsage = { - inputTokens: number; - outputTokens: number; - cachedReadTokens: number; - cachedWriteTokens: number; +type PromptWaiter = { + uuid: string; + sessionId: string; + resolve: (result: PromptResponse) => void; + reject: (error: Error) => void; }; type Session = { @@ -113,11 +113,9 @@ type Session = { cwd: string; permissionMode: PermissionMode; settingsManager: SettingsManager; - accumulatedUsage: AccumulatedUsage; configOptions: SessionConfigOption[]; - promptRunning: boolean; - pendingMessages: Map void; order: number }>; - nextPendingOrder: number; + activePrompt: PromptWaiter | null; + promptQueue: PromptWaiter[]; }; type BackgroundTerminal = @@ -487,49 +485,48 @@ export class ClaudeAcpAgent implements Agent { } session.cancelled = false; - session.accumulatedUsage = { - inputTokens: 0, - outputTokens: 0, - cachedReadTokens: 0, - cachedWriteTokens: 0, - }; - - let lastAssistantTotalUsage: number | null = null; const userMessage = promptToClaude(params); - const promptUuid = randomUUID(); userMessage.uuid = promptUuid; - let promptReplayed = false; + session.input.push(userMessage); - if (session.promptRunning) { - session.input.push(userMessage); - const order = session.nextPendingOrder++; - const cancelled = await new Promise((resolve) => { - session.pendingMessages.set(promptUuid, { resolve, order }); + return new Promise((resolve, reject) => { + session.promptQueue.push({ + uuid: promptUuid, + sessionId: params.sessionId, + resolve, + reject, }); - if (cancelled) { - return { stopReason: "cancelled" }; - } - // The replay resolved the promise, mark in this loop too, - // so we don't treat the next result as a background task's result. - promptReplayed = true; - } else { - session.input.push(userMessage); - } + }); + } + + /** + * Background loop that continuously processes SDK messages for a session. + * Started when the session is created, runs independently of prompt() calls. + * Resolves/rejects prompt waiters as results and replays arrive. + */ + private async processSessionMessages(sessionId: string): Promise { + const session = this.sessions[sessionId]; + if (!session) return; - session.promptRunning = true; - let handedOff = false; + let lastAssistantTotalUsage: number | null = null; try { while (true) { const { value: message, done } = await session.query.next(); if (done || !message) { - if (session.cancelled) { - return { stopReason: "cancelled" }; + // Query stream ended + if (session.activePrompt) { + session.activePrompt.resolve({ stopReason: "end_turn" }); + session.activePrompt = null; + } + for (const waiter of session.promptQueue) { + waiter.reject(new Error("Session ended")); } + session.promptQueue = []; break; } @@ -589,12 +586,6 @@ export class ClaudeAcpAgent implements Agent { } break; case "result": { - // Accumulate usage from this result - session.accumulatedUsage.inputTokens += message.usage.input_tokens; - session.accumulatedUsage.outputTokens += message.usage.output_tokens; - session.accumulatedUsage.cachedReadTokens += message.usage.cache_read_input_tokens; - session.accumulatedUsage.cachedWriteTokens += message.usage.cache_creation_input_tokens; - // Calculate context window size from modelUsage (minimum across all models used) const contextWindows = Object.values(message.modelUsage).map((m) => m.contextWindow); const contextWindowSize = @@ -603,7 +594,7 @@ export class ClaudeAcpAgent implements Agent { // Send usage_update notification if (lastAssistantTotalUsage !== null) { await this.client.sessionUpdate({ - sessionId: params.sessionId, + sessionId, update: { sessionUpdate: "usage_update", used: lastAssistantTotalUsage, @@ -616,65 +607,80 @@ export class ClaudeAcpAgent implements Agent { }); } - if (!promptReplayed) { - // This result is from a background task that finished after - // the previous prompt loop ended. Consume it and continue - // waiting for our own prompt's result. - this.logger.log(`Session ${params.sessionId}: consuming background task result`); + if (!session.activePrompt) { + // No active prompt — background task result, consume it + this.logger.log(`Session ${sessionId}: consuming background task result`); break; } + const waiter = session.activePrompt; + session.activePrompt = null; + if (session.cancelled) { - return { stopReason: "cancelled" }; + waiter.resolve({ stopReason: "cancelled" }); + break; } - // Build the usage response + // Build usage directly from the result message const usage: PromptResponse["usage"] = { - inputTokens: session.accumulatedUsage.inputTokens, - outputTokens: session.accumulatedUsage.outputTokens, - cachedReadTokens: session.accumulatedUsage.cachedReadTokens, - cachedWriteTokens: session.accumulatedUsage.cachedWriteTokens, + inputTokens: message.usage.input_tokens, + outputTokens: message.usage.output_tokens, + cachedReadTokens: message.usage.cache_read_input_tokens, + cachedWriteTokens: message.usage.cache_creation_input_tokens, totalTokens: - session.accumulatedUsage.inputTokens + - session.accumulatedUsage.outputTokens + - session.accumulatedUsage.cachedReadTokens + - session.accumulatedUsage.cachedWriteTokens, + message.usage.input_tokens + + message.usage.output_tokens + + message.usage.cache_read_input_tokens + + message.usage.cache_creation_input_tokens, }; switch (message.subtype) { case "success": { if (message.result.includes("Please run /login")) { - throw RequestError.authRequired(); + waiter.reject(RequestError.authRequired()); + break; } if (message.stop_reason === "max_tokens") { - return { stopReason: "max_tokens", usage }; + waiter.resolve({ stopReason: "max_tokens", usage }); + break; } if (message.is_error) { - throw RequestError.internalError(undefined, message.result); + waiter.reject(RequestError.internalError(undefined, message.result)); + break; } - return { stopReason: "end_turn", usage }; + waiter.resolve({ stopReason: "end_turn", usage }); + break; } case "error_during_execution": if (message.stop_reason === "max_tokens") { - return { stopReason: "max_tokens", usage }; + waiter.resolve({ stopReason: "max_tokens", usage }); + break; } if (message.is_error) { - throw RequestError.internalError( - undefined, - message.errors.join(", ") || message.subtype, + waiter.reject( + RequestError.internalError( + undefined, + message.errors.join(", ") || message.subtype, + ), ); + break; } - return { stopReason: "end_turn", usage }; + waiter.resolve({ stopReason: "end_turn", usage }); + break; case "error_max_budget_usd": case "error_max_turns": case "error_max_structured_output_retries": if (message.is_error) { - throw RequestError.internalError( - undefined, - message.errors.join(", ") || message.subtype, + waiter.reject( + RequestError.internalError( + undefined, + message.errors.join(", ") || message.subtype, + ), ); + break; } - return { stopReason: "max_turn_requests", usage }; + waiter.resolve({ stopReason: "max_turn_requests", usage }); + break; default: unreachable(message, this.logger); break; @@ -684,7 +690,7 @@ export class ClaudeAcpAgent implements Agent { case "stream_event": { for (const notification of streamEventToAcpNotifications( message, - params.sessionId, + sessionId, this.toolUseCache, this.client, this.logger, @@ -705,21 +711,19 @@ export class ClaudeAcpAgent implements Agent { // Check for prompt replay if (message.type === "user" && "uuid" in message && message.uuid) { - if (message.uuid === promptUuid) { - // Our own prompt was replayed back — we're now processing - // our prompt's response (not a background task's). - promptReplayed = true; + // This should usually be 0, as prompt order is expected to be preserved. + const queueIndex = session.promptQueue.findIndex((w) => w.uuid === message.uuid); + if (queueIndex >= 0) { + // Finish current active prompt (interrupted by new prompt) + if (session.activePrompt) { + session.activePrompt.resolve({ stopReason: "end_turn" }); + } + // Activate the matched prompt + const waiter = session.promptQueue.splice(queueIndex, 1)[0]; + session.activePrompt = waiter; + lastAssistantTotalUsage = null; break; } - const pending = session.pendingMessages.get(message.uuid as string); - if (pending) { - pending.resolve(false); - session.pendingMessages.delete(message.uuid as string); - handedOff = true; - // the current loop stops with end_turn, - // the loop of the next prompt continues running - return { stopReason: "end_turn" }; - } if ("isReplay" in message && message.isReplay) { // not pending or unrelated replay message break; @@ -749,7 +753,7 @@ export class ClaudeAcpAgent implements Agent { .replace("", "") .replace("", ""), "assistant", - params.sessionId, + sessionId, this.toolUseCache, this.client, this.logger, @@ -788,7 +792,13 @@ export class ClaudeAcpAgent implements Agent { message.message.content[0].type === "text" && message.message.content[0].text.includes("Please run /login") ) { - throw RequestError.authRequired(); + // Reject the active prompt if there is one + if (session.activePrompt) { + const waiter = session.activePrompt; + session.activePrompt = null; + waiter.reject(RequestError.authRequired()); + } + break; } const content = @@ -802,7 +812,7 @@ export class ClaudeAcpAgent implements Agent { for (const notification of toAcpNotifications( content, message.message.role, - params.sessionId, + sessionId, this.toolUseCache, this.client, this.logger, @@ -827,45 +837,45 @@ export class ClaudeAcpAgent implements Agent { break; } } - throw new Error("Session did not end in result"); } catch (error) { - if (error instanceof RequestError || !(error instanceof Error)) { - throw error; - } - const message = error.message; - if ( - message.includes("ProcessTransport") || - message.includes("terminated process") || - message.includes("process exited with") || - message.includes("process terminated by signal") || - message.includes("Failed to write to process stdin") - ) { - this.logger.error(`Session ${params.sessionId}: Claude Agent process died: ${message}`); - session.input.end(); - delete this.sessions[params.sessionId]; - throw RequestError.internalError( - undefined, - "The Claude Agent process exited unexpectedly. Please start a new session.", - ); + const rejectError = this.toSessionError(sessionId, session, error); + + if (session.activePrompt) { + const waiter = session.activePrompt; + session.activePrompt = null; + waiter.reject(rejectError); } - throw error; - } finally { - if (!handedOff) { - session.promptRunning = false; - // This usually should not happen, but in case the loop finishes - // without claude sending all message replays, we resolve the - // next pending prompt call to ensure no prompts get stuck. - if (session.pendingMessages.size > 0) { - const next = [...session.pendingMessages.entries()].sort( - (a, b) => a[1].order - b[1].order, - )[0]; - if (next) { - next[1].resolve(false); - session.pendingMessages.delete(next[0]); - } - } + for (const waiter of session.promptQueue) { + waiter.reject(rejectError); } + session.promptQueue = []; + } + } + + private toSessionError(sessionId: string, session: Session, error: unknown): Error { + if (error instanceof RequestError) { + return error; + } + if (!(error instanceof Error)) { + return error as Error; + } + const message = error.message; + if ( + message.includes("ProcessTransport") || + message.includes("terminated process") || + message.includes("process exited with") || + message.includes("process terminated by signal") || + message.includes("Failed to write to process stdin") + ) { + this.logger.error(`Session ${sessionId}: Claude Agent process died: ${message}`); + session.input.end(); + delete this.sessions[sessionId]; + return RequestError.internalError( + undefined, + "The Claude Agent process exited unexpectedly. Please start a new session.", + ); } + return error; } async cancel(params: CancelNotification): Promise { @@ -874,10 +884,17 @@ export class ClaudeAcpAgent implements Agent { throw new Error("Session not found"); } session.cancelled = true; - for (const [, pending] of session.pendingMessages) { - pending.resolve(true); + + // Immediately resolve active prompt and all queued prompts + if (session.activePrompt) { + session.activePrompt.resolve({ stopReason: "cancelled" }); + session.activePrompt = null; + } + for (const waiter of session.promptQueue) { + waiter.resolve({ stopReason: "cancelled" }); } - session.pendingMessages.clear(); + session.promptQueue = []; + await session.query.interrupt(); } @@ -1404,18 +1421,17 @@ export class ClaudeAcpAgent implements Agent { cwd: params.cwd, permissionMode, settingsManager, - accumulatedUsage: { - inputTokens: 0, - outputTokens: 0, - cachedReadTokens: 0, - cachedWriteTokens: 0, - }, configOptions, - promptRunning: false, - pendingMessages: new Map(), - nextPendingOrder: 0, + activePrompt: null, + promptQueue: [], }; + // Start the background message processing loop + // TODO: depending on the error, we might want to recursively call this + this.processSessionMessages(sessionId).catch((err) => { + this.logger.error(`Session ${sessionId}: message processing loop error:`, err); + }); + return { sessionId, models, diff --git a/src/tests/acp-agent.test.ts b/src/tests/acp-agent.test.ts index 50ee83cd..d245600e 100644 --- a/src/tests/acp-agent.test.ts +++ b/src/tests/acp-agent.test.ts @@ -1,4 +1,4 @@ -import { describe, it, expect, beforeAll, afterAll } from "vitest"; +import { describe, it, expect, beforeAll, afterAll, vi } from "vitest"; import { spawn, spawnSync } from "child_process"; import { Agent, @@ -1325,17 +1325,12 @@ describe("stop reason propagation", () => { cwd: "/test", permissionMode: "default", settingsManager: {} as any, - accumulatedUsage: { - inputTokens: 0, - outputTokens: 0, - cachedReadTokens: 0, - cachedWriteTokens: 0, - }, configOptions: [], - promptRunning: false, - pendingMessages: new Map(), - nextPendingOrder: 0, + activePrompt: null, + promptQueue: [], }; + // Start the background message processing loop + (agent as any).processSessionMessages("test-session"); } it("should return max_tokens when success result has stop_reason max_tokens", async () => { @@ -1451,17 +1446,86 @@ describe("stop reason propagation", () => { cancelled: false, permissionMode: "default", settingsManager: {} as any, - accumulatedUsage: { - inputTokens: 0, - outputTokens: 0, - cachedReadTokens: 0, - cachedWriteTokens: 0, - }, configOptions: [], - promptRunning: false, - pendingMessages: new Map(), - nextPendingOrder: 0, + activePrompt: null, + promptQueue: [], + }; + // Start the background message processing loop + (agent as any).processSessionMessages("test-session"); + + const response = await agent.prompt({ + sessionId: "test-session", + prompt: [{ type: "text", text: "test" }], + }); + + expect(response.stopReason).toBe("end_turn"); + // Usage reflects only the prompt's own result, not background tasks + expect(response.usage?.inputTokens).toBe(promptResult.usage.input_tokens); + expect(response.usage?.outputTokens).toBe(promptResult.usage.output_tokens); + }); + + it("should send sessionUpdate notifications for background task stream events", async () => { + const sessionUpdateFn = vi.fn().mockResolvedValue(undefined); + const mockClient = { + sessionUpdate: sessionUpdateFn, + } as unknown as AgentSideConnection; + const agent = new ClaudeAcpAgent(mockClient, { log: () => {}, error: () => {} }); + const input = new Pushable(); + + const promptResult = createResultMessage({ + subtype: "success", + stop_reason: null, + is_error: false, + }); + + const backgroundTaskResult = createResultMessage({ + subtype: "success", + stop_reason: null, + is_error: false, + }); + + async function* messageGenerator() { + // Background task stream event + result arrive before our prompt's replay + yield { + type: "stream_event", + session_id: "test-session", + parent_tool_use_id: null, + event: { + type: "content_block_start", + index: 0, + content_block: { type: "text", text: "background work" }, + }, + }; + yield backgroundTaskResult; + + // Now the prompt's user message replay arrives + const iter = input[Symbol.asyncIterator](); + const { value: userMessage } = await iter.next(); + yield { + type: "user", + message: userMessage.message, + parent_tool_use_id: null, + uuid: userMessage.uuid, + session_id: "test-session", + isReplay: true, + }; + + // Then the prompt's own result + yield promptResult; + } + + agent.sessions["test-session"] = { + query: messageGenerator() as any, + cwd: "/tmp/test", + input, + cancelled: false, + permissionMode: "default", + settingsManager: {} as any, + configOptions: [], + activePrompt: null, + promptQueue: [], }; + (agent as any).processSessionMessages("test-session"); const response = await agent.prompt({ sessionId: "test-session", @@ -1469,13 +1533,15 @@ describe("stop reason propagation", () => { }); expect(response.stopReason).toBe("end_turn"); - // Usage should include both background task and prompt result tokens - expect(response.usage?.inputTokens).toBe( - backgroundTaskResult.usage.input_tokens + promptResult.usage.input_tokens, + // sessionUpdate should have been called for the background stream event + const streamCalls = sessionUpdateFn.mock.calls.filter( + (call: any) => call[0]?.update?.sessionUpdate === "agent_message_chunk", ); - expect(response.usage?.outputTokens).toBe( - backgroundTaskResult.usage.output_tokens + promptResult.usage.output_tokens, + expect(streamCalls.length).toBeGreaterThan(0); + const hasBackgroundWork = streamCalls.some((call: any) => + call[0]?.update?.content?.text?.includes("background work"), ); + expect(hasBackgroundWork).toBe(true); }); it("should throw internal error for success with is_error true and no max_tokens", async () => {