Skip to content

Commit b7f6cc2

Browse files
committed
perf: column-major hash layout with direct writes and thread-local encoder pooling
1 parent 7ade96f commit b7f6cc2

2 files changed

Lines changed: 97 additions & 70 deletions

File tree

crates/walrus-core/benches/encoding_phases.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ fn primary_encoding_with_hashing(c: &mut Criterion) {
151151
for (col_index, col) in columns.iter().enumerate() {
152152
let symbols = enc.encode_all_ref(col).unwrap();
153153
for (row_index, symbol) in symbols.to_symbols().enumerate() {
154-
hashes[n_shards * row_index + col_index] =
154+
hashes[col_index * n_shards + row_index] =
155155
leaf_hash::<Blake2b256>(symbol);
156156
}
157157
}
@@ -188,15 +188,13 @@ fn metadata_from_hashes(c: &mut Criterion) {
188188
// Build 2 * n_shards Merkle trees (primary + secondary per sliver pair).
189189
for sliver_index in 0..n_shards {
190190
let _primary = MerkleTree::<Blake2b256>::build_from_leaf_hashes(
191-
hashes[n_shards * sliver_index..n_shards * (sliver_index + 1)]
192-
.iter()
193-
.cloned(),
191+
(0..n_shards).map(|col| hashes[col * n_shards + sliver_index].clone()),
194192
);
193+
let sec_col = n_shards - 1 - sliver_index;
195194
let _secondary = MerkleTree::<Blake2b256>::build_from_leaf_hashes(
196-
(0..n_shards).map(|symbol_index| {
197-
hashes[n_shards * symbol_index + n_shards - 1 - sliver_index]
198-
.clone()
199-
}),
195+
hashes[sec_col * n_shards..(sec_col + 1) * n_shards]
196+
.iter()
197+
.cloned(),
200198
);
201199
}
202200
});

crates/walrus-core/src/encoding/blob_encoding.rs

Lines changed: 91 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
use alloc::{collections::BTreeSet, vec, vec::Vec};
5-
use core::{cmp, marker::PhantomData, num::NonZeroU16, ops::Range, slice::Chunks};
5+
use core::{cell::RefCell, cmp, marker::PhantomData, num::NonZeroU16, ops::Range, slice::Chunks};
66

77
use fastcrypto::hash::Blake2b256;
88
use rayon::prelude::*;
@@ -153,7 +153,8 @@ impl BlobEncoderData {
153153
/// Computes the blob metadata from the provided leaf hashes of all symbols.
154154
///
155155
/// The provided slice *must* be of length `n_shards * n_shards`, where `n_shards` is the number
156-
/// of shards. The slice is interpreted as a matrix in row-major order.
156+
/// of shards. The slice is interpreted as a matrix in column-major order:
157+
/// `symbol_hashes[col * n_shards + row]`.
157158
///
158159
/// # Panics
159160
///
@@ -172,16 +173,19 @@ impl BlobEncoderData {
172173
let metadata: Vec<SliverPairMetadata> = (0..n_shards)
173174
.into_par_iter()
174175
.map(|sliver_index| {
176+
// Column-major: primary tree gathers row `sliver_index` across all columns
177+
// (strided access).
175178
let primary_hash = MerkleTree::<Blake2b256>::build_from_leaf_hashes(
176-
symbol_hashes[n_shards * sliver_index..n_shards * (sliver_index + 1)]
177-
.iter()
178-
.cloned(),
179+
(0..n_shards).map(|col| symbol_hashes[col * n_shards + sliver_index].clone()),
179180
)
180181
.root();
182+
// Column-major: secondary tree reads column `n_shards - 1 - sliver_index`
183+
// as a contiguous slice.
184+
let sec_col = n_shards - 1 - sliver_index;
181185
let secondary_hash = MerkleTree::<Blake2b256>::build_from_leaf_hashes(
182-
(0..n_shards).map(|symbol_index| {
183-
symbol_hashes[n_shards * symbol_index + n_shards - 1 - sliver_index].clone()
184-
}),
186+
symbol_hashes[sec_col * n_shards..(sec_col + 1) * n_shards]
187+
.iter()
188+
.cloned(),
185189
)
186190
.root();
187191
SliverPairMetadata {
@@ -341,47 +345,61 @@ impl<'a> BlobEncoder<'a> {
341345
let n_rows = self.inner.n_rows_usize();
342346
let symbol_usize = self.inner.symbol_usize();
343347

344-
// Parallel phase: each rayon task gets its own encoder to encode a column and hash
345-
// all symbols. Repair symbols for non-systematic primary slivers are collected for
346-
// the sequential scatter phase below.
347-
let column_results: Vec<(usize, Vec<Node>, Option<Vec<u8>>)> = secondary_slivers
348-
.par_iter()
348+
// Thread-local encoder pool: reuse ReedSolomonEncoder across rayon tasks within this
349+
// call, avoiding ~1000 FFT table constructions (reduced to ~8-16, one per thread).
350+
std::thread_local! {
351+
static PRIMARY_ENCODER: RefCell<Option<ReedSolomonEncoder>> =
352+
const { RefCell::new(None) };
353+
}
354+
355+
// Parallel phase: par_chunks_mut on column-major hashes gives each task an exclusive
356+
// &mut [Node] slice, enabling direct writes with no intermediate allocation.
357+
// Only repair symbol data is collected for the sequential scatter below.
358+
let repair_results: Vec<(usize, Vec<u8>)> = symbol_hashes
359+
.par_chunks_mut(n_shards)
360+
.zip(secondary_slivers.par_iter())
349361
.enumerate()
350-
.map(|(col_index, column)| {
351-
let mut encoder = self.inner.get_encoder::<Primary>();
352-
let symbols = encoder
353-
.encode_all(column.symbols.data())
354-
.expect("size has already been checked");
355-
356-
let hashes: Vec<Node> = symbols.to_symbols().map(leaf_hash::<Blake2b256>).collect();
357-
358-
// Collect repair symbols for non-systematic primary slivers.
359-
let repair_data = if col_index < n_columns {
360-
let n_repair = n_shards - n_rows;
361-
let mut data = vec![0u8; n_repair * symbol_usize];
362-
for (i, symbol) in symbols.to_symbols().skip(n_rows).enumerate() {
363-
data[i * symbol_usize..i * symbol_usize + symbol.len()]
364-
.copy_from_slice(symbol);
362+
.filter_map(|(col_index, (hash_col, column))| {
363+
PRIMARY_ENCODER.with(|cell| {
364+
let mut opt = cell.borrow_mut();
365+
let encoder = if opt
366+
.as_ref()
367+
.is_some_and(|e| usize::from(e.symbol_size().get()) == symbol_usize)
368+
{
369+
opt.as_mut().expect("checked above")
370+
} else {
371+
opt.insert(self.inner.get_encoder::<Primary>())
372+
};
373+
let symbols = encoder
374+
.encode_all(column.symbols.data())
375+
.expect("size has already been checked");
376+
377+
// Write hashes directly into the column-major slice.
378+
for (row_index, symbol) in symbols.to_symbols().enumerate() {
379+
hash_col[row_index] = leaf_hash::<Blake2b256>(symbol);
365380
}
366-
Some(data)
367-
} else {
368-
None
369-
};
370381

371-
(col_index, hashes, repair_data)
382+
// Collect repair data only for systematic columns.
383+
if col_index < n_columns {
384+
let n_repair = n_shards - n_rows;
385+
let mut data = vec![0u8; n_repair * symbol_usize];
386+
for (i, symbol) in symbols.to_symbols().skip(n_rows).enumerate() {
387+
data[i * symbol_usize..i * symbol_usize + symbol.len()]
388+
.copy_from_slice(symbol);
389+
}
390+
Some((col_index, data))
391+
} else {
392+
None
393+
}
394+
})
372395
})
373396
.collect();
374397

375-
// Sequential scatter: write hashes and repair symbols to their destinations.
376-
for (col_index, hashes, repair_data) in column_results {
377-
for (row_index, hash) in hashes.into_iter().enumerate() {
378-
symbol_hashes[n_shards * row_index + col_index] = hash;
379-
}
380-
if let Some(data) = repair_data {
381-
for (i, sliver) in primary_slivers.iter_mut().skip(n_rows).enumerate() {
382-
let start = i * symbol_usize;
383-
sliver.copy_symbol_to(col_index, &data[start..start + symbol_usize]);
384-
}
398+
// Sequential scatter for repair symbols only.
399+
for (col_index, data) in repair_results {
400+
for (i, sliver) in primary_slivers.iter_mut().skip(n_rows).enumerate() {
401+
let start = i * symbol_usize;
402+
sliver.copy_symbol_to(col_index, &data[start..start + symbol_usize]);
385403
}
386404
}
387405

@@ -483,26 +501,36 @@ impl<'a> BlobEncoder<'a> {
483501
// Parallel primary encoding + hashing over all secondary slivers.
484502
let mut symbol_hashes = vec![Node::Empty; n_shards * n_shards];
485503

486-
let column_results: Vec<(usize, Vec<Node>)> = secondary_slivers
487-
.par_iter()
488-
.enumerate()
489-
.map(|(col_index, column)| {
490-
let mut encoder = self.inner.get_encoder::<Primary>();
491-
let symbols = encoder
492-
.encode_all(column.symbols.data())
493-
.expect("size has already been checked");
494-
let hashes: Vec<Node> = symbols.to_symbols().map(leaf_hash::<Blake2b256>).collect();
495-
(col_index, hashes)
496-
})
497-
.collect();
504+
let symbol_usize = self.inner.symbol_usize();
498505

499-
// Sequential scatter.
500-
for (col_index, hashes) in column_results {
501-
for (row_index, hash) in hashes.into_iter().enumerate() {
502-
symbol_hashes[n_shards * row_index + col_index] = hash;
503-
}
506+
std::thread_local! {
507+
static PRIMARY_ENCODER: RefCell<Option<ReedSolomonEncoder>> =
508+
const { RefCell::new(None) };
504509
}
505510

511+
symbol_hashes
512+
.par_chunks_mut(n_shards)
513+
.zip(secondary_slivers.par_iter())
514+
.for_each(|(hash_col, column)| {
515+
PRIMARY_ENCODER.with(|cell| {
516+
let mut opt = cell.borrow_mut();
517+
let encoder = if opt
518+
.as_ref()
519+
.is_some_and(|e| usize::from(e.symbol_size().get()) == symbol_usize)
520+
{
521+
opt.as_mut().expect("checked above")
522+
} else {
523+
opt.insert(self.inner.get_encoder::<Primary>())
524+
};
525+
let symbols = encoder
526+
.encode_all(column.symbols.data())
527+
.expect("size has already been checked");
528+
for (row_index, symbol) in symbols.to_symbols().enumerate() {
529+
hash_col[row_index] = leaf_hash::<Blake2b256>(symbol);
530+
}
531+
});
532+
});
533+
506534
BlobEncoderData::compute_metadata_from_symbol_hashes(
507535
self.inner.config,
508536
&symbol_hashes,
@@ -745,8 +773,9 @@ impl<'a> ExpandedMessageMatrix<'a> {
745773

746774
let n_shards = self.config.n_shards_as_usize();
747775
let mut symbol_hashes = Vec::with_capacity(n_shards * n_shards);
748-
for row in 0..n_shards {
749-
for col in 0..n_shards {
776+
// Column-major layout: symbol_hashes[col * n_shards + row].
777+
for col in 0..n_shards {
778+
for row in 0..n_shards {
750779
symbol_hashes.push(leaf_hash::<Blake2b256>(&self.matrix[row][col]));
751780
}
752781
}

0 commit comments

Comments
 (0)