Skip to content

Commit

Permalink
improve prompting
Browse files Browse the repository at this point in the history
  • Loading branch information
webdevcody committed Oct 24, 2024
1 parent 7699d6d commit 9373980
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 16 deletions.
6 changes: 1 addition & 5 deletions convex/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ The 2d Grid is made up of characters, where each character has a meaning.
"boxCoordinates": [[ROW, COL], [ROW, COL]],
"playerCoordinates": [ROW, COL],
"reasoning": "REASONING"
}
# MOST IMPORTANT RULE
- DO NOT TRY TO PUT A BLOCK OR A PLAYER IN A LOCATION THAT IS ALREADY OCCUPIED`;
}`;

export type Prompt = {
_id: Id<"prompts">;
Expand Down
14 changes: 11 additions & 3 deletions models/claude-3-5-sonnet.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type ModelHandler } from ".";
import { type ModelHandler, getValidLocations } from ".";
import { Anthropic } from "@anthropic-ai/sdk";
import { z } from "zod";

Expand All @@ -22,7 +22,11 @@ export const claude35sonnet: ModelHandler = async (prompt, map, config) => {
messages: [
{
role: "user",
content: JSON.stringify(map),
content: `
Grid: ${JSON.stringify(map)}
Valid Locations: ${JSON.stringify(getValidLocations(map))}
`,
},
],
});
Expand All @@ -33,7 +37,11 @@ export const claude35sonnet: ModelHandler = async (prompt, map, config) => {
throw new Error("Unexpected completion type from Claude");
}

const parsedContent = JSON.parse(content.text);
const jsonStartIndex = content.text.indexOf("{");
const jsonEndIndex = content.text.lastIndexOf("}") + 1;
const jsonString = content.text.substring(jsonStartIndex, jsonEndIndex);

const parsedContent = JSON.parse(jsonString);
const response = await responseSchema.safeParseAsync(parsedContent);

if (!response.success) {
Expand Down
8 changes: 6 additions & 2 deletions models/gemini-1.5-pro.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type ModelHandler } from ".";
import { type ModelHandler, getValidLocations } from ".";
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";

interface GeminiResponse {
Expand Down Expand Up @@ -62,7 +62,11 @@ export const gemini15pro: ModelHandler = async (prompt, map, config) => {
},
});

const result = await model.generateContent(JSON.stringify(map));
const result = await model.generateContent(`
Grid: ${JSON.stringify(map)}
Valid Locations: ${JSON.stringify(getValidLocations(map))}
`);
const parsedResponse = JSON.parse(result.response.text()) as GeminiResponse;

return {
Expand Down
8 changes: 6 additions & 2 deletions models/gpt-4o.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type ModelHandler } from ".";
import { type ModelHandler, getValidLocations } from ".";
import OpenAI from "openai";
import { zodResponseFormat } from "openai/helpers/zod";
import { z } from "zod";
Expand All @@ -24,7 +24,11 @@ export const gpt4o: ModelHandler = async (prompt, map, config) => {
},
{
role: "user",
content: JSON.stringify(map),
content: `
Grid: ${JSON.stringify(map)}
Valid Locations: ${JSON.stringify(getValidLocations(map))}
`,
},
],
response_format: zodResponseFormat(responseSchema, "game_map"),
Expand Down
11 changes: 11 additions & 0 deletions models/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ export interface ModelHandlerConfig {
topP: number;
}

export function getValidLocations(map: string[][]) {
return map.flatMap((row, y) =>
row.reduce((acc, cell, x) => {
if (cell === " ") {
acc.push([y, x]);
}
return acc;
}, [] as number[][]),
);
}

export type ModelHandler = (
prompt: string,
map: string[][],
Expand Down
8 changes: 6 additions & 2 deletions models/mistral-large-2.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type ModelHandler } from ".";
import { type ModelHandler, getValidLocations } from ".";
import { isJSON } from "../lib/utils";
import { Mistral } from "@mistralai/mistralai";
import { z } from "zod";
Expand All @@ -24,7 +24,11 @@ export const mistralLarge2: ModelHandler = async (prompt, map, config) => {
},
{
role: "user",
content: JSON.stringify(map),
content: `
Grid: ${JSON.stringify(map)}
Valid Locations: ${JSON.stringify(getValidLocations(map))}
`,
},
],
responseFormat: {
Expand Down
8 changes: 6 additions & 2 deletions models/perplexity-llama-3.1.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { isJSON } from "../lib/utils";
import { z } from "zod";
import { ModelHandler } from "./index";
import { ModelHandler, getValidLocations } from "./index";

const completionSchema = z.object({
id: z.string(),
Expand Down Expand Up @@ -49,7 +49,11 @@ export const perplexityLlama31: ModelHandler = async (prompt, map, config) => {
{ role: "system", content: prompt },
{
role: "user",
content: JSON.stringify(map),
content: `
Grid: ${JSON.stringify(map)}
Valid Locations: ${JSON.stringify(getValidLocations(map))}
`,
},
],
max_tokens: config.maxTokens,
Expand Down

0 comments on commit 9373980

Please sign in to comment.