diff --git a/lib/utils.ts b/lib/utils.ts index 5a7b05a..34a8af8 100644 --- a/lib/utils.ts +++ b/lib/utils.ts @@ -12,3 +12,12 @@ export function errorMessage(error: unknown) { ? error : "An unexpected error occurred"; } + +export function isJSON(val: string): boolean { + try { + JSON.parse(val); + return true; + } catch { + return false; + } +} diff --git a/models/index.ts b/models/index.ts index 4ddb93e..ce4f165 100644 --- a/models/index.ts +++ b/models/index.ts @@ -3,7 +3,7 @@ import { ZombieSurvival } from "../simulators/zombie-survival"; import { claude35sonnet } from "./claude-3-5-sonnet"; import { gemini15pro } from "./gemini-1.5-pro"; import { gpt4o } from "./gpt-4o"; -import { perplexityModel } from "./perplexity-llama"; +import { perplexityLlama } from "./perplexity-llama"; export type ModelHandler = ( prompt: string, @@ -40,7 +40,7 @@ export async function runModel( break; } case "perplexity-llama-3.1": { - result = await perplexityModel(prompt, map); + result = await perplexityLlama(prompt, map); break; } default: { diff --git a/models/perplexity-llama.ts b/models/perplexity-llama.ts index 5153e34..2d4336c 100644 --- a/models/perplexity-llama.ts +++ b/models/perplexity-llama.ts @@ -1,3 +1,4 @@ +import { isJSON } from "../lib/utils"; import { z } from "zod"; import { ModelHandler } from "./index"; @@ -30,18 +31,23 @@ const PerplexityResponseSchema = z.object({ }); const GameResponseSchema = z.object({ - reasoning: z.string(), playerCoordinates: z.array(z.number()), boxCoordinates: z.array(z.array(z.number())), }); -export const perplexityModel: ModelHandler = async ( +export const perplexityLlama: ModelHandler = async ( prompt: string, map: string[][], ) => { + const promptAnswerRequirement = + "Answer only with JSON output and a single paragraph explaining your placement strategy."; + const messages = [ { role: "system", content: "Be precise and concise." }, - { role: "user", content: `${prompt}\n\nMap:\n${JSON.stringify(map)}` }, + { + role: "user", + content: `${prompt}\n\nMap:\n${JSON.stringify(map)}\n\n${promptAnswerRequirement}`, + }, ]; const data = { @@ -49,7 +55,7 @@ export const perplexityModel: ModelHandler = async ( messages, temperature: 0.2, top_p: 0.9, - return_citations: true, + return_citations: false, search_domain_filter: ["perplexity.ai"], return_images: false, return_related_questions: false, @@ -78,14 +84,43 @@ export const perplexityModel: ModelHandler = async ( } const responseData = await response.json(); - const validatedResponse = PerplexityResponseSchema.parse(responseData); - const content = validatedResponse.choices[0].message.content; - const parsedContent = JSON.parse(content); - const gameResponse = GameResponseSchema.parse(parsedContent); + + const validatedResponse = + await PerplexityResponseSchema.safeParseAsync(responseData); + + if (!validatedResponse.success) { + throw new Error(validatedResponse.error.message); + } + + const content = validatedResponse.data.choices[0].message.content; + const jsonContent = content.match(/```json([^`]+)```/)?.[1] ?? ""; + + if (!isJSON(jsonContent)) { + throw new Error("JSON returned by perplexity is malformed"); + } + + const parsedContent = JSON.parse(jsonContent); + const gameResponse = await GameResponseSchema.safeParseAsync(parsedContent); + + if (!gameResponse.success) { + throw new Error(gameResponse.error.message); + } + + const reasoning = content + .replace(/```json([^`]+)```/, "") + .split("\n") + .map((it) => it) + .map((it) => it.replace(/(\*\*|```)/, "").trim()) + .filter((it) => it !== "") + .join(" "); + + if (reasoning === "") { + throw new Error("Answer returned by perplexity doesn't contain reasoning"); + } return { - boxCoordinates: gameResponse.boxCoordinates, - playerCoordinates: gameResponse.playerCoordinates, - reasoning: gameResponse.reasoning, + boxCoordinates: gameResponse.data.boxCoordinates, + playerCoordinates: gameResponse.data.playerCoordinates, + reasoning, }; };