Skip to content

Commit

Permalink
fixing llama
Browse files Browse the repository at this point in the history
  • Loading branch information
webdevcody committed Nov 5, 2024
1 parent 2f7332c commit 7b3976a
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 0 deletions.
58 changes: 58 additions & 0 deletions models/multiplayer/claude-3-5-sonnet.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import { type MultiplayerModelHandler } from ".";
import { calculateTotalCost } from "../pricing";
import { Anthropic } from "@anthropic-ai/sdk";
import { z } from "zod";

const responseSchema = z.object({
moveDirection: z.string(),
zombieToShoot: z.array(z.number()),
});

export const claude35sonnet: MultiplayerModelHandler = async (
systemPrompt,
userPrompt,
config,
) => {
const anthropic = new Anthropic({
apiKey: process.env.ANTHROPIC_API_KEY,
});

const completion = await anthropic.messages.create({
model: "claude-3-5-sonnet-20241022",
max_tokens: config.maxTokens,
temperature: config.temperature,
top_p: config.topP,
system: systemPrompt,
messages: [
{
role: "user",
content: userPrompt,
},
],
});

const content = completion.content[0];
if (content.type !== "text") {
throw new Error("Unexpected completion type from Claude");
}

const jsonStartIndex = content.text.indexOf("{");
const jsonEndIndex = content.text.lastIndexOf("}") + 1;
const jsonString = content.text.substring(jsonStartIndex, jsonEndIndex);

const parsedContent = JSON.parse(jsonString);
const response = await responseSchema.safeParseAsync(parsedContent);

if (!response.success) {
throw new Error(response.error.message);
}

const promptTokens = completion.usage.input_tokens;
const outputTokens = completion.usage.output_tokens;

return {
moveDirection: response.data.moveDirection,
zombieToShoot: response.data.zombieToShoot,
cost: calculateTotalCost("claude-3.5-sonnet", promptTokens, outputTokens),
};
};
53 changes: 53 additions & 0 deletions models/multiplayer/gemini-1.5-pro.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import { type MultiplayerModelHandler } from ".";
import { calculateTotalCost } from "../pricing";
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai";

const responseSchema = {
type: SchemaType.OBJECT,
properties: {
moveDirection: {
type: SchemaType.STRING,
description: "The direction to move: UP, DOWN, LEFT, RIGHT, or STAY",
},
zombieToShoot: {
type: SchemaType.ARRAY,
items: {
type: SchemaType.NUMBER,
},
description: "The coordinates of the zombie to shoot [row, col]",
},
},
required: ["moveDirection", "zombieToShoot"],
};

export const gemini15pro: MultiplayerModelHandler = async (
systemPrompt,
userPrompt,
config,
) => {
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY!);

const model = genAI.getGenerativeModel({
model: "gemini-1.5-pro",
systemInstruction: systemPrompt,
generationConfig: {
responseMimeType: "application/json",
responseSchema,
maxOutputTokens: config.maxTokens,
temperature: config.temperature,
topP: config.topP,
},
});

const result = await model.generateContent(userPrompt);
const parsedResponse = JSON.parse(result.response.text());

const promptTokens = result.response.usageMetadata?.promptTokenCount;
const outputTokens = result.response.usageMetadata?.candidatesTokenCount;

return {
moveDirection: parsedResponse.moveDirection,
zombieToShoot: parsedResponse.zombieToShoot,
cost: calculateTotalCost("gemini-1.5-pro", promptTokens, outputTokens),
};
};
21 changes: 21 additions & 0 deletions models/multiplayer/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import { claude35sonnet } from "./claude-3-5-sonnet";
import { gemini15pro } from "./gemini-1.5-pro";
import { gpt4o } from "./gpt-4o";
import { mistralLarge2 } from "./mistral-large-2";
import { perplexityLlama31 } from "./mp-perplexity-llama-3.1";
import { ModelSlug } from "@/convex/constants";
import { errorMessage } from "@/lib/utils";
import { ZombieSurvival } from "@/simulator";
Expand Down Expand Up @@ -109,6 +113,22 @@ export async function runMultiplayerModel(
result = await gpt4o(SYSTEM_PROMPT, userPrompt, CONFIG);
break;
}
case "claude-3.5-sonnet": {
result = await claude35sonnet(SYSTEM_PROMPT, userPrompt, CONFIG);
break;
}
case "gemini-1.5-pro": {
result = await gemini15pro(SYSTEM_PROMPT, userPrompt, CONFIG);
break;
}
case "mistral-large-2": {
result = await mistralLarge2(SYSTEM_PROMPT, userPrompt, CONFIG);
break;
}
case "perplexity-llama-3.1": {
result = await perplexityLlama31(SYSTEM_PROMPT, userPrompt, CONFIG);
break;
}
default: {
throw new Error(`Tried running unknown model '${modelSlug}'`);
}
Expand All @@ -120,6 +140,7 @@ export async function runMultiplayerModel(
cost: result.cost,
};
} catch (error) {
console.error(error);
if (retry === MAX_RETRIES || reasoning === null) {
return {
error: errorMessage(error),
Expand Down
54 changes: 54 additions & 0 deletions models/multiplayer/mistral-large-2.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { type MultiplayerModelHandler } from ".";
import { calculateTotalCost } from "../pricing";
import { Mistral } from "@mistralai/mistralai";
import { z } from "zod";

const responseSchema = z.object({
moveDirection: z.string(),
zombieToShoot: z.array(z.number()),
});

export const mistralLarge2: MultiplayerModelHandler = async (
systemPrompt,
userPrompt,
config,
) => {
const client = new Mistral();

const completion = await client.chat.complete({
model: "mistral-large-2407",
maxTokens: config.maxTokens,
temperature: config.temperature,
topP: config.topP,
messages: [
{
role: "system",
content: systemPrompt,
},
{
role: "user",
content: userPrompt,
},
],
responseFormat: {
type: "json_object",
},
});

const content = completion.choices?.[0].message.content ?? "";
const parsedContent = JSON.parse(content);
const response = await responseSchema.safeParseAsync(parsedContent);

if (!response.success) {
throw new Error(response.error.message);
}

const promptTokens = completion.usage.promptTokens;
const outputTokens = completion.usage.completionTokens;

return {
moveDirection: response.data.moveDirection,
zombieToShoot: response.data.zombieToShoot,
cost: calculateTotalCost("mistral-large-2", promptTokens, outputTokens),
};
};
99 changes: 99 additions & 0 deletions models/multiplayer/mp-perplexity-llama-3.1.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import { type MultiplayerModelHandler } from ".";
import { isJSON } from "../../lib/utils";
import { calculateTotalCost } from "../pricing";
import { z } from "zod";

const completionSchema = z.object({
id: z.string(),
model: z.string(),
object: z.string(),
created: z.number(),
choices: z.array(
z.object({
index: z.number(),
finish_reason: z.string(),
message: z.object({
role: z.string(),
content: z.string(),
}),
}),
),
usage: z.object({
prompt_tokens: z.number(),
completion_tokens: z.number(),
total_tokens: z.number(),
}),
});

const responseSchema = z.object({
moveDirection: z.string(),
zombieToShoot: z.array(z.number()),
});

export const perplexityLlama31: MultiplayerModelHandler = async (
systemPrompt,
userPrompt,
config,
) => {
const completion = await fetch("https://api.perplexity.ai/chat/completions", {
method: "POST",
headers: {
Authorization: `Bearer ${process.env.PERPLEXITY_API_KEY}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
model: "llama-3.1-sonar-large-128k-online",
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: userPrompt },
],
max_tokens: config.maxTokens,
temperature: config.temperature,
top_p: config.topP,
}),
});

if (!completion.ok) {
const errorData = await completion.json();
throw new Error(
`HTTP error! status: ${completion.status}, message: ${JSON.stringify(errorData)}`,
);
}

const data = await completion.json();
const validatedResponse = await completionSchema.safeParseAsync(data);

if (!validatedResponse.success) {
throw new Error(validatedResponse.error.message);
}

const content = validatedResponse.data.choices[0].message.content;
const promptTokens = validatedResponse.data.usage.prompt_tokens;
const outputTokens = validatedResponse.data.usage.completion_tokens;

// Extract JSON from markdown code block if present
const jsonContent = content.match(/```json([^`]+)```/)?.[1] ?? "";

if (!isJSON(jsonContent)) {
throw new Error("JSON returned by perplexity is malformed");
}

const parsedContent = JSON.parse(jsonContent);
const response = await responseSchema.safeParseAsync(parsedContent);

if (!response.success) {
throw new Error(response.error.message);
}

const cost = calculateTotalCost(
"perplexity-llama-3.1",
promptTokens,
outputTokens,
);

return {
moveDirection: response.data.moveDirection,
zombieToShoot: response.data.zombieToShoot,
cost,
};
};

0 comments on commit 7b3976a

Please sign in to comment.