Skip to content

Add chat tab #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"@parcel/compressor-brotli": "^2.8.3",
"@parcel/compressor-gzip": "^2.8.3",
"@parcel/config-default": "^2.8.3",
"@tailwindcss/typography": "^0.5.9",
"@types/chroma-js": "^2.4.0",
"@types/react": "^18.0.27",
"@types/react-dom": "^18.0.10",
Expand Down Expand Up @@ -48,12 +49,15 @@
"match-sorter": "^6.3.1",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-hotkeys-hook": "^4.4.0",
"react-markdown": "^8.0.6",
"react-responsive": "^9.0.2",
"react-router-dom": "^6.8.1",
"react-scroll": "^1.8.9",
"react-select": "^5.7.0",
"react-tiny-popover": "^7.2.3",
"react-transition-group": "^4.4.5",
"remark-gfm": "^3.0.1",
"sort-by": "^1.2.0",
"sse.js": "^0.6.1",
"tailwind-merge": "^1.9.0",
Expand Down
112 changes: 64 additions & 48 deletions app/src/app.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import React, { useEffect } from "react"
import {Playground, Compare, Settings} from "./pages"
import {Playground, Compare, Chat, Settings} from "./pages"
import {SSE} from "sse.js"
import {
EditorState,
Expand Down Expand Up @@ -56,6 +56,12 @@ const DEFAULT_CONTEXTS = {
showParametersTable: false
}
},
chat:{
history: DEFAULT_HISTORY_STATE,
editor: DEFAULT_EDITOR_STATE,
modelsState: [],
parameters: DEFAULT_PARAMETERS_STATE
},
},
MODELS: [],
}
Expand Down Expand Up @@ -155,7 +161,7 @@ const APIContextWrapper = ({children}) => {

useEffect(() => {
const sse_request = new SSE("/api/notifications")

sse_request.addEventListener("notification", (event: any) => {
const parsedEvent = JSON.parse(event.data);
notificationSubscribers.current.forEach((callback) => {
Expand All @@ -180,15 +186,15 @@ const APIContextWrapper = ({children}) => {
notificationSubscribers.current = notificationSubscribers.current.filter((cb) => cb !== callback);
},
};

const Provider = {
setAPIKey: async (provider, apiKey) => (await fetch(`/api/provider/${provider}/api-key`, {method: "PUT", headers: {"Content-Type": "application/json"},
setAPIKey: async (provider, apiKey) => (await fetch(`/api/provider/${provider}/api-key`, {method: "PUT", headers: {"Content-Type": "application/json"},
body: JSON.stringify({apiKey: apiKey})}
)).json(),
getAll: async () => (await fetch("/api/providers")).json(),
getAllWithModels: async () => (await fetch("/api/providers-with-key-and-models")).json(),
};

const Inference = {
subscribeTextCompletion: (callback) => {
textCompletionSubscribers.current.push(callback);
Expand All @@ -205,14 +211,14 @@ const APIContextWrapper = ({children}) => {
},
chatCompletion: createChatCompletionRequest,
};

const [apiContext, _] = React.useState({
Model,
Notifications,
Provider,
Inference,
});

function createTextCompletionRequest({prompt, models}) {
const url = "/api/inference/text/stream";
const payload = {
Expand All @@ -221,30 +227,30 @@ const APIContextWrapper = ({children}) => {
};
return createCompletionRequest(url, payload, textCompletionSubscribers);
}

function createChatCompletionRequest(prompt, model) {
const url = "/api/inference/chat/stream";
const payload = {prompt, model};
return createCompletionRequest(url, payload, chatCompletionSubscribers);
}

function createCompletionRequest(url, payload, subscribers) {
pendingCompletionRequest.current = true;
let sse_request = null;

function beforeUnloadHandler() {
if (sse_request) sse_request.close();
}

window.addEventListener("beforeunload", beforeUnloadHandler);
const completionsBuffer = createCompletionsBuffer(payload.models);
let error_occured = false;
let request_complete = false;

sse_request = new SSE(url, {payload: JSON.stringify(payload)});

bindSSEEvents(sse_request, completionsBuffer, {error_occured, request_complete}, beforeUnloadHandler, subscribers);

return () => {
if (sse_request) sse_request.close();
};
Expand All @@ -257,52 +263,52 @@ const APIContextWrapper = ({children}) => {
});
return buffer;
}

function bindSSEEvents(sse_request, completionsBuffer, requestState, beforeUnloadHandler, subscribers) {
sse_request.onopen = async () => {
bulkWrite(completionsBuffer, requestState, subscribers);
};

sse_request.addEventListener("infer", (event) => {
let resp = JSON.parse(event.data);
completionsBuffer[resp.modelTag].push(resp);
});

sse_request.addEventListener("status", (event) => {
subscribers.current.forEach((callback) => callback({
event: "status",
data: JSON.parse(event.data)
}));
});

sse_request.addEventListener("error", (event) => {
requestState.error_occured = true;
try {
const message = JSON.parse(event.data);

subscribers.current.forEach((callback) => callback({
"event": "error",
"data": message.status
"data": message.status
}));
} catch (e) {
subscribers.current.forEach((callback) => callback({
"event": "error",
"data": "Unknown error"
}));
}

close_sse(sse_request, requestState, beforeUnloadHandler, subscribers);
});

sse_request.addEventListener("abort", () => {
requestState.error_occured = true;
close_sse(sse_request, requestState, beforeUnloadHandler, subscribers);
});

sse_request.addEventListener("readystatechange", (event) => {
if (event.readyState === 2) close_sse(sse_request, requestState, beforeUnloadHandler, subscribers);
});

sse_request.stream();
}

Expand All @@ -313,27 +319,27 @@ const APIContextWrapper = ({children}) => {
"meta": {error: requestState.error_occured},
}));
window.removeEventListener("beforeunload", beforeUnloadHandler);
}
}

function bulkWrite(completionsBuffer, requestState, subscribers) {
setTimeout(() => {
let newTokens = false;
let batchUpdate = {};

for (let modelTag in completionsBuffer) {
if (completionsBuffer[modelTag].length > 0) {
newTokens = true;
batchUpdate[modelTag] = completionsBuffer[modelTag].splice(0, completionsBuffer[modelTag].length);
}
}

if (newTokens) {
subscribers.current.forEach((callback) => callback({
event: "completion",
data: batchUpdate,
}));
}

if (!requestState.request_complete) bulkWrite(completionsBuffer, requestState, subscribers);
}, 20);
}
Expand All @@ -347,20 +353,19 @@ const APIContextWrapper = ({children}) => {

const PlaygroundContextWrapper = ({page, children}) => {
const apiContext = React.useContext(APIContext)

const [editorContext, _setEditorContext] = React.useState(DEFAULT_CONTEXTS.PAGES[page].editor);
const [parametersContext, _setParametersContext] = React.useState(DEFAULT_CONTEXTS.PAGES[page].parameters);
let [modelsStateContext, _setModelsStateContext] = React.useState(DEFAULT_CONTEXTS.PAGES[page].modelsState);
const [modelsContext, _setModelsContext] = React.useState(DEFAULT_CONTEXTS.MODELS);
const [historyContext, _setHistoryContext] = React.useState(DEFAULT_CONTEXTS.PAGES[page].history);
const [historyContext, setHistoryContext] = React.useState(DEFAULT_CONTEXTS.PAGES[page].history);

/* Temporary fix for models that have been purged remotely but are still cached locally */
for(const {name} of modelsStateContext) {
if (!modelsContext[name]) {
modelsStateContext = modelsStateContext.filter(({name: _name}) => _name !== name)
}
}

const editorContextRef = React.useRef(editorContext);
const historyContextRef = React.useRef(historyContext);

Expand Down Expand Up @@ -399,7 +404,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
break;
}
}

apiContext.Notifications.subscribe(notificationCallback)

return () => {
Expand All @@ -410,9 +415,9 @@ const PlaygroundContextWrapper = ({page, children}) => {
const updateModelsData = async () => {
const json_params = await apiContext.Model.getAllEnabled()
const models = {};

const PAGE_MODELS_STATE = SETTINGS.pages[page].modelsState;

for (const [model_key, modelDetails] of Object.entries(json_params)) {
const existingModelEntry = (PAGE_MODELS_STATE.find((model) => model.name === model_key));

Expand Down Expand Up @@ -448,7 +453,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
provider: modelDetails.provider,
}
}

const SERVER_SIDE_MODELS = Object.keys(json_params);
for (const {name} of PAGE_MODELS_STATE) {
if (!SERVER_SIDE_MODELS.includes(name)) {
Expand All @@ -465,7 +470,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
const setEditorContext = (newEditorContext, immediate=false) => {
SETTINGS.pages[page].editor = {...SETTINGS.pages[page].editor, ...newEditorContext};

const _editor = {...SETTINGS.pages[page].editor, internalState: null };
const _editor = {...SETTINGS.pages[page].editor, internalState: null};

_setEditorContext(_editor);
if (immediate) {
Expand All @@ -485,14 +490,14 @@ const PlaygroundContextWrapper = ({page, children}) => {

const setModelsContext = (newModels) => {
SETTINGS.models = newModels;

debouncedSettingsSave()
_setModelsContext(newModels);
}

const setModelsStateContext = (newModelsState) => {
SETTINGS.pages[page].modelsState = newModelsState;

debouncedSettingsSave()
_setModelsStateContext(newModelsState);
}
Expand All @@ -503,7 +508,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
show: (value === undefined || value === null) ? !SETTINGS.pages[page].history.show : value
}

_setHistoryContext(_newHistory);
setHistoryContext(_newHistory);

SETTINGS.pages[page].history = _newHistory;
debouncedSettingsSave()
Expand All @@ -518,7 +523,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
const year = currentDate.getFullYear();
const month = String(currentDate.getMonth() + 1).padStart(2, '0');
const day = String(currentDate.getDate()).padStart(2, '0');

const newEntry = {
timestamp: currentDate.getTime(),
date: `${year}-${month}-${day}`,
Expand All @@ -536,7 +541,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
current: newEntry
}

_setHistoryContext(_newHistory);
setHistoryContext(_newHistory);

//console.warn("Adding to history", _newHistory)
SETTINGS.pages[page].history = _newHistory;
Expand All @@ -549,7 +554,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
entries: SETTINGS.pages[page].history.entries.filter((historyEntry) => historyEntry !== entry)
}

_setHistoryContext(_newHistory);
setHistoryContext(_newHistory);

SETTINGS.pages[page].history = _newHistory;
debouncedSettingsSave()
Expand All @@ -562,7 +567,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
current: null
}

_setHistoryContext(_newHistory);
setHistoryContext(_newHistory);

SETTINGS.pages[page].history = _newHistory;
debouncedSettingsSave()
Expand All @@ -572,7 +577,7 @@ const PlaygroundContextWrapper = ({page, children}) => {
SETTINGS.pages[page].history.current = entry;
_setEditorContext(entry.editor);

_setHistoryContext(SETTINGS.pages[page].history);
setHistoryContext(SETTINGS.pages[page].history);
setParametersContext(entry.parameters);
setModelsStateContext(entry.modelsState);
}
Expand All @@ -583,8 +588,8 @@ const PlaygroundContextWrapper = ({page, children}) => {

return (
<HistoryContext.Provider value = {{
historyContext, selectHistoryItem,
addHistoryEntry, removeHistoryEntry, clearHistory, toggleShowHistory
historyContext, setHistoryContext, selectHistoryItem,
addHistoryEntry, removeHistoryEntry, clearHistory, toggleShowHistory,
}}>
<EditorContext.Provider value = {{editorContext, setEditorContext}}>
<ParametersContext.Provider value = {{parametersContext, setParametersContext}}>
Expand Down Expand Up @@ -626,6 +631,17 @@ function ProviderWithRoutes() {
}
/>

<Route
path="/chat"
element={
<APIContextWrapper>
<PlaygroundContextWrapper key = "chat" page = "chat">
<Chat/>
</PlaygroundContextWrapper>
</APIContextWrapper>
}
/>

<Route
path="/settings"
element={
Expand All @@ -645,4 +661,4 @@ export default function App() {
<ProviderWithRoutes />
</BrowserRouter>
)
}
}
Loading