Skip to content

Commit

Permalink
Merge pull request #29 from quake/quake/mmr-commit
Browse files Browse the repository at this point in the history
refactor: use mut ref in mmr#commit
  • Loading branch information
jjyr authored Mar 7, 2023
2 parents 6494cd2 + fef46e0 commit cb54163
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 75 deletions.
26 changes: 17 additions & 9 deletions src/mmr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ impl<T, M, S> MMR<T, M, S> {
merge: PhantomData,
}
}

pub fn mmr_size(&self) -> u64 {
self.mmr_size
}

pub fn is_empty(&self) -> bool {
self.mmr_size == 0
}

pub fn batch(&self) -> &MMRBatch<T, S> {
&self.batch
}

pub fn store(&self) -> &S {
self.batch.store()
}
}

impl<T: Clone + PartialEq, M: Merge<Item = T>, S: MMRStoreReadOps<T>> MMR<T, M, S> {
Expand All @@ -42,14 +58,6 @@ impl<T: Clone + PartialEq, M: Merge<Item = T>, S: MMRStoreReadOps<T>> MMR<T, M,
Ok(Cow::Owned(elem))
}

pub fn mmr_size(&self) -> u64 {
self.mmr_size
}

pub fn is_empty(&self) -> bool {
self.mmr_size == 0
}

// push a element and return position
pub fn push(&mut self, elem: T) -> Result<u64> {
let mut elems = vec![elem];
Expand Down Expand Up @@ -216,7 +224,7 @@ impl<T: Clone + PartialEq, M: Merge<Item = T>, S: MMRStoreReadOps<T>> MMR<T, M,
}

impl<T, M, S: MMRStoreWriteOps<T>> MMR<T, M, S> {
pub fn commit(self) -> Result<()> {
pub fn commit(&mut self) -> Result<()> {
self.batch.commit()
}
}
Expand Down
14 changes: 7 additions & 7 deletions src/mmr_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ impl<Elem, Store> MMRBatch<Elem, Store> {
pub fn append(&mut self, pos: u64, elems: Vec<Elem>) {
self.memory_batch.push((pos, elems));
}

pub fn store(&self) -> &Store {
&self.store
}
}

impl<Elem: Clone, Store: MMRStoreReadOps<Elem>> MMRBatch<Elem, Store> {
Expand All @@ -35,13 +39,9 @@ impl<Elem: Clone, Store: MMRStoreReadOps<Elem>> MMRBatch<Elem, Store> {
}

impl<Elem, Store: MMRStoreWriteOps<Elem>> MMRBatch<Elem, Store> {
pub fn commit(self) -> Result<()> {
let Self {
mut store,
memory_batch,
} = self;
for (pos, elems) in memory_batch {
store.append(pos, elems)?;
pub fn commit(&mut self) -> Result<()> {
for (pos, elems) in self.memory_batch.drain(..) {
self.store.append(pos, elems)?;
}
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion src/tests/test_accumulate_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl Prover {
let mut mmr = MMR::<_, MergeHashWithTD, _>::new(self.positions.len() as u64, &self.store);
// get previous element
let mut previous = if let Some(pos) = self.positions.last() {
MMRStoreReadOps::<_>::get_elem(&&self.store, *pos)?.expect("exists")
mmr.store().get_elem(*pos)?.expect("exists")
} else {
let genesis = Header::default();

Expand Down
22 changes: 13 additions & 9 deletions src/tests/test_mmr.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use super::{MergeNumberHash, NumberHash};
use crate::{
helper::pos_height_in_tree, leaf_index_to_mmr_size, util::MemStore, Error, MMRStoreReadOps, MMR,
helper::pos_height_in_tree,
leaf_index_to_mmr_size,
util::{MemMMR, MemStore},
Error,
};
use faster_hex::hex_string;
use proptest::prelude::*;
use rand::{seq::SliceRandom, thread_rng};

fn test_mmr(count: u32, proof_elem: Vec<u32>) {
let store = MemStore::default();
let mut mmr = MMR::<_, MergeNumberHash, _>::new(0, &store);
let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store);
let positions: Vec<u64> = (0u32..count)
.map(|i| mmr.push(NumberHash::from(i)).unwrap())
.collect();
Expand Down Expand Up @@ -36,7 +39,7 @@ fn test_mmr(count: u32, proof_elem: Vec<u32>) {

fn test_gen_new_root_from_proof(count: u32) {
let store = MemStore::default();
let mut mmr = MMR::<_, MergeNumberHash, _>::new(0, &store);
let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store);
let positions: Vec<u64> = (0u32..count)
.map(|i| mmr.push(NumberHash::from(i)).unwrap())
.collect();
Expand All @@ -61,7 +64,7 @@ fn test_gen_new_root_from_proof(count: u32) {
#[test]
fn test_mmr_root() {
let store = MemStore::default();
let mut mmr = MMR::<_, MergeNumberHash, _>::new(0, &store);
let mut mmr = MemMMR::<_, MergeNumberHash>::new(0, &store);
(0u32..11).for_each(|i| {
mmr.push(NumberHash::from(i)).unwrap();
});
Expand All @@ -76,7 +79,7 @@ fn test_mmr_root() {
#[test]
fn test_empty_mmr_root() {
let store = MemStore::<NumberHash>::default();
let mmr = MMR::<_, MergeNumberHash, _>::new(0, &store);
let mmr = MemMMR::<_, MergeNumberHash>::new(0, &store);
assert_eq!(Err(Error::GetRootOnEmpty), mmr.get_root());
}

Expand Down Expand Up @@ -155,7 +158,7 @@ fn test_invalid_proof_verification(
// optionally handroll proof from these positions
handrolled_proof_positions: Option<Vec<u64>>,
) {
use crate::{util::MemMMR, Merge, MerkleProof};
use crate::{Merge, MerkleProof};
use std::fmt::{Debug, Formatter};

// Simple item struct to allow debugging the contents of MMR nodes/peaks
Expand Down Expand Up @@ -184,7 +187,8 @@ fn test_invalid_proof_verification(
}
}

let mut mmr: MemMMR<MyItem, MyMerge> = MemMMR::default();
let store = MemStore::default();
let mut mmr = MemMMR::<_, MyMerge>::new(0, &store);
let mut positions: Vec<u64> = Vec::new();
for i in 0u32..leaf_count {
let pos = mmr.push(MyItem::Number(i)).unwrap();
Expand All @@ -194,7 +198,7 @@ fn test_invalid_proof_verification(

let entries_to_verify: Vec<(u64, MyItem)> = positions_to_verify
.iter()
.map(|pos| (*pos, mmr.store().get_elem(*pos).unwrap().unwrap()))
.map(|pos| (*pos, mmr.batch().get_elem(*pos).unwrap().unwrap()))
.collect();

let mut tampered_entries_to_verify = entries_to_verify.clone();
Expand All @@ -211,7 +215,7 @@ fn test_invalid_proof_verification(
mmr.mmr_size(),
handrolled_proof_positions
.iter()
.map(|pos| mmr.store().get_elem(*pos).unwrap().unwrap())
.map(|pos| mmr.batch().get_elem(*pos).unwrap().unwrap())
.collect(),
)
});
Expand Down
51 changes: 2 additions & 49 deletions src/util.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::collections::BTreeMap;
use crate::{vec::Vec, MMRStoreReadOps, MMRStoreWriteOps, Merge, MerkleProof, Result, MMR};
use crate::{vec::Vec, MMRStoreReadOps, MMRStoreWriteOps, Result, MMR};
use core::cell::RefCell;
use core::marker::PhantomData;

#[derive(Clone)]
pub struct MemStore<T>(RefCell<BTreeMap<u64, T>>);
Expand Down Expand Up @@ -34,50 +33,4 @@ impl<T> MMRStoreWriteOps<T> for &MemStore<T> {
}
}

pub struct MemMMR<T, M> {
store: MemStore<T>,
mmr_size: u64,
merge: PhantomData<M>,
}

impl<T: Clone + PartialEq, M: Merge<Item = T>> Default for MemMMR<T, M> {
fn default() -> Self {
Self::new(0, Default::default())
}
}

impl<T: Clone + PartialEq, M: Merge<Item = T>> MemMMR<T, M> {
pub fn new(mmr_size: u64, store: MemStore<T>) -> Self {
MemMMR {
mmr_size,
store,
merge: PhantomData,
}
}

pub fn store(&self) -> &MemStore<T> {
&self.store
}

pub fn mmr_size(&self) -> u64 {
self.mmr_size
}

pub fn get_root(&self) -> Result<T> {
let mmr = MMR::<T, M, &MemStore<T>>::new(self.mmr_size, &self.store);
mmr.get_root()
}

pub fn push(&mut self, elem: T) -> Result<u64> {
let mut mmr = MMR::<T, M, &MemStore<T>>::new(self.mmr_size, &self.store);
let pos = mmr.push(elem)?;
self.mmr_size = mmr.mmr_size();
mmr.commit()?;
Ok(pos)
}

pub fn gen_proof(&self, pos_list: Vec<u64>) -> Result<MerkleProof<T, M>> {
let mmr = MMR::<T, M, &MemStore<T>>::new(self.mmr_size, &self.store);
mmr.gen_proof(pos_list)
}
}
pub type MemMMR<'a, T, M> = MMR<T, M, &'a MemStore<T>>;

0 comments on commit cb54163

Please sign in to comment.