From cfac1c469fbb3a1150fba85c0f0166b82162fd1a Mon Sep 17 00:00:00 2001
From: Kerollmops <clement@meilisearch.com>
Date: Tue, 4 Jul 2023 15:18:08 +0200
Subject: [PATCH 1/4] Expose a new SorterBuilder::sort_in_parallel method

---
 .github/workflows/rust.yml |  4 ++++
 Cargo.toml                 |  1 +
 src/lib.rs                 | 20 +++++++++++++++---
 src/sorter.rs              | 43 ++++++++++++++++++++++++++++++++++++--
 4 files changed, 63 insertions(+), 5 deletions(-)

diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml
index db13e9e..e1340c0 100644
--- a/.github/workflows/rust.yml
+++ b/.github/workflows/rust.yml
@@ -34,6 +34,10 @@ jobs:
           command: test
           args: --all-features
 
+      - uses: actions-rs/cargo@v1
+        with:
+          command: test
+
   lint:
     runs-on: ubuntu-latest
     steps:
diff --git a/Cargo.toml b/Cargo.toml
index 70a7678..3287d61 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -13,6 +13,7 @@ bytemuck = { version = "1.7.0", features = ["derive"] }
 byteorder = "1.3.4"
 flate2 = { version = "1.0", optional = true }
 lz4_flex = { version = "0.9.2", optional = true }
+rayon = { version = "1.7.0", optional = true }
 snap = { version = "1.0.5", optional = true }
 tempfile = { version = "3.2.0", optional = true }
 zstd = { version = "0.10.0", optional = true }
diff --git a/src/lib.rs b/src/lib.rs
index e38dd9b..56bf543 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -2,7 +2,21 @@
 //! The entries in the grenad files are _immutable_ and the only way to modify them is by _creating
 //! a new file_ with the changes.
 //!
-//! # Example: Use the `Writer` and `Reader` structs
+//! # Features
+//!
+//! You can define which compression schemes to support, there are currently a few
+//! available choices, these determine which types will be available inside the above modules:
+//!
+//! - _Snappy_ with the [`snap`](https://crates.io/crates/snap) crate.
+//! - _Zlib_ with the [`flate2`](https://crates.io/crates/flate2) crate.
+//! - _Lz4_ with the [`lz4_flex`](https://crates.io/crates/lz4_flex) crate.
+//!
+//! If you need more performances you can enable the `rayon` feature that will enable a bunch
+//! of new settings like being able to make the `Sorter` sort in parallel.
+//!
+//! # Examples
+//!
+//! ## Use the `Writer` and `Reader` structs
 //!
 //! You can use the [`Writer`] struct to store key-value pairs into the specified
 //! [`std::io::Write`] type. The [`Reader`] type can then be used to read the entries.
@@ -37,7 +51,7 @@
 //! # Ok(()) }
 //! ```
 //!
-//! # Example: Use the `Merger` struct
+//! ## Use the `Merger` struct
 //!
 //! In this example we show how you can merge multiple [`Reader`]s
 //! by using a _merge function_ when a conflict is encountered.
@@ -107,7 +121,7 @@
 //! # Ok(()) }
 //! ```
 //!
-//! # Example: Use the `Sorter` struct
+//! ## Use the `Sorter` struct
 //!
 //! In this example we show how by defining a _merge function_, we can insert
 //! multiple entries with the same key and output them in lexicographic order.
diff --git a/src/sorter.rs b/src/sorter.rs
index 94cf92d..8f022ed 100644
--- a/src/sorter.rs
+++ b/src/sorter.rs
@@ -47,6 +47,7 @@ pub struct SorterBuilder<MF, CC> {
     index_levels: Option<u8>,
     chunk_creator: CC,
     sort_algorithm: SortAlgorithm,
+    sort_in_parallel: bool,
     merge: MF,
 }
 
@@ -65,6 +66,7 @@ impl<MF> SorterBuilder<MF, DefaultChunkCreator> {
             index_levels: None,
             chunk_creator: DefaultChunkCreator::default(),
             sort_algorithm: SortAlgorithm::Stable,
+            sort_in_parallel: false,
             merge,
         }
     }
@@ -142,6 +144,15 @@ impl<MF, CC> SorterBuilder<MF, CC> {
         self
     }
 
+    /// Whether we use [rayon to sort](https://docs.rs/rayon/latest/rayon/slice/trait.ParallelSliceMut.html#method.par_sort_by_key) the entries.
+    ///
+    /// By default we do not sort in parallel, the value is `false`.
+    #[cfg(feature = "rayon")]
+    pub fn sort_in_parallel(&mut self, value: bool) -> &mut Self {
+        self.sort_in_parallel = value;
+        self
+    }
+
     /// The [`ChunkCreator`] struct used to generate the chunks used
     /// by the [`Sorter`] to bufferize when required.
     pub fn chunk_creator<CC2>(self, creation: CC2) -> SorterBuilder<MF, CC2> {
@@ -156,6 +167,7 @@ impl<MF, CC> SorterBuilder<MF, CC> {
             index_levels: self.index_levels,
             chunk_creator: creation,
             sort_algorithm: self.sort_algorithm,
+            sort_in_parallel: self.sort_in_parallel,
             merge: self.merge,
         }
     }
@@ -181,6 +193,7 @@ impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
             index_levels: self.index_levels,
             chunk_creator: self.chunk_creator,
             sort_algorithm: self.sort_algorithm,
+            sort_in_parallel: self.sort_in_parallel,
             merge: self.merge,
         }
     }
@@ -281,6 +294,27 @@ impl Entries {
         sort(bounds, |b: &EntryBound| &tail[tail.len() - b.key_start..][..b.key_length as usize]);
     }
 
+    /// Sorts in **parallel** the entry bounds by the entries keys,
+    /// after a sort the `iter` method will yield the entries sorted.
+    #[cfg(feature = "rayon")]
+    pub fn par_sort_by_key(&mut self, algorithm: SortAlgorithm) {
+        use rayon::slice::ParallelSliceMut;
+
+        let bounds_end = self.bounds_count * size_of::<EntryBound>();
+        let (bounds, tail) = self.buffer.split_at_mut(bounds_end);
+        let bounds = cast_slice_mut::<_, EntryBound>(bounds);
+        let sort = match algorithm {
+            SortAlgorithm::Stable => <[EntryBound]>::par_sort_by_key,
+            SortAlgorithm::Unstable => <[EntryBound]>::par_sort_unstable_by_key,
+        };
+        sort(bounds, |b: &EntryBound| &tail[tail.len() - b.key_start..][..b.key_length as usize]);
+    }
+
+    #[cfg(not(feature = "rayon"))]
+    pub fn par_sort_by_key(&mut self, algorithm: SortAlgorithm) {
+        self.sort_by_key(algorithm);
+    }
+
     /// Returns an iterator over the keys and datas.
     pub fn iter(&self) -> impl Iterator<Item = (&[u8], &[u8])> + '_ {
         let bounds_end = self.bounds_count * size_of::<EntryBound>();
@@ -399,6 +433,7 @@ pub struct Sorter<MF, CC: ChunkCreator = DefaultChunkCreator> {
     index_levels: Option<u8>,
     chunk_creator: CC,
     sort_algorithm: SortAlgorithm,
+    sort_in_parallel: bool,
     merge: MF,
 }
 
@@ -489,7 +524,11 @@ where
         }
         let mut writer = writer_builder.build(count_write_chunk);
 
-        self.entries.sort_by_key(self.sort_algorithm);
+        if self.sort_in_parallel {
+            self.entries.par_sort_by_key(self.sort_algorithm);
+        } else {
+            self.entries.sort_by_key(self.sort_algorithm);
+        }
 
         let mut current = None;
         for (key, value) in self.entries.iter() {
@@ -509,7 +548,7 @@ where
 
         if let Some((key, vals)) = current.take() {
             let merged_val = (self.merge)(key, &vals).map_err(Error::Merge)?;
-            writer.insert(&key, &merged_val)?;
+            writer.insert(key, &merged_val)?;
         }
 
         // We retrieve the wrapped CountWrite and extract

From eafb6ae795af6078e087edf77e7cd31a26238707 Mon Sep 17 00:00:00 2001
From: Kerollmops <clement@meilisearch.com>
Date: Tue, 4 Jul 2023 17:51:03 +0200
Subject: [PATCH 2/4] WIP Introduce the SorterBuilder::build_in_parallel method

We need to find a way to restrict it a little bit less
---
 Cargo.toml    |   1 +
 src/sorter.rs | 123 ++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 124 insertions(+)

diff --git a/Cargo.toml b/Cargo.toml
index 3287d61..63be726 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -11,6 +11,7 @@ license = "MIT"
 [dependencies]
 bytemuck = { version = "1.7.0", features = ["derive"] }
 byteorder = "1.3.4"
+crossbeam-channel = "0.5.8"
 flate2 = { version = "1.0", optional = true }
 lz4_flex = { version = "0.9.2", optional = true }
 rayon = { version = "1.7.0", optional = true }
diff --git a/src/sorter.rs b/src/sorter.rs
index 8f022ed..ae85306 100644
--- a/src/sorter.rs
+++ b/src/sorter.rs
@@ -1,14 +1,19 @@
 use std::alloc::{alloc, dealloc, Layout};
 use std::borrow::Cow;
+use std::collections::hash_map::DefaultHasher;
 use std::convert::Infallible;
 #[cfg(feature = "tempfile")]
 use std::fs::File;
+use std::hash::{Hash, Hasher};
 use std::io::{Cursor, Read, Seek, SeekFrom, Write};
+use std::iter::repeat_with;
 use std::mem::{align_of, size_of};
 use std::num::NonZeroUsize;
+use std::thread::{self, JoinHandle};
 use std::{cmp, io, ops, slice};
 
 use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable};
+use crossbeam_channel::{unbounded, Sender};
 
 use crate::count_write::CountWrite;
 
@@ -197,6 +202,44 @@ impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
             merge: self.merge,
         }
     }
+
+    /// Creates a [`ParallelSorter`] configured by this builder.
+    ///
+    /// Indicate the number of different [`Sorter`] you want to use to balanced
+    /// the load to sort.
+    pub fn build_in_parallel<U>(self, number: NonZeroUsize) -> ParallelSorter<MF, U, CC>
+    where
+        MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>
+            + Clone
+            + Send
+            + 'static,
+        U: Send + 'static,
+        CC: Clone + Send + 'static,
+        CC::Chunk: Send + 'static,
+    {
+        match number.get() {
+            1 => ParallelSorter::Single(self.build()),
+            number => {
+                let (senders, receivers): (Vec<Sender<(usize, Vec<u8>)>>, Vec<_>) =
+                    repeat_with(unbounded).take(number).unzip();
+
+                let mut handles = Vec::new();
+                for receiver in receivers {
+                    let sorter_builder = self.clone();
+                    handles.push(thread::spawn(move || {
+                        let mut sorter = sorter_builder.build();
+                        for (key_length, data) in receiver {
+                            let (key, val) = data.split_at(key_length);
+                            sorter.insert(key, val)?;
+                        }
+                        sorter.into_reader_cursors().map_err(Into::into)
+                    }));
+                }
+
+                ParallelSorter::Multi { senders, handles, merge_function: self.merge }
+            }
+        }
+    }
 }
 
 /// Stores entries memory efficiently in a buffer.
@@ -669,6 +712,86 @@ where
     }
 }
 
+// TODO Make this private by wrapping it
+pub enum ParallelSorter<MF, U, CC: ChunkCreator = DefaultChunkCreator>
+where
+    MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
+{
+    Single(Sorter<MF, CC>),
+    Multi {
+        // Indicates the length of the key and the bytes assoicated to the key + the data.
+        senders: Vec<Sender<(usize, Vec<u8>)>>,
+        handles: Vec<JoinHandle<Result<Vec<ReaderCursor<CC::Chunk>>, Error<U>>>>,
+        merge_function: MF,
+    },
+}
+
+impl<MF, U, CC> ParallelSorter<MF, U, CC>
+where
+    MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
+    CC: ChunkCreator,
+{
+    /// Insert an entry into the [`Sorter`] making sure that conflicts
+    /// are resolved by the provided merge function.
+    pub fn insert<K, V>(&mut self, key: K, val: V) -> Result<(), Error<U>>
+    where
+        K: AsRef<[u8]>,
+        V: AsRef<[u8]>,
+    {
+        let key = key.as_ref();
+        let val = val.as_ref();
+        match self {
+            ParallelSorter::Single(sorter) => sorter.insert(key, val),
+            ParallelSorter::Multi { senders, .. } => {
+                let key_length = key.len();
+                let key_hash = compute_hash(key);
+
+                // We put the key and val into the same allocation to speed things up
+                // by reducing the amount of calls to the allocator.
+                //
+                // TODO test that it works for real because having a bigger allocation
+                //      can make it harder to find the space.
+                let mut data = Vec::with_capacity(key.len() + val.len());
+                data.extend_from_slice(key);
+                data.extend_from_slice(val);
+
+                let index = (key_hash % senders.len() as u64) as usize;
+                // TODO remove unwraps
+                senders[index].send((key_length, data)).unwrap();
+
+                Ok(())
+            }
+        }
+    }
+
+    /// Consumes this [`Sorter`] and outputs a stream of the merged entries in key-order.
+    pub fn into_stream_merger_iter(self) -> Result<MergerIter<CC::Chunk, MF>, Error<U>> {
+        match self {
+            ParallelSorter::Single(sorter) => sorter.into_stream_merger_iter(),
+            ParallelSorter::Multi { senders, handles, merge_function } => {
+                drop(senders);
+
+                let mut sources = Vec::new();
+                for handle in handles {
+                    // TODO remove unwraps
+                    sources.extend(handle.join().unwrap()?);
+                }
+
+                let mut builder = Merger::builder(merge_function);
+                builder.extend(sources);
+                builder.build().into_stream_merger_iter().map_err(Error::convert_merge_error)
+            }
+        }
+    }
+}
+
+/// Computes the hash of a key.
+fn compute_hash(key: &[u8]) -> u64 {
+    let mut state = DefaultHasher::new();
+    key.hash(&mut state);
+    state.finish()
+}
+
 /// A trait that represent a `ChunkCreator`.
 pub trait ChunkCreator {
     /// The generated chunk by this `ChunkCreator`.

From 97896cd838b7699527a503f226d2e70a7418a63c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= <clement@meilisearch.com>
Date: Wed, 1 Nov 2023 11:34:18 +0100
Subject: [PATCH 3/4] Expose and wrap the ParallelSorter

---
 src/lib.rs    |  3 ++-
 src/sorter.rs | 32 +++++++++++++++++++++-----------
 2 files changed, 23 insertions(+), 12 deletions(-)

diff --git a/src/lib.rs b/src/lib.rs
index 56bf543..7d2891c 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -203,7 +203,8 @@ pub use self::reader::{PrefixIter, RangeIter, Reader, ReaderCursor, RevPrefixIte
 #[cfg(feature = "tempfile")]
 pub use self::sorter::TempFileChunk;
 pub use self::sorter::{
-    ChunkCreator, CursorVec, DefaultChunkCreator, SortAlgorithm, Sorter, SorterBuilder,
+    ChunkCreator, CursorVec, DefaultChunkCreator, ParallelSorter, SortAlgorithm, Sorter,
+    SorterBuilder,
 };
 pub use self::writer::{Writer, WriterBuilder};
 
diff --git a/src/sorter.rs b/src/sorter.rs
index ae85306..1e01885 100644
--- a/src/sorter.rs
+++ b/src/sorter.rs
@@ -218,7 +218,7 @@ impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
         CC::Chunk: Send + 'static,
     {
         match number.get() {
-            1 => ParallelSorter::Single(self.build()),
+            1 | 2 => ParallelSorter(ParallelSorterInner::Single(self.build())),
             number => {
                 let (senders, receivers): (Vec<Sender<(usize, Vec<u8>)>>, Vec<_>) =
                     repeat_with(unbounded).take(number).unzip();
@@ -227,6 +227,7 @@ impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
                 for receiver in receivers {
                     let sorter_builder = self.clone();
                     handles.push(thread::spawn(move || {
+                        // TODO make sure the max memory is divided by the number of threads
                         let mut sorter = sorter_builder.build();
                         for (key_length, data) in receiver {
                             let (key, val) = data.split_at(key_length);
@@ -236,7 +237,11 @@ impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
                     }));
                 }
 
-                ParallelSorter::Multi { senders, handles, merge_function: self.merge }
+                ParallelSorter(ParallelSorterInner::Multi {
+                    senders,
+                    handles,
+                    merge_function: self.merge,
+                })
             }
         }
     }
@@ -712,14 +717,19 @@ where
     }
 }
 
-// TODO Make this private by wrapping it
-pub enum ParallelSorter<MF, U, CC: ChunkCreator = DefaultChunkCreator>
+pub struct ParallelSorter<MF, U, CC: ChunkCreator = DefaultChunkCreator>(
+    ParallelSorterInner<MF, U, CC>,
+)
+where
+    MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>;
+
+enum ParallelSorterInner<MF, U, CC: ChunkCreator = DefaultChunkCreator>
 where
     MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
 {
     Single(Sorter<MF, CC>),
     Multi {
-        // Indicates the length of the key and the bytes assoicated to the key + the data.
+        // Indicates the length of the key and the bytes associated to the key + the data.
         senders: Vec<Sender<(usize, Vec<u8>)>>,
         handles: Vec<JoinHandle<Result<Vec<ReaderCursor<CC::Chunk>>, Error<U>>>>,
         merge_function: MF,
@@ -740,9 +750,9 @@ where
     {
         let key = key.as_ref();
         let val = val.as_ref();
-        match self {
-            ParallelSorter::Single(sorter) => sorter.insert(key, val),
-            ParallelSorter::Multi { senders, .. } => {
+        match &mut self.0 {
+            ParallelSorterInner::Single(sorter) => sorter.insert(key, val),
+            ParallelSorterInner::Multi { senders, .. } => {
                 let key_length = key.len();
                 let key_hash = compute_hash(key);
 
@@ -766,9 +776,9 @@ where
 
     /// Consumes this [`Sorter`] and outputs a stream of the merged entries in key-order.
     pub fn into_stream_merger_iter(self) -> Result<MergerIter<CC::Chunk, MF>, Error<U>> {
-        match self {
-            ParallelSorter::Single(sorter) => sorter.into_stream_merger_iter(),
-            ParallelSorter::Multi { senders, handles, merge_function } => {
+        match self.0 {
+            ParallelSorterInner::Single(sorter) => sorter.into_stream_merger_iter(),
+            ParallelSorterInner::Multi { senders, handles, merge_function } => {
                 drop(senders);
 
                 let mut sources = Vec::new();

From 5ac4c0b3cae7e56a3b812d516632306fba60c8cb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= <clement@meilisearch.com>
Date: Wed, 1 Nov 2023 12:09:01 +0100
Subject: [PATCH 4/4] Divide the parallel dump threashold by number of running
 threads

---
 src/sorter.rs | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/sorter.rs b/src/sorter.rs
index 1e01885..94a9996 100644
--- a/src/sorter.rs
+++ b/src/sorter.rs
@@ -225,9 +225,9 @@ impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
 
                 let mut handles = Vec::new();
                 for receiver in receivers {
-                    let sorter_builder = self.clone();
+                    let mut sorter_builder = self.clone();
+                    sorter_builder.dump_threshold(self.dump_threshold / number);
                     handles.push(thread::spawn(move || {
-                        // TODO make sure the max memory is divided by the number of threads
                         let mut sorter = sorter_builder.build();
                         for (key_length, data) in receiver {
                             let (key, val) = data.split_at(key_length);