Skip to content

Commit

Permalink
improve lisibility of the relevancy benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
irevoire committed Jul 10, 2024
1 parent 7f004e8 commit b6ea560
Showing 1 changed file with 69 additions and 19 deletions.
88 changes: 69 additions & 19 deletions examples/relevancy.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt;

use rand::seq::SliceRandom;

use arroy::distances::{
Expand All @@ -11,50 +13,97 @@ use rand::{Rng, SeedableRng};

const TWENTY_HUNDRED_MIB: usize = 2 * 1024 * 1024 * 1024;

const NUMBER_VECTORS: usize = 10_000;
// The openAI dimensions
// const VECTOR_DIMENSIONS: usize = 256;
// const VECTOR_DIMENSIONS: usize = 512;
// const VECTOR_DIMENSIONS: usize = 1024;
const VECTOR_DIMENSIONS: usize = 1536;
// const VECTOR_DIMENSIONS: usize = 3072;
const NUMBER_VECTORS: usize = 4_000;

fn main() {
let dimensions_tested = [256, 512, 1024, 1536, 3072];
let recall_tested = [1, 10, 50, 100];

println!("Testing the following dimensions: @{dimensions_tested:?}");
println!("Testing the following recall: @{recall_tested:?}");
println!("Starting...");
println!();

for (distance_name, func) in &[
(Angular::name(), &measure_distance::<Angular, Angular> as &dyn Fn(usize)),
(Euclidean::name(), &measure_distance::<Euclidean, Euclidean> as &dyn Fn(usize)),
(Manhattan::name(), &measure_distance::<Manhattan, Manhattan> as &dyn Fn(usize)),
(DotProduct::name(), &measure_distance::<DotProduct, DotProduct> as &dyn Fn(usize)),
(Angular::name(), &measure_distance::<Angular, Angular> as &dyn Fn(usize, usize) -> f32),
(
Euclidean::name(),
&measure_distance::<Euclidean, Euclidean> as &dyn Fn(usize, usize) -> f32,
),
(
Manhattan::name(),
&measure_distance::<Manhattan, Manhattan> as &dyn Fn(usize, usize) -> f32,
),
(
DotProduct::name(),
&measure_distance::<DotProduct, DotProduct> as &dyn Fn(usize, usize) -> f32,
),
(
BinaryQuantizedEuclidean::name(),
&measure_distance::<BinaryQuantizedEuclidean, Euclidean> as &dyn Fn(usize),
&measure_distance::<BinaryQuantizedEuclidean, Euclidean>
as &dyn Fn(usize, usize) -> f32,
),
(
BinaryQuantizedManhattan::name(),
&measure_distance::<BinaryQuantizedManhattan, Manhattan> as &dyn Fn(usize),
&measure_distance::<BinaryQuantizedManhattan, Manhattan>
as &dyn Fn(usize, usize) -> f32,
),
] {
let now = std::time::Instant::now();
println!("{distance_name}");
for number_fetched in [1, 10, 50, 100] {
(func)(number_fetched);
// The openAI dimensions
for dimensions in [256, 512, 1024, 1536, 3072] {
let mut recall = Vec::new();
for number_fetched in recall_tested {
let rec = (func)(number_fetched, dimensions);
recall.push(Recall(rec));
}
println!("For {dimensions:4} dim, recall: {recall:3?}");
}
println!("Took {:?}", now.elapsed());
println!();
}
}

fn measure_distance<ArroyDistance: Distance, PerfectDistance: Distance>(number_fetched: usize) {
struct Recall(f32);

impl fmt::Debug for Recall {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
// red
f32::NEG_INFINITY..=0.25 => write!(f, "\x1b[1;31m")?,
// yellow
0.25..=0.5 => write!(f, "\x1b[1;33m")?,
// green
0.5..=0.75 => write!(f, "\x1b[1;32m")?,
// blue
0.75..=0.90 => write!(f, "\x1b[1;34m")?,
// cyan
0.90..=0.999 => write!(f, "\x1b[1;36m")?,
// underlined cyan
0.999..=f32::INFINITY => write!(f, "\x1b[1;4;36m")?,
_ => (),
}
write!(f, "{:.2}\x1b[0m", self.0)
}
}

fn measure_distance<ArroyDistance: Distance, PerfectDistance: Distance>(
number_fetched: usize,
dimensions: usize,
) -> f32 {
let dir = tempfile::tempdir().unwrap();
let env =
unsafe { EnvOpenOptions::new().map_size(TWENTY_HUNDRED_MIB).open(dir.path()) }.unwrap();

let mut rng = StdRng::seed_from_u64(13);
let points = generate_points(&mut rng, NUMBER_VECTORS, VECTOR_DIMENSIONS);
let points = generate_points(&mut rng, NUMBER_VECTORS, dimensions);
let mut wtxn = env.write_txn().unwrap();

let database = env
.create_database::<internals::KeyCodec, NodeCodec<ArroyDistance>>(&mut wtxn, None)
.unwrap();
load_into_arroy(&mut rng, &mut wtxn, database, VECTOR_DIMENSIONS, &points).unwrap();
load_into_arroy(&mut rng, &mut wtxn, database, dimensions, &points).unwrap();

let reader = arroy::Reader::open(&wtxn, 0, database).unwrap();

Expand All @@ -75,7 +124,8 @@ fn measure_distance<ArroyDistance: Distance, PerfectDistance: Distance>(number_f
}
}

println!("recall@{number_fetched}: {}", correctly_retrieved as f32 / relevant.len() as f32);
// println!("recall@{number_fetched}: {}", correctly_retrieved as f32 / relevant.len() as f32);
correctly_retrieved as f32 / relevant.len() as f32
}

fn partial_sort_by<'a, D: Distance>(
Expand Down

0 comments on commit b6ea560

Please sign in to comment.