-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathopenai.js
129 lines (109 loc) · 3.96 KB
/
openai.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
const ENDPOINT_URL = 'https://us-central1-musemuse.cloudfunctions.net/openai_complete';
import {DEFAULT_TEMPERATURE} from './constants.js';
// const ENDPOINT_URL = 'http://localhost:8080';
export async function extractEntities(groundingText, entityCount = 5) {
const prompt = generateExtractEntitiesPrompt(groundingText, entityCount);
const gptRes = await callGPT4(prompt);
// Parse response into a list.
let entities = gptRes.split('\n');
entities = entities.map(ent => cleanupEntity(ent));
return entities;
}
export async function isCausalLink(groundingText, entity1, entity2, {isOpposite = false} = {}) {
const prompt = isOpposite ?
generateCausalNegativeLinkPrompt(groundingText, entity1, entity2) :
generateCausalLinkPrompt(groundingText, entity1, entity2);
const gptRes = await callGPT4(prompt);
const lowerGptRes = gptRes.toLowerCase();
// Parse response into a boolean.
let res = null;
switch (lowerGptRes) {
case 'true':
res = true;
break;
case 'false':
res = false;
break;
default:
throw new Error(`Could not parse response into boolean: "${gptRes}"`);
}
if (res) {
console.log(`isCausalLink${isOpposite ? ' (opposite)' : ''}: ${entity1} --> ${entity2}`);
}
return res;
}
export async function explainCausalLink(groundingText, entity1, entity2, isOpposite = false) {
const prompt = generateExplainLinkPrompt(groundingText, entity1, entity2, isOpposite);
const gptRes = await callGPT4(prompt);
return gptRes;
}
export async function callGPT4(prompt, {verbose = true} = {}) {
if (!prompt) {
throw new Error(`Prompt required.`);
}
const openAiKey = localStorage.getItem('openai_key');
if (!openAiKey) {
alert(`Please paste your OpenAI key into the first input element.`);
throw new Error(`No OpenAI key specified.`);
}
const request = {
// Assume the key is saved in storage by this point.
openai_key: openAiKey,
model: 'gpt-4',
temperature: DEFAULT_TEMPERATURE,
prompt,
};
if (verbose) {
console.log(`[GPT] Request prompt "${prompt}"`);
}
const start = performance.now();
const res = await fetch(ENDPOINT_URL + '?' + new URLSearchParams(request));
const duration = performance.now() - start;
if (res.status !== 200) {
alert(`Error calling GPT. See console.`);
throw new Error(`Error calling GPT: ${res.text}`);
}
const text = await res.text();
if (verbose) {
console.log(`[GPT] Response text "${text}"`);
} else {
console.log(`[GPT] Response: ${text.length} chars. Took ${Math.floor(duration)} ms.`);
}
return text;
}
function generateExtractEntitiesPrompt(groundingText, entityCount) {
return `Text: ${groundingText}
The following ${entityCount} entities appear in the text above:
-`;
}
function generateCausalLinkPrompt(groundingText, entity1, entity2) {
return `Text: ${groundingText}
The text above suggests that more ${entity1} causes more ${entity2}. Answer one of "true" or "false".`;
}
function generateCausalNegativeLinkPrompt(groundingText, entity1, entity2) {
return `Text: ${groundingText}
The text above suggests that more ${entity1} causes less ${entity2}. Answer one of "true" or "false".`;
}
function generateExplainLinkPrompt(groundingText, entity1, entity2, isOpposite) {
const adverb = isOpposite ? 'less' : 'more';
return `Text: ${groundingText}
The text above suggests that more ${entity1} causes ${adverb} ${entity2}. Explain why in fewer than ten words.`;
}
function cleanupEntity(ent) {
console.info('cleanupEntity', ent);
// Usually in the form of a bulleted or numbered list, eg ("30. Foo" or "- Bar").
// First check for bullets.
if (ent.startsWith('- ')) {
return ent.substring(2);
}
// Might also be numbered.
const match = ent.match(/^[0-9]+\. (.*)$/);
if (match !== null) {
return match[1];
}
// If there no prefix, and it starts with a capital, let it through.
if (ent.match(/^[A-Z].*$/)) {
return ent;
}
throw new Error(`Entity not in recognized format: "${ent}"`)
}