Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 33 additions & 43 deletions contracts/src/libs/MerkleTree.sol
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,10 @@ library MerkleTree {
function setup(Tree storage self) internal returns (bytes32 initialRoot) {
initialRoot = SHA256.EMPTY_HASH;

// Store depth in the dynamic array
Arrays.unsafeSetLength(self._zeros, 256);

// Build each root of zero-filled subtrees
bytes32 currentZero = SHA256.EMPTY_HASH;
for (uint256 i = 0; i < 256; ++i) {
Arrays.unsafeAccess(self._zeros, i).value = currentZero;
currentZero = SHA256.hash(currentZero, currentZero);
}
self._zeros.push(currentZero);
self._sides.push(currentZero);

self._nextLeafIndex = 0;
}
Expand All @@ -47,51 +42,46 @@ library MerkleTree {
/// @param self The tree data structure.
/// @param leaf The leaf to add.
/// @return index The index of the leaf.
/// @return newRoot The new root of the tree.
function push(Tree storage self, bytes32 leaf) internal returns (uint256 index, bytes32 newRoot) {
// Cache the tree depth read.
uint256 treeDepth = depth(self);

// Get the next leaf index and increment it after assignment.
index = self._nextLeafIndex++;

// Rebuild the branch from leaf to root.
uint256 currentIndex = index;
bytes32 currentLevelHash = leaf;
for (uint256 i = 0; i < treeDepth; ++i) {
// Compute the next level hash for depth `i+1`.
// Check whether the `currentIndex` node is the left or right child of its parent.
if (isLeftChild(currentIndex)) {
// Store the current hash as the sibling (side) for the current level.
Arrays.unsafeAccess(self._sides, i).value = currentLevelHash;

// Compute the current level hash using the right sibling, which is the zero hash of this level.
currentLevelHash = SHA256.hash(currentLevelHash, Arrays.unsafeAccess(self._zeros, i).value);
/// @return accumulatorNode The new root of the tree.
function push(Tree storage self, bytes32 leaf) internal returns (uint256 index, bytes32 accumulatorNode) {
// If the capacity of the current Merkle tree is exhausted, then expand it
if (self._nextLeafIndex != 0 && (self._nextLeafIndex & (self._nextLeafIndex - 1)) == 0) {
// Compute the next zero for the next level.
bytes32 currentZero = Arrays.unsafeAccess(self._zeros, self._zeros.length - 1).value;
bytes32 nextZero = SHA256.hash(currentZero, currentZero);
self._zeros.push(nextZero);
self._sides.push(nextZero);
}
uint256 height = 0;
bytes32 replacementNode = leaf;
// Propagate a hash update up the Merkle tree until there's space
for (; self._nextLeafIndex & (1 << height) != 0; height++) {
// Compute the replacement of the parent node
replacementNode = SHA256.hash(Arrays.unsafeAccess(self._sides, height).value, replacementNode);
}
accumulatorNode = replacementNode;
// Record where we are going to insert the new node
uint256 insertHeight = height;
// Now let's compute the new root hash starting from the replacement node
for (; height < self._zeros.length - 1; height++) {
if (self._nextLeafIndex & (1 << height) == 0) {
// If no partial tree at current level, then right-pad the accumulator
accumulatorNode = SHA256.hash(accumulatorNode, Arrays.unsafeAccess(self._zeros, height).value);
} else {
// Compute the current level hash using the left sibling (side).
currentLevelHash = SHA256.hash(Arrays.unsafeAccess(self._sides, i).value, currentLevelHash);
// If there's a partial tree, then combine it with the accumulator
accumulatorNode = SHA256.hash(Arrays.unsafeAccess(self._sides, height).value, accumulatorNode);
}

currentIndex >>= 1;
}

// Expand the tree if the capacity is reached.
if (self._nextLeafIndex == capacity(self)) {
// Store the current level hash as the sibling (side) for the current level.
self._sides.push(currentLevelHash);

// Compute the new current level hash.
currentLevelHash = SHA256.hash(currentLevelHash, Arrays.unsafeAccess(self._zeros, treeDepth).value);
}

newRoot = currentLevelHash;
// Finish off the propagation with a final assignment
Arrays.unsafeAccess(self._sides, insertHeight).value = replacementNode;
index = self._nextLeafIndex++;
}

/// @notice Returns the tree depth.
/// @param self The tree data structure.
/// @return treeDepth The depth of the tree.
function depth(Tree storage self) internal view returns (uint8 treeDepth) {
treeDepth = uint8(self._sides.length);
treeDepth = uint8(self._sides.length - 1);
}

/// @notice Returns the number of leaves that have been added to the tree.
Expand Down
43 changes: 12 additions & 31 deletions contracts/test/examples/MerkleTree.e.sol
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,23 @@ contract MerkleTreeExample {

// State 1
{
_leaves[1] = new bytes32[](2);
_leaves[1] = new bytes32[](1);
_leaves[1][0] = bytes32(uint256(1));
_leaves[1][1] = SHA256.EMPTY_HASH;

_heightOneNodes[1] = _calculateNextLevel(_leaves[1]);

_heightTwoNodes[1] = _calculateNextLevel(_heightOneNodes[1]);

_roots[1] = _leaves[1].computeRoot();

_siblings[1] = new bytes32[][](1);

_siblings[1][0] = new bytes32[](1);
_siblings[1][0][0] = _leaves[1][1];
_siblings[1] = new bytes32[][](0);
}

// State 2
{
_leaves[2] = new bytes32[](4);
_leaves[2] = new bytes32[](2);
_leaves[2][0] = _leaves[1][0];
_leaves[2][1] = bytes32(uint256(2));
_leaves[2][2] = SHA256.EMPTY_HASH;
_leaves[2][3] = SHA256.EMPTY_HASH;

_heightOneNodes[2] = _calculateNextLevel(_leaves[2]);

Expand All @@ -83,13 +77,11 @@ contract MerkleTreeExample {

_siblings[2] = new bytes32[][](2);

_siblings[2][0] = new bytes32[](2);
_siblings[2][0] = new bytes32[](1);
_siblings[2][0][0] = _leaves[2][1];
_siblings[2][0][1] = _heightOneNodes[2][1];

_siblings[2][1] = new bytes32[](2);
_siblings[2][1] = new bytes32[](1);
_siblings[2][1][0] = _leaves[2][0];
_siblings[2][1][1] = _heightOneNodes[2][1];
}

// State 3
Expand All @@ -108,7 +100,7 @@ contract MerkleTreeExample {

_siblings[3] = new bytes32[][](3);

_siblings[3][0] = new bytes32[](3);
_siblings[3][0] = new bytes32[](2);
_siblings[3][0][0] = _leaves[3][1];
_siblings[3][0][1] = _heightOneNodes[3][1];

Expand All @@ -123,18 +115,11 @@ contract MerkleTreeExample {

// State 4
{
_leaves[4] = new bytes32[](8);
_leaves[4] = new bytes32[](4);
_leaves[4][0] = _leaves[3][0];
_leaves[4][1] = _leaves[3][1];
_leaves[4][2] = _leaves[3][2];
_leaves[4][3] = bytes32(uint256(4));
_leaves[4][4] = SHA256.EMPTY_HASH;

_leaves[4][5] = SHA256.EMPTY_HASH;

_leaves[4][6] = SHA256.EMPTY_HASH;

_leaves[4][7] = SHA256.EMPTY_HASH;

_heightOneNodes[4] = _calculateNextLevel(_leaves[4]);

Expand All @@ -144,25 +129,21 @@ contract MerkleTreeExample {

_siblings[4] = new bytes32[][](4);

_siblings[4][0] = new bytes32[](3);
_siblings[4][0] = new bytes32[](2);
_siblings[4][0][0] = _leaves[4][1];
_siblings[4][0][1] = _heightOneNodes[4][1];
_siblings[4][0][2] = _heightTwoNodes[4][1];

_siblings[4][1] = new bytes32[](3);
_siblings[4][1] = new bytes32[](2);
_siblings[4][1][0] = _leaves[4][0];
_siblings[4][1][1] = _heightOneNodes[4][1];
_siblings[4][1][2] = _heightTwoNodes[4][1];

_siblings[4][2] = new bytes32[](3);
_siblings[4][2] = new bytes32[](2);
_siblings[4][2][0] = _leaves[4][3];
_siblings[4][2][1] = _heightOneNodes[4][0];
_siblings[4][2][2] = _heightTwoNodes[4][1];

_siblings[4][3] = new bytes32[](3);
_siblings[4][3] = new bytes32[](2);
_siblings[4][3][0] = _leaves[4][2];
_siblings[4][3][1] = _heightOneNodes[4][0];
_siblings[4][3][2] = _heightTwoNodes[4][1];
}

// State 5
Expand Down Expand Up @@ -306,7 +287,7 @@ contract MerkleTreeExample {
_siblings[7][3][2] = _heightTwoNodes[7][1];

_siblings[7][4] = new bytes32[](3);
_siblings[7][4][0] = _leaves[7][4];
_siblings[7][4][0] = _leaves[7][5];
_siblings[7][4][1] = _heightOneNodes[7][3];
_siblings[7][4][2] = _heightTwoNodes[7][0];

Expand Down
10 changes: 7 additions & 3 deletions contracts/test/state/CommitmentTree.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ contract CommitmentTreeTest is Test, MerkleTreeExample {
function test_should_produce_an_invalid_root_for_a_non_existent_leaf() public {
bytes32 nonExistentCommitment = sha256("NON_EXISTENT");

for (uint256 i = 0; i < _N_LEAVES; ++i) {
for (uint256 i = 1; i < _N_LEAVES; ++i) {
bytes32 root = _cmAcc.addCommitment(_leaves[i + 1][i]);

for (uint256 j = 0; j <= i; ++j) {
for (uint256 j = 0; j < i; ++j) {
bytes32 computedRoot = MerkleTree.processProof({
siblings: _siblings[i + 1][j],
directionBits: _directionBits[_cmAcc.commitmentTreeCapacity()][j],
Expand Down Expand Up @@ -118,7 +118,9 @@ contract CommitmentTreeTest is Test, MerkleTreeExample {
*/

bytes32 commitment = bytes32(uint256(1));
bytes32 newRoot = _cmAcc.addCommitment(commitment);
_cmAcc.addCommitment(commitment);
bytes32 existingCommitment = bytes32(uint256(3));
bytes32 newRoot = _cmAcc.addCommitment(existingCommitment);
_cmAcc.storeCommitmentTreeRoot(newRoot);

bytes32 nonExistingCommitment = bytes32(uint256(2));
Expand Down Expand Up @@ -153,6 +155,8 @@ contract CommitmentTreeTest is Test, MerkleTreeExample {

function test_verifyMerkleProof_reverts_on_wrong_path() public {
bytes32 commitment = sha256("SOMETHING");
bytes32 commitment2 = sha256("ELSE");
_cmAcc.addCommitment(commitment2);
bytes32 newRoot = _cmAcc.addCommitment(commitment);
_cmAcc.storeCommitmentTreeRoot(newRoot);

Expand Down
8 changes: 4 additions & 4 deletions contracts/test/state/MerkleTree.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ contract MerkleTreeTest is Test, MerkleTreeExample {

_merkleTree.push(_leaves[7][0]);
assertEq(_merkleTree.leafCount(), 1);
assertEq(_merkleTree.depth(), 1);
assertEq(_merkleTree.depth(), 0);

_merkleTree.push(_leaves[7][1]);
assertEq(_merkleTree.leafCount(), 2);
assertEq(_merkleTree.depth(), 2);
assertEq(_merkleTree.depth(), 1);

_merkleTree.push(_leaves[7][2]);
assertEq(_merkleTree.leafCount(), 3);
assertEq(_merkleTree.depth(), 2);

_merkleTree.push(_leaves[7][3]);
assertEq(_merkleTree.leafCount(), 4);
assertEq(_merkleTree.depth(), 3);
assertEq(_merkleTree.depth(), 2);

_merkleTree.push(_leaves[7][4]);
assertEq(_merkleTree.leafCount(), 5);
Expand All @@ -62,7 +62,7 @@ contract MerkleTreeTest is Test, MerkleTreeExample {
// First compute what tree depth is required to store leaves
uint8 treeDepth = 0;
// Essentially compute the logarithm of the leaf count base 2
for (uint256 i = leaves.length; i > 0; i >>= 1) {
for (uint256 i = 1; i < leaves.length; i <<= 1) {
treeDepth++;
}
// Set up a protocol adapter Merkle tree and an OpenZeppelin one
Expand Down
Loading