-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2f7332c
commit 7b3976a
Showing
5 changed files
with
285 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import { type MultiplayerModelHandler } from "."; | ||
import { calculateTotalCost } from "../pricing"; | ||
import { Anthropic } from "@anthropic-ai/sdk"; | ||
import { z } from "zod"; | ||
|
||
const responseSchema = z.object({ | ||
moveDirection: z.string(), | ||
zombieToShoot: z.array(z.number()), | ||
}); | ||
|
||
export const claude35sonnet: MultiplayerModelHandler = async ( | ||
systemPrompt, | ||
userPrompt, | ||
config, | ||
) => { | ||
const anthropic = new Anthropic({ | ||
apiKey: process.env.ANTHROPIC_API_KEY, | ||
}); | ||
|
||
const completion = await anthropic.messages.create({ | ||
model: "claude-3-5-sonnet-20241022", | ||
max_tokens: config.maxTokens, | ||
temperature: config.temperature, | ||
top_p: config.topP, | ||
system: systemPrompt, | ||
messages: [ | ||
{ | ||
role: "user", | ||
content: userPrompt, | ||
}, | ||
], | ||
}); | ||
|
||
const content = completion.content[0]; | ||
if (content.type !== "text") { | ||
throw new Error("Unexpected completion type from Claude"); | ||
} | ||
|
||
const jsonStartIndex = content.text.indexOf("{"); | ||
const jsonEndIndex = content.text.lastIndexOf("}") + 1; | ||
const jsonString = content.text.substring(jsonStartIndex, jsonEndIndex); | ||
|
||
const parsedContent = JSON.parse(jsonString); | ||
const response = await responseSchema.safeParseAsync(parsedContent); | ||
|
||
if (!response.success) { | ||
throw new Error(response.error.message); | ||
} | ||
|
||
const promptTokens = completion.usage.input_tokens; | ||
const outputTokens = completion.usage.output_tokens; | ||
|
||
return { | ||
moveDirection: response.data.moveDirection, | ||
zombieToShoot: response.data.zombieToShoot, | ||
cost: calculateTotalCost("claude-3.5-sonnet", promptTokens, outputTokens), | ||
}; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import { type MultiplayerModelHandler } from "."; | ||
import { calculateTotalCost } from "../pricing"; | ||
import { GoogleGenerativeAI, SchemaType } from "@google/generative-ai"; | ||
|
||
const responseSchema = { | ||
type: SchemaType.OBJECT, | ||
properties: { | ||
moveDirection: { | ||
type: SchemaType.STRING, | ||
description: "The direction to move: UP, DOWN, LEFT, RIGHT, or STAY", | ||
}, | ||
zombieToShoot: { | ||
type: SchemaType.ARRAY, | ||
items: { | ||
type: SchemaType.NUMBER, | ||
}, | ||
description: "The coordinates of the zombie to shoot [row, col]", | ||
}, | ||
}, | ||
required: ["moveDirection", "zombieToShoot"], | ||
}; | ||
|
||
export const gemini15pro: MultiplayerModelHandler = async ( | ||
systemPrompt, | ||
userPrompt, | ||
config, | ||
) => { | ||
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY!); | ||
|
||
const model = genAI.getGenerativeModel({ | ||
model: "gemini-1.5-pro", | ||
systemInstruction: systemPrompt, | ||
generationConfig: { | ||
responseMimeType: "application/json", | ||
responseSchema, | ||
maxOutputTokens: config.maxTokens, | ||
temperature: config.temperature, | ||
topP: config.topP, | ||
}, | ||
}); | ||
|
||
const result = await model.generateContent(userPrompt); | ||
const parsedResponse = JSON.parse(result.response.text()); | ||
|
||
const promptTokens = result.response.usageMetadata?.promptTokenCount; | ||
const outputTokens = result.response.usageMetadata?.candidatesTokenCount; | ||
|
||
return { | ||
moveDirection: parsedResponse.moveDirection, | ||
zombieToShoot: parsedResponse.zombieToShoot, | ||
cost: calculateTotalCost("gemini-1.5-pro", promptTokens, outputTokens), | ||
}; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import { type MultiplayerModelHandler } from "."; | ||
import { calculateTotalCost } from "../pricing"; | ||
import { Mistral } from "@mistralai/mistralai"; | ||
import { z } from "zod"; | ||
|
||
const responseSchema = z.object({ | ||
moveDirection: z.string(), | ||
zombieToShoot: z.array(z.number()), | ||
}); | ||
|
||
export const mistralLarge2: MultiplayerModelHandler = async ( | ||
systemPrompt, | ||
userPrompt, | ||
config, | ||
) => { | ||
const client = new Mistral(); | ||
|
||
const completion = await client.chat.complete({ | ||
model: "mistral-large-2407", | ||
maxTokens: config.maxTokens, | ||
temperature: config.temperature, | ||
topP: config.topP, | ||
messages: [ | ||
{ | ||
role: "system", | ||
content: systemPrompt, | ||
}, | ||
{ | ||
role: "user", | ||
content: userPrompt, | ||
}, | ||
], | ||
responseFormat: { | ||
type: "json_object", | ||
}, | ||
}); | ||
|
||
const content = completion.choices?.[0].message.content ?? ""; | ||
const parsedContent = JSON.parse(content); | ||
const response = await responseSchema.safeParseAsync(parsedContent); | ||
|
||
if (!response.success) { | ||
throw new Error(response.error.message); | ||
} | ||
|
||
const promptTokens = completion.usage.promptTokens; | ||
const outputTokens = completion.usage.completionTokens; | ||
|
||
return { | ||
moveDirection: response.data.moveDirection, | ||
zombieToShoot: response.data.zombieToShoot, | ||
cost: calculateTotalCost("mistral-large-2", promptTokens, outputTokens), | ||
}; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import { type MultiplayerModelHandler } from "."; | ||
import { isJSON } from "../../lib/utils"; | ||
import { calculateTotalCost } from "../pricing"; | ||
import { z } from "zod"; | ||
|
||
const completionSchema = z.object({ | ||
id: z.string(), | ||
model: z.string(), | ||
object: z.string(), | ||
created: z.number(), | ||
choices: z.array( | ||
z.object({ | ||
index: z.number(), | ||
finish_reason: z.string(), | ||
message: z.object({ | ||
role: z.string(), | ||
content: z.string(), | ||
}), | ||
}), | ||
), | ||
usage: z.object({ | ||
prompt_tokens: z.number(), | ||
completion_tokens: z.number(), | ||
total_tokens: z.number(), | ||
}), | ||
}); | ||
|
||
const responseSchema = z.object({ | ||
moveDirection: z.string(), | ||
zombieToShoot: z.array(z.number()), | ||
}); | ||
|
||
export const perplexityLlama31: MultiplayerModelHandler = async ( | ||
systemPrompt, | ||
userPrompt, | ||
config, | ||
) => { | ||
const completion = await fetch("https://api.perplexity.ai/chat/completions", { | ||
method: "POST", | ||
headers: { | ||
Authorization: `Bearer ${process.env.PERPLEXITY_API_KEY}`, | ||
"Content-Type": "application/json", | ||
}, | ||
body: JSON.stringify({ | ||
model: "llama-3.1-sonar-large-128k-online", | ||
messages: [ | ||
{ role: "system", content: systemPrompt }, | ||
{ role: "user", content: userPrompt }, | ||
], | ||
max_tokens: config.maxTokens, | ||
temperature: config.temperature, | ||
top_p: config.topP, | ||
}), | ||
}); | ||
|
||
if (!completion.ok) { | ||
const errorData = await completion.json(); | ||
throw new Error( | ||
`HTTP error! status: ${completion.status}, message: ${JSON.stringify(errorData)}`, | ||
); | ||
} | ||
|
||
const data = await completion.json(); | ||
const validatedResponse = await completionSchema.safeParseAsync(data); | ||
|
||
if (!validatedResponse.success) { | ||
throw new Error(validatedResponse.error.message); | ||
} | ||
|
||
const content = validatedResponse.data.choices[0].message.content; | ||
const promptTokens = validatedResponse.data.usage.prompt_tokens; | ||
const outputTokens = validatedResponse.data.usage.completion_tokens; | ||
|
||
// Extract JSON from markdown code block if present | ||
const jsonContent = content.match(/```json([^`]+)```/)?.[1] ?? ""; | ||
|
||
if (!isJSON(jsonContent)) { | ||
throw new Error("JSON returned by perplexity is malformed"); | ||
} | ||
|
||
const parsedContent = JSON.parse(jsonContent); | ||
const response = await responseSchema.safeParseAsync(parsedContent); | ||
|
||
if (!response.success) { | ||
throw new Error(response.error.message); | ||
} | ||
|
||
const cost = calculateTotalCost( | ||
"perplexity-llama-3.1", | ||
promptTokens, | ||
outputTokens, | ||
); | ||
|
||
return { | ||
moveDirection: response.data.moveDirection, | ||
zombieToShoot: response.data.zombieToShoot, | ||
cost, | ||
}; | ||
}; |