diff --git a/convex/prompts.ts b/convex/prompts.ts index bba5244..94d6cfb 100644 --- a/convex/prompts.ts +++ b/convex/prompts.ts @@ -1,10 +1,9 @@ import { v } from "convex/values"; import { Id } from "./_generated/dataModel"; -import { internalMutation, mutation, query } from "./_generated/server"; +import { internalMutation, query } from "./_generated/server"; import { adminMutationBuilder } from "./users"; -const defaultPrompt = ` -Your task is to play a game. We will give you a 2d array of characters that represent the game board. Before the game starts, you have these two tasks: +const defaultPrompt = `Your task is to play a game. We will give you a 2d array of characters that represent the game board. Before the game starts, you have these two tasks: 1. Place two blocks ("B") in locations which maximize the player's survival. 2. Place the player ("P") in a location which maximize the player's survival. @@ -20,13 +19,13 @@ The 2d Grid is made up of characters, where each character has a meaning. " " represents an empty space. "Z" represents a zombie. "R" represents rocks which zombies can not pass through and path finding will not allow them to go through. -"P" represents the player, who cannot move. The player's goal is to throw popsickles at zombies before they reach them. +"P" represents the player, who cannot move. The player's goal is to throw popsicle at zombies before they reach them. "B" represents blocks that can be placed before the round begins to hinder the zombies. # Game Rules - The game is turn based. -- At the start of the turn, the player (P) throws a popsickle at the closest zombie (using euclidean distance). -- Popsickles deal 1 damage to zombies. +- At the start of the turn, the player (P) throws a popsicle at the closest zombie (using euclidean distance). +- Popsicle deal 1 damage to zombies. - A zombie is removed from the game when its health reaches 0. - When all zombies are removed, the player wins. - If a zombie reaches a player, the player loses. @@ -53,19 +52,20 @@ The 2d Grid is made up of characters, where each character has a meaning. # Output Format +- Respond only with valid JSON. Do not write an introduction or summary. +- Assume a single paragraph explaining your placement strategy is always represented as REASONING. - Assume a position on the 2d grid is always represented as [ROW, COL]. - Your output should be a JSON object with the following format: { "boxCoordinates": [[ROW, COL], [ROW, COL]], - "playerCoordinates": [ROW, COL] + "playerCoordinates": [ROW, COL], + "reasoning": "REASONING" } -## MOST IMPORTANT RULE +# MOST IMPORTANT RULE -- DO NOT TRY TO PUT A BLOCK OR PLAYER IN A LOCATION THAT IS ALREADY OCCUPIED - -`; +- DO NOT TRY TO PUT A BLOCK OR A PLAYER IN A LOCATION THAT IS ALREADY OCCUPIED`; export type Prompt = { _id: Id<"prompts">; diff --git a/models/claude-3-5-sonnet.ts b/models/claude-3-5-sonnet.ts index d41ed6d..268ba51 100644 --- a/models/claude-3-5-sonnet.ts +++ b/models/claude-3-5-sonnet.ts @@ -1,15 +1,23 @@ import { type ModelHandler } from "."; import { Anthropic } from "@anthropic-ai/sdk"; +import { z } from "zod"; -export const claude35sonnet: ModelHandler = async (prompt, map) => { +const responseSchema = z.object({ + playerCoordinates: z.array(z.number()), + boxCoordinates: z.array(z.array(z.number())), + reasoning: z.string(), +}); + +export const claude35sonnet: ModelHandler = async (prompt, map, config) => { const anthropic = new Anthropic({ apiKey: process.env.ANTHROPIC_API_KEY, }); - const response = await anthropic.messages.create({ - model: "claude-3-5-sonnet-20240620", - max_tokens: 1024, - temperature: 0, + const completion = await anthropic.messages.create({ + model: "claude-3-5-sonnet-20241022", + max_tokens: config.maxTokens, + temperature: config.temperature, + top_p: config.topP, system: prompt, messages: [ { @@ -19,25 +27,22 @@ export const claude35sonnet: ModelHandler = async (prompt, map) => { ], }); - const content = response.content[0]; + const content = completion.content[0]; if (content.type !== "text") { - throw new Error("Unexpected response type from Claude"); + throw new Error("Unexpected completion type from Claude"); } - const parsedResponse = JSON.parse(content.text); + const parsedContent = JSON.parse(content.text); + const response = await responseSchema.safeParseAsync(parsedContent); - if ( - !Array.isArray(parsedResponse.boxCoordinates) || - !Array.isArray(parsedResponse.playerCoordinates) || - typeof parsedResponse.reasoning !== "string" - ) { - throw new Error("Invalid response structure"); + if (!response.success) { + throw new Error(response.error.message); } return { - boxCoordinates: parsedResponse.boxCoordinates, - playerCoordinates: parsedResponse.playerCoordinates, - reasoning: parsedResponse.reasoning, + boxCoordinates: response.data.boxCoordinates, + playerCoordinates: response.data.playerCoordinates, + reasoning: response.data.reasoning, }; }; diff --git a/models/gemini-1.5-pro.ts b/models/gemini-1.5-pro.ts index 0a43bca..75740ed 100644 --- a/models/gemini-1.5-pro.ts +++ b/models/gemini-1.5-pro.ts @@ -1,7 +1,14 @@ import { type ModelHandler } from "."; import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; -const schema = { +interface GeminiResponse { + boxCoordinates: number[][]; + map: string[][]; + playerCoordinates: number[]; + reasoning: string; +} + +const responseSchema = { description: "Game Round Results", type: SchemaType.OBJECT, properties: { @@ -40,31 +47,22 @@ const schema = { required: ["map", "reasoning", "playerCoordinates", "boxCoordinates"], }; -interface GeminiResponse { - boxCoordinates: number[][]; - map: string[][]; - playerCoordinates: number[]; - reasoning: string; -} - -export const gemini15pro: ModelHandler = async (prompt, map) => { +export const gemini15pro: ModelHandler = async (prompt, map, config) => { const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY!); const model = genAI.getGenerativeModel({ model: "gemini-1.5-pro", + systemInstruction: prompt, generationConfig: { responseMimeType: "application/json", - responseSchema: schema, + responseSchema, + maxOutputTokens: config.maxTokens, + temperature: config.temperature, + topP: config.topP, }, }); - const result = await model.generateContent( - `${prompt}\n\nGrid: ${JSON.stringify(map)}`, - ); - - // todo: check if the response is valid acc to types and the player and box coordinates are valid, - // as sometimes the model returns a state that's erroring out in the simulator - + const result = await model.generateContent(JSON.stringify(map)); const parsedResponse = JSON.parse(result.response.text()) as GeminiResponse; return { diff --git a/models/gpt-4o.ts b/models/gpt-4o.ts index f836c25..969431d 100644 --- a/models/gpt-4o.ts +++ b/models/gpt-4o.ts @@ -3,17 +3,20 @@ import OpenAI from "openai"; import { zodResponseFormat } from "openai/helpers/zod"; import { z } from "zod"; -const ResponseSchema = z.object({ +const responseSchema = z.object({ reasoning: z.string(), playerCoordinates: z.array(z.number()), boxCoordinates: z.array(z.array(z.number())), }); -export const gpt4o: ModelHandler = async (prompt, map) => { +export const gpt4o: ModelHandler = async (prompt, map, config) => { const openai = new OpenAI(); const completion = await openai.beta.chat.completions.parse({ model: "gpt-4o-2024-08-06", + max_tokens: config.maxTokens, + temperature: config.temperature, + top_p: config.topP, messages: [ { role: "system", @@ -24,7 +27,7 @@ export const gpt4o: ModelHandler = async (prompt, map) => { content: JSON.stringify(map), }, ], - response_format: zodResponseFormat(ResponseSchema, "game_map"), + response_format: zodResponseFormat(responseSchema, "game_map"), }); const response = completion.choices[0].message; diff --git a/models/index.ts b/models/index.ts index 40efa63..4cd0dd7 100644 --- a/models/index.ts +++ b/models/index.ts @@ -6,17 +6,33 @@ import { gpt4o } from "./gpt-4o"; import { mistralLarge2 } from "./mistral-large-2"; import { perplexityLlama31 } from "./perplexity-llama-3.1"; -const MAX_RETRIES = 3; +export interface ModelHandlerConfig { + maxTokens: number; + temperature: number; + topP: number; +} export type ModelHandler = ( prompt: string, map: string[][], + config: ModelHandlerConfig, ) => Promise<{ boxCoordinates: number[][]; playerCoordinates: number[]; reasoning: string; }>; +const MAX_RETRIES = 1; + +// Decision was made based on this research: +// https://discord.com/channels/663478877355507769/1295376750154350654/1298659719636058144 + +const CONFIG: ModelHandlerConfig = { + maxTokens: 1024, + temperature: 0.5, + topP: 0.95, +}; + export async function runModel( modelId: string, map: string[][], @@ -33,23 +49,23 @@ export async function runModel( try { switch (modelId) { case "gemini-1.5-pro": { - result = await gemini15pro(prompt, map); + result = await gemini15pro(prompt, map, CONFIG); break; } case "gpt-4o": { - result = await gpt4o(prompt, map); + result = await gpt4o(prompt, map, CONFIG); break; } case "claude-3.5-sonnet": { - result = await claude35sonnet(prompt, map); + result = await claude35sonnet(prompt, map, CONFIG); break; } case "perplexity-llama-3.1": { - result = await perplexityLlama31(prompt, map); + result = await perplexityLlama31(prompt, map, CONFIG); break; } case "mistral-large-2": { - result = await mistralLarge2(prompt, map); + result = await mistralLarge2(prompt, map, CONFIG); break; } default: { @@ -62,7 +78,7 @@ export async function runModel( const originalMap = ZombieSurvival.cloneMap(map); const [playerRow, playerCol] = result.playerCoordinates; - if (originalMap[playerRow][playerCol] !== " ") { + if (originalMap[playerRow]?.[playerCol] !== " ") { throw new Error("Tried to place player in a non-empty space"); } @@ -71,7 +87,7 @@ export async function runModel( for (const block of result.boxCoordinates) { const [blockRow, blockCol] = block; - if (originalMap[blockRow][blockCol] !== " ") { + if (originalMap[blockRow]?.[blockCol] !== " ") { throw new Error("Tried to place block in a non-empty space"); } diff --git a/models/mistral-large-2.ts b/models/mistral-large-2.ts index 15a430b..58b7ea8 100644 --- a/models/mistral-large-2.ts +++ b/models/mistral-large-2.ts @@ -9,19 +9,22 @@ const responseSchema = z.object({ boxCoordinates: z.array(z.array(z.number())), }); -export const mistralLarge2: ModelHandler = async (prompt, map) => { - const promptAnswerRequirement = - 'Answer only with JSON containing "playerCoordinates" key, "boxCoordinates" key,' + - 'and a single paragraph explaining your placement strategy as "reasoning" key.'; - +export const mistralLarge2: ModelHandler = async (prompt, map, config) => { const client = new Mistral(); const completion = await client.chat.complete({ model: "mistral-large-2407", + maxTokens: config.maxTokens, + temperature: config.temperature, + topP: config.topP, messages: [ + { + role: "system", + content: prompt, + }, { role: "user", - content: `${prompt}\n\nMap: ${JSON.stringify(map)}\n\n${promptAnswerRequirement}`, + content: JSON.stringify(map), }, ], responseFormat: { diff --git a/models/perplexity-llama-3.1.ts b/models/perplexity-llama-3.1.ts index 61619dc..288ffdf 100644 --- a/models/perplexity-llama-3.1.ts +++ b/models/perplexity-llama-3.1.ts @@ -2,7 +2,7 @@ import { isJSON } from "../lib/utils"; import { z } from "zod"; import { ModelHandler } from "./index"; -const PerplexityResponseSchema = z.object({ +const completionSchema = z.object({ id: z.string(), model: z.string(), object: z.string(), @@ -30,63 +30,46 @@ const PerplexityResponseSchema = z.object({ }), }); -const GameResponseSchema = z.object({ +const responseSchema = z.object({ playerCoordinates: z.array(z.number()), boxCoordinates: z.array(z.array(z.number())), + reasoning: z.string(), }); -export const perplexityLlama31: ModelHandler = async ( - prompt: string, - map: string[][], -) => { - const promptAnswerRequirement = - "Answer only with JSON output and a single paragraph explaining your placement strategy."; - - const messages = [ - { role: "system", content: "Be precise and concise." }, - { - role: "user", - content: `${prompt}\n\nMap:\n${JSON.stringify(map)}\n\n${promptAnswerRequirement}`, - }, - ]; - - const data = { - model: "llama-3.1-sonar-large-128k-online", - messages, - temperature: 0.2, - top_p: 0.9, - return_citations: false, - search_domain_filter: ["perplexity.ai"], - return_images: false, - return_related_questions: false, - search_recency_filter: "month", - top_k: 0, - stream: false, - presence_penalty: 0, - frequency_penalty: 1, - }; - - const response = await fetch("https://api.perplexity.ai/chat/completions", { +export const perplexityLlama31: ModelHandler = async (prompt, map, config) => { + const completion = await fetch("https://api.perplexity.ai/chat/completions", { method: "POST", headers: { Authorization: `Bearer ${process.env.PERPLEXITY_API_KEY}`, "Content-Type": "application/json", }, - body: JSON.stringify(data), + body: JSON.stringify({ + model: "llama-3.1-sonar-large-128k-online", + messages: [ + { role: "system", content: prompt }, + { + role: "user", + content: JSON.stringify(map), + }, + ], + max_tokens: config.maxTokens, + temperature: config.temperature, + top_p: config.topP, + search_domain_filter: ["perplexity.ai"], + search_recency_filter: "month", + }), }); - if (!response.ok) { - const errorData = await response.json(); + if (!completion.ok) { + const errorData = await completion.json(); throw new Error( - `HTTP error! status: ${response.status}, message: ${JSON.stringify(errorData)}`, + `HTTP error! status: ${completion.status}, message: ${JSON.stringify(errorData)}`, ); } - const responseData = await response.json(); - - const validatedResponse = - await PerplexityResponseSchema.safeParseAsync(responseData); + const data = await completion.json(); + const validatedResponse = await completionSchema.safeParseAsync(data); if (!validatedResponse.success) { throw new Error(validatedResponse.error.message); @@ -100,27 +83,11 @@ export const perplexityLlama31: ModelHandler = async ( } const parsedContent = JSON.parse(jsonContent); - const gameResponse = await GameResponseSchema.safeParseAsync(parsedContent); - - if (!gameResponse.success) { - throw new Error(gameResponse.error.message); - } - - const reasoning = content - .replace(/```json([^`]+)```/, "") - .split("\n") - .map((it) => it) - .map((it) => it.replace(/(\*\*|```)/, "").trim()) - .filter((it) => it !== "") - .join(" "); + const response = await responseSchema.safeParseAsync(parsedContent); - if (reasoning === "") { - throw new Error("Answer returned by perplexity doesn't contain reasoning"); + if (!response.success) { + throw new Error(response.error.message); } - return { - boxCoordinates: gameResponse.data.boxCoordinates, - playerCoordinates: gameResponse.data.playerCoordinates, - reasoning, - }; + return response.data; };