Skip to content

Commit

Permalink
Support indexed db for embeddings storage
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Wang <[email protected]>
  • Loading branch information
xiaohk committed Feb 4, 2024
1 parent bfc396d commit d711874
Show file tree
Hide file tree
Showing 4 changed files with 520 additions and 24 deletions.
7 changes: 4 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@
},
"homepage": "https://github.com/xiaohk/mememo#readme",
"scripts": {
"test": "vitest",
"test:ui": "vitest --ui",
"test:run": "vitest run",
"test": "vitest test/mememo.test.ts",
"test:browser": "vitest -c vitest.config.browser.ts test/mememo.browser.test.ts",
"test:run": "vitest run test/mememo.test.ts && test:run:browser",
"test:run:browser": "vitest run -c vitest.config.browser.ts test/mememo.browser.test.ts",
"coverage": "vitest run --coverage && c8 report && pnpm run coverage:badge",
"coverage:badge": "pnpx make-coverage-badge --output-path ./imgs/coverage-badge.svg",
"build": "pnpm run clean && vite build",
Expand Down
76 changes: 55 additions & 21 deletions src/mememo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import { randomLcg, randomUniform } from 'd3-random';
import { MinHeap, MaxHeap, IGetCompareValue } from '@datastructures-js/heap';
import { openDB, IDBPDatabase } from 'idb';

export type IDBValidKey = number | string | Date | BufferSource | IDBValidKey[];
type BuiltInDistanceFunction = 'cosine' | 'cosine-normalized';

interface SearchNodeCandidate<T> {
interface SearchNodeCandidate<T extends IDBValidKey> {
key: T;
distance: number;
}
Expand Down Expand Up @@ -67,12 +68,17 @@ interface HNSWConfig {

/** Optional random seed. */
seed?: number;

/** Whether to use indexedDB. If this is false, store all embeddings in
* the memory. Default to false so that MeMemo can be used in node.js.
*/
useIndexedDB?: boolean;
}

/**
* A node in the HNSW graph.
*/
class Node<T> {
class Node<T extends IDBValidKey> {
/** The unique key of an element. */
key: T;

Expand All @@ -92,15 +98,18 @@ class Node<T> {
/**
* An abstraction of a map storing nodes (in memory or in indexedDB)
*/
class Nodes<T> {
class Nodes<T extends IDBValidKey> {
nodesMap: Map<T, Node<T>>;
indexedDBStoreName = '';
dbPromise: Promise<IDBPDatabase<string>> | null;

constructor(indexedDBStoreName?: string) {
this.nodesMap = new Map<T, Node<T>>();
this.dbPromise = null;

if (indexedDBStoreName !== undefined) {
this.indexedDBStoreName = indexedDBStoreName;

// Create a new store
this.dbPromise = openDB<string>('mememo-index-store', 1, {
upgrade(db) {
Expand All @@ -121,47 +130,57 @@ class Nodes<T> {
if (this.dbPromise === null) {
return this.nodesMap.size;
} else {
return 1;
const db = await this.dbPromise;
const keys = await db.getAllKeys(this.indexedDBStoreName);
return keys.length;
}
}

async has(key: T) {
if (this.dbPromise === null) {
return this.nodesMap.has(key);
} else {
return false;
const db = await this.dbPromise;
const result = await db.getKey(this.indexedDBStoreName, key);
return result !== undefined;
}
}

async get(key: T) {
if (this.dbPromise === null) {
return this.nodesMap.get(key);
} else {
return undefined;
const db = await this.dbPromise;
const result = await (db.get(this.indexedDBStoreName, key) as Promise<
Node<T> | undefined
>);
return result;
}
}

async set(key: T, value: Node<T>) {
if (this.dbPromise === null) {
this.nodesMap.set(key, value);
} else {
// pass
const db = await this.dbPromise;
await db.put(this.indexedDBStoreName, value, key);
}
}

async clear() {
if (this.dbPromise === null) {
this.nodesMap = new Map<T, Node<T>>();
} else {
// pass
const db = await this.dbPromise;
await db.clear(this.indexedDBStoreName);
}
}
}

/**
* One graph layer in the HNSW index
*/
class GraphLayer<T> {
class GraphLayer<T extends IDBValidKey> {
/** The graph maps a key to its neighbor and distances */
graph: Map<T, Map<T, number>>;

Expand All @@ -178,7 +197,7 @@ class GraphLayer<T> {
/**
* HNSW (Hierarchical Navigable Small World) class.
*/
export class HNSW<T = string> {
export class HNSW<T extends IDBValidKey = string> {
distanceFunction: (a: number[], b: number[]) => number;

/** The max number of neighbors for each node. */
Expand Down Expand Up @@ -219,14 +238,16 @@ export class HNSW<T = string> {
* have in the zero layer. Default 2 * m.
* @param config.ml - Normalizer parameter. Default 1 / ln(m)
* @param config.seed - Optional random seed.
* @param config.useIndexedDB - Whether to use indexedDB
*/
constructor({
distanceFunction,
m,
efConstruction,
mMax0,
ml,
seed
seed,
useIndexedDB
}: HNSWConfig) {
// Initialize HNSW parameters
this.m = m || 16;
Expand All @@ -251,9 +272,14 @@ export class HNSW<T = string> {
}
}

if (useIndexedDB === undefined || useIndexedDB === false) {
this.nodes = new Nodes();
} else {
this.nodes = new Nodes('mememo-store');
}

// Data structures
this.graphLayers = [];
this.nodes = new Nodes();
}

/**
Expand All @@ -277,8 +303,8 @@ export class HNSW<T = string> {
}

throw Error(
`There is already a node with key ${key} in the index. ` +
'Use update() to update this node.'
`There is already a node with key ${JSON.stringify(key)} in the` +
'index. Use update() to update this node.'
);
}

Expand Down Expand Up @@ -346,7 +372,9 @@ export class HNSW<T = string> {
for (const neighbor of selectedNeighbors) {
const neighborNode = this.graphLayers[l].graph.get(neighbor.key);
if (neighborNode === undefined) {
throw Error(`Can't find neighbor node ${neighbor.key}`);
throw Error(
`Can't find neighbor node ${JSON.stringify(neighbor.key)}`
);
}

// Add the neighbor's existing neighbors as candidates
Expand Down Expand Up @@ -396,7 +424,7 @@ export class HNSW<T = string> {
async update(key: T, value: number[]) {
if (!(await this.nodes.has(key))) {
throw Error(
`The node with key ${key} does not exist. ` +
`The node with key ${JSON.stringify(key)} does not exist. ` +
'Use insert() to add new node.'
);
}
Expand Down Expand Up @@ -431,7 +459,9 @@ export class HNSW<T = string> {
const firstDegreeNeighborNode =
curGraphLayer.graph.get(firstDegreeNeighbor);
if (firstDegreeNeighborNode === undefined) {
throw Error(`Can't find node with key ${firstDegreeNeighbor}`);
throw Error(
`Can't find node with key ${JSON.stringify(firstDegreeNeighbor)}`
);
}

for (const secondDegreeNeighbor of firstDegreeNeighborNode.keys()) {
Expand Down Expand Up @@ -511,7 +541,7 @@ export class HNSW<T = string> {
*/
async markDeleted(key: T) {
if (!(await this.nodes.has(key))) {
throw Error(`Node with key ${key} does not exist.`);
throw Error(`Node with key ${JSON.stringify(key)} does not exist.`);
}

// Special case: the user is trying to delete the entry point
Expand Down Expand Up @@ -729,7 +759,9 @@ export class HNSW<T = string> {

const curNode = graphLayer.graph.get(curCandidate.key);
if (curNode === undefined) {
throw Error(`Cannot find node with key ${curCandidate.key}`);
throw Error(
`Cannot find node with key ${JSON.stringify(curCandidate.key)}`
);
}

for (const key of curNode.keys()) {
Expand Down Expand Up @@ -802,7 +834,9 @@ export class HNSW<T = string> {
// Update candidates and found nodes using the current node's neighbors
const curNode = graphLayer.graph.get(nearestCandidate.key);
if (curNode === undefined) {
throw Error(`Cannot find node with key ${nearestCandidate.key}`);
throw Error(
`Cannot find node with key ${JSON.stringify(nearestCandidate.key)}`
);
}

for (const neighborKey of curNode.keys()) {
Expand Down Expand Up @@ -927,7 +961,7 @@ export class HNSW<T = string> {
async _getNodeInfo(key: T) {
const node = await this.nodes.get(key);
if (node === undefined) {
throw Error(`Can't find node with key ${key}`);
throw Error(`Can't find node with key ${JSON.stringify(key)}`);
}
return node;
}
Expand Down
Loading

0 comments on commit d711874

Please sign in to comment.