From 6c765969b14d68e467e14b1319afb583f200cd19 Mon Sep 17 00:00:00 2001 From: William Bakst Date: Tue, 6 Jan 2026 22:52:25 -0800 Subject: [PATCH] feat(cloud): add cost estimation for router requests --- cloud/api/router/cost-estimator.test.ts | 499 ++++++++++++++++++++++++ cloud/api/router/cost-estimator.ts | 396 +++++++++++++++++++ cloud/errors.ts | 29 ++ 3 files changed, 924 insertions(+) create mode 100644 cloud/api/router/cost-estimator.test.ts create mode 100644 cloud/api/router/cost-estimator.ts diff --git a/cloud/api/router/cost-estimator.test.ts b/cloud/api/router/cost-estimator.test.ts new file mode 100644 index 0000000000..7b8f77415c --- /dev/null +++ b/cloud/api/router/cost-estimator.test.ts @@ -0,0 +1,499 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { Effect } from "effect"; +import { estimateCost } from "@/api/router/cost-estimator"; +import { PricingUnavailableError } from "@/errors"; +import * as pricing from "@/api/router/pricing"; + +describe("cost-estimator", () => { + describe("estimateCost", () => { + beforeEach(() => { + vi.restoreAllMocks(); + }); + + it("should estimate cost for OpenAI request with messages", async () => { + const mockPricing = { + input: 10, // $10 per million + output: 30, // $30 per million + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const requestBody = { + model: "gpt-4", + messages: [ + { role: "user", content: "Hello world" }, // ~11 chars + 4 role = 15 chars = ~4 tokens + ], + max_tokens: 100, + }; + + const result = await Effect.runPromise( + estimateCost({ + provider: "openai", + modelId: "gpt-4", + requestBody, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBeGreaterThan(0); + expect(result.outputTokens).toBe(100); + expect(result.cost).toBeGreaterThan(0); + }); + + it("should estimate cost for Anthropic request with multimodal content", async () => { + const mockPricing = { + input: 15, + output: 75, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const requestBody = { + model: "claude-3-opus-20240229", + messages: [ + { + role: "user", + content: [ + { type: "text", text: "What's in this image?" }, + { type: "image", source: { type: "base64", data: "..." } }, + ], + }, + ], + max_tokens: 500, + }; + + const result = await Effect.runPromise( + estimateCost({ + provider: "anthropic", + modelId: "claude-3-opus-20240229", + requestBody, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBeGreaterThan(0); + expect(result.outputTokens).toBe(500); + expect(result.cost).toBeGreaterThan(0); + }); + + it("should estimate cost for Google request with contents", async () => { + const mockPricing = { + input: 5, + output: 15, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const requestBody = { + contents: [ + { + role: "user", + parts: [{ text: "Hello from Google" }], + }, + ], + generationConfig: { + maxOutputTokens: 200, + }, + }; + + const result = await Effect.runPromise( + estimateCost({ + provider: "google", + modelId: "gemini-pro", + requestBody, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBeGreaterThan(0); + expect(result.outputTokens).toBe(200); + expect(result.cost).toBeGreaterThan(0); + }); + + it("should use default output tokens when max_tokens not specified", async () => { + const mockPricing = { + input: 10, + output: 30, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const requestBody = { + messages: [{ role: "user", content: "Hello" }], + }; + + const result = await Effect.runPromise( + estimateCost({ + provider: "openai", + modelId: "gpt-4", + requestBody, + }), + ); + + expect(result).toBeDefined(); + expect(result.outputTokens).toBe(1000); // DEFAULT_OUTPUT_TOKENS_ESTIMATE + }); + + it("should throw PricingUnavailableError when pricing returns null", async () => { + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(null), + ); + + const requestBody = { + messages: [{ role: "user", content: "Hello" }], + }; + + const result = await Effect.runPromise( + estimateCost({ + provider: "openai", + modelId: "unknown-model", + requestBody, + }).pipe(Effect.flip), + ); + + expect(result).toBeInstanceOf(PricingUnavailableError); + expect((result as PricingUnavailableError).provider).toBe("openai"); + expect((result as PricingUnavailableError).model).toBe("unknown-model"); + }); + + it("should throw PricingUnavailableError when pricing fails", async () => { + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.fail(new Error("Pricing not found")), + ); + + const requestBody = { + messages: [{ role: "user", content: "Hello" }], + }; + + const result = await Effect.runPromise( + estimateCost({ + provider: "openai", + modelId: "unknown-model", + requestBody, + }).pipe(Effect.flip), + ); + + expect(result).toBeInstanceOf(Error); + }); + + it("should handle empty request body", async () => { + const mockPricing = { + input: 10, + output: 30, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "openai", + modelId: "gpt-4", + requestBody: {}, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBeGreaterThan(0); // Fallback to stringifying + expect(result.outputTokens).toBe(1000); + }); + + it("should handle null request body", async () => { + const mockPricing = { + input: 10, + output: 30, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "openai", + modelId: "gpt-4", + requestBody: null, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBe(0); + expect(result.outputTokens).toBe(1000); + }); + + it("should handle Google request with null body", async () => { + const mockPricing = { + input: 5, + output: 15, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "google", + modelId: "gemini-pro", + requestBody: null, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBe(0); + expect(result.outputTokens).toBe(1000); + }); + + it("should handle Google request with no contents or generationConfig", async () => { + const mockPricing = { + input: 5, + output: 15, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "google", + modelId: "gemini-pro", + requestBody: { model: "gemini-pro" }, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBeGreaterThan(0); // Falls back to stringify + expect(result.outputTokens).toBe(1000); // Default + }); + + it("should handle Anthropic request with null body", async () => { + const mockPricing = { + input: 15, + output: 75, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "anthropic", + modelId: "claude-3-opus-20240229", + requestBody: null, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBe(0); + expect(result.outputTokens).toBe(1000); + }); + + it("should handle Anthropic request with no messages", async () => { + const mockPricing = { + input: 15, + output: 75, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "anthropic", + modelId: "claude-3-opus-20240229", + requestBody: { model: "claude-3-opus-20240229" }, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBeGreaterThan(0); // Falls back to stringify + expect(result.outputTokens).toBe(1000); // Default + }); + + it("should handle OpenAI request with no messages", async () => { + const mockPricing = { + input: 10, + output: 30, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "openai", + modelId: "gpt-4", + requestBody: { model: "gpt-4" }, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBeGreaterThan(0); // Falls back to stringify + expect(result.outputTokens).toBe(1000); // Default + }); + + it("should handle OpenAI request with non-object body", async () => { + const mockPricing = { + input: 10, + output: 30, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "openai", + modelId: "gpt-4", + requestBody: "invalid", + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBe(0); // Non-object returns 0 + expect(result.outputTokens).toBe(1000); // Default + }); + + it("should handle Anthropic request with non-object body", async () => { + const mockPricing = { + input: 15, + output: 75, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "anthropic", + modelId: "claude-3-opus-20240229", + requestBody: 123, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBe(0); // Non-object returns 0 + expect(result.outputTokens).toBe(1000); // Default + }); + + it("should handle Google request with non-object body", async () => { + const mockPricing = { + input: 5, + output: 15, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "google", + modelId: "gemini-pro", + requestBody: true, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBe(0); // Non-object returns 0 + expect(result.outputTokens).toBe(1000); // Default + }); + + it("should handle OpenAI request with non-string/non-array content", async () => { + const mockPricing = { + input: 10, + output: 30, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "openai", + modelId: "gpt-4", + requestBody: { + messages: [ + { role: "user", content: 123 }, // Non-string, non-array content + ], + }, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBeGreaterThan(0); // Counts role + expect(result.outputTokens).toBe(1000); + }); + + it("should handle message without role field", async () => { + const mockPricing = { + input: 10, + output: 30, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "openai", + modelId: "gpt-4", + requestBody: { + messages: [ + { content: "Hello world" }, // No role field + ], + }, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBeGreaterThan(0); + expect(result.outputTokens).toBe(1000); + }); + + it("should handle Google request with invalid maxOutputTokens", async () => { + const mockPricing = { + input: 5, + output: 15, + }; + + vi.spyOn(pricing, "getModelPricing").mockReturnValue( + Effect.succeed(mockPricing), + ); + + const result = await Effect.runPromise( + estimateCost({ + provider: "google", + modelId: "gemini-pro", + requestBody: { + contents: [{ role: "user", parts: [{ text: "Hello" }] }], + generationConfig: { + maxOutputTokens: 0, // Invalid (not > 0) + }, + }, + }), + ); + + expect(result).toBeDefined(); + expect(result.inputTokens).toBeGreaterThan(0); + expect(result.outputTokens).toBe(1000); // Should use default + }); + }); +}); diff --git a/cloud/api/router/cost-estimator.ts b/cloud/api/router/cost-estimator.ts new file mode 100644 index 0000000000..f2d1153bad --- /dev/null +++ b/cloud/api/router/cost-estimator.ts @@ -0,0 +1,396 @@ +/** + * @fileoverview Cost estimation for router requests. + * + * Provides utilities for estimating the cost of AI provider requests before they're made. + * This is used by the reservation system to lock sufficient funds for concurrent requests. + */ + +import { Effect } from "effect"; +import { getModelPricing } from "@/api/router/pricing"; +import type { ProviderName } from "@/api/router/providers"; +import { PricingUnavailableError } from "@/errors"; + +/** + * Default estimate for output tokens when not specified. + * Conservative estimate to avoid underestimating costs. + */ +const DEFAULT_OUTPUT_TOKENS_ESTIMATE = 1000; + +/** + * Rough heuristic for token counting: 4 characters per token. + * This is an approximation used for input token estimation. + */ +const CHARS_PER_TOKEN = 4; + +/** + * Parameters for cost estimation. + */ +export interface EstimateCostParams { + /** Provider name (openai, anthropic, google) */ + provider: ProviderName; + /** Model ID */ + modelId: string; + /** Parsed request body */ + requestBody: unknown; +} + +/** + * Result of cost estimation. + */ +export interface EstimatedCost { + /** Estimated cost in dollars (e.g., 0.05 for $0.05) */ + cost: number; + /** Estimated input tokens */ + inputTokens: number; + /** Estimated output tokens */ + outputTokens: number; +} + +/** + * Message-like structure with content and optional role. + * Used internally to unify different provider message formats. + */ +interface MessageLike { + content?: unknown; + role?: unknown; +} + +/** + * Base cost estimator that handles shared logic for all providers. + * + * Subclasses implement provider-specific extraction of messages/contents arrays + * and max token fields. All token counting logic is centralized in the base class. + */ +export abstract class BaseCostEstimator { + protected readonly provider: ProviderName; + + constructor(provider: ProviderName) { + this.provider = provider; + } + + /** + * Extracts messages from the request body in provider-specific format. + * + * Must be implemented by each provider-specific estimator. + * + * @param requestBody - Non-null object (validated by caller) + * @returns Array of message-like objects, or null if not found + */ + protected abstract extractMessages( + requestBody: Record, + ): MessageLike[] | null; + + /** + * Extracts the max output tokens from the request body in provider-specific format. + * + * Must be implemented by each provider-specific estimator. + * + * @param requestBody - Non-null object (validated by caller) + * @returns Max output tokens if specified, or null to use default + */ + protected abstract extractMaxOutputTokens( + requestBody: Record, + ): number | null; + + /** + * Counts characters from message content. + * + * Handles both string content and array content (multimodal): + * - String content: counts characters directly + * - Array content: extracts text from blocks and counts characters + * + * @private + */ + private countContentChars(content: unknown): number { + if (typeof content === "string") { + return content.length; + } + + if (Array.isArray(content)) { + // Handle multimodal content (array of content blocks) + let chars = 0; + for (const block of content) { + if ( + typeof block === "object" && + block !== null && + "text" in block && + typeof (block as { text?: unknown }).text === "string" + ) { + chars += (block as { text: string }).text.length; + } + } + return chars; + } + + return 0; + } + + /** + * Counts tokens from an array of messages. + * + * Sums up: + * - Content characters (string or multimodal) + * - Role overhead (if present) + * + * @private + */ + private countMessageTokens(messages: MessageLike[]): number { + let totalChars = 0; + + for (const message of messages) { + // Count content + totalChars += this.countContentChars(message.content); + + // Add role overhead + if (typeof message.role === "string") { + totalChars += message.role.length; + } + } + + return Math.ceil(totalChars / CHARS_PER_TOKEN); + } + + /** + * Extracts the number of input tokens from a request body. + * + * Uses provider-specific message extraction, then common token counting logic. + */ + protected extractInputTokens(requestBody: unknown): number { + if (typeof requestBody !== "object" || requestBody === null) { + return 0; + } + + const body = requestBody as Record; + const messages = this.extractMessages(body); + if (messages) { + return this.countMessageTokens(messages); + } + + // Fallback: stringify and count characters + return Math.ceil(JSON.stringify(body).length / CHARS_PER_TOKEN); + } + + /** + * Extracts the number of output tokens from a request body. + * + * Uses provider-specific max tokens extraction. + */ + protected extractOutputTokens(requestBody: unknown): number { + if (typeof requestBody !== "object" || requestBody === null) { + return DEFAULT_OUTPUT_TOKENS_ESTIMATE; + } + + const body = requestBody as Record; + const maxTokens = this.extractMaxOutputTokens(body); + return maxTokens ?? DEFAULT_OUTPUT_TOKENS_ESTIMATE; + } + + /** + * Estimates the cost of a router request before it's made. + * + * This function: + * 1. Estimates input/output token counts from the request body + * 2. Fetches pricing data for the model + * 3. Calculates estimated cost + * + * The estimate is intentionally conservative (tends to overestimate) to ensure + * sufficient funds are reserved. The actual cost will be charged on settlement. + * + * @param model - Model ID + * @param requestBody - Parsed request body + * @returns Estimated cost + * @throws PricingUnavailableError if pricing data cannot be fetched + * + * @example + * ```ts + * const estimator = new OpenAICostEstimator(); + * const estimate = yield* estimator.estimate( + * "gpt-4", + * { messages: [...], max_tokens: 1024 } + * ); + * console.log(`Estimated cost: $${estimate.cost.toFixed(4)}`); + * ``` + */ + public estimate( + model: string, + requestBody: unknown, + ): Effect.Effect { + return Effect.gen(this, function* () { + const inputTokens = this.extractInputTokens(requestBody); + const outputTokens = this.extractOutputTokens(requestBody); + + // Get pricing data - fail if unavailable + const pricing = yield* getModelPricing(this.provider, model).pipe( + Effect.flatMap((p) => + p === null + ? Effect.fail( + new PricingUnavailableError({ + message: `Pricing unavailable for ${this.provider}/${model}`, + provider: this.provider, + model, + }), + ) + : Effect.succeed(p), + ), + ); + + // Calculate cost (convert from per-million to actual tokens) + const inputCost = (inputTokens / 1_000_000) * pricing.input; + const outputCost = (outputTokens / 1_000_000) * pricing.output; + const totalCost = inputCost + outputCost; + + return { + cost: totalCost, + inputTokens, + outputTokens, + }; + }); + } +} + +/** + * OpenAI-specific cost estimator. + * + * Handles OpenAI format: + * - messages array with content field + * - max_tokens for output estimation + */ +export class OpenAICostEstimator extends BaseCostEstimator { + constructor() { + super("openai"); + } + + protected extractMessages(requestBody: unknown): MessageLike[] | null { + const body = requestBody as Record; + if (Array.isArray(body.messages)) { + return body.messages as MessageLike[]; + } + return null; + } + + protected extractMaxOutputTokens(requestBody: unknown): number | null { + const body = requestBody as Record; + if (typeof body.max_tokens === "number" && body.max_tokens > 0) { + return body.max_tokens; + } + return null; + } +} + +/** + * Anthropic-specific cost estimator. + * + * Handles Anthropic format (same as OpenAI): + * - messages array with content field + * - max_tokens for output estimation + */ +export class AnthropicCostEstimator extends BaseCostEstimator { + constructor() { + super("anthropic"); + } + + protected extractMessages( + requestBody: Record, + ): MessageLike[] | null { + if (Array.isArray(requestBody.messages)) { + return requestBody.messages as MessageLike[]; + } + return null; + } + + protected extractMaxOutputTokens( + requestBody: Record, + ): number | null { + if ( + typeof requestBody.max_tokens === "number" && + requestBody.max_tokens > 0 + ) { + return requestBody.max_tokens; + } + return null; + } +} + +/** + * Google AI-specific cost estimator. + * + * Handles Google format: + * - contents array with parts (maps to message-like structure) + * - generationConfig.maxOutputTokens for output estimation + */ +export class GoogleCostEstimator extends BaseCostEstimator { + constructor() { + super("google"); + } + + protected extractMessages( + requestBody: Record, + ): MessageLike[] | null { + if (Array.isArray(requestBody.contents)) { + // Convert Google's contents format to message-like structure + return (requestBody.contents as Array>).map( + (content) => ({ + content: content.parts, + role: content.role, + }), + ); + } + return null; + } + + protected extractMaxOutputTokens( + requestBody: Record, + ): number | null { + if ( + typeof requestBody.generationConfig === "object" && + requestBody.generationConfig !== null + ) { + const config = requestBody.generationConfig as Record; + if ( + typeof config.maxOutputTokens === "number" && + config.maxOutputTokens > 0 + ) { + return config.maxOutputTokens; + } + } + return null; + } +} + +/** + * Factory function to get the appropriate cost estimator for a provider. + */ +function getCostEstimator(provider: ProviderName): BaseCostEstimator { + switch (provider) { + case "openai": + return new OpenAICostEstimator(); + case "anthropic": + return new AnthropicCostEstimator(); + case "google": + return new GoogleCostEstimator(); + /* v8 ignore next 4 */ + default: { + const exhaustiveCheck: never = provider; + throw new Error(`Unsupported provider: ${exhaustiveCheck as string}`); + } + } +} + +/** + * Estimates the cost of a router request before it's made. + * + * @param params - Cost estimation parameters + * @returns Estimated cost + * @throws PricingUnavailableError if pricing data cannot be fetched + */ +export function estimateCost({ + provider, + modelId, + requestBody, +}: EstimateCostParams): Effect.Effect< + EstimatedCost, + PricingUnavailableError | Error +> { + const estimator = getCostEstimator(provider); + return estimator.estimate(modelId, requestBody); +} diff --git a/cloud/errors.ts b/cloud/errors.ts index 0e2d166b74..cbdc434ddf 100644 --- a/cloud/errors.ts +++ b/cloud/errors.ts @@ -187,3 +187,32 @@ export class ProxyError extends Schema.TaggedError()("ProxyError", { }) { static readonly status = 502 as const; } + +/** + * Error that occurs when pricing data is unavailable for cost estimation. + * + * This error is raised when we cannot retrieve pricing information needed to + * estimate the cost of a request. When this occurs, the request should be + * rejected rather than proxied, as we cannot lock sufficient funds without + * a cost estimate. + * + * @example + * ```ts + * const estimate = yield* estimateCost({ ... }).pipe( + * Effect.catchTag("PricingUnavailableError", (error) => { + * console.error("Cannot estimate cost:", error.message); + * return Effect.fail(new HandlerError({ message: "Pricing unavailable" })); + * }) + * ); + * ``` + */ +export class PricingUnavailableError extends Schema.TaggedError()( + "PricingUnavailableError", + { + message: Schema.String, + provider: Schema.optional(Schema.String), + model: Schema.optional(Schema.String), + }, +) { + static readonly status = 503 as const; +}