diff --git a/Cargo.toml b/Cargo.toml index ca692d3..463c505 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ std = [] [dependencies] cfg-if = "1.0" +itertools = {version = "0.10.5", default-features = false, features = ["use_alloc"]} [dev-dependencies] faster-hex = "0.8.0" diff --git a/src/ancestry_proof.rs b/src/ancestry_proof.rs new file mode 100644 index 0000000..6d23b03 --- /dev/null +++ b/src/ancestry_proof.rs @@ -0,0 +1,342 @@ +use crate::collections::VecDeque; +use crate::helper::{ + get_peak_map, get_peaks, is_descendant_pos, leaf_index_to_pos, parent_offset, + pos_height_in_tree, sibling_offset, +}; +use crate::vec::Vec; +use crate::{Error, Merge, Result}; +use core::fmt::Debug; +use core::marker::PhantomData; +use itertools::Itertools; + +#[derive(Debug)] +pub struct NodeMerkleProof { + mmr_size: u64, + proof: Vec<(u64, T)>, + merge: PhantomData, +} + +#[derive(Debug)] +pub struct AncestryProof { + pub prev_peaks: Vec, + pub prev_size: u64, + pub proof: NodeMerkleProof, +} + +impl> AncestryProof { + // TODO: restrict roots to be T::Node + pub fn verify_ancestor(&self, root: T, prev_root: T) -> Result { + let current_leaves_count = get_peak_map(self.proof.mmr_size); + if current_leaves_count <= self.prev_peaks.len() as u64 { + return Err(Error::CorruptedProof); + } + // Test if previous root is correct. + let prev_peaks_positions = { + let prev_peaks_positions = get_peaks(self.prev_size); + if prev_peaks_positions.len() != self.prev_peaks.len() { + return Err(Error::CorruptedProof); + } + prev_peaks_positions + }; + + let calculated_prev_root = bagging_peaks_hashes::(self.prev_peaks.clone())?; + if calculated_prev_root != prev_root { + return Ok(false); + } + + let nodes = self + .prev_peaks + .clone() + .into_iter() + .zip(prev_peaks_positions.iter()) + .map(|(peak, position)| (*position, peak)) + .collect(); + + self.proof.verify(root, nodes) + } +} + +impl> NodeMerkleProof { + pub fn new(mmr_size: u64, proof: Vec<(u64, T)>) -> Self { + NodeMerkleProof { + mmr_size, + proof, + merge: PhantomData, + } + } + + pub fn mmr_size(&self) -> u64 { + self.mmr_size + } + + pub fn proof_items(&self) -> &[(u64, T)] { + &self.proof + } + + pub fn calculate_root(&self, leaves: Vec<(u64, T)>) -> Result { + calculate_root::<_, M, _>(leaves, self.mmr_size, self.proof.iter()) + } + + /// from merkle proof of leaf n to calculate merkle root of n + 1 leaves. + /// by observe the MMR construction graph we know it is possible. + /// https://github.com/jjyr/merkle-mountain-range#construct + pub fn calculate_root_with_new_leaf( + &self, + mut nodes: Vec<(u64, T)>, + new_pos: u64, + new_elem: T, + new_mmr_size: u64, + ) -> Result { + nodes.push((new_pos, new_elem)); + calculate_root::<_, M, _>(nodes, new_mmr_size, self.proof.iter()) + } + + pub fn verify(&self, root: T, nodes: Vec<(u64, T)>) -> Result { + let calculated_root = self.calculate_root(nodes)?; + Ok(calculated_root == root) + } + + /// Verifies a old root and all incremental leaves. + /// + /// If this method returns `true`, it means the following assertion are true: + /// - The old root could be generated in the history of the current MMR. + /// - All incremental leaves are on the current MMR. + /// - The MMR, which could generate the old root, appends all incremental leaves, becomes the + /// current MMR. + pub fn verify_incremental(&self, root: T, prev_root: T, incremental: Vec) -> Result { + let current_leaves_count = get_peak_map(self.mmr_size); + if current_leaves_count <= incremental.len() as u64 { + return Err(Error::CorruptedProof); + } + // Test if previous root is correct. + let prev_leaves_count = current_leaves_count - incremental.len() as u64; + + let prev_peaks: Vec<_> = self + .proof_items() + .iter() + .map(|(_, item)| item.clone()) + .collect(); + + let calculated_prev_root = bagging_peaks_hashes::(prev_peaks)?; + if calculated_prev_root != prev_root { + return Ok(false); + } + + // Test if incremental leaves are correct. + let leaves = incremental + .into_iter() + .enumerate() + .map(|(index, leaf)| { + let pos = leaf_index_to_pos(prev_leaves_count + index as u64); + (pos, leaf) + }) + .collect(); + self.verify(root, leaves) + } +} + +fn calculate_peak_root< + 'a, + T: 'a + PartialEq, + M: Merge, + // I: Iterator +>( + nodes: Vec<(u64, T)>, + peak_pos: u64, + // proof_iter: &mut I, +) -> Result { + debug_assert!(!nodes.is_empty(), "can't be empty"); + // (position, hash, height) + + let mut queue: VecDeque<_> = nodes + .into_iter() + .map(|(pos, item)| (pos, item, pos_height_in_tree(pos))) + .collect(); + + let mut sibs_processed_from_back = Vec::new(); + + // calculate tree root from each items + while let Some((pos, item, height)) = queue.pop_front() { + if pos == peak_pos { + if queue.is_empty() { + // return root once queue is consumed + return Ok(item); + } + if queue + .iter() + .any(|entry| entry.0 == peak_pos && entry.1 != item) + { + return Err(Error::CorruptedProof); + } + if queue + .iter() + .all(|entry| entry.0 == peak_pos && &entry.1 == &item && entry.2 == height) + { + // return root if remaining queue consists only of duplicate root entries + return Ok(item); + } + // if queue not empty, push peak back to the end + queue.push_back((pos, item, height)); + continue; + } + // calculate sibling + let next_height = pos_height_in_tree(pos + 1); + let (parent_pos, parent_item) = { + let sibling_offset = sibling_offset(height); + if next_height > height { + // implies pos is right sibling + let (sib_pos, parent_pos) = (pos - sibling_offset, pos + 1); + let parent_item = if Some(&sib_pos) == queue.front().map(|(pos, _, _)| pos) { + let sibling_item = queue.pop_front().map(|(_, item, _)| item).unwrap(); + M::merge(&sibling_item, &item)? + } else if Some(&sib_pos) == queue.back().map(|(pos, _, _)| pos) { + let sibling_item = queue.pop_back().map(|(_, item, _)| item).unwrap(); + M::merge(&sibling_item, &item)? + } + // handle special if next queue item is descendant of sibling + else if let Some(&(front_pos, ..)) = queue.front() { + if height > 0 && is_descendant_pos(sib_pos, front_pos) { + queue.push_back((pos, item, height)); + continue; + } else { + return Err(Error::CorruptedProof); + } + } else { + return Err(Error::CorruptedProof); + }; + (parent_pos, parent_item) + } else { + // pos is left sibling + let (sib_pos, parent_pos) = (pos + sibling_offset, pos + parent_offset(height)); + let parent_item = if Some(&sib_pos) == queue.front().map(|(pos, _, _)| pos) { + let sibling_item = queue.pop_front().map(|(_, item, _)| item).unwrap(); + M::merge(&item, &sibling_item)? + } else if Some(&sib_pos) == queue.back().map(|(pos, _, _)| pos) { + let sibling_item = queue.pop_back().map(|(_, item, _)| item).unwrap(); + let parent = M::merge(&item, &sibling_item)?; + sibs_processed_from_back.push((sib_pos, sibling_item, height)); + parent + } else if let Some(&(front_pos, ..)) = queue.front() { + if height > 0 && is_descendant_pos(sib_pos, front_pos) { + queue.push_back((pos, item, height)); + continue; + } else { + return Err(Error::CorruptedProof); + } + } else { + return Err(Error::CorruptedProof); + }; + (parent_pos, parent_item) + } + }; + + if parent_pos <= peak_pos { + let parent = (parent_pos, parent_item, height + 1); + if peak_pos == parent_pos + || queue.front() != Some(&parent) + && !sibs_processed_from_back.iter().any(|item| item == &parent) + { + queue.push_front(parent) + }; + } else { + return Err(Error::CorruptedProof); + } + } + Err(Error::CorruptedProof) +} + +fn calculate_peaks_hashes< + 'a, + T: 'a + PartialEq + Clone, + M: Merge, + I: Iterator, +>( + nodes: Vec<(u64, T)>, + mmr_size: u64, + proof_iter: I, +) -> Result> { + // special handle the only 1 leaf MMR + if mmr_size == 1 && nodes.len() == 1 && nodes[0].0 == 0 { + return Ok(nodes.into_iter().map(|(_pos, item)| item).collect()); + } + + // ensure nodes are sorted and unique + let mut nodes: Vec<_> = nodes + .into_iter() + .chain(proof_iter.cloned()) + .sorted_by_key(|(pos, _)| *pos) + .dedup_by(|a, b| a.0 == b.0) + .collect(); + + let peaks = get_peaks(mmr_size); + + let mut peaks_hashes: Vec = Vec::with_capacity(peaks.len() + 1); + for peak_pos in peaks { + let mut nodes: Vec<(u64, T)> = take_while_vec(&mut nodes, |(pos, _)| *pos <= peak_pos); + let peak_root = if nodes.len() == 1 && nodes[0].0 == peak_pos { + // leaf is the peak + nodes.remove(0).1 + } else if nodes.is_empty() { + // if empty, means the next proof is a peak root or rhs bagged root + // means that either all right peaks are bagged, or proof is corrupted + // so we break loop and check no items left + break; + } else { + calculate_peak_root::<_, M>(nodes, peak_pos)? + }; + peaks_hashes.push(peak_root.clone()); + } + + // ensure nothing left in leaves + if nodes.len() != 0 { + return Err(Error::CorruptedProof); + } + + // check rhs peaks + // if let Some((_, rhs_peaks_hashes)) = proof_iter.next() { + // peaks_hashes.push(rhs_peaks_hashes.clone()); + // } + // ensure nothing left in proof_iter + // if proof_iter.next().is_some() { + // return Err(Error::CorruptedProof); + // } + Ok(peaks_hashes) +} + +pub fn bagging_peaks_hashes>(mut peaks_hashes: Vec) -> Result { + // bagging peaks + // bagging from right to left via hash(right, left). + while peaks_hashes.len() > 1 { + let right_peak = peaks_hashes.pop().expect("pop"); + let left_peak = peaks_hashes.pop().expect("pop"); + peaks_hashes.push(M::merge_peaks(&right_peak, &left_peak)?); + } + peaks_hashes.pop().ok_or(Error::CorruptedProof) +} + +/// merkle proof +/// 1. sort items by position +/// 2. calculate root of each peak +/// 3. bagging peaks +fn calculate_root< + 'a, + T: 'a + PartialEq + Clone, + M: Merge, + I: Iterator, +>( + nodes: Vec<(u64, T)>, + mmr_size: u64, + proof_iter: I, +) -> Result { + let peaks_hashes = calculate_peaks_hashes::<_, M, _>(nodes, mmr_size, proof_iter)?; + bagging_peaks_hashes::<_, M>(peaks_hashes) +} + +fn take_while_vec bool>(v: &mut Vec, p: P) -> Vec { + for i in 0..v.len() { + if !p(&v[i]) { + return v.drain(..i).collect(); + } + } + v.drain(..).collect() +} diff --git a/src/error.rs b/src/error.rs index c1c9276..741c2df 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,15 +2,16 @@ pub type Result = core::result::Result; #[derive(Debug, PartialEq, Eq, Clone)] pub enum Error { + AncestorRootNotPredecessor, GetRootOnEmpty, InconsistentStore, StoreError(crate::string::String), /// proof items is not enough to build a tree CorruptedProof, - /// tried to verify proof of a non-leaf - NodeProofsNotSupported, /// The leaves is an empty list, or beyond the mmr range GenProofForInvalidLeaves, + /// The nodes are an empty list, or beyond the mmr range + GenProofForInvalidNodes, /// The two nodes couldn't merge into one. MergeError(crate::string::String), @@ -20,12 +21,13 @@ impl core::fmt::Display for Error { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { use Error::*; match self { + AncestorRootNotPredecessor => write!(f, "Ancestor mmr size exceeds current mmr size")?, GetRootOnEmpty => write!(f, "Get root on an empty MMR")?, InconsistentStore => write!(f, "Inconsistent store")?, StoreError(msg) => write!(f, "Store error {}", msg)?, CorruptedProof => write!(f, "Corrupted proof")?, - NodeProofsNotSupported => write!(f, "Tried to verify membership of a non-leaf")?, - GenProofForInvalidLeaves => write!(f, "Generate proof ofr invalid leaves")?, + GenProofForInvalidLeaves => write!(f, "Generate proof for invalid leaves")?, + GenProofForInvalidNodes => write!(f, "Generate proof for invalid nodes")?, MergeError(msg) => write!(f, "Merge error {}", msg)?, } Ok(()) diff --git a/src/helper.rs b/src/helper.rs index 8a24e68..50ca5a0 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -75,6 +75,16 @@ pub fn get_peak_map(mmr_size: u64) -> u64 { peak_map } +/// Returns whether `descendant_contender` is a descendant of `ancestor_contender` in a tree of the MMR. +pub fn is_descendant_pos(ancestor_contender: u64, descendant_contender: u64) -> bool { + // NOTE: "ancestry" here refers to the hierarchy within an MMR tree, not temporal hierarchy. + // the descendant needs to have been added to the mmr prior to the ancestor + descendant_contender <= ancestor_contender + // the descendant needs to be within the cone of positions descendant from the ancestor + && descendant_contender + >= (ancestor_contender + 1 - sibling_offset(pos_height_in_tree(ancestor_contender))) +} + /// Returns the pos of the peaks in the mmr. /// for example, for a mmr with 11 leaves, the mmr_size is 19, it will return [14, 17, 18]. /// 14 diff --git a/src/lib.rs b/src/lib.rs index 96b3fc0..f5e94db 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ #![cfg_attr(not(feature = "std"), no_std)] +pub mod ancestry_proof; mod error; pub mod helper; mod merge; @@ -9,6 +10,7 @@ mod mmr_store; mod tests; pub mod util; +pub use ancestry_proof::{AncestryProof, NodeMerkleProof}; pub use error::{Error, Result}; pub use helper::{leaf_index_to_mmr_size, leaf_index_to_pos}; pub use merge::Merge; diff --git a/src/mmr.rs b/src/mmr.rs index fcd32c4..ac5108d 100644 --- a/src/mmr.rs +++ b/src/mmr.rs @@ -4,11 +4,12 @@ //! https://github.com/mimblewimble/grin/blob/master/doc/mmr.md#structure //! https://github.com/mimblewimble/grin/blob/0ff6763ee64e5a14e70ddd4642b99789a1648a32/core/src/core/pmmr.rs#L606 +use crate::ancestry_proof::{AncestryProof, NodeMerkleProof}; use crate::borrow::Cow; use crate::collections::VecDeque; use crate::helper::{ - get_peak_map, get_peaks, leaf_index_to_mmr_size, leaf_index_to_pos, parent_offset, - pos_height_in_tree, sibling_offset, + get_peak_map, get_peaks, is_descendant_pos, leaf_index_to_mmr_size, leaf_index_to_pos, + parent_offset, pos_height_in_tree, sibling_offset, }; use crate::mmr_store::{MMRBatch, MMRStoreReadOps, MMRStoreWriteOps}; use crate::vec; @@ -102,6 +103,38 @@ impl, S: MMRStoreReadOps> MMR Result<(Vec, T)> { + if self.mmr_size == 0 { + return Err(Error::GetRootOnEmpty); + } else if self.mmr_size == 1 && prev_mmr_size == 1 { + let singleton = self.batch.get_elem(0)?.ok_or(Error::InconsistentStore); + match singleton { + Ok(singleton) => return Ok((vec![singleton.clone()], singleton)), + Err(e) => return Err(e), + } + } else if prev_mmr_size > self.mmr_size { + return Err(Error::AncestorRootNotPredecessor); + } + let peaks: Result> = get_peaks(prev_mmr_size) + .into_iter() + .map(|peak_pos| { + self.batch + .get_elem(peak_pos) + .and_then(|elem| elem.ok_or(Error::InconsistentStore)) + }) + .collect::>>(); + match peaks { + Ok(peaks) => { + let root = self + .bag_rhs_peaks(peaks.clone())? + .ok_or(Error::InconsistentStore)?; + return Ok((peaks, root)); + } + Err(e) => Err(e), + } + } + fn bag_rhs_peaks(&self, mut rhs_peaks: Vec) -> Result> { while rhs_peaks.len() > 1 { let right_peak = rhs_peaks.pop().expect("pop"); @@ -146,7 +179,7 @@ impl, S: MMRStoreReadOps> MMR, S: MMRStoreReadOps> MMR, + pos_list: Vec, + peak_pos: u64, + ) -> Result<()> { + // do nothing if position itself is the peak + if pos_list.len() == 1 && pos_list == [peak_pos] { + return Ok(()); + } + // take peak root from store if no positions need to be proven + if pos_list.is_empty() { + proof.push(( + peak_pos, + self.batch + .get_elem(peak_pos)? + .ok_or(Error::InconsistentStore)?, + )); + return Ok(()); + } + + let mut queue: VecDeque<_> = pos_list + .clone() + .into_iter() + .map(|pos| (pos, pos_height_in_tree(pos))) + .collect(); + + // Generate sub-tree merkle proof for positions + while let Some((pos, height)) = queue.pop_front() { + debug_assert!(pos <= peak_pos); + if pos == peak_pos { + if queue.is_empty() { + break; + } else { + continue; + } + } + + // calculate sibling + let (sib_pos, parent_pos) = { + let next_height = pos_height_in_tree(pos + 1); + let sibling_offset = sibling_offset(height); + if next_height > height { + // implies pos is right sibling + (pos - sibling_offset, pos + 1) + } else { + // pos is left sibling + (pos + sibling_offset, pos + parent_offset(height)) + } + }; + + let queue_front_pos = queue.front().map(|(pos, _)| pos); + if Some(&sib_pos) == queue_front_pos { + // drop sibling + queue.pop_front(); + } else if queue_front_pos.is_none() + || !is_descendant_pos( + sib_pos, + *queue_front_pos.expect("checked queue_front_pos != None"), + ) + // only push a sibling into the proof if either of these cases is satisfied: + // 1. the queue is empty + // 2. the next item in the queue is not the sibling or a child of it + { + let sibling = ( + sib_pos, + self.batch + .get_elem(sib_pos.clone())? + .ok_or(Error::InconsistentStore)?, + ); + + // only push sibling if it's not already a proof item or to be proven, + // which can be the case if both a child and its parent are to be proven + if height == 0 + || !(proof.contains(&sibling)) && pos_list.binary_search(&sib_pos).is_err() + { + proof.push(sibling); + } + } + if parent_pos < peak_pos { + // save pos to tree buf + queue.push_back((parent_pos, height + 1)); + } + } + Ok(()) + } + /// Generate merkle proof for positions /// 1. sort positions /// 2. push merkle proof to proof by peak from left to right @@ -193,7 +319,7 @@ impl, S: MMRStoreReadOps> MMR 0) { - return Err(Error::NodeProofsNotSupported); + return Err(Error::GenProofForInvalidLeaves); } // ensure positions are sorted and unique pos_list.sort_unstable(); @@ -224,6 +350,120 @@ impl, S: MMRStoreReadOps> MMR) -> Result> { + if pos_list.is_empty() { + return Err(Error::GenProofForInvalidNodes); + } + if self.mmr_size == 1 && pos_list == [0] { + return Ok(NodeMerkleProof::new(self.mmr_size, Vec::new())); + } + // ensure positions are sorted and unique + pos_list.sort_unstable(); + pos_list.dedup(); + let peaks = get_peaks(self.mmr_size); + let mut proof: Vec<(u64, T)> = Vec::new(); + // generate merkle proof for each peaks + let mut bagging_track = 0; + for peak_pos in peaks { + let pos_list: Vec<_> = take_while_vec(&mut pos_list, |&pos| pos <= peak_pos); + if pos_list.is_empty() { + bagging_track += 1; + } else { + bagging_track = 0; + } + self.gen_node_proof_for_peak(&mut proof, pos_list, peak_pos)?; + } + + // ensure no remain positions + if !pos_list.is_empty() { + return Err(Error::GenProofForInvalidNodes); + } + + // starting from the rightmost peak, an unbroken sequence of + // peaks that don't have descendants to be proven can be bagged + // during the proof construction already since during verification, + // they'll only be utilized during the bagging step anyway + if bagging_track > 1 { + let rhs_peaks = proof.split_off(proof.len() - bagging_track); + proof.push(( + rhs_peaks[0].0, + self.bag_rhs_peaks(rhs_peaks.iter().map(|(_pos, item)| item.clone()).collect())? + .expect("bagging rhs peaks"), + )); + } + + proof.sort_by_key(|(pos, _)| *pos); + + Ok(NodeMerkleProof::new(self.mmr_size, proof)) + } + + /// Generate proof that prior merkle root r' is an ancestor of current merkle proof r + /// 1. calculate positions of peaks of old root r' given mmr size n + /// 2. generate membership proof of peaks in root r + /// 3. calculate r' from peaks(n) + /// 4. return (mmr root r', peak hashes, membership proof of peaks(n) in r) + pub fn gen_ancestry_proof(&self, prev_mmr_size: u64) -> Result> { + let mut pos_list = get_peaks(prev_mmr_size); + if pos_list.is_empty() { + return Err(Error::GenProofForInvalidNodes); + } + if self.mmr_size == 1 && pos_list == [0] { + return Ok(AncestryProof { + prev_peaks: Vec::new(), + prev_size: self.mmr_size, + proof: NodeMerkleProof::new(self.mmr_size(), Vec::new()), + }); + } + // ensure positions are sorted and unique + pos_list.sort_unstable(); + pos_list.dedup(); + let peaks = get_peaks(self.mmr_size); + let mut proof: Vec<(u64, T)> = Vec::new(); + // generate merkle proof for each peaks + let mut bagging_track = 0; + for peak_pos in peaks { + let pos_list: Vec<_> = take_while_vec(&mut pos_list, |&pos| pos <= peak_pos); + if pos_list.is_empty() { + bagging_track += 1; + } else { + bagging_track = 0; + } + self.gen_node_proof_for_peak(&mut proof, pos_list, peak_pos)?; + } + + // ensure no remain positions + if !pos_list.is_empty() { + return Err(Error::GenProofForInvalidNodes); + } + + // starting from the rightmost peak, an unbroken sequence of + // peaks that don't have descendants to be proven can be bagged + // during the proof construction already since during verification, + // they'll only be utilized during the bagging step anyway + if bagging_track > 1 { + let rhs_peaks = proof.split_off(proof.len() - bagging_track); + proof.push(( + rhs_peaks[0].0, + self.bag_rhs_peaks(rhs_peaks.iter().map(|(_pos, item)| item.clone()).collect())? + .expect("bagging rhs peaks"), + )); + } + + proof.sort_by_key(|(pos, _)| *pos); + + let (prev_peaks, _prev_root) = self.get_ancestor_peaks_and_root(prev_mmr_size)?; + + Ok(AncestryProof { + prev_peaks, + prev_size: prev_mmr_size, + proof: NodeMerkleProof::new(self.mmr_size, proof), + }) + } } impl> MMR { @@ -419,7 +659,7 @@ fn calculate_peaks_hashes<'a, T: 'a + Clone, M: Merge, I: Iterator Result> { if leaves.iter().any(|(pos, _)| pos_height_in_tree(*pos) > 0) { - return Err(Error::NodeProofsNotSupported); + return Err(Error::GenProofForInvalidLeaves); } // special handle the only 1 leaf MMR diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 81fb485..14e6d75 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,7 +1,9 @@ mod test_accumulate_headers; +mod test_ancestry; mod test_helper; mod test_incremental; mod test_mmr; +mod test_node_mmr; mod test_sequence; use crate::{Merge, Result}; diff --git a/src/tests/test_ancestry.rs b/src/tests/test_ancestry.rs new file mode 100644 index 0000000..fc016df --- /dev/null +++ b/src/tests/test_ancestry.rs @@ -0,0 +1,25 @@ +use super::{MergeNumberHash, NumberHash}; +use crate::leaf_index_to_mmr_size; +use crate::util::{MemMMR, MemStore}; + +#[test] +fn test_ancestry() { + let store = MemStore::default(); + let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store); + + let mmr_size = 300; + let mut prev_roots = Vec::new(); + for i in 0..mmr_size { + mmr.push(NumberHash::from(i)).unwrap(); + prev_roots.push(mmr.get_root().expect("get root")); + } + + let root = mmr.get_root().expect("get root"); + for i in 0..mmr_size { + let prev_size = leaf_index_to_mmr_size(i.into()); + let ancestry_proof = mmr.gen_ancestry_proof(prev_size).expect("gen proof"); + assert!(ancestry_proof + .verify_ancestor(root.clone(), prev_roots[i as usize].clone()) + .unwrap()); + } +} diff --git a/src/tests/test_mmr.rs b/src/tests/test_mmr.rs index 7e4405c..c756875 100644 --- a/src/tests/test_mmr.rs +++ b/src/tests/test_mmr.rs @@ -232,7 +232,7 @@ fn test_invalid_proof_verification( assert!(proof.verify(root.clone(), entries_to_verify).unwrap()); assert!(!proof.verify(root, tampered_entries_to_verify).unwrap()); } - Err(Error::NodeProofsNotSupported) => { + Err(Error::GenProofForInvalidLeaves) => { // if couldn't generate proof, then it contained a non-leaf assert!(positions_to_verify .iter() diff --git a/src/tests/test_node_mmr.rs b/src/tests/test_node_mmr.rs new file mode 100644 index 0000000..8d31593 --- /dev/null +++ b/src/tests/test_node_mmr.rs @@ -0,0 +1,362 @@ +use super::{MergeNumberHash, NumberHash}; +use crate::{ + leaf_index_to_mmr_size, + util::{MemMMR, MemStore}, + Error, +}; +use core::ops::Shl; +use faster_hex::hex_string; +use proptest::prelude::*; +use rand::{seq::SliceRandom, thread_rng}; + +fn test_mmr(count: u32, proof_elem: Vec) { + let store = MemStore::default(); + let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store); + let positions: Vec = (0u32..count) + .map(|i| mmr.push(NumberHash::from(i)).unwrap()) + .collect(); + let root = mmr.get_root().expect("get root"); + let proof = mmr + .gen_node_proof( + proof_elem + .iter() + .map(|elem| positions[*elem as usize]) + .collect(), + ) + .expect("gen proof"); + assert!(proof + .proof_items() + .iter() + .zip(proof.proof_items().iter().skip(1)) + .all(|((pos_a, _), (pos_b, _))| pos_a < pos_b)); + mmr.commit().expect("commit changes"); + let result = proof + .verify( + root, + proof_elem + .iter() + .map(|elem| (positions[*elem as usize], NumberHash::from(*elem))) + .collect(), + ) + .unwrap(); + assert!(result); +} + +fn test_gen_new_root_from_proof(count: u32) { + let store = MemStore::default(); + let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store); + let positions: Vec = (0u32..count) + .map(|i| mmr.push(NumberHash::from(i)).unwrap()) + .collect(); + let elem = count - 1; + let pos = positions[elem as usize]; + let proof = mmr.gen_proof(vec![pos]).expect("gen proof"); + let new_elem = count; + let new_pos = mmr.push(NumberHash::from(new_elem)).unwrap(); + let root = mmr.get_root().expect("get root"); + mmr.commit().expect("commit changes"); + let calculated_root = proof + .calculate_root_with_new_leaf( + vec![(pos, NumberHash::from(elem))], + new_pos, + NumberHash::from(new_elem), + leaf_index_to_mmr_size(new_elem.into()), + ) + .unwrap(); + assert_eq!(calculated_root, root); +} + +#[test] +fn test_mmr_root() { + let store = MemStore::default(); + let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store); + (0u32..11).for_each(|i| { + mmr.push(NumberHash::from(i)).unwrap(); + }); + let root = mmr.get_root().expect("get root"); + let hex_root = hex_string(&root.0); + assert_eq!( + "f6794677f37a57df6a5ec36ce61036e43a36c1a009d05c81c9aa685dde1fd6e3", + hex_root + ); +} + +#[test] +fn test_empty_mmr_root() { + let store = MemStore::::default(); + let mmr = MemMMR::<_, MergeNumberHash>::new(0, &store); + assert_eq!(Err(Error::GetRootOnEmpty), mmr.get_root()); +} + +#[test] +fn test_mmr_3_peaks() { + test_mmr(11, vec![5]); +} + +#[test] +fn test_mmr_2_peaks() { + test_mmr(10, vec![5]); +} + +#[test] +fn test_mmr_1_peak() { + test_mmr(8, vec![5]); +} + +#[test] +fn test_mmr_first_elem_proof() { + test_mmr(11, vec![0]); +} + +#[test] +fn test_mmr_last_elem_proof() { + test_mmr(11, vec![10]); +} + +#[test] +fn test_mmr_1_elem() { + test_mmr(1, vec![0]); +} + +#[test] +fn test_mmr_2_elems() { + test_mmr(2, vec![0]); + test_mmr(2, vec![1]); +} + +#[test] +fn test_mmr_2_leaves_merkle_proof() { + test_mmr(11, vec![3, 7]); + test_mmr(11, vec![3, 4]); +} + +#[test] +fn test_mmr_2_sibling_leaves_merkle_proof() { + test_mmr(11, vec![4, 5]); + test_mmr(11, vec![5, 6]); + test_mmr(11, vec![6, 7]); +} + +#[test] +fn test_mmr_3_leaves_merkle_proof() { + test_mmr(11, vec![4, 5, 6]); + test_mmr(11, vec![3, 5, 7]); + test_mmr(11, vec![3, 4, 5]); + test_mmr(100, vec![3, 5, 13]); +} + +#[test] +fn test_gen_root_from_proof() { + test_gen_new_root_from_proof(11); +} + +#[test] +fn test_gen_proof_with_duplicate_leaves() { + test_mmr(10, vec![5, 5]); +} + +fn test_invalid_proof_verification( + leaf_count: u32, + positions_to_verify: Vec, + // positions of entries that should be tampered + tampered_positions: Vec, + // optionally handroll proof from these positions + handrolled_proof_positions: Option>, + // optionally handroll tampered proof from these positions + handrolled_tampered_proof_positions: Option>, +) { + use crate::{ancestry_proof::NodeMerkleProof, Merge}; + use std::fmt::{Debug, Formatter}; + + // Simple item struct to allow debugging the contents of MMR nodes/peaks + #[derive(Clone, PartialEq)] + enum MyItem { + Number(u32), + Merged(Box, Box), + } + + impl Debug for MyItem { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + MyItem::Number(x) => f.write_fmt(format_args!("{}", x)), + MyItem::Merged(a, b) => f.write_fmt(format_args!("Merged({:#?}, {:#?})", a, b)), + } + } + } + + #[derive(Debug)] + struct MyMerge; + + impl Merge for MyMerge { + type Item = MyItem; + fn merge(lhs: &Self::Item, rhs: &Self::Item) -> Result { + return Ok(MyItem::Merged(Box::new(lhs.clone()), Box::new(rhs.clone()))); + } + } + + // Let's build a simple MMR with the numbers 0 to 6 + let store = MemStore::default(); + let mut mmr = MemMMR::<_, MyMerge>::new(0, &store); + let mut positions: Vec = Vec::new(); + for i in 0u32..leaf_count { + let pos = mmr.push(MyItem::Number(i)).unwrap(); + positions.push(pos); + } + let root = mmr.get_root().unwrap(); + + let entries_to_verify: Vec<(u64, MyItem)> = positions_to_verify + .iter() + .map(|pos| (*pos, mmr.batch().get_elem(*pos).unwrap().unwrap())) + .collect(); + + let mut tampered_entries_to_verify = entries_to_verify.clone(); + tampered_positions.iter().for_each(|proof_pos| { + tampered_entries_to_verify[*proof_pos] = ( + tampered_entries_to_verify[*proof_pos].0, + MyItem::Number(31337), + ) + }); + + let tampered_proof: Option> = + if let Some(tampered_proof_positions) = handrolled_tampered_proof_positions { + Some(NodeMerkleProof::new( + mmr.mmr_size(), + tampered_proof_positions + .iter() + .map(|pos| (*pos, mmr.batch().get_elem(*pos).unwrap().unwrap())) + .collect(), + )) + } else { + None + }; + + // test with the proof generated by the library itself, or, if provided, a handrolled proof + let proof = if let Some(proof_positions) = handrolled_proof_positions { + NodeMerkleProof::new( + mmr.mmr_size(), + proof_positions + .iter() + .map(|pos| (*pos, mmr.batch().get_elem(*pos).unwrap().unwrap())) + .collect(), + ) + } else { + mmr.gen_node_proof(positions_to_verify.clone()).unwrap() + }; + + // if proof items have been tampered with, the proof verification fails + if let Some(tampered_proof) = tampered_proof { + let tampered_proof_result = + tampered_proof.verify(root.clone(), tampered_entries_to_verify.clone()); + assert!(tampered_proof_result.is_err() || !tampered_proof_result.unwrap()); + } + + // if any nodes to be verified aren't members of the mmr, the proof verification fails + let tampered_entries_result = proof.verify(root.clone(), tampered_entries_to_verify.clone()); + assert!(tampered_entries_result.is_err() || !tampered_entries_result.unwrap()); + + let proof_verification = proof.verify(root, entries_to_verify); + // verification of the correct nodes passes + assert!(proof_verification.unwrap()); +} + +#[test] +fn test_generic_proofs() { + // working with proof generation + test_invalid_proof_verification(7, vec![5], vec![0], None, None); + test_invalid_proof_verification(7, vec![1, 2], vec![0], None, None); + test_invalid_proof_verification(7, vec![1, 5], vec![0], None, None); + // original example with proof items [Merged(Merged(0, 1), Merged(2, 3)), Merged(4, 5), 6]: + test_invalid_proof_verification(7, vec![1, 6], vec![0], None, Some(vec![6, 9, 10])); + // original example, but with correct proof items [0, Merged(2, 3), Merged(6, Merged(4, 5))] + test_invalid_proof_verification(7, vec![1, 6], vec![0], None, None); + test_invalid_proof_verification(7, vec![1, 6], vec![0], Some(vec![0, 5, 9, 10]), None); + test_invalid_proof_verification(7, vec![5, 6], vec![0], None, None); + test_invalid_proof_verification(7, vec![1, 5, 6], vec![0], None, None); + test_invalid_proof_verification(7, vec![1, 5, 7], vec![0], None, None); + test_invalid_proof_verification(7, vec![5, 6, 7], vec![0], None, None); + test_invalid_proof_verification(7, vec![5, 6, 7, 8, 9, 10], vec![0], None, None); + test_invalid_proof_verification(7, vec![1, 5, 7, 8, 9, 10], vec![0], None, None); + test_invalid_proof_verification(7, vec![0, 1, 5, 7, 8, 9, 10], vec![0], None, None); + test_invalid_proof_verification(7, vec![0, 1, 5, 6, 7, 8, 9, 10], vec![0], None, None); + test_invalid_proof_verification(7, vec![0, 1, 2, 5, 6, 7, 8, 9, 10], vec![0], None, None); + + test_invalid_proof_verification( + 7, + vec![0, 1, 2, 3, 7, 8, 9, 10], + vec![0], + Some(vec![4]), + None, + ); + test_invalid_proof_verification(7, vec![0, 2, 3, 7, 8, 9, 10], vec![0], None, None); + test_invalid_proof_verification(7, vec![0, 3, 7, 8, 9, 10], vec![0], None, None); + test_invalid_proof_verification(7, vec![0, 2, 3, 7, 8, 9, 10], vec![0], None, None); +} + +prop_compose! { + fn count_elem(count: u32) + (elem in 0..count) + -> (u32, u32) { + (count, elem) + } +} + +fn nodes_subset(subset_index: u128, position_count: u8) -> Vec { + let mut positions = vec![]; + + for index in 0..position_count { + if (1 << index) & subset_index != 0 { + positions.push(index as u64) + } + } + + positions +} + +const MAX_LEAVES_COUNT: u32 = 64; +proptest! { + #![proptest_config(ProptestConfig { + cases: 2000, max_shrink_iters: 2000, .. ProptestConfig::default() + })] + #[test] + fn test_mmr_generic_proof_proptest( + (leaves_count, (positions, tampered_node_position)) in (1..=MAX_LEAVES_COUNT) + .prop_flat_map(|leaves_count| {let mmr_size = leaf_index_to_mmr_size(leaves_count as u64 - 1); + let subset_index = 1u128..1u128.shl(mmr_size as u8); + (Just(leaves_count), + (Just(mmr_size), subset_index).prop_flat_map(|(mmr_size, subset_index)| { + let positions = nodes_subset(subset_index, mmr_size as u8); + (Just(positions.clone()), 0..positions.len()) + }))}) + ) { + test_invalid_proof_verification(leaves_count, positions, vec![tampered_node_position], None, None) + } +} + +const MAX_POS: u8 = 11; +proptest! { + // for 7 leaves, have 11 nodes, so 2^11 possible subsets of nodes to generate a proof for + #[test] + fn test_7_leaf_mmr_generic_proof_proptest( + positions in (1u128..1u128.shl(MAX_POS)).prop_map(|subset_index| nodes_subset(subset_index, MAX_POS)) + ) { + let leaves_count = 7; + test_invalid_proof_verification(leaves_count, positions, vec![0], None, None) + } +} + +proptest! { + #[test] + fn test_random_mmr(count in 10u32..500u32) { + let mut leaves: Vec = (0..count).collect(); + let mut rng = thread_rng(); + leaves.shuffle(&mut rng); + let leaves_count = rng.gen_range(1..count - 1); + leaves.truncate(leaves_count as usize); + test_mmr(count, leaves); + } + + #[test] + fn test_random_gen_root_with_new_leaf(count in 1u32..500u32) { + test_gen_new_root_from_proof(count); + } +}