diff --git a/convex/prompts.ts b/convex/prompts.ts index 94d6cfb..67bf9ee 100644 --- a/convex/prompts.ts +++ b/convex/prompts.ts @@ -61,11 +61,7 @@ The 2d Grid is made up of characters, where each character has a meaning. "boxCoordinates": [[ROW, COL], [ROW, COL]], "playerCoordinates": [ROW, COL], "reasoning": "REASONING" -} - -# MOST IMPORTANT RULE - -- DO NOT TRY TO PUT A BLOCK OR A PLAYER IN A LOCATION THAT IS ALREADY OCCUPIED`; +}`; export type Prompt = { _id: Id<"prompts">; diff --git a/models/claude-3-5-sonnet.ts b/models/claude-3-5-sonnet.ts index 268ba51..196423e 100644 --- a/models/claude-3-5-sonnet.ts +++ b/models/claude-3-5-sonnet.ts @@ -1,4 +1,4 @@ -import { type ModelHandler } from "."; +import { type ModelHandler, getValidLocations } from "."; import { Anthropic } from "@anthropic-ai/sdk"; import { z } from "zod"; @@ -22,7 +22,11 @@ export const claude35sonnet: ModelHandler = async (prompt, map, config) => { messages: [ { role: "user", - content: JSON.stringify(map), + content: ` +Grid: ${JSON.stringify(map)} + +Valid Locations: ${JSON.stringify(getValidLocations(map))} +`, }, ], }); @@ -33,7 +37,11 @@ export const claude35sonnet: ModelHandler = async (prompt, map, config) => { throw new Error("Unexpected completion type from Claude"); } - const parsedContent = JSON.parse(content.text); + const jsonStartIndex = content.text.indexOf("{"); + const jsonEndIndex = content.text.lastIndexOf("}") + 1; + const jsonString = content.text.substring(jsonStartIndex, jsonEndIndex); + + const parsedContent = JSON.parse(jsonString); const response = await responseSchema.safeParseAsync(parsedContent); if (!response.success) { diff --git a/models/gemini-1.5-pro.ts b/models/gemini-1.5-pro.ts index 75740ed..cba3561 100644 --- a/models/gemini-1.5-pro.ts +++ b/models/gemini-1.5-pro.ts @@ -1,4 +1,4 @@ -import { type ModelHandler } from "."; +import { type ModelHandler, getValidLocations } from "."; import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; interface GeminiResponse { @@ -62,7 +62,11 @@ export const gemini15pro: ModelHandler = async (prompt, map, config) => { }, }); - const result = await model.generateContent(JSON.stringify(map)); + const result = await model.generateContent(` +Grid: ${JSON.stringify(map)} + +Valid Locations: ${JSON.stringify(getValidLocations(map))} +`); const parsedResponse = JSON.parse(result.response.text()) as GeminiResponse; return { diff --git a/models/gpt-4o.ts b/models/gpt-4o.ts index 969431d..b9d2a2b 100644 --- a/models/gpt-4o.ts +++ b/models/gpt-4o.ts @@ -1,4 +1,4 @@ -import { type ModelHandler } from "."; +import { type ModelHandler, getValidLocations } from "."; import OpenAI from "openai"; import { zodResponseFormat } from "openai/helpers/zod"; import { z } from "zod"; @@ -24,7 +24,11 @@ export const gpt4o: ModelHandler = async (prompt, map, config) => { }, { role: "user", - content: JSON.stringify(map), + content: ` +Grid: ${JSON.stringify(map)} + +Valid Locations: ${JSON.stringify(getValidLocations(map))} +`, }, ], response_format: zodResponseFormat(responseSchema, "game_map"), diff --git a/models/index.ts b/models/index.ts index 4cd0dd7..2d25e5a 100644 --- a/models/index.ts +++ b/models/index.ts @@ -12,6 +12,17 @@ export interface ModelHandlerConfig { topP: number; } +export function getValidLocations(map: string[][]) { + return map.flatMap((row, y) => + row.reduce((acc, cell, x) => { + if (cell === " ") { + acc.push([y, x]); + } + return acc; + }, [] as number[][]), + ); +} + export type ModelHandler = ( prompt: string, map: string[][], diff --git a/models/mistral-large-2.ts b/models/mistral-large-2.ts index 58b7ea8..7a67f7b 100644 --- a/models/mistral-large-2.ts +++ b/models/mistral-large-2.ts @@ -1,4 +1,4 @@ -import { type ModelHandler } from "."; +import { type ModelHandler, getValidLocations } from "."; import { isJSON } from "../lib/utils"; import { Mistral } from "@mistralai/mistralai"; import { z } from "zod"; @@ -24,7 +24,11 @@ export const mistralLarge2: ModelHandler = async (prompt, map, config) => { }, { role: "user", - content: JSON.stringify(map), + content: ` +Grid: ${JSON.stringify(map)} + +Valid Locations: ${JSON.stringify(getValidLocations(map))} +`, }, ], responseFormat: { diff --git a/models/perplexity-llama-3.1.ts b/models/perplexity-llama-3.1.ts index 288ffdf..1ea4016 100644 --- a/models/perplexity-llama-3.1.ts +++ b/models/perplexity-llama-3.1.ts @@ -1,6 +1,6 @@ import { isJSON } from "../lib/utils"; import { z } from "zod"; -import { ModelHandler } from "./index"; +import { ModelHandler, getValidLocations } from "./index"; const completionSchema = z.object({ id: z.string(), @@ -49,7 +49,11 @@ export const perplexityLlama31: ModelHandler = async (prompt, map, config) => { { role: "system", content: prompt }, { role: "user", - content: JSON.stringify(map), + content: ` +Grid: ${JSON.stringify(map)} + +Valid Locations: ${JSON.stringify(getValidLocations(map))} +`, }, ], max_tokens: config.maxTokens,