diff --git a/src/mememo.ts b/src/mememo.ts index 7c66a92..e64de98 100644 --- a/src/mememo.ts +++ b/src/mememo.ts @@ -190,10 +190,12 @@ export class HNSW { * this value in most cases. We add this parameter for testing purpose. */ insert(key: T, value: number[], maxLevel?: number | undefined) { - // If the key already exists, update the node + // If the key already exists, throw an error if (this.nodes.has(key)) { - // TODO: Update the node - return; + throw Error( + `There is already a node with key ${key} in the index. ` + + 'Use update() to update this node.' + ); } // Randomly determine the max level of this node @@ -306,9 +308,12 @@ export class HNSW { * @param key Key of the element. * @param value The new embedding of the element */ - _update(key: T, value: number[]) { + update(key: T, value: number[]) { if (!this.nodes.has(key)) { - throw Error(`The node with key ${key} does not exist.`); + throw Error( + `The node with key ${key} does not exist. ` + + 'Use insert() to add new node.' + ); } this.nodes.set(key, new Node(key, value)); @@ -455,6 +460,9 @@ export class HNSW { this.efConstruction + 1 ); + // Remove the current node itself as it would be selected (0 distance) + entryPoints = entryPoints.filter(d => d.key !== key); + // Prune the neighbors so we have at most levelM neighbors const selectedNeighbors = this._selectNeighborsHeuristic( entryPoints, diff --git a/test/data/update-50-3-layer-10.json b/test/data/update-50-3-layer-10.json new file mode 100644 index 0000000..8845a86 --- /dev/null +++ b/test/data/update-50-3-layer-10.json @@ -0,0 +1 @@ +[{"1": {"8": 0.1248, "372": 0.3222}, "2": {"4": 0.121, "373": 0.2152, "1342": 0.2451, "372": 0.2455, "145": 0.2455, "377": 0.2766}, "3": {"12": 0.1373, "15": 0.1391, "11": 0.1406}, "4": {"385": 0.2577, "156": 0.2594, "15": 0.3242}, "5": {"386": 0.2163, "1342": 0.2312, "146": 0.2543, "6": 0.3429}, "6": {"8": 0.1348, "15": 0.1661, "145": 0.3142}, "7": {"4": 0.1417, "388": 0.2401, "157": 0.2505, "2": 0.2786}, "8": {"1": 0.1248, "12": 0.1254, "6": 0.1348, "9": 0.1591}, "9": {"8": 0.1591, "11": 0.2278}, "10": {"4": 0.0807, "156": 0.2234, "389": 0.2342}, "11": {"3": 0.1406, "145": 0.2612}, "12": {"8": 0.1254, "3": 0.1373, "14": 0.1435, "150": 0.2693}, "14": {"12": 0.1435}, "15": {"3": 0.1391, "14": 0.1869, "377": 0.3038, "7": 0.3039}, "139": {"142": 0.0244, "146": 0.0317, "153": 0.0369, "145": 0.0753, "4": 0.2622}, "141": {"157": 0.0341, "153": 0.0419}, "142": {"139": 0.0244, "143": 0.0338, "156": 0.0396, "154": 0.0588, "9": 0.3349}, "143": {"155": 0.0326, "142": 0.0338, "153": 0.0347, "150": 0.0429, "148": 0.0945, "10": 0.2426}, "144": {"148": 0.0464, "153": 0.0527, "12": 0.2977}, "145": {"139": 0.0753, "10": 0.2371, "382": 0.2429, "5": 0.2579, "11": 0.2612}, "146": {"139": 0.0317, "149": 0.0567, "384": 0.2512, "5": 0.2543, "14": 0.3656}, "148": {"144": 0.0464, "388": 0.2491, "2": 0.2936}, "149": {"153": 0.0521, "10": 0.3548}, "150": {"1": 0.338, "6": 0.3314, "8": 0.3346, "143": 0.0429, "12": 0.2693, "5": 0.2796, "9": 0.3117}, "151": {"152": 0.0531, "157": 0.0606, "145": 0.0938, "389": 0.2634}, "152": {"153": 0.0381, "151": 0.0531}, "153": {"143": 0.0347, "152": 0.0381, "141": 0.0419, "154": 0.0479, "149": 0.0521, "144": 0.0527, "145": 0.0858}, "154": {"153": 0.0479, "372": 0.3222, "14": 0.37}, "155": {"143": 0.0326, "141": 0.0565, "373": 0.2566, "5": 0.2805}, "156": {"142": 0.0396, "144": 0.0717, "4": 0.2594, "382": 0.2606, "11": 0.2811}, "157": {"141": 0.0341, "148": 0.1041, "373": 0.263, "12": 0.3001}, "372": {"386": 0.1769, "2": 0.2455, "11": 0.292}, "373": {"381": 0.0666, "388": 0.0902, "2": 0.2152, "155": 0.2566}, "374": {"389": 0.0773, "376": 0.11}, "375": {"389": 0.0994, "384": 0.1795}, "376": {"387": 0.1081, "1342": 0.2215, "15": 0.3509}, "377": {"384": 0.2761, "386": 0.2909, "15": 0.3038}, "378": {"7": 0.0868, "4": 0.1214, "389": 0.264, "141": 0.2896}, "379": {"380": 0.0379, "388": 0.0885, "2": 0.2213, "149": 0.2835}, "380": {"379": 0.0379, "387": 0.0428, "381": 0.0633, "8": 0.3691}, "381": {"380": 0.0633, "373": 0.0666, "383": 0.0715, "389": 0.0732, "145": 0.263}, "382": {"387": 0.0714, "385": 0.0919, "145": 0.2429}, "383": {"387": 0.0596, "2": 0.2217, "1342": 0.2349}, "384": {"4": 0.1385, "385": 0.1692, "156": 0.2415, "5": 0.2728, "377": 0.2761, "3": 0.3246}, "385": {"382": 0.0919, "389": 0.0938, "1342": 0.1669, "384": 0.1692}, "386": {"372": 0.1769, "5": 0.2163, "150": 0.2783, "14": 0.3401}, "387": {"380": 0.0428, "383": 0.0596, "382": 0.0714, "373": 0.0724, "389": 0.0845, "376": 0.1081, "4": 0.2669}, "388": {"389": 0.0428, "379": 0.0885, "384": 0.194, "151": 0.2606}, "389": {"388": 0.0428, "381": 0.0732, "374": 0.0773, "385": 0.0938, "375": 0.0994, "10": 0.2342, "14": 0.3789}, "1342": {"385": 0.1669, "5": 0.2312, "148": 0.2637, "377": 0.2777}}, {"7": {"14": 0.3716, "149": 0.2962, "15": 0.3039, "380": 0.2869, "379": 0.2677, "381": 0.2815, "145": 0.2529, "142": 0.2869, "389": 0.2651}, "14": {"381": 0.4001, "379": 0.3948, "149": 0.3837, "142": 0.3802, "380": 0.382, "145": 0.3291, "389": 0.3789, "7": 0.3716, "15": 0.1869}, "15": {"389": 0.3745, "380": 0.3705, "381": 0.3736, "149": 0.3642, "379": 0.3591, "7": 0.3039, "145": 0.3162, "14": 0.1869, "142": 0.3598}, "142": {"14": 0.3802, "15": 0.3598, "381": 0.305, "380": 0.3084, "379": 0.2876, "389": 0.3013, "145": 0.0868, "149": 0.0602, "7": 0.2869}, "145": {"14": 0.3291, "15": 0.3162, "389": 0.2692, "380": 0.2697, "7": 0.2529, "381": 0.263, "149": 0.089, "142": 0.0868, "379": 0.2475}, "149": {"14": 0.3837, "15": 0.3642, "389": 0.3003, "380": 0.3051, "7": 0.2962, "381": 0.2933, "145": 0.089, "142": 0.0602, "379": 0.2835}, "379": {"14": 0.3948, "15": 0.3591, "149": 0.2835, "142": 0.2876, "381": 0.0634, "389": 0.1019, "145": 0.2475, "380": 0.0379, "7": 0.2677}, "380": {"14": 0.382, "15": 0.3705, "149": 0.3051, "142": 0.3084, "381": 0.0633, "389": 0.1035, "145": 0.2697, "379": 0.0379, "7": 0.2869}, "381": {"14": 0.4001, "15": 0.3736, "149": 0.2933, "142": 0.305, "380": 0.0633, "389": 0.0732, "145": 0.263, "379": 0.0634, "7": 0.2815}, "389": {"14": 0.3789, "15": 0.3745, "149": 0.3003, "142": 0.3013, "381": 0.0732, "379": 0.1019, "145": 0.2692, "380": 0.1035, "7": 0.2651}}, {"145": {"149": 0.089, "379": 0.2475}, "149": {"145": 0.089, "379": 0.2835}, "379": {"149": 0.2835, "145": 0.2475}}] \ No newline at end of file diff --git a/test/mememo.test.ts b/test/mememo.test.ts index 63f819c..a00a186 100644 --- a/test/mememo.test.ts +++ b/test/mememo.test.ts @@ -7,6 +7,7 @@ import graph10Layer2JSON from './data/insert-10-2-layer.json'; import graph30Layer3JSON from './data/insert-30-3-layer.json'; import graph100Layer6JSON from './data/insert-100-6-layer.json'; import graph100Layer3M3JSON from './data/insert-100-3-layer-m=3.json'; +import graph50Update10JSON from './data/update-50-3-layer-10.json'; interface EmbeddingData { embeddings: number[][]; @@ -21,6 +22,40 @@ const graph10Layer2 = graph10Layer2JSON as GraphLayer[]; const graph30Layer3 = graph30Layer3JSON as GraphLayer[]; const graph100Layer6 = graph100Layer6JSON as GraphLayer[]; const graph100Layer3M3 = graph100Layer3M3JSON as GraphLayer[]; +const graph50Update10 = graph50Update10JSON as GraphLayer[]; + +/** + * Check if the graphs in HNSW match the expected graph layers from json + * @param reportIDs Report IDs in the hnsw + * @param hnsw HNSW index + * @param expectedGraphs Expected graph layers loaded from json + */ +const _checkGraphLayers = ( + reportIDs: string[], + hnsw: HNSW, + expectedGraphs: GraphLayer[] +) => { + for (const reportID of reportIDs) { + for (const [l, graphLayer] of hnsw.graphLayers.entries()) { + const curNode = graphLayer.graph.get(reportID); + + if (curNode === undefined) { + expect(expectedGraphs[l][reportID]).toBeUndefined(); + } else { + expect(expectedGraphs[l][reportID]).not.to.toBeUndefined(); + // Check the distances + const expectedNeighbors = expectedGraphs[l][reportID]; + for (const [neighborKey, neighborDistance] of curNode.entries()) { + expect(expectedNeighbors[neighborKey]).to.not.toBeUndefined(); + expect(neighborDistance).toBeCloseTo( + expectedNeighbors[neighborKey]!, + 1e-6 + ); + } + } + } + } +}; describe('constructor', () => { it('constructor', () => { @@ -56,22 +91,7 @@ describe('insert()', () => { // There should be only one layer, and all nodes are fully connected expect(hnsw.graphLayers.length).toBe(1); - - for (const reportID of reportIDs) { - const curNode = hnsw.graphLayers[0].graph.get(reportID); - expect(curNode).to.not.toBeUndefined(); - expect(curNode!.size).toBe(9); - - // Check the distances - const expectedNeighbors = graph10Layer1[0][reportID]; - for (const [neighborKey, neighborDistance] of curNode!.entries()) { - expect(expectedNeighbors[neighborKey]).to.not.toBeUndefined(); - expect(neighborDistance).toBeCloseTo( - expectedNeighbors[neighborKey]!, - 1e-6 - ); - } - } + _checkGraphLayers(reportIDs, hnsw, graph10Layer1); }); it('insert() 10 items, 2 layer', () => { @@ -92,27 +112,7 @@ describe('insert()', () => { } expect(hnsw.graphLayers.length).toBe(2); - - for (const reportID of reportIDs) { - for (const [l, graphLayer] of hnsw.graphLayers.entries()) { - const curNode = graphLayer.graph.get(reportID); - - if (curNode === undefined) { - expect(graph10Layer2[l][reportID]).toBeUndefined(); - } else { - expect(graph10Layer2[l][reportID]).not.to.toBeUndefined(); - // Check the distances - const expectedNeighbors = graph10Layer2[l][reportID]; - for (const [neighborKey, neighborDistance] of curNode.entries()) { - expect(expectedNeighbors[neighborKey]).to.not.toBeUndefined(); - expect(neighborDistance).toBeCloseTo( - expectedNeighbors[neighborKey]!, - 1e-6 - ); - } - } - } - } + _checkGraphLayers(reportIDs, hnsw, graph10Layer2); }); it('insert() 30 items, 3 layer', () => { @@ -134,27 +134,7 @@ describe('insert()', () => { } expect(hnsw.graphLayers.length).toBe(3); - - for (const reportID of reportIDs) { - for (const [l, graphLayer] of hnsw.graphLayers.entries()) { - const curNode = graphLayer.graph.get(reportID); - - if (curNode === undefined) { - expect(graph30Layer3[l][reportID]).toBeUndefined(); - } else { - expect(graph30Layer3[l][reportID]).not.to.toBeUndefined(); - // Check the distances - const expectedNeighbors = graph30Layer3[l][reportID]; - for (const [neighborKey, neighborDistance] of curNode.entries()) { - expect(expectedNeighbors[neighborKey]).to.not.toBeUndefined(); - expect(neighborDistance).toBeCloseTo( - expectedNeighbors[neighborKey]!, - 1e-6 - ); - } - } - } - } + _checkGraphLayers(reportIDs, hnsw, graph30Layer3); }); it('insert() 100 items, 6 layer', () => { @@ -179,27 +159,7 @@ describe('insert()', () => { } expect(hnsw.graphLayers.length).toBe(6); - - for (const reportID of reportIDs) { - for (const [l, graphLayer] of hnsw.graphLayers.entries()) { - const curNode = graphLayer.graph.get(reportID); - - if (curNode === undefined) { - expect(graph100Layer6[l][reportID]).toBeUndefined(); - } else { - expect(graph100Layer6[l][reportID]).not.to.toBeUndefined(); - // Check the distances - const expectedNeighbors = graph100Layer6[l][reportID]; - for (const [neighborKey, neighborDistance] of curNode.entries()) { - expect(expectedNeighbors[neighborKey]).to.not.toBeUndefined(); - expect(neighborDistance).toBeCloseTo( - expectedNeighbors[neighborKey]!, - 1e-6 - ); - } - } - } - } + _checkGraphLayers(reportIDs, hnsw, graph100Layer6); }); it('insert() 100 items, 3 layer, m=3', () => { @@ -229,51 +189,81 @@ describe('insert()', () => { } expect(hnsw.graphLayers.length).toBe(3); + _checkGraphLayers(reportIDs, hnsw, graph100Layer3M3); + }); + + it.skip('Find random seeds', () => { + // Find random seed that give a nice level sequence to test + const size = 50; + for (let i = 1; i < 100000; i++) { + const rng = randomLcg(i); + const curLevels: number[] = []; + const ml = 1 / Math.log(16); + + for (let j = 0; j < size; j++) { + const level = Math.floor(-Math.log(rng()) * ml); + curLevels.push(level); + } - for (const reportID of reportIDs) { - for (const [l, graphLayer] of hnsw.graphLayers.entries()) { - const curNode = graphLayer.graph.get(reportID); - - if (curNode === undefined) { - expect(graph100Layer3M3[l][reportID]).toBeUndefined(); - } else { - expect(graph100Layer3M3[l][reportID]).not.to.toBeUndefined(); - // Check the distances - const expectedNeighbors = graph100Layer3M3[l][reportID]; - for (const [neighborKey, neighborDistance] of curNode.entries()) { - expect(expectedNeighbors[neighborKey]).to.not.toBeUndefined(); - expect(neighborDistance).toBeCloseTo( - expectedNeighbors[neighborKey]!, - 1e-6 - ); - } + if (Math.max(...curLevels) < 4) { + const levelSum = curLevels.reduce((sum, value) => sum + value, 0); + if (levelSum > 12) { + console.log('Good seed: ', i); + console.log(curLevels); + break; } } } }); +}); + +//==========================================================================|| +// Update || +//==========================================================================|| + +describe('update()', () => { + it('update() 10 / 50 items', () => { + const hnsw = new HNSW({ + distanceFunction: 'cosine', + seed: 65975 + }); + + // Insert 50 embeddings + const size = 50; + + // The random levels with this seed is [ 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + // 1, 1, 0, 0, 1, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0 ] + const reportIDs: string[] = []; + for (let i = 0; i < size; i++) { + const curReportID = String(embeddingData.reportNumbers[i]); + reportIDs.push(curReportID); + hnsw.insert(curReportID, embeddingData.embeddings[i]); + } - // it('Find random seeds', () => { - // // Find random seed that give a nice level sequence to test - // const size = 100; - // for (let i = 1; i < 100000; i++) { - // const rng = randomLcg(i); - // const curLevels: number[] = []; - // const ml = 1 / Math.log(16); - - // for (let j = 0; j < size; j++) { - // const level = Math.floor(-Math.log(rng()) * ml); - // curLevels.push(level); - // } - - // if (Math.max(...curLevels) < 4) { - // const levelSum = curLevels.reduce((sum, value) => sum + value, 0); - // if (levelSum > 20) { - // console.log('Good seed: ', i); - // break; - // } - // } - // } - // }); + // Update 10 nodes + const updateIndexes = [ + [3, 71], + [6, 63], + [36, 82], + [9, 67], + [31, 91], + [1, 55], + [43, 65], + [4, 85], + [37, 61], + [45, 86] + ]; + + for (const pair of updateIndexes) { + const oldKey = String(embeddingData.reportNumbers[pair[0]]); + const newValue = embeddingData.embeddings[pair[1]]; + hnsw.update(oldKey, newValue); + } + + expect(hnsw.graphLayers.length).toBe(3); + _checkGraphLayers(reportIDs, hnsw, graph50Update10); + }); }); //==========================================================================||