Skip to content

Commit

Permalink
First working version of the heed based reader
Browse files Browse the repository at this point in the history
  • Loading branch information
Kerollmops committed Nov 11, 2023
1 parent 10c661c commit cf7f551
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 28 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
Cargo.lock
/target
/assets/test.tree
*.out
*.tree
11 changes: 11 additions & 0 deletions diff-run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#! /bin/bash

# Runs the classic and heed-based version of Annoy and compare them side-by-side.
# Very useful when you make sure to log the same things in both programs.

set -v

cargo run --bin classic > classic.out
cargo run --bin heed > heed.out

diff --side-by-side classic.out heed.out
19 changes: 15 additions & 4 deletions src/arroy_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ impl<'a> ArroyReader<'a> {
}
}

println!("{roots:?}");
println!("{max_descendants:?}");

ArroyReader {
dimension,
distance_type,
Expand Down Expand Up @@ -129,9 +132,13 @@ impl<'a> ArroyReader<'a> {
let top_node_header = top_node.header;
let top_node_offset = top_node.offset;
let n_descendants = top_node_header.get_n_descendant();
println!("top_node_id: {top_node_id_i32:?} {:?}", n_descendants);
if n_descendants == 1 && top_node_id < self.size {
println!("top_node_id: {top_node_id:?}");
nearest_neighbors.push(top_node_id_i32);
} else if n_descendants <= self.max_descendants {
let children_ids = self.descendant_ids(top_node_offset, n_descendants as usize);
children_ids.for_each(|id| println!("id: {id:?}"));
let children_ids = self.descendant_ids(top_node_offset, n_descendants as usize);
nearest_neighbors.extend(children_ids);
} else {
Expand All @@ -150,7 +157,12 @@ impl<'a> ArroyReader<'a> {
}
}
}

println!("{:?}", nearest_neighbors.len());
println!("{nearest_neighbors:?}");

nearest_neighbors.sort_unstable();

let mut sorted_nns = BinaryHeap::with_capacity(nearest_neighbors.len());
let mut nn_id_last = -1;
for nn_id in nearest_neighbors {
Expand All @@ -166,10 +178,9 @@ impl<'a> ArroyReader<'a> {
}

let s = self.node_slice_with_offset(nn_id as usize * self.node_size);
sorted_nns.push(Reverse(BinaryHeapItem {
item: nn_id,
ord: OrderedFloat(self.distance_no_norm(&s, query_vector)),
}));
let distance = self.distance_no_norm(&s, query_vector);
println!("{nn_id}: {distance:?}");
sorted_nns.push(Reverse(BinaryHeapItem { item: nn_id, ord: OrderedFloat(distance) }));
}

let final_result_capacity = n_results.min(sorted_nns.len());
Expand Down
17 changes: 17 additions & 0 deletions src/bin/classic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use arroy::{ArroyReader, DistanceType};

fn main() -> std::io::Result<()> {
let dimensions = 40;
let distance_type = DistanceType::Angular;
let tree = std::fs::read("test.tree").unwrap();

let arroy = ArroyReader::new(&tree[..], dimensions, distance_type);
// dbg!(&arroy);
let v = arroy.item_vector(0).unwrap();
let results = arroy.nns_by_item(0, 3, None).unwrap();

println!("{v:?}");
println!("{results:?}");

Ok(())
}
29 changes: 29 additions & 0 deletions src/bin/heed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use arroy::{DistanceType, HeedReader};
use heed::EnvOpenOptions;

const TWENTY_HUNDRED_MIB: usize = 200 * 1024 * 1024;

fn main() -> heed::Result<()> {
let dimensions = 40;
let distance_type = DistanceType::Angular;
let tree = std::fs::read("test.tree").unwrap();

let dir = tempfile::tempdir()?;
let env = EnvOpenOptions::new().map_size(TWENTY_HUNDRED_MIB).open(dir.path())?;

// we will open the default unnamed database
let mut wtxn = env.write_txn()?;
let database = env.create_database(&mut wtxn, None)?;
HeedReader::load_from_tree(&mut wtxn, database, dimensions, distance_type, &tree)?;
wtxn.commit()?;

let rtxn = env.read_txn()?;
let arroy = HeedReader::new(&rtxn, database, dimensions, distance_type)?;
let v = arroy.item_vector(&rtxn, 0)?.unwrap();
let results = arroy.nns_by_item(&rtxn, 0, 3, None)?.unwrap();

println!("{v:?}");
println!("{results:?}");

Ok(())
}
40 changes: 29 additions & 11 deletions src/heed_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use std::borrow::Cow;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::mem::size_of;
use std::ops::RemAssign;
use std::{io, iter};
use std::{iter, mem};

use bytemuck::{pod_collect_to_vec, Pod, Zeroable};
use byteorder::{ByteOrder, NativeEndian};
Expand All @@ -15,7 +14,6 @@ use crate::distance::{
cosine_distance_no_simd, dot_product_no_simd, euclidean_distance_no_simd,
manhattan_distance_no_simd, minkowski_margin,
};
use crate::node::*;
use crate::priority_queue::BinaryHeapItem;
use crate::DistanceType;

Expand Down Expand Up @@ -89,6 +87,9 @@ impl HeedReader {
}
}

println!("{roots:?}");
println!("{max_descendants:?}");

Ok(HeedReader {
dimension: dimensions,
distance_type,
Expand Down Expand Up @@ -181,9 +182,17 @@ impl HeedReader {
while !pq.is_empty() && nearest_neighbors.len() < search_k {
if let Some(BinaryHeapItem { item: top_node_id, ord: top_node_margin }) = pq.pop() {
let node_bytes = self.database.get(rtxn, &top_node_id)?.unwrap();
match Node::from_bytes(node_bytes, self.max_descendants) {
Node::Leaf(_) => nearest_neighbors.push(top_node_id),
let node = Node::from_bytes(node_bytes, self.max_descendants);
println!("top_node_id: {top_node_id:?} {:?}", node.n_descendants());
match node {
Node::Leaf(_) => {
if (top_node_id as usize) < self.size {
println!("top_node_id: {top_node_id:?}");
nearest_neighbors.push(top_node_id);
}
}
Node::Descendants(descendants) => {
descendants.descendants_ids().for_each(|id| println!("id: {id:?}"));
nearest_neighbors.extend(descendants.descendants_ids())
}
Node::SplitPlaneNormal(normal) => {
Expand All @@ -202,7 +211,12 @@ impl HeedReader {
}
}
}

println!("{:?}", nearest_neighbors.len());
println!("{nearest_neighbors:?}");

nearest_neighbors.sort_unstable();

let mut sorted_nns = BinaryHeap::with_capacity(nearest_neighbors.len());
let mut nn_id_last = None;
for nn_id in nearest_neighbors {
Expand All @@ -213,16 +227,19 @@ impl HeedReader {
let node_bytes = self.database.get(rtxn, &nn_id)?.unwrap();
if let Node::Leaf(node) = Node::from_bytes(node_bytes, self.max_descendants) {
let s = node.vector();
sorted_nns.push(Reverse(BinaryHeapItem {
item: nn_id,
ord: OrderedFloat(DistanceType::Angular.distance_no_norm(&s, query_vector)),
}));
let distance = DistanceType::Angular.distance_no_norm(&s, query_vector);
println!("{nn_id}: {distance:?}");
sorted_nns
.push(Reverse(BinaryHeapItem { item: nn_id, ord: OrderedFloat(distance) }));
}
}

let final_result_capacity = n_results.min(sorted_nns.len());
let mut output = Vec::with_capacity(final_result_capacity);
while let Some(Reverse(heap_item)) = sorted_nns.pop() {
if output.len() == final_result_capacity {
break;
}
let BinaryHeapItem { item, ord: OrderedFloat(dist) } = heap_item;
output.push((item, DistanceType::Angular.normalized_distance(dist)));
}
Expand Down Expand Up @@ -295,8 +312,9 @@ impl<'a> Node<'a> {
let (header, vector_bytes) = NodeHeader::from_bytes(bytes, DistanceType::Angular);
Node::Leaf(Leaf { header, vector_bytes })
} else if n_descendants as usize <= max_descendants {
let offset_before_children = NodeHeaderAngular::offset_before_children();
let descendants_bytes = &bytes[offset_before_children..];
let offset = NodeHeaderAngular::offset_before_children();
let length = n_descendants as usize * size_of::<u32>();
let descendants_bytes = &bytes[offset..offset + length];
Node::Descendants(Descendants { n_descendants, descendants_bytes })
} else {
let (header, normal_bytes) = NodeHeader::from_bytes(bytes, DistanceType::Angular);
Expand Down
31 changes: 18 additions & 13 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use annoy_rs::*;
use arroy::{DistanceType, HeedReader};
use arroy::{ArroyReader, DistanceType, HeedReader};
use heed::EnvOpenOptions;

const TWENTY_HUNDRED_MIB: usize = 200 * 1024 * 1024;
Expand All @@ -26,21 +26,26 @@ fn main() -> heed::Result<()> {
let results = arroy.nns_by_item(&rtxn, 0, 3, None)?.unwrap();
println!("{v:?}");

let index = AnnoyIndex::load(40, "test.tree", IndexType::Angular).unwrap();
// dbg!(&index);
let v0 = index.get_item_vector(0);
let results0 = index.get_nearest_to_item(0, 3, -1, true);
println!("{v0:?}");
let arroy = ArroyReader::new(&tree[..], dimensions, distance_type);
// dbg!(&arroy);
let classic_v = arroy.item_vector(0).unwrap();
let classic_results = arroy.nns_by_item(0, 3, None).unwrap();

// let index = AnnoyIndex::load(40, "test.tree", IndexType::Angular).unwrap();
// // dbg!(&index);
// let v0 = index.get_item_vector(0);
// let results0 = index.get_nearest_to_item(0, 3, -1, true);
// println!("{v0:?}");

assert_eq!(v, v0);
assert_eq!(v, classic_v);

assert_eq!(results[0].0, results0.id_list[0] as u32);
assert_eq!(results[1].0, results0.id_list[1] as u32);
assert_eq!(results[2].0, results0.id_list[2] as u32);
assert_eq!(results[0].0, classic_results[0].0 as u32);
assert_eq!(results[1].0, classic_results[1].0 as u32);
assert_eq!(results[2].0, classic_results[2].0 as u32);

assert_eq!(results[0].1, results0.distance_list[0]);
assert_eq!(results[1].1, results0.distance_list[1]);
assert_eq!(results[2].1, results0.distance_list[2]);
assert_eq!(results[0].1, classic_results[0].1);
assert_eq!(results[1].1, classic_results[1].1);
assert_eq!(results[2].1, classic_results[2].1);

Ok(())
}

0 comments on commit cf7f551

Please sign in to comment.