diff --git a/app/play/[level]/test-mode.tsx b/app/play/[level]/test-mode.tsx
index 332c770..c86a244 100644
--- a/app/play/[level]/test-mode.tsx
+++ b/app/play/[level]/test-mode.tsx
@@ -13,7 +13,7 @@ import {
} from "@/components/ui/select";
import { api } from "@/convex/_generated/api";
import { AI_MODELS } from "@/convex/constants";
-import { errorMessage } from "@/lib/utils";
+import { useAITesting } from "@/hooks/useAITesting";
interface TestModeProps {
level: number;
@@ -21,49 +21,30 @@ interface TestModeProps {
}
export default function TestMode({ level, map }: TestModeProps) {
- const [playerMap, setPlayerMap] = useState([]);
- const [isSimulating, setIsSimulating] = useState(false);
- const [gameResult, setGameResult] = useState<"WON" | "LOST" | null>(null);
const [selectedModel, setSelectedModel] = useState(AI_MODELS[0].model);
- const testAIModel = useAction(api.maps.testAIModel);
- const [aiError, setAiError] = useState(null);
- const [aiReasoning, setAiReasoning] = useState(null);
const [showOriginalMap, setShowOriginalMap] = useState(true);
+ const {
+ isSimulating,
+ gameResult,
+ aiError,
+ aiReasoning,
+ aiPromptTokensUsed,
+ aiOutputTokensUsed,
+ aiTotalTokensUsed,
+ aiTotalRunCost,
+ resultMap,
+ runTest,
+ resetAITest,
+ } = useAITesting({ testingType: "MODEL", level });
const handleAITest = async () => {
- setIsSimulating(true);
- setGameResult(null);
- setAiError(null);
- setAiReasoning(null);
setShowOriginalMap(false);
-
- try {
- const result = await testAIModel({
- level,
- modelId: selectedModel,
- });
-
- if (!result.map) {
- throw new Error("No map found");
- }
-
- setPlayerMap(result.map);
- setGameResult(result.isWin ? "WON" : "LOST");
- setAiReasoning(result.reasoning);
- } catch (error) {
- console.error("Error testing AI model:", error);
- setAiError(errorMessage(error));
- } finally {
- setIsSimulating(false);
- }
+ await runTest(selectedModel, map);
};
const handleReset = () => {
setShowOriginalMap(true);
- setPlayerMap([]);
- setGameResult(null);
- setAiError(null);
- setAiReasoning(null);
+ resetAITest();
};
return (
@@ -105,7 +86,7 @@ export default function TestMode({ level, map }: TestModeProps) {
{aiError && {aiError}
}
{gameResult && (
-
+
{aiReasoning}
)}
+
+
Token Usage:
+
+ - Prompt Tokens: {aiPromptTokensUsed ?? "N/A"}
+ - Output Tokens: {aiOutputTokensUsed ?? "N/A"}
+ - Total Tokens: {aiTotalTokensUsed ?? "N/A"}
+ -
+ Total Cost:{" "}
+ {aiTotalRunCost ? `$${aiTotalRunCost.toFixed(6)}` : "N/A"}
+
+
+
)}
diff --git a/app/playground/page.tsx b/app/playground/page.tsx
index 1082534..303fad5 100644
--- a/app/playground/page.tsx
+++ b/app/playground/page.tsx
@@ -29,6 +29,7 @@ import { useToast } from "@/components/ui/use-toast";
import { api } from "@/convex/_generated/api";
import { Id } from "@/convex/_generated/dataModel";
import { SIGN_IN_ERROR_MESSAGE } from "@/convex/users";
+import { useAITesting } from "@/hooks/useAITesting";
import { errorMessage } from "@/lib/utils";
import { ZombieSurvival } from "@/simulators/zombie-survival";
@@ -56,15 +57,26 @@ export default function PlaygroundPage() {
]);
const [model, setModel] = useState("");
const [error, setError] = useState(null);
- const [solution, setSolution] = useState(null);
- const [reasoning, setReasoning] = useState(null);
const [publishing, setPublishing] = useState(false);
- const [simulating, setSimulating] = useState(false);
const [userPlaying, setUserPlaying] = useState(false);
const [userSolution, setUserSolution] = useState([]);
const [visualizingUserSolution, setVisualizingUserSolution] = useState(false);
const [openSignInModal, setOpenSignInModal] = useState(false);
+ const {
+ isSimulating,
+ gameResult: solution,
+ aiError,
+ aiReasoning,
+ aiPromptTokensUsed,
+ aiOutputTokensUsed,
+ aiTotalTokensUsed,
+ aiTotalRunCost,
+ resultMap,
+ runTest,
+ resetAITest,
+ } = useAITesting({ testingType: "MAP" });
+
async function handlePublish() {
if (!ZombieSurvival.mapHasZombies(map)) {
alert("Add some zombies to the map first");
@@ -98,30 +110,12 @@ export default function PlaygroundPage() {
}
async function handleSimulate() {
- setError(null);
- setSolution(null);
- setReasoning(null);
-
if (!ZombieSurvival.mapHasZombies(map)) {
alert("Add some zombies to the map first");
return;
}
- setSimulating(true);
-
- const { error, solution, reasoning } = await testMap({
- modelId: model,
- map: map,
- });
-
- if (typeof error !== "undefined") {
- setError(error);
- } else {
- setSolution(solution!);
- setReasoning(reasoning!);
- }
-
- setSimulating(false);
+ await runTest(model, map);
}
function handleChangeMap(value: string[][]) {
@@ -139,8 +133,7 @@ export default function PlaygroundPage() {
}
function handleEdit() {
- setSolution(null);
- setReasoning(null);
+ resetAITest();
setUserPlaying(false);
setVisualizingUserSolution(false);
@@ -223,7 +216,7 @@ export default function PlaygroundPage() {
}
}
- // function handleReset() {
+ // function handleresetAITest() {
// handleChangeMap([]);
// setSolution(null);
// setReasoning(null);
@@ -304,7 +297,7 @@ export default function PlaygroundPage() {
autoReplay
autoStart
controls={false}
- map={visualizingUserSolution ? userSolution : solution!}
+ map={visualizingUserSolution ? userSolution : resultMap}
/>
)}
{!visualizing && (
@@ -485,7 +478,7 @@ export default function PlaygroundPage() {
)}
{visualizingUserSolution && }
- {reasoning !== null && (
-
-
-
{reasoning}
+ {aiReasoning !== null && (
+
+
+
{aiReasoning}
+
+
+
+ Prompt Tokens:{" "}
+ {aiPromptTokensUsed?.toLocaleString() ?? "N/A"}
+
+
+
+ Output Tokens:{" "}
+ {aiOutputTokensUsed?.toLocaleString() ?? "N/A"}
+
+
+ Total Tokens:{" "}
+ {aiTotalTokensUsed?.toLocaleString() ?? "N/A"}
+
+
+ Total Cost:{" "}
+ {aiTotalRunCost ? `$${aiTotalRunCost.toFixed(6)}` : "N/A"}
+
+
)}
@@ -514,11 +527,11 @@ export default function PlaygroundPage() {
>
diff --git a/convex/maps.ts b/convex/maps.ts
index 7d04732..7ab8573 100644
--- a/convex/maps.ts
+++ b/convex/maps.ts
@@ -404,16 +404,24 @@ export const testAIModel = action({
throw new Error("Active prompt not found");
}
- const { solution, reasoning, error } = await runModel(
- args.modelId,
- map.grid,
- activePrompt.prompt,
- );
+ const {
+ solution,
+ reasoning,
+ error,
+ promptTokens,
+ outputTokens,
+ totalTokensUsed,
+ totalRunCost,
+ } = await runModel(args.modelId, map.grid, activePrompt.prompt);
return {
map: solution,
isWin: error ? false : ZombieSurvival.isWin(solution!),
reasoning,
+ promptTokens,
+ outputTokens,
+ totalTokensUsed,
+ totalRunCost,
error,
};
},
diff --git a/hooks/useAITesting.ts b/hooks/useAITesting.ts
new file mode 100644
index 0000000..35e79e2
--- /dev/null
+++ b/hooks/useAITesting.ts
@@ -0,0 +1,118 @@
+"use client";
+
+import { useState } from "react";
+import { useAction } from "convex/react";
+import { api } from "@/convex/_generated/api";
+import { errorMessage } from "@/lib/utils";
+
+interface UseAITestingProps {
+ testingType: "MAP" | "MODEL";
+ level?: number;
+}
+
+interface AITestResult {
+ map?: string[][];
+ isWin?: boolean;
+ reasoning: string;
+ promptTokens?: number;
+ outputTokens?: number;
+ totalTokensUsed?: number;
+ totalRunCost?: number;
+}
+
+export function useAITesting({ testingType, level }: UseAITestingProps) {
+ const [isSimulating, setIsSimulating] = useState(false);
+ const [gameResult, setGameResult] = useState<"WON" | "LOST" | null>(null);
+ const [aiError, setAiError] = useState(null);
+ const [aiReasoning, setAiReasoning] = useState(null);
+ const [aiPromptTokensUsed, setAiPromptTokensUsed] = useState(
+ null,
+ );
+ const [aiOutputTokensUsed, setAiOutputTokensUsed] = useState(
+ null,
+ );
+ const [aiTotalTokensUsed, setAiTotalTokensUsed] = useState(
+ null,
+ );
+ const [aiTotalRunCost, setAiTotalRunCost] = useState(null);
+ const [resultMap, setResultMap] = useState([]);
+
+ const testAIModel = useAction(api.maps.testAIModel);
+ const testMap = useAction(api.maps.testMap);
+
+ const runTest = async (modelId: string, map: string[][]) => {
+ setIsSimulating(true);
+ setGameResult(null);
+ setAiError(null);
+ setAiReasoning(null);
+
+ try {
+ let result: AITestResult | null = null;
+
+ if (testingType === "MODEL") {
+ if (!level) {
+ console.error(
+ "Error testing AI model:",
+ "Level is required when testing a model",
+ );
+ setAiError("Level is required when testing a model");
+ return;
+ }
+ result = await testAIModel({
+ level,
+ modelId,
+ });
+ } else if (testingType === "MAP") {
+ result = await testMap({
+ modelId,
+ map,
+ });
+ }
+
+ if (!result?.map) {
+ throw new Error("No map found");
+ }
+
+ setResultMap(result.map);
+ setGameResult(result.isWin ? "WON" : "LOST");
+ setAiReasoning(result.reasoning);
+ setAiPromptTokensUsed(result.promptTokens ?? null);
+ setAiOutputTokensUsed(result.outputTokens ?? null);
+ setAiTotalTokensUsed(result.totalTokensUsed ?? null);
+ setAiTotalRunCost(result.totalRunCost ?? null);
+
+ return result as AITestResult;
+ } catch (error) {
+ console.error("Error testing AI model:", error);
+ setAiError(errorMessage(error));
+ throw error;
+ } finally {
+ setIsSimulating(false);
+ }
+ };
+
+ const resetAITest = () => {
+ setResultMap([]);
+ setGameResult(null);
+ setAiError(null);
+ setAiReasoning(null);
+ setAiPromptTokensUsed(null);
+ setAiOutputTokensUsed(null);
+ setAiTotalTokensUsed(null);
+ setAiTotalRunCost(null);
+ };
+
+ return {
+ isSimulating,
+ gameResult,
+ aiError,
+ aiReasoning,
+ aiPromptTokensUsed,
+ aiOutputTokensUsed,
+ aiTotalTokensUsed,
+ aiTotalRunCost,
+ resultMap,
+ runTest,
+ resetAITest,
+ };
+}
diff --git a/models/claude-3-5-sonnet.ts b/models/claude-3-5-sonnet.ts
index c848f13..be2c33c 100644
--- a/models/claude-3-5-sonnet.ts
+++ b/models/claude-3-5-sonnet.ts
@@ -48,9 +48,37 @@ export const claude35sonnet: ModelHandler = async (
throw new Error(response.error.message);
}
+ const promptTokens = completion.usage.input_tokens;
+ const outputTokens = completion.usage.output_tokens;
+ const totalTokensUsed =
+ completion.usage.input_tokens + completion.usage.output_tokens;
+
+ // https://docs.anthropic.com/en/docs/about-claude/models
+ const getPriceForInputToken = (tokenCount?: number) => {
+ if (!tokenCount) {
+ return 0;
+ }
+
+ return (3.0 / 1_000_000) * tokenCount;
+ };
+
+ const getPriceForOutputToken = (tokenCount?: number) => {
+ if (!tokenCount) {
+ return 0;
+ }
+
+ return (15.0 / 1_000_000) * tokenCount;
+ };
+
return {
boxCoordinates: response.data.boxCoordinates,
playerCoordinates: response.data.playerCoordinates,
reasoning: response.data.reasoning,
+ promptTokens: promptTokens,
+ outputTokens: outputTokens,
+ totalTokensUsed: totalTokensUsed,
+ totalRunCost:
+ getPriceForInputToken(promptTokens) +
+ getPriceForOutputToken(outputTokens),
};
};
diff --git a/models/gemini-1.5-pro.ts b/models/gemini-1.5-pro.ts
index 81a21cf..a8549b0 100644
--- a/models/gemini-1.5-pro.ts
+++ b/models/gemini-1.5-pro.ts
@@ -69,9 +69,40 @@ export const gemini15pro: ModelHandler = async (
const result = await model.generateContent(userPrompt);
const parsedResponse = JSON.parse(result.response.text()) as GeminiResponse;
+ const promptTokens = result.response.usageMetadata?.promptTokenCount;
+ const outputTokens = result.response.usageMetadata?.candidatesTokenCount;
+ const totalTokensUsed = result.response.usageMetadata?.totalTokenCount;
+
+ // https://ai.google.dev/pricing#1_5pro
+ const getPriceForInputToken = (tokenCount?: number) => {
+ if (!tokenCount) {
+ return 0;
+ }
+ if (tokenCount > 128_000) {
+ return (2.5 / 1_000_000) * tokenCount;
+ }
+ return (1.25 / 1_000_000) * tokenCount;
+ };
+
+ const getPriceForOutputToken = (tokenCount?: number) => {
+ if (!tokenCount) {
+ return 0;
+ }
+ if (tokenCount > 128_000) {
+ return (10.0 / 1_000_000) * tokenCount;
+ }
+ return (5.0 / 1_000_000) * tokenCount;
+ };
+
return {
boxCoordinates: parsedResponse.boxCoordinates,
playerCoordinates: parsedResponse.playerCoordinates,
reasoning: parsedResponse.reasoning,
+ promptTokens: promptTokens,
+ outputTokens: outputTokens,
+ totalTokensUsed: totalTokensUsed,
+ totalRunCost:
+ getPriceForInputToken(promptTokens) +
+ getPriceForOutputToken(outputTokens),
};
};
diff --git a/models/gpt-4o.ts b/models/gpt-4o.ts
index c018648..b2ef2f4 100644
--- a/models/gpt-4o.ts
+++ b/models/gpt-4o.ts
@@ -38,9 +38,36 @@ export const gpt4o: ModelHandler = async (systemPrompt, userPrompt, config) => {
throw new Error("Failed to run model GPT-4o");
}
+ const promptTokens = completion.usage?.prompt_tokens;
+ const outputTokens = completion.usage?.completion_tokens;
+ const totalTokensUsed = completion.usage?.total_tokens;
+
+ // https://openai.com/api/pricing/
+ const getPriceForInputToken = (tokenCount?: number) => {
+ if (!tokenCount) {
+ return 0;
+ }
+
+ return (2.5 / 1_000_000) * tokenCount;
+ };
+
+ const getPriceForOutputToken = (tokenCount?: number) => {
+ if (!tokenCount) {
+ return 0;
+ }
+
+ return (10.0 / 1_000_000) * tokenCount;
+ };
+
return {
boxCoordinates: response.parsed.boxCoordinates,
playerCoordinates: response.parsed.playerCoordinates,
reasoning: response.parsed.reasoning,
+ promptTokens: promptTokens,
+ outputTokens: outputTokens,
+ totalTokensUsed: totalTokensUsed,
+ totalRunCost:
+ getPriceForInputToken(promptTokens) +
+ getPriceForOutputToken(outputTokens),
};
};
diff --git a/models/index.ts b/models/index.ts
index 549f068..668de69 100644
--- a/models/index.ts
+++ b/models/index.ts
@@ -20,6 +20,10 @@ export type ModelHandler = (
boxCoordinates: number[][];
playerCoordinates: number[];
reasoning: string;
+ promptTokens?: number;
+ outputTokens?: number;
+ totalTokensUsed?: number;
+ totalRunCost?: number;
}>;
const MAX_RETRIES = 1;
@@ -33,16 +37,22 @@ const CONFIG: ModelHandlerConfig = {
topP: 0.95,
};
+export type RunModelResult = {
+ solution?: string[][];
+ reasoning: string;
+ promptTokens?: number;
+ outputTokens?: number;
+ totalTokensUsed?: number;
+ totalRunCost?: number;
+ error?: string;
+};
+
export async function runModel(
modelId: string,
map: string[][],
prompt: string,
retry = 1,
-): Promise<{
- solution?: string[][];
- reasoning: string;
- error?: string;
-}> {
+): Promise {
const userPrompt =
`Grid: ${JSON.stringify(map)}\n\n` +
`Valid Locations: ${JSON.stringify(ZombieSurvival.validLocations(map))}`;
@@ -101,6 +111,10 @@ export async function runModel(
return {
solution: originalMap,
reasoning: result.reasoning,
+ promptTokens: result.promptTokens,
+ outputTokens: result.outputTokens,
+ totalTokensUsed: result.totalTokensUsed,
+ totalRunCost: result.totalRunCost,
};
} catch (error) {
if (retry === MAX_RETRIES || reasoning === null) {
diff --git a/models/mistral-large-2.ts b/models/mistral-large-2.ts
index 64bb21b..9598f8d 100644
--- a/models/mistral-large-2.ts
+++ b/models/mistral-large-2.ts
@@ -7,6 +7,9 @@ const responseSchema = z.object({
reasoning: z.string(),
playerCoordinates: z.array(z.number()),
boxCoordinates: z.array(z.array(z.number())),
+ promptTokens: z.number().optional(),
+ outputTokens: z.number().optional(),
+ totalTokensUsed: z.number().optional(),
});
export const mistralLarge2: ModelHandler = async (
@@ -42,7 +45,36 @@ export const mistralLarge2: ModelHandler = async (
throw new Error("Response received from Mistral is not JSON");
}
- const response = await responseSchema.safeParseAsync(JSON.parse(content));
+ const promptTokens = completion.usage.promptTokens;
+ const outputTokens = completion.usage.completionTokens;
+ const totalTokensUsed = completion.usage.totalTokens;
+
+ // https://mistral.ai/technology/
+ const getPriceForInputToken = (tokenCount?: number) => {
+ if (!tokenCount) {
+ return 0;
+ }
+
+ return (2.0 / 1_000_000) * tokenCount;
+ };
+
+ const getPriceForOutputToken = (tokenCount?: number) => {
+ if (!tokenCount) {
+ return 0;
+ }
+
+ return (6.0 / 1_000_000) * tokenCount;
+ };
+
+ const response = await responseSchema.safeParseAsync({
+ ...JSON.parse(content),
+ promptTokens: completion.usage.promptTokens,
+ outputTokens: completion.usage.completionTokens,
+ totalTokensUsed: completion.usage.totalTokens,
+ totalRunCost:
+ getPriceForInputToken(promptTokens) +
+ getPriceForOutputToken(outputTokens),
+ });
if (!response.success) {
throw new Error(response.error.message);
diff --git a/models/perplexity-llama-3.1.ts b/models/perplexity-llama-3.1.ts
index 98cd9d1..17ea955 100644
--- a/models/perplexity-llama-3.1.ts
+++ b/models/perplexity-llama-3.1.ts
@@ -34,6 +34,10 @@ const responseSchema = z.object({
playerCoordinates: z.array(z.number()),
boxCoordinates: z.array(z.array(z.number())),
reasoning: z.string(),
+ promptTokens: z.number().optional(),
+ outputTokens: z.number().optional(),
+ totalTokensUsed: z.number().optional(),
+ totalRunCost: z.number().optional(),
});
export const perplexityLlama31: ModelHandler = async (
@@ -80,14 +84,48 @@ export const perplexityLlama31: ModelHandler = async (
}
const content = validatedResponse.data.choices[0].message.content;
+ const promptTokens = validatedResponse.data.usage.prompt_tokens;
+ const outputTokens = validatedResponse.data.usage.completion_tokens;
+ const totalTokensUsed = validatedResponse.data.usage.total_tokens;
+
const jsonContent = content.match(/```json([^`]+)```/)?.[1] ?? "";
if (!isJSON(jsonContent)) {
throw new Error("JSON returned by perplexity is malformed");
}
+ // https://docs.perplexity.ai/guides/pricing#perplexity-sonar-models
+ const getPriceForInputToken = (tokenCount?: number) => {
+ if (!tokenCount) {
+ return 0;
+ }
+
+ return (1.0 / 1_000_000) * tokenCount;
+ };
+
+ const getPriceForOutputToken = (tokenCount?: number) => {
+ if (!tokenCount) {
+ return 0;
+ }
+
+ return (1.0 / 1_000_000) * tokenCount;
+ };
+
+ const priceForRequest = 5 / 1_000;
+
+ const totalRunCost =
+ getPriceForInputToken(promptTokens) +
+ getPriceForOutputToken(outputTokens) +
+ priceForRequest;
+
const parsedContent = JSON.parse(jsonContent);
- const response = await responseSchema.safeParseAsync(parsedContent);
+ const response = await responseSchema.safeParseAsync({
+ ...parsedContent,
+ promptTokens,
+ outputTokens,
+ totalTokensUsed,
+ totalRunCost,
+ });
if (!response.success) {
throw new Error(response.error.message);