Skip to content

Commit

Permalink
Add llm integration
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Wang <[email protected]>
  • Loading branch information
xiaohk committed Feb 7, 2024
1 parent 459860e commit 94eb7a3
Show file tree
Hide file tree
Showing 16 changed files with 1,532 additions and 7 deletions.
3 changes: 3 additions & 0 deletions examples/rag-playground/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
},
"devDependencies": {
"@floating-ui/dom": "^1.6.1",
"@mlc-ai/web-llm": "^0.2.18",
"@types/d3-array": "^3.2.1",
"@types/d3-format": "^3.0.4",
"@types/d3-random": "^3.0.3",
"@types/d3-time-format": "^4.0.3",
"@types/flexsearch": "^0.7.6",
"@typescript-eslint/eslint-plugin": "^6.20.0",
"@webgpu/types": "^0.1.40",
"@xenova/transformers": "^2.14.2",
"@xiaohk/utils": "^0.0.6",
"d3-array": "^3.2.4",
Expand All @@ -34,6 +36,7 @@
"flexsearch": "^0.7.43",
"gh-pages": "^6.1.1",
"gpt-tokenizer": "^2.1.2",
"idb-keyval": "^6.2.1",
"lit": "^3.1.2",
"prettier": "^3.2.4",
"typescript": "^5.3.3",
Expand Down
169 changes: 168 additions & 1 deletion examples/rag-playground/src/components/playground/playground.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,23 @@ import { LitElement, css, unsafeCSS, html, PropertyValues } from 'lit';
import { customElement, property, state, query } from 'lit/decorators.js';
import { unsafeHTML } from 'lit/directives/unsafe-html.js';
import { EmbeddingModel } from '../../workers/embedding';
import {
UserConfigManager,
UserConfig,
SupportedRemoteModel,
SupportedLocalModel,
supportedModelReverseLookup,
ModelFamily
} from './user-config';
import { textGenGpt } from '../../llms/gpt';
import { textGenMememo } from '../../llms/mememo-gen';
import { textGenGemini } from '../../llms/gemini';
import TextGenLocalWorkerInline from '../../llms/web-llm?worker&inline';

import type { TextGenMessage } from '../../llms/gpt';
import type { EmbeddingWorkerMessage } from '../../workers/embedding';
import type { MememoTextViewer } from '../text-viewer/text-viewer';
import type { TextGenLocalWorkerMessage } from '../../llms/web-llm';

import '../query-box/query-box';
import '../prompt-box/prompt-box';
Expand Down Expand Up @@ -35,6 +50,9 @@ const datasets: Record<Dataset, DatasetInfo> = {
}
};

const DEV_MODE = import.meta.env.DEV;
const USE_CACHE = true && DEV_MODE;

/**
* Playground element.
*
Expand Down Expand Up @@ -63,6 +81,18 @@ export class MememoPlayground extends LitElement {
@query('mememo-text-viewer')
textViewerComponent: MememoTextViewer | undefined | null;

@state()
userConfigManager: UserConfigManager;

@state()
userConfig!: UserConfig;

@property({ attribute: false })
textGenLocalWorker: Worker;
textGenLocalWorkerResolve = (
value: TextGenMessage | PromiseLike<TextGenMessage>
) => {};

//==========================================================================||
// Lifecycle Methods ||
//==========================================================================||
Expand All @@ -76,6 +106,15 @@ export class MememoPlayground extends LitElement {
this.embeddingWorkerMessageHandler(e);
}
);

// Initialize the local llm worker
this.textGenLocalWorker = new TextGenLocalWorkerInline();

// Set up the user config store
const updateUserConfig = (userConfig: UserConfig) => {
this.userConfig = userConfig;
};
this.userConfigManager = new UserConfigManager(updateUserConfig);
}

/**
Expand Down Expand Up @@ -129,13 +168,28 @@ export class MememoPlayground extends LitElement {
//==========================================================================||
// Event Handlers ||
//==========================================================================||
/**
* Start extracting embeddings form the user query
* @param e Event
*/
userQueryRunClickHandler(e: CustomEvent<string>) {
this.userQuery = e.detail;

// Extract embeddings for the user query
this.getEmbedding([this.userQuery]);
}

/**
* Run the prompt using external AI services or local LLM
* @param e Event
*/
promptRunClickHandler(e: CustomEvent<string>) {
const prompt = e.detail;

// Run the prompt
this._runPrompt(prompt);
}

semanticSearchFinishedHandler(e: CustomEvent<string[]>) {
this.relevantDocuments = e.detail;
}
Expand Down Expand Up @@ -164,6 +218,117 @@ export class MememoPlayground extends LitElement {
//==========================================================================||
// Private Helpers ||
//==========================================================================||
/**
* Run the given prompt using the preferred model
* @returns A promise of the prompt inference
*/
_runPrompt(curPrompt: string, temperature = 0.2) {
let runRequest: Promise<TextGenMessage>;

switch (this.userConfig.preferredLLM) {
case SupportedRemoteModel['gpt-3.5']: {
runRequest = textGenGpt(
this.userConfig.llmAPIKeys[ModelFamily.openAI],
'text-gen',
curPrompt,
temperature,
'gpt-3.5-turbo',
USE_CACHE
);
break;
}

case SupportedRemoteModel['gpt-4']: {
runRequest = textGenGpt(
this.userConfig.llmAPIKeys[ModelFamily.openAI],
'text-gen',
curPrompt,
temperature,
'gpt-4-1106-preview',
USE_CACHE
);
break;
}

case SupportedRemoteModel['gemini-pro']: {
runRequest = textGenGemini(
this.userConfig.llmAPIKeys[ModelFamily.google],
'text-gen',
curPrompt,
temperature,
USE_CACHE
);
break;
}

// case SupportedLocalModel['mistral-7b-v0.2']:
// case SupportedLocalModel['gpt-2']:
case SupportedLocalModel['phi-2']:
case SupportedLocalModel['llama-2-7b']:
case SupportedLocalModel['tinyllama-1.1b']: {
runRequest = new Promise<TextGenMessage>(resolve => {
this.textGenLocalWorkerResolve = resolve;
});
const message: TextGenLocalWorkerMessage = {
command: 'startTextGen',
payload: {
apiKey: '',
prompt: curPrompt,
requestID: '',
temperature: temperature
}
};
this.textGenLocalWorker.postMessage(message);
break;
}

case SupportedRemoteModel['gpt-3.5-free']: {
runRequest = textGenMememo(
'text-gen',
curPrompt,
temperature,
'gpt-3.5-free',
USE_CACHE
);
break;
}

default: {
console.error('Unknown case ', this.userConfig.preferredLLM);
runRequest = textGenMememo(
'text-gen',
curPrompt,
temperature,
'gpt-3.5-free',
USE_CACHE
);
}
}

runRequest.then(
message => {
switch (message.command) {
case 'finishTextGen': {
// Success
if (DEV_MODE) {
console.info(
`Finished running prompt with [${this.userConfig.preferredLLM}]`
);
console.info(message.payload.result);
}

const output = message.payload.result;
break;
}

case 'error': {
console.error(message.payload.message);
}
}
},
() => {}
);
}

//==========================================================================||
// Templates and Styles ||
Expand Down Expand Up @@ -198,7 +363,9 @@ export class MememoPlayground extends LitElement {
template=${promptTemplate[Dataset.Arxiv]}
userQuery=${this.userQuery}
.relevantDocuments=${this.relevantDocuments}
@runButtonClicked=${(e: CustomEvent<string>) => {}}
@runButtonClicked=${(e: CustomEvent<string>) => {
this.promptRunClickHandler(e);
}}
></mememo-prompt-box>
</div>
Expand Down
149 changes: 149 additions & 0 deletions examples/rag-playground/src/components/playground/user-config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import { get, set, del, clear } from 'idb-keyval';

const PREFIX = 'user-config';

export enum SupportedLocalModel {
'llama-2-7b' = 'Llama 2 (7B)',
// 'mistral-7b-v0.2' = 'Mistral (7B)',
'phi-2' = 'Phi 2 (2.7B)',
'tinyllama-1.1b' = 'TinyLlama (1.1B)'
// 'gpt-2' = 'GPT 2 (124M)'
}

export enum SupportedRemoteModel {
'gpt-3.5-free' = 'GPT 3.5 (free)',
'gpt-3.5' = 'GPT 3.5',
'gpt-4' = 'GPT 4',
'gemini-pro' = 'Gemini Pro'
}

export const supportedModelReverseLookup: Record<
SupportedRemoteModel | SupportedLocalModel,
keyof typeof SupportedRemoteModel | keyof typeof SupportedLocalModel
> = {
[SupportedRemoteModel['gpt-3.5-free']]: 'gpt-3.5-free',
[SupportedRemoteModel['gpt-3.5']]: 'gpt-3.5',
[SupportedRemoteModel['gpt-4']]: 'gpt-4',
[SupportedRemoteModel['gemini-pro']]: 'gemini-pro',
[SupportedLocalModel['tinyllama-1.1b']]: 'tinyllama-1.1b',
[SupportedLocalModel['llama-2-7b']]: 'llama-2-7b',
[SupportedLocalModel['phi-2']]: 'phi-2'
// [SupportedLocalModel['gpt-2']]: 'gpt-2'
// [SupportedLocalModel['mistral-7b-v0.2']]: 'mistral-7b-v0.2'
};

export enum ModelFamily {
google = 'Google',
openAI = 'Open AI',
local = 'Local'
}

export const modelFamilyMap: Record<
SupportedRemoteModel | SupportedLocalModel,
ModelFamily
> = {
[SupportedRemoteModel['gpt-3.5']]: ModelFamily.openAI,
[SupportedRemoteModel['gpt-3.5-free']]: ModelFamily.openAI,
[SupportedRemoteModel['gpt-4']]: ModelFamily.openAI,
[SupportedRemoteModel['gemini-pro']]: ModelFamily.google,
[SupportedLocalModel['tinyllama-1.1b']]: ModelFamily.local,
[SupportedLocalModel['llama-2-7b']]: ModelFamily.local,
// [SupportedLocalModel['gpt-2']]: ModelFamily.local
// [SupportedLocalModel['mistral-7b-v0.2']]: ModelFamily.local
[SupportedLocalModel['phi-2']]: ModelFamily.local
};

export interface UserConfig {
llmAPIKeys: Record<ModelFamily, string>;
preferredLLM: SupportedRemoteModel | SupportedLocalModel;
}

export class UserConfigManager {
restoreFinished: Promise<void>;
updateUserConfig: (userConfig: UserConfig) => void;

#llmAPIKeys: Record<ModelFamily, string>;
#preferredLLM: SupportedRemoteModel | SupportedLocalModel;

constructor(updateUserConfig: (userConfig: UserConfig) => void) {
this.updateUserConfig = updateUserConfig;

this.#llmAPIKeys = {
[ModelFamily.openAI]: '',
[ModelFamily.google]: '',
[ModelFamily.local]: ''
};
this.#preferredLLM = SupportedRemoteModel['gpt-3.5-free'];
this._broadcastUserConfig();

this.restoreFinished = this._restoreFromStorage();

// this._cleanStorage();
}

setAPIKey(modelFamily: ModelFamily, key: string) {
this.#llmAPIKeys[modelFamily] = key;
this._syncStorage().then(
() => {},
() => {}
);
this._broadcastUserConfig();
}

setPreferredLLM(model: SupportedRemoteModel | SupportedLocalModel) {
this.#preferredLLM = model;
this._syncStorage().then(
() => {},
() => {}
);
this._broadcastUserConfig();
}

/**
* Reconstruct the prompts from the local storage.
*/
async _restoreFromStorage() {
// Restore the local prompts
const config = (await get(PREFIX)) as UserConfig | undefined;
if (config) {
this.#llmAPIKeys = config.llmAPIKeys;
this.#preferredLLM = config.preferredLLM;
}
this._broadcastUserConfig();
}

/**
* Store the current config to local storage
*/
async _syncStorage() {
const config = this._constructConfig();
await set(PREFIX, config);
}

/**
* Create a copy of the user config
* @returns User config
*/
_constructConfig(): UserConfig {
const config: UserConfig = {
llmAPIKeys: this.#llmAPIKeys,
preferredLLM: this.#preferredLLM
};
return config;
}

/**
* Clean the local storage
*/
async _cleanStorage() {
await del(PREFIX);
}

/**
* Update the public user config
*/
_broadcastUserConfig() {
const newConfig = this._constructConfig();
this.updateUserConfig(newConfig);
}
}
Loading

0 comments on commit 94eb7a3

Please sign in to comment.