Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ interface GatewayStreamEvent extends Omit<AgentStreamEvent, "type"> {
}
import { useQueryClient } from "@tanstack/react-query"
import type { FileUIPart } from "ai"
import { useCallback, useState } from "react"
import { useCallback, useRef, useState } from "react"
import { toast } from "sonner"
import type {
ChatAttachment,
Expand All @@ -40,6 +40,7 @@ type UseChatStreamReturn = {
streamingreasoning: string
streamingartifacts: FileArtifact[]
handlesubmit: (message: SubmitMessage) => Promise<void>
handlecancel: () => void
}

const useChatStream = (
Expand All @@ -51,6 +52,9 @@ const useChatStream = (
const { mutateAsync: saveMessage } = useSaveMessage()
const { data: session } = useSession()

/** Ref to the active WebSocket so handlecancel can reach it across renders. */
const wsRef = useRef<WebSocket | null>(null)

const [streaming, setStreaming] = useState(false)
const [streamingcontent, setStreamingcontent] = useState("")
const [streamingtoolcalls, setStreamingtoolcalls] = useState<StreamToolCall[]>([])
Expand Down Expand Up @@ -138,6 +142,7 @@ const useChatStream = (
try {
const wsurl = `${GATEWAY_URL.replace(/^http/, "ws")}/ws`
const ws = new WebSocket(wsurl)
wsRef.current = ws
let fullcontent = ""
let fullreasoning = ""
const toolcalls = new Map<string, StreamToolCall>()
Expand Down Expand Up @@ -293,6 +298,7 @@ const useChatStream = (
}

ws.onclose = () => {
wsRef.current = null
resolve()
}
})
Expand Down Expand Up @@ -345,6 +351,13 @@ const useChatStream = (
[conversationid, conversation, participants, queryClient, saveMessage, session],
)

const handlecancel = useCallback(() => {
const ws = wsRef.current
if (!ws || ws.readyState !== WebSocket.OPEN) return
ws.send(JSON.stringify({ type: "cancel", sessionId: conversationid }))
ws.close()
}, [conversationid])

return {
streaming,
streamingcontent,
Expand All @@ -353,6 +366,7 @@ const useChatStream = (
streamingreasoning,
streamingartifacts,
handlesubmit,
handlecancel,
}
}

Expand Down
3 changes: 3 additions & 0 deletions apps/web/src/components/organisms/chat-view/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const ChatView = () => {
streamingreasoning,
streamingartifacts,
handlesubmit,
handlecancel,
} = useChatStream(conversationid, conversation, participants)

const { artifacts, hasfiles } = useSessionArtifacts(conversationid, streamingartifacts)
Expand Down Expand Up @@ -94,6 +95,7 @@ const ChatView = () => {
<div className="w-full max-w-2xl">
<PromptInput
handlesubmit={handlesubmit}
handlecancel={handlecancel}
hasmessages={hasmessages}
textarearef={textarearef as React.RefObject<HTMLTextAreaElement>}
streaming={streaming}
Expand Down Expand Up @@ -157,6 +159,7 @@ const ChatView = () => {
<div className="max-w-3xl mx-auto w-full px-4 py-3">
<PromptInput
handlesubmit={handlesubmit}
handlecancel={handlecancel}
hasmessages={hasmessages}
textarearef={textarearef as React.RefObject<HTMLTextAreaElement>}
streaming={streaming}
Expand Down
14 changes: 13 additions & 1 deletion apps/web/src/components/organisms/chat-view/prompt-input.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ import { useCallback, useRef, useState } from "react"

const PromptInput = ({
handlesubmit,
handlecancel,
hasmessages,
textarearef,
streaming,
}: {
handlesubmit: (msg: { text: string; files: FileUIPart[] }) => void
handlecancel?: () => void
hasmessages: boolean
textarearef: React.RefObject<HTMLTextAreaElement>
streaming: boolean
Expand Down Expand Up @@ -136,7 +138,17 @@ const PromptInput = ({
</PopoverContent>
</Popover>
</PromptInputTools>
<PromptInputSubmit disabled={streaming} status={streaming ? "streaming" : undefined} />
<PromptInputSubmit
status={streaming ? "streaming" : undefined}
onClick={
streaming && handlecancel
? (e) => {
e.preventDefault()
handlecancel()
}
: undefined
}
/>
</PromptInputFooter>
</PromptInputComponent>
</div>
Expand Down
4 changes: 3 additions & 1 deletion infra/openshell/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ RUN npm install -g \
chart.js \
chartjs-node-canvas \
d3 \
exceljs
exceljs \
pi-interactive-shell

# ---------------------------------------------------------------------------
# Stage 2: Runtime
Expand All @@ -97,6 +98,7 @@ FROM node:22-slim
# tini: PID 1 / signal handling.
# git, curl, wget, jq, less, tree, unzip, zip: general-purpose CLI tools for the agent.
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
curl \
git \
iproute2 \
Expand Down
9 changes: 9 additions & 0 deletions packages/gateway/src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,15 @@ export const createApp = (
return c.json({ role: "assistant", content: text })
})

app.post("/api/v1/sessions/:id/cancel", requirePermission("sessions", "write"), async (c) => {
const session = sessionManager.getSession(c.req.param("id"))
if (!session) {
return c.json({ error: "Session not found" }, 404)
}
const cancelled = await sessionManager.cancelSession(c.req.param("id"))
return c.json({ ok: true, cancelled })
})

app.get("/api/v1/sessions/:id/messages", requirePermission("sessions", "read"), (c) => {
const session = sessionManager.getSession(c.req.param("id"))
if (!session) {
Expand Down
14 changes: 14 additions & 0 deletions packages/gateway/src/session-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,20 @@ export class SessionManager {
return this.sessions.get(id)?.workspaceDir
}

/**
* Cancel the active turn for a session.
*
* In orchestrator mode, delegates to the sandbox-server's cancel endpoint.
* In local mode, cancellation is handled via the AbortSignal passed to
* sendMessage (the WebSocket handler in ws.ts manages those controllers).
*/
async cancelSession(id: string): Promise<boolean> {
if (this.orchestrator) {
return this.orchestrator.cancelSession(id)
}
return false
}

deleteSession(id: string): boolean {
if (this.orchestrator) {
const session = this.sessions.get(id)
Expand Down
111 changes: 75 additions & 36 deletions packages/sandbox-server/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ export class SandboxAgentManager {
private snapshots = new Map<string, Map<string, FileSnapshot>>()
/** Human-readable folder name per session, derived from the first user message. */
private sessionLabels = new Map<string, string>()
/** Abort controllers for active turns, keyed by sessionId. */
private activeTurnControllers = new Map<string, AbortController>()

/**
* Create a new agent session.
Expand Down Expand Up @@ -163,55 +165,73 @@ export class SandboxAgentManager {
this.sessionLabels.set(sessionId, sanitizeFolderName(content, sessionId))
}

// Create an internal controller for this turn so cancelSession() can abort
// it independently of the SSE stream disconnect signal.
const turnController = new AbortController()
this.activeTurnControllers.set(sessionId, turnController)

// Propagate the external abort signal (SSE client disconnect) into the turn controller.
const onExternalAbort = () => turnController.abort()
if (signal?.aborted) {
turnController.abort()
} else {
signal?.addEventListener("abort", onExternalAbort, { once: true })
}

// Take initial snapshot for artifact detection
let snapshot = this.snapshots.get(sessionId) ?? createSnapshot(WORKSPACE_DIR)

const t0 = Date.now()
let yieldCount = 0
log.info("[DIAG-AGM] starting for-await on session.sendMessage()", { sessionId })

for await (const event of session.sendMessage(content, signal)) {
yieldCount++
if (event.type !== "message_update" && event.type !== "thinking_update" && event.type !== "tool_call_update") {
log.info("[DIAG-AGM] received event from PiAgentSession", {
sessionId,
yieldCount,
type: event.type,
ms: Date.now() - t0,
})
}
yield event

// After a tool call ends, scan for new output files
if (event.type === "tool_call_end") {
const scanStart = Date.now()
const result = this.scanForArtifacts(sessionId, snapshot)
const scanMs = Date.now() - scanStart
if (scanMs > 100) {
log.warn("[DIAG-AGM] slow artifact scan", { sessionId, scanMs })
try {
for await (const event of session.sendMessage(content, turnController.signal)) {
yieldCount++
if (event.type !== "message_update" && event.type !== "thinking_update" && event.type !== "tool_call_update") {
log.info("[DIAG-AGM] received event from PiAgentSession", {
sessionId,
yieldCount,
type: event.type,
ms: Date.now() - t0,
})
}
if (result) {
snapshot = result.newSnapshot
yield { type: "file_output", artifacts: result.artifacts }
yield event

// After a tool call ends, scan for new output files
if (event.type === "tool_call_end") {
const scanStart = Date.now()
const result = this.scanForArtifacts(sessionId, snapshot)
const scanMs = Date.now() - scanStart
if (scanMs > 100) {
log.warn("[DIAG-AGM] slow artifact scan", { sessionId, scanMs })
}
if (result) {
snapshot = result.newSnapshot
yield { type: "file_output", artifacts: result.artifacts }
}
}
}
}

log.info("[DIAG-AGM] for-await loop completed", {
sessionId,
yieldCount,
durationMs: Date.now() - t0,
})
log.info("[DIAG-AGM] for-await loop completed", {
sessionId,
yieldCount,
durationMs: Date.now() - t0,
})

// Final scan after the turn completes to catch stragglers
const finalResult = this.scanForArtifacts(sessionId, snapshot)
if (finalResult) {
snapshot = finalResult.newSnapshot
yield { type: "file_output", artifacts: finalResult.artifacts }
}

// Final scan after the turn completes to catch stragglers
const finalResult = this.scanForArtifacts(sessionId, snapshot)
if (finalResult) {
snapshot = finalResult.newSnapshot
yield { type: "file_output", artifacts: finalResult.artifacts }
// Persist snapshot for next turn
this.snapshots.set(sessionId, snapshot)
} finally {
signal?.removeEventListener("abort", onExternalAbort)
this.activeTurnControllers.delete(sessionId)
}

// Persist snapshot for next turn
this.snapshots.set(sessionId, snapshot)
}

/**
Expand All @@ -237,6 +257,24 @@ export class SandboxAgentManager {
return { newSnapshot, artifacts }
}

/**
* Cancel the active turn for a session.
*
* Aborts the in-flight LLM call / tool execution without destroying the
* session or its message history. The session remains usable for further
* messages after cancellation.
*
* Returns true if there was an active turn to cancel, false if the session
* has no in-progress turn.
*/
cancelSession(sessionId: string): boolean {
const controller = this.activeTurnControllers.get(sessionId)
if (!controller) return false
controller.abort()
this.activeTurnControllers.delete(sessionId)
return true
}

/**
* Check if a session exists.
*/
Expand All @@ -248,6 +286,7 @@ export class SandboxAgentManager {
* Delete a session.
*/
deleteSession(sessionId: string): boolean {
this.cancelSession(sessionId)
this.snapshots.delete(sessionId)
this.sessionLabels.delete(sessionId)
return this.sessions.delete(sessionId)
Expand Down
13 changes: 6 additions & 7 deletions packages/sandbox-server/src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -414,20 +414,19 @@ export function createSandboxApp(): Hono {
})

/**
* POST /sessions/:id/cancel -- cancel the current turn.
* POST /sessions/:id/cancel -- cancel the active turn for a session.
*
* This is a placeholder. The actual cancellation happens when the client
* disconnects from the SSE stream (abort signal fires).
* Aborts any in-flight LLM call or tool execution for the session without
* destroying the session or its message history. The session remains usable
* for further messages after cancellation.
*/
app.post("/sessions/:id/cancel", (c) => {
const sessionId = c.req.param("id")
if (!agent.hasSession(sessionId)) {
return c.json({ error: "Session not found" }, 404)
}
// Cancellation is handled by the SSE abort signal in sendMessage.
// A dedicated cancel mechanism would require tracking active generators,
// which will be added if needed.
return c.json({ ok: true })
const cancelled = agent.cancelSession(sessionId)
return c.json({ ok: true, cancelled })
})

// -----------------------------------------------------------------------
Expand Down
Loading