diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index db13e9e..e1340c0 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -34,6 +34,10 @@ jobs: command: test args: --all-features + - uses: actions-rs/cargo@v1 + with: + command: test + lint: runs-on: ubuntu-latest steps: diff --git a/Cargo.toml b/Cargo.toml index 70a7678..63be726 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,8 +11,10 @@ license = "MIT" [dependencies] bytemuck = { version = "1.7.0", features = ["derive"] } byteorder = "1.3.4" +crossbeam-channel = "0.5.8" flate2 = { version = "1.0", optional = true } lz4_flex = { version = "0.9.2", optional = true } +rayon = { version = "1.7.0", optional = true } snap = { version = "1.0.5", optional = true } tempfile = { version = "3.2.0", optional = true } zstd = { version = "0.10.0", optional = true } diff --git a/src/lib.rs b/src/lib.rs index e38dd9b..7d2891c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,21 @@ //! The entries in the grenad files are _immutable_ and the only way to modify them is by _creating //! a new file_ with the changes. //! -//! # Example: Use the `Writer` and `Reader` structs +//! # Features +//! +//! You can define which compression schemes to support, there are currently a few +//! available choices, these determine which types will be available inside the above modules: +//! +//! - _Snappy_ with the [`snap`](https://crates.io/crates/snap) crate. +//! - _Zlib_ with the [`flate2`](https://crates.io/crates/flate2) crate. +//! - _Lz4_ with the [`lz4_flex`](https://crates.io/crates/lz4_flex) crate. +//! +//! If you need more performances you can enable the `rayon` feature that will enable a bunch +//! of new settings like being able to make the `Sorter` sort in parallel. +//! +//! # Examples +//! +//! ## Use the `Writer` and `Reader` structs //! //! You can use the [`Writer`] struct to store key-value pairs into the specified //! [`std::io::Write`] type. The [`Reader`] type can then be used to read the entries. @@ -37,7 +51,7 @@ //! # Ok(()) } //! ``` //! -//! # Example: Use the `Merger` struct +//! ## Use the `Merger` struct //! //! In this example we show how you can merge multiple [`Reader`]s //! by using a _merge function_ when a conflict is encountered. @@ -107,7 +121,7 @@ //! # Ok(()) } //! ``` //! -//! # Example: Use the `Sorter` struct +//! ## Use the `Sorter` struct //! //! In this example we show how by defining a _merge function_, we can insert //! multiple entries with the same key and output them in lexicographic order. @@ -189,7 +203,8 @@ pub use self::reader::{PrefixIter, RangeIter, Reader, ReaderCursor, RevPrefixIte #[cfg(feature = "tempfile")] pub use self::sorter::TempFileChunk; pub use self::sorter::{ - ChunkCreator, CursorVec, DefaultChunkCreator, SortAlgorithm, Sorter, SorterBuilder, + ChunkCreator, CursorVec, DefaultChunkCreator, ParallelSorter, SortAlgorithm, Sorter, + SorterBuilder, }; pub use self::writer::{Writer, WriterBuilder}; diff --git a/src/sorter.rs b/src/sorter.rs index 94cf92d..94a9996 100644 --- a/src/sorter.rs +++ b/src/sorter.rs @@ -1,14 +1,19 @@ use std::alloc::{alloc, dealloc, Layout}; use std::borrow::Cow; +use std::collections::hash_map::DefaultHasher; use std::convert::Infallible; #[cfg(feature = "tempfile")] use std::fs::File; +use std::hash::{Hash, Hasher}; use std::io::{Cursor, Read, Seek, SeekFrom, Write}; +use std::iter::repeat_with; use std::mem::{align_of, size_of}; use std::num::NonZeroUsize; +use std::thread::{self, JoinHandle}; use std::{cmp, io, ops, slice}; use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable}; +use crossbeam_channel::{unbounded, Sender}; use crate::count_write::CountWrite; @@ -47,6 +52,7 @@ pub struct SorterBuilder { index_levels: Option, chunk_creator: CC, sort_algorithm: SortAlgorithm, + sort_in_parallel: bool, merge: MF, } @@ -65,6 +71,7 @@ impl SorterBuilder { index_levels: None, chunk_creator: DefaultChunkCreator::default(), sort_algorithm: SortAlgorithm::Stable, + sort_in_parallel: false, merge, } } @@ -142,6 +149,15 @@ impl SorterBuilder { self } + /// Whether we use [rayon to sort](https://docs.rs/rayon/latest/rayon/slice/trait.ParallelSliceMut.html#method.par_sort_by_key) the entries. + /// + /// By default we do not sort in parallel, the value is `false`. + #[cfg(feature = "rayon")] + pub fn sort_in_parallel(&mut self, value: bool) -> &mut Self { + self.sort_in_parallel = value; + self + } + /// The [`ChunkCreator`] struct used to generate the chunks used /// by the [`Sorter`] to bufferize when required. pub fn chunk_creator(self, creation: CC2) -> SorterBuilder { @@ -156,6 +172,7 @@ impl SorterBuilder { index_levels: self.index_levels, chunk_creator: creation, sort_algorithm: self.sort_algorithm, + sort_in_parallel: self.sort_in_parallel, merge: self.merge, } } @@ -181,9 +198,53 @@ impl SorterBuilder { index_levels: self.index_levels, chunk_creator: self.chunk_creator, sort_algorithm: self.sort_algorithm, + sort_in_parallel: self.sort_in_parallel, merge: self.merge, } } + + /// Creates a [`ParallelSorter`] configured by this builder. + /// + /// Indicate the number of different [`Sorter`] you want to use to balanced + /// the load to sort. + pub fn build_in_parallel(self, number: NonZeroUsize) -> ParallelSorter + where + MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result, U> + + Clone + + Send + + 'static, + U: Send + 'static, + CC: Clone + Send + 'static, + CC::Chunk: Send + 'static, + { + match number.get() { + 1 | 2 => ParallelSorter(ParallelSorterInner::Single(self.build())), + number => { + let (senders, receivers): (Vec)>>, Vec<_>) = + repeat_with(unbounded).take(number).unzip(); + + let mut handles = Vec::new(); + for receiver in receivers { + let mut sorter_builder = self.clone(); + sorter_builder.dump_threshold(self.dump_threshold / number); + handles.push(thread::spawn(move || { + let mut sorter = sorter_builder.build(); + for (key_length, data) in receiver { + let (key, val) = data.split_at(key_length); + sorter.insert(key, val)?; + } + sorter.into_reader_cursors().map_err(Into::into) + })); + } + + ParallelSorter(ParallelSorterInner::Multi { + senders, + handles, + merge_function: self.merge, + }) + } + } + } } /// Stores entries memory efficiently in a buffer. @@ -281,6 +342,27 @@ impl Entries { sort(bounds, |b: &EntryBound| &tail[tail.len() - b.key_start..][..b.key_length as usize]); } + /// Sorts in **parallel** the entry bounds by the entries keys, + /// after a sort the `iter` method will yield the entries sorted. + #[cfg(feature = "rayon")] + pub fn par_sort_by_key(&mut self, algorithm: SortAlgorithm) { + use rayon::slice::ParallelSliceMut; + + let bounds_end = self.bounds_count * size_of::(); + let (bounds, tail) = self.buffer.split_at_mut(bounds_end); + let bounds = cast_slice_mut::<_, EntryBound>(bounds); + let sort = match algorithm { + SortAlgorithm::Stable => <[EntryBound]>::par_sort_by_key, + SortAlgorithm::Unstable => <[EntryBound]>::par_sort_unstable_by_key, + }; + sort(bounds, |b: &EntryBound| &tail[tail.len() - b.key_start..][..b.key_length as usize]); + } + + #[cfg(not(feature = "rayon"))] + pub fn par_sort_by_key(&mut self, algorithm: SortAlgorithm) { + self.sort_by_key(algorithm); + } + /// Returns an iterator over the keys and datas. pub fn iter(&self) -> impl Iterator + '_ { let bounds_end = self.bounds_count * size_of::(); @@ -399,6 +481,7 @@ pub struct Sorter { index_levels: Option, chunk_creator: CC, sort_algorithm: SortAlgorithm, + sort_in_parallel: bool, merge: MF, } @@ -489,7 +572,11 @@ where } let mut writer = writer_builder.build(count_write_chunk); - self.entries.sort_by_key(self.sort_algorithm); + if self.sort_in_parallel { + self.entries.par_sort_by_key(self.sort_algorithm); + } else { + self.entries.sort_by_key(self.sort_algorithm); + } let mut current = None; for (key, value) in self.entries.iter() { @@ -509,7 +596,7 @@ where if let Some((key, vals)) = current.take() { let merged_val = (self.merge)(key, &vals).map_err(Error::Merge)?; - writer.insert(&key, &merged_val)?; + writer.insert(key, &merged_val)?; } // We retrieve the wrapped CountWrite and extract @@ -630,6 +717,91 @@ where } } +pub struct ParallelSorter( + ParallelSorterInner, +) +where + MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result, U>; + +enum ParallelSorterInner +where + MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result, U>, +{ + Single(Sorter), + Multi { + // Indicates the length of the key and the bytes associated to the key + the data. + senders: Vec)>>, + handles: Vec>, Error>>>, + merge_function: MF, + }, +} + +impl ParallelSorter +where + MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result, U>, + CC: ChunkCreator, +{ + /// Insert an entry into the [`Sorter`] making sure that conflicts + /// are resolved by the provided merge function. + pub fn insert(&mut self, key: K, val: V) -> Result<(), Error> + where + K: AsRef<[u8]>, + V: AsRef<[u8]>, + { + let key = key.as_ref(); + let val = val.as_ref(); + match &mut self.0 { + ParallelSorterInner::Single(sorter) => sorter.insert(key, val), + ParallelSorterInner::Multi { senders, .. } => { + let key_length = key.len(); + let key_hash = compute_hash(key); + + // We put the key and val into the same allocation to speed things up + // by reducing the amount of calls to the allocator. + // + // TODO test that it works for real because having a bigger allocation + // can make it harder to find the space. + let mut data = Vec::with_capacity(key.len() + val.len()); + data.extend_from_slice(key); + data.extend_from_slice(val); + + let index = (key_hash % senders.len() as u64) as usize; + // TODO remove unwraps + senders[index].send((key_length, data)).unwrap(); + + Ok(()) + } + } + } + + /// Consumes this [`Sorter`] and outputs a stream of the merged entries in key-order. + pub fn into_stream_merger_iter(self) -> Result, Error> { + match self.0 { + ParallelSorterInner::Single(sorter) => sorter.into_stream_merger_iter(), + ParallelSorterInner::Multi { senders, handles, merge_function } => { + drop(senders); + + let mut sources = Vec::new(); + for handle in handles { + // TODO remove unwraps + sources.extend(handle.join().unwrap()?); + } + + let mut builder = Merger::builder(merge_function); + builder.extend(sources); + builder.build().into_stream_merger_iter().map_err(Error::convert_merge_error) + } + } + } +} + +/// Computes the hash of a key. +fn compute_hash(key: &[u8]) -> u64 { + let mut state = DefaultHasher::new(); + key.hash(&mut state); + state.finish() +} + /// A trait that represent a `ChunkCreator`. pub trait ChunkCreator { /// The generated chunk by this `ChunkCreator`.