diff --git a/app/play/[level]/test-mode.tsx b/app/play/[level]/test-mode.tsx index 332c770..c86a244 100644 --- a/app/play/[level]/test-mode.tsx +++ b/app/play/[level]/test-mode.tsx @@ -13,7 +13,7 @@ import { } from "@/components/ui/select"; import { api } from "@/convex/_generated/api"; import { AI_MODELS } from "@/convex/constants"; -import { errorMessage } from "@/lib/utils"; +import { useAITesting } from "@/hooks/useAITesting"; interface TestModeProps { level: number; @@ -21,49 +21,30 @@ interface TestModeProps { } export default function TestMode({ level, map }: TestModeProps) { - const [playerMap, setPlayerMap] = useState([]); - const [isSimulating, setIsSimulating] = useState(false); - const [gameResult, setGameResult] = useState<"WON" | "LOST" | null>(null); const [selectedModel, setSelectedModel] = useState(AI_MODELS[0].model); - const testAIModel = useAction(api.maps.testAIModel); - const [aiError, setAiError] = useState(null); - const [aiReasoning, setAiReasoning] = useState(null); const [showOriginalMap, setShowOriginalMap] = useState(true); + const { + isSimulating, + gameResult, + aiError, + aiReasoning, + aiPromptTokensUsed, + aiOutputTokensUsed, + aiTotalTokensUsed, + aiTotalRunCost, + resultMap, + runTest, + resetAITest, + } = useAITesting({ testingType: "MODEL", level }); const handleAITest = async () => { - setIsSimulating(true); - setGameResult(null); - setAiError(null); - setAiReasoning(null); setShowOriginalMap(false); - - try { - const result = await testAIModel({ - level, - modelId: selectedModel, - }); - - if (!result.map) { - throw new Error("No map found"); - } - - setPlayerMap(result.map); - setGameResult(result.isWin ? "WON" : "LOST"); - setAiReasoning(result.reasoning); - } catch (error) { - console.error("Error testing AI model:", error); - setAiError(errorMessage(error)); - } finally { - setIsSimulating(false); - } + await runTest(selectedModel, map); }; const handleReset = () => { setShowOriginalMap(true); - setPlayerMap([]); - setGameResult(null); - setAiError(null); - setAiReasoning(null); + resetAITest(); }; return ( @@ -105,7 +86,7 @@ export default function TestMode({ level, map }: TestModeProps) { {aiError &&
{aiError}
} {gameResult && (
- +
{aiReasoning}

)} +
+

Token Usage:

+
    +
  • Prompt Tokens: {aiPromptTokensUsed ?? "N/A"}
  • +
  • Output Tokens: {aiOutputTokensUsed ?? "N/A"}
  • +
  • Total Tokens: {aiTotalTokensUsed ?? "N/A"}
  • +
  • + Total Cost:{" "} + {aiTotalRunCost ? `$${aiTotalRunCost.toFixed(6)}` : "N/A"} +
  • +
+
)} diff --git a/app/playground/page.tsx b/app/playground/page.tsx index 1082534..303fad5 100644 --- a/app/playground/page.tsx +++ b/app/playground/page.tsx @@ -29,6 +29,7 @@ import { useToast } from "@/components/ui/use-toast"; import { api } from "@/convex/_generated/api"; import { Id } from "@/convex/_generated/dataModel"; import { SIGN_IN_ERROR_MESSAGE } from "@/convex/users"; +import { useAITesting } from "@/hooks/useAITesting"; import { errorMessage } from "@/lib/utils"; import { ZombieSurvival } from "@/simulators/zombie-survival"; @@ -56,15 +57,26 @@ export default function PlaygroundPage() { ]); const [model, setModel] = useState(""); const [error, setError] = useState(null); - const [solution, setSolution] = useState(null); - const [reasoning, setReasoning] = useState(null); const [publishing, setPublishing] = useState(false); - const [simulating, setSimulating] = useState(false); const [userPlaying, setUserPlaying] = useState(false); const [userSolution, setUserSolution] = useState([]); const [visualizingUserSolution, setVisualizingUserSolution] = useState(false); const [openSignInModal, setOpenSignInModal] = useState(false); + const { + isSimulating, + gameResult: solution, + aiError, + aiReasoning, + aiPromptTokensUsed, + aiOutputTokensUsed, + aiTotalTokensUsed, + aiTotalRunCost, + resultMap, + runTest, + resetAITest, + } = useAITesting({ testingType: "MAP" }); + async function handlePublish() { if (!ZombieSurvival.mapHasZombies(map)) { alert("Add some zombies to the map first"); @@ -98,30 +110,12 @@ export default function PlaygroundPage() { } async function handleSimulate() { - setError(null); - setSolution(null); - setReasoning(null); - if (!ZombieSurvival.mapHasZombies(map)) { alert("Add some zombies to the map first"); return; } - setSimulating(true); - - const { error, solution, reasoning } = await testMap({ - modelId: model, - map: map, - }); - - if (typeof error !== "undefined") { - setError(error); - } else { - setSolution(solution!); - setReasoning(reasoning!); - } - - setSimulating(false); + await runTest(model, map); } function handleChangeMap(value: string[][]) { @@ -139,8 +133,7 @@ export default function PlaygroundPage() { } function handleEdit() { - setSolution(null); - setReasoning(null); + resetAITest(); setUserPlaying(false); setVisualizingUserSolution(false); @@ -223,7 +216,7 @@ export default function PlaygroundPage() { } } - // function handleReset() { + // function handleresetAITest() { // handleChangeMap([]); // setSolution(null); // setReasoning(null); @@ -304,7 +297,7 @@ export default function PlaygroundPage() { autoReplay autoStart controls={false} - map={visualizingUserSolution ? userSolution : solution!} + map={visualizingUserSolution ? userSolution : resultMap} /> )} {!visualizing && ( @@ -485,7 +478,7 @@ export default function PlaygroundPage() { diff --git a/convex/maps.ts b/convex/maps.ts index 7d04732..7ab8573 100644 --- a/convex/maps.ts +++ b/convex/maps.ts @@ -404,16 +404,24 @@ export const testAIModel = action({ throw new Error("Active prompt not found"); } - const { solution, reasoning, error } = await runModel( - args.modelId, - map.grid, - activePrompt.prompt, - ); + const { + solution, + reasoning, + error, + promptTokens, + outputTokens, + totalTokensUsed, + totalRunCost, + } = await runModel(args.modelId, map.grid, activePrompt.prompt); return { map: solution, isWin: error ? false : ZombieSurvival.isWin(solution!), reasoning, + promptTokens, + outputTokens, + totalTokensUsed, + totalRunCost, error, }; }, diff --git a/hooks/useAITesting.ts b/hooks/useAITesting.ts new file mode 100644 index 0000000..35e79e2 --- /dev/null +++ b/hooks/useAITesting.ts @@ -0,0 +1,118 @@ +"use client"; + +import { useState } from "react"; +import { useAction } from "convex/react"; +import { api } from "@/convex/_generated/api"; +import { errorMessage } from "@/lib/utils"; + +interface UseAITestingProps { + testingType: "MAP" | "MODEL"; + level?: number; +} + +interface AITestResult { + map?: string[][]; + isWin?: boolean; + reasoning: string; + promptTokens?: number; + outputTokens?: number; + totalTokensUsed?: number; + totalRunCost?: number; +} + +export function useAITesting({ testingType, level }: UseAITestingProps) { + const [isSimulating, setIsSimulating] = useState(false); + const [gameResult, setGameResult] = useState<"WON" | "LOST" | null>(null); + const [aiError, setAiError] = useState(null); + const [aiReasoning, setAiReasoning] = useState(null); + const [aiPromptTokensUsed, setAiPromptTokensUsed] = useState( + null, + ); + const [aiOutputTokensUsed, setAiOutputTokensUsed] = useState( + null, + ); + const [aiTotalTokensUsed, setAiTotalTokensUsed] = useState( + null, + ); + const [aiTotalRunCost, setAiTotalRunCost] = useState(null); + const [resultMap, setResultMap] = useState([]); + + const testAIModel = useAction(api.maps.testAIModel); + const testMap = useAction(api.maps.testMap); + + const runTest = async (modelId: string, map: string[][]) => { + setIsSimulating(true); + setGameResult(null); + setAiError(null); + setAiReasoning(null); + + try { + let result: AITestResult | null = null; + + if (testingType === "MODEL") { + if (!level) { + console.error( + "Error testing AI model:", + "Level is required when testing a model", + ); + setAiError("Level is required when testing a model"); + return; + } + result = await testAIModel({ + level, + modelId, + }); + } else if (testingType === "MAP") { + result = await testMap({ + modelId, + map, + }); + } + + if (!result?.map) { + throw new Error("No map found"); + } + + setResultMap(result.map); + setGameResult(result.isWin ? "WON" : "LOST"); + setAiReasoning(result.reasoning); + setAiPromptTokensUsed(result.promptTokens ?? null); + setAiOutputTokensUsed(result.outputTokens ?? null); + setAiTotalTokensUsed(result.totalTokensUsed ?? null); + setAiTotalRunCost(result.totalRunCost ?? null); + + return result as AITestResult; + } catch (error) { + console.error("Error testing AI model:", error); + setAiError(errorMessage(error)); + throw error; + } finally { + setIsSimulating(false); + } + }; + + const resetAITest = () => { + setResultMap([]); + setGameResult(null); + setAiError(null); + setAiReasoning(null); + setAiPromptTokensUsed(null); + setAiOutputTokensUsed(null); + setAiTotalTokensUsed(null); + setAiTotalRunCost(null); + }; + + return { + isSimulating, + gameResult, + aiError, + aiReasoning, + aiPromptTokensUsed, + aiOutputTokensUsed, + aiTotalTokensUsed, + aiTotalRunCost, + resultMap, + runTest, + resetAITest, + }; +} diff --git a/models/claude-3-5-sonnet.ts b/models/claude-3-5-sonnet.ts index c848f13..be2c33c 100644 --- a/models/claude-3-5-sonnet.ts +++ b/models/claude-3-5-sonnet.ts @@ -48,9 +48,37 @@ export const claude35sonnet: ModelHandler = async ( throw new Error(response.error.message); } + const promptTokens = completion.usage.input_tokens; + const outputTokens = completion.usage.output_tokens; + const totalTokensUsed = + completion.usage.input_tokens + completion.usage.output_tokens; + + // https://docs.anthropic.com/en/docs/about-claude/models + const getPriceForInputToken = (tokenCount?: number) => { + if (!tokenCount) { + return 0; + } + + return (3.0 / 1_000_000) * tokenCount; + }; + + const getPriceForOutputToken = (tokenCount?: number) => { + if (!tokenCount) { + return 0; + } + + return (15.0 / 1_000_000) * tokenCount; + }; + return { boxCoordinates: response.data.boxCoordinates, playerCoordinates: response.data.playerCoordinates, reasoning: response.data.reasoning, + promptTokens: promptTokens, + outputTokens: outputTokens, + totalTokensUsed: totalTokensUsed, + totalRunCost: + getPriceForInputToken(promptTokens) + + getPriceForOutputToken(outputTokens), }; }; diff --git a/models/gemini-1.5-pro.ts b/models/gemini-1.5-pro.ts index 81a21cf..a8549b0 100644 --- a/models/gemini-1.5-pro.ts +++ b/models/gemini-1.5-pro.ts @@ -69,9 +69,40 @@ export const gemini15pro: ModelHandler = async ( const result = await model.generateContent(userPrompt); const parsedResponse = JSON.parse(result.response.text()) as GeminiResponse; + const promptTokens = result.response.usageMetadata?.promptTokenCount; + const outputTokens = result.response.usageMetadata?.candidatesTokenCount; + const totalTokensUsed = result.response.usageMetadata?.totalTokenCount; + + // https://ai.google.dev/pricing#1_5pro + const getPriceForInputToken = (tokenCount?: number) => { + if (!tokenCount) { + return 0; + } + if (tokenCount > 128_000) { + return (2.5 / 1_000_000) * tokenCount; + } + return (1.25 / 1_000_000) * tokenCount; + }; + + const getPriceForOutputToken = (tokenCount?: number) => { + if (!tokenCount) { + return 0; + } + if (tokenCount > 128_000) { + return (10.0 / 1_000_000) * tokenCount; + } + return (5.0 / 1_000_000) * tokenCount; + }; + return { boxCoordinates: parsedResponse.boxCoordinates, playerCoordinates: parsedResponse.playerCoordinates, reasoning: parsedResponse.reasoning, + promptTokens: promptTokens, + outputTokens: outputTokens, + totalTokensUsed: totalTokensUsed, + totalRunCost: + getPriceForInputToken(promptTokens) + + getPriceForOutputToken(outputTokens), }; }; diff --git a/models/gpt-4o.ts b/models/gpt-4o.ts index c018648..b2ef2f4 100644 --- a/models/gpt-4o.ts +++ b/models/gpt-4o.ts @@ -38,9 +38,36 @@ export const gpt4o: ModelHandler = async (systemPrompt, userPrompt, config) => { throw new Error("Failed to run model GPT-4o"); } + const promptTokens = completion.usage?.prompt_tokens; + const outputTokens = completion.usage?.completion_tokens; + const totalTokensUsed = completion.usage?.total_tokens; + + // https://openai.com/api/pricing/ + const getPriceForInputToken = (tokenCount?: number) => { + if (!tokenCount) { + return 0; + } + + return (2.5 / 1_000_000) * tokenCount; + }; + + const getPriceForOutputToken = (tokenCount?: number) => { + if (!tokenCount) { + return 0; + } + + return (10.0 / 1_000_000) * tokenCount; + }; + return { boxCoordinates: response.parsed.boxCoordinates, playerCoordinates: response.parsed.playerCoordinates, reasoning: response.parsed.reasoning, + promptTokens: promptTokens, + outputTokens: outputTokens, + totalTokensUsed: totalTokensUsed, + totalRunCost: + getPriceForInputToken(promptTokens) + + getPriceForOutputToken(outputTokens), }; }; diff --git a/models/index.ts b/models/index.ts index 549f068..668de69 100644 --- a/models/index.ts +++ b/models/index.ts @@ -20,6 +20,10 @@ export type ModelHandler = ( boxCoordinates: number[][]; playerCoordinates: number[]; reasoning: string; + promptTokens?: number; + outputTokens?: number; + totalTokensUsed?: number; + totalRunCost?: number; }>; const MAX_RETRIES = 1; @@ -33,16 +37,22 @@ const CONFIG: ModelHandlerConfig = { topP: 0.95, }; +export type RunModelResult = { + solution?: string[][]; + reasoning: string; + promptTokens?: number; + outputTokens?: number; + totalTokensUsed?: number; + totalRunCost?: number; + error?: string; +}; + export async function runModel( modelId: string, map: string[][], prompt: string, retry = 1, -): Promise<{ - solution?: string[][]; - reasoning: string; - error?: string; -}> { +): Promise { const userPrompt = `Grid: ${JSON.stringify(map)}\n\n` + `Valid Locations: ${JSON.stringify(ZombieSurvival.validLocations(map))}`; @@ -101,6 +111,10 @@ export async function runModel( return { solution: originalMap, reasoning: result.reasoning, + promptTokens: result.promptTokens, + outputTokens: result.outputTokens, + totalTokensUsed: result.totalTokensUsed, + totalRunCost: result.totalRunCost, }; } catch (error) { if (retry === MAX_RETRIES || reasoning === null) { diff --git a/models/mistral-large-2.ts b/models/mistral-large-2.ts index 64bb21b..9598f8d 100644 --- a/models/mistral-large-2.ts +++ b/models/mistral-large-2.ts @@ -7,6 +7,9 @@ const responseSchema = z.object({ reasoning: z.string(), playerCoordinates: z.array(z.number()), boxCoordinates: z.array(z.array(z.number())), + promptTokens: z.number().optional(), + outputTokens: z.number().optional(), + totalTokensUsed: z.number().optional(), }); export const mistralLarge2: ModelHandler = async ( @@ -42,7 +45,36 @@ export const mistralLarge2: ModelHandler = async ( throw new Error("Response received from Mistral is not JSON"); } - const response = await responseSchema.safeParseAsync(JSON.parse(content)); + const promptTokens = completion.usage.promptTokens; + const outputTokens = completion.usage.completionTokens; + const totalTokensUsed = completion.usage.totalTokens; + + // https://mistral.ai/technology/ + const getPriceForInputToken = (tokenCount?: number) => { + if (!tokenCount) { + return 0; + } + + return (2.0 / 1_000_000) * tokenCount; + }; + + const getPriceForOutputToken = (tokenCount?: number) => { + if (!tokenCount) { + return 0; + } + + return (6.0 / 1_000_000) * tokenCount; + }; + + const response = await responseSchema.safeParseAsync({ + ...JSON.parse(content), + promptTokens: completion.usage.promptTokens, + outputTokens: completion.usage.completionTokens, + totalTokensUsed: completion.usage.totalTokens, + totalRunCost: + getPriceForInputToken(promptTokens) + + getPriceForOutputToken(outputTokens), + }); if (!response.success) { throw new Error(response.error.message); diff --git a/models/perplexity-llama-3.1.ts b/models/perplexity-llama-3.1.ts index 98cd9d1..17ea955 100644 --- a/models/perplexity-llama-3.1.ts +++ b/models/perplexity-llama-3.1.ts @@ -34,6 +34,10 @@ const responseSchema = z.object({ playerCoordinates: z.array(z.number()), boxCoordinates: z.array(z.array(z.number())), reasoning: z.string(), + promptTokens: z.number().optional(), + outputTokens: z.number().optional(), + totalTokensUsed: z.number().optional(), + totalRunCost: z.number().optional(), }); export const perplexityLlama31: ModelHandler = async ( @@ -80,14 +84,48 @@ export const perplexityLlama31: ModelHandler = async ( } const content = validatedResponse.data.choices[0].message.content; + const promptTokens = validatedResponse.data.usage.prompt_tokens; + const outputTokens = validatedResponse.data.usage.completion_tokens; + const totalTokensUsed = validatedResponse.data.usage.total_tokens; + const jsonContent = content.match(/```json([^`]+)```/)?.[1] ?? ""; if (!isJSON(jsonContent)) { throw new Error("JSON returned by perplexity is malformed"); } + // https://docs.perplexity.ai/guides/pricing#perplexity-sonar-models + const getPriceForInputToken = (tokenCount?: number) => { + if (!tokenCount) { + return 0; + } + + return (1.0 / 1_000_000) * tokenCount; + }; + + const getPriceForOutputToken = (tokenCount?: number) => { + if (!tokenCount) { + return 0; + } + + return (1.0 / 1_000_000) * tokenCount; + }; + + const priceForRequest = 5 / 1_000; + + const totalRunCost = + getPriceForInputToken(promptTokens) + + getPriceForOutputToken(outputTokens) + + priceForRequest; + const parsedContent = JSON.parse(jsonContent); - const response = await responseSchema.safeParseAsync(parsedContent); + const response = await responseSchema.safeParseAsync({ + ...parsedContent, + promptTokens, + outputTokens, + totalTokensUsed, + totalRunCost, + }); if (!response.success) { throw new Error(response.error.message);