diff --git a/models/multiplayer/claude-3-5-sonnet.ts b/models/multiplayer/claude-3-5-sonnet.ts new file mode 100644 index 0000000..fe29c09 --- /dev/null +++ b/models/multiplayer/claude-3-5-sonnet.ts @@ -0,0 +1,58 @@ +import { type MultiplayerModelHandler } from "."; +import { calculateTotalCost } from "../pricing"; +import { Anthropic } from "@anthropic-ai/sdk"; +import { z } from "zod"; + +const responseSchema = z.object({ + moveDirection: z.string(), + zombieToShoot: z.array(z.number()), +}); + +export const claude35sonnet: MultiplayerModelHandler = async ( + systemPrompt, + userPrompt, + config, +) => { + const anthropic = new Anthropic({ + apiKey: process.env.ANTHROPIC_API_KEY, + }); + + const completion = await anthropic.messages.create({ + model: "claude-3-5-sonnet-20241022", + max_tokens: config.maxTokens, + temperature: config.temperature, + top_p: config.topP, + system: systemPrompt, + messages: [ + { + role: "user", + content: userPrompt, + }, + ], + }); + + const content = completion.content[0]; + if (content.type !== "text") { + throw new Error("Unexpected completion type from Claude"); + } + + const jsonStartIndex = content.text.indexOf("{"); + const jsonEndIndex = content.text.lastIndexOf("}") + 1; + const jsonString = content.text.substring(jsonStartIndex, jsonEndIndex); + + const parsedContent = JSON.parse(jsonString); + const response = await responseSchema.safeParseAsync(parsedContent); + + if (!response.success) { + throw new Error(response.error.message); + } + + const promptTokens = completion.usage.input_tokens; + const outputTokens = completion.usage.output_tokens; + + return { + moveDirection: response.data.moveDirection, + zombieToShoot: response.data.zombieToShoot, + cost: calculateTotalCost("claude-3.5-sonnet", promptTokens, outputTokens), + }; +}; diff --git a/models/multiplayer/gemini-1.5-pro.ts b/models/multiplayer/gemini-1.5-pro.ts new file mode 100644 index 0000000..a0098cf --- /dev/null +++ b/models/multiplayer/gemini-1.5-pro.ts @@ -0,0 +1,53 @@ +import { type MultiplayerModelHandler } from "."; +import { calculateTotalCost } from "../pricing"; +import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; + +const responseSchema = { + type: SchemaType.OBJECT, + properties: { + moveDirection: { + type: SchemaType.STRING, + description: "The direction to move: UP, DOWN, LEFT, RIGHT, or STAY", + }, + zombieToShoot: { + type: SchemaType.ARRAY, + items: { + type: SchemaType.NUMBER, + }, + description: "The coordinates of the zombie to shoot [row, col]", + }, + }, + required: ["moveDirection", "zombieToShoot"], +}; + +export const gemini15pro: MultiplayerModelHandler = async ( + systemPrompt, + userPrompt, + config, +) => { + const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY!); + + const model = genAI.getGenerativeModel({ + model: "gemini-1.5-pro", + systemInstruction: systemPrompt, + generationConfig: { + responseMimeType: "application/json", + responseSchema, + maxOutputTokens: config.maxTokens, + temperature: config.temperature, + topP: config.topP, + }, + }); + + const result = await model.generateContent(userPrompt); + const parsedResponse = JSON.parse(result.response.text()); + + const promptTokens = result.response.usageMetadata?.promptTokenCount; + const outputTokens = result.response.usageMetadata?.candidatesTokenCount; + + return { + moveDirection: parsedResponse.moveDirection, + zombieToShoot: parsedResponse.zombieToShoot, + cost: calculateTotalCost("gemini-1.5-pro", promptTokens, outputTokens), + }; +}; diff --git a/models/multiplayer/index.ts b/models/multiplayer/index.ts index ca30c50..49cd118 100644 --- a/models/multiplayer/index.ts +++ b/models/multiplayer/index.ts @@ -1,4 +1,8 @@ +import { claude35sonnet } from "./claude-3-5-sonnet"; +import { gemini15pro } from "./gemini-1.5-pro"; import { gpt4o } from "./gpt-4o"; +import { mistralLarge2 } from "./mistral-large-2"; +import { perplexityLlama31 } from "./mp-perplexity-llama-3.1"; import { ModelSlug } from "@/convex/constants"; import { errorMessage } from "@/lib/utils"; import { ZombieSurvival } from "@/simulator"; @@ -109,6 +113,22 @@ export async function runMultiplayerModel( result = await gpt4o(SYSTEM_PROMPT, userPrompt, CONFIG); break; } + case "claude-3.5-sonnet": { + result = await claude35sonnet(SYSTEM_PROMPT, userPrompt, CONFIG); + break; + } + case "gemini-1.5-pro": { + result = await gemini15pro(SYSTEM_PROMPT, userPrompt, CONFIG); + break; + } + case "mistral-large-2": { + result = await mistralLarge2(SYSTEM_PROMPT, userPrompt, CONFIG); + break; + } + case "perplexity-llama-3.1": { + result = await perplexityLlama31(SYSTEM_PROMPT, userPrompt, CONFIG); + break; + } default: { throw new Error(`Tried running unknown model '${modelSlug}'`); } @@ -120,6 +140,7 @@ export async function runMultiplayerModel( cost: result.cost, }; } catch (error) { + console.error(error); if (retry === MAX_RETRIES || reasoning === null) { return { error: errorMessage(error), diff --git a/models/multiplayer/mistral-large-2.ts b/models/multiplayer/mistral-large-2.ts new file mode 100644 index 0000000..0df387a --- /dev/null +++ b/models/multiplayer/mistral-large-2.ts @@ -0,0 +1,54 @@ +import { type MultiplayerModelHandler } from "."; +import { calculateTotalCost } from "../pricing"; +import { Mistral } from "@mistralai/mistralai"; +import { z } from "zod"; + +const responseSchema = z.object({ + moveDirection: z.string(), + zombieToShoot: z.array(z.number()), +}); + +export const mistralLarge2: MultiplayerModelHandler = async ( + systemPrompt, + userPrompt, + config, +) => { + const client = new Mistral(); + + const completion = await client.chat.complete({ + model: "mistral-large-2407", + maxTokens: config.maxTokens, + temperature: config.temperature, + topP: config.topP, + messages: [ + { + role: "system", + content: systemPrompt, + }, + { + role: "user", + content: userPrompt, + }, + ], + responseFormat: { + type: "json_object", + }, + }); + + const content = completion.choices?.[0].message.content ?? ""; + const parsedContent = JSON.parse(content); + const response = await responseSchema.safeParseAsync(parsedContent); + + if (!response.success) { + throw new Error(response.error.message); + } + + const promptTokens = completion.usage.promptTokens; + const outputTokens = completion.usage.completionTokens; + + return { + moveDirection: response.data.moveDirection, + zombieToShoot: response.data.zombieToShoot, + cost: calculateTotalCost("mistral-large-2", promptTokens, outputTokens), + }; +}; diff --git a/models/multiplayer/mp-perplexity-llama-3.1.ts b/models/multiplayer/mp-perplexity-llama-3.1.ts new file mode 100644 index 0000000..1aa1e61 --- /dev/null +++ b/models/multiplayer/mp-perplexity-llama-3.1.ts @@ -0,0 +1,99 @@ +import { type MultiplayerModelHandler } from "."; +import { isJSON } from "../../lib/utils"; +import { calculateTotalCost } from "../pricing"; +import { z } from "zod"; + +const completionSchema = z.object({ + id: z.string(), + model: z.string(), + object: z.string(), + created: z.number(), + choices: z.array( + z.object({ + index: z.number(), + finish_reason: z.string(), + message: z.object({ + role: z.string(), + content: z.string(), + }), + }), + ), + usage: z.object({ + prompt_tokens: z.number(), + completion_tokens: z.number(), + total_tokens: z.number(), + }), +}); + +const responseSchema = z.object({ + moveDirection: z.string(), + zombieToShoot: z.array(z.number()), +}); + +export const perplexityLlama31: MultiplayerModelHandler = async ( + systemPrompt, + userPrompt, + config, +) => { + const completion = await fetch("https://api.perplexity.ai/chat/completions", { + method: "POST", + headers: { + Authorization: `Bearer ${process.env.PERPLEXITY_API_KEY}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model: "llama-3.1-sonar-large-128k-online", + messages: [ + { role: "system", content: systemPrompt }, + { role: "user", content: userPrompt }, + ], + max_tokens: config.maxTokens, + temperature: config.temperature, + top_p: config.topP, + }), + }); + + if (!completion.ok) { + const errorData = await completion.json(); + throw new Error( + `HTTP error! status: ${completion.status}, message: ${JSON.stringify(errorData)}`, + ); + } + + const data = await completion.json(); + const validatedResponse = await completionSchema.safeParseAsync(data); + + if (!validatedResponse.success) { + throw new Error(validatedResponse.error.message); + } + + const content = validatedResponse.data.choices[0].message.content; + const promptTokens = validatedResponse.data.usage.prompt_tokens; + const outputTokens = validatedResponse.data.usage.completion_tokens; + + // Extract JSON from markdown code block if present + const jsonContent = content.match(/```json([^`]+)```/)?.[1] ?? ""; + + if (!isJSON(jsonContent)) { + throw new Error("JSON returned by perplexity is malformed"); + } + + const parsedContent = JSON.parse(jsonContent); + const response = await responseSchema.safeParseAsync(parsedContent); + + if (!response.success) { + throw new Error(response.error.message); + } + + const cost = calculateTotalCost( + "perplexity-llama-3.1", + promptTokens, + outputTokens, + ); + + return { + moveDirection: response.data.moveDirection, + zombieToShoot: response.data.zombieToShoot, + cost, + }; +};