Skip to content
This repository has been archived by the owner on Oct 21, 2024. It is now read-only.

Commit

Permalink
Vector: revert multiple model support and add support for retrive exc…
Browse files Browse the repository at this point in the history
…lude
  • Loading branch information
fwang committed Jan 18, 2024
1 parent 0744e2e commit d11ee5f
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 90 deletions.
4 changes: 0 additions & 4 deletions examples/playground/functions/vector-example/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ export const seeder = async () => {
for (const tag of tags) {
console.log("ingesting tag", tag.id);
await client.ingest({
//model: "amazon.titan-embed-image-v1",
model: "text-embedding-ada-002",
text: tag.text,
metadata: { type: "tag", id: tag.id },
});
Expand All @@ -63,7 +61,6 @@ export const seeder = async () => {
const image = imageBuffer.toString("base64");

await client.ingest({
model: "text-embedding-ada-002",
text: movie.summary,
image,
metadata: { type: "movie", id: movie.id },
Expand All @@ -78,7 +75,6 @@ export const seeder = async () => {

export const app = async (event) => {
const ret = await client.retrieve({
model: "text-embedding-ada-002",
text: event.queryStringParameters?.text,
include: { type: "movie" },
exclude: { id: "movie1" },
Expand Down
2 changes: 2 additions & 0 deletions examples/playground/sst.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ export default $config({
},
async run() {
const vector = new sst.Vector("MyVectorDB", {
model: "text-embedding-ada-002",
//model: "amazon.titan-embed-image-v1",
openAiApiKey: new sst.Secret("OpenAiApiKey").value,
});

Expand Down
82 changes: 17 additions & 65 deletions internal/components/src/components/handlers/vector-handler/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,13 @@ import {
import { OpenAI } from "openai";
import { useClient } from "../../helpers/aws/client";

const ModelInfo = {
"amazon.titan-embed-text-v1": {
provider: "bedrock" as const,
shortName: "brtt1",
},
"amazon.titan-embed-image-v1": {
provider: "bedrock" as const,
shortName: "brti1",
},
"text-embedding-ada-002": {
provider: "openai" as const,
shortName: "oata2",
},
};

type Model = keyof typeof ModelInfo;

export type IngestEvent = {
model?: Model;
text?: string;
image?: string;
metadata: Record<string, any>;
};

export type RetrieveEvent = {
model?: Model;
text?: string;
image?: string;
include: Record<string, any>;
Expand All @@ -52,25 +33,24 @@ const {
SECRET_ARN,
DATABASE_NAME,
TABLE_NAME,
MODEL,
MODEL_PROVIDER,
// modal provider dependent (optional)
OPENAI_API_KEY,
} = process.env;

export async function ingest(event: IngestEvent) {
const model = normalizeModel(event.model);
const embedding = await generateEmbedding(model, event.text, event.image);
const embedding = await generateEmbedding(event.text, event.image);
const metadata = JSON.stringify(event.metadata);
await storeEmbedding(model, metadata, embedding);
await storeEmbedding(metadata, embedding);
}
export async function retrieve(event: RetrieveEvent) {
const model = normalizeModel(event.model);
const embedding = await generateEmbedding(model, event.text, event.image);
const embedding = await generateEmbedding(event.text, event.image);
const include = JSON.stringify(event.include);
// The return type of JSON.stringify() is always "string".
// This is wrong when "event.exclude" is undefined.
const exclude = JSON.stringify(event.exclude) as string | undefined;
const result = await queryEmbeddings(
model,
include,
exclude,
embedding,
Expand All @@ -86,45 +66,31 @@ export async function remove(event: RemoveEvent) {
await removeEmbedding(include);
}

function normalizeModel(model?: Model) {
model = model ?? "amazon.titan-embed-image-v1";
if (ModelInfo[model].provider === "openai" && !OPENAI_API_KEY) {
throw new Error(
`To use the model "${model}", an OpenAI API key is necessary. Please ensure that "openAiApiKey" has been configured in the Vector component.`
);
}
return model;
}

async function generateEmbedding(model: Model, text?: string, image?: string) {
if (ModelInfo[model].provider === "openai") {
return await generateEmbeddingOpenAI(model, text!);
async function generateEmbedding(text?: string, image?: string) {
if (MODEL_PROVIDER === "openai") {
return await generateEmbeddingOpenAI(text!);
}
return await generateEmbeddingBedrock(model, text, image);
return await generateEmbeddingBedrock(text, image);
}

async function generateEmbeddingOpenAI(model: Model, text: string) {
async function generateEmbeddingOpenAI(text: string) {
const openAi = new OpenAI({ apiKey: OPENAI_API_KEY });
const embeddingResponse = await openAi.embeddings.create({
model,
model: MODEL!,
input: text,
encoding_format: "float",
});
return embeddingResponse.data[0].embedding;
}

async function generateEmbeddingBedrock(
model: Model,
text?: string,
image?: string
) {
async function generateEmbeddingBedrock(text?: string, image?: string) {
const ret = await useClient(BedrockRuntimeClient).send(
new InvokeModelCommand({
body: JSON.stringify({
inputText: text,
inputImage: image,
}),
modelId: model,
modelId: MODEL,
contentType: "application/json",
accept: "*/*",
})
Expand All @@ -133,23 +99,15 @@ async function generateEmbeddingBedrock(
return payload.embedding;
}

async function storeEmbedding(
model: Model,
metadata: string,
embedding: number[]
) {
async function storeEmbedding(metadata: string, embedding: number[]) {
await useClient(RDSDataClient).send(
new ExecuteStatementCommand({
resourceArn: CLUSTER_ARN,
secretArn: SECRET_ARN,
database: DATABASE_NAME,
sql: `INSERT INTO ${TABLE_NAME} (model, embedding, metadata)
VALUES (:model, ARRAY[${embedding.join(",")}], :metadata)`,
sql: `INSERT INTO ${TABLE_NAME} (embedding, metadata)
VALUES (ARRAY[${embedding.join(",")}], :metadata)`,
parameters: [
{
name: "model",
value: { stringValue: ModelInfo[model].shortName },
},
{
name: "metadata",
value: { stringValue: metadata },
Expand All @@ -161,7 +119,6 @@ async function storeEmbedding(
}

async function queryEmbeddings(
model: Model,
include: string,
exclude: string | undefined,
embedding: number[],
Expand All @@ -175,17 +132,12 @@ async function queryEmbeddings(
secretArn: SECRET_ARN,
database: DATABASE_NAME,
sql: `SELECT id, metadata, ${score} AS score FROM ${TABLE_NAME}
WHERE model = :model
AND ${score} < ${1 - threshold}
WHERE ${score} < ${1 - threshold}
AND metadata @> :include
${exclude ? "AND NOT metadata @> :exclude" : ""}
ORDER BY ${score}
LIMIT ${count}`,
parameters: [
{
name: "model",
value: { stringValue: ModelInfo[model].shortName },
},
{
name: "include",
value: { stringValue: include },
Expand Down
19 changes: 17 additions & 2 deletions internal/components/src/components/providers/embeddings-table.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ export interface PostgresTableInputs {
secretArn: Input<string>;
databaseName: Input<string>;
tableName: Input<string>;
vectorSize: Input<number>;
}

interface Inputs {
clusterArn: string;
secretArn: string;
databaseName: string;
tableName: string;
vectorSize: number;
}

class Provider implements dynamic.ResourceProvider {
Expand All @@ -41,6 +43,9 @@ class Provider implements dynamic.ResourceProvider {
await this.createDatabase(news);
await this.enablePgvectorExtension(news);
await this.enablePgtrgmExtension(news);
if (olds.vectorSize !== news.vectorSize) {
await this.removeTable(news);
}
await this.createTable(news);
await this.createEmbeddingIndex(news);
await this.createMetadataIndex(news);
Expand Down Expand Up @@ -112,8 +117,7 @@ class Provider implements dynamic.ResourceProvider {
database: inputs.databaseName,
sql: `create table ${inputs.tableName} (
id bigserial primary key,
model char(5),
embedding vector(1536),
embedding vector(${inputs.vectorSize}),
metadata jsonb
);`,
})
Expand All @@ -125,6 +129,17 @@ class Provider implements dynamic.ResourceProvider {
}
}

async removeTable(inputs: Inputs) {
await useClient(RDSDataClient).send(
new ExecuteStatementCommand({
resourceArn: inputs.clusterArn,
secretArn: inputs.secretArn,
database: inputs.databaseName,
sql: `drop table if exists ${inputs.tableName};`,
})
);
}

async createEmbeddingIndex(inputs: Inputs) {
try {
await useClient(RDSDataClient).send(
Expand Down
55 changes: 51 additions & 4 deletions internal/components/src/components/vector.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
import path from "path";
import { ComponentResourceOptions, Input, interpolate } from "@pulumi/pulumi";
import {
ComponentResourceOptions,
Input,
all,
interpolate,
output,
} from "@pulumi/pulumi";
import { Component } from "./component.js";
import { Postgres } from "./postgres.js";
import { EmbeddingsTable } from "./providers/embeddings-table.js";
import { Function, FunctionPermissionArgs } from "./function.js";
import { AWSLinkable, Link, Linkable } from "./link.js";
import { VisibleError } from "./error.js";

const ModelInfo = {
"amazon.titan-embed-text-v1": { provider: "bedrock" as const, size: 1536 },
"amazon.titan-embed-image-v1": { provider: "bedrock" as const, size: 1024 },
"text-embedding-ada-002": { provider: "openai" as const, size: 1536 },
};

export interface VectorArgs {
/**
* The embedding model to use for generating vectors
* @default Titan Multimodal Embeddings G1
*/
model?: Input<keyof typeof ModelInfo>;
/**
* Specifies the OpenAI API key
*.This key is required for datar ingestoig and retrieal usvin an OpenAI modelg.
Expand Down Expand Up @@ -36,6 +54,9 @@ export class Vector extends Component implements Linkable, AWSLinkable {
super("sst:sst:Vector", name, args, opts);

const parent = this;
const model = normalizeModel();
const vectorSize = normalizeVectorSize();
const openAiApiKey = normalizeOpenAiApiKey();
const databaseName = normalizeDatabaseName();
const tableName = normalizeTableName();

Expand All @@ -49,6 +70,29 @@ export class Vector extends Component implements Linkable, AWSLinkable {
this.retrieveHandler = retrieveHandler;
this.removeHandler = removeHandler;

function normalizeModel() {
return output(args?.model).apply((model) => {
if (model && !ModelInfo[model])
throw new Error(`Invalid model: ${model}`);
return model ?? "amazon.titan-embed-image-v1";
});
}

function normalizeOpenAiApiKey() {
return all([model, args?.openAiApiKey]).apply(([model, openAiApiKey]) => {
if (ModelInfo[model].provider === "openai" && !openAiApiKey) {
throw new VisibleError(
`Please pass in the OPENAI_API_KEY via environment variable to use the ${model} model. You can get your API keys here: https://platform.openai.com/api-keys`
);
}
return openAiApiKey;
});
}

function normalizeVectorSize() {
return model.apply((model) => ModelInfo[model].size);
}

function normalizeDatabaseName() {
return $app.stage;
}
Expand All @@ -71,6 +115,7 @@ export class Vector extends Component implements Linkable, AWSLinkable {
secretArn: postgres.nodes.cluster.masterUserSecrets[0].secretArn,
databaseName,
tableName,
vectorSize,
},
{ parent }
);
Expand Down Expand Up @@ -130,13 +175,15 @@ export class Vector extends Component implements Linkable, AWSLinkable {
}

function buildHandlerEnvironment() {
return {
return all([model, openAiApiKey]).apply(([model, openAiApiKey]) => ({
CLUSTER_ARN: postgres.nodes.cluster.arn,
SECRET_ARN: postgres.nodes.cluster.masterUserSecrets[0].secretArn,
DATABASE_NAME: databaseName,
TABLE_NAME: tableName,
...(args?.openAiApiKey ? { OPENAI_API_KEY: args.openAiApiKey } : {}),
};
MODEL: model,
MODEL_PROVIDER: ModelInfo[model].provider,
...(openAiApiKey ? { OPENAI_API_KEY: openAiApiKey } : {}),
}));
}

function buildHandlerPermissions() {
Expand Down
15 changes: 0 additions & 15 deletions sdk/js/src/vector-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,7 @@ import {
} from "@aws-sdk/client-lambda";
import { Resource } from "./resource.js";

type Model =
| "amazon.titan-embed-text-v1"
| "amazon.titan-embed-image-v1"
| "text-embedding-ada-002";

export type IngestEvent = {
/**
* The embedding model to use for generating vectors
* @default Titan Multimodal Embeddings G1
*/
model?: Model;
/**
* The text used to generate the embedding vector.
* At least one of `text` or `image` must be provided.
Expand Down Expand Up @@ -56,11 +46,6 @@ export type IngestEvent = {
};

export type RetrieveEvent = {
/**
* The embedding model to use for generating vectors
* @default Titan Multimodal Embeddings G1
*/
model?: Model;
/**
* The text prompt used to retrieve embeddings.
* At least one of `text` or `image` must be provided.
Expand Down

0 comments on commit d11ee5f

Please sign in to comment.