Skip to content

Commit

Permalink
Test update(), fix a bug where reIndexNode() adds self loop edge
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Wang <[email protected]>
  • Loading branch information
xiaohk committed Jan 31, 2024
1 parent 624ace0 commit 0edfdf6
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 123 deletions.
18 changes: 13 additions & 5 deletions src/mememo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,12 @@ export class HNSW<T = string> {
* this value in most cases. We add this parameter for testing purpose.
*/
insert(key: T, value: number[], maxLevel?: number | undefined) {
// If the key already exists, update the node
// If the key already exists, throw an error
if (this.nodes.has(key)) {
// TODO: Update the node
return;
throw Error(
`There is already a node with key ${key} in the index. ` +
'Use update() to update this node.'
);
}

// Randomly determine the max level of this node
Expand Down Expand Up @@ -306,9 +308,12 @@ export class HNSW<T = string> {
* @param key Key of the element.
* @param value The new embedding of the element
*/
_update(key: T, value: number[]) {
update(key: T, value: number[]) {
if (!this.nodes.has(key)) {
throw Error(`The node with key ${key} does not exist.`);
throw Error(
`The node with key ${key} does not exist. ` +
'Use insert() to add new node.'
);
}

this.nodes.set(key, new Node(key, value));
Expand Down Expand Up @@ -455,6 +460,9 @@ export class HNSW<T = string> {
this.efConstruction + 1
);

// Remove the current node itself as it would be selected (0 distance)
entryPoints = entryPoints.filter(d => d.key !== key);

// Prune the neighbors so we have at most levelM neighbors
const selectedNeighbors = this._selectNeighborsHeuristic(
entryPoints,
Expand Down
1 change: 1 addition & 0 deletions test/data/update-50-3-layer-10.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"1": {"8": 0.1248, "372": 0.3222}, "2": {"4": 0.121, "373": 0.2152, "1342": 0.2451, "372": 0.2455, "145": 0.2455, "377": 0.2766}, "3": {"12": 0.1373, "15": 0.1391, "11": 0.1406}, "4": {"385": 0.2577, "156": 0.2594, "15": 0.3242}, "5": {"386": 0.2163, "1342": 0.2312, "146": 0.2543, "6": 0.3429}, "6": {"8": 0.1348, "15": 0.1661, "145": 0.3142}, "7": {"4": 0.1417, "388": 0.2401, "157": 0.2505, "2": 0.2786}, "8": {"1": 0.1248, "12": 0.1254, "6": 0.1348, "9": 0.1591}, "9": {"8": 0.1591, "11": 0.2278}, "10": {"4": 0.0807, "156": 0.2234, "389": 0.2342}, "11": {"3": 0.1406, "145": 0.2612}, "12": {"8": 0.1254, "3": 0.1373, "14": 0.1435, "150": 0.2693}, "14": {"12": 0.1435}, "15": {"3": 0.1391, "14": 0.1869, "377": 0.3038, "7": 0.3039}, "139": {"142": 0.0244, "146": 0.0317, "153": 0.0369, "145": 0.0753, "4": 0.2622}, "141": {"157": 0.0341, "153": 0.0419}, "142": {"139": 0.0244, "143": 0.0338, "156": 0.0396, "154": 0.0588, "9": 0.3349}, "143": {"155": 0.0326, "142": 0.0338, "153": 0.0347, "150": 0.0429, "148": 0.0945, "10": 0.2426}, "144": {"148": 0.0464, "153": 0.0527, "12": 0.2977}, "145": {"139": 0.0753, "10": 0.2371, "382": 0.2429, "5": 0.2579, "11": 0.2612}, "146": {"139": 0.0317, "149": 0.0567, "384": 0.2512, "5": 0.2543, "14": 0.3656}, "148": {"144": 0.0464, "388": 0.2491, "2": 0.2936}, "149": {"153": 0.0521, "10": 0.3548}, "150": {"1": 0.338, "6": 0.3314, "8": 0.3346, "143": 0.0429, "12": 0.2693, "5": 0.2796, "9": 0.3117}, "151": {"152": 0.0531, "157": 0.0606, "145": 0.0938, "389": 0.2634}, "152": {"153": 0.0381, "151": 0.0531}, "153": {"143": 0.0347, "152": 0.0381, "141": 0.0419, "154": 0.0479, "149": 0.0521, "144": 0.0527, "145": 0.0858}, "154": {"153": 0.0479, "372": 0.3222, "14": 0.37}, "155": {"143": 0.0326, "141": 0.0565, "373": 0.2566, "5": 0.2805}, "156": {"142": 0.0396, "144": 0.0717, "4": 0.2594, "382": 0.2606, "11": 0.2811}, "157": {"141": 0.0341, "148": 0.1041, "373": 0.263, "12": 0.3001}, "372": {"386": 0.1769, "2": 0.2455, "11": 0.292}, "373": {"381": 0.0666, "388": 0.0902, "2": 0.2152, "155": 0.2566}, "374": {"389": 0.0773, "376": 0.11}, "375": {"389": 0.0994, "384": 0.1795}, "376": {"387": 0.1081, "1342": 0.2215, "15": 0.3509}, "377": {"384": 0.2761, "386": 0.2909, "15": 0.3038}, "378": {"7": 0.0868, "4": 0.1214, "389": 0.264, "141": 0.2896}, "379": {"380": 0.0379, "388": 0.0885, "2": 0.2213, "149": 0.2835}, "380": {"379": 0.0379, "387": 0.0428, "381": 0.0633, "8": 0.3691}, "381": {"380": 0.0633, "373": 0.0666, "383": 0.0715, "389": 0.0732, "145": 0.263}, "382": {"387": 0.0714, "385": 0.0919, "145": 0.2429}, "383": {"387": 0.0596, "2": 0.2217, "1342": 0.2349}, "384": {"4": 0.1385, "385": 0.1692, "156": 0.2415, "5": 0.2728, "377": 0.2761, "3": 0.3246}, "385": {"382": 0.0919, "389": 0.0938, "1342": 0.1669, "384": 0.1692}, "386": {"372": 0.1769, "5": 0.2163, "150": 0.2783, "14": 0.3401}, "387": {"380": 0.0428, "383": 0.0596, "382": 0.0714, "373": 0.0724, "389": 0.0845, "376": 0.1081, "4": 0.2669}, "388": {"389": 0.0428, "379": 0.0885, "384": 0.194, "151": 0.2606}, "389": {"388": 0.0428, "381": 0.0732, "374": 0.0773, "385": 0.0938, "375": 0.0994, "10": 0.2342, "14": 0.3789}, "1342": {"385": 0.1669, "5": 0.2312, "148": 0.2637, "377": 0.2777}}, {"7": {"14": 0.3716, "149": 0.2962, "15": 0.3039, "380": 0.2869, "379": 0.2677, "381": 0.2815, "145": 0.2529, "142": 0.2869, "389": 0.2651}, "14": {"381": 0.4001, "379": 0.3948, "149": 0.3837, "142": 0.3802, "380": 0.382, "145": 0.3291, "389": 0.3789, "7": 0.3716, "15": 0.1869}, "15": {"389": 0.3745, "380": 0.3705, "381": 0.3736, "149": 0.3642, "379": 0.3591, "7": 0.3039, "145": 0.3162, "14": 0.1869, "142": 0.3598}, "142": {"14": 0.3802, "15": 0.3598, "381": 0.305, "380": 0.3084, "379": 0.2876, "389": 0.3013, "145": 0.0868, "149": 0.0602, "7": 0.2869}, "145": {"14": 0.3291, "15": 0.3162, "389": 0.2692, "380": 0.2697, "7": 0.2529, "381": 0.263, "149": 0.089, "142": 0.0868, "379": 0.2475}, "149": {"14": 0.3837, "15": 0.3642, "389": 0.3003, "380": 0.3051, "7": 0.2962, "381": 0.2933, "145": 0.089, "142": 0.0602, "379": 0.2835}, "379": {"14": 0.3948, "15": 0.3591, "149": 0.2835, "142": 0.2876, "381": 0.0634, "389": 0.1019, "145": 0.2475, "380": 0.0379, "7": 0.2677}, "380": {"14": 0.382, "15": 0.3705, "149": 0.3051, "142": 0.3084, "381": 0.0633, "389": 0.1035, "145": 0.2697, "379": 0.0379, "7": 0.2869}, "381": {"14": 0.4001, "15": 0.3736, "149": 0.2933, "142": 0.305, "380": 0.0633, "389": 0.0732, "145": 0.263, "379": 0.0634, "7": 0.2815}, "389": {"14": 0.3789, "15": 0.3745, "149": 0.3003, "142": 0.3013, "381": 0.0732, "379": 0.1019, "145": 0.2692, "380": 0.1035, "7": 0.2651}}, {"145": {"149": 0.089, "379": 0.2475}, "149": {"145": 0.089, "379": 0.2835}, "379": {"149": 0.2835, "145": 0.2475}}]
226 changes: 108 additions & 118 deletions test/mememo.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import graph10Layer2JSON from './data/insert-10-2-layer.json';
import graph30Layer3JSON from './data/insert-30-3-layer.json';
import graph100Layer6JSON from './data/insert-100-6-layer.json';
import graph100Layer3M3JSON from './data/insert-100-3-layer-m=3.json';
import graph50Update10JSON from './data/update-50-3-layer-10.json';

interface EmbeddingData {
embeddings: number[][];
Expand All @@ -21,6 +22,40 @@ const graph10Layer2 = graph10Layer2JSON as GraphLayer[];
const graph30Layer3 = graph30Layer3JSON as GraphLayer[];
const graph100Layer6 = graph100Layer6JSON as GraphLayer[];
const graph100Layer3M3 = graph100Layer3M3JSON as GraphLayer[];
const graph50Update10 = graph50Update10JSON as GraphLayer[];

/**
* Check if the graphs in HNSW match the expected graph layers from json
* @param reportIDs Report IDs in the hnsw
* @param hnsw HNSW index
* @param expectedGraphs Expected graph layers loaded from json
*/
const _checkGraphLayers = (
reportIDs: string[],
hnsw: HNSW,
expectedGraphs: GraphLayer[]
) => {
for (const reportID of reportIDs) {
for (const [l, graphLayer] of hnsw.graphLayers.entries()) {
const curNode = graphLayer.graph.get(reportID);

if (curNode === undefined) {
expect(expectedGraphs[l][reportID]).toBeUndefined();
} else {
expect(expectedGraphs[l][reportID]).not.to.toBeUndefined();
// Check the distances
const expectedNeighbors = expectedGraphs[l][reportID];
for (const [neighborKey, neighborDistance] of curNode.entries()) {
expect(expectedNeighbors[neighborKey]).to.not.toBeUndefined();
expect(neighborDistance).toBeCloseTo(
expectedNeighbors[neighborKey]!,
1e-6
);
}
}
}
}
};

describe('constructor', () => {
it('constructor', () => {
Expand Down Expand Up @@ -56,22 +91,7 @@ describe('insert()', () => {

// There should be only one layer, and all nodes are fully connected
expect(hnsw.graphLayers.length).toBe(1);

for (const reportID of reportIDs) {
const curNode = hnsw.graphLayers[0].graph.get(reportID);
expect(curNode).to.not.toBeUndefined();
expect(curNode!.size).toBe(9);

// Check the distances
const expectedNeighbors = graph10Layer1[0][reportID];
for (const [neighborKey, neighborDistance] of curNode!.entries()) {
expect(expectedNeighbors[neighborKey]).to.not.toBeUndefined();
expect(neighborDistance).toBeCloseTo(
expectedNeighbors[neighborKey]!,
1e-6
);
}
}
_checkGraphLayers(reportIDs, hnsw, graph10Layer1);
});

it('insert() 10 items, 2 layer', () => {
Expand All @@ -92,27 +112,7 @@ describe('insert()', () => {
}

expect(hnsw.graphLayers.length).toBe(2);

for (const reportID of reportIDs) {
for (const [l, graphLayer] of hnsw.graphLayers.entries()) {
const curNode = graphLayer.graph.get(reportID);

if (curNode === undefined) {
expect(graph10Layer2[l][reportID]).toBeUndefined();
} else {
expect(graph10Layer2[l][reportID]).not.to.toBeUndefined();
// Check the distances
const expectedNeighbors = graph10Layer2[l][reportID];
for (const [neighborKey, neighborDistance] of curNode.entries()) {
expect(expectedNeighbors[neighborKey]).to.not.toBeUndefined();
expect(neighborDistance).toBeCloseTo(
expectedNeighbors[neighborKey]!,
1e-6
);
}
}
}
}
_checkGraphLayers(reportIDs, hnsw, graph10Layer2);
});

it('insert() 30 items, 3 layer', () => {
Expand All @@ -134,27 +134,7 @@ describe('insert()', () => {
}

expect(hnsw.graphLayers.length).toBe(3);

for (const reportID of reportIDs) {
for (const [l, graphLayer] of hnsw.graphLayers.entries()) {
const curNode = graphLayer.graph.get(reportID);

if (curNode === undefined) {
expect(graph30Layer3[l][reportID]).toBeUndefined();
} else {
expect(graph30Layer3[l][reportID]).not.to.toBeUndefined();
// Check the distances
const expectedNeighbors = graph30Layer3[l][reportID];
for (const [neighborKey, neighborDistance] of curNode.entries()) {
expect(expectedNeighbors[neighborKey]).to.not.toBeUndefined();
expect(neighborDistance).toBeCloseTo(
expectedNeighbors[neighborKey]!,
1e-6
);
}
}
}
}
_checkGraphLayers(reportIDs, hnsw, graph30Layer3);
});

it('insert() 100 items, 6 layer', () => {
Expand All @@ -179,27 +159,7 @@ describe('insert()', () => {
}

expect(hnsw.graphLayers.length).toBe(6);

for (const reportID of reportIDs) {
for (const [l, graphLayer] of hnsw.graphLayers.entries()) {
const curNode = graphLayer.graph.get(reportID);

if (curNode === undefined) {
expect(graph100Layer6[l][reportID]).toBeUndefined();
} else {
expect(graph100Layer6[l][reportID]).not.to.toBeUndefined();
// Check the distances
const expectedNeighbors = graph100Layer6[l][reportID];
for (const [neighborKey, neighborDistance] of curNode.entries()) {
expect(expectedNeighbors[neighborKey]).to.not.toBeUndefined();
expect(neighborDistance).toBeCloseTo(
expectedNeighbors[neighborKey]!,
1e-6
);
}
}
}
}
_checkGraphLayers(reportIDs, hnsw, graph100Layer6);
});

it('insert() 100 items, 3 layer, m=3', () => {
Expand Down Expand Up @@ -229,51 +189,81 @@ describe('insert()', () => {
}

expect(hnsw.graphLayers.length).toBe(3);
_checkGraphLayers(reportIDs, hnsw, graph100Layer3M3);
});

it.skip('Find random seeds', () => {
// Find random seed that give a nice level sequence to test
const size = 50;
for (let i = 1; i < 100000; i++) {
const rng = randomLcg(i);
const curLevels: number[] = [];
const ml = 1 / Math.log(16);

for (let j = 0; j < size; j++) {
const level = Math.floor(-Math.log(rng()) * ml);
curLevels.push(level);
}

for (const reportID of reportIDs) {
for (const [l, graphLayer] of hnsw.graphLayers.entries()) {
const curNode = graphLayer.graph.get(reportID);

if (curNode === undefined) {
expect(graph100Layer3M3[l][reportID]).toBeUndefined();
} else {
expect(graph100Layer3M3[l][reportID]).not.to.toBeUndefined();
// Check the distances
const expectedNeighbors = graph100Layer3M3[l][reportID];
for (const [neighborKey, neighborDistance] of curNode.entries()) {
expect(expectedNeighbors[neighborKey]).to.not.toBeUndefined();
expect(neighborDistance).toBeCloseTo(
expectedNeighbors[neighborKey]!,
1e-6
);
}
if (Math.max(...curLevels) < 4) {
const levelSum = curLevels.reduce((sum, value) => sum + value, 0);
if (levelSum > 12) {
console.log('Good seed: ', i);
console.log(curLevels);
break;
}
}
}
});
});

//==========================================================================||
// Update ||
//==========================================================================||

describe('update()', () => {
it('update() 10 / 50 items', () => {
const hnsw = new HNSW({
distanceFunction: 'cosine',
seed: 65975
});

// Insert 50 embeddings
const size = 50;

// The random levels with this seed is [ 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
// 1, 1, 0, 0, 1, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
// 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0 ]
const reportIDs: string[] = [];
for (let i = 0; i < size; i++) {
const curReportID = String(embeddingData.reportNumbers[i]);
reportIDs.push(curReportID);
hnsw.insert(curReportID, embeddingData.embeddings[i]);
}

// it('Find random seeds', () => {
// // Find random seed that give a nice level sequence to test
// const size = 100;
// for (let i = 1; i < 100000; i++) {
// const rng = randomLcg(i);
// const curLevels: number[] = [];
// const ml = 1 / Math.log(16);

// for (let j = 0; j < size; j++) {
// const level = Math.floor(-Math.log(rng()) * ml);
// curLevels.push(level);
// }

// if (Math.max(...curLevels) < 4) {
// const levelSum = curLevels.reduce((sum, value) => sum + value, 0);
// if (levelSum > 20) {
// console.log('Good seed: ', i);
// break;
// }
// }
// }
// });
// Update 10 nodes
const updateIndexes = [
[3, 71],
[6, 63],
[36, 82],
[9, 67],
[31, 91],
[1, 55],
[43, 65],
[4, 85],
[37, 61],
[45, 86]
];

for (const pair of updateIndexes) {
const oldKey = String(embeddingData.reportNumbers[pair[0]]);
const newValue = embeddingData.embeddings[pair[1]];
hnsw.update(oldKey, newValue);
}

expect(hnsw.graphLayers.length).toBe(3);
_checkGraphLayers(reportIDs, hnsw, graph50Update10);
});
});

//==========================================================================||
Expand Down

0 comments on commit 0edfdf6

Please sign in to comment.