Skip to content

Commit

Permalink
Various model/prompt improvements
Browse files Browse the repository at this point in the history
All models have prompt set via system
Improved perplexity parsing
Improved prompt
All models have same temperature/top_k
  • Loading branch information
delasy committed Oct 23, 2024
1 parent ba4cd2a commit c5c55f5
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 124 deletions.
22 changes: 11 additions & 11 deletions convex/prompts.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import { v } from "convex/values";
import { Id } from "./_generated/dataModel";
import { internalMutation, mutation, query } from "./_generated/server";
import { internalMutation, query } from "./_generated/server";
import { adminMutationBuilder } from "./users";

const defaultPrompt = `
Your task is to play a game. We will give you a 2d array of characters that represent the game board. Before the game starts, you have these two tasks:
const defaultPrompt = `Your task is to play a game. We will give you a 2d array of characters that represent the game board. Before the game starts, you have these two tasks:
1. Place two blocks ("B") in locations which maximize the player's survival.
2. Place the player ("P") in a location which maximize the player's survival.
Expand All @@ -20,13 +19,13 @@ The 2d Grid is made up of characters, where each character has a meaning.
" " represents an empty space.
"Z" represents a zombie.
"R" represents rocks which zombies can not pass through and path finding will not allow them to go through.
"P" represents the player, who cannot move. The player's goal is to throw popsickles at zombies before they reach them.
"P" represents the player, who cannot move. The player's goal is to throw popsicle at zombies before they reach them.
"B" represents blocks that can be placed before the round begins to hinder the zombies.
# Game Rules
- The game is turn based.
- At the start of the turn, the player (P) throws a popsickle at the closest zombie (using euclidean distance).
- Popsickles deal 1 damage to zombies.
- At the start of the turn, the player (P) throws a popsicle at the closest zombie (using euclidean distance).
- Popsicle deal 1 damage to zombies.
- A zombie is removed from the game when its health reaches 0.
- When all zombies are removed, the player wins.
- If a zombie reaches a player, the player loses.
Expand All @@ -53,19 +52,20 @@ The 2d Grid is made up of characters, where each character has a meaning.
# Output Format
- Respond only with valid JSON. Do not write an introduction or summary.
- Assume a single paragraph explaining your placement strategy is always represented as REASONING.
- Assume a position on the 2d grid is always represented as [ROW, COL].
- Your output should be a JSON object with the following format:
{
"boxCoordinates": [[ROW, COL], [ROW, COL]],
"playerCoordinates": [ROW, COL]
"playerCoordinates": [ROW, COL],
"reasoning": "REASONING"
}
## MOST IMPORTANT RULE
# MOST IMPORTANT RULE
- DO NOT TRY TO PUT A BLOCK OR PLAYER IN A LOCATION THAT IS ALREADY OCCUPIED
`;
- DO NOT TRY TO PUT A BLOCK OR A PLAYER IN A LOCATION THAT IS ALREADY OCCUPIED`;

export type Prompt = {
_id: Id<"prompts">;
Expand Down
39 changes: 22 additions & 17 deletions models/claude-3-5-sonnet.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import { type ModelHandler } from ".";
import { Anthropic } from "@anthropic-ai/sdk";
import { z } from "zod";

export const claude35sonnet: ModelHandler = async (prompt, map) => {
const responseSchema = z.object({
playerCoordinates: z.array(z.number()),
boxCoordinates: z.array(z.array(z.number())),
reasoning: z.string(),
});

export const claude35sonnet: ModelHandler = async (prompt, map, config) => {
const anthropic = new Anthropic({
apiKey: process.env.ANTHROPIC_API_KEY,
});

const response = await anthropic.messages.create({
model: "claude-3-5-sonnet-20240620",
max_tokens: 1024,
temperature: 0,
const completion = await anthropic.messages.create({
model: "claude-3-5-sonnet-20241022",
max_tokens: config.maxTokens,
temperature: config.temperature,
top_p: config.topP,
system: prompt,
messages: [
{
Expand All @@ -19,25 +27,22 @@ export const claude35sonnet: ModelHandler = async (prompt, map) => {
],
});

const content = response.content[0];
const content = completion.content[0];

if (content.type !== "text") {
throw new Error("Unexpected response type from Claude");
throw new Error("Unexpected completion type from Claude");
}

const parsedResponse = JSON.parse(content.text);
const parsedContent = JSON.parse(content.text);
const response = await responseSchema.safeParseAsync(parsedContent);

if (
!Array.isArray(parsedResponse.boxCoordinates) ||
!Array.isArray(parsedResponse.playerCoordinates) ||
typeof parsedResponse.reasoning !== "string"
) {
throw new Error("Invalid response structure");
if (!response.success) {
throw new Error(response.error.message);
}

return {
boxCoordinates: parsedResponse.boxCoordinates,
playerCoordinates: parsedResponse.playerCoordinates,
reasoning: parsedResponse.reasoning,
boxCoordinates: response.data.boxCoordinates,
playerCoordinates: response.data.playerCoordinates,
reasoning: response.data.reasoning,
};
};
32 changes: 15 additions & 17 deletions models/gemini-1.5-pro.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import { type ModelHandler } from ".";
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";

const schema = {
interface GeminiResponse {
boxCoordinates: number[][];
map: string[][];
playerCoordinates: number[];
reasoning: string;
}

const responseSchema = {
description: "Game Round Results",
type: SchemaType.OBJECT,
properties: {
Expand Down Expand Up @@ -40,31 +47,22 @@ const schema = {
required: ["map", "reasoning", "playerCoordinates", "boxCoordinates"],
};

interface GeminiResponse {
boxCoordinates: number[][];
map: string[][];
playerCoordinates: number[];
reasoning: string;
}

export const gemini15pro: ModelHandler = async (prompt, map) => {
export const gemini15pro: ModelHandler = async (prompt, map, config) => {
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY!);

const model = genAI.getGenerativeModel({
model: "gemini-1.5-pro",
systemInstruction: prompt,
generationConfig: {
responseMimeType: "application/json",
responseSchema: schema,
responseSchema,
maxOutputTokens: config.maxTokens,
temperature: config.temperature,
topP: config.topP,
},
});

const result = await model.generateContent(
`${prompt}\n\nGrid: ${JSON.stringify(map)}`,
);

// todo: check if the response is valid acc to types and the player and box coordinates are valid,
// as sometimes the model returns a state that's erroring out in the simulator

const result = await model.generateContent(JSON.stringify(map));
const parsedResponse = JSON.parse(result.response.text()) as GeminiResponse;

return {
Expand Down
9 changes: 6 additions & 3 deletions models/gpt-4o.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@ import OpenAI from "openai";
import { zodResponseFormat } from "openai/helpers/zod";
import { z } from "zod";

const ResponseSchema = z.object({
const responseSchema = z.object({
reasoning: z.string(),
playerCoordinates: z.array(z.number()),
boxCoordinates: z.array(z.array(z.number())),
});

export const gpt4o: ModelHandler = async (prompt, map) => {
export const gpt4o: ModelHandler = async (prompt, map, config) => {
const openai = new OpenAI();

const completion = await openai.beta.chat.completions.parse({
model: "gpt-4o-2024-08-06",
max_tokens: config.maxTokens,
temperature: config.temperature,
top_p: config.topP,
messages: [
{
role: "system",
Expand All @@ -24,7 +27,7 @@ export const gpt4o: ModelHandler = async (prompt, map) => {
content: JSON.stringify(map),
},
],
response_format: zodResponseFormat(ResponseSchema, "game_map"),
response_format: zodResponseFormat(responseSchema, "game_map"),
});

const response = completion.choices[0].message;
Expand Down
32 changes: 24 additions & 8 deletions models/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,33 @@ import { gpt4o } from "./gpt-4o";
import { mistralLarge2 } from "./mistral-large-2";
import { perplexityLlama31 } from "./perplexity-llama-3.1";

const MAX_RETRIES = 3;
export interface ModelHandlerConfig {
maxTokens: number;
temperature: number;
topP: number;
}

export type ModelHandler = (
prompt: string,
map: string[][],
config: ModelHandlerConfig,
) => Promise<{
boxCoordinates: number[][];
playerCoordinates: number[];
reasoning: string;
}>;

const MAX_RETRIES = 1;

// Decision was made based on this research:
// https://discord.com/channels/663478877355507769/1295376750154350654/1298659719636058144

const CONFIG: ModelHandlerConfig = {
maxTokens: 1024,
temperature: 0.5,
topP: 0.95,
};

export async function runModel(
modelId: string,
map: string[][],
Expand All @@ -33,23 +49,23 @@ export async function runModel(
try {
switch (modelId) {
case "gemini-1.5-pro": {
result = await gemini15pro(prompt, map);
result = await gemini15pro(prompt, map, CONFIG);
break;
}
case "gpt-4o": {
result = await gpt4o(prompt, map);
result = await gpt4o(prompt, map, CONFIG);
break;
}
case "claude-3.5-sonnet": {
result = await claude35sonnet(prompt, map);
result = await claude35sonnet(prompt, map, CONFIG);
break;
}
case "perplexity-llama-3.1": {
result = await perplexityLlama31(prompt, map);
result = await perplexityLlama31(prompt, map, CONFIG);
break;
}
case "mistral-large-2": {
result = await mistralLarge2(prompt, map);
result = await mistralLarge2(prompt, map, CONFIG);
break;
}
default: {
Expand All @@ -62,7 +78,7 @@ export async function runModel(
const originalMap = ZombieSurvival.cloneMap(map);
const [playerRow, playerCol] = result.playerCoordinates;

if (originalMap[playerRow][playerCol] !== " ") {
if (originalMap[playerRow]?.[playerCol] !== " ") {
throw new Error("Tried to place player in a non-empty space");
}

Expand All @@ -71,7 +87,7 @@ export async function runModel(
for (const block of result.boxCoordinates) {
const [blockRow, blockCol] = block;

if (originalMap[blockRow][blockCol] !== " ") {
if (originalMap[blockRow]?.[blockCol] !== " ") {
throw new Error("Tried to place block in a non-empty space");
}

Expand Down
15 changes: 9 additions & 6 deletions models/mistral-large-2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@ const responseSchema = z.object({
boxCoordinates: z.array(z.array(z.number())),
});

export const mistralLarge2: ModelHandler = async (prompt, map) => {
const promptAnswerRequirement =
'Answer only with JSON containing "playerCoordinates" key, "boxCoordinates" key,' +
'and a single paragraph explaining your placement strategy as "reasoning" key.';

export const mistralLarge2: ModelHandler = async (prompt, map, config) => {
const client = new Mistral();

const completion = await client.chat.complete({
model: "mistral-large-2407",
maxTokens: config.maxTokens,
temperature: config.temperature,
topP: config.topP,
messages: [
{
role: "system",
content: prompt,
},
{
role: "user",
content: `${prompt}\n\nMap: ${JSON.stringify(map)}\n\n${promptAnswerRequirement}`,
content: JSON.stringify(map),
},
],
responseFormat: {
Expand Down
Loading

0 comments on commit c5c55f5

Please sign in to comment.