diff --git a/app/api/chat/config/llamacloud/route.ts b/app/api/chat/config/llamacloud/route.ts new file mode 100644 index 0000000..7333947 --- /dev/null +++ b/app/api/chat/config/llamacloud/route.ts @@ -0,0 +1,24 @@ +import { LLamaCloudFileService } from "llamaindex"; +import { NextResponse } from "next/server"; + +/** + * This API is to get config from the backend envs and expose them to the frontend + */ +export async function GET() { + if (!process.env.LLAMA_CLOUD_API_KEY) { + return NextResponse.json( + { + error: "env variable LLAMA_CLOUD_API_KEY is required to use LlamaCloud", + }, + { status: 500 }, + ); + } + const config = { + projects: await LLamaCloudFileService.getAllProjectsWithPipelines(), + pipeline: { + pipeline: process.env.LLAMA_CLOUD_INDEX_NAME, + project: process.env.LLAMA_CLOUD_PROJECT_NAME, + }, + }; + return NextResponse.json(config, { status: 200 }); +} diff --git a/app/api/chat/config/route.ts b/app/api/chat/config/route.ts new file mode 100644 index 0000000..8d875e6 --- /dev/null +++ b/app/api/chat/config/route.ts @@ -0,0 +1,11 @@ +import { NextResponse } from "next/server"; + +/** + * This API is to get config from the backend envs and expose them to the frontend + */ +export async function GET() { + const config = { + starterQuestions: process.env.CONVERSATION_STARTERS?.trim().split("\n"), + }; + return NextResponse.json(config, { status: 200 }); +} diff --git a/app/api/chat/engine/chat.ts b/app/api/chat/engine/chat.ts index ef1dd5b..719d786 100644 --- a/app/api/chat/engine/chat.ts +++ b/app/api/chat/engine/chat.ts @@ -1,21 +1,36 @@ -import { ContextChatEngine, Settings } from "llamaindex"; +import { BaseChatEngine, BaseToolWithCall, LLMAgent } from "llamaindex"; +import fs from "node:fs/promises"; +import path from "node:path"; import { getDataSource } from "./index"; +import { createTools } from "./tools"; +import { createQueryEngineTool } from "./tools/query-engine"; -export async function createChatEngine() { - const index = await getDataSource(); - if (!index) { - throw new Error( - `StorageContext is empty - call 'npm run generate' to generate the storage first`, - ); +export async function createChatEngine(documentIds?: string[], params?: any) { + const tools: BaseToolWithCall[] = []; + + // Add a query engine tool if we have a data source + // Delete this code if you don't have a data source + const index = await getDataSource(params); + if (index) { + tools.push(createQueryEngineTool(index, { documentIds })); } - const retriever = index.asRetriever(); - retriever.similarityTopK = process.env.TOP_K - ? parseInt(process.env.TOP_K) - : 3; - return new ContextChatEngine({ - chatModel: Settings.llm, - retriever, + const configFile = path.join("config", "tools.json"); + let toolConfig: any; + try { + // add tools from config file if it exists + toolConfig = JSON.parse(await fs.readFile(configFile, "utf8")); + } catch (e) { + console.info(`Could not read ${configFile} file. Using no tools.`); + } + if (toolConfig) { + tools.push(...(await createTools(toolConfig))); + } + + const agent = new LLMAgent({ + tools, systemPrompt: process.env.SYSTEM_PROMPT, - }); + }) as unknown as BaseChatEngine; + + return agent; } diff --git a/app/api/chat/engine/generate.ts b/app/api/chat/engine/generate.ts index 8c16280..4647361 100644 --- a/app/api/chat/engine/generate.ts +++ b/app/api/chat/engine/generate.ts @@ -5,7 +5,6 @@ import * as dotenv from "dotenv"; import { getDocuments } from "./loader"; import { initSettings } from "./settings"; -import { STORAGE_CACHE_DIR } from "./shared"; // Load environment variables from local .env file dotenv.config(); @@ -20,11 +19,16 @@ async function getRuntime(func: any) { async function generateDatasource() { console.log(`Generating storage context...`); // Split documents, create embeddings and store them in the storage context + const persistDir = process.env.STORAGE_CACHE_DIR; + if (!persistDir) { + throw new Error("STORAGE_CACHE_DIR environment variable is required!"); + } const ms = await getRuntime(async () => { const storageContext = await storageContextFromDefaults({ - persistDir: STORAGE_CACHE_DIR, + persistDir, }); const documents = await getDocuments(); + await VectorStoreIndex.fromDocuments(documents, { storageContext, }); diff --git a/app/api/chat/engine/index.ts b/app/api/chat/engine/index.ts index 64b2897..d38ea60 100644 --- a/app/api/chat/engine/index.ts +++ b/app/api/chat/engine/index.ts @@ -1,10 +1,13 @@ import { SimpleDocumentStore, VectorStoreIndex } from "llamaindex"; import { storageContextFromDefaults } from "llamaindex/storage/StorageContext"; -import { STORAGE_CACHE_DIR } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { + const persistDir = process.env.STORAGE_CACHE_DIR; + if (!persistDir) { + throw new Error("STORAGE_CACHE_DIR environment variable is required!"); + } const storageContext = await storageContextFromDefaults({ - persistDir: `${STORAGE_CACHE_DIR}`, + persistDir, }); const numberOfDocs = Object.keys( diff --git a/app/api/chat/engine/loader.ts b/app/api/chat/engine/loader.ts index 3039f34..42a2e0a 100644 --- a/app/api/chat/engine/loader.ts +++ b/app/api/chat/engine/loader.ts @@ -1,9 +1,24 @@ -import { SimpleDirectoryReader } from "llamaindex"; +import { + FILE_EXT_TO_READER, + SimpleDirectoryReader, +} from "@llamaindex/readers/directory"; export const DATA_DIR = "./data"; +export function getExtractors() { + return FILE_EXT_TO_READER; +} + export async function getDocuments() { - return await new SimpleDirectoryReader().loadData({ + const documents = await new SimpleDirectoryReader().loadData({ directoryPath: DATA_DIR, }); + // Set private=false to mark the document as public (required for filtering) + for (const document of documents) { + document.metadata = { + ...document.metadata, + private: "false", + }; + } + return documents; } diff --git a/app/api/chat/engine/provider.ts b/app/api/chat/engine/provider.ts new file mode 100644 index 0000000..f833ebb --- /dev/null +++ b/app/api/chat/engine/provider.ts @@ -0,0 +1,37 @@ +import { OpenAI, OpenAIEmbedding } from "@llamaindex/openai"; +import { Settings } from "llamaindex"; +import { + DefaultAzureCredential, + getBearerTokenProvider, +} from "@azure/identity"; + +const AZURE_COGNITIVE_SERVICES_SCOPE = + "https://cognitiveservices.azure.com/.default"; + +export function setupProvider() { + const credential = new DefaultAzureCredential(); + const azureADTokenProvider = getBearerTokenProvider( + credential, + AZURE_COGNITIVE_SERVICES_SCOPE, + ); + + const azure = { + azureADTokenProvider, + deployment: process.env.AZURE_DEPLOYMENT_NAME ?? "gpt-35-turbo", + }; + + // configure LLM model + Settings.llm = new OpenAI({ + azure, + }) as any; + + // configure embedding model + azure.deployment = process.env.EMBEDDING_MODEL as string; + Settings.embedModel = new OpenAIEmbedding({ + azure, + model: process.env.EMBEDDING_MODEL, + dimensions: process.env.EMBEDDING_DIM + ? parseInt(process.env.EMBEDDING_DIM) + : undefined, + }); +} diff --git a/app/api/chat/engine/queryFilter.ts b/app/api/chat/engine/queryFilter.ts new file mode 100644 index 0000000..ee3bc1e --- /dev/null +++ b/app/api/chat/engine/queryFilter.ts @@ -0,0 +1,25 @@ +import { MetadataFilter, MetadataFilters } from "llamaindex"; + +export function generateFilters(documentIds: string[]): MetadataFilters { + // filter all documents have the private metadata key set to true + const publicDocumentsFilter: MetadataFilter = { + key: "private", + value: "true", + operator: "!=", + }; + + // if no documentIds are provided, only retrieve information from public documents + if (!documentIds.length) return { filters: [publicDocumentsFilter] }; + + const privateDocumentsFilter: MetadataFilter = { + key: "doc_id", + value: documentIds, + operator: "in", + }; + + // if documentIds are provided, retrieve information from public and private documents + return { + filters: [publicDocumentsFilter, privateDocumentsFilter], + condition: "or", + }; +} diff --git a/app/api/chat/engine/settings.ts b/app/api/chat/engine/settings.ts index 4b382de..ce2c3b9 100644 --- a/app/api/chat/engine/settings.ts +++ b/app/api/chat/engine/settings.ts @@ -1,106 +1,18 @@ -import { OpenAI, OpenAIEmbedding, Settings } from "llamaindex"; - -import { - DefaultAzureCredential, - getBearerTokenProvider, -} from "@azure/identity"; -import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding"; -import { Ollama } from "llamaindex/llm/ollama"; +import { Settings } from "llamaindex"; +import { setupProvider } from "./provider"; const CHUNK_SIZE = 512; const CHUNK_OVERLAP = 20; -const AZURE_COGNITIVE_SERVICES_SCOPE = - "https://cognitiveservices.azure.com/.default"; export const initSettings = async () => { console.log(`Using '${process.env.MODEL_PROVIDER}' model provider`); - // if provider is OpenAI, MODEL must be set - if (process.env.MODEL_PROVIDER === 'openai' && process.env.OPENAI_API_TYPE !== 'AzureOpenAI' && !process.env.MODEL) { - throw new Error("'MODEL' env variable must be set."); - } - - // if provider is Azure OpenAI, AZURE_DEPLOYMENT_NAME must be set - if (process.env.MODEL_PROVIDER === 'openai' && process.env.OPENAI_API_TYPE === 'AzureOpenAI' && !process.env.AZURE_DEPLOYMENT_NAME) { - throw new Error("'AZURE_DEPLOYMENT_NAME' env variables must be set."); + if (!process.env.MODEL || !process.env.EMBEDDING_MODEL) { + throw new Error("'MODEL' and 'EMBEDDING_MODEL' env variables must be set."); } - if (!process.env.EMBEDDING_MODEL) { - throw new Error("'EMBEDDING_MODEL' env variable must be set."); - } - - switch (process.env.MODEL_PROVIDER) { - case "ollama": - initOllama(); - break; - case "openai": - if (process.env.OPENAI_API_TYPE === "AzureOpenAI") { - await initAzureOpenAI(); - } else { - initOpenAI(); - } - break; - default: - throw new Error( - `Model provider '${process.env.MODEL_PROVIDER}' not supported.`, - ); - } Settings.chunkSize = CHUNK_SIZE; Settings.chunkOverlap = CHUNK_OVERLAP; -}; - -function initOpenAI() { - Settings.llm = new OpenAI({ - model: process.env.MODEL ?? "gpt-3.5-turbo", - maxTokens: 512, - }); - Settings.embedModel = new OpenAIEmbedding({ - model: process.env.EMBEDDING_MODEL, - dimensions: process.env.EMBEDDING_DIM - ? parseInt(process.env.EMBEDDING_DIM) - : undefined, - }); -} - -async function initAzureOpenAI() { - const credential = new DefaultAzureCredential(); - const azureADTokenProvider = getBearerTokenProvider( - credential, - AZURE_COGNITIVE_SERVICES_SCOPE, - ); - - const azure = { - azureADTokenProvider, - deployment: process.env.AZURE_OPENAI_DEPLOYMENT ?? "gpt-35-turbo", - }; - - // configure LLM model - Settings.llm = new OpenAI({ - azure, - }) as any; - - // configure embedding model - azure.deployment = process.env.EMBEDDING_MODEL as string; - Settings.embedModel = new OpenAIEmbedding({ - azure, - model: process.env.EMBEDDING_MODEL, - dimensions: process.env.EMBEDDING_DIM - ? parseInt(process.env.EMBEDDING_DIM) - : undefined, - }); -} - -function initOllama() { - const config = { - host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434", - }; - Settings.llm = new Ollama({ - model: process.env.MODEL ?? "", - config, - }); - Settings.embedModel = new OllamaEmbedding({ - model: process.env.EMBEDDING_MODEL ?? "", - config, - }); -} + setupProvider(); +}; diff --git a/app/api/chat/engine/shared.ts b/app/api/chat/engine/shared.ts deleted file mode 100644 index e7736e5..0000000 --- a/app/api/chat/engine/shared.ts +++ /dev/null @@ -1 +0,0 @@ -export const STORAGE_CACHE_DIR = "./cache"; diff --git a/app/api/chat/engine/tools/code-generator.ts b/app/api/chat/engine/tools/code-generator.ts new file mode 100644 index 0000000..610cf72 --- /dev/null +++ b/app/api/chat/engine/tools/code-generator.ts @@ -0,0 +1,146 @@ +import type { JSONSchemaType } from "ajv"; +import { + BaseTool, + ChatMessage, + JSONValue, + Settings, + ToolMetadata, +} from "llamaindex"; + +// prompt based on https://github.com/e2b-dev/ai-artifacts +const CODE_GENERATION_PROMPT = `You are a skilled software engineer. You do not make mistakes. Generate an artifact. You can install additional dependencies. You can use one of the following templates:\n + +1. code-interpreter-multilang: "Runs code as a Jupyter notebook cell. Strong data analysis angle. Can use complex visualisation to explain results.". File: script.py. Dependencies installed: python, jupyter, numpy, pandas, matplotlib, seaborn, plotly. Port: none. + +2. nextjs-developer: "A Next.js 13+ app that reloads automatically. Using the pages router.". File: pages/index.tsx. Dependencies installed: nextjs@14.2.5, typescript, @types/node, @types/react, @types/react-dom, postcss, tailwindcss, shadcn. Port: 3000. + +3. vue-developer: "A Vue.js 3+ app that reloads automatically. Only when asked specifically for a Vue app.". File: app.vue. Dependencies installed: vue@latest, nuxt@3.13.0, tailwindcss. Port: 3000. + +4. streamlit-developer: "A streamlit app that reloads automatically.". File: app.py. Dependencies installed: streamlit, pandas, numpy, matplotlib, request, seaborn, plotly. Port: 8501. + +5. gradio-developer: "A gradio app. Gradio Blocks/Interface should be called demo.". File: app.py. Dependencies installed: gradio, pandas, numpy, matplotlib, request, seaborn, plotly. Port: 7860. + +Provide detail information about the artifact you're about to generate in the following JSON format with the following keys: + +commentary: Describe what you're about to do and the steps you want to take for generating the artifact in great detail. +template: Name of the template used to generate the artifact. +title: Short title of the artifact. Max 3 words. +description: Short description of the artifact. Max 1 sentence. +additional_dependencies: Additional dependencies required by the artifact. Do not include dependencies that are already included in the template. +has_additional_dependencies: Detect if additional dependencies that are not included in the template are required by the artifact. +install_dependencies_command: Command to install additional dependencies required by the artifact. +port: Port number used by the resulted artifact. Null when no ports are exposed. +file_path: Relative path to the file, including the file name. +code: Code generated by the artifact. Only runnable code is allowed. + +Make sure to use the correct syntax for the programming language you're using. Make sure to generate only one code file. If you need to use CSS, make sure to include the CSS in the code file using Tailwind CSS syntax. +`; + +// detail information to execute code +export type CodeArtifact = { + commentary: string; + template: string; + title: string; + description: string; + additional_dependencies: string[]; + has_additional_dependencies: boolean; + install_dependencies_command: string; + port: number | null; + file_path: string; + code: string; + files?: string[]; +}; + +export type CodeGeneratorParameter = { + requirement: string; + oldCode?: string; + sandboxFiles?: string[]; +}; + +export type CodeGeneratorToolParams = { + metadata?: ToolMetadata>; +}; + +const DEFAULT_META_DATA: ToolMetadata> = + { + name: "artifact", + description: `Generate a code artifact based on the input. Don't call this tool if the user has not asked for code generation. E.g. if the user asks to write a description or specification, don't call this tool.`, + parameters: { + type: "object", + properties: { + requirement: { + type: "string", + description: "The description of the application you want to build.", + }, + oldCode: { + type: "string", + description: "The existing code to be modified", + nullable: true, + }, + sandboxFiles: { + type: "array", + description: + "A list of sandbox file paths. Include these files if the code requires them.", + items: { + type: "string", + }, + nullable: true, + }, + }, + required: ["requirement"], + }, + }; + +export class CodeGeneratorTool implements BaseTool { + metadata: ToolMetadata>; + + constructor(params?: CodeGeneratorToolParams) { + this.metadata = params?.metadata || DEFAULT_META_DATA; + } + + async call(input: CodeGeneratorParameter) { + try { + const artifact = await this.generateArtifact( + input.requirement, + input.oldCode, + input.sandboxFiles, // help the generated code use exact files + ); + if (input.sandboxFiles) { + artifact.files = input.sandboxFiles; + } + return artifact as JSONValue; + } catch (error) { + return { isError: true }; + } + } + + // Generate artifact (code, environment, dependencies, etc.) + async generateArtifact( + query: string, + oldCode?: string, + attachments?: string[], + ): Promise { + const userMessage = ` + ${query} + ${oldCode ? `The existing code is: \n\`\`\`${oldCode}\`\`\`` : ""} + ${attachments ? `The attachments are: \n${attachments.join("\n")}` : ""} + `; + const messages: ChatMessage[] = [ + { role: "system", content: CODE_GENERATION_PROMPT }, + { role: "user", content: userMessage }, + ]; + try { + const response = await Settings.llm.chat({ messages }); + const content = response.message.content.toString(); + const jsonContent = content + .replace(/^```json\s*|\s*```$/g, "") + .replace(/^`+|`+$/g, "") + .trim(); + const artifact = JSON.parse(jsonContent) as CodeArtifact; + return artifact; + } catch (error) { + console.log("Failed to generate artifact", error); + throw error; + } + } +} diff --git a/app/api/chat/engine/tools/document-generator.ts b/app/api/chat/engine/tools/document-generator.ts new file mode 100644 index 0000000..b630db2 --- /dev/null +++ b/app/api/chat/engine/tools/document-generator.ts @@ -0,0 +1,142 @@ +import { JSONSchemaType } from "ajv"; +import { BaseTool, ToolMetadata } from "llamaindex"; +import { marked } from "marked"; +import path from "node:path"; +import { saveDocument } from "../../llamaindex/documents/helper"; + +const OUTPUT_DIR = "output/tools"; + +type DocumentParameter = { + originalContent: string; + fileName: string; +}; + +const DEFAULT_METADATA: ToolMetadata> = { + name: "document_generator", + description: + "Generate HTML document from markdown content. Return a file url to the document", + parameters: { + type: "object", + properties: { + originalContent: { + type: "string", + description: "The original markdown content to convert.", + }, + fileName: { + type: "string", + description: "The name of the document file (without extension).", + }, + }, + required: ["originalContent", "fileName"], + }, +}; + +const COMMON_STYLES = ` + body { + font-family: Arial, sans-serif; + line-height: 1.3; + color: #333; + } + h1, h2, h3, h4, h5, h6 { + margin-top: 1em; + margin-bottom: 0.5em; + } + p { + margin-bottom: 0.7em; + } + code { + background-color: #f4f4f4; + padding: 2px 4px; + border-radius: 4px; + } + pre { + background-color: #f4f4f4; + padding: 10px; + border-radius: 4px; + overflow-x: auto; + } + table { + border-collapse: collapse; + width: 100%; + margin-bottom: 1em; + } + th, td { + border: 1px solid #ddd; + padding: 8px; + text-align: left; + } + th { + background-color: #f2f2f2; + font-weight: bold; + } + img { + max-width: 90%; + height: auto; + display: block; + margin: 1em auto; + border-radius: 10px; + } +`; + +const HTML_SPECIFIC_STYLES = ` + body { + max-width: 800px; + margin: 0 auto; + padding: 20px; + } +`; + +const HTML_TEMPLATE = ` + + + + + + + + + {{content}} + + +`; + +export interface DocumentGeneratorParams { + metadata?: ToolMetadata>; +} + +export class DocumentGenerator implements BaseTool { + metadata: ToolMetadata>; + + constructor(params: DocumentGeneratorParams) { + this.metadata = params.metadata ?? DEFAULT_METADATA; + } + + private static async generateHtmlContent( + originalContent: string, + ): Promise { + return await marked(originalContent); + } + + private static generateHtmlDocument(htmlContent: string): string { + return HTML_TEMPLATE.replace("{{content}}", htmlContent); + } + + async call(input: DocumentParameter): Promise { + const { originalContent, fileName } = input; + + const htmlContent = + await DocumentGenerator.generateHtmlContent(originalContent); + const fileContent = DocumentGenerator.generateHtmlDocument(htmlContent); + + const filePath = path.join(OUTPUT_DIR, `${fileName}.html`); + + return `URL: ${await saveDocument(filePath, fileContent)}`; + } +} + +export function getTools(): BaseTool[] { + return [new DocumentGenerator({})]; +} diff --git a/app/api/chat/engine/tools/duckduckgo.ts b/app/api/chat/engine/tools/duckduckgo.ts new file mode 100644 index 0000000..419ff90 --- /dev/null +++ b/app/api/chat/engine/tools/duckduckgo.ts @@ -0,0 +1,78 @@ +import { JSONSchemaType } from "ajv"; +import { search } from "duck-duck-scrape"; +import { BaseTool, ToolMetadata } from "llamaindex"; + +export type DuckDuckGoParameter = { + query: string; + region?: string; + maxResults?: number; +}; + +export type DuckDuckGoToolParams = { + metadata?: ToolMetadata>; +}; + +const DEFAULT_SEARCH_METADATA: ToolMetadata< + JSONSchemaType +> = { + name: "duckduckgo_search", + description: + "Use this function to search for information (only text) in the internet using DuckDuckGo.", + parameters: { + type: "object", + properties: { + query: { + type: "string", + description: "The query to search in DuckDuckGo.", + }, + region: { + type: "string", + description: + "Optional, The region to be used for the search in [country-language] convention, ex us-en, uk-en, ru-ru, etc...", + nullable: true, + }, + maxResults: { + type: "number", + description: + "Optional, The maximum number of results to be returned. Default is 10.", + nullable: true, + }, + }, + required: ["query"], + }, +}; + +type DuckDuckGoSearchResult = { + title: string; + description: string; + url: string; +}; + +export class DuckDuckGoSearchTool implements BaseTool { + metadata: ToolMetadata>; + + constructor(params: DuckDuckGoToolParams) { + this.metadata = params.metadata ?? DEFAULT_SEARCH_METADATA; + } + + async call(input: DuckDuckGoParameter) { + const { query, region, maxResults = 10 } = input; + const options = region ? { region } : {}; + // Temporarily sleep to reduce overloading the DuckDuckGo + await new Promise((resolve) => setTimeout(resolve, 1000)); + + const searchResults = await search(query, options); + + return searchResults.results.slice(0, maxResults).map((result) => { + return { + title: result.title, + description: result.description, + url: result.url, + } as DuckDuckGoSearchResult; + }); + } +} + +export function getTools() { + return [new DuckDuckGoSearchTool({})]; +} diff --git a/app/api/chat/engine/tools/form-filling.ts b/app/api/chat/engine/tools/form-filling.ts new file mode 100644 index 0000000..6ac0a52 --- /dev/null +++ b/app/api/chat/engine/tools/form-filling.ts @@ -0,0 +1,296 @@ +import { JSONSchemaType } from "ajv"; +import fs from "fs"; +import { BaseTool, Settings, ToolMetadata } from "llamaindex"; +import Papa from "papaparse"; +import path from "path"; +import { saveDocument } from "../../llamaindex/documents/helper"; + +type ExtractMissingCellsParameter = { + filePath: string; +}; + +export type MissingCell = { + rowIndex: number; + columnIndex: number; + question: string; +}; + +const CSV_EXTRACTION_PROMPT = `You are a data analyst. You are given a table with missing cells. +Your task is to identify the missing cells and the questions needed to fill them. +IMPORTANT: Column indices should be 0-based + +# Instructions: +- Understand the entire content of the table and the topics of the table. +- Identify the missing cells and the meaning of the data in the cells. +- For each missing cell, provide the row index and the correct column index (remember: first data column is 1). +- For each missing cell, provide the question needed to fill the cell (it's important to provide the question that is relevant to the topic of the table). +- Since the cell's value should be concise, the question should request a numerical answer or a specific value. +- Finally, only return the answer in JSON format with the following schema: +{ + "missing_cells": [ + { + "rowIndex": number, + "columnIndex": number, + "question": string + } + ] +} +- If there are no missing cells, return an empty array. +- The answer is only the JSON object, nothing else and don't wrap it inside markdown code block. + +# Example: +# | | Name | Age | City | +# |----|------|-----|------| +# | 0 | John | | Paris| +# | 1 | Mary | | | +# | 2 | | 30 | | +# +# Your thoughts: +# - The table is about people's names, ages, and cities. +# - Row: 1, Column: 2 (Age column), Question: "How old is Mary? Please provide only the numerical answer." +# - Row: 1, Column: 3 (City column), Question: "In which city does Mary live? Please provide only the city name." +# Your answer: +# { +# "missing_cells": [ +# { +# "rowIndex": 1, +# "columnIndex": 2, +# "question": "How old is Mary? Please provide only the numerical answer." +# }, +# { +# "rowIndex": 1, +# "columnIndex": 3, +# "question": "In which city does Mary live? Please provide only the city name." +# } +# ] +# } + + +# Here is your task: + +- Table content: +{table_content} + +- Your answer: +`; + +const DEFAULT_METADATA: ToolMetadata< + JSONSchemaType +> = { + name: "extract_missing_cells", + description: `Use this tool to extract missing cells in a CSV file and generate questions to fill them. This tool only works with local file path.`, + parameters: { + type: "object", + properties: { + filePath: { + type: "string", + description: "The local file path to the CSV file.", + }, + }, + required: ["filePath"], + }, +}; + +export interface ExtractMissingCellsParams { + metadata?: ToolMetadata>; +} + +export class ExtractMissingCellsTool + implements BaseTool +{ + metadata: ToolMetadata>; + defaultExtractionPrompt: string; + + constructor(params: ExtractMissingCellsParams) { + this.metadata = params.metadata ?? DEFAULT_METADATA; + this.defaultExtractionPrompt = CSV_EXTRACTION_PROMPT; + } + + private readCsvFile(filePath: string): Promise { + return new Promise((resolve, reject) => { + fs.readFile(filePath, "utf8", (err, data) => { + if (err) { + reject(err); + return; + } + + const parsedData = Papa.parse(data, { + skipEmptyLines: false, + }); + + if (parsedData.errors.length) { + reject(parsedData.errors); + return; + } + + // Ensure all rows have the same number of columns as the header + const maxColumns = parsedData.data[0].length; + const paddedRows = parsedData.data.map((row) => { + return [...row, ...Array(maxColumns - row.length).fill("")]; + }); + + resolve(paddedRows); + }); + }); + } + + private formatToMarkdownTable(data: string[][]): string { + if (data.length === 0) return ""; + + const maxColumns = data[0].length; + + const headerRow = `| ${data[0].join(" | ")} |`; + const separatorRow = `| ${Array(maxColumns).fill("---").join(" | ")} |`; + + const dataRows = data.slice(1).map((row) => { + return `| ${row.join(" | ")} |`; + }); + + return [headerRow, separatorRow, ...dataRows].join("\n"); + } + + async call(input: ExtractMissingCellsParameter): Promise { + const { filePath } = input; + let tableContent: string[][]; + try { + tableContent = await this.readCsvFile(filePath); + } catch (error) { + throw new Error( + `Failed to read CSV file. Make sure that you are reading a local file path (not a sandbox path).`, + ); + } + + const prompt = this.defaultExtractionPrompt.replace( + "{table_content}", + this.formatToMarkdownTable(tableContent), + ); + + const llm = Settings.llm; + const response = await llm.complete({ + prompt, + }); + const rawAnswer = response.text; + const parsedResponse = JSON.parse(rawAnswer) as { + missing_cells: MissingCell[]; + }; + if (!parsedResponse.missing_cells) { + throw new Error( + "The answer is not in the correct format. There should be a missing_cells array.", + ); + } + const answer = parsedResponse.missing_cells; + + return answer; + } +} + +type FillMissingCellsParameter = { + filePath: string; + cells: { + rowIndex: number; + columnIndex: number; + answer: string; + }[]; +}; + +const FILL_CELLS_METADATA: ToolMetadata< + JSONSchemaType +> = { + name: "fill_missing_cells", + description: `Use this tool to fill missing cells in a CSV file with provided answers. This tool only works with local file path.`, + parameters: { + type: "object", + properties: { + filePath: { + type: "string", + description: "The local file path to the CSV file.", + }, + cells: { + type: "array", + items: { + type: "object", + properties: { + rowIndex: { type: "number" }, + columnIndex: { type: "number" }, + answer: { type: "string" }, + }, + required: ["rowIndex", "columnIndex", "answer"], + }, + description: "Array of cells to fill with their answers", + }, + }, + required: ["filePath", "cells"], + }, +}; + +export interface FillMissingCellsParams { + metadata?: ToolMetadata>; +} + +export class FillMissingCellsTool + implements BaseTool +{ + metadata: ToolMetadata>; + + constructor(params: FillMissingCellsParams = {}) { + this.metadata = params.metadata ?? FILL_CELLS_METADATA; + } + + async call(input: FillMissingCellsParameter): Promise { + const { filePath, cells } = input; + + // Read the CSV file + const fileContent = await new Promise((resolve, reject) => { + fs.readFile(filePath, "utf8", (err, data) => { + if (err) { + reject(err); + } else { + resolve(data); + } + }); + }); + + // Parse CSV with PapaParse + const parseResult = Papa.parse(fileContent, { + header: false, // Ensure the header is not treated as a separate object + skipEmptyLines: false, // Ensure empty lines are not skipped + }); + + if (parseResult.errors.length) { + throw new Error( + "Failed to parse CSV file: " + parseResult.errors[0].message, + ); + } + + const rows = parseResult.data; + + // Fill the cells with answers + for (const cell of cells) { + // Adjust rowIndex to start from 1 for data rows + const adjustedRowIndex = cell.rowIndex + 1; + if ( + adjustedRowIndex < rows.length && + cell.columnIndex < rows[adjustedRowIndex].length + ) { + rows[adjustedRowIndex][cell.columnIndex] = cell.answer; + } + } + + // Convert back to CSV format + const updatedContent = Papa.unparse(rows, { + delimiter: parseResult.meta.delimiter, + }); + + // Use the helper function to write the file + const parsedPath = path.parse(filePath); + const newFileName = `${parsedPath.name}-filled${parsedPath.ext}`; + const newFilePath = path.join("output/tools", newFileName); + + const newFileUrl = await saveDocument(newFilePath, updatedContent); + + return ( + "Successfully filled missing cells in the CSV file. File URL to show to the user: " + + newFileUrl + ); + } +} diff --git a/app/api/chat/engine/tools/img-gen.ts b/app/api/chat/engine/tools/img-gen.ts new file mode 100644 index 0000000..d24d556 --- /dev/null +++ b/app/api/chat/engine/tools/img-gen.ts @@ -0,0 +1,112 @@ +import type { JSONSchemaType } from "ajv"; +import { FormData } from "formdata-node"; +import fs from "fs"; +import got from "got"; +import { BaseTool, ToolMetadata } from "llamaindex"; +import path from "node:path"; +import { Readable } from "stream"; + +export type ImgGeneratorParameter = { + prompt: string; +}; + +export type ImgGeneratorToolParams = { + metadata?: ToolMetadata>; +}; + +export type ImgGeneratorToolOutput = { + isSuccess: boolean; + imageUrl?: string; + errorMessage?: string; +}; + +const DEFAULT_META_DATA: ToolMetadata> = { + name: "image_generator", + description: `Use this function to generate an image based on the prompt.`, + parameters: { + type: "object", + properties: { + prompt: { + type: "string", + description: "The prompt to generate the image", + }, + }, + required: ["prompt"], + }, +}; + +export class ImgGeneratorTool implements BaseTool { + readonly IMG_OUTPUT_FORMAT = "webp"; + readonly IMG_OUTPUT_DIR = "output/tools"; + readonly IMG_GEN_API = + "https://api.stability.ai/v2beta/stable-image/generate/core"; + + metadata: ToolMetadata>; + + constructor(params?: ImgGeneratorToolParams) { + this.checkRequiredEnvVars(); + this.metadata = params?.metadata || DEFAULT_META_DATA; + } + + async call(input: ImgGeneratorParameter): Promise { + return await this.generateImage(input.prompt); + } + + private generateImage = async ( + prompt: string, + ): Promise => { + try { + const buffer = await this.promptToImgBuffer(prompt); + const imageUrl = this.saveImage(buffer); + return { isSuccess: true, imageUrl }; + } catch (error) { + console.error(error); + return { + isSuccess: false, + errorMessage: "Failed to generate image. Please try again.", + }; + } + }; + + private promptToImgBuffer = async (prompt: string) => { + const form = new FormData(); + form.append("prompt", prompt); + form.append("output_format", this.IMG_OUTPUT_FORMAT); + const buffer = await got + .post(this.IMG_GEN_API, { + // Not sure why it shows an type error when passing form to body + // Although I follow document: https://github.com/sindresorhus/got/blob/main/documentation/2-options.md#body + // Tt still works fine, so I make casting to unknown to avoid the typescript warning + // Found a similar issue: https://github.com/sindresorhus/got/discussions/1877 + body: form as unknown as Buffer | Readable | string, + headers: { + Authorization: `Bearer ${process.env.STABILITY_API_KEY}`, + Accept: "image/*", + }, + }) + .buffer(); + return buffer; + }; + + private saveImage = (buffer: Buffer) => { + const filename = `${crypto.randomUUID()}.${this.IMG_OUTPUT_FORMAT}`; + const outputPath = path.join(this.IMG_OUTPUT_DIR, filename); + fs.writeFileSync(outputPath, buffer); + const url = `${process.env.FILESERVER_URL_PREFIX}/${this.IMG_OUTPUT_DIR}/${filename}`; + console.log(`Saved image to ${outputPath}.\nURL: ${url}`); + return url; + }; + + private checkRequiredEnvVars = () => { + if (!process.env.STABILITY_API_KEY) { + throw new Error( + "STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys", + ); + } + if (!process.env.FILESERVER_URL_PREFIX) { + throw new Error( + "FILESERVER_URL_PREFIX is required to display file output after generation", + ); + } + }; +} diff --git a/app/api/chat/engine/tools/index.ts b/app/api/chat/engine/tools/index.ts new file mode 100644 index 0000000..edc4d62 --- /dev/null +++ b/app/api/chat/engine/tools/index.ts @@ -0,0 +1,103 @@ +import { BaseToolWithCall } from "llamaindex"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { CodeGeneratorTool, CodeGeneratorToolParams } from "./code-generator"; +import { + DocumentGenerator, + DocumentGeneratorParams, +} from "./document-generator"; +import { DuckDuckGoSearchTool, DuckDuckGoToolParams } from "./duckduckgo"; +import { + ExtractMissingCellsParams, + ExtractMissingCellsTool, + FillMissingCellsParams, + FillMissingCellsTool, +} from "./form-filling"; +import { ImgGeneratorTool, ImgGeneratorToolParams } from "./img-gen"; +import { InterpreterTool, InterpreterToolParams } from "./interpreter"; +import { OpenAPIActionTool } from "./openapi-action"; +import { WeatherTool, WeatherToolParams } from "./weather"; +import { WikipediaTool, WikipediaToolParams } from "./wikipedia"; + +type ToolCreator = (config: unknown) => Promise; + +export async function createTools(toolConfig: { + local: Record; + llamahub: any; +}): Promise { + // add local tools from the 'tools' folder (if configured) + const tools = await createLocalTools(toolConfig.local); + return tools; +} + +const toolFactory: Record = { + "wikipedia.WikipediaToolSpec": async (config: unknown) => { + return [new WikipediaTool(config as WikipediaToolParams)]; + }, + weather: async (config: unknown) => { + return [new WeatherTool(config as WeatherToolParams)]; + }, + interpreter: async (config: unknown) => { + return [new InterpreterTool(config as InterpreterToolParams)]; + }, + "openapi_action.OpenAPIActionToolSpec": async (config: unknown) => { + const { openapi_uri, domain_headers } = config as { + openapi_uri: string; + domain_headers: Record>; + }; + const openAPIActionTool = new OpenAPIActionTool( + openapi_uri, + domain_headers, + ); + return await openAPIActionTool.toToolFunctions(); + }, + duckduckgo: async (config: unknown) => { + return [new DuckDuckGoSearchTool(config as DuckDuckGoToolParams)]; + }, + img_gen: async (config: unknown) => { + return [new ImgGeneratorTool(config as ImgGeneratorToolParams)]; + }, + artifact: async (config: unknown) => { + return [new CodeGeneratorTool(config as CodeGeneratorToolParams)]; + }, + document_generator: async (config: unknown) => { + return [new DocumentGenerator(config as DocumentGeneratorParams)]; + }, + form_filling: async (config: unknown) => { + return [ + new ExtractMissingCellsTool(config as ExtractMissingCellsParams), + new FillMissingCellsTool(config as FillMissingCellsParams), + ]; + }, +}; + +async function createLocalTools( + localConfig: Record, +): Promise { + const tools: BaseToolWithCall[] = []; + + for (const [key, toolConfig] of Object.entries(localConfig)) { + if (key in toolFactory) { + const newTools = await toolFactory[key](toolConfig); + tools.push(...newTools); + } + } + + return tools; +} + +export async function getConfiguredTools( + configPath?: string, +): Promise { + const configFile = path.join(configPath ?? "config", "tools.json"); + const toolConfig = JSON.parse(await fs.readFile(configFile, "utf8")); + const tools = await createTools(toolConfig); + return tools; +} + +export async function getTool( + toolName: string, +): Promise { + const tools = await getConfiguredTools(); + return tools.find((tool) => tool.metadata.name === toolName); +} diff --git a/app/api/chat/engine/tools/interpreter.ts b/app/api/chat/engine/tools/interpreter.ts new file mode 100644 index 0000000..9cea944 --- /dev/null +++ b/app/api/chat/engine/tools/interpreter.ts @@ -0,0 +1,248 @@ +import { Logs, Result, Sandbox } from "@e2b/code-interpreter"; +import type { JSONSchemaType } from "ajv"; +import fs from "fs"; +import { BaseTool, ToolMetadata } from "llamaindex"; +import crypto from "node:crypto"; +import path from "node:path"; + +export type InterpreterParameter = { + code: string; + sandboxFiles?: string[]; + retryCount?: number; +}; + +export type InterpreterToolParams = { + metadata?: ToolMetadata>; + apiKey?: string; + fileServerURLPrefix?: string; +}; + +export type InterpreterToolOutput = { + isError: boolean; + logs: Logs; + text?: string; + extraResult: InterpreterExtraResult[]; + retryCount?: number; +}; + +type InterpreterExtraType = + | "html" + | "markdown" + | "svg" + | "png" + | "jpeg" + | "pdf" + | "latex" + | "json" + | "javascript"; + +export type InterpreterExtraResult = { + type: InterpreterExtraType; + content?: string; + filename?: string; + url?: string; +}; + +const DEFAULT_META_DATA: ToolMetadata> = { + name: "interpreter", + description: `Execute python code in a Jupyter notebook cell and return any result, stdout, stderr, display_data, and error. +If the code needs to use a file, ALWAYS pass the file path in the sandbox_files argument. +You have a maximum of 3 retries to get the code to run successfully. +`, + parameters: { + type: "object", + properties: { + code: { + type: "string", + description: "The python code to execute in a single cell.", + }, + sandboxFiles: { + type: "array", + description: + "List of local file paths to be used by the code. The tool will throw an error if a file is not found.", + items: { + type: "string", + }, + nullable: true, + }, + retryCount: { + type: "number", + description: "The number of times the tool has been retried", + default: 0, + nullable: true, + }, + }, + required: ["code"], + }, +}; + +export class InterpreterTool implements BaseTool { + private readonly outputDir = "output/tools"; + private readonly uploadedFilesDir = "output/uploaded"; + private apiKey?: string; + private fileServerURLPrefix?: string; + metadata: ToolMetadata>; + codeInterpreter?: Sandbox; + + constructor(params?: InterpreterToolParams) { + this.metadata = params?.metadata || DEFAULT_META_DATA; + this.apiKey = params?.apiKey || process.env.E2B_API_KEY; + this.fileServerURLPrefix = + params?.fileServerURLPrefix || process.env.FILESERVER_URL_PREFIX; + + if (!this.apiKey) { + throw new Error( + "E2B_API_KEY key is required to run code interpreter. Get it here: https://e2b.dev/docs/getting-started/api-key", + ); + } + if (!this.fileServerURLPrefix) { + throw new Error( + "FILESERVER_URL_PREFIX is required to display file output from sandbox", + ); + } + } + + public async initInterpreter(input: InterpreterParameter) { + if (!this.codeInterpreter) { + this.codeInterpreter = await Sandbox.create({ + apiKey: this.apiKey, + }); + // upload files to sandbox when it's initialized + if (input.sandboxFiles) { + console.log(`Uploading ${input.sandboxFiles.length} files to sandbox`); + try { + for (const filePath of input.sandboxFiles) { + const fileName = path.basename(filePath); + const localFilePath = path.join(this.uploadedFilesDir, fileName); + const content = fs.readFileSync(localFilePath); + + const arrayBuffer = new Uint8Array(content).buffer; + await this.codeInterpreter?.files.write(filePath, arrayBuffer); + } + } catch (error) { + console.error("Got error when uploading files to sandbox", error); + } + } + } + + return this.codeInterpreter; + } + + public async codeInterpret( + input: InterpreterParameter, + ): Promise { + console.log( + `Sandbox files: ${input.sandboxFiles}. Retry count: ${input.retryCount}`, + ); + + if (input.retryCount && input.retryCount >= 3) { + return { + isError: true, + logs: { + stdout: [], + stderr: [], + }, + text: "Max retries reached", + extraResult: [], + }; + } + + console.log( + `\n${"=".repeat(50)}\n> Running following AI-generated code:\n${input.code}\n${"=".repeat(50)}`, + ); + const interpreter = await this.initInterpreter(input); + const exec = await interpreter.runCode(input.code); + if (exec.error) console.error("[Code Interpreter error]", exec.error); + const extraResult = await this.getExtraResult(exec.results[0]); + const result: InterpreterToolOutput = { + isError: !!exec.error, + logs: exec.logs, + text: exec.text, + extraResult, + retryCount: input.retryCount ? input.retryCount + 1 : 1, + }; + return result; + } + + async call(input: InterpreterParameter): Promise { + const result = await this.codeInterpret(input); + return result; + } + + async close() { + await this.codeInterpreter?.kill(); + } + + private async getExtraResult( + res?: Result, + ): Promise { + if (!res) return []; + const output: InterpreterExtraResult[] = []; + + try { + const formats = res.formats(); // formats available for the result. Eg: ['png', ...] + const results = formats.map((f) => res[f as keyof Result]); // get base64 data for each format + + // save base64 data to file and return the url + for (let i = 0; i < formats.length; i++) { + const ext = formats[i]; + const data = results[i]; + switch (ext) { + case "png": + case "jpeg": + case "svg": + case "pdf": + const { filename } = this.saveToDisk(data, ext); + output.push({ + type: ext as InterpreterExtraType, + filename, + url: this.getFileUrl(filename), + }); + break; + default: + output.push({ + type: ext as InterpreterExtraType, + content: data, + }); + break; + } + } + } catch (error) { + console.error("Error when parsing e2b response", error); + } + + return output; + } + + // Consider saving to cloud storage instead but it may cost more for you + // See: https://e2b.dev/docs/sandbox/api/filesystem#write-to-file + private saveToDisk( + base64Data: string, + ext: string, + ): { + outputPath: string; + filename: string; + } { + const filename = `${crypto.randomUUID()}.${ext}`; // generate a unique filename + const buffer = Buffer.from(base64Data, "base64"); + const outputPath = this.getOutputPath(filename); + fs.writeFileSync(outputPath, buffer); + console.log(`Saved file to ${outputPath}`); + return { + outputPath, + filename, + }; + } + + private getOutputPath(filename: string): string { + // if outputDir doesn't exist, create it + if (!fs.existsSync(this.outputDir)) { + fs.mkdirSync(this.outputDir, { recursive: true }); + } + return path.join(this.outputDir, filename); + } + + private getFileUrl(filename: string): string { + return `${this.fileServerURLPrefix}/${this.outputDir}/${filename}`; + } +} diff --git a/app/api/chat/engine/tools/openapi-action.ts b/app/api/chat/engine/tools/openapi-action.ts new file mode 100644 index 0000000..74bb5bd --- /dev/null +++ b/app/api/chat/engine/tools/openapi-action.ts @@ -0,0 +1,164 @@ +import SwaggerParser from "@apidevtools/swagger-parser"; +import { JSONSchemaType } from "ajv"; +import got from "got"; +import { FunctionTool, JSONValue, ToolMetadata } from "llamaindex"; + +interface DomainHeaders { + [key: string]: { [header: string]: string }; +} + +type Input = { + url: string; + params: object; +}; + +type APIInfo = { + description: string; + title: string; +}; + +export class OpenAPIActionTool { + // cache the loaded specs by URL + private static specs: Record = {}; + + private readonly INVALID_URL_PROMPT = + "This url did not include a hostname or scheme. Please determine the complete URL and try again."; + + private createLoadSpecMetaData = (info: APIInfo) => { + return { + name: "load_openapi_spec", + description: `Use this to retrieve the OpenAPI spec for the API named ${info.title} with the following description: ${info.description}. Call it before making any requests to the API.`, + }; + }; + + private readonly createMethodCallMetaData = ( + method: "POST" | "PATCH" | "GET", + info: APIInfo, + ) => { + return { + name: `${method.toLowerCase()}_request`, + description: `Use this to call the ${method} method on the API named ${info.title}`, + parameters: { + type: "object", + properties: { + url: { + type: "string", + description: `The url to make the ${method} request against`, + }, + params: { + type: "object", + description: + method === "GET" + ? "the URL parameters to provide with the get request" + : `the key-value pairs to provide with the ${method} request`, + }, + }, + required: ["url"], + }, + } as ToolMetadata>; + }; + + constructor( + public openapi_uri: string, + public domainHeaders: DomainHeaders = {}, + ) {} + + async loadOpenapiSpec(url: string): Promise { + const api = await SwaggerParser.validate(url); + return { + servers: "servers" in api ? api.servers : "", + info: { description: api.info.description, title: api.info.title }, + endpoints: api.paths, + }; + } + + async getRequest(input: Input): Promise { + if (!this.validUrl(input.url)) { + return this.INVALID_URL_PROMPT; + } + try { + const data = await got + .get(input.url, { + headers: this.getHeadersForUrl(input.url), + searchParams: input.params as URLSearchParams, + }) + .json(); + return data as JSONValue; + } catch (error) { + return error as JSONValue; + } + } + + async postRequest(input: Input): Promise { + if (!this.validUrl(input.url)) { + return this.INVALID_URL_PROMPT; + } + try { + const res = await got.post(input.url, { + headers: this.getHeadersForUrl(input.url), + json: input.params, + }); + return res.body as JSONValue; + } catch (error) { + return error as JSONValue; + } + } + + async patchRequest(input: Input): Promise { + if (!this.validUrl(input.url)) { + return this.INVALID_URL_PROMPT; + } + try { + const res = await got.patch(input.url, { + headers: this.getHeadersForUrl(input.url), + json: input.params, + }); + return res.body as JSONValue; + } catch (error) { + return error as JSONValue; + } + } + + public async toToolFunctions() { + if (!OpenAPIActionTool.specs[this.openapi_uri]) { + console.log(`Loading spec for URL: ${this.openapi_uri}`); + const spec = await this.loadOpenapiSpec(this.openapi_uri); + OpenAPIActionTool.specs[this.openapi_uri] = spec; + } + const spec = OpenAPIActionTool.specs[this.openapi_uri]; + // TODO: read endpoints with parameters from spec and create one tool for each endpoint + // For now, we just create a tool for each HTTP method which does not work well for passing parameters + return [ + FunctionTool.from(() => { + return spec; + }, this.createLoadSpecMetaData(spec.info)), + FunctionTool.from( + this.getRequest.bind(this), + this.createMethodCallMetaData("GET", spec.info), + ), + FunctionTool.from( + this.postRequest.bind(this), + this.createMethodCallMetaData("POST", spec.info), + ), + FunctionTool.from( + this.patchRequest.bind(this), + this.createMethodCallMetaData("PATCH", spec.info), + ), + ]; + } + + private validUrl(url: string): boolean { + const parsed = new URL(url); + return !!parsed.protocol && !!parsed.hostname; + } + + private getDomain(url: string): string { + const parsed = new URL(url); + return parsed.hostname; + } + + private getHeadersForUrl(url: string): { [header: string]: string } { + const domain = this.getDomain(url); + return this.domainHeaders[domain] || {}; + } +} diff --git a/app/api/chat/engine/tools/query-engine.ts b/app/api/chat/engine/tools/query-engine.ts new file mode 100644 index 0000000..48e1e9c --- /dev/null +++ b/app/api/chat/engine/tools/query-engine.ts @@ -0,0 +1,57 @@ +import { + BaseQueryEngine, + CloudRetrieveParams, + LlamaCloudIndex, + MetadataFilters, + QueryEngineTool, + VectorStoreIndex, +} from "llamaindex"; +import { generateFilters } from "../queryFilter"; + +interface QueryEngineParams { + documentIds?: string[]; + topK?: number; +} + +export function createQueryEngineTool( + index: VectorStoreIndex | LlamaCloudIndex, + params?: QueryEngineParams, + name?: string, + description?: string, +): QueryEngineTool { + return new QueryEngineTool({ + queryEngine: createQueryEngine(index, params), + metadata: { + name: name || "query_engine", + description: + description || + `Use this tool to retrieve information about the text corpus from an index.`, + }, + }); +} + +function createQueryEngine( + index: VectorStoreIndex | LlamaCloudIndex, + params?: QueryEngineParams, +): BaseQueryEngine { + const baseQueryParams = { + similarityTopK: + params?.topK ?? + (process.env.TOP_K ? parseInt(process.env.TOP_K) : undefined), + }; + + if (index instanceof LlamaCloudIndex) { + return index.asQueryEngine({ + ...baseQueryParams, + retrieval_mode: "auto_routed", + preFilters: generateFilters( + params?.documentIds || [], + ) as CloudRetrieveParams["filters"], + }); + } + + return index.asQueryEngine({ + ...baseQueryParams, + preFilters: generateFilters(params?.documentIds || []) as MetadataFilters, + }); +} diff --git a/app/api/chat/engine/tools/weather.ts b/app/api/chat/engine/tools/weather.ts new file mode 100644 index 0000000..c1f6014 --- /dev/null +++ b/app/api/chat/engine/tools/weather.ts @@ -0,0 +1,81 @@ +import type { JSONSchemaType } from "ajv"; +import { BaseTool, ToolMetadata } from "llamaindex"; + +interface GeoLocation { + id: string; + name: string; + latitude: number; + longitude: number; +} + +export type WeatherParameter = { + location: string; +}; + +export type WeatherToolParams = { + metadata?: ToolMetadata>; +}; + +const DEFAULT_META_DATA: ToolMetadata> = { + name: "get_weather_information", + description: ` + Use this function to get the weather of any given location. + Note that the weather code should follow WMO Weather interpretation codes (WW): + 0: Clear sky + 1, 2, 3: Mainly clear, partly cloudy, and overcast + 45, 48: Fog and depositing rime fog + 51, 53, 55: Drizzle: Light, moderate, and dense intensity + 56, 57: Freezing Drizzle: Light and dense intensity + 61, 63, 65: Rain: Slight, moderate and heavy intensity + 66, 67: Freezing Rain: Light and heavy intensity + 71, 73, 75: Snow fall: Slight, moderate, and heavy intensity + 77: Snow grains + 80, 81, 82: Rain showers: Slight, moderate, and violent + 85, 86: Snow showers slight and heavy + 95: Thunderstorm: Slight or moderate + 96, 99: Thunderstorm with slight and heavy hail + `, + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The location to get the weather information", + }, + }, + required: ["location"], + }, +}; + +export class WeatherTool implements BaseTool { + metadata: ToolMetadata>; + + private getGeoLocation = async (location: string): Promise => { + const apiUrl = `https://geocoding-api.open-meteo.com/v1/search?name=${location}&count=10&language=en&format=json`; + const response = await fetch(apiUrl); + const data = await response.json(); + const { id, name, latitude, longitude } = data.results[0]; + return { id, name, latitude, longitude }; + }; + + private getWeatherByLocation = async (location: string) => { + console.log( + "Calling open-meteo api to get weather information of location:", + location, + ); + const { latitude, longitude } = await this.getGeoLocation(location); + const timezone = Intl.DateTimeFormat().resolvedOptions().timeZone; + const apiUrl = `https://api.open-meteo.com/v1/forecast?latitude=${latitude}&longitude=${longitude}¤t=temperature_2m,weather_code&hourly=temperature_2m,weather_code&daily=weather_code&timezone=${timezone}`; + const response = await fetch(apiUrl); + const data = await response.json(); + return data; + }; + + constructor(params?: WeatherToolParams) { + this.metadata = params?.metadata || DEFAULT_META_DATA; + } + + async call(input: WeatherParameter) { + return await this.getWeatherByLocation(input.location); + } +} diff --git a/app/api/chat/engine/tools/wikipedia.ts b/app/api/chat/engine/tools/wikipedia.ts new file mode 100644 index 0000000..5e171e9 --- /dev/null +++ b/app/api/chat/engine/tools/wikipedia.ts @@ -0,0 +1,60 @@ +import type { JSONSchemaType } from "ajv"; +import type { BaseTool, ToolMetadata } from "llamaindex"; +import { default as wiki } from "wikipedia"; + +type WikipediaParameter = { + query: string; + lang?: string; +}; + +export type WikipediaToolParams = { + metadata?: ToolMetadata>; +}; + +const DEFAULT_META_DATA: ToolMetadata> = { + name: "wikipedia_tool", + description: "A tool that uses a query engine to search Wikipedia.", + parameters: { + type: "object", + properties: { + query: { + type: "string", + description: "The query to search for", + }, + lang: { + type: "string", + description: "The language to search in", + nullable: true, + }, + }, + required: ["query"], + }, +}; + +export class WikipediaTool implements BaseTool { + private readonly DEFAULT_LANG = "en"; + metadata: ToolMetadata>; + + constructor(params?: WikipediaToolParams) { + this.metadata = params?.metadata || DEFAULT_META_DATA; + } + + async loadData( + page: string, + lang: string = this.DEFAULT_LANG, + ): Promise { + wiki.setLang(lang); + const pageResult = await wiki.page(page, { autoSuggest: false }); + const content = await pageResult.content(); + return content; + } + + async call({ + query, + lang = this.DEFAULT_LANG, + }: WikipediaParameter): Promise { + const searchResult = await wiki.search(query); + if (searchResult.results.length === 0) return "No search results."; + return await this.loadData(searchResult.results[0].title, lang); + } +} diff --git a/app/api/chat/llamaindex-stream.ts b/app/api/chat/llamaindex-stream.ts deleted file mode 100644 index 6ffb32f..0000000 --- a/app/api/chat/llamaindex-stream.ts +++ /dev/null @@ -1,78 +0,0 @@ -import { - StreamData, - createCallbacksTransformer, - createStreamDataTransformer, - trimStartOfStreamHelper, - type AIStreamCallbacksAndOptions, -} from "ai"; - -import { - Metadata, - NodeWithScore, - EngineResponse, - ToolCallLLMMessageOptions, -} from "llamaindex"; - -import { appendImageData, appendSourceData } from "./stream-helper"; - -type LlamaIndexResponse = - | ReadableStream - | EngineResponse; - -type ParserOptions = { - image_url?: string; -}; - -function createParser( - res: AsyncIterable, - data: StreamData, - opts?: ParserOptions, -) { - const it = res[Symbol.asyncIterator](); - const trimStartOfStream = trimStartOfStreamHelper(); - - let sourceNodes: NodeWithScore[] | undefined; - return new ReadableStream({ - start() { - appendImageData(data, opts?.image_url); - }, - async pull(controller): Promise { - const { value, done } = await it.next(); - if (done) { - if (sourceNodes) { - appendSourceData(data, sourceNodes); - } - controller.close(); - data.close(); - return; - } - - let delta; - if (value instanceof EngineResponse) { - // handle Response type - if (value.sourceNodes) { - // get source nodes from the first response - sourceNodes = value.sourceNodes; - } - delta = value.response ?? ""; - } - const text = trimStartOfStream(delta ?? ""); - if (text) { - controller.enqueue(text); - } - }, - }); -} - -export function LlamaIndexStream( - response: AsyncIterable, - data: StreamData, - opts?: { - callbacks?: AIStreamCallbacksAndOptions; - parserOptions?: ParserOptions; - }, -): ReadableStream { - return createParser(response, data, opts?.parserOptions) - .pipeThrough(createCallbacksTransformer(opts?.callbacks)) - .pipeThrough(createStreamDataTransformer()); -} diff --git a/app/api/chat/llamaindex/documents/helper.ts b/app/api/chat/llamaindex/documents/helper.ts new file mode 100644 index 0000000..a6d18c9 --- /dev/null +++ b/app/api/chat/llamaindex/documents/helper.ts @@ -0,0 +1,105 @@ +import { Document } from "llamaindex"; +import crypto from "node:crypto"; +import fs from "node:fs"; +import path from "node:path"; +import { getExtractors } from "../../engine/loader"; +import { DocumentFile } from "../streaming/annotations"; + +const MIME_TYPE_TO_EXT: Record = { + "application/pdf": "pdf", + "text/plain": "txt", + "text/csv": "csv", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": + "docx", +}; + +export const UPLOADED_FOLDER = "output/uploaded"; + +export async function storeAndParseFile( + name: string, + fileBuffer: Buffer, + mimeType: string, +): Promise { + const file = await storeFile(name, fileBuffer, mimeType); + const documents: Document[] = await parseFile(fileBuffer, name, mimeType); + // Update document IDs in the file metadata + file.refs = documents.map((document) => document.id_ as string); + return file; +} + +export async function storeFile( + name: string, + fileBuffer: Buffer, + mimeType: string, +) { + const fileExt = MIME_TYPE_TO_EXT[mimeType]; + if (!fileExt) throw new Error(`Unsupported document type: ${mimeType}`); + + const fileId = crypto.randomUUID(); + const newFilename = `${sanitizeFileName(name)}_${fileId}.${fileExt}`; + const filepath = path.join(UPLOADED_FOLDER, newFilename); + const fileUrl = await saveDocument(filepath, fileBuffer); + return { + id: fileId, + name: newFilename, + size: fileBuffer.length, + type: fileExt, + url: fileUrl, + refs: [] as string[], + } as DocumentFile; +} + +export async function parseFile( + fileBuffer: Buffer, + filename: string, + mimeType: string, +) { + const documents = await loadDocuments(fileBuffer, mimeType); + for (const document of documents) { + document.metadata = { + ...document.metadata, + file_name: filename, + private: "true", // to separate private uploads from public documents + }; + } + return documents; +} + +async function loadDocuments(fileBuffer: Buffer, mimeType: string) { + const extractors = getExtractors(); + const reader = extractors[MIME_TYPE_TO_EXT[mimeType]]; + + if (!reader) { + throw new Error(`Unsupported document type: ${mimeType}`); + } + console.log(`Processing uploaded document of type: ${mimeType}`); + return await reader.loadDataAsContent(fileBuffer); +} + +// Save document to file server and return the file url +export async function saveDocument(filepath: string, content: string | Buffer) { + if (path.isAbsolute(filepath)) { + throw new Error("Absolute file paths are not allowed."); + } + if (!process.env.FILESERVER_URL_PREFIX) { + throw new Error("FILESERVER_URL_PREFIX environment variable is not set."); + } + + const dirPath = path.dirname(filepath); + await fs.promises.mkdir(dirPath, { recursive: true }); + + if (typeof content === "string") { + await fs.promises.writeFile(filepath, content, "utf-8"); + } else { + await fs.promises.writeFile(filepath, content); + } + + const fileurl = `${process.env.FILESERVER_URL_PREFIX}/${filepath}`; + console.log(`Saved document to ${filepath}. Reachable at URL: ${fileurl}`); + return fileurl; +} + +function sanitizeFileName(fileName: string) { + // Remove file extension and sanitize + return fileName.split(".")[0].replace(/[^a-zA-Z0-9_-]/g, "_"); +} diff --git a/app/api/chat/llamaindex/documents/pipeline.ts b/app/api/chat/llamaindex/documents/pipeline.ts new file mode 100644 index 0000000..cd4d6d0 --- /dev/null +++ b/app/api/chat/llamaindex/documents/pipeline.ts @@ -0,0 +1,48 @@ +import { + Document, + IngestionPipeline, + Settings, + SimpleNodeParser, + storageContextFromDefaults, + VectorStoreIndex, +} from "llamaindex"; + +export async function runPipeline( + currentIndex: VectorStoreIndex | null, + documents: Document[], +) { + // Use ingestion pipeline to process the documents into nodes and add them to the vector store + const pipeline = new IngestionPipeline({ + transformations: [ + new SimpleNodeParser({ + chunkSize: Settings.chunkSize, + chunkOverlap: Settings.chunkOverlap, + }), + Settings.embedModel, + ], + }); + const nodes = await pipeline.run({ documents }); + if (currentIndex) { + await currentIndex.insertNodes(nodes); + currentIndex.storageContext.docStore.persist(); + console.log("Added nodes to the vector store."); + return documents.map((document) => document.id_); + } else { + // Initialize a new index with the documents + console.log( + "Got empty index, created new index with the uploaded documents", + ); + const persistDir = process.env.STORAGE_CACHE_DIR; + if (!persistDir) { + throw new Error("STORAGE_CACHE_DIR environment variable is required!"); + } + const storageContext = await storageContextFromDefaults({ + persistDir, + }); + const newIndex = await VectorStoreIndex.fromDocuments(documents, { + storageContext, + }); + await newIndex.storageContext.docStore.persist(); + return documents.map((document) => document.id_); + } +} diff --git a/app/api/chat/llamaindex/documents/upload.ts b/app/api/chat/llamaindex/documents/upload.ts new file mode 100644 index 0000000..b3786a3 --- /dev/null +++ b/app/api/chat/llamaindex/documents/upload.ts @@ -0,0 +1,61 @@ +import { Document, LLamaCloudFileService, VectorStoreIndex } from "llamaindex"; +import { LlamaCloudIndex } from "llamaindex/cloud/LlamaCloudIndex"; +import { DocumentFile } from "../streaming/annotations"; +import { parseFile, storeFile } from "./helper"; +import { runPipeline } from "./pipeline"; + +export async function uploadDocument( + index: VectorStoreIndex | LlamaCloudIndex | null, + name: string, + raw: string, +): Promise { + const [header, content] = raw.split(","); + const mimeType = header.replace("data:", "").replace(";base64", ""); + const fileBuffer = Buffer.from(content, "base64"); + + // Store file + const fileMetadata = await storeFile(name, fileBuffer, mimeType); + + // Do not index csv files + if (mimeType === "text/csv") { + return fileMetadata; + } + let documentIds: string[] = []; + if (index instanceof LlamaCloudIndex) { + // trigger LlamaCloudIndex API to upload the file and run the pipeline + const projectId = await index.getProjectId(); + const pipelineId = await index.getPipelineId(); + try { + documentIds = [ + await LLamaCloudFileService.addFileToPipeline( + projectId, + pipelineId, + new File([fileBuffer], name, { type: mimeType }), + { private: "true" }, + ), + ]; + } catch (error) { + if ( + error instanceof ReferenceError && + error.message.includes("File is not defined") + ) { + throw new Error( + "File class is not supported in the current Node.js version. Please use Node.js 20 or higher.", + ); + } + throw error; + } + } else { + // run the pipeline for other vector store indexes + const documents: Document[] = await parseFile( + fileBuffer, + fileMetadata.name, + mimeType, + ); + documentIds = await runPipeline(index, documents); + } + + // Update file metadata with document IDs + fileMetadata.refs = documentIds; + return fileMetadata; +} diff --git a/app/api/chat/llamaindex/streaming/annotations.ts b/app/api/chat/llamaindex/streaming/annotations.ts new file mode 100644 index 0000000..164ddca --- /dev/null +++ b/app/api/chat/llamaindex/streaming/annotations.ts @@ -0,0 +1,251 @@ +import { JSONValue, Message } from "ai"; +import { + ChatMessage, + MessageContent, + MessageContentDetail, + MessageType, +} from "llamaindex"; +import { UPLOADED_FOLDER } from "../documents/helper"; + +export type DocumentFileType = "csv" | "pdf" | "txt" | "docx"; + +export type DocumentFile = { + id: string; + name: string; + size: number; + type: string; + url: string; + refs?: string[]; +}; + +type Annotation = { + type: string; + data: object; +}; + +export function isValidMessages(messages: Message[]): boolean { + const lastMessage = + messages && messages.length > 0 ? messages[messages.length - 1] : null; + return lastMessage !== null && lastMessage.role === "user"; +} + +export function retrieveDocumentIds(messages: Message[]): string[] { + // retrieve document Ids from the annotations of all messages (if any) + const documentFiles = retrieveDocumentFiles(messages); + return documentFiles.map((file) => file.refs || []).flat(); +} + +export function retrieveDocumentFiles(messages: Message[]): DocumentFile[] { + const annotations = getAllAnnotations(messages); + if (annotations.length === 0) return []; + + const files: DocumentFile[] = []; + for (const { type, data } of annotations) { + if ( + type === "document_file" && + "files" in data && + Array.isArray(data.files) + ) { + files.push(...data.files); + } + } + return files; +} + +export function retrieveMessageContent(messages: Message[]): MessageContent { + const userMessage = messages[messages.length - 1]; + return [ + { + type: "text", + text: userMessage.content, + }, + ...retrieveLatestArtifact(messages), + ...convertAnnotations(messages), + ]; +} + +export function convertToChatHistory(messages: Message[]): ChatMessage[] { + if (!messages || !Array.isArray(messages)) { + return []; + } + const agentHistory = retrieveAgentHistoryMessage(messages); + if (agentHistory) { + const previousMessages = messages.slice(0, -1); + return [...previousMessages, agentHistory].map((msg) => ({ + role: msg.role as MessageType, + content: msg.content, + })); + } + return messages.map((msg) => ({ + role: msg.role as MessageType, + content: msg.content, + })); +} + +function retrieveAgentHistoryMessage( + messages: Message[], + maxAgentMessages = 10, +): ChatMessage | null { + const agentAnnotations = getAnnotations<{ agent: string; text: string }>( + messages, + { role: "assistant", type: "agent" }, + ).slice(-maxAgentMessages); + + if (agentAnnotations.length > 0) { + const messageContent = + "Here is the previous conversation of agents:\n" + + agentAnnotations.map((annotation) => annotation.data.text).join("\n"); + return { + role: "assistant", + content: messageContent, + }; + } + return null; +} + +function getFileContent(file: DocumentFile): string { + let defaultContent = `=====File: ${file.name}=====\n`; + // Include file URL if it's available + const urlPrefix = process.env.FILESERVER_URL_PREFIX; + let urlContent = ""; + if (urlPrefix) { + if (file.url) { + urlContent = `File URL: ${file.url}\n`; + } else { + urlContent = `File URL (instruction: do not update this file URL yourself): ${urlPrefix}/output/uploaded/${file.name}\n`; + } + } else { + console.warn( + "Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server", + ); + } + defaultContent += urlContent; + + // Include document IDs if it's available + if (file.refs) { + defaultContent += `Document IDs: ${file.refs}\n`; + } + // Include sandbox file paths + const sandboxFilePath = `/tmp/${file.name}`; + defaultContent += `Sandbox file path (instruction: only use sandbox path for artifact or code interpreter tool): ${sandboxFilePath}\n`; + + // Include local file path + const localFilePath = `${UPLOADED_FOLDER}/${file.name}`; + defaultContent += `Local file path (instruction: use for local tool that requires a local path): ${localFilePath}\n`; + + return defaultContent; +} + +function getAllAnnotations(messages: Message[]): Annotation[] { + return messages.flatMap((message) => + (message.annotations ?? []).map((annotation) => + getValidAnnotation(annotation), + ), + ); +} + +// get latest artifact from annotations to append to the user message +function retrieveLatestArtifact(messages: Message[]): MessageContentDetail[] { + const annotations = getAllAnnotations(messages); + if (annotations.length === 0) return []; + + for (const { type, data } of annotations.reverse()) { + if ( + type === "tools" && + "toolCall" in data && + "toolOutput" in data && + typeof data.toolCall === "object" && + typeof data.toolOutput === "object" && + data.toolCall !== null && + data.toolOutput !== null && + "name" in data.toolCall && + data.toolCall.name === "artifact" + ) { + const toolOutput = data.toolOutput as { output?: { code?: string } }; + if (toolOutput.output?.code) { + return [ + { + type: "text", + text: `The existing code is:\n\`\`\`\n${toolOutput.output.code}\n\`\`\``, + }, + ]; + } + } + } + return []; +} + +function convertAnnotations(messages: Message[]): MessageContentDetail[] { + // get all annotations from user messages + const annotations: Annotation[] = messages + .filter((message) => message.role === "user" && message.annotations) + .flatMap((message) => message.annotations?.map(getValidAnnotation) || []); + if (annotations.length === 0) return []; + + const content: MessageContentDetail[] = []; + annotations.forEach(({ type, data }) => { + // convert image + if (type === "image" && "url" in data && typeof data.url === "string") { + content.push({ + type: "image_url", + image_url: { + url: data.url, + }, + }); + } + // convert the content of files to a text message + if ( + type === "document_file" && + "files" in data && + Array.isArray(data.files) + ) { + const fileContent = data.files.map(getFileContent).join("\n"); + content.push({ + type: "text", + text: fileContent, + }); + } + }); + + return content; +} + +function getValidAnnotation(annotation: JSONValue): Annotation { + if ( + !( + annotation && + typeof annotation === "object" && + "type" in annotation && + typeof annotation.type === "string" && + "data" in annotation && + annotation.data && + typeof annotation.data === "object" + ) + ) { + throw new Error("Client sent invalid annotation. Missing data and type"); + } + return { type: annotation.type, data: annotation.data }; +} + +// validate and get all annotations of a specific type or role from the frontend messages +export function getAnnotations< + T extends Annotation["data"] = Annotation["data"], +>( + messages: Message[], + options?: { + role?: Message["role"]; // message role + type?: Annotation["type"]; // annotation type + }, +): { + type: string; + data: T; +}[] { + const messagesByRole = options?.role + ? messages.filter((msg) => msg.role === options?.role) + : messages; + const annotations = getAllAnnotations(messagesByRole); + const annotationsByType = options?.type + ? annotations.filter((a) => a.type === options.type) + : annotations; + return annotationsByType as { type: string; data: T }[]; +} diff --git a/app/api/chat/llamaindex/streaming/events.ts b/app/api/chat/llamaindex/streaming/events.ts new file mode 100644 index 0000000..538e001 --- /dev/null +++ b/app/api/chat/llamaindex/streaming/events.ts @@ -0,0 +1,182 @@ +import { StreamData } from "ai"; +import { + CallbackManager, + LLamaCloudFileService, + Metadata, + MetadataMode, + NodeWithScore, + ToolCall, + ToolOutput, +} from "llamaindex"; +import path from "node:path"; +import { DATA_DIR } from "../../engine/loader"; +import { downloadFile } from "./file"; + +const LLAMA_CLOUD_DOWNLOAD_FOLDER = "output/llamacloud"; + +export function appendSourceData( + data: StreamData, + sourceNodes?: NodeWithScore[], +) { + if (!sourceNodes?.length) return; + try { + const nodes = sourceNodes.map((node) => ({ + metadata: node.node.metadata, + id: node.node.id_, + score: node.score ?? null, + url: getNodeUrl(node.node.metadata), + text: node.node.getContent(MetadataMode.NONE), + })); + data.appendMessageAnnotation({ + type: "sources", + data: { + nodes, + }, + }); + } catch (error) { + console.error("Error appending source data:", error); + } +} + +export function appendEventData(data: StreamData, title?: string) { + if (!title) return; + data.appendMessageAnnotation({ + type: "events", + data: { + title, + }, + }); +} + +export function appendToolData( + data: StreamData, + toolCall: ToolCall, + toolOutput: ToolOutput, +) { + data.appendMessageAnnotation({ + type: "tools", + data: { + toolCall: { + id: toolCall.id, + name: toolCall.name, + input: toolCall.input, + }, + toolOutput: { + output: toolOutput.output, + isError: toolOutput.isError, + }, + }, + }); +} + +export function createCallbackManager(stream: StreamData) { + const callbackManager = new CallbackManager(); + + callbackManager.on("retrieve-end", (data) => { + const { nodes, query } = data.detail; + appendSourceData(stream, nodes); + appendEventData(stream, `Retrieving context for query: '${query.query}'`); + appendEventData( + stream, + `Retrieved ${nodes.length} sources to use as context for the query`, + ); + downloadFilesFromNodes(nodes); // don't await to avoid blocking chat streaming + }); + + callbackManager.on("llm-tool-call", (event) => { + const { name, input } = event.detail.toolCall; + const inputString = Object.entries(input) + .map(([key, value]) => `${key}: ${value}`) + .join(", "); + appendEventData( + stream, + `Using tool: '${name}' with inputs: '${inputString}'`, + ); + }); + + callbackManager.on("llm-tool-result", (event) => { + const { toolCall, toolResult } = event.detail; + appendToolData(stream, toolCall, toolResult); + }); + + return callbackManager; +} + +function getNodeUrl(metadata: Metadata) { + if (!process.env.FILESERVER_URL_PREFIX) { + console.warn( + "FILESERVER_URL_PREFIX is not set. File URLs will not be generated.", + ); + } + const fileName = metadata["file_name"]; + if (fileName && process.env.FILESERVER_URL_PREFIX) { + // file_name exists and file server is configured + const pipelineId = metadata["pipeline_id"]; + if (pipelineId) { + const name = toDownloadedName(pipelineId, fileName); + return `${process.env.FILESERVER_URL_PREFIX}/${LLAMA_CLOUD_DOWNLOAD_FOLDER}/${name}`; + } + const isPrivate = metadata["private"] === "true"; + if (isPrivate) { + return `${process.env.FILESERVER_URL_PREFIX}/output/uploaded/${fileName}`; + } + const filePath = metadata["file_path"]; + const dataDir = path.resolve(DATA_DIR); + + if (filePath && dataDir) { + const relativePath = path.relative(dataDir, filePath); + return `${process.env.FILESERVER_URL_PREFIX}/data/${relativePath}`; + } + } + // fallback to URL in metadata (e.g. for websites) + return metadata["URL"]; +} + +async function downloadFilesFromNodes(nodes: NodeWithScore[]) { + try { + const files = nodesToLlamaCloudFiles(nodes); + for (const { pipelineId, fileName, downloadedName } of files) { + const downloadUrl = await LLamaCloudFileService.getFileUrl( + pipelineId, + fileName, + ); + if (downloadUrl) { + await downloadFile( + downloadUrl, + downloadedName, + LLAMA_CLOUD_DOWNLOAD_FOLDER, + ); + } + } + } catch (error) { + console.error("Error downloading files from nodes:", error); + } +} + +function nodesToLlamaCloudFiles(nodes: NodeWithScore[]) { + const files: Array<{ + pipelineId: string; + fileName: string; + downloadedName: string; + }> = []; + for (const node of nodes) { + const pipelineId = node.node.metadata["pipeline_id"]; + const fileName = node.node.metadata["file_name"]; + if (!pipelineId || !fileName) continue; + const isDuplicate = files.some( + (f) => f.pipelineId === pipelineId && f.fileName === fileName, + ); + if (!isDuplicate) { + files.push({ + pipelineId, + fileName, + downloadedName: toDownloadedName(pipelineId, fileName), + }); + } + } + return files; +} + +function toDownloadedName(pipelineId: string, fileName: string) { + return `${pipelineId}$${fileName}`; +} diff --git a/app/api/chat/llamaindex/streaming/file.ts b/app/api/chat/llamaindex/streaming/file.ts new file mode 100644 index 0000000..b5d1caa --- /dev/null +++ b/app/api/chat/llamaindex/streaming/file.ts @@ -0,0 +1,35 @@ +import fs from "node:fs"; +import https from "node:https"; +import path from "node:path"; + +export async function downloadFile( + urlToDownload: string, + filename: string, + folder = "output/uploaded", +) { + try { + const downloadedPath = path.join(folder, filename); + + // Check if file already exists + if (fs.existsSync(downloadedPath)) return; + + const file = fs.createWriteStream(downloadedPath); + https + .get(urlToDownload, (response) => { + response.pipe(file); + file.on("finish", () => { + file.close(() => { + console.log("File downloaded successfully"); + }); + }); + }) + .on("error", (err) => { + fs.unlink(downloadedPath, () => { + console.error("Error downloading file:", err); + throw err; + }); + }); + } catch (error) { + throw new Error(`Error downloading file: ${error}`); + } +} diff --git a/app/api/chat/llamaindex/streaming/suggestion.ts b/app/api/chat/llamaindex/streaming/suggestion.ts new file mode 100644 index 0000000..d8949cc --- /dev/null +++ b/app/api/chat/llamaindex/streaming/suggestion.ts @@ -0,0 +1,43 @@ +import { ChatMessage, Settings } from "llamaindex"; + +export async function generateNextQuestions(conversation: ChatMessage[]) { + const llm = Settings.llm; + const NEXT_QUESTION_PROMPT = process.env.NEXT_QUESTION_PROMPT; + if (!NEXT_QUESTION_PROMPT) { + return []; + } + + // Format conversation + const conversationText = conversation + .map((message) => `${message.role}: ${message.content}`) + .join("\n"); + const message = NEXT_QUESTION_PROMPT.replace( + "{conversation}", + conversationText, + ); + + try { + const response = await llm.complete({ prompt: message }); + const questions = extractQuestions(response.text); + return questions; + } catch (error) { + console.error("Error when generating the next questions: ", error); + return []; + } +} + +// TODO: instead of parsing the LLM's result we can use structured predict, once LITS supports it +function extractQuestions(text: string): string[] { + // Extract the text inside the triple backticks + // @ts-ignore + const contentMatch = text.match(/```(.*?)```/s); + const content = contentMatch ? contentMatch[1] : ""; + + // Split the content by newlines to get each question + const questions = content + .split("\n") + .map((question) => question.trim()) + .filter((question) => question !== ""); + + return questions; +} diff --git a/app/api/chat/route.ts b/app/api/chat/route.ts index 86ab494..5c0c0c6 100644 --- a/app/api/chat/route.ts +++ b/app/api/chat/route.ts @@ -1,11 +1,16 @@ import { initObservability } from "@/app/observability"; -import { Message, StreamData, StreamingTextResponse } from "ai"; -import { ChatMessage, MessageContent, Settings } from "llamaindex"; +import { LlamaIndexAdapter, Message, StreamData } from "ai"; +import { ChatMessage, Settings } from "llamaindex"; import { NextRequest, NextResponse } from "next/server"; import { createChatEngine } from "./engine/chat"; import { initSettings } from "./engine/settings"; -import { LlamaIndexStream } from "./llamaindex-stream"; -import { createCallbackManager } from "./stream-helper"; +import { + isValidMessages, + retrieveDocumentIds, + retrieveMessageContent, +} from "./llamaindex/streaming/annotations"; +import { createCallbackManager } from "./llamaindex/streaming/events"; +import { generateNextQuestions } from "./llamaindex/streaming/suggestion"; initObservability(); initSettings(); @@ -13,31 +18,14 @@ initSettings(); export const runtime = "nodejs"; export const dynamic = "force-dynamic"; -const convertMessageContent = ( - textMessage: string, - imageUrl: string | undefined, -): MessageContent => { - if (!imageUrl) return textMessage; - return [ - { - type: "text", - text: textMessage, - }, - { - type: "image_url", - image_url: { - url: imageUrl, - }, - }, - ]; -}; - export async function POST(request: NextRequest) { + // Init Vercel AI StreamData and timeout + const vercelStreamData = new StreamData(); + try { const body = await request.json(); - const { messages, data }: { messages: Message[]; data: any } = body; - const userMessage = messages.pop(); - if (!messages || !userMessage || userMessage.role !== "user") { + const { messages, data }: { messages: Message[]; data?: any } = body; + if (!isValidMessages(messages)) { return NextResponse.json( { error: @@ -47,38 +35,47 @@ export async function POST(request: NextRequest) { ); } - const chatEngine = await createChatEngine(); - - // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format - const userMessageContent = convertMessageContent( - userMessage.content, - data?.imageUrl, - ); + // retrieve document ids from the annotations of all messages (if any) + const ids = retrieveDocumentIds(messages); + // create chat engine with index using the document ids + const chatEngine = await createChatEngine(ids, data); - // Init Vercel AI StreamData - const vercelStreamData = new StreamData(); + // retrieve user message content from Vercel/AI format + const userMessageContent = retrieveMessageContent(messages); // Setup callbacks const callbackManager = createCallbackManager(vercelStreamData); + const chatHistory: ChatMessage[] = messages.slice(0, -1) as ChatMessage[]; // Calling LlamaIndex's ChatEngine to get a streamed response const response = await Settings.withCallbackManager(callbackManager, () => { return chatEngine.chat({ message: userMessageContent, - chatHistory: messages as ChatMessage[], + chatHistory, stream: true, }); }); - // Transform LlamaIndex stream to Vercel/AI format - const stream = LlamaIndexStream(response, vercelStreamData, { - parserOptions: { - image_url: data?.imageUrl, - }, - }); + const onCompletion = (content: string) => { + chatHistory.push({ role: "assistant", content: content }); + generateNextQuestions(chatHistory) + .then((questions: string[]) => { + if (questions.length > 0) { + vercelStreamData.appendMessageAnnotation({ + type: "suggested_questions", + data: questions, + }); + } + }) + .finally(() => { + vercelStreamData.close(); + }); + }; - // Return a StreamingTextResponse, which can be consumed by the Vercel/AI client - return new StreamingTextResponse(stream, {}, vercelStreamData); + return LlamaIndexAdapter.toDataStreamResponse(response, { + data: vercelStreamData, + callbacks: { onCompletion }, + }); } catch (error) { console.error("[LlamaIndex]", error); return NextResponse.json( diff --git a/app/api/chat/stream-helper.ts b/app/api/chat/stream-helper.ts deleted file mode 100644 index 9f1a886..0000000 --- a/app/api/chat/stream-helper.ts +++ /dev/null @@ -1,97 +0,0 @@ -import { StreamData } from "ai"; -import { - CallbackManager, - Metadata, - NodeWithScore, - ToolCall, - ToolOutput, -} from "llamaindex"; - -export function appendImageData(data: StreamData, imageUrl?: string) { - if (!imageUrl) return; - data.appendMessageAnnotation({ - type: "image", - data: { - url: imageUrl, - }, - }); -} - -export function appendSourceData( - data: StreamData, - sourceNodes?: NodeWithScore[], -) { - if (!sourceNodes?.length) return; - data.appendMessageAnnotation({ - type: "sources", - data: { - nodes: sourceNodes.map((node) => ({ - ...node.node.toMutableJSON(), - id: node.node.id_, - score: node.score ?? null, - })), - }, - }); -} - -export function appendEventData(data: StreamData, title?: string) { - if (!title) return; - data.appendMessageAnnotation({ - type: "events", - data: { - title, - }, - }); -} - -export function appendToolData( - data: StreamData, - toolCall: ToolCall, - toolOutput: ToolOutput, -) { - data.appendMessageAnnotation({ - type: "tools", - data: { - toolCall: { - id: toolCall.id, - name: toolCall.name, - input: toolCall.input, - }, - toolOutput: { - output: toolOutput.output, - isError: toolOutput.isError, - }, - }, - }); -} - -export function createCallbackManager(stream: StreamData) { - const callbackManager = new CallbackManager(); - - callbackManager.on("retrieve", (data) => { - const { nodes, query } = data.detail; - appendEventData(stream, `Retrieving context for query: '${query}'`); - appendEventData( - stream, - `Retrieved ${nodes.length} sources to use as context for the query`, - ); - }); - - callbackManager.on("llm-tool-call", (event) => { - const { name, input } = event.detail.payload.toolCall; - const inputString = Object.entries(input) - .map(([key, value]) => `${key}: ${value}`) - .join(", "); - appendEventData( - stream, - `Using tool: '${name}' with inputs: '${inputString}'`, - ); - }); - - callbackManager.on("llm-tool-result", (event) => { - const { toolCall, toolResult } = event.detail.payload; - appendToolData(stream, toolCall, toolResult); - }); - - return callbackManager; -} diff --git a/app/api/chat/upload/route.ts b/app/api/chat/upload/route.ts new file mode 100644 index 0000000..05939e3 --- /dev/null +++ b/app/api/chat/upload/route.ts @@ -0,0 +1,38 @@ +import { NextRequest, NextResponse } from "next/server"; +import { getDataSource } from "../engine"; +import { initSettings } from "../engine/settings"; +import { uploadDocument } from "../llamaindex/documents/upload"; + +initSettings(); + +export const runtime = "nodejs"; +export const dynamic = "force-dynamic"; + +export async function POST(request: NextRequest) { + try { + const { + name, + base64, + params, + }: { + name: string; + base64: string; + params?: any; + } = await request.json(); + if (!base64 || !name) { + return NextResponse.json( + { error: "base64 and name is required in the request body" }, + { status: 400 }, + ); + } + const index = await getDataSource(params); + const documentFile = await uploadDocument(index, name, base64); + return NextResponse.json(documentFile); + } catch (error) { + console.error("[Upload API]", error); + return NextResponse.json( + { error: (error as Error).message }, + { status: 500 }, + ); + } +} diff --git a/app/api/files/[...slug]/route.ts b/app/api/files/[...slug]/route.ts index 03c285e..7ccaeda 100644 --- a/app/api/files/[...slug]/route.ts +++ b/app/api/files/[...slug]/route.ts @@ -1,6 +1,7 @@ import { readFile } from "fs/promises"; import { NextRequest, NextResponse } from "next/server"; import path from "path"; +import { DATA_DIR } from "../../chat/engine/loader"; /** * This API is to get file data from allowed folders @@ -8,9 +9,9 @@ import path from "path"; */ export async function GET( _request: NextRequest, - { params }: { params: { slug: string[] } }, + { params }: { params: Promise<{ slug: string[] }> }, ) { - const slug = params.slug; + const slug = (await params).slug; if (!slug) { return NextResponse.json({ detail: "Missing file slug" }, { status: 400 }); @@ -20,15 +21,19 @@ export async function GET( return NextResponse.json({ detail: "Invalid file path" }, { status: 400 }); } - const [folder, ...pathTofile] = params.slug; // data, file.pdf - const allowedFolders = ["data", "tool-output"]; + const [folder, ...pathTofile] = slug; // data, file.pdf + const allowedFolders = ["data", "output"]; if (!allowedFolders.includes(folder)) { return NextResponse.json({ detail: "No permission" }, { status: 400 }); } try { - const filePath = path.join(process.cwd(), folder, path.join(...pathTofile)); + const filePath = path.join( + process.cwd(), + folder === "data" ? DATA_DIR : folder, + path.join(...pathTofile), + ); const blob = await readFile(filePath); return new NextResponse(blob, { diff --git a/app/api/sandbox/route.ts b/app/api/sandbox/route.ts new file mode 100644 index 0000000..07336cc --- /dev/null +++ b/app/api/sandbox/route.ts @@ -0,0 +1,155 @@ +/* + * Copyright 2023 FoundryLabs, Inc. + * Portions of this file are copied from the e2b project (https://github.com/e2b-dev/ai-artifacts) + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { ExecutionError, Result, Sandbox } from "@e2b/code-interpreter"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { saveDocument } from "../chat/llamaindex/documents/helper"; + +type CodeArtifact = { + commentary: string; + template: string; + title: string; + description: string; + additional_dependencies: string[]; + has_additional_dependencies: boolean; + install_dependencies_command: string; + port: number | null; + file_path: string; + code: string; + files?: string[]; +}; + +const sandboxTimeout = 10 * 60 * 1000; // 10 minute in ms + +export const maxDuration = 60; + +const OUTPUT_DIR = path.join("output", "tools"); + +export type ExecutionResult = { + template: string; + stdout: string[]; + stderr: string[]; + runtimeError?: ExecutionError; + outputUrls: Array<{ url: string; filename: string }>; + url: string; +}; + +// see https://github.com/e2b-dev/fragments/tree/main/sandbox-templates +const SUPPORTED_TEMPLATES = [ + "nextjs-developer", + "vue-developer", + "streamlit-developer", + "gradio-developer", +]; + +export async function POST(req: Request) { + const { artifact }: { artifact: CodeArtifact } = await req.json(); + + let sbx: Sandbox; + const sandboxOpts = { + metadata: { template: artifact.template, userID: "default" }, + timeoutMs: sandboxTimeout, + }; + if (SUPPORTED_TEMPLATES.includes(artifact.template)) { + sbx = await Sandbox.create(artifact.template, sandboxOpts); + } else { + sbx = await Sandbox.create(sandboxOpts); + } + console.log("Created sandbox", sbx.sandboxId); + + // Install packages + if (artifact.has_additional_dependencies) { + await sbx.commands.run(artifact.install_dependencies_command); + console.log( + `Installed dependencies: ${artifact.additional_dependencies.join(", ")} in sandbox ${sbx.sandboxId}`, + ); + } + + // Copy files + if (artifact.files) { + artifact.files.forEach(async (sandboxFilePath) => { + const fileName = path.basename(sandboxFilePath); + const localFilePath = path.join("output", "uploaded", fileName); + const fileContent = await fs.readFile(localFilePath); + + const arrayBuffer = new Uint8Array(fileContent).buffer; + await sbx.files.write(sandboxFilePath, arrayBuffer); + console.log(`Copied file to ${sandboxFilePath} in ${sbx.sandboxId}`); + }); + } + + // Copy code to fs + if (artifact.code && Array.isArray(artifact.code)) { + artifact.code.forEach(async (file) => { + await sbx.files.write(file.file_path, file.file_content); + console.log(`Copied file to ${file.file_path} in ${sbx.sandboxId}`); + }); + } else { + await sbx.files.write(artifact.file_path, artifact.code); + console.log(`Copied file to ${artifact.file_path} in ${sbx.sandboxId}`); + } + + // Execute code or return a URL to the running sandbox + if (artifact.template === "code-interpreter-multilang") { + const result = await sbx.runCode(artifact.code || ""); + await sbx.kill(); + const outputUrls = await downloadCellResults(result.results); + return new Response( + JSON.stringify({ + template: artifact.template, + stdout: result.logs.stdout, + stderr: result.logs.stderr, + runtimeError: result.error, + outputUrls: outputUrls, + }), + ); + } else { + return new Response( + JSON.stringify({ + template: artifact.template, + url: `https://${sbx?.getHost(artifact.port || 80)}`, + }), + ); + } +} + +async function downloadCellResults( + cellResults?: Result[], +): Promise> { + if (!cellResults) return []; + + const results = await Promise.all( + cellResults.map(async (res) => { + const formats = res.formats(); // available formats in the result + const formatResults = await Promise.all( + formats + .filter((ext) => ["png", "svg", "jpeg", "pdf"].includes(ext)) + .map(async (ext) => { + const filename = `${crypto.randomUUID()}.${ext}`; + const base64 = res[ext as keyof Result]; + const buffer = Buffer.from(base64, "base64"); + const fileurl = await saveDocument( + path.join(OUTPUT_DIR, filename), + buffer, + ); + return { url: fileurl, filename }; + }), + ); + return formatResults; + }), + ); + return results.flat(); +} diff --git a/app/components/chat-section.tsx b/app/components/chat-section.tsx index 4f88322..edf7f33 100644 --- a/app/components/chat-section.tsx +++ b/app/components/chat-section.tsx @@ -1,44 +1,32 @@ "use client"; +import { ChatSection as ChatSectionUI } from "@llamaindex/chat-ui"; +import "@llamaindex/chat-ui/styles/markdown.css"; +import "@llamaindex/chat-ui/styles/pdf.css"; import { useChat } from "ai/react"; -import { ChatInput, ChatMessages } from "./ui/chat"; +import CustomChatInput from "./ui/chat/chat-input"; +import CustomChatMessages from "./ui/chat/chat-messages"; +import { useClientConfig } from "./ui/chat/hooks/use-config"; export default function ChatSection() { - const { - messages, - input, - isLoading, - handleSubmit, - handleInputChange, - reload, - stop, - } = useChat({ - api: process.env.NEXT_PUBLIC_CHAT_API, - headers: { - "Content-Type": "application/json", // using JSON because of vercel/ai 2.2.26 - }, + const { backend } = useClientConfig(); + const handler = useChat({ + api: `${backend}/api/chat`, onError: (error: unknown) => { if (!(error instanceof Error)) throw error; - const message = JSON.parse(error.message); - alert(message.detail); + let errorMessage: string; + try { + errorMessage = JSON.parse(error.message).detail; + } catch (e) { + errorMessage = error.message; + } + alert(errorMessage); }, }); - return ( -
- - -
+ + + + ); } diff --git a/app/components/header.tsx b/app/components/header.tsx index 2b0e488..f02ce73 100644 --- a/app/components/header.tsx +++ b/app/components/header.tsx @@ -7,7 +7,7 @@ export default function Header() { Get started by editing  app/page.tsx

-
+
, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)); +AccordionItem.displayName = "AccordionItem"; + +const AccordionTrigger = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + svg]:rotate-180", + className, + )} + {...props} + > + {children} + + + +)); +AccordionTrigger.displayName = AccordionPrimitive.Trigger.displayName; + +const AccordionContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + +
{children}
+
+)); +AccordionContent.displayName = AccordionPrimitive.Content.displayName; + +export { Accordion, AccordionContent, AccordionItem, AccordionTrigger }; diff --git a/app/components/ui/card.tsx b/app/components/ui/card.tsx new file mode 100644 index 0000000..da14564 --- /dev/null +++ b/app/components/ui/card.tsx @@ -0,0 +1,82 @@ +import * as React from "react"; +import { cn } from "./lib/utils"; + +const Card = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +Card.displayName = "Card"; + +const CardHeader = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +CardHeader.displayName = "CardHeader"; + +const CardTitle = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +CardTitle.displayName = "CardTitle"; + +const CardDescription = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +CardDescription.displayName = "CardDescription"; + +const CardContent = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +CardContent.displayName = "CardContent"; + +const CardFooter = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +CardFooter.displayName = "CardFooter"; + +export { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +}; diff --git a/app/components/ui/chat/chat-actions.tsx b/app/components/ui/chat/chat-actions.tsx deleted file mode 100644 index 151ef61..0000000 --- a/app/components/ui/chat/chat-actions.tsx +++ /dev/null @@ -1,28 +0,0 @@ -import { PauseCircle, RefreshCw } from "lucide-react"; - -import { Button } from "../button"; -import { ChatHandler } from "./chat.interface"; - -export default function ChatActions( - props: Pick & { - showReload?: boolean; - showStop?: boolean; - }, -) { - return ( -
- {props.showStop && ( - - )} - {props.showReload && ( - - )} -
- ); -} diff --git a/app/components/ui/chat/chat-avatar.tsx b/app/components/ui/chat/chat-avatar.tsx index ce04e30..cfa307c 100644 --- a/app/components/ui/chat/chat-avatar.tsx +++ b/app/components/ui/chat/chat-avatar.tsx @@ -1,8 +1,10 @@ +import { useChatMessage } from "@llamaindex/chat-ui"; import { User2 } from "lucide-react"; import Image from "next/image"; -export default function ChatAvatar({ role }: { role: string }) { - if (role === "user") { +export function ChatMessageAvatar() { + const { message } = useChatMessage(); + if (message.role === "user") { return (
diff --git a/app/components/ui/chat/chat-events.tsx b/app/components/ui/chat/chat-events.tsx deleted file mode 100644 index af30676..0000000 --- a/app/components/ui/chat/chat-events.tsx +++ /dev/null @@ -1,50 +0,0 @@ -import { ChevronDown, ChevronRight, Loader2 } from "lucide-react"; -import { useState } from "react"; -import { Button } from "../button"; -import { - Collapsible, - CollapsibleContent, - CollapsibleTrigger, -} from "../collapsible"; -import { EventData } from "./index"; - -export function ChatEvents({ - data, - isLoading, -}: { - data: EventData[]; - isLoading: boolean; -}) { - const [isOpen, setIsOpen] = useState(false); - - const buttonLabel = isOpen ? "Hide events" : "Show events"; - - const EventIcon = isOpen ? ( - - ) : ( - - ); - - return ( -
- - - - - -
- {data.map((eventItem, index) => ( -
- {eventItem.title} -
- ))} -
-
-
-
- ); -} diff --git a/app/components/ui/chat/chat-image.tsx b/app/components/ui/chat/chat-image.tsx deleted file mode 100644 index 0560921..0000000 --- a/app/components/ui/chat/chat-image.tsx +++ /dev/null @@ -1,17 +0,0 @@ -import Image from "next/image"; -import { type ImageData } from "./index"; - -export function ChatImage({ data }: { data: ImageData }) { - return ( -
- -
- ); -} diff --git a/app/components/ui/chat/chat-input.tsx b/app/components/ui/chat/chat-input.tsx index 435637e..7b11168 100644 --- a/app/components/ui/chat/chat-input.tsx +++ b/app/components/ui/chat/chat-input.tsx @@ -1,84 +1,81 @@ -import { useState } from "react"; -import { Button } from "../button"; -import FileUploader from "../file-uploader"; -import { Input } from "../input"; -import UploadImagePreview from "../upload-image-preview"; -import { ChatHandler } from "./chat.interface"; +"use client"; -export default function ChatInput( - props: Pick< - ChatHandler, - | "isLoading" - | "input" - | "onFileUpload" - | "onFileError" - | "handleSubmit" - | "handleInputChange" - > & { - multiModal?: boolean; - }, -) { - const [imageUrl, setImageUrl] = useState(null); +import { ChatInput, useChatUI, useFile } from "@llamaindex/chat-ui"; +import { DocumentInfo, ImagePreview } from "@llamaindex/chat-ui/widgets"; +import { LlamaCloudSelector } from "./custom/llama-cloud-selector"; +import { useClientConfig } from "./hooks/use-config"; - const onSubmit = (e: React.FormEvent) => { +export default function CustomChatInput() { + const { requestData, isLoading, input } = useChatUI(); + const { backend } = useClientConfig(); + const { + imageUrl, + setImageUrl, + uploadFile, + files, + removeDoc, + reset, + getAnnotations, + } = useFile({ uploadAPI: `${backend}/api/chat/upload` }); + + /** + * Handles file uploads. Overwrite to hook into the file upload behavior. + * @param file The file to upload + */ + const handleUploadFile = async (file: File) => { + // There's already an image uploaded, only allow one image at a time if (imageUrl) { - props.handleSubmit(e, { - data: { imageUrl: imageUrl }, - }); - setImageUrl(null); + alert("You can only upload one image at a time."); return; } - props.handleSubmit(e); - }; - - const onRemovePreviewImage = () => setImageUrl(null); - - const handleUploadImageFile = async (file: File) => { - const base64 = await new Promise((resolve, reject) => { - const reader = new FileReader(); - reader.readAsDataURL(file); - reader.onload = () => resolve(reader.result as string); - reader.onerror = (error) => reject(error); - }); - setImageUrl(base64); - }; - const handleUploadFile = async (file: File) => { try { - if (props.multiModal && file.type.startsWith("image/")) { - return await handleUploadImageFile(file); - } - props.onFileUpload?.(file); + // Upload the file and send with it the current request data + await uploadFile(file, requestData); } catch (error: any) { - props.onFileError?.(error.message); + // Show error message if upload fails + alert(error.message); } }; + // Get references to the upload files in message annotations format, see https://github.com/run-llama/chat-ui/blob/main/packages/chat-ui/src/hook/use-file.tsx#L56 + const annotations = getAnnotations(); + return ( -
- {imageUrl && ( - - )} -
- - - +
+ {/* Image preview section */} + {imageUrl && ( + setImageUrl(null)} /> + )} + {/* Document previews section */} + {files.length > 0 && ( +
+ {files.map((file) => ( + removeDoc(file)} + /> + ))} +
+ )}
- + + + + + + + ); } diff --git a/app/components/ui/chat/chat-message-content.tsx b/app/components/ui/chat/chat-message-content.tsx new file mode 100644 index 0000000..fa32f6f --- /dev/null +++ b/app/components/ui/chat/chat-message-content.tsx @@ -0,0 +1,44 @@ +import { + ChatMessage, + ContentPosition, + getSourceAnnotationData, + useChatMessage, + useChatUI, +} from "@llamaindex/chat-ui"; +import { DeepResearchCard } from "./custom/deep-research-card"; +import { Markdown } from "./custom/markdown"; +import { ToolAnnotations } from "./tools/chat-tools"; + +export function ChatMessageContent() { + const { isLoading, append } = useChatUI(); + const { message } = useChatMessage(); + const customContent = [ + { + // override the default markdown component + position: ContentPosition.MARKDOWN, + component: ( + + ), + }, + // add the deep research card + { + position: ContentPosition.CHAT_EVENTS, + component: , + }, + { + // add the tool annotations after events + position: ContentPosition.AFTER_EVENTS, + component: , + }, + ]; + return ( + + ); +} diff --git a/app/components/ui/chat/chat-message.tsx b/app/components/ui/chat/chat-message.tsx deleted file mode 100644 index da1d92e..0000000 --- a/app/components/ui/chat/chat-message.tsx +++ /dev/null @@ -1,127 +0,0 @@ -import { Check, Copy } from "lucide-react"; - -import { Message } from "ai"; -import { Fragment } from "react"; -import { Button } from "../button"; -import ChatAvatar from "./chat-avatar"; -import { ChatEvents } from "./chat-events"; -import { ChatImage } from "./chat-image"; -import { ChatSources } from "./chat-sources"; -import ChatTools from "./chat-tools"; -import { - AnnotationData, - EventData, - ImageData, - MessageAnnotation, - MessageAnnotationType, - SourceData, - ToolData, -} from "./index"; -import Markdown from "./markdown"; -import { useCopyToClipboard } from "./use-copy-to-clipboard"; - -type ContentDisplayConfig = { - order: number; - component: JSX.Element | null; -}; - -function getAnnotationData( - annotations: MessageAnnotation[], - type: MessageAnnotationType, -): T[] { - return annotations.filter((a) => a.type === type).map((a) => a.data as T); -} - -function ChatMessageContent({ - message, - isLoading, -}: { - message: Message; - isLoading: boolean; -}) { - const annotations = message.annotations as MessageAnnotation[] | undefined; - if (!annotations?.length) return ; - - const imageData = getAnnotationData( - annotations, - MessageAnnotationType.IMAGE, - ); - const eventData = getAnnotationData( - annotations, - MessageAnnotationType.EVENTS, - ); - const sourceData = getAnnotationData( - annotations, - MessageAnnotationType.SOURCES, - ); - const toolData = getAnnotationData( - annotations, - MessageAnnotationType.TOOLS, - ); - - const contents: ContentDisplayConfig[] = [ - { - order: -3, - component: imageData[0] ? : null, - }, - { - order: -2, - component: - eventData.length > 0 ? ( - - ) : null, - }, - { - order: -1, - component: toolData[0] ? : null, - }, - { - order: 0, - component: , - }, - { - order: 1, - component: sourceData[0] ? : null, - }, - ]; - - return ( -
- {contents - .sort((a, b) => a.order - b.order) - .map((content, index) => ( - {content.component} - ))} -
- ); -} - -export default function ChatMessage({ - chatMessage, - isLoading, -}: { - chatMessage: Message; - isLoading: boolean; -}) { - const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 2000 }); - return ( -
- -
- - -
-
- ); -} diff --git a/app/components/ui/chat/chat-messages.tsx b/app/components/ui/chat/chat-messages.tsx index 55cfd9e..99e151a 100644 --- a/app/components/ui/chat/chat-messages.tsx +++ b/app/components/ui/chat/chat-messages.tsx @@ -1,69 +1,30 @@ -import { Loader2 } from "lucide-react"; -import { useEffect, useRef } from "react"; +"use client"; -import ChatActions from "./chat-actions"; -import ChatMessage from "./chat-message"; -import { ChatHandler } from "./chat.interface"; - -export default function ChatMessages( - props: Pick, -) { - const scrollableChatContainerRef = useRef(null); - const messageLength = props.messages.length; - const lastMessage = props.messages[messageLength - 1]; - - const scrollToBottom = () => { - if (scrollableChatContainerRef.current) { - scrollableChatContainerRef.current.scrollTop = - scrollableChatContainerRef.current.scrollHeight; - } - }; - - const isLastMessageFromAssistant = - messageLength > 0 && lastMessage?.role !== "user"; - const showReload = - props.reload && !props.isLoading && isLastMessageFromAssistant; - const showStop = props.stop && props.isLoading; - - // `isPending` indicate - // that stream response is not yet received from the server, - // so we show a loading indicator to give a better UX. - const isPending = props.isLoading && !isLastMessageFromAssistant; - - useEffect(() => { - scrollToBottom(); - }, [messageLength, lastMessage]); +import { ChatMessage, ChatMessages, useChatUI } from "@llamaindex/chat-ui"; +import { ChatMessageAvatar } from "./chat-avatar"; +import { ChatMessageContent } from "./chat-message-content"; +import { ChatStarter } from "./chat-starter"; +export default function CustomChatMessages() { + const { messages } = useChatUI(); return ( -
-
- {props.messages.map((m, i) => { - const isLoadingMessage = i === messageLength - 1 && props.isLoading; - return ( - - ); - })} - {isPending && ( -
- -
- )} -
-
- -
-
+ + + {messages.map((message, index) => ( + + + + + + ))} + + + + + ); } diff --git a/app/components/ui/chat/chat-sources.tsx b/app/components/ui/chat/chat-sources.tsx deleted file mode 100644 index 893541b..0000000 --- a/app/components/ui/chat/chat-sources.tsx +++ /dev/null @@ -1,150 +0,0 @@ -import { Check, Copy } from "lucide-react"; -import { useMemo } from "react"; -import { Button } from "../button"; -import { HoverCard, HoverCardContent, HoverCardTrigger } from "../hover-card"; -import { getStaticFileDataUrl } from "../lib/url"; -import { SourceData, SourceNode } from "./index"; -import { useCopyToClipboard } from "./use-copy-to-clipboard"; -import PdfDialog from "./widgets/PdfDialog"; - -const DATA_SOURCE_FOLDER = "data"; -const SCORE_THRESHOLD = 0.3; - -function SourceNumberButton({ index }: { index: number }) { - return ( -
- {index + 1} -
- ); -} - -enum NODE_TYPE { - URL, - FILE, - UNKNOWN, -} - -type NodeInfo = { - id: string; - type: NODE_TYPE; - path?: string; - url?: string; -}; - -function getNodeInfo(node: SourceNode): NodeInfo { - if (typeof node.metadata["URL"] === "string") { - const url = node.metadata["URL"]; - return { - id: node.id, - type: NODE_TYPE.URL, - path: url, - url, - }; - } - if (typeof node.metadata["file_path"] === "string") { - const fileName = node.metadata["file_name"] as string; - const filePath = `${DATA_SOURCE_FOLDER}/${fileName}`; - return { - id: node.id, - type: NODE_TYPE.FILE, - path: node.metadata["file_path"], - url: getStaticFileDataUrl(filePath), - }; - } - - return { - id: node.id, - type: NODE_TYPE.UNKNOWN, - }; -} - -export function ChatSources({ data }: { data: SourceData }) { - const sources: NodeInfo[] = useMemo(() => { - // aggregate nodes by url or file_path (get the highest one by score) - const nodesByPath: { [path: string]: NodeInfo } = {}; - - data.nodes - .filter((node) => (node.score ?? 1) > SCORE_THRESHOLD) - .sort((a, b) => (b.score ?? 1) - (a.score ?? 1)) - .forEach((node) => { - const nodeInfo = getNodeInfo(node); - const key = nodeInfo.path ?? nodeInfo.id; // use id as key for UNKNOWN type - if (!nodesByPath[key]) { - nodesByPath[key] = nodeInfo; - } - }); - - return Object.values(nodesByPath); - }, [data.nodes]); - - if (sources.length === 0) return null; - - return ( -
- Sources: -
- {sources.map((nodeInfo: NodeInfo, index: number) => { - if (nodeInfo.path?.endsWith(".pdf")) { - return ( - } - /> - ); - } - return ( -
- - - - - - - - -
- ); - })} -
-
- ); -} - -function NodeInfo({ nodeInfo }: { nodeInfo: NodeInfo }) { - const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 1000 }); - - if (nodeInfo.type !== NODE_TYPE.UNKNOWN) { - // this is a node generated by the web loader or file loader, - // add a link to view its URL and a button to copy the URL to the clipboard - return ( -
- - {nodeInfo.path} - - -
- ); - } - - // node generated by unknown loader, implement renderer by analyzing logged out metadata - return ( -

- Sorry, unknown node type. Please add a new renderer in the NodeInfo - component. -

- ); -} diff --git a/app/components/ui/chat/chat-starter.tsx b/app/components/ui/chat/chat-starter.tsx new file mode 100644 index 0000000..0f45531 --- /dev/null +++ b/app/components/ui/chat/chat-starter.tsx @@ -0,0 +1,26 @@ +import { useChatUI } from "@llamaindex/chat-ui"; +import { StarterQuestions } from "@llamaindex/chat-ui/widgets"; +import { useEffect, useState } from "react"; +import { useClientConfig } from "./hooks/use-config"; + +export function ChatStarter() { + const { append } = useChatUI(); + const { backend } = useClientConfig(); + const [starterQuestions, setStarterQuestions] = useState(); + + useEffect(() => { + if (!starterQuestions) { + fetch(`${backend}/api/chat/config`) + .then((response) => response.json()) + .then((data) => { + if (data?.starterQuestions) { + setStarterQuestions(data.starterQuestions); + } + }) + .catch((error) => console.error("Error fetching config", error)); + } + }, [starterQuestions, backend]); + + if (!starterQuestions?.length) return null; + return ; +} diff --git a/app/components/ui/chat/chat-tools.tsx b/app/components/ui/chat/chat-tools.tsx deleted file mode 100644 index 268b436..0000000 --- a/app/components/ui/chat/chat-tools.tsx +++ /dev/null @@ -1,26 +0,0 @@ -import { ToolData } from "./index"; -import { WeatherCard, WeatherData } from "./widgets/WeatherCard"; - -// TODO: If needed, add displaying more tool outputs here -export default function ChatTools({ data }: { data: ToolData }) { - if (!data) return null; - const { toolCall, toolOutput } = data; - - if (toolOutput.isError) { - return ( -
- There was an error when calling the tool {toolCall.name} with input:{" "} -
- {JSON.stringify(toolCall.input)} -
- ); - } - - switch (toolCall.name) { - case "get_weather_information": - const weatherData = toolOutput.output as unknown as WeatherData; - return ; - default: - return null; - } -} diff --git a/app/components/ui/chat/chat.interface.ts b/app/components/ui/chat/chat.interface.ts deleted file mode 100644 index 5b9f225..0000000 --- a/app/components/ui/chat/chat.interface.ts +++ /dev/null @@ -1,18 +0,0 @@ -import { Message } from "ai"; - -export interface ChatHandler { - messages: Message[]; - input: string; - isLoading: boolean; - handleSubmit: ( - e: React.FormEvent, - ops?: { - data?: any; - }, - ) => void; - handleInputChange: (e: React.ChangeEvent) => void; - reload?: () => void; - stop?: () => void; - onFileUpload?: (file: File) => Promise; - onFileError?: (errMsg: string) => void; -} diff --git a/app/components/ui/chat/codeblock.tsx b/app/components/ui/chat/codeblock.tsx deleted file mode 100644 index 014a0fc..0000000 --- a/app/components/ui/chat/codeblock.tsx +++ /dev/null @@ -1,139 +0,0 @@ -"use client"; - -import { Check, Copy, Download } from "lucide-react"; -import { FC, memo } from "react"; -import { Prism, SyntaxHighlighterProps } from "react-syntax-highlighter"; -import { coldarkDark } from "react-syntax-highlighter/dist/cjs/styles/prism"; - -import { Button } from "../button"; -import { useCopyToClipboard } from "./use-copy-to-clipboard"; - -// TODO: Remove this when @type/react-syntax-highlighter is updated -const SyntaxHighlighter = Prism as unknown as FC; - -interface Props { - language: string; - value: string; -} - -interface languageMap { - [key: string]: string | undefined; -} - -export const programmingLanguages: languageMap = { - javascript: ".js", - python: ".py", - java: ".java", - c: ".c", - cpp: ".cpp", - "c++": ".cpp", - "c#": ".cs", - ruby: ".rb", - php: ".php", - swift: ".swift", - "objective-c": ".m", - kotlin: ".kt", - typescript: ".ts", - go: ".go", - perl: ".pl", - rust: ".rs", - scala: ".scala", - haskell: ".hs", - lua: ".lua", - shell: ".sh", - sql: ".sql", - html: ".html", - css: ".css", - // add more file extensions here, make sure the key is same as language prop in CodeBlock.tsx component -}; - -export const generateRandomString = (length: number, lowercase = false) => { - const chars = "ABCDEFGHJKLMNPQRSTUVWXY3456789"; // excluding similar looking characters like Z, 2, I, 1, O, 0 - let result = ""; - for (let i = 0; i < length; i++) { - result += chars.charAt(Math.floor(Math.random() * chars.length)); - } - return lowercase ? result.toLowerCase() : result; -}; - -const CodeBlock: FC = memo(({ language, value }) => { - const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 2000 }); - - const downloadAsFile = () => { - if (typeof window === "undefined") { - return; - } - const fileExtension = programmingLanguages[language] || ".file"; - const suggestedFileName = `file-${generateRandomString( - 3, - true, - )}${fileExtension}`; - const fileName = window.prompt("Enter file name" || "", suggestedFileName); - - if (!fileName) { - // User pressed cancel on prompt. - return; - } - - const blob = new Blob([value], { type: "text/plain" }); - const url = URL.createObjectURL(blob); - const link = document.createElement("a"); - link.download = fileName; - link.href = url; - link.style.display = "none"; - document.body.appendChild(link); - link.click(); - document.body.removeChild(link); - URL.revokeObjectURL(url); - }; - - const onCopy = () => { - if (isCopied) return; - copyToClipboard(value); - }; - - return ( -
-
- {language} -
- - -
-
- - {value} - -
- ); -}); -CodeBlock.displayName = "CodeBlock"; - -export { CodeBlock }; diff --git a/app/components/ui/chat/custom/deep-research-card.tsx b/app/components/ui/chat/custom/deep-research-card.tsx new file mode 100644 index 0000000..03d0bd8 --- /dev/null +++ b/app/components/ui/chat/custom/deep-research-card.tsx @@ -0,0 +1,216 @@ +"use client"; + +import { Message } from "@llamaindex/chat-ui"; +import { + AlertCircle, + CheckCircle2, + CircleDashed, + Clock, + NotebookPen, + Search, +} from "lucide-react"; +import { useMemo } from "react"; +import { + Accordion, + AccordionContent, + AccordionItem, + AccordionTrigger, +} from "../../accordion"; +import { Card, CardContent, CardHeader, CardTitle } from "../../card"; +import { cn } from "../../lib/utils"; +import { Markdown } from "./markdown"; + +// Streaming event types +type EventState = "pending" | "inprogress" | "done" | "error"; + +type DeepResearchEvent = { + type: "deep_research_event"; + data: { + event: "retrieve" | "analyze" | "answer"; + state: EventState; + id?: string; + question?: string; + answer?: string | null; + }; +}; + +// UI state types +type QuestionState = { + id: string; + question: string; + answer: string | null; + state: EventState; + isOpen: boolean; +}; + +type DeepResearchCardState = { + retrieve: { + state: EventState | null; + }; + analyze: { + state: EventState | null; + questions: QuestionState[]; + }; +}; + +interface DeepResearchCardProps { + message: Message; + className?: string; +} + +const stateIcon: Record = { + pending: , + inprogress: , + done: , + error: , +}; + +// Transform the state based on the event without mutations +const transformState = ( + state: DeepResearchCardState, + event: DeepResearchEvent, +): DeepResearchCardState => { + switch (event.data.event) { + case "answer": { + const { id, question, answer } = event.data; + if (!id || !question) return state; + + const updatedQuestions = state.analyze.questions.map((q) => { + if (q.id !== id) return q; + return { + ...q, + state: event.data.state, + answer: answer ?? q.answer, + }; + }); + + const newQuestion = !state.analyze.questions.some((q) => q.id === id) + ? [ + { + id, + question, + answer: answer ?? null, + state: event.data.state, + isOpen: false, + }, + ] + : []; + + return { + ...state, + analyze: { + ...state.analyze, + questions: [...updatedQuestions, ...newQuestion], + }, + }; + } + + case "retrieve": + case "analyze": + return { + ...state, + [event.data.event]: { + ...state[event.data.event], + state: event.data.state, + }, + }; + + default: + return state; + } +}; + +// Convert deep research events to state +const deepResearchEventsToState = ( + events: DeepResearchEvent[] | undefined, +): DeepResearchCardState => { + if (!events?.length) { + return { + retrieve: { state: null }, + analyze: { state: null, questions: [] }, + }; + } + + const initialState: DeepResearchCardState = { + retrieve: { state: null }, + analyze: { state: null, questions: [] }, + }; + + return events.reduce( + (acc: DeepResearchCardState, event: DeepResearchEvent) => + transformState(acc, event), + initialState, + ); +}; + +export function DeepResearchCard({ + message, + className, +}: DeepResearchCardProps) { + const deepResearchEvents = message.annotations as + | DeepResearchEvent[] + | undefined; + const hasDeepResearchEvents = deepResearchEvents?.some( + (event) => event.type === "deep_research_event", + ); + + const state = useMemo( + () => deepResearchEventsToState(deepResearchEvents), + [deepResearchEvents], + ); + + if (!hasDeepResearchEvents) { + return null; + } + + return ( + + + {state.retrieve.state !== null && ( + + + {state.retrieve.state === "inprogress" + ? "Searching..." + : "Search completed"} + + )} + {state.analyze.state !== null && ( + + + {state.analyze.state === "inprogress" ? "Analyzing..." : "Analysis"} + + )} + + + + {state.analyze.questions.length > 0 && ( + + {state.analyze.questions.map((question: QuestionState) => ( + + +
+
+ {stateIcon[question.state]} +
+ + {question.question} + +
+
+ {question.answer && ( + + + + )} +
+ ))} +
+ )} +
+
+ ); +} diff --git a/app/components/ui/chat/custom/llama-cloud-selector.tsx b/app/components/ui/chat/custom/llama-cloud-selector.tsx new file mode 100644 index 0000000..f40a33a --- /dev/null +++ b/app/components/ui/chat/custom/llama-cloud-selector.tsx @@ -0,0 +1,188 @@ +import { useChatUI } from "@llamaindex/chat-ui"; +import { Loader2 } from "lucide-react"; +import { useCallback, useEffect, useState } from "react"; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectTrigger, + SelectValue, +} from "../../select"; +import { useClientConfig } from "../hooks/use-config"; + +type LLamaCloudPipeline = { + id: string; + name: string; +}; + +type LLamaCloudProject = { + id: string; + organization_id: string; + name: string; + is_default: boolean; + pipelines: Array; +}; + +type PipelineConfig = { + project: string; // project name + pipeline: string; // pipeline name +}; + +type LlamaCloudConfig = { + projects?: LLamaCloudProject[]; + pipeline?: PipelineConfig; +}; + +export interface LlamaCloudSelectorProps { + onSelect?: (pipeline: PipelineConfig | undefined) => void; + defaultPipeline?: PipelineConfig; + shouldCheckValid?: boolean; +} + +export function LlamaCloudSelector({ + onSelect, + defaultPipeline, + shouldCheckValid = false, +}: LlamaCloudSelectorProps) { + const { backend } = useClientConfig(); + const { setRequestData } = useChatUI(); + const [config, setConfig] = useState(); + + const updateRequestParams = useCallback( + (pipeline?: PipelineConfig) => { + if (setRequestData) { + setRequestData({ + llamaCloudPipeline: pipeline, + }); + } else { + onSelect?.(pipeline); + } + }, + [onSelect, setRequestData], + ); + + useEffect(() => { + if (process.env.NEXT_PUBLIC_USE_LLAMACLOUD === "true" && !config) { + fetch(`${backend}/api/chat/config/llamacloud`) + .then((response) => { + if (!response.ok) { + return response.json().then((errorData) => { + window.alert( + `Error: ${JSON.stringify(errorData) || "Unknown error occurred"}`, + ); + }); + } + return response.json(); + }) + .then((data) => { + const pipeline = defaultPipeline ?? data.pipeline; // defaultPipeline will override pipeline in .env + setConfig({ ...data, pipeline }); + updateRequestParams(pipeline); + }) + .catch((error) => console.error("Error fetching config", error)); + } + }, [backend, config, defaultPipeline, updateRequestParams]); + + const setPipeline = (pipelineConfig?: PipelineConfig) => { + setConfig((prevConfig: any) => ({ + ...prevConfig, + pipeline: pipelineConfig, + })); + updateRequestParams(pipelineConfig); + }; + + const handlePipelineSelect = async (value: string) => { + setPipeline(JSON.parse(value) as PipelineConfig); + }; + + if (process.env.NEXT_PUBLIC_USE_LLAMACLOUD !== "true") { + return null; + } + + if (!config) { + return ( +
+ +
+ ); + } + + if (shouldCheckValid && !isValid(config.projects, config.pipeline)) { + return ( +

+ Invalid LlamaCloud configuration. Check console logs. +

+ ); + } + const { projects, pipeline } = config; + + return ( + + ); +} + +function isValid( + projects: LLamaCloudProject[] | undefined, + pipeline: PipelineConfig | undefined, + logErrors: boolean = true, +): boolean { + if (!projects?.length) return false; + if (!pipeline) return false; + const matchedProject = projects.find( + (project: LLamaCloudProject) => project.name === pipeline.project, + ); + if (!matchedProject) { + if (logErrors) { + console.error( + `LlamaCloud project ${pipeline.project} not found. Check LLAMA_CLOUD_PROJECT_NAME variable`, + ); + } + return false; + } + const pipelineExists = matchedProject.pipelines.some( + (p) => p.name === pipeline.pipeline, + ); + if (!pipelineExists) { + if (logErrors) { + console.error( + `LlamaCloud pipeline ${pipeline.pipeline} not found. Check LLAMA_CLOUD_INDEX_NAME variable`, + ); + } + return false; + } + return true; +} diff --git a/app/components/ui/chat/custom/markdown.tsx b/app/components/ui/chat/custom/markdown.tsx new file mode 100644 index 0000000..88925e8 --- /dev/null +++ b/app/components/ui/chat/custom/markdown.tsx @@ -0,0 +1,27 @@ +import { SourceData } from "@llamaindex/chat-ui"; +import { Markdown as MarkdownUI } from "@llamaindex/chat-ui/widgets"; +import { useClientConfig } from "../hooks/use-config"; + +const preprocessMedia = (content: string) => { + // Remove `sandbox:` from the beginning of the URL before rendering markdown + // OpenAI models sometimes prepend `sandbox:` to relative URLs - this fixes it + return content.replace(/(sandbox|attachment|snt):/g, ""); +}; + +export function Markdown({ + content, + sources, +}: { + content: string; + sources?: SourceData; +}) { + const { backend } = useClientConfig(); + const processedContent = preprocessMedia(content); + return ( + + ); +} diff --git a/app/components/ui/chat/hooks/use-config.ts b/app/components/ui/chat/hooks/use-config.ts new file mode 100644 index 0000000..b344c25 --- /dev/null +++ b/app/components/ui/chat/hooks/use-config.ts @@ -0,0 +1,24 @@ +"use client"; + +export interface ChatConfig { + backend?: string; +} + +function getBackendOrigin(): string { + const chatAPI = process.env.NEXT_PUBLIC_CHAT_API; + if (chatAPI) { + return new URL(chatAPI).origin; + } else { + if (typeof window !== "undefined") { + // Use BASE_URL from window.ENV + return (window as any).ENV?.BASE_URL || ""; + } + return ""; + } +} + +export function useClientConfig(): ChatConfig { + return { + backend: getBackendOrigin(), + }; +} diff --git a/app/components/ui/chat/use-copy-to-clipboard.tsx b/app/components/ui/chat/hooks/use-copy-to-clipboard.tsx similarity index 100% rename from app/components/ui/chat/use-copy-to-clipboard.tsx rename to app/components/ui/chat/hooks/use-copy-to-clipboard.tsx diff --git a/app/components/ui/chat/index.ts b/app/components/ui/chat/index.ts deleted file mode 100644 index 106f629..0000000 --- a/app/components/ui/chat/index.ts +++ /dev/null @@ -1,54 +0,0 @@ -import { JSONValue } from "ai"; -import ChatInput from "./chat-input"; -import ChatMessages from "./chat-messages"; - -export { type ChatHandler } from "./chat.interface"; -export { ChatInput, ChatMessages }; - -export enum MessageAnnotationType { - IMAGE = "image", - SOURCES = "sources", - EVENTS = "events", - TOOLS = "tools", -} - -export type ImageData = { - url: string; -}; - -export type SourceNode = { - id: string; - metadata: Record; - score?: number; - text: string; -}; - -export type SourceData = { - nodes: SourceNode[]; -}; - -export type EventData = { - title: string; - isCollapsed: boolean; -}; - -export type ToolData = { - toolCall: { - id: string; - name: string; - input: { - [key: string]: JSONValue; - }; - }; - toolOutput: { - output: JSONValue; - isError: boolean; - }; -}; - -export type AnnotationData = ImageData | SourceData | EventData | ToolData; - -export type MessageAnnotation = { - type: MessageAnnotationType; - data: AnnotationData; -}; diff --git a/app/components/ui/chat/markdown.tsx b/app/components/ui/chat/markdown.tsx deleted file mode 100644 index 801a7b3..0000000 --- a/app/components/ui/chat/markdown.tsx +++ /dev/null @@ -1,77 +0,0 @@ -import "katex/dist/katex.min.css"; -import { FC, memo } from "react"; -import ReactMarkdown, { Options } from "react-markdown"; -import rehypeKatex from "rehype-katex"; -import remarkGfm from "remark-gfm"; -import remarkMath from "remark-math"; - -import { CodeBlock } from "./codeblock"; - -const MemoizedReactMarkdown: FC = memo( - ReactMarkdown, - (prevProps, nextProps) => - prevProps.children === nextProps.children && - prevProps.className === nextProps.className, -); - -const preprocessLaTeX = (content: string) => { - // Replace block-level LaTeX delimiters \[ \] with $$ $$ - const blockProcessedContent = content.replace( - /\\\[(.*?)\\\]/g, - (_, equation) => `$$${equation}$$`, - ); - // Replace inline LaTeX delimiters \( \) with $ $ - const inlineProcessedContent = blockProcessedContent.replace( - /\\\((.*?)\\\)/g, - (_, equation) => `$${equation}$`, - ); - return inlineProcessedContent; -}; - -export default function Markdown({ content }: { content: string }) { - const processedContent = preprocessLaTeX(content); - return ( - {children}

; - }, - code({ node, inline, className, children, ...props }) { - if (children.length) { - if (children[0] == "▍") { - return ( - - ); - } - - children[0] = (children[0] as string).replace("`▍`", "▍"); - } - - const match = /language-(\w+)/.exec(className || ""); - - if (inline) { - return ( - - {children} - - ); - } - - return ( - - ); - }, - }} - > - {processedContent} -
- ); -} diff --git a/app/components/ui/chat/tools/artifact.tsx b/app/components/ui/chat/tools/artifact.tsx new file mode 100644 index 0000000..fe6e819 --- /dev/null +++ b/app/components/ui/chat/tools/artifact.tsx @@ -0,0 +1,388 @@ +"use client"; + +import { Check, ChevronDown, Code, Copy, Loader2 } from "lucide-react"; +import { useEffect, useRef, useState } from "react"; +import { Button, buttonVariants } from "../../button"; +import { + Collapsible, + CollapsibleContent, + CollapsibleTrigger, +} from "../../collapsible"; +import { cn } from "../../lib/utils"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "../../tabs"; +import { Markdown } from "../custom/markdown"; +import { useClientConfig } from "../hooks/use-config"; +import { useCopyToClipboard } from "../hooks/use-copy-to-clipboard"; + +// detail information to execute code +export type CodeArtifact = { + commentary: string; + template: string; + title: string; + description: string; + additional_dependencies: string[]; + has_additional_dependencies: boolean; + install_dependencies_command: string; + port: number | null; + file_path: string; + code: string; + files?: string[]; +}; + +type OutputUrl = { + url: string; + filename: string; +}; + +type ArtifactResult = { + template: string; + stdout: string[]; + stderr: string[]; + runtimeError?: { name: string; value: string; tracebackRaw: string[] }; + outputUrls: OutputUrl[]; + url: string; +}; + +export function Artifact({ + artifact, + version, +}: { + artifact: CodeArtifact | null; + version?: number; +}) { + const [result, setResult] = useState(null); + const [sandboxCreationError, setSandboxCreationError] = useState(); + const [sandboxCreating, setSandboxCreating] = useState(false); + const [openOutputPanel, setOpenOutputPanel] = useState(false); + const panelRef = useRef(null); + const { backend } = useClientConfig(); + + const handleOpenOutput = async () => { + setOpenOutputPanel(true); + openPanel(); + panelRef.current?.classList.remove("hidden"); + }; + + const fetchArtifactResult = async () => { + try { + setSandboxCreating(true); + + const response = await fetch(`${backend}/api/sandbox`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ artifact }), + }); + + if (!response.ok) { + throw new Error("Failure running code artifact"); + } + + const fetchedResult = await response.json(); + + setResult(fetchedResult); + } catch (error) { + console.error("Error fetching artifact result:", error); + setSandboxCreationError( + error instanceof Error + ? error.message + : "An unknown error occurred when executing code", + ); + } finally { + setSandboxCreating(false); + } + }; + + useEffect(() => { + // auto trigger code execution + !result && fetchArtifactResult(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + if (!artifact || version === undefined) return null; + + return ( +
+
+ +
+

+ {artifact.title} v{version} +

+ Click to open code +
+
+ + {openOutputPanel && ( +
+
+
+

{artifact?.title}

+ Version: v{version} +
+ +
+ + {sandboxCreating && ( +
+ +
+ )} + {sandboxCreationError && ( +
+

+ Error when creating Sandbox: +

+

{sandboxCreationError}

+
+ )} + {result && ( + + )} +
+ )} +
+ ); +} + +function ArtifactOutput({ + artifact, + result, + version, +}: { + artifact: CodeArtifact; + result: ArtifactResult; + version: number; +}) { + const fileExtension = artifact.file_path.split(".").pop() || ""; + const markdownCode = `\`\`\`${fileExtension}\n${artifact.code}\n\`\`\``; + const { url: sandboxUrl, outputUrls, runtimeError, stderr, stdout } = result; + + return ( + + + Code + Preview + + +
+ +
+
+ + {runtimeError && } + + {sandboxUrl && } + {outputUrls && } + +
+ ); +} + +function RunTimeError({ + runtimeError, +}: { + runtimeError: { name: string; value: string; tracebackRaw?: string[] }; +}) { + const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 1000 }); + const contentToCopy = `Fix this error:\n${runtimeError.name}\n${runtimeError.value}\n${runtimeError.tracebackRaw?.join("\n")}`; + return ( + + + Runtime Error: + + + +
+

{runtimeError.name}

+

{runtimeError.value}

+ {runtimeError.tracebackRaw?.map((trace, index) => ( +
+              {trace}
+            
+ ))} +
+ +
+
+ ); +} + +function CodeSandboxPreview({ url }: { url: string }) { + const [loading, setLoading] = useState(true); + const iframeRef = useRef(null); + + useEffect(() => { + if (!loading && iframeRef.current) { + iframeRef.current.focus(); + } + }, [loading]); + + return ( + <> +