Skip to content

Commit

Permalink
Refactor models
Browse files Browse the repository at this point in the history
  • Loading branch information
delasy committed Oct 24, 2024
1 parent f8348d5 commit 90cd194
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 58 deletions.
16 changes: 8 additions & 8 deletions models/claude-3-5-sonnet.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type ModelHandler, getValidLocations } from ".";
import { type ModelHandler } from ".";
import { Anthropic } from "@anthropic-ai/sdk";
import { z } from "zod";

Expand All @@ -8,7 +8,11 @@ const responseSchema = z.object({
reasoning: z.string(),
});

export const claude35sonnet: ModelHandler = async (prompt, map, config) => {
export const claude35sonnet: ModelHandler = async (
systemPrompt,
userPrompt,
config,
) => {
const anthropic = new Anthropic({
apiKey: process.env.ANTHROPIC_API_KEY,
});
Expand All @@ -18,15 +22,11 @@ export const claude35sonnet: ModelHandler = async (prompt, map, config) => {
max_tokens: config.maxTokens,
temperature: config.temperature,
top_p: config.topP,
system: prompt,
system: systemPrompt,
messages: [
{
role: "user",
content: `
Grid: ${JSON.stringify(map)}
Valid Locations: ${JSON.stringify(getValidLocations(map))}
`,
content: userPrompt,
},
],
});
Expand Down
16 changes: 8 additions & 8 deletions models/gemini-1.5-pro.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type ModelHandler, getValidLocations } from ".";
import { type ModelHandler } from ".";
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";

interface GeminiResponse {
Expand Down Expand Up @@ -47,12 +47,16 @@ const responseSchema = {
required: ["map", "reasoning", "playerCoordinates", "boxCoordinates"],
};

export const gemini15pro: ModelHandler = async (prompt, map, config) => {
export const gemini15pro: ModelHandler = async (
systemPrompt,
userPrompt,
config,
) => {
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY!);

const model = genAI.getGenerativeModel({
model: "gemini-1.5-pro",
systemInstruction: prompt,
systemInstruction: systemPrompt,
generationConfig: {
responseMimeType: "application/json",
responseSchema,
Expand All @@ -62,11 +66,7 @@ export const gemini15pro: ModelHandler = async (prompt, map, config) => {
},
});

const result = await model.generateContent(`
Grid: ${JSON.stringify(map)}
Valid Locations: ${JSON.stringify(getValidLocations(map))}
`);
const result = await model.generateContent(userPrompt);
const parsedResponse = JSON.parse(result.response.text()) as GeminiResponse;

return {
Expand Down
12 changes: 4 additions & 8 deletions models/gpt-4o.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type ModelHandler, getValidLocations } from ".";
import { type ModelHandler } from ".";
import OpenAI from "openai";
import { zodResponseFormat } from "openai/helpers/zod";
import { z } from "zod";
Expand All @@ -9,7 +9,7 @@ const responseSchema = z.object({
boxCoordinates: z.array(z.array(z.number())),
});

export const gpt4o: ModelHandler = async (prompt, map, config) => {
export const gpt4o: ModelHandler = async (systemPrompt, userPrompt, config) => {
const openai = new OpenAI();

const completion = await openai.beta.chat.completions.parse({
Expand All @@ -20,15 +20,11 @@ export const gpt4o: ModelHandler = async (prompt, map, config) => {
messages: [
{
role: "system",
content: prompt,
content: systemPrompt,
},
{
role: "user",
content: `
Grid: ${JSON.stringify(map)}
Valid Locations: ${JSON.stringify(getValidLocations(map))}
`,
content: userPrompt,
},
],
response_format: zodResponseFormat(responseSchema, "game_map"),
Expand Down
29 changes: 11 additions & 18 deletions models/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,9 @@ export interface ModelHandlerConfig {
topP: number;
}

export function getValidLocations(map: string[][]) {
return map.flatMap((row, y) =>
row.reduce((acc, cell, x) => {
if (cell === " ") {
acc.push([y, x]);
}
return acc;
}, [] as number[][]),
);
}

export type ModelHandler = (
prompt: string,
map: string[][],
systemPrompt: string,
userPrompt: string,
config: ModelHandlerConfig,
) => Promise<{
boxCoordinates: number[][];
Expand Down Expand Up @@ -54,29 +43,33 @@ export async function runModel(
reasoning: string;
error?: string;
}> {
const userPrompt =
`Grid: ${JSON.stringify(map)}\n\n` +
`Valid Locations: ${JSON.stringify(ZombieSurvival.validLocations(map))}`;

let result;
let reasoning: string | null = null;

try {
switch (modelId) {
case "gemini-1.5-pro": {
result = await gemini15pro(prompt, map, CONFIG);
result = await gemini15pro(prompt, userPrompt, CONFIG);
break;
}
case "gpt-4o": {
result = await gpt4o(prompt, map, CONFIG);
result = await gpt4o(prompt, userPrompt, CONFIG);
break;
}
case "claude-3.5-sonnet": {
result = await claude35sonnet(prompt, map, CONFIG);
result = await claude35sonnet(prompt, userPrompt, CONFIG);
break;
}
case "perplexity-llama-3.1": {
result = await perplexityLlama31(prompt, map, CONFIG);
result = await perplexityLlama31(prompt, userPrompt, CONFIG);
break;
}
case "mistral-large-2": {
result = await mistralLarge2(prompt, map, CONFIG);
result = await mistralLarge2(prompt, userPrompt, CONFIG);
break;
}
default: {
Expand Down
16 changes: 8 additions & 8 deletions models/mistral-large-2.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type ModelHandler, getValidLocations } from ".";
import { type ModelHandler } from ".";
import { isJSON } from "../lib/utils";
import { Mistral } from "@mistralai/mistralai";
import { z } from "zod";
Expand All @@ -9,7 +9,11 @@ const responseSchema = z.object({
boxCoordinates: z.array(z.array(z.number())),
});

export const mistralLarge2: ModelHandler = async (prompt, map, config) => {
export const mistralLarge2: ModelHandler = async (
systemPrompt,
userPrompt,
config,
) => {
const client = new Mistral();

const completion = await client.chat.complete({
Expand All @@ -20,15 +24,11 @@ export const mistralLarge2: ModelHandler = async (prompt, map, config) => {
messages: [
{
role: "system",
content: prompt,
content: systemPrompt,
},
{
role: "user",
content: `
Grid: ${JSON.stringify(map)}
Valid Locations: ${JSON.stringify(getValidLocations(map))}
`,
content: userPrompt,
},
],
responseFormat: {
Expand Down
16 changes: 8 additions & 8 deletions models/perplexity-llama-3.1.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { isJSON } from "../lib/utils";
import { z } from "zod";
import { ModelHandler, getValidLocations } from "./index";
import { ModelHandler } from "./index";

const completionSchema = z.object({
id: z.string(),
Expand Down Expand Up @@ -36,7 +36,11 @@ const responseSchema = z.object({
reasoning: z.string(),
});

export const perplexityLlama31: ModelHandler = async (prompt, map, config) => {
export const perplexityLlama31: ModelHandler = async (
systemPrompt,
userPrompt,
config,
) => {
const completion = await fetch("https://api.perplexity.ai/chat/completions", {
method: "POST",
headers: {
Expand All @@ -46,14 +50,10 @@ export const perplexityLlama31: ModelHandler = async (prompt, map, config) => {
body: JSON.stringify({
model: "llama-3.1-sonar-large-128k-online",
messages: [
{ role: "system", content: prompt },
{ role: "system", content: systemPrompt },
{
role: "user",
content: `
Grid: ${JSON.stringify(map)}
Valid Locations: ${JSON.stringify(getValidLocations(map))}
`,
content: userPrompt,
},
],
max_tokens: config.maxTokens,
Expand Down
11 changes: 11 additions & 0 deletions simulators/zombie-survival/ZombieSurvival.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ export class ZombieSurvival {
return map.length === 0 || map[0].length === 0;
}

public static validLocations(map: string[][]): number[][] {
return map.flatMap((row, y) =>
row.reduce((acc, cell, x) => {
if (cell === " ") {
acc.push([y, x]);
}
return acc;
}, [] as number[][]),
);
}

public constructor(config: string[][]) {
if (ZombieSurvival.mapIsEmpty(config)) {
throw new Error("Config is empty");
Expand Down

0 comments on commit 90cd194

Please sign in to comment.