diff --git a/examples/playground/functions/vector-example/index.ts b/examples/playground/functions/vector-example/index.ts index 69f5e80a8..164292836 100644 --- a/examples/playground/functions/vector-example/index.ts +++ b/examples/playground/functions/vector-example/index.ts @@ -49,8 +49,6 @@ export const seeder = async () => { for (const tag of tags) { console.log("ingesting tag", tag.id); await client.ingest({ - //model: "amazon.titan-embed-image-v1", - model: "text-embedding-ada-002", text: tag.text, metadata: { type: "tag", id: tag.id }, }); @@ -63,7 +61,6 @@ export const seeder = async () => { const image = imageBuffer.toString("base64"); await client.ingest({ - model: "text-embedding-ada-002", text: movie.summary, image, metadata: { type: "movie", id: movie.id }, @@ -78,7 +75,6 @@ export const seeder = async () => { export const app = async (event) => { const ret = await client.retrieve({ - model: "text-embedding-ada-002", text: event.queryStringParameters?.text, include: { type: "movie" }, exclude: { id: "movie1" }, diff --git a/examples/playground/sst.config.ts b/examples/playground/sst.config.ts index 6377c7899..a8f5649ba 100644 --- a/examples/playground/sst.config.ts +++ b/examples/playground/sst.config.ts @@ -17,6 +17,8 @@ export default $config({ }, async run() { const vector = new sst.Vector("MyVectorDB", { + model: "text-embedding-ada-002", + //model: "amazon.titan-embed-image-v1", openAiApiKey: new sst.Secret("OpenAiApiKey").value, }); diff --git a/internal/components/src/components/handlers/vector-handler/index.ts b/internal/components/src/components/handlers/vector-handler/index.ts index 3b6f46a3c..bdbcfc371 100644 --- a/internal/components/src/components/handlers/vector-handler/index.ts +++ b/internal/components/src/components/handlers/vector-handler/index.ts @@ -9,32 +9,13 @@ import { import { OpenAI } from "openai"; import { useClient } from "../../helpers/aws/client"; -const ModelInfo = { - "amazon.titan-embed-text-v1": { - provider: "bedrock" as const, - shortName: "brtt1", - }, - "amazon.titan-embed-image-v1": { - provider: "bedrock" as const, - shortName: "brti1", - }, - "text-embedding-ada-002": { - provider: "openai" as const, - shortName: "oata2", - }, -}; - -type Model = keyof typeof ModelInfo; - export type IngestEvent = { - model?: Model; text?: string; image?: string; metadata: Record; }; export type RetrieveEvent = { - model?: Model; text?: string; image?: string; include: Record; @@ -52,25 +33,24 @@ const { SECRET_ARN, DATABASE_NAME, TABLE_NAME, + MODEL, + MODEL_PROVIDER, // modal provider dependent (optional) OPENAI_API_KEY, } = process.env; export async function ingest(event: IngestEvent) { - const model = normalizeModel(event.model); - const embedding = await generateEmbedding(model, event.text, event.image); + const embedding = await generateEmbedding(event.text, event.image); const metadata = JSON.stringify(event.metadata); - await storeEmbedding(model, metadata, embedding); + await storeEmbedding(metadata, embedding); } export async function retrieve(event: RetrieveEvent) { - const model = normalizeModel(event.model); - const embedding = await generateEmbedding(model, event.text, event.image); + const embedding = await generateEmbedding(event.text, event.image); const include = JSON.stringify(event.include); // The return type of JSON.stringify() is always "string". // This is wrong when "event.exclude" is undefined. const exclude = JSON.stringify(event.exclude) as string | undefined; const result = await queryEmbeddings( - model, include, exclude, embedding, @@ -86,45 +66,31 @@ export async function remove(event: RemoveEvent) { await removeEmbedding(include); } -function normalizeModel(model?: Model) { - model = model ?? "amazon.titan-embed-image-v1"; - if (ModelInfo[model].provider === "openai" && !OPENAI_API_KEY) { - throw new Error( - `To use the model "${model}", an OpenAI API key is necessary. Please ensure that "openAiApiKey" has been configured in the Vector component.` - ); - } - return model; -} - -async function generateEmbedding(model: Model, text?: string, image?: string) { - if (ModelInfo[model].provider === "openai") { - return await generateEmbeddingOpenAI(model, text!); +async function generateEmbedding(text?: string, image?: string) { + if (MODEL_PROVIDER === "openai") { + return await generateEmbeddingOpenAI(text!); } - return await generateEmbeddingBedrock(model, text, image); + return await generateEmbeddingBedrock(text, image); } -async function generateEmbeddingOpenAI(model: Model, text: string) { +async function generateEmbeddingOpenAI(text: string) { const openAi = new OpenAI({ apiKey: OPENAI_API_KEY }); const embeddingResponse = await openAi.embeddings.create({ - model, + model: MODEL!, input: text, encoding_format: "float", }); return embeddingResponse.data[0].embedding; } -async function generateEmbeddingBedrock( - model: Model, - text?: string, - image?: string -) { +async function generateEmbeddingBedrock(text?: string, image?: string) { const ret = await useClient(BedrockRuntimeClient).send( new InvokeModelCommand({ body: JSON.stringify({ inputText: text, inputImage: image, }), - modelId: model, + modelId: MODEL, contentType: "application/json", accept: "*/*", }) @@ -133,23 +99,15 @@ async function generateEmbeddingBedrock( return payload.embedding; } -async function storeEmbedding( - model: Model, - metadata: string, - embedding: number[] -) { +async function storeEmbedding(metadata: string, embedding: number[]) { await useClient(RDSDataClient).send( new ExecuteStatementCommand({ resourceArn: CLUSTER_ARN, secretArn: SECRET_ARN, database: DATABASE_NAME, - sql: `INSERT INTO ${TABLE_NAME} (model, embedding, metadata) - VALUES (:model, ARRAY[${embedding.join(",")}], :metadata)`, + sql: `INSERT INTO ${TABLE_NAME} (embedding, metadata) + VALUES (ARRAY[${embedding.join(",")}], :metadata)`, parameters: [ - { - name: "model", - value: { stringValue: ModelInfo[model].shortName }, - }, { name: "metadata", value: { stringValue: metadata }, @@ -161,7 +119,6 @@ async function storeEmbedding( } async function queryEmbeddings( - model: Model, include: string, exclude: string | undefined, embedding: number[], @@ -175,17 +132,12 @@ async function queryEmbeddings( secretArn: SECRET_ARN, database: DATABASE_NAME, sql: `SELECT id, metadata, ${score} AS score FROM ${TABLE_NAME} - WHERE model = :model - AND ${score} < ${1 - threshold} + WHERE ${score} < ${1 - threshold} AND metadata @> :include ${exclude ? "AND NOT metadata @> :exclude" : ""} ORDER BY ${score} LIMIT ${count}`, parameters: [ - { - name: "model", - value: { stringValue: ModelInfo[model].shortName }, - }, { name: "include", value: { stringValue: include }, diff --git a/internal/components/src/components/providers/embeddings-table.ts b/internal/components/src/components/providers/embeddings-table.ts index 0ffbbc359..a0dd10b1e 100644 --- a/internal/components/src/components/providers/embeddings-table.ts +++ b/internal/components/src/components/providers/embeddings-table.ts @@ -10,6 +10,7 @@ export interface PostgresTableInputs { secretArn: Input; databaseName: Input; tableName: Input; + vectorSize: Input; } interface Inputs { @@ -17,6 +18,7 @@ interface Inputs { secretArn: string; databaseName: string; tableName: string; + vectorSize: number; } class Provider implements dynamic.ResourceProvider { @@ -41,6 +43,9 @@ class Provider implements dynamic.ResourceProvider { await this.createDatabase(news); await this.enablePgvectorExtension(news); await this.enablePgtrgmExtension(news); + if (olds.vectorSize !== news.vectorSize) { + await this.removeTable(news); + } await this.createTable(news); await this.createEmbeddingIndex(news); await this.createMetadataIndex(news); @@ -112,8 +117,7 @@ class Provider implements dynamic.ResourceProvider { database: inputs.databaseName, sql: `create table ${inputs.tableName} ( id bigserial primary key, - model char(5), - embedding vector(1536), + embedding vector(${inputs.vectorSize}), metadata jsonb );`, }) @@ -125,6 +129,17 @@ class Provider implements dynamic.ResourceProvider { } } + async removeTable(inputs: Inputs) { + await useClient(RDSDataClient).send( + new ExecuteStatementCommand({ + resourceArn: inputs.clusterArn, + secretArn: inputs.secretArn, + database: inputs.databaseName, + sql: `drop table if exists ${inputs.tableName};`, + }) + ); + } + async createEmbeddingIndex(inputs: Inputs) { try { await useClient(RDSDataClient).send( diff --git a/internal/components/src/components/vector.ts b/internal/components/src/components/vector.ts index d62a95ca7..9f5341265 100644 --- a/internal/components/src/components/vector.ts +++ b/internal/components/src/components/vector.ts @@ -1,12 +1,30 @@ import path from "path"; -import { ComponentResourceOptions, Input, interpolate } from "@pulumi/pulumi"; +import { + ComponentResourceOptions, + Input, + all, + interpolate, + output, +} from "@pulumi/pulumi"; import { Component } from "./component.js"; import { Postgres } from "./postgres.js"; import { EmbeddingsTable } from "./providers/embeddings-table.js"; import { Function, FunctionPermissionArgs } from "./function.js"; import { AWSLinkable, Link, Linkable } from "./link.js"; +import { VisibleError } from "./error.js"; + +const ModelInfo = { + "amazon.titan-embed-text-v1": { provider: "bedrock" as const, size: 1536 }, + "amazon.titan-embed-image-v1": { provider: "bedrock" as const, size: 1024 }, + "text-embedding-ada-002": { provider: "openai" as const, size: 1536 }, +}; export interface VectorArgs { + /** + * The embedding model to use for generating vectors + * @default Titan Multimodal Embeddings G1 + */ + model?: Input; /** * Specifies the OpenAI API key *.This key is required for datar ingestoig and retrieal usvin an OpenAI modelg. @@ -36,6 +54,9 @@ export class Vector extends Component implements Linkable, AWSLinkable { super("sst:sst:Vector", name, args, opts); const parent = this; + const model = normalizeModel(); + const vectorSize = normalizeVectorSize(); + const openAiApiKey = normalizeOpenAiApiKey(); const databaseName = normalizeDatabaseName(); const tableName = normalizeTableName(); @@ -49,6 +70,29 @@ export class Vector extends Component implements Linkable, AWSLinkable { this.retrieveHandler = retrieveHandler; this.removeHandler = removeHandler; + function normalizeModel() { + return output(args?.model).apply((model) => { + if (model && !ModelInfo[model]) + throw new Error(`Invalid model: ${model}`); + return model ?? "amazon.titan-embed-image-v1"; + }); + } + + function normalizeOpenAiApiKey() { + return all([model, args?.openAiApiKey]).apply(([model, openAiApiKey]) => { + if (ModelInfo[model].provider === "openai" && !openAiApiKey) { + throw new VisibleError( + `Please pass in the OPENAI_API_KEY via environment variable to use the ${model} model. You can get your API keys here: https://platform.openai.com/api-keys` + ); + } + return openAiApiKey; + }); + } + + function normalizeVectorSize() { + return model.apply((model) => ModelInfo[model].size); + } + function normalizeDatabaseName() { return $app.stage; } @@ -71,6 +115,7 @@ export class Vector extends Component implements Linkable, AWSLinkable { secretArn: postgres.nodes.cluster.masterUserSecrets[0].secretArn, databaseName, tableName, + vectorSize, }, { parent } ); @@ -130,13 +175,15 @@ export class Vector extends Component implements Linkable, AWSLinkable { } function buildHandlerEnvironment() { - return { + return all([model, openAiApiKey]).apply(([model, openAiApiKey]) => ({ CLUSTER_ARN: postgres.nodes.cluster.arn, SECRET_ARN: postgres.nodes.cluster.masterUserSecrets[0].secretArn, DATABASE_NAME: databaseName, TABLE_NAME: tableName, - ...(args?.openAiApiKey ? { OPENAI_API_KEY: args.openAiApiKey } : {}), - }; + MODEL: model, + MODEL_PROVIDER: ModelInfo[model].provider, + ...(openAiApiKey ? { OPENAI_API_KEY: openAiApiKey } : {}), + })); } function buildHandlerPermissions() { diff --git a/sdk/js/src/vector-client.ts b/sdk/js/src/vector-client.ts index 8a3825146..9155f4e31 100644 --- a/sdk/js/src/vector-client.ts +++ b/sdk/js/src/vector-client.ts @@ -5,17 +5,7 @@ import { } from "@aws-sdk/client-lambda"; import { Resource } from "./resource.js"; -type Model = - | "amazon.titan-embed-text-v1" - | "amazon.titan-embed-image-v1" - | "text-embedding-ada-002"; - export type IngestEvent = { - /** - * The embedding model to use for generating vectors - * @default Titan Multimodal Embeddings G1 - */ - model?: Model; /** * The text used to generate the embedding vector. * At least one of `text` or `image` must be provided. @@ -56,11 +46,6 @@ export type IngestEvent = { }; export type RetrieveEvent = { - /** - * The embedding model to use for generating vectors - * @default Titan Multimodal Embeddings G1 - */ - model?: Model; /** * The text prompt used to retrieve embeddings. * At least one of `text` or `image` must be provided.