Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 79 additions & 12 deletions src/acp-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ type Session = {
promptRunning: boolean;
pendingMessages: Map<string, { resolve: (cancelled: boolean) => void; order: number }>;
nextPendingOrder: number;
messageIdMap: Map<string, string>; // ACP user messageId -> preceding SDK assistant UUID
lastAssistantUuid?: string;
};

type BackgroundTerminal =
Expand Down Expand Up @@ -375,6 +377,18 @@ export class ClaudeAcpAgent implements Agent {
}

async unstable_forkSession(params: ForkSessionRequest): Promise<ForkSessionResponse> {
// Support forking at a specific message for "edit previous message" UX.
// Zed passes atMessageId via _meta to indicate where to fork.
const atMessageId = (params._meta as any)?.atMessageId as string | undefined;
let resumeSessionAt: string | undefined;
if (atMessageId) {
const sourceSession = this.sessions[params.sessionId];
const precedingAssistantUuid = sourceSession?.messageIdMap.get(atMessageId);
if (precedingAssistantUuid) {
resumeSessionAt = precedingAssistantUuid;
}
}

const response = await this.createSession(
{
cwd: params.cwd,
Expand All @@ -384,6 +398,7 @@ export class ClaudeAcpAgent implements Agent {
{
resume: params.sessionId,
forkSession: true,
...(resumeSessionAt && { resumeSessionAt }),
},
);
// Needs to happen after we return the session
Expand Down Expand Up @@ -478,19 +493,26 @@ export class ClaudeAcpAgent implements Agent {
};

let lastAssistantTotalUsage: number | null = null;
const userMessageId = params.messageId ?? undefined;

const userMessage = promptToClaude(params);

// Track message ID mapping: this user message -> preceding assistant UUID
if (params.messageId && session.lastAssistantUuid) {
session.messageIdMap.set(params.messageId, session.lastAssistantUuid);
}

if (session.promptRunning) {
const uuid = randomUUID();
const uuid =
(params.messageId as `${string}-${string}-${string}-${string}-${string}`) ?? randomUUID();
userMessage.uuid = uuid;
session.input.push(userMessage);
const order = session.nextPendingOrder++;
const cancelled = await new Promise<boolean>((resolve) => {
session.pendingMessages.set(uuid, { resolve, order });
});
if (cancelled) {
return { stopReason: "cancelled" };
return { stopReason: "cancelled", userMessageId };
}
} else {
session.input.push(userMessage);
Expand All @@ -505,7 +527,7 @@ export class ClaudeAcpAgent implements Agent {

if (done || !message) {
if (session.cancelled) {
return { stopReason: "cancelled" };
return { stopReason: "cancelled", userMessageId };
}
break;
}
Expand Down Expand Up @@ -567,7 +589,7 @@ export class ClaudeAcpAgent implements Agent {
break;
case "result": {
if (session.cancelled) {
return { stopReason: "cancelled" };
return { stopReason: "cancelled", userMessageId };
}

// Accumulate usage from this result
Expand Down Expand Up @@ -616,24 +638,40 @@ export class ClaudeAcpAgent implements Agent {
throw RequestError.authRequired();
}
if (message.stop_reason === "max_tokens") {
return { stopReason: "max_tokens", usage };
return {
stopReason: "max_tokens",
usage,
userMessageId,
};
}
if (message.is_error) {
throw RequestError.internalError(undefined, message.result);
}
return { stopReason: "end_turn", usage };
return {
stopReason: "end_turn",
usage,
userMessageId,
};
}
case "error_during_execution":
if (message.stop_reason === "max_tokens") {
return { stopReason: "max_tokens", usage };
return {
stopReason: "max_tokens",
usage,
userMessageId,
};
}
if (message.is_error) {
throw RequestError.internalError(
undefined,
message.errors.join(", ") || message.subtype,
);
}
return { stopReason: "end_turn", usage };
return {
stopReason: "end_turn",
usage,
userMessageId,
};
case "error_max_budget_usd":
case "error_max_turns":
case "error_max_structured_output_retries":
Expand All @@ -643,7 +681,11 @@ export class ClaudeAcpAgent implements Agent {
message.errors.join(", ") || message.subtype,
);
}
return { stopReason: "max_turn_requests", usage };
return {
stopReason: "max_turn_requests",
usage,
userMessageId,
};
default:
unreachable(message, this.logger);
break;
Expand All @@ -660,6 +702,7 @@ export class ClaudeAcpAgent implements Agent {
{
clientCapabilities: this.clientCapabilities,
cwd: session.cwd,
messageId: session.lastAssistantUuid,
},
)) {
await this.client.sessionUpdate(notification);
Expand All @@ -681,14 +724,24 @@ export class ClaudeAcpAgent implements Agent {
handedOff = true;
// the current loop stops with end_turn,
// the loop of the next prompt continues running
return { stopReason: "end_turn" };
return { stopReason: "end_turn", userMessageId };
}
if ("isReplay" in message && message.isReplay) {
// not pending or unrelated replay message
break;
}
}

// Track top-level assistant message UUIDs for message editing (fork-at support)
if (
message.type === "assistant" &&
message.parent_tool_use_id === null &&
"uuid" in message &&
message.uuid
) {
session.lastAssistantUuid = message.uuid as string;
}

// Store latest assistant usage (excluding subagents)
if ((message.message as any).usage && message.parent_tool_use_id === null) {
const messageWithUsage = message.message as unknown as SDKResultMessage;
Expand Down Expand Up @@ -773,6 +826,7 @@ export class ClaudeAcpAgent implements Agent {
clientCapabilities: this.clientCapabilities,
parentToolUseId: message.parent_tool_use_id,
cwd: session.cwd,
messageId: "uuid" in message ? (message.uuid as string) : undefined,
},
)) {
await this.client.sessionUpdate(notification);
Expand Down Expand Up @@ -1139,7 +1193,7 @@ export class ClaudeAcpAgent implements Agent {

private async createSession(
params: NewSessionRequest,
creationOpts: { resume?: string; forkSession?: boolean } = {},
creationOpts: { resume?: string; forkSession?: boolean; resumeSessionAt?: string } = {},
): Promise<NewSessionResponse> {
// We want to create a new session id unless it is resume,
// but not resume + forkSession.
Expand Down Expand Up @@ -1376,6 +1430,7 @@ export class ClaudeAcpAgent implements Agent {
promptRunning: false,
pendingMessages: new Map(),
nextPendingOrder: 0,
messageIdMap: new Map(),
};

return {
Expand Down Expand Up @@ -1653,7 +1708,7 @@ export function promptToClaude(prompt: PromptRequest): SDKUserMessage {

content.push(...context);

return {
const msg: SDKUserMessage = {
type: "user",
message: {
role: "user",
Expand All @@ -1662,6 +1717,10 @@ export function promptToClaude(prompt: PromptRequest): SDKUserMessage {
session_id: prompt.sessionId,
parent_tool_use_id: null,
};
if (prompt.messageId) {
msg.uuid = prompt.messageId as `${string}-${string}-${string}-${string}-${string}`;
}
return msg;
}

/**
Expand All @@ -1680,6 +1739,7 @@ export function toAcpNotifications(
clientCapabilities?: ClientCapabilities;
parentToolUseId?: string | null;
cwd?: string;
messageId?: string;
},
): SessionNotification[] {
const registerHooks = options?.registerHooks !== false;
Expand All @@ -1691,6 +1751,7 @@ export function toAcpNotifications(
type: "text",
text: content,
},
...(options?.messageId && { messageId: options.messageId }),
};

if (options?.parentToolUseId) {
Expand Down Expand Up @@ -1719,6 +1780,7 @@ export function toAcpNotifications(
type: "text",
text: chunk.text,
},
...(options?.messageId && { messageId: options.messageId }),
};
break;
case "image":
Expand All @@ -1730,6 +1792,7 @@ export function toAcpNotifications(
mimeType: chunk.source.type === "base64" ? chunk.source.media_type : "",
uri: chunk.source.type === "url" ? chunk.source.url : undefined,
},
...(options?.messageId && { messageId: options.messageId }),
};
break;
case "thinking":
Expand All @@ -1740,6 +1803,7 @@ export function toAcpNotifications(
type: "text",
text: chunk.thinking,
},
...(options?.messageId && { messageId: options.messageId }),
};
break;
case "tool_use":
Expand Down Expand Up @@ -1935,6 +1999,7 @@ export function streamEventToAcpNotifications(
options?: {
clientCapabilities?: ClientCapabilities;
cwd?: string;
messageId?: string;
},
): SessionNotification[] {
const event = message.event;
Expand All @@ -1951,6 +2016,7 @@ export function streamEventToAcpNotifications(
clientCapabilities: options?.clientCapabilities,
parentToolUseId: message.parent_tool_use_id,
cwd: options?.cwd,
messageId: options?.messageId,
},
);
case "content_block_delta":
Expand All @@ -1965,6 +2031,7 @@ export function streamEventToAcpNotifications(
clientCapabilities: options?.clientCapabilities,
parentToolUseId: message.parent_tool_use_id,
cwd: options?.cwd,
messageId: options?.messageId,
},
);
// No content
Expand Down
1 change: 1 addition & 0 deletions src/tests/acp-agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,7 @@ describe("stop reason propagation", () => {
promptRunning: false,
pendingMessages: new Map(),
nextPendingOrder: 0,
messageIdMap: new Map(),
};
}

Expand Down