From 90cd194310be14d110bcf35ef219966cc61f490f Mon Sep 17 00:00:00 2001 From: Aaron Delasy Date: Thu, 24 Oct 2024 23:35:39 +0300 Subject: [PATCH] Refactor models --- models/claude-3-5-sonnet.ts | 16 +++++------ models/gemini-1.5-pro.ts | 16 +++++------ models/gpt-4o.ts | 12 +++----- models/index.ts | 29 ++++++++------------ models/mistral-large-2.ts | 16 +++++------ models/perplexity-llama-3.1.ts | 16 +++++------ simulators/zombie-survival/ZombieSurvival.ts | 11 ++++++++ 7 files changed, 58 insertions(+), 58 deletions(-) diff --git a/models/claude-3-5-sonnet.ts b/models/claude-3-5-sonnet.ts index 196423e..c848f13 100644 --- a/models/claude-3-5-sonnet.ts +++ b/models/claude-3-5-sonnet.ts @@ -1,4 +1,4 @@ -import { type ModelHandler, getValidLocations } from "."; +import { type ModelHandler } from "."; import { Anthropic } from "@anthropic-ai/sdk"; import { z } from "zod"; @@ -8,7 +8,11 @@ const responseSchema = z.object({ reasoning: z.string(), }); -export const claude35sonnet: ModelHandler = async (prompt, map, config) => { +export const claude35sonnet: ModelHandler = async ( + systemPrompt, + userPrompt, + config, +) => { const anthropic = new Anthropic({ apiKey: process.env.ANTHROPIC_API_KEY, }); @@ -18,15 +22,11 @@ export const claude35sonnet: ModelHandler = async (prompt, map, config) => { max_tokens: config.maxTokens, temperature: config.temperature, top_p: config.topP, - system: prompt, + system: systemPrompt, messages: [ { role: "user", - content: ` -Grid: ${JSON.stringify(map)} - -Valid Locations: ${JSON.stringify(getValidLocations(map))} -`, + content: userPrompt, }, ], }); diff --git a/models/gemini-1.5-pro.ts b/models/gemini-1.5-pro.ts index cba3561..81a21cf 100644 --- a/models/gemini-1.5-pro.ts +++ b/models/gemini-1.5-pro.ts @@ -1,4 +1,4 @@ -import { type ModelHandler, getValidLocations } from "."; +import { type ModelHandler } from "."; import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; interface GeminiResponse { @@ -47,12 +47,16 @@ const responseSchema = { required: ["map", "reasoning", "playerCoordinates", "boxCoordinates"], }; -export const gemini15pro: ModelHandler = async (prompt, map, config) => { +export const gemini15pro: ModelHandler = async ( + systemPrompt, + userPrompt, + config, +) => { const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY!); const model = genAI.getGenerativeModel({ model: "gemini-1.5-pro", - systemInstruction: prompt, + systemInstruction: systemPrompt, generationConfig: { responseMimeType: "application/json", responseSchema, @@ -62,11 +66,7 @@ export const gemini15pro: ModelHandler = async (prompt, map, config) => { }, }); - const result = await model.generateContent(` -Grid: ${JSON.stringify(map)} - -Valid Locations: ${JSON.stringify(getValidLocations(map))} -`); + const result = await model.generateContent(userPrompt); const parsedResponse = JSON.parse(result.response.text()) as GeminiResponse; return { diff --git a/models/gpt-4o.ts b/models/gpt-4o.ts index b9d2a2b..c018648 100644 --- a/models/gpt-4o.ts +++ b/models/gpt-4o.ts @@ -1,4 +1,4 @@ -import { type ModelHandler, getValidLocations } from "."; +import { type ModelHandler } from "."; import OpenAI from "openai"; import { zodResponseFormat } from "openai/helpers/zod"; import { z } from "zod"; @@ -9,7 +9,7 @@ const responseSchema = z.object({ boxCoordinates: z.array(z.array(z.number())), }); -export const gpt4o: ModelHandler = async (prompt, map, config) => { +export const gpt4o: ModelHandler = async (systemPrompt, userPrompt, config) => { const openai = new OpenAI(); const completion = await openai.beta.chat.completions.parse({ @@ -20,15 +20,11 @@ export const gpt4o: ModelHandler = async (prompt, map, config) => { messages: [ { role: "system", - content: prompt, + content: systemPrompt, }, { role: "user", - content: ` -Grid: ${JSON.stringify(map)} - -Valid Locations: ${JSON.stringify(getValidLocations(map))} -`, + content: userPrompt, }, ], response_format: zodResponseFormat(responseSchema, "game_map"), diff --git a/models/index.ts b/models/index.ts index 2d25e5a..549f068 100644 --- a/models/index.ts +++ b/models/index.ts @@ -12,20 +12,9 @@ export interface ModelHandlerConfig { topP: number; } -export function getValidLocations(map: string[][]) { - return map.flatMap((row, y) => - row.reduce((acc, cell, x) => { - if (cell === " ") { - acc.push([y, x]); - } - return acc; - }, [] as number[][]), - ); -} - export type ModelHandler = ( - prompt: string, - map: string[][], + systemPrompt: string, + userPrompt: string, config: ModelHandlerConfig, ) => Promise<{ boxCoordinates: number[][]; @@ -54,29 +43,33 @@ export async function runModel( reasoning: string; error?: string; }> { + const userPrompt = + `Grid: ${JSON.stringify(map)}\n\n` + + `Valid Locations: ${JSON.stringify(ZombieSurvival.validLocations(map))}`; + let result; let reasoning: string | null = null; try { switch (modelId) { case "gemini-1.5-pro": { - result = await gemini15pro(prompt, map, CONFIG); + result = await gemini15pro(prompt, userPrompt, CONFIG); break; } case "gpt-4o": { - result = await gpt4o(prompt, map, CONFIG); + result = await gpt4o(prompt, userPrompt, CONFIG); break; } case "claude-3.5-sonnet": { - result = await claude35sonnet(prompt, map, CONFIG); + result = await claude35sonnet(prompt, userPrompt, CONFIG); break; } case "perplexity-llama-3.1": { - result = await perplexityLlama31(prompt, map, CONFIG); + result = await perplexityLlama31(prompt, userPrompt, CONFIG); break; } case "mistral-large-2": { - result = await mistralLarge2(prompt, map, CONFIG); + result = await mistralLarge2(prompt, userPrompt, CONFIG); break; } default: { diff --git a/models/mistral-large-2.ts b/models/mistral-large-2.ts index 7a67f7b..64bb21b 100644 --- a/models/mistral-large-2.ts +++ b/models/mistral-large-2.ts @@ -1,4 +1,4 @@ -import { type ModelHandler, getValidLocations } from "."; +import { type ModelHandler } from "."; import { isJSON } from "../lib/utils"; import { Mistral } from "@mistralai/mistralai"; import { z } from "zod"; @@ -9,7 +9,11 @@ const responseSchema = z.object({ boxCoordinates: z.array(z.array(z.number())), }); -export const mistralLarge2: ModelHandler = async (prompt, map, config) => { +export const mistralLarge2: ModelHandler = async ( + systemPrompt, + userPrompt, + config, +) => { const client = new Mistral(); const completion = await client.chat.complete({ @@ -20,15 +24,11 @@ export const mistralLarge2: ModelHandler = async (prompt, map, config) => { messages: [ { role: "system", - content: prompt, + content: systemPrompt, }, { role: "user", - content: ` -Grid: ${JSON.stringify(map)} - -Valid Locations: ${JSON.stringify(getValidLocations(map))} -`, + content: userPrompt, }, ], responseFormat: { diff --git a/models/perplexity-llama-3.1.ts b/models/perplexity-llama-3.1.ts index 1ea4016..98cd9d1 100644 --- a/models/perplexity-llama-3.1.ts +++ b/models/perplexity-llama-3.1.ts @@ -1,6 +1,6 @@ import { isJSON } from "../lib/utils"; import { z } from "zod"; -import { ModelHandler, getValidLocations } from "./index"; +import { ModelHandler } from "./index"; const completionSchema = z.object({ id: z.string(), @@ -36,7 +36,11 @@ const responseSchema = z.object({ reasoning: z.string(), }); -export const perplexityLlama31: ModelHandler = async (prompt, map, config) => { +export const perplexityLlama31: ModelHandler = async ( + systemPrompt, + userPrompt, + config, +) => { const completion = await fetch("https://api.perplexity.ai/chat/completions", { method: "POST", headers: { @@ -46,14 +50,10 @@ export const perplexityLlama31: ModelHandler = async (prompt, map, config) => { body: JSON.stringify({ model: "llama-3.1-sonar-large-128k-online", messages: [ - { role: "system", content: prompt }, + { role: "system", content: systemPrompt }, { role: "user", - content: ` -Grid: ${JSON.stringify(map)} - -Valid Locations: ${JSON.stringify(getValidLocations(map))} -`, + content: userPrompt, }, ], max_tokens: config.maxTokens, diff --git a/simulators/zombie-survival/ZombieSurvival.ts b/simulators/zombie-survival/ZombieSurvival.ts index cb9329a..5bc40ac 100644 --- a/simulators/zombie-survival/ZombieSurvival.ts +++ b/simulators/zombie-survival/ZombieSurvival.ts @@ -62,6 +62,17 @@ export class ZombieSurvival { return map.length === 0 || map[0].length === 0; } + public static validLocations(map: string[][]): number[][] { + return map.flatMap((row, y) => + row.reduce((acc, cell, x) => { + if (cell === " ") { + acc.push([y, x]); + } + return acc; + }, [] as number[][]), + ); + } + public constructor(config: string[][]) { if (ZombieSurvival.mapIsEmpty(config)) { throw new Error("Config is empty");