diff --git a/convex/ai/resilience.ts b/convex/ai/resilience.ts index 9cb45025..ed3ab706 100644 --- a/convex/ai/resilience.ts +++ b/convex/ai/resilience.ts @@ -12,21 +12,22 @@ import { COACH_MAX_STEPS } from "./coach"; import { type ProviderId } from "./providers"; import { type AttemptOutcome, runWithPrimaryCircuitBreaker } from "./resilienceCircuitBreaker"; import { type AccumulatorInit, RunAccumulator } from "./runTelemetry"; -import { buildByokErrorMessage, classifyByokError } from "./byokErrors"; +import { classifyByokError } from "./byokErrors"; import { runInRunSpan } from "./otel"; +import { isQuotaError, isTransientError } from "./transientErrors"; import { - buildProviderTransientMessage, - classifyTransientError, - isQuotaError, - isTransientError, -} from "./transientErrors"; + getFinalizeCodeForError, + safeFinalizePending, + safeReportError, + safeTryReportByok, +} from "./resilienceReporting"; // Re-export for backwards compatibility with existing callers/tests. export { buildByokErrorMessage, classifyByokError, withByokErrorSanitization } from "./byokErrors"; export type { ByokErrorCode } from "./byokErrors"; export { isTransientError } from "./transientErrors"; +export { getFinalizeCodeForError } from "./resilienceReporting"; -const AI_ERROR_MESSAGE = "I'm having trouble right now. Please try again in a moment."; const BUDGET_CAP_MESSAGE = "This is getting expensive on your API key, so I'm simplifying here. Ask a narrower follow-up if you want me to keep going."; const MAX_OUTPUT_TOKENS = 4096; @@ -262,6 +263,15 @@ async function attemptStream({ // Telemetry must never fail the LLM turn. } }, + onError: async ({ error }: { error: unknown }) => { + // @convex-dev/agent@0.6.1 does not catch stream error events in its + // internal finalizeMessage mutation — the raw provider error propagates + // as an unhandled exception that Convex reports to Sentry. Pre-empting + // with a sanitized finalize code here prevents the agent library from + // encountering the error-event delta when its mutation runs, because + // a message that is already "failed" skips further delta processing. + await safeFinalizePending(ctx, threadId, getFinalizeCodeForError(error)); + }, }, STREAM_OPTIONS, ); @@ -310,91 +320,3 @@ function buildTelemetryConfig(telemetry: TelemetryArgs): TelemetrySettings { metadata, }; } - -interface ErrorReport { - threadId: string; - userId: string; - error: unknown; - isByok: boolean; - provider: ProviderId; -} - -// streamText's abortSignal handler finalizes on clean aborts; provider errors -// thrown from result.text bypass that path and leave a stranded pending row. -async function finalizePendingMessages( - ctx: ActionCtx, - threadId: string, - reason: string, -): Promise { - const result = await ctx.runQuery(components.agent.messages.listMessagesByThreadId, { - threadId, - paginationOpts: { cursor: null, numItems: 50 }, - order: "desc", - }); - for (const message of result.page) { - if (message.status !== "pending") continue; - await ctx.runMutation(components.agent.messages.finalizeMessage, { - messageId: message._id, - result: { status: "failed", error: reason }, - }); - } -} - -// Best-effort wrappers must not leave users with stuck pending messages. -const safeFinalizePending = (ctx: ActionCtx, threadId: string, reason: string) => - finalizePendingMessages(ctx, threadId, reason).catch(() => undefined); -const safeReportError = (ctx: ActionCtx, report: ErrorReport) => - reportError(ctx, report).catch(() => undefined); -const safeTryReportByok = (ctx: ActionCtx, report: ErrorReport) => - tryReportByok(ctx, report).catch(() => false); - -export function getFinalizeCodeForError(error: unknown): string { - const transientKind = classifyTransientError(error); - return transientKind ?? (error instanceof Error ? error.name : "unknown_error"); -} - -async function tryReportByok(ctx: ActionCtx, report: ErrorReport): Promise { - if (!report.isByok) return false; - const code = classifyByokError(report.error); - if (code === null) return false; - // Provider bodies can include the decrypted key, so the finalize reason is the code only. - await finalizePendingMessages(ctx, report.threadId, code); - await saveMessage(ctx, components.agent, { - threadId: report.threadId, - userId: report.userId, - message: { role: "assistant", content: buildByokErrorMessage(code, report.provider) }, - }); - await ctx.runAction(internal.discord.notifyError, { - source: "streamWithRetry", - message: `${code} on ${report.provider} (${report.error instanceof Error ? report.error.name : "Unknown"})`, - userId: report.userId, - }); - return true; -} - -async function reportError(ctx: ActionCtx, report: ErrorReport): Promise { - const transientKind = classifyTransientError(report.error); - const content = transientKind - ? buildProviderTransientMessage(transientKind, report.provider, report.isByok) - : AI_ERROR_MESSAGE; - - // Keep provider text out of the agent component's failed-message field. - await finalizePendingMessages(ctx, report.threadId, getFinalizeCodeForError(report.error)); - - await saveMessage(ctx, components.agent, { - threadId: report.threadId, - userId: report.userId, - message: { role: "assistant", content }, - }); - - // Upstream provider outages already surface to the user with an attributed - // message; paging Discord on every Gemini/Claude capacity blip is noise. - if (transientKind) return; - - const reason = report.error instanceof Error ? report.error.message : String(report.error); - await ctx.runAction(internal.discord.notifyError, { - source: "streamWithRetry", - message: reason, - userId: report.userId, - }); -} diff --git a/convex/ai/resilienceReporting.ts b/convex/ai/resilienceReporting.ts new file mode 100644 index 00000000..9bd67a14 --- /dev/null +++ b/convex/ai/resilienceReporting.ts @@ -0,0 +1,96 @@ +import type { ActionCtx } from "../_generated/server"; +import { components, internal } from "../_generated/api"; +import { saveMessage } from "@convex-dev/agent"; +import { buildByokErrorMessage, classifyByokError } from "./byokErrors"; +import type { ProviderId } from "./providers"; +import { buildProviderTransientMessage, classifyTransientError } from "./transientErrors"; + +const AI_ERROR_MESSAGE = "I'm having trouble right now. Please try again in a moment."; + +export interface ErrorReport { + threadId: string; + userId: string; + error: unknown; + isByok: boolean; + provider: ProviderId; +} + +// streamText's abortSignal handler finalizes on clean aborts; provider errors +// thrown from result.text bypass that path and leave a stranded pending row. +export async function finalizePendingMessages( + ctx: ActionCtx, + threadId: string, + reason: string, +): Promise { + const result = await ctx.runQuery(components.agent.messages.listMessagesByThreadId, { + threadId, + paginationOpts: { cursor: null, numItems: 50 }, + order: "desc", + }); + for (const message of result.page) { + if (message.status !== "pending") continue; + await ctx.runMutation(components.agent.messages.finalizeMessage, { + messageId: message._id, + result: { status: "failed", error: reason }, + }); + } +} + +export function getFinalizeCodeForError(error: unknown): string { + const transientKind = classifyTransientError(error); + return transientKind ?? (error instanceof Error ? error.name : "unknown_error"); +} + +// Best-effort wrappers must not leave users with stuck pending messages. +export const safeFinalizePending = (ctx: ActionCtx, threadId: string, reason: string) => + finalizePendingMessages(ctx, threadId, reason).catch(() => undefined); +export const safeReportError = (ctx: ActionCtx, report: ErrorReport) => + reportError(ctx, report).catch(() => undefined); +export const safeTryReportByok = (ctx: ActionCtx, report: ErrorReport) => + tryReportByok(ctx, report).catch(() => false); + +async function tryReportByok(ctx: ActionCtx, report: ErrorReport): Promise { + if (!report.isByok) return false; + const code = classifyByokError(report.error); + if (code === null) return false; + // Provider bodies can include the decrypted key, so the finalize reason is the code only. + await finalizePendingMessages(ctx, report.threadId, code); + await saveMessage(ctx, components.agent, { + threadId: report.threadId, + userId: report.userId, + message: { role: "assistant", content: buildByokErrorMessage(code, report.provider) }, + }); + await ctx.runAction(internal.discord.notifyError, { + source: "streamWithRetry", + message: `${code} on ${report.provider} (${report.error instanceof Error ? report.error.name : "Unknown"})`, + userId: report.userId, + }); + return true; +} + +async function reportError(ctx: ActionCtx, report: ErrorReport): Promise { + const transientKind = classifyTransientError(report.error); + const content = transientKind + ? buildProviderTransientMessage(transientKind, report.provider, report.isByok) + : AI_ERROR_MESSAGE; + + // Keep provider text out of the agent component's failed-message field. + await finalizePendingMessages(ctx, report.threadId, getFinalizeCodeForError(report.error)); + + await saveMessage(ctx, components.agent, { + threadId: report.threadId, + userId: report.userId, + message: { role: "assistant", content }, + }); + + // Upstream provider outages already surface to the user with an attributed + // message; paging Discord on every Gemini/Claude capacity blip is noise. + if (transientKind) return; + + const reason = report.error instanceof Error ? report.error.message : String(report.error); + await ctx.runAction(internal.discord.notifyError, { + source: "streamWithRetry", + message: reason, + userId: report.userId, + }); +} diff --git a/convex/ai/resilienceStreamFailure.test.ts b/convex/ai/resilienceStreamFailure.test.ts index 2dcc74bf..77933a7a 100644 --- a/convex/ai/resilienceStreamFailure.test.ts +++ b/convex/ai/resilienceStreamFailure.test.ts @@ -46,15 +46,15 @@ function responseFailedStep(): StepResult { } as unknown as StepResult; } +function makeCircuitBreakerMock() { + return async (options: { primaryAgent: Agent; runAttempt: (agent: Agent) => Promise }) => + options.runAttempt(options.primaryAgent); +} + describe("streamWithRetry provider response failures", () => { beforeEach(() => { vi.clearAllMocks(); - runWithPrimaryCircuitBreakerMock.mockImplementation( - async (options: { primaryAgent: Agent; runAttempt: unknown }) => { - const runAttempt = options.runAttempt as (agent: Agent) => Promise; - return await runAttempt(options.primaryAgent); - }, - ); + runWithPrimaryCircuitBreakerMock.mockImplementation(makeCircuitBreakerMock()); }); it("surfaces non-thrown provider finish errors as BYOK messages", async () => { @@ -113,3 +113,164 @@ describe("streamWithRetry provider response failures", () => { expect(recordErrorMock).toHaveBeenCalledWith("byok_unknown_error"); }); }); + +describe("onError pre-emptive finalization", () => { + beforeEach(() => { + vi.clearAllMocks(); + runWithPrimaryCircuitBreakerMock.mockImplementation(makeCircuitBreakerMock()); + }); + + it("calls finalizeMessage with provider_overload code for Gemini high-demand stream errors", async () => { + const highDemandError = new Error( + "This model is currently experiencing high demand. Spikes in demand are usually temporary.", + ); + const streamText = vi.fn( + async (options: { onError?: (args: { error: unknown }) => Promise }) => { + await options.onError?.({ error: highDemandError }); + return { text: Promise.reject(highDemandError) }; + }, + ); + const agent = { + continueThread: vi.fn(async () => ({ thread: { streamText } })), + } as unknown as Agent; + const runQuery = vi.fn(async () => ({ + page: [{ _id: "pending-msg", status: "pending" }], + })); + const runMutation = vi.fn(async () => undefined); + + await streamWithRetry({ runQuery, runMutation, runAction: vi.fn() } as unknown as ActionCtx, { + primaryAgent: agent, + fallbackAgent: agent, + primaryModelName: "gemini-2.0-flash", + threadId: "thread-1", + userId: "user-1", + prompt: "hello", + isByok: false, + provider: "gemini", + source: "chat", + environment: "dev", + }); + + expect(runMutation).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + messageId: "pending-msg", + result: { status: "failed", error: "provider_overload" }, + }), + ); + }); + + it("does not expose raw provider error text in the finalization code", async () => { + const rawMessage = + "This model is currently experiencing high demand. Spikes in demand are usually temporary."; + const streamText = vi.fn( + async (options: { onError?: (args: { error: unknown }) => Promise }) => { + await options.onError?.({ error: new Error(rawMessage) }); + return { text: Promise.reject(new Error(rawMessage)) }; + }, + ); + const agent = { + continueThread: vi.fn(async () => ({ thread: { streamText } })), + } as unknown as Agent; + const runQuery = vi.fn(async () => ({ + page: [{ _id: "msg-1", status: "pending" }], + })); + const runMutation = vi.fn(async () => undefined); + + await streamWithRetry({ runQuery, runMutation, runAction: vi.fn() } as unknown as ActionCtx, { + primaryAgent: agent, + fallbackAgent: agent, + primaryModelName: "gemini-2.0-flash", + threadId: "thread-1", + userId: "user-1", + prompt: "hi", + isByok: false, + provider: "gemini", + source: "chat", + environment: "dev", + }); + + expect(runMutation).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + result: { status: "failed", error: "provider_overload" }, + }), + ); + expect(runMutation).not.toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + result: expect.objectContaining({ error: expect.stringContaining("high demand") }), + }), + ); + }); + + it("finalizes with provider_overload for Claude overload errors as well", async () => { + const overloadError = new Error("Service overloaded. Please try again later."); + const streamText = vi.fn( + async (options: { onError?: (args: { error: unknown }) => Promise }) => { + await options.onError?.({ error: overloadError }); + return { text: Promise.reject(overloadError) }; + }, + ); + const agent = { + continueThread: vi.fn(async () => ({ thread: { streamText } })), + } as unknown as Agent; + const runQuery = vi.fn(async () => ({ + page: [{ _id: "pending-claude", status: "pending" }], + })); + const runMutation = vi.fn(async () => undefined); + + await streamWithRetry({ runQuery, runMutation, runAction: vi.fn() } as unknown as ActionCtx, { + primaryAgent: agent, + fallbackAgent: agent, + primaryModelName: "claude-sonnet-4-6", + threadId: "thread-3", + userId: "user-1", + prompt: "hi", + isByok: false, + provider: "claude", + source: "chat", + environment: "dev", + }); + + expect(runMutation).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + messageId: "pending-claude", + result: { status: "failed", error: "provider_overload" }, + }), + ); + }); + + it("is safe when onError is not provided by the agent library", async () => { + // Ensure we don't crash if the SDK doesn't call onError (happy path). + const streamText = vi.fn(async () => ({ text: Promise.resolve("response text") })); + const agent = { + continueThread: vi.fn(async () => ({ thread: { streamText } })), + } as unknown as Agent; + const runQuery = vi.fn(async () => ({ page: [] })); + const runMutation = vi.fn(async () => undefined); + + const accumulator = await streamWithRetry( + { runQuery, runMutation, runAction: vi.fn() } as unknown as ActionCtx, + { + primaryAgent: agent, + fallbackAgent: agent, + primaryModelName: "gemini-2.0-flash", + threadId: "thread-4", + userId: "user-1", + prompt: "hello", + isByok: false, + provider: "gemini", + source: "chat", + environment: "dev", + }, + ); + + expect(accumulator.toRow().finishReason).not.toBe("error"); + expect(runMutation).not.toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ result: expect.objectContaining({ status: "failed" }) }), + ); + }); +});