diff --git a/convex/_generated/api.d.ts b/convex/_generated/api.d.ts index 4f841d4..4169466 100644 --- a/convex/_generated/api.d.ts +++ b/convex/_generated/api.d.ts @@ -18,6 +18,7 @@ import type { import type * as auth from "../auth.js"; import type * as constants from "../constants.js"; import type * as games from "../games.js"; +import type * as gemini from "../gemini.js"; import type * as http from "../http.js"; import type * as init from "../init.js"; import type * as maps from "../maps.js"; @@ -38,6 +39,7 @@ declare const fullApi: ApiFromModules<{ auth: typeof auth; constants: typeof constants; games: typeof games; + gemini: typeof gemini; http: typeof http; init: typeof init; maps: typeof maps; diff --git a/convex/constants.ts b/convex/constants.ts index b03b955..f3009e7 100644 --- a/convex/constants.ts +++ b/convex/constants.ts @@ -3,6 +3,10 @@ export const AI_MODELS = [ model: "gpt-4o", name: "OpenAI - 4o Mini", }, + { + model: "gemini-1.5-pro", + name: "Gemini - 1.5 Pro", + }, ]; export const AI_MODEL_IDS = AI_MODELS.map((model) => model.model); diff --git a/convex/games.ts b/convex/games.ts index 34db7ed..73b31e9 100644 --- a/convex/games.ts +++ b/convex/games.ts @@ -25,11 +25,19 @@ export const startNewGame = mutation({ throw new Error("No map found for level 1"); } - await ctx.scheduler.runAfter(0, internal.openai.playMapAction, { - gameId, - modelId: args.modelId, - level: 1, - }); + if (args.modelId === "gpt-4o") { + await ctx.scheduler.runAfter(0, internal.openai.playMapAction, { + gameId, + modelId: args.modelId, + level: 1, + }); + } else if (args.modelId === "gemini-1.5-pro") { + await ctx.scheduler.runAfter(0, internal.gemini.playMapAction, { + gameId, + modelId: args.modelId, + level: 1, + }); + } return gameId; }, diff --git a/convex/gemini.ts b/convex/gemini.ts new file mode 100644 index 0000000..a83df75 --- /dev/null +++ b/convex/gemini.ts @@ -0,0 +1,140 @@ +import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; +import { action, internalAction } from "./_generated/server"; +import { v } from "convex/values"; +import { api, internal } from "./_generated/api"; +import { Doc } from "./_generated/dataModel"; +import { ZombieSurvival } from "../simulators/zombie-survival"; + +const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY as string); + +const schema = { + description: "Game Round Results", + type: SchemaType.OBJECT, + properties: { + map: { + type: SchemaType.ARRAY, + items: { + type: SchemaType.ARRAY, + items: { + type: SchemaType.STRING, + }, + }, + description: "The resulting map after the placements", + }, + reasoning: { + type: SchemaType.STRING, + description: "The reasoning behind the move", + }, + playerCoordinates: { + type: SchemaType.ARRAY, + items: { + type: SchemaType.NUMBER, + }, + description: "The player's coordinates", + }, + boxCoordinates: { + type: SchemaType.ARRAY, + items: { + type: SchemaType.ARRAY, + items: { + type: SchemaType.NUMBER, + }, + }, + description: "The box coordinates", + }, + }, + required: ["map", "reasoning", "playerCoordinates", "boxCoordinates"], +}; + +const model = genAI.getGenerativeModel({ + model: "gemini-1.5-pro", + generationConfig: { + responseMimeType: "application/json", + responseSchema: schema, + }, +}); + +type playMapActionResponse = { + map: string[][]; + reasoning: string; + playerCoordinates: number[]; + boxCoordinates: number[][]; +}; + +export const playMapAction = internalAction({ + args: { + level: v.number(), + gameId: v.id("games"), + modelId: v.string(), + }, + handler: async (ctx, args) => { + const resultId = await ctx.runMutation( + internal.results.createInitialResult, + { + gameId: args.gameId, + level: args.level, + }, + ); + + const map: Doc<"maps"> | null = (await ctx.runQuery( + api.maps.getMapByLevel, + { + level: args.level, + }, + )) as any; + + if (!map) { + throw new Error("Map not found"); + } + + if (process.env.MOCK_GEMINI === "true") { + const existingMap = [...map.grid.map((row) => [...row])]; + existingMap[0][0] = "P"; + existingMap[0][1] = "B"; + existingMap[0][2] = "B"; + return { + map: existingMap, + reasoning: "This is a mock response", + playerCoordinates: [0, 0], + boxCoordinates: [], + }; + } + + const result = await model.generateContent( + `You're given a 2d grid of nums such that. + " " represents an empty space. + "Z" represents a zombie. Zombies move one Manhattan step every turn and aim to reach the player. + "R" represents rocks, which players can shoot over but zombies cannot pass through or break. + "P" represents the player, who cannot move. The player's goal is to shoot and kill zombies before they reach them. + "B" represents blocks that can be placed before the round begins to hinder the zombies. You can place up to two blocks on the map. + + Your goal is to place the player ("P") and two blocks ("B") in locations that maximize the player's survival by delaying the zombies' approach. + You can shoot any zombie regardless of where it is on the grid. + Returning a 2d grid with the player and blocks placed in the optimal locations, with the coordinates player ("P") and the blocks ("B"), also provide reasoning for the choices. + + You can't replace rocks R or zombies Z with blocks. If there is no room to place a block, do not place any. + + Grid: ${JSON.stringify(map)}`, + ); + + // todo: check if the response is valid acc to types and the player and box coordinates are valid, + // as sometimes the model returns a state that's erroring out in the simulator + + const parsedResponse = JSON.parse( + result.response.text(), + ) as playMapActionResponse; + + const game = new ZombieSurvival(parsedResponse.map); + while (!game.finished()) { + game.step(); + } + const isWin = !game.getPlayer().dead(); + + await ctx.runMutation(internal.results.updateResult, { + resultId, + isWin, + reasoning: parsedResponse.reasoning, + map: parsedResponse.map, + }); + }, +}); diff --git a/convex/results.ts b/convex/results.ts index 912b045..489d114 100644 --- a/convex/results.ts +++ b/convex/results.ts @@ -77,11 +77,19 @@ export const updateResult = internalMutation({ throw new Error("Next map not found"); } - await ctx.scheduler.runAfter(0, internal.openai.playMapAction, { - modelId: game.modelId, - level: result.level + 1, - gameId: result.gameId, - }); + if (game.modelId === "gpt-4o") { + await ctx.scheduler.runAfter(0, internal.openai.playMapAction, { + gameId: result.gameId, + modelId: game.modelId, + level: result.level + 1, + }); + } else if (game.modelId === "gemini-1.5-pro") { + await ctx.scheduler.runAfter(0, internal.gemini.playMapAction, { + gameId: result.gameId, + modelId: game.modelId, + level: result.level + 1, + }); + } } }, }); diff --git a/package.json b/package.json index 8af9657..cd82859 100644 --- a/package.json +++ b/package.json @@ -16,6 +16,7 @@ "dependencies": { "@auth/core": "^0.34.2", "@convex-dev/auth": "^0.0.71", + "@google/generative-ai": "^0.21.0", "@radix-ui/react-dropdown-menu": "^2.1.1", "@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-select": "^2.1.2",