Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various Improvements #55

Merged
merged 10 commits into from
Nov 25, 2024
Merged
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "grenad"
description = "Tools to sort, merge, write, and read immutable key-value pairs."
version = "0.4.7"
version = "0.5.0"
authors = ["Kerollmops <clement@meilisearch.com>"]
repository = "https://github.com/meilisearch/grenad"
documentation = "https://docs.rs/grenad"
@@ -11,6 +11,7 @@ license = "MIT"
[dependencies]
bytemuck = { version = "1.16.1", features = ["derive"] }
byteorder = "1.5.0"
either = { version = "1.13.0", default-features = false }
flate2 = { version = "1.0", optional = true }
lz4_flex = { version = "0.11.3", optional = true }
rayon = { version = "1.10.0", optional = true }
2 changes: 1 addition & 1 deletion benches/index-levels.rs
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@ fn index_levels(bytes: &[u8]) {

for x in (0..NUMBER_OF_ENTRIES).step_by(1_567) {
let num = x.to_be_bytes();
cursor.move_on_key_greater_than_or_equal_to(&num).unwrap().unwrap();
cursor.move_on_key_greater_than_or_equal_to(num).unwrap().unwrap();
}
}

6 changes: 3 additions & 3 deletions src/block_writer.rs
Original file line number Diff line number Diff line change
@@ -94,8 +94,8 @@ impl BlockWriter {
/// Insert a key that must be greater than the previously added one.
pub fn insert(&mut self, key: &[u8], val: &[u8]) {
debug_assert!(self.index_key_counter <= self.index_key_interval.get());
assert!(key.len() <= u32::max_value() as usize);
assert!(val.len() <= u32::max_value() as usize);
assert!(key.len() <= u32::MAX as usize);
assert!(val.len() <= u32::MAX as usize);

if self.index_key_counter == self.index_key_interval.get() {
self.index_offsets.push(self.buffer.len() as u64);
@@ -106,7 +106,7 @@ impl BlockWriter {
// and save the current key to become the last key.
match &mut self.last_key {
Some(last_key) => {
assert!(key > last_key, "{:?} must be greater than {:?}", key, last_key);
assert!(key > last_key.as_slice(), "{:?} must be greater than {:?}", key, last_key);
last_key.clear();
last_key.extend_from_slice(key);
}
12 changes: 5 additions & 7 deletions src/compression.rs
Original file line number Diff line number Diff line change
@@ -4,10 +4,11 @@ use std::str::FromStr;
use std::{fmt, io};

/// The different supported types of compression.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum CompressionType {
/// Do not compress the blocks.
#[default]
None = 0,
/// Use the [`snap`] crate to de/compress the blocks.
///
@@ -55,12 +56,6 @@ impl FromStr for CompressionType {
}
}

impl Default for CompressionType {
fn default() -> CompressionType {
CompressionType::None
}
}

/// An invalid compression type have been read and the block can't be de/compressed.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct InvalidCompressionType;
@@ -107,6 +102,7 @@ fn zlib_decompress<R: io::Read>(data: R, out: &mut Vec<u8>) -> io::Result<()> {
}

#[cfg(not(feature = "zlib"))]
#[allow(clippy::ptr_arg)] // it doesn't understand that I need the same signature for all function
fn zlib_decompress<R: io::Read>(_data: R, _out: &mut Vec<u8>) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "unsupported zlib decompression"))
}
@@ -186,6 +182,7 @@ fn zstd_decompress<R: io::Read>(data: R, out: &mut Vec<u8>) -> io::Result<()> {
}

#[cfg(not(feature = "zstd"))]
#[allow(clippy::ptr_arg)] // it doesn't understand that I need the same signature for all function
fn zstd_decompress<R: io::Read>(_data: R, _out: &mut Vec<u8>) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "unsupported zstd decompression"))
}
@@ -211,6 +208,7 @@ fn lz4_decompress<R: io::Read>(data: R, out: &mut Vec<u8>) -> io::Result<()> {
}

#[cfg(not(feature = "lz4"))]
#[allow(clippy::ptr_arg)] // it doesn't understand that I need the same signature for all function
fn lz4_decompress<R: io::Read>(_data: R, _out: &mut Vec<u8>) -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::Other, "unsupported lz4 decompression"))
}
60 changes: 34 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -72,23 +72,25 @@
//! use std::convert::TryInto;
//! use std::io::Cursor;
//!
//! use grenad::{MergerBuilder, Reader, Writer};
//! use grenad::{MergerBuilder, MergeFunction, Reader, Writer};
//!
//! // This merge function:
//! // - parses u32s from native-endian bytes,
//! // - wrapping sums them and,
//! // - outputs the result as native-endian bytes.
//! fn wrapping_sum_u32s<'a>(
//! _key: &[u8],
//! values: &[Cow<'a, [u8]>],
//! ) -> Result<Cow<'a, [u8]>, TryFromSliceError>
//! {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! struct WrappingSumU32s;
//!
//! impl MergeFunction for WrappingSumU32s {
//! type Error = TryFromSliceError;
//!
//! fn merge<'a>(&self, key: &[u8], values: &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, Self::Error> {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
@@ -115,7 +117,7 @@
//!
//! // We create a merger that will sum our u32s when necessary,
//! // and we add our readers to the list of readers to merge.
//! let merger_builder = MergerBuilder::new(wrapping_sum_u32s);
//! let merger_builder = MergerBuilder::new(WrappingSumU32s);
//! let merger = merger_builder.add(readera).add(readerb).add(readerc).build();
//!
//! // We can iterate over the entries in key-order.
@@ -142,28 +144,30 @@
//! use std::borrow::Cow;
//! use std::convert::TryInto;
//!
//! use grenad::{CursorVec, SorterBuilder};
//! use grenad::{CursorVec, MergeFunction, SorterBuilder};
//!
//! // This merge function:
//! // - parses u32s from native-endian bytes,
//! // - wrapping sums them and,
//! // - outputs the result as native-endian bytes.
//! fn wrapping_sum_u32s<'a>(
//! _key: &[u8],
//! values: &[Cow<'a, [u8]>],
//! ) -> Result<Cow<'a, [u8]>, TryFromSliceError>
//! {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! struct WrappingSumU32s;
//!
//! impl MergeFunction for WrappingSumU32s {
//! type Error = TryFromSliceError;
//!
//! fn merge<'a>(&self, key: &[u8], values: &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, Self::Error> {
//! let mut output: u32 = 0;
//! for bytes in values.iter().map(AsRef::as_ref) {
//! let num = bytes.try_into().map(u32::from_ne_bytes)?;
//! output = output.wrapping_add(num);
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//! Ok(Cow::Owned(output.to_ne_bytes().to_vec()))
//! }
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! // We create a sorter that will sum our u32s when necessary.
//! let mut sorter = SorterBuilder::new(wrapping_sum_u32s).chunk_creator(CursorVec).build();
//! let mut sorter = SorterBuilder::new(WrappingSumU32s).chunk_creator(CursorVec).build();
//!
//! // We insert multiple entries with the same key but different values
//! // in arbitrary order, the sorter will take care of merging them for us.
@@ -187,14 +191,15 @@
#[cfg(test)]
#[macro_use]
extern crate quickcheck;

use std::convert::Infallible;
use std::mem;

mod block;
mod block_writer;
mod compression;
mod count_write;
mod error;
mod merge_function;
mod merger;
mod metadata;
mod reader;
@@ -204,6 +209,7 @@ mod writer;

pub use self::compression::CompressionType;
pub use self::error::Error;
pub use self::merge_function::MergeFunction;
pub use self::merger::{Merger, MergerBuilder, MergerIter};
pub use self::metadata::FileVersion;
pub use self::reader::{PrefixIter, RangeIter, Reader, ReaderCursor, RevPrefixIter, RevRangeIter};
@@ -214,10 +220,12 @@ pub use self::sorter::{
};
pub use self::writer::{Writer, WriterBuilder};

pub type Result<T, U = Infallible> = std::result::Result<T, Error<U>>;

/// Sometimes we need to use an unsafe trick to make the compiler happy.
/// You can read more about the issue [on the Rust's Github issues].
///
/// [on the Rust's Github issues]: https://github.com/rust-lang/rust/issues/47680
unsafe fn transmute_entry_to_static(key: &[u8], val: &[u8]) -> (&'static [u8], &'static [u8]) {
(mem::transmute(key), mem::transmute(val))
(mem::transmute::<&[u8], &'static [u8]>(key), mem::transmute::<&[u8], &'static [u8]>(val))
}
46 changes: 46 additions & 0 deletions src/merge_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use std::borrow::Cow;
use std::result::Result;

use either::Either;

/// A trait defining the way we merge multiple
/// values sharing the same key.
pub trait MergeFunction {
type Error;
fn merge<'a>(&self, key: &[u8], values: &[Cow<'a, [u8]>])
-> Result<Cow<'a, [u8]>, Self::Error>;
}

impl<MF> MergeFunction for &MF
where
MF: MergeFunction,
{
type Error = MF::Error;

fn merge<'a>(
&self,
key: &[u8],
values: &[Cow<'a, [u8]>],
) -> Result<Cow<'a, [u8]>, Self::Error> {
(*self).merge(key, values)
}
}

impl<MFA, MFB> MergeFunction for Either<MFA, MFB>
where
MFA: MergeFunction,
MFB: MergeFunction<Error = MFA::Error>,
{
type Error = MFA::Error;

fn merge<'a>(
&self,
key: &[u8],
values: &[Cow<'a, [u8]>],
) -> Result<Cow<'a, [u8]>, Self::Error> {
match self {
Either::Left(mfa) => mfa.merge(key, values),
Either::Right(mfb) => mfb.merge(key, values),
}
}
}
23 changes: 13 additions & 10 deletions src/merger.rs
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ use std::collections::BinaryHeap;
use std::io;
use std::iter::once;

use crate::{Error, ReaderCursor, Writer};
use crate::{Error, MergeFunction, ReaderCursor, Writer};

/// A struct that is used to configure a [`Merger`] with the sources to merge.
pub struct MergerBuilder<R, MF> {
@@ -20,6 +20,7 @@ impl<R, MF> MergerBuilder<R, MF> {
}

/// Add a source to merge, this function can be chained.
#[allow(clippy::should_implement_trait)] // We return interior references
pub fn add(mut self, source: ReaderCursor<R>) -> Self {
self.push(source);
self
@@ -95,7 +96,7 @@ impl<R: io::Read + io::Seek, MF> Merger<R, MF> {
}

Ok(MergerIter {
merge: self.merge,
merge_function: self.merge,
heap,
current_key: Vec::new(),
merged_value: Vec::new(),
@@ -104,16 +105,16 @@ impl<R: io::Read + io::Seek, MF> Merger<R, MF> {
}
}

impl<R, MF, U> Merger<R, MF>
impl<R, MF> Merger<R, MF>
where
R: io::Read + io::Seek,
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
MF: MergeFunction,
{
/// Consumes this [`Merger`] and streams the entries to the [`Writer`] given in parameter.
pub fn write_into_stream_writer<W: io::Write>(
self,
writer: &mut Writer<W>,
) -> Result<(), Error<U>> {
) -> crate::Result<(), MF::Error> {
let mut iter = self.into_stream_merger_iter().map_err(Error::convert_merge_error)?;
while let Some((key, val)) = iter.next()? {
writer.insert(key, val)?;
@@ -124,21 +125,23 @@ where

/// An iterator that yield the merged entries in key-order.
pub struct MergerIter<R, MF> {
merge: MF,
merge_function: MF,
heap: BinaryHeap<Entry<R>>,
current_key: Vec<u8>,
merged_value: Vec<u8>,
/// We keep this buffer to avoid allocating a vec every time.
tmp_entries: Vec<Entry<R>>,
}

impl<R, MF, U> MergerIter<R, MF>
impl<R, MF> MergerIter<R, MF>
where
R: io::Read + io::Seek,
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
MF: MergeFunction,
{
/// Yield the entries in key-order where values have been merged when needed.
pub fn next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error<U>> {
#[allow(clippy::should_implement_trait)] // We return interior references
#[allow(clippy::type_complexity)] // Return type is not THAT complex
pub fn next(&mut self) -> crate::Result<Option<(&[u8], &[u8])>, MF::Error> {
let first_entry = match self.heap.pop() {
Some(entry) => entry,
None => return Ok(None),
@@ -167,7 +170,7 @@ where
self.tmp_entries.iter().filter_map(|e| e.cursor.current().map(|(_, v)| v));
let values: Vec<_> = once(first_value).chain(other_values).map(Cow::Borrowed).collect();

match (self.merge)(first_key, &values) {
match self.merge_function.merge(first_key, &values) {
Ok(value) => {
self.current_key.clear();
self.current_key.extend_from_slice(first_key);
12 changes: 7 additions & 5 deletions src/reader/prefix_iter.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::io;

use crate::{Error, ReaderCursor};
use crate::ReaderCursor;

/// An iterator that is able to yield all the entries with
/// a key that starts with a given prefix.
@@ -18,7 +18,8 @@ impl<R: io::Read + io::Seek> PrefixIter<R> {
}

/// Returns the next entry that starts with the given prefix.
pub fn next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error> {
#[allow(clippy::should_implement_trait)] // We return interior references
pub fn next(&mut self) -> crate::Result<Option<(&[u8], &[u8])>> {
let entry = if self.move_on_first_prefix {
self.move_on_first_prefix = false;
self.cursor.move_on_key_greater_than_or_equal_to(&self.prefix)?
@@ -49,7 +50,8 @@ impl<R: io::Read + io::Seek> RevPrefixIter<R> {
}

/// Returns the next entry that starts with the given prefix.
pub fn next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error> {
#[allow(clippy::should_implement_trait)] // We return interior references
pub fn next(&mut self) -> crate::Result<Option<(&[u8], &[u8])>> {
let entry = if self.move_on_last_prefix {
self.move_on_last_prefix = false;
move_on_last_prefix(&mut self.cursor, self.prefix.clone())?
@@ -68,7 +70,7 @@ impl<R: io::Read + io::Seek> RevPrefixIter<R> {
fn move_on_last_prefix<R: io::Read + io::Seek>(
cursor: &mut ReaderCursor<R>,
prefix: Vec<u8>,
) -> Result<Option<(&[u8], &[u8])>, Error> {
) -> crate::Result<Option<(&[u8], &[u8])>> {
match advance_key(prefix) {
Some(next_prefix) => match cursor.move_on_key_lower_than_or_equal_to(&next_prefix)? {
Some((k, _)) if k == next_prefix => cursor.move_on_prev(),
@@ -108,7 +110,7 @@ mod tests {
let mut writer = Writer::memory();
for x in (10..24000u32).step_by(3) {
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
20 changes: 11 additions & 9 deletions src/reader/range_iter.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::io;
use std::ops::{Bound, RangeBounds};

use crate::{Error, ReaderCursor};
use crate::ReaderCursor;

/// An iterator that is able to yield all the entries lying in a specified range.
#[derive(Clone)]
@@ -24,7 +24,8 @@ impl<R: io::Read + io::Seek> RangeIter<R> {
}

/// Returns the next entry that is inside of the given range.
pub fn next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error> {
#[allow(clippy::should_implement_trait)] // We return interior references
pub fn next(&mut self) -> crate::Result<Option<(&[u8], &[u8])>> {
let entry = if self.move_on_start {
self.move_on_start = false;
match self.range.start_bound() {
@@ -75,7 +76,8 @@ impl<R: io::Read + io::Seek> RevRangeIter<R> {
}

/// Returns the next entry that is inside of the given range.
pub fn next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error> {
#[allow(clippy::should_implement_trait)] // We return interior references
pub fn next(&mut self) -> crate::Result<Option<(&[u8], &[u8])>> {
let entry = if self.move_on_start {
self.move_on_start = false;
match self.range.end_bound() {
@@ -116,17 +118,17 @@ fn map_bound<T, U, F: FnOnce(T) -> U>(bound: Bound<T>, f: F) -> Bound<U> {
fn end_contains(end: Bound<&Vec<u8>>, key: &[u8]) -> bool {
match end {
Bound::Unbounded => true,
Bound::Included(end) => key <= end,
Bound::Excluded(end) => key < end,
Bound::Included(end) => key <= end.as_slice(),
Bound::Excluded(end) => key < end.as_slice(),
}
}

/// Returns weither the provided key doesn't outbound this start bound.
fn start_contains(end: Bound<&Vec<u8>>, key: &[u8]) -> bool {
match end {
Bound::Unbounded => true,
Bound::Included(end) => key >= end,
Bound::Excluded(end) => key > end,
Bound::Included(end) => key >= end.as_slice(),
Bound::Excluded(end) => key > end.as_slice(),
}
}

@@ -149,7 +151,7 @@ mod tests {
for x in (10..24000i32).step_by(3) {
nums.insert(x);
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
@@ -186,7 +188,7 @@ mod tests {
for x in (10..24000i32).step_by(3) {
nums.insert(x);
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
69 changes: 35 additions & 34 deletions src/reader/reader_cursor.rs
Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@ impl<R: io::Read + io::Seek> ReaderCursor<R> {
}

/// Moves the cursor on the first entry and returns it.
pub fn move_on_first(&mut self) -> Result<Option<(&[u8], &[u8])>, Error> {
pub fn move_on_first(&mut self) -> crate::Result<Option<(&[u8], &[u8])>> {
match self.index_block_cursor.move_on_first(&mut self.reader.reader)? {
Some((_, offset_bytes)) => {
let offset = offset_bytes.try_into().map(u64::from_be_bytes).unwrap();
@@ -109,7 +109,7 @@ impl<R: io::Read + io::Seek> ReaderCursor<R> {
}

/// Moves the cursor on the last entry and returns it.
pub fn move_on_last(&mut self) -> Result<Option<(&[u8], &[u8])>, Error> {
pub fn move_on_last(&mut self) -> crate::Result<Option<(&[u8], &[u8])>> {
match self.index_block_cursor.move_on_last(&mut self.reader.reader)? {
Some((_, offset_bytes)) => {
let offset = offset_bytes.try_into().map(u64::from_be_bytes).unwrap();
@@ -129,7 +129,7 @@ impl<R: io::Read + io::Seek> ReaderCursor<R> {
}

/// Moves the cursor on the entry following the current one and returns it.
pub fn move_on_next(&mut self) -> Result<Option<(&[u8], &[u8])>, Error> {
pub fn move_on_next(&mut self) -> crate::Result<Option<(&[u8], &[u8])>> {
match self.current_cursor.as_mut().map(BlockCursor::move_on_next) {
Some(Some((key, val))) => {
let (key, val) = unsafe { crate::transmute_entry_to_static(key, val) };
@@ -147,7 +147,7 @@ impl<R: io::Read + io::Seek> ReaderCursor<R> {
}

/// Moves the cursor on the entry preceding the current one and returns it.
pub fn move_on_prev(&mut self) -> Result<Option<(&[u8], &[u8])>, Error> {
pub fn move_on_prev(&mut self) -> crate::Result<Option<(&[u8], &[u8])>> {
match self.current_cursor.as_mut().map(BlockCursor::move_on_prev) {
Some(Some((key, val))) => {
let (key, val) = unsafe { crate::transmute_entry_to_static(key, val) };
@@ -169,7 +169,7 @@ impl<R: io::Read + io::Seek> ReaderCursor<R> {
pub fn move_on_key_lower_than_or_equal_to<A: AsRef<[u8]>>(
&mut self,
target_key: A,
) -> Result<Option<(&[u8], &[u8])>, Error> {
) -> crate::Result<Option<(&[u8], &[u8])>> {
let target_key = target_key.as_ref();
match self.move_on_key_greater_than_or_equal_to(target_key)? {
Some((key, val)) if key == target_key => {
@@ -186,7 +186,7 @@ impl<R: io::Read + io::Seek> ReaderCursor<R> {
pub fn move_on_key_greater_than_or_equal_to<A: AsRef<[u8]>>(
&mut self,
key: A,
) -> Result<Option<(&[u8], &[u8])>, Error> {
) -> crate::Result<Option<(&[u8], &[u8])>> {
// We move on the block which has a key greater than or equal to the key we are
// searching for as the key stored in the index block is the last key of the block.
let key = key.as_ref();
@@ -213,7 +213,7 @@ impl<R: io::Read + io::Seek> ReaderCursor<R> {
pub fn move_on_key_equal_to<A: AsRef<[u8]>>(
&mut self,
key: A,
) -> Result<Option<(&[u8], &[u8])>, Error> {
) -> crate::Result<Option<(&[u8], &[u8])>> {
let key = key.as_ref();
self.move_on_key_greater_than_or_equal_to(key).map(|opt| opt.filter(|(k, _)| *k == key))
}
@@ -255,44 +255,44 @@ impl IndexBlockCursor {
fn move_on_first<R: io::Read + io::Seek>(
&mut self,
reader: R,
) -> Result<Option<(&[u8], &[u8])>, Error> {
) -> crate::Result<Option<(&[u8], &[u8])>> {
self.iter_index_blocks(reader, |c| c.move_on_first())
}

fn move_on_last<R: io::Read + io::Seek>(
&mut self,
reader: R,
) -> Result<Option<(&[u8], &[u8])>, Error> {
) -> crate::Result<Option<(&[u8], &[u8])>> {
self.iter_index_blocks(reader, |c| c.move_on_last())
}

fn move_on_next<R: io::Read + io::Seek>(
&mut self,
reader: R,
) -> Result<Option<(&[u8], &[u8])>, Error> {
) -> crate::Result<Option<(&[u8], &[u8])>> {
self.recursive_index_block(reader, |c| c.move_on_next())
}

fn move_on_prev<R: io::Read + io::Seek>(
&mut self,
reader: R,
) -> Result<Option<(&[u8], &[u8])>, Error> {
) -> crate::Result<Option<(&[u8], &[u8])>> {
self.recursive_index_block(reader, |c| c.move_on_prev())
}

fn move_on_key_greater_than_or_equal_to<R: io::Read + io::Seek>(
&mut self,
key: &[u8],
reader: R,
) -> Result<Option<(&[u8], &[u8])>, Error> {
) -> crate::Result<Option<(&[u8], &[u8])>> {
self.iter_index_blocks(reader, |c| c.move_on_key_greater_than_or_equal_to(key))
}

fn iter_index_blocks<R, F>(
&mut self,
mut reader: R,
mut mov: F,
) -> Result<Option<(&[u8], &[u8])>, Error>
) -> crate::Result<Option<(&[u8], &[u8])>>
where
R: io::Read + io::Seek,
F: FnMut(&mut BlockCursor<Block>) -> Option<(&[u8], &[u8])>,
@@ -334,7 +334,7 @@ impl IndexBlockCursor {
&mut self,
mut reader: R,
mut mov: FM,
) -> Result<Option<(&[u8], &[u8])>, Error>
) -> crate::Result<Option<(&[u8], &[u8])>>
where
R: io::Read + io::Seek,
FM: FnMut(&mut BlockCursor<Block>) -> Option<(&[u8], &[u8])>,
@@ -344,7 +344,7 @@ impl IndexBlockCursor {
compression_type: CompressionType,
blocks: &'a mut [(u64, BlockCursor<Block>)],
mov: &mut FN,
) -> Result<Option<(&'a [u8], &'a [u8])>, Error>
) -> crate::Result<Option<(&'a [u8], &'a [u8])>>
where
S: io::Read + io::Seek,
FN: FnMut(&mut BlockCursor<Block>) -> Option<(&[u8], &[u8])>,
@@ -393,11 +393,12 @@ impl IndexBlockCursor {
}

/// Returns the index block cursors by calling the user function to load the blocks.
#[allow(clippy::type_complexity)] // Return type is not THAT complex
fn initial_index_blocks<R, FM>(
&mut self,
mut reader: R,
mut mov: FM,
) -> Result<Option<Vec<(u64, BlockCursor<Block>)>>, Error>
) -> crate::Result<Option<Vec<(u64, BlockCursor<Block>)>>>
where
R: io::Read + io::Seek,
FM: FnMut(&mut BlockCursor<Block>) -> Option<(&[u8], &[u8])>,
@@ -441,7 +442,7 @@ mod tests {
let reader = Reader::new(Cursor::new(bytes.as_slice())).unwrap();

let mut cursor = reader.into_cursor().unwrap();
let result = cursor.move_on_key_greater_than_or_equal_to(&[0, 0, 0, 0]).unwrap();
let result = cursor.move_on_key_greater_than_or_equal_to([0, 0, 0, 0]).unwrap();
assert_eq!(result, None);
}

@@ -453,7 +454,7 @@ mod tests {

for x in 0..2000u32 {
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
@@ -490,7 +491,7 @@ mod tests {

for x in 0..2000u32 {
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
@@ -517,7 +518,7 @@ mod tests {
for x in (10..24000i32).step_by(3) {
nums.push(x);
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
@@ -531,15 +532,15 @@ mod tests {
Ok(i) => {
let n = nums[i];
let (k, _) = cursor
.move_on_key_greater_than_or_equal_to(&n.to_be_bytes())
.move_on_key_greater_than_or_equal_to(n.to_be_bytes())
.unwrap()
.unwrap();
let k = k.try_into().map(i32::from_be_bytes).unwrap();
assert_eq!(k, n);
}
Err(i) => {
let k = cursor
.move_on_key_greater_than_or_equal_to(&n.to_be_bytes())
.move_on_key_greater_than_or_equal_to(n.to_be_bytes())
.unwrap()
.map(|(k, _)| k.try_into().map(i32::from_be_bytes).unwrap());
assert_eq!(k, nums.get(i).copied());
@@ -556,7 +557,7 @@ mod tests {
for x in (10..24000i32).step_by(3) {
nums.push(x);
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
@@ -569,15 +570,15 @@ mod tests {
Ok(i) => {
let n = nums[i];
let (k, _) = cursor
.move_on_key_lower_than_or_equal_to(&n.to_be_bytes())
.move_on_key_lower_than_or_equal_to(n.to_be_bytes())
.unwrap()
.unwrap();
let k = k.try_into().map(i32::from_be_bytes).unwrap();
assert_eq!(k, n);
}
Err(i) => {
let k = cursor
.move_on_key_lower_than_or_equal_to(&n.to_be_bytes())
.move_on_key_lower_than_or_equal_to(n.to_be_bytes())
.unwrap()
.map(|(k, _)| k.try_into().map(i32::from_be_bytes).unwrap());
let expected = i.checked_sub(1).and_then(|i| nums.get(i)).copied();
@@ -597,7 +598,7 @@ mod tests {
for x in (10..24000i32).step_by(3) {
nums.push(x);
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
@@ -611,15 +612,15 @@ mod tests {
Ok(i) => {
let n = nums[i];
let (k, _) = cursor
.move_on_key_greater_than_or_equal_to(&n.to_be_bytes())
.move_on_key_greater_than_or_equal_to(n.to_be_bytes())
.unwrap()
.unwrap();
let k = k.try_into().map(i32::from_be_bytes).unwrap();
assert_eq!(k, n);
}
Err(i) => {
let k = cursor
.move_on_key_greater_than_or_equal_to(&n.to_be_bytes())
.move_on_key_greater_than_or_equal_to(n.to_be_bytes())
.unwrap()
.map(|(k, _)| k.try_into().map(i32::from_be_bytes).unwrap());
assert_eq!(k, nums.get(i).copied());
@@ -638,7 +639,7 @@ mod tests {
for x in (10..24000i32).step_by(3) {
nums.push(x);
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
@@ -651,15 +652,15 @@ mod tests {
Ok(i) => {
let n = nums[i];
let (k, _) = cursor
.move_on_key_lower_than_or_equal_to(&n.to_be_bytes())
.move_on_key_lower_than_or_equal_to(n.to_be_bytes())
.unwrap()
.unwrap();
let k = k.try_into().map(i32::from_be_bytes).unwrap();
assert_eq!(k, n);
}
Err(i) => {
let k = cursor
.move_on_key_lower_than_or_equal_to(&n.to_be_bytes())
.move_on_key_lower_than_or_equal_to(n.to_be_bytes())
.unwrap()
.map(|(k, _)| k.try_into().map(i32::from_be_bytes).unwrap());
let expected = i.checked_sub(1).and_then(|i| nums.get(i)).copied();
@@ -679,7 +680,7 @@ mod tests {
let mut writer = Writer::builder().index_levels(2).memory();
for &x in &nums {
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
@@ -691,7 +692,7 @@ mod tests {
Ok(i) => {
let q = nums[i];
let (k, _) = cursor
.move_on_key_lower_than_or_equal_to(&q.to_be_bytes())
.move_on_key_lower_than_or_equal_to(q.to_be_bytes())
.unwrap()
.unwrap();
let k = k.try_into().map(u32::from_be_bytes).unwrap();
@@ -701,7 +702,7 @@ mod tests {
}
Err(i) => {
let k = cursor
.move_on_key_lower_than_or_equal_to(&q.to_be_bytes())
.move_on_key_lower_than_or_equal_to(q.to_be_bytes())
.unwrap()
.map(|(k, _)| k.try_into().map(u32::from_be_bytes).unwrap());
let expected = i.checked_sub(1).and_then(|i| nums.get(i)).copied();
140 changes: 95 additions & 45 deletions src/sorter.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::alloc::{alloc, dealloc, Layout};
use std::borrow::Cow;
use std::convert::Infallible;
use std::fmt::Debug;
#[cfg(feature = "tempfile")]
use std::fs::File;
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
use std::mem::{align_of, size_of};
use std::num::NonZeroUsize;
use std::ptr::NonNull;
use std::{cmp, io, ops, slice};

use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable};
@@ -20,7 +22,8 @@ const DEFAULT_NB_CHUNKS: usize = 25;
const MIN_NB_CHUNKS: usize = 1;

use crate::{
CompressionType, Error, Merger, MergerIter, Reader, ReaderCursor, Writer, WriterBuilder,
CompressionType, Error, MergeFunction, Merger, MergerIter, Reader, ReaderCursor, Writer,
WriterBuilder,
};

/// The kind of sort algorithm used by the sorter to sort its internal vector.
@@ -194,7 +197,7 @@ impl<MF, CC: ChunkCreator> SorterBuilder<MF, CC> {
chunk_creator: self.chunk_creator,
sort_algorithm: self.sort_algorithm,
sort_in_parallel: self.sort_in_parallel,
merge: self.merge,
merge_function: self.merge,
}
}
}
@@ -238,8 +241,8 @@ impl Entries {
/// Inserts a new entry into the buffer, if there is not
/// enough space for it to be stored, we double the buffer size.
pub fn insert(&mut self, key: &[u8], data: &[u8]) {
assert!(key.len() <= u32::max_value() as usize);
assert!(data.len() <= u32::max_value() as usize);
assert!(key.len() <= u32::MAX as usize);
assert!(data.len() <= u32::MAX as usize);

if self.fits(key, data) {
// We store the key and data bytes one after the other at the back of the buffer.
@@ -374,7 +377,10 @@ struct EntryBound {
}

/// Represents an `EntryBound` aligned buffer.
struct EntryBoundAlignedBuffer(&'static mut [u8]);
struct EntryBoundAlignedBuffer {
data: NonNull<u8>,
len: usize,
}

impl EntryBoundAlignedBuffer {
/// Allocates a new buffer of the given size, it is correctly aligned to store `EntryBound`s.
@@ -383,34 +389,36 @@ impl EntryBoundAlignedBuffer {
let size = (size + entry_bound_size - 1) / entry_bound_size * entry_bound_size;
let layout = Layout::from_size_align(size, align_of::<EntryBound>()).unwrap();
let ptr = unsafe { alloc(layout) };
assert!(
!ptr.is_null(),
"the allocator is unable to allocate that much memory ({} bytes requested)",
size
);
let slice = unsafe { slice::from_raw_parts_mut(ptr, size) };
EntryBoundAlignedBuffer(slice)
let Some(ptr) = NonNull::new(ptr) else {
panic!(
"the allocator is unable to allocate that much memory ({} bytes requested)",
size
);
};

EntryBoundAlignedBuffer { data: ptr, len: size }
}
}

impl ops::Deref for EntryBoundAlignedBuffer {
type Target = [u8];

fn deref(&self) -> &Self::Target {
self.0
unsafe { slice::from_raw_parts(self.data.as_ptr(), self.len) }
}
}

impl ops::DerefMut for EntryBoundAlignedBuffer {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0
unsafe { slice::from_raw_parts_mut(self.data.as_ptr(), self.len) }
}
}

impl Drop for EntryBoundAlignedBuffer {
fn drop(&mut self) {
let layout = Layout::from_size_align(self.0.len(), align_of::<EntryBound>()).unwrap();
unsafe { dealloc(self.0.as_mut_ptr(), layout) }
let layout = Layout::from_size_align(self.len, align_of::<EntryBound>()).unwrap();

unsafe { dealloc(self.data.as_ptr(), layout) }
}
}

@@ -434,7 +442,7 @@ pub struct Sorter<MF, CC: ChunkCreator = DefaultChunkCreator> {
chunk_creator: CC,
sort_algorithm: SortAlgorithm,
sort_in_parallel: bool,
merge: MF,
merge_function: MF,
}

impl<MF> Sorter<MF, DefaultChunkCreator> {
@@ -460,14 +468,14 @@ impl<MF> Sorter<MF, DefaultChunkCreator> {
}
}

impl<MF, CC, U> Sorter<MF, CC>
impl<MF, CC> Sorter<MF, CC>
where
MF: for<'a> Fn(&[u8], &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, U>,
MF: MergeFunction,
CC: ChunkCreator,
{
/// Insert an entry into the [`Sorter`] making sure that conflicts
/// are resolved by the provided merge function.
pub fn insert<K, V>(&mut self, key: K, val: V) -> Result<(), Error<U>>
pub fn insert<K, V>(&mut self, key: K, val: V) -> crate::Result<(), MF::Error>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
@@ -498,7 +506,7 @@ where
///
/// Writes the in-memory entries to disk, using the specify settings
/// to compress the block and entries. It clears the in-memory entries.
fn write_chunk(&mut self) -> Result<u64, Error<U>> {
fn write_chunk(&mut self) -> crate::Result<u64, MF::Error> {
let count_write_chunk = self
.chunk_creator
.create()
@@ -536,7 +544,8 @@ where
None => current = Some((key, vec![Cow::Borrowed(value)])),
Some((current_key, vals)) => {
if current_key != &key {
let merged_val = (self.merge)(current_key, vals).map_err(Error::Merge)?;
let merged_val =
self.merge_function.merge(current_key, vals).map_err(Error::Merge)?;
writer.insert(&current_key, &merged_val)?;
vals.clear();
*current_key = key;
@@ -547,7 +556,7 @@ where
}

if let Some((key, vals)) = current.take() {
let merged_val = (self.merge)(key, &vals).map_err(Error::Merge)?;
let merged_val = self.merge_function.merge(key, &vals).map_err(Error::Merge)?;
writer.insert(key, &merged_val)?;
}

@@ -569,7 +578,7 @@ where
///
/// Merges all of the chunks into a final chunk that replaces them.
/// It uses the user provided merge function to resolve merge conflicts.
fn merge_chunks(&mut self) -> Result<u64, Error<U>> {
fn merge_chunks(&mut self) -> crate::Result<u64, MF::Error> {
let count_write_chunk = self
.chunk_creator
.create()
@@ -595,7 +604,7 @@ where
}
let mut writer = writer_builder.build(count_write_chunk);

let sources: Result<Vec<_>, Error<U>> = self
let sources: crate::Result<Vec<_>, MF::Error> = self
.chunks
.drain(..)
.map(|mut chunk| {
@@ -605,7 +614,7 @@ where
.collect();

// Create a merger to merge all those chunks.
let mut builder = Merger::builder(&self.merge);
let mut builder = Merger::builder(&self.merge_function);
builder.extend(sources?);
let merger = builder.build();

@@ -628,7 +637,7 @@ where
pub fn write_into_stream_writer<W: io::Write>(
self,
writer: &mut Writer<W>,
) -> Result<(), Error<U>> {
) -> crate::Result<(), MF::Error> {
let mut iter = self.into_stream_merger_iter()?;
while let Some((key, val)) = iter.next()? {
writer.insert(key, val)?;
@@ -637,26 +646,27 @@ where
}

/// Consumes this [`Sorter`] and outputs a stream of the merged entries in key-order.
pub fn into_stream_merger_iter(self) -> Result<MergerIter<CC::Chunk, MF>, Error<U>> {
pub fn into_stream_merger_iter(self) -> crate::Result<MergerIter<CC::Chunk, MF>, MF::Error> {
let (sources, merge) = self.extract_reader_cursors_and_merger()?;
let mut builder = Merger::builder(merge);
builder.extend(sources);
builder.build().into_stream_merger_iter().map_err(Error::convert_merge_error)
}

/// Consumes this [`Sorter`] and outputs the list of reader cursors.
pub fn into_reader_cursors(self) -> Result<Vec<ReaderCursor<CC::Chunk>>, Error<U>> {
pub fn into_reader_cursors(self) -> crate::Result<Vec<ReaderCursor<CC::Chunk>>, MF::Error> {
self.extract_reader_cursors_and_merger().map(|(readers, _)| readers)
}

/// A helper function to extract the readers and the merge function.
#[allow(clippy::type_complexity)] // Return type is not THAT complex
fn extract_reader_cursors_and_merger(
mut self,
) -> Result<(Vec<ReaderCursor<CC::Chunk>>, MF), Error<U>> {
) -> crate::Result<(Vec<ReaderCursor<CC::Chunk>>, MF), MF::Error> {
// Flush the pending unordered entries.
self.chunks_total_size = self.write_chunk()?;

let Sorter { chunks, merge, .. } = self;
let Sorter { chunks, merge_function: merge, .. } = self;
let result: Result<Vec<_>, _> = chunks
.into_iter()
.map(|mut chunk| {
@@ -669,6 +679,28 @@ where
}
}

impl<MF, CC: ChunkCreator> Debug for Sorter<MF, CC> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sorter")
.field("chunks_count", &self.chunks.len())
.field("remaining_entries", &self.entries.remaining())
.field("chunks_total_size", &self.chunks_total_size)
.field("allow_realloc", &self.allow_realloc)
.field("dump_threshold", &self.dump_threshold)
.field("max_nb_chunks", &self.max_nb_chunks)
.field("chunk_compression_type", &self.chunk_compression_type)
.field("chunk_compression_level", &self.chunk_compression_level)
.field("index_key_interval", &self.index_key_interval)
.field("block_size", &self.block_size)
.field("index_levels", &self.index_levels)
.field("chunk_creator", &"[chunck creator]")
.field("sort_algorithm", &self.sort_algorithm)
.field("sort_in_parallel", &self.sort_in_parallel)
.field("merge", &"[merge function]")
.finish()
}
}

/// A trait that represent a `ChunkCreator`.
pub trait ChunkCreator {
/// The generated chunk by this `ChunkCreator`.
@@ -733,14 +765,25 @@ mod tests {

use super::*;

fn merge<'a>(_key: &[u8], vals: &[Cow<'a, [u8]>]) -> Result<Cow<'a, [u8]>, Infallible> {
Ok(vals.iter().map(AsRef::as_ref).flatten().cloned().collect())
#[derive(Copy, Clone)]
struct ConcatMerger;

impl MergeFunction for ConcatMerger {
type Error = Infallible;

fn merge<'a>(
&self,
_key: &[u8],
values: &[Cow<'a, [u8]>],
) -> std::result::Result<Cow<'a, [u8]>, Self::Error> {
Ok(values.iter().flat_map(AsRef::as_ref).cloned().collect())
}
}

#[test]
#[cfg_attr(miri, ignore)]
fn simple_cursorvec() {
let mut sorter = SorterBuilder::new(merge)
let mut sorter = SorterBuilder::new(ConcatMerger)
.chunk_compression_type(CompressionType::Snappy)
.chunk_creator(CursorVec)
.build();
@@ -769,7 +812,7 @@ mod tests {
#[test]
#[cfg_attr(miri, ignore)]
fn hard_cursorvec() {
let mut sorter = SorterBuilder::new(merge)
let mut sorter = SorterBuilder::new(ConcatMerger)
.dump_threshold(1024) // 1KiB
.allow_realloc(false)
.chunk_compression_type(CompressionType::Snappy)
@@ -803,20 +846,27 @@ mod tests {
use rand::prelude::{SeedableRng, SliceRandom};
use rand::rngs::StdRng;

// This merge function concat bytes in the order they are received.
fn concat_bytes<'a>(
_key: &[u8],
values: &[Cow<'a, [u8]>],
) -> Result<Cow<'a, [u8]>, Infallible> {
let mut output = Vec::new();
for value in values {
output.extend_from_slice(&value);
/// This merge function concat bytes in the order they are received.
struct ConcatBytesMerger;

impl MergeFunction for ConcatBytesMerger {
type Error = Infallible;

fn merge<'a>(
&self,
_key: &[u8],
values: &[Cow<'a, [u8]>],
) -> std::result::Result<Cow<'a, [u8]>, Self::Error> {
let mut output = Vec::new();
for value in values {
output.extend_from_slice(value);
}
Ok(Cow::from(output))
}
Ok(Cow::from(output))
}

// We create a sorter that will sum our u32s when necessary.
let mut sorter = SorterBuilder::new(concat_bytes).chunk_creator(CursorVec).build();
let mut sorter = SorterBuilder::new(ConcatBytesMerger).chunk_creator(CursorVec).build();

// We insert all the possible values of an u8 in ascending order
// but we split them along different keys.
14 changes: 7 additions & 7 deletions src/writer.rs
Original file line number Diff line number Diff line change
@@ -146,9 +146,9 @@ impl Writer<()> {
}
}

impl<W: io::Write> Writer<W> {
impl<W: io::Write> AsRef<W> for Writer<W> {
/// Gets a reference to the underlying writer.
pub fn as_ref(&self) -> &W {
fn as_ref(&self) -> &W {
self.writer.as_ref()
}
}
@@ -330,7 +330,7 @@ mod tests {

for x in 0..2000u32 {
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
@@ -346,7 +346,7 @@ mod tests {

for x in 0..2000u32 {
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
@@ -363,7 +363,7 @@ mod tests {

for x in 0..2000u32 {
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();
@@ -378,11 +378,11 @@ mod tests {
.compression_type(grenad_0_4::CompressionType::Snappy)
.memory();

let total: u32 = 156_000;
let total: u32 = 1_500;

for x in 0..total {
let x = x.to_be_bytes();
writer.insert(&x, &x).unwrap();
writer.insert(x, x).unwrap();
}

let bytes = writer.into_inner().unwrap();