Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions crates/walrus-core/benches/encoding_phases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ fn primary_encoding_with_hashing(c: &mut Criterion) {
for (col_index, col) in columns.iter().enumerate() {
let symbols = enc.encode_all_ref(col).unwrap();
for (row_index, symbol) in symbols.to_symbols().enumerate() {
hashes[n_shards * row_index + col_index] =
hashes[col_index * n_shards + row_index] =
leaf_hash::<Blake2b256>(symbol);
}
}
Expand Down Expand Up @@ -188,15 +188,13 @@ fn metadata_from_hashes(c: &mut Criterion) {
// Build 2 * n_shards Merkle trees (primary + secondary per sliver pair).
for sliver_index in 0..n_shards {
let _primary = MerkleTree::<Blake2b256>::build_from_leaf_hashes(
hashes[n_shards * sliver_index..n_shards * (sliver_index + 1)]
.iter()
.cloned(),
(0..n_shards).map(|col| hashes[col * n_shards + sliver_index].clone()),
);
let sec_col = n_shards - 1 - sliver_index;
let _secondary = MerkleTree::<Blake2b256>::build_from_leaf_hashes(
(0..n_shards).map(|symbol_index| {
hashes[n_shards * symbol_index + n_shards - 1 - sliver_index]
.clone()
}),
hashes[sec_col * n_shards..(sec_col + 1) * n_shards]
.iter()
.cloned(),
);
}
});
Expand Down
153 changes: 91 additions & 62 deletions crates/walrus-core/src/encoding/blob_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0

use alloc::{collections::BTreeSet, vec, vec::Vec};
use core::{cmp, marker::PhantomData, num::NonZeroU16, ops::Range, slice::Chunks};
use core::{cell::RefCell, cmp, marker::PhantomData, num::NonZeroU16, ops::Range, slice::Chunks};

use fastcrypto::hash::Blake2b256;
use rayon::prelude::*;
Expand Down Expand Up @@ -153,7 +153,8 @@ impl BlobEncoderData {
/// Computes the blob metadata from the provided leaf hashes of all symbols.
///
/// The provided slice *must* be of length `n_shards * n_shards`, where `n_shards` is the number
/// of shards. The slice is interpreted as a matrix in row-major order.
/// of shards. The slice is interpreted as a matrix in column-major order:
/// `symbol_hashes[col * n_shards + row]`.
///
/// # Panics
///
Expand All @@ -172,16 +173,19 @@ impl BlobEncoderData {
let metadata: Vec<SliverPairMetadata> = (0..n_shards)
.into_par_iter()
.map(|sliver_index| {
// Column-major: primary tree gathers row `sliver_index` across all columns
// (strided access).
let primary_hash = MerkleTree::<Blake2b256>::build_from_leaf_hashes(
symbol_hashes[n_shards * sliver_index..n_shards * (sliver_index + 1)]
.iter()
.cloned(),
(0..n_shards).map(|col| symbol_hashes[col * n_shards + sliver_index].clone()),
)
.root();
// Column-major: secondary tree reads column `n_shards - 1 - sliver_index`
// as a contiguous slice.
let sec_col = n_shards - 1 - sliver_index;
let secondary_hash = MerkleTree::<Blake2b256>::build_from_leaf_hashes(
(0..n_shards).map(|symbol_index| {
symbol_hashes[n_shards * symbol_index + n_shards - 1 - sliver_index].clone()
}),
symbol_hashes[sec_col * n_shards..(sec_col + 1) * n_shards]
.iter()
.cloned(),
)
.root();
SliverPairMetadata {
Expand Down Expand Up @@ -341,47 +345,61 @@ impl<'a> BlobEncoder<'a> {
let n_rows = self.inner.n_rows_usize();
let symbol_usize = self.inner.symbol_usize();

// Parallel phase: each rayon task gets its own encoder to encode a column and hash
// all symbols. Repair symbols for non-systematic primary slivers are collected for
// the sequential scatter phase below.
let column_results: Vec<(usize, Vec<Node>, Option<Vec<u8>>)> = secondary_slivers
.par_iter()
// Thread-local encoder pool: reuse ReedSolomonEncoder across rayon tasks within this
// call, avoiding ~1000 FFT table constructions (reduced to ~8-16, one per thread).
std::thread_local! {
static PRIMARY_ENCODER: RefCell<Option<ReedSolomonEncoder>> =
const { RefCell::new(None) };
}

// Parallel phase: par_chunks_mut on column-major hashes gives each task an exclusive
// &mut [Node] slice, enabling direct writes with no intermediate allocation.
// Only repair symbol data is collected for the sequential scatter below.
let repair_results: Vec<(usize, Vec<u8>)> = symbol_hashes
.par_chunks_mut(n_shards)
.zip(secondary_slivers.par_iter())
.enumerate()
.map(|(col_index, column)| {
let mut encoder = self.inner.get_encoder::<Primary>();
let symbols = encoder
.encode_all(column.symbols.data())
.expect("size has already been checked");

let hashes: Vec<Node> = symbols.to_symbols().map(leaf_hash::<Blake2b256>).collect();

// Collect repair symbols for non-systematic primary slivers.
let repair_data = if col_index < n_columns {
let n_repair = n_shards - n_rows;
let mut data = vec![0u8; n_repair * symbol_usize];
for (i, symbol) in symbols.to_symbols().skip(n_rows).enumerate() {
data[i * symbol_usize..i * symbol_usize + symbol.len()]
.copy_from_slice(symbol);
.filter_map(|(col_index, (hash_col, column))| {
PRIMARY_ENCODER.with(|cell| {
let mut opt = cell.borrow_mut();
let encoder = if opt
.as_ref()
.is_some_and(|e| usize::from(e.symbol_size().get()) == symbol_usize)
{
opt.as_mut().expect("checked above")
} else {
opt.insert(self.inner.get_encoder::<Primary>())
};
let symbols = encoder
.encode_all(column.symbols.data())
.expect("size has already been checked");

// Write hashes directly into the column-major slice.
for (row_index, symbol) in symbols.to_symbols().enumerate() {
hash_col[row_index] = leaf_hash::<Blake2b256>(symbol);
}
Some(data)
} else {
None
};

(col_index, hashes, repair_data)
// Collect repair data only for systematic columns.
if col_index < n_columns {
let n_repair = n_shards - n_rows;
let mut data = vec![0u8; n_repair * symbol_usize];
for (i, symbol) in symbols.to_symbols().skip(n_rows).enumerate() {
data[i * symbol_usize..i * symbol_usize + symbol.len()]
.copy_from_slice(symbol);
}
Some((col_index, data))
} else {
None
}
})
})
.collect();

// Sequential scatter: write hashes and repair symbols to their destinations.
for (col_index, hashes, repair_data) in column_results {
for (row_index, hash) in hashes.into_iter().enumerate() {
symbol_hashes[n_shards * row_index + col_index] = hash;
}
if let Some(data) = repair_data {
for (i, sliver) in primary_slivers.iter_mut().skip(n_rows).enumerate() {
let start = i * symbol_usize;
sliver.copy_symbol_to(col_index, &data[start..start + symbol_usize]);
}
// Sequential scatter for repair symbols only.
for (col_index, data) in repair_results {
for (i, sliver) in primary_slivers.iter_mut().skip(n_rows).enumerate() {
let start = i * symbol_usize;
sliver.copy_symbol_to(col_index, &data[start..start + symbol_usize]);
}
}

Expand Down Expand Up @@ -483,26 +501,36 @@ impl<'a> BlobEncoder<'a> {
// Parallel primary encoding + hashing over all secondary slivers.
let mut symbol_hashes = vec![Node::Empty; n_shards * n_shards];

let column_results: Vec<(usize, Vec<Node>)> = secondary_slivers
.par_iter()
.enumerate()
.map(|(col_index, column)| {
let mut encoder = self.inner.get_encoder::<Primary>();
let symbols = encoder
.encode_all(column.symbols.data())
.expect("size has already been checked");
let hashes: Vec<Node> = symbols.to_symbols().map(leaf_hash::<Blake2b256>).collect();
(col_index, hashes)
})
.collect();
let symbol_usize = self.inner.symbol_usize();

// Sequential scatter.
for (col_index, hashes) in column_results {
for (row_index, hash) in hashes.into_iter().enumerate() {
symbol_hashes[n_shards * row_index + col_index] = hash;
}
std::thread_local! {
static PRIMARY_ENCODER: RefCell<Option<ReedSolomonEncoder>> =
const { RefCell::new(None) };
}

symbol_hashes
.par_chunks_mut(n_shards)
.zip(secondary_slivers.par_iter())
.for_each(|(hash_col, column)| {
PRIMARY_ENCODER.with(|cell| {
let mut opt = cell.borrow_mut();
let encoder = if opt
.as_ref()
.is_some_and(|e| usize::from(e.symbol_size().get()) == symbol_usize)
{
opt.as_mut().expect("checked above")
} else {
opt.insert(self.inner.get_encoder::<Primary>())
};
let symbols = encoder
.encode_all(column.symbols.data())
.expect("size has already been checked");
for (row_index, symbol) in symbols.to_symbols().enumerate() {
hash_col[row_index] = leaf_hash::<Blake2b256>(symbol);
}
});
});

BlobEncoderData::compute_metadata_from_symbol_hashes(
self.inner.config,
&symbol_hashes,
Expand Down Expand Up @@ -745,8 +773,9 @@ impl<'a> ExpandedMessageMatrix<'a> {

let n_shards = self.config.n_shards_as_usize();
let mut symbol_hashes = Vec::with_capacity(n_shards * n_shards);
for row in 0..n_shards {
for col in 0..n_shards {
// Column-major layout: symbol_hashes[col * n_shards + row].
for col in 0..n_shards {
for row in 0..n_shards {
symbol_hashes.push(leaf_hash::<Blake2b256>(&self.matrix[row][col]));
}
}
Expand Down
Loading