Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 69 additions & 36 deletions extensions/control.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* - Retrieve the last assistant message from a session
* - Get AI-generated summaries of session activity
* - Clear/rewind sessions to their initial state
* - Subscribe to turn_end events for async coordination
* - Subscribe to turn_end and agent_end events for async coordination
*
* Once loaded the extension registers a `send_to_session` tool that allows the AI to
* communicate with other pi sessions programmatically.
Expand All @@ -36,7 +36,7 @@
* - { type: "get_summary" }
* - { type: "clear", summarize?: boolean }
* - { type: "abort" }
* - { type: "subscribe", event: "turn_end" }
* - { type: "subscribe", event: "turn_end" | "agent_end" }
*
* Responses are JSON objects with { type: "response", command, success, data?, error? }
* Events are JSON objects with { type: "event", event, data?, subscriptionId? }
Expand All @@ -46,6 +46,7 @@ import type {
ExtensionAPI,
ExtensionContext,
TurnEndEvent,
AgentEndEvent,
MessageRenderer,
ModelRegistry,
} from "@earendil-works/pi-coding-agent";
Expand Down Expand Up @@ -121,7 +122,7 @@ interface RpcAbortCommand {

interface RpcSubscribeCommand {
type: "subscribe";
event: "turn_end";
event: "turn_end" | "agent_end";
id?: string;
}

Expand All @@ -137,7 +138,7 @@ type RpcCommand =
// Subscription Management
// ============================================================================

interface TurnEndSubscription {
interface EventSubscription {
socket: net.Socket;
subscriptionId: string;
}
Expand All @@ -148,7 +149,8 @@ interface SocketState {
context: ExtensionContext | null;
alias: string | null;
aliasTimer: ReturnType<typeof setInterval> | null;
turnEndSubscriptions: TurnEndSubscription[];
turnEndSubscriptions: EventSubscription[];
agentEndSubscriptions: EventSubscription[];
}

// ============================================================================
Expand Down Expand Up @@ -617,20 +619,23 @@ async function handleCommand(
return;
}

// Subscribe to turn_end
// Subscribe to turn_end or agent_end
if (command.type === "subscribe") {
if (command.event === "turn_end") {
if (command.event === "turn_end" || command.event === "agent_end") {
const subscriptionId = id ?? `sub_${Date.now()}_${Math.random().toString(36).slice(2, 8)}`;
state.turnEndSubscriptions.push({ socket, subscriptionId });
const subscriptions = command.event === "turn_end"
? state.turnEndSubscriptions
: state.agentEndSubscriptions;
subscriptions.push({ socket, subscriptionId });

const cleanup = () => {
const idx = state.turnEndSubscriptions.findIndex((s) => s.subscriptionId === subscriptionId);
if (idx !== -1) state.turnEndSubscriptions.splice(idx, 1);
const idx = subscriptions.findIndex((s) => s.subscriptionId === subscriptionId);
if (idx !== -1) subscriptions.splice(idx, 1);
};
socket.once("close", cleanup);
socket.once("error", cleanup);

respond(true, "subscribe", { subscriptionId, event: "turn_end" });
respond(true, "subscribe", { subscriptionId, event: command.event });
return;
}
respond(false, "subscribe", undefined, `Unknown event type: ${command.event}`);
Expand Down Expand Up @@ -823,7 +828,7 @@ async function createServer(pi: ExtensionAPI, state: SocketState, socketPath: st

interface RpcClientOptions {
timeout?: number;
waitForEvent?: "turn_end";
waitForEvent?: "turn_end" | "agent_end";
}

async function sendRpcCommand(
Expand All @@ -850,13 +855,15 @@ async function sendRpcCommand(
};

socket.on("connect", () => {
socket.write(`${JSON.stringify(command)}\n`);

// If waiting for turn_end, also subscribe
if (waitForEvent === "turn_end") {
const subscribeCmd: RpcSubscribeCommand = { type: "subscribe", event: "turn_end" };
// If waiting for event, subscribe FIRST (atomic subscribe-before-send).
// The subscription must be registered before the send triggers the
// agent turn, so that agent_end / turn_end is guaranteed to be caught.
if (waitForEvent) {
const subscribeCmd: RpcSubscribeCommand = { type: "subscribe", event: waitForEvent };
socket.write(`${JSON.stringify(subscribeCmd)}\n`);
}

socket.write(`${JSON.stringify(command)}\n`);
});

socket.on("data", (chunk) => {
Expand Down Expand Up @@ -887,8 +894,8 @@ async function sendRpcCommand(
continue;
}

// Handle turn_end event
if (msg.type === "event" && msg.event === "turn_end" && waitForEvent === "turn_end") {
// Handle events
if (msg.type === "event" && msg.event === waitForEvent) {
cleanup();
socket.end();
if (!response) {
Expand Down Expand Up @@ -944,6 +951,7 @@ async function stopControlServer(state: SocketState): Promise<void> {
const socketPath = state.socketPath;
state.socketPath = null;
state.turnEndSubscriptions = [];
state.agentEndSubscriptions = [];
await new Promise<void>((resolve) => state.server?.close(() => resolve()));
state.server = null;
await removeAliasesForSocket(socketPath);
Expand Down Expand Up @@ -1004,7 +1012,7 @@ export default function (pi: ExtensionAPI) {
default: "steer",
});
pi.registerFlag(CONTROL_SEND_WAIT_FLAG, {
description: "Startup send wait mode: turn_end or message_processed",
description: "Startup send wait mode: turn_end, agent_end, or message_processed",
type: "string",
});
pi.registerFlag(CONTROL_SEND_INCLUDE_SENDER_FLAG, {
Expand All @@ -1021,6 +1029,7 @@ export default function (pi: ExtensionAPI) {
alias: null,
aliasTimer: null,
turnEndSubscriptions: [],
agentEndSubscriptions: [],
};

pi.registerMessageRenderer(SESSION_MESSAGE_TYPE, renderSessionMessage);
Expand Down Expand Up @@ -1093,6 +1102,28 @@ export default function (pi: ExtensionAPI) {
});
}
});

// Fire agent_end events to subscribers
pi.on("agent_end", (_event: AgentEndEvent, ctx: ExtensionContext) => {
if (state.agentEndSubscriptions.length === 0) return;

void syncAlias(state, ctx);
const lastMessage = getLastAssistantMessage(ctx);
const eventData = { message: lastMessage };

// Fire to all subscribers (one-shot)
const subscriptions = [...state.agentEndSubscriptions];
state.agentEndSubscriptions = [];

for (const sub of subscriptions) {
writeEvent(sub.socket, {
type: "event",
event: "agent_end",
data: eventData,
subscriptionId: sub.subscriptionId,
});
}
});
}

// ============================================================================
Expand All @@ -1117,6 +1148,7 @@ Target selection:

Wait behavior (only for action=send):
- wait_until=turn_end: Wait for the turn to complete, returns last assistant message.
- wait_until=agent_end: Wait for the agent loop to complete (may span multiple turns), returns last assistant message.
- wait_until=message_processed: Returns immediately after message is queued.

CLI bridge (for shell scripts/background jobs):
Expand All @@ -1128,7 +1160,7 @@ CLI bridge (for shell scripts/background jobs):
--control-session <session-name|session-id>
--send-session-message <text>
--send-session-mode <steer|follow_up> (optional, default: steer)
--send-session-wait <turn_end|message_processed> (optional)
--send-session-wait <turn_end|agent_end|message_processed> (optional)
--send-session-include-sender-info (optional, advanced; default: off)
- Startup sends are one-way by default (no sender_info), which avoids reply attempts to short-lived 'pi -p' sender sessions.
- If a script needs a response, use --send-session-wait turn_end and read stdout.
Expand Down Expand Up @@ -1157,7 +1189,7 @@ Messages automatically include sender session info for replies. When you want a
}),
),
wait_until: Type.Optional(
StringEnum(["turn_end", "message_processed"] as const, {
StringEnum(["turn_end", "agent_end", "message_processed"] as const, {
description: "Wait behavior for send action",
}),
),
Expand Down Expand Up @@ -1312,11 +1344,11 @@ Messages automatically include sender session info for replies. When you want a
};
}

if (params.wait_until === "turn_end") {
// Send and wait for turn to complete
if (params.wait_until === "turn_end" || params.wait_until === "agent_end") {
// Send and wait for turn/agent to complete
const result = await sendRpcCommand(socketPath, sendCommand, {
timeout: 300000, // 5 minutes
waitForEvent: "turn_end",
waitForEvent: params.wait_until,
});

if (!result.response.success) {
Expand All @@ -1330,14 +1362,14 @@ Messages automatically include sender session info for replies. When you want a
const lastMessage = result.event?.message;
if (!lastMessage) {
return {
content: [{ type: "text", text: "Turn completed but no assistant message found" }],
details: { turnIndex: result.event?.turnIndex },
content: [{ type: "text", text: "Completed but no assistant message found" }],
details: result.event,
};
}

return {
content: [{ type: "text", text: lastMessage.content }],
details: { message: lastMessage, turnIndex: result.event?.turnIndex },
details: { message: lastMessage, ...result.event },
};
}

Expand Down Expand Up @@ -1419,7 +1451,7 @@ Messages automatically include sender session info for replies. When you want a
const hasCleared = details && "cleared" in details;
const hasTurnIndex = details && "turnIndex" in details;

// get_message or turn_end result with message
// get_message, turn_end, or agent_end result with message
if (hasMessage) {
const message = details.message as ExtractedMessage;
const icon = theme.fg("success", "✓");
Expand Down Expand Up @@ -1539,7 +1571,7 @@ type StartupControlSendOptions = {
target: string;
message: string;
mode: "steer" | "follow_up";
waitUntil?: "turn_end" | "message_processed";
waitUntil?: "turn_end" | "agent_end" | "message_processed";
includeSenderInfo: boolean;
};

Expand All @@ -1550,9 +1582,10 @@ function normalizeMode(raw: string): "steer" | "follow_up" | null {
return null;
}

function normalizeWaitUntil(raw: string): "turn_end" | "message_processed" | null {
function normalizeWaitUntil(raw: string): "turn_end" | "agent_end" | "message_processed" | null {
const value = raw.trim().toLowerCase();
if (value === "turn_end" || value === "turn-end") return "turn_end";
if (value === "agent_end" || value === "agent-end") return "agent_end";
if (value === "message_processed" || value === "message-processed") return "message_processed";
return null;
}
Expand Down Expand Up @@ -1585,12 +1618,12 @@ function parseStartupControlSendOptions(pi: ExtensionAPI): { options?: StartupCo
}

const rawWait = getStringFlag(pi, CONTROL_SEND_WAIT_FLAG);
let waitUntil: "turn_end" | "message_processed" | undefined;
let waitUntil: "turn_end" | "agent_end" | "message_processed" | undefined;
if (rawWait) {
const normalized = normalizeWaitUntil(rawWait);
if (!normalized) {
return {
error: `Invalid --${CONTROL_SEND_WAIT_FLAG}: ${rawWait}. Use turn_end|message_processed.`,
error: `Invalid --${CONTROL_SEND_WAIT_FLAG}: ${rawWait}. Use turn_end|agent_end|message_processed.`,
};
}
waitUntil = normalized;
Expand Down Expand Up @@ -1668,18 +1701,18 @@ async function maybeHandleStartupControlSend(pi: ExtensionAPI, ctx: ExtensionCon
};

try {
if (waitUntil === "turn_end") {
if (waitUntil === "turn_end" || waitUntil === "agent_end") {
const result = await sendRpcCommand(socketPath, sendCommand, {
timeout: 300000,
waitForEvent: "turn_end",
waitForEvent: waitUntil,
});
if (!result.response.success) {
reportStartupControlSend(ctx, `Failed to send: ${result.response.error ?? "unknown error"}`, "error");
return;
}
const lastMessage = result.event?.message;
if (!lastMessage?.content) {
reportStartupControlSend(ctx, `Message delivered to ${target}; turn completed without assistant output.`);
reportStartupControlSend(ctx, `Message delivered to ${target}; completed without assistant output.`);
return;
}
if (ctx.hasUI) {
Expand Down