Skip to content

Commit

Permalink
Implement central place for models
Browse files Browse the repository at this point in the history
  • Loading branch information
delasy committed Oct 16, 2024
1 parent eaf4971 commit 80e1c85
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 115 deletions.
2 changes: 0 additions & 2 deletions convex/_generated/api.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import type * as games from "../games.js";
import type * as http from "../http.js";
import type * as init from "../init.js";
import type * as maps from "../maps.js";
import type * as openai from "../openai.js";
import type * as results from "../results.js";
import type * as scores from "../scores.js";
import type * as users from "../users.js";
Expand All @@ -41,7 +40,6 @@ declare const fullApi: ApiFromModules<{
http: typeof http;
init: typeof init;
maps: typeof maps;
openai: typeof openai;
results: typeof results;
scores: typeof scores;
users: typeof users;
Expand Down
2 changes: 1 addition & 1 deletion convex/games.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export const startNewGame = mutation({
throw new Error("No map found for level 1");
}

await ctx.scheduler.runAfter(0, internal.openai.playMapAction, {
await ctx.scheduler.runAfter(0, internal.maps.playMapAction, {
gameId,
modelId: args.modelId,
level: 1,
Expand Down
83 changes: 82 additions & 1 deletion convex/maps.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import OpenAI from "openai";
import { internalAction, internalMutation, query } from "./_generated/server";
import { v } from "convex/values";
import { internalMutation, query } from "./_generated/server";
import { z } from "zod";
import { zodResponseFormat } from "openai/helpers/zod";
import { Doc } from "./_generated/dataModel";
import { ZombieSurvival } from "../simulators/zombie-survival";
import { api, internal } from "./_generated/api";
import { runModel } from "../models";

const MAPS = [
{
Expand Down Expand Up @@ -98,3 +105,77 @@ export const getMapByLevel = query({
.first();
},
});

export const playMapAction = internalAction({
args: {
level: v.number(),
gameId: v.id("games"),
modelId: v.string(),
},
handler: async (ctx, args) => {
const resultId = await ctx.runMutation(
internal.results.createInitialResult,
{
gameId: args.gameId,
level: args.level,
},
);

const map: Doc<"maps"> | null = (await ctx.runQuery(
api.maps.getMapByLevel,
{
level: args.level,
},
)) as any;

if (!map) {
throw new Error("Map not found");
}

if (process.env.MOCK_MODELS === "true") {
const existingMap = [...map.grid.map((row) => [...row])];

existingMap[0][0] = "P";
existingMap[0][1] = "B";
existingMap[0][2] = "B";

const game = new ZombieSurvival(existingMap);

while (!game.finished()) {
game.step();
}

const isWin = !game.getPlayer().dead();

await ctx.runMutation(internal.results.updateResult, {
resultId,
isWin,
reasoning: "This is a mock response",
map: existingMap,
});

return;
}

try {
const { solution, reasoning } = await runModel(args.modelId, map.grid);
const game = new ZombieSurvival(solution);

while (!game.finished()) {
game.step();
}

const isWin = !game.getPlayer().dead();

await ctx.runMutation(internal.results.updateResult, {
resultId,
isWin,
reasoning,
map: solution,
});
} catch (error) {
// todo: handle error
console.log(error);
}
},
});
109 changes: 0 additions & 109 deletions convex/openai.ts

This file was deleted.

4 changes: 2 additions & 2 deletions convex/results.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ export const updateResult = internalMutation({
throw new Error("Next map not found");
}

await ctx.scheduler.runAfter(0, internal.openai.playMapAction, {
await ctx.scheduler.runAfter(0, internal.maps.playMapAction, {
gameId: result.gameId,
modelId: game.modelId,
level: result.level + 1,
gameId: result.gameId,
});
}
},
Expand Down
56 changes: 56 additions & 0 deletions models/gpt-4o.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import OpenAI from "openai";
import { z } from "zod";
import { zodResponseFormat } from "openai/helpers/zod";
import { type ModelResult } from ".";

const ResponseSchema = z.object({
map: z.array(z.array(z.string())),
reasoning: z.string(),
playerCoordinates: z.array(z.number()),
boxCoordinates: z.array(z.array(z.number())),
});

export async function gpt4o(map: string[][]): Promise<ModelResult> {
const openai = new OpenAI();

const prompt = `You're given a 2d grid of nums such that.
" " represents an empty space.
"Z" represents a zombie. Zombies move one Manhattan step every turn and aim to reach the player.
"R" represents rocks, which players can shoot over but zombies cannot pass through or break.
"P" represents the player, who cannot move. The player's goal is to shoot and kill zombies before they reach them.
"B" represents blocks that can be placed before the round begins to hinder the zombies. You can place up to two blocks on the map.
Your goal is to place the player ("P") and two blocks ("B") in locations that maximize the player's survival by delaying the zombies' approach.
You can shoot any zombie regardless of where it is on the grid.
Returning a 2d grid with the player and blocks placed in the optimal locations, with the coordinates player ("P") and the blocks ("B"), also provide reasoning for the choices.
You can't replace rocks R or zombies Z with blocks. If there is no room to place a block, do not place any.`;

const completion = await openai.beta.chat.completions.parse({
model: "gpt-4o-2024-08-06",
messages: [
{
role: "system",
content: prompt,
},
{
role: "user",
content: JSON.stringify(map),
},
],
response_format: zodResponseFormat(ResponseSchema, "game_map"),
});

const response = completion.choices[0].message;

if (response.refusal) {
throw new Error(`Refusal: ${response.refusal}`);
} else if (!response.parsed) {
throw new Error("Failed to run model GPT-4o");
}

return {
solution: response.parsed.map,
reasoning: response.parsed.reasoning,
};
}
20 changes: 20 additions & 0 deletions models/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { gpt4o } from "./gpt-4o";

export interface ModelResult {
solution: string[][];
reasoning: string;
}

export async function runModel(
modelId: string,
map: string[][],
): Promise<ModelResult> {
switch (modelId) {
case "gpt-4o": {
return gpt4o(map);
}
default: {
throw new Error(`Tried running unknown model '${modelId}'`);
}
}
}

0 comments on commit 80e1c85

Please sign in to comment.