diff --git a/contracts/src/libs/MerkleTree.sol b/contracts/src/libs/MerkleTree.sol index ffcf89c5..a78756f8 100644 --- a/contracts/src/libs/MerkleTree.sol +++ b/contracts/src/libs/MerkleTree.sol @@ -30,15 +30,10 @@ library MerkleTree { function setup(Tree storage self) internal returns (bytes32 initialRoot) { initialRoot = SHA256.EMPTY_HASH; - // Store depth in the dynamic array - Arrays.unsafeSetLength(self._zeros, 256); - // Build each root of zero-filled subtrees bytes32 currentZero = SHA256.EMPTY_HASH; - for (uint256 i = 0; i < 256; ++i) { - Arrays.unsafeAccess(self._zeros, i).value = currentZero; - currentZero = SHA256.hash(currentZero, currentZero); - } + self._zeros.push(currentZero); + self._sides.push(currentZero); self._nextLeafIndex = 0; } @@ -47,51 +42,46 @@ library MerkleTree { /// @param self The tree data structure. /// @param leaf The leaf to add. /// @return index The index of the leaf. - /// @return newRoot The new root of the tree. - function push(Tree storage self, bytes32 leaf) internal returns (uint256 index, bytes32 newRoot) { - // Cache the tree depth read. - uint256 treeDepth = depth(self); - - // Get the next leaf index and increment it after assignment. - index = self._nextLeafIndex++; - - // Rebuild the branch from leaf to root. - uint256 currentIndex = index; - bytes32 currentLevelHash = leaf; - for (uint256 i = 0; i < treeDepth; ++i) { - // Compute the next level hash for depth `i+1`. - // Check whether the `currentIndex` node is the left or right child of its parent. - if (isLeftChild(currentIndex)) { - // Store the current hash as the sibling (side) for the current level. - Arrays.unsafeAccess(self._sides, i).value = currentLevelHash; - - // Compute the current level hash using the right sibling, which is the zero hash of this level. - currentLevelHash = SHA256.hash(currentLevelHash, Arrays.unsafeAccess(self._zeros, i).value); + /// @return accumulatorNode The new root of the tree. + function push(Tree storage self, bytes32 leaf) internal returns (uint256 index, bytes32 accumulatorNode) { + // If the capacity of the current Merkle tree is exhausted, then expand it + if (self._nextLeafIndex != 0 && (self._nextLeafIndex & (self._nextLeafIndex - 1)) == 0) { + // Compute the next zero for the next level. + bytes32 currentZero = Arrays.unsafeAccess(self._zeros, self._zeros.length - 1).value; + bytes32 nextZero = SHA256.hash(currentZero, currentZero); + self._zeros.push(nextZero); + self._sides.push(nextZero); + } + uint256 height = 0; + bytes32 replacementNode = leaf; + // Propagate a hash update up the Merkle tree until there's space + for (; self._nextLeafIndex & (1 << height) != 0; height++) { + // Compute the replacement of the parent node + replacementNode = SHA256.hash(Arrays.unsafeAccess(self._sides, height).value, replacementNode); + } + accumulatorNode = replacementNode; + // Record where we are going to insert the new node + uint256 insertHeight = height; + // Now let's compute the new root hash starting from the replacement node + for (; height < self._zeros.length - 1; height++) { + if (self._nextLeafIndex & (1 << height) == 0) { + // If no partial tree at current level, then right-pad the accumulator + accumulatorNode = SHA256.hash(accumulatorNode, Arrays.unsafeAccess(self._zeros, height).value); } else { - // Compute the current level hash using the left sibling (side). - currentLevelHash = SHA256.hash(Arrays.unsafeAccess(self._sides, i).value, currentLevelHash); + // If there's a partial tree, then combine it with the accumulator + accumulatorNode = SHA256.hash(Arrays.unsafeAccess(self._sides, height).value, accumulatorNode); } - - currentIndex >>= 1; } - - // Expand the tree if the capacity is reached. - if (self._nextLeafIndex == capacity(self)) { - // Store the current level hash as the sibling (side) for the current level. - self._sides.push(currentLevelHash); - - // Compute the new current level hash. - currentLevelHash = SHA256.hash(currentLevelHash, Arrays.unsafeAccess(self._zeros, treeDepth).value); - } - - newRoot = currentLevelHash; + // Finish off the propagation with a final assignment + Arrays.unsafeAccess(self._sides, insertHeight).value = replacementNode; + index = self._nextLeafIndex++; } /// @notice Returns the tree depth. /// @param self The tree data structure. /// @return treeDepth The depth of the tree. function depth(Tree storage self) internal view returns (uint8 treeDepth) { - treeDepth = uint8(self._sides.length); + treeDepth = uint8(self._sides.length - 1); } /// @notice Returns the number of leaves that have been added to the tree. diff --git a/contracts/test/examples/MerkleTree.e.sol b/contracts/test/examples/MerkleTree.e.sol index e68eeb6e..c3511e19 100644 --- a/contracts/test/examples/MerkleTree.e.sol +++ b/contracts/test/examples/MerkleTree.e.sol @@ -51,9 +51,8 @@ contract MerkleTreeExample { // State 1 { - _leaves[1] = new bytes32[](2); + _leaves[1] = new bytes32[](1); _leaves[1][0] = bytes32(uint256(1)); - _leaves[1][1] = SHA256.EMPTY_HASH; _heightOneNodes[1] = _calculateNextLevel(_leaves[1]); @@ -61,19 +60,14 @@ contract MerkleTreeExample { _roots[1] = _leaves[1].computeRoot(); - _siblings[1] = new bytes32[][](1); - - _siblings[1][0] = new bytes32[](1); - _siblings[1][0][0] = _leaves[1][1]; + _siblings[1] = new bytes32[][](0); } // State 2 { - _leaves[2] = new bytes32[](4); + _leaves[2] = new bytes32[](2); _leaves[2][0] = _leaves[1][0]; _leaves[2][1] = bytes32(uint256(2)); - _leaves[2][2] = SHA256.EMPTY_HASH; - _leaves[2][3] = SHA256.EMPTY_HASH; _heightOneNodes[2] = _calculateNextLevel(_leaves[2]); @@ -83,13 +77,11 @@ contract MerkleTreeExample { _siblings[2] = new bytes32[][](2); - _siblings[2][0] = new bytes32[](2); + _siblings[2][0] = new bytes32[](1); _siblings[2][0][0] = _leaves[2][1]; - _siblings[2][0][1] = _heightOneNodes[2][1]; - _siblings[2][1] = new bytes32[](2); + _siblings[2][1] = new bytes32[](1); _siblings[2][1][0] = _leaves[2][0]; - _siblings[2][1][1] = _heightOneNodes[2][1]; } // State 3 @@ -108,7 +100,7 @@ contract MerkleTreeExample { _siblings[3] = new bytes32[][](3); - _siblings[3][0] = new bytes32[](3); + _siblings[3][0] = new bytes32[](2); _siblings[3][0][0] = _leaves[3][1]; _siblings[3][0][1] = _heightOneNodes[3][1]; @@ -123,18 +115,11 @@ contract MerkleTreeExample { // State 4 { - _leaves[4] = new bytes32[](8); + _leaves[4] = new bytes32[](4); _leaves[4][0] = _leaves[3][0]; _leaves[4][1] = _leaves[3][1]; _leaves[4][2] = _leaves[3][2]; _leaves[4][3] = bytes32(uint256(4)); - _leaves[4][4] = SHA256.EMPTY_HASH; - - _leaves[4][5] = SHA256.EMPTY_HASH; - - _leaves[4][6] = SHA256.EMPTY_HASH; - - _leaves[4][7] = SHA256.EMPTY_HASH; _heightOneNodes[4] = _calculateNextLevel(_leaves[4]); @@ -144,25 +129,21 @@ contract MerkleTreeExample { _siblings[4] = new bytes32[][](4); - _siblings[4][0] = new bytes32[](3); + _siblings[4][0] = new bytes32[](2); _siblings[4][0][0] = _leaves[4][1]; _siblings[4][0][1] = _heightOneNodes[4][1]; - _siblings[4][0][2] = _heightTwoNodes[4][1]; - _siblings[4][1] = new bytes32[](3); + _siblings[4][1] = new bytes32[](2); _siblings[4][1][0] = _leaves[4][0]; _siblings[4][1][1] = _heightOneNodes[4][1]; - _siblings[4][1][2] = _heightTwoNodes[4][1]; - _siblings[4][2] = new bytes32[](3); + _siblings[4][2] = new bytes32[](2); _siblings[4][2][0] = _leaves[4][3]; _siblings[4][2][1] = _heightOneNodes[4][0]; - _siblings[4][2][2] = _heightTwoNodes[4][1]; - _siblings[4][3] = new bytes32[](3); + _siblings[4][3] = new bytes32[](2); _siblings[4][3][0] = _leaves[4][2]; _siblings[4][3][1] = _heightOneNodes[4][0]; - _siblings[4][3][2] = _heightTwoNodes[4][1]; } // State 5 @@ -306,7 +287,7 @@ contract MerkleTreeExample { _siblings[7][3][2] = _heightTwoNodes[7][1]; _siblings[7][4] = new bytes32[](3); - _siblings[7][4][0] = _leaves[7][4]; + _siblings[7][4][0] = _leaves[7][5]; _siblings[7][4][1] = _heightOneNodes[7][3]; _siblings[7][4][2] = _heightTwoNodes[7][0]; diff --git a/contracts/test/state/CommitmentTree.t.sol b/contracts/test/state/CommitmentTree.t.sol index 2bec4df1..b3cfd585 100644 --- a/contracts/test/state/CommitmentTree.t.sol +++ b/contracts/test/state/CommitmentTree.t.sol @@ -80,10 +80,10 @@ contract CommitmentTreeTest is Test, MerkleTreeExample { function test_should_produce_an_invalid_root_for_a_non_existent_leaf() public { bytes32 nonExistentCommitment = sha256("NON_EXISTENT"); - for (uint256 i = 0; i < _N_LEAVES; ++i) { + for (uint256 i = 1; i < _N_LEAVES; ++i) { bytes32 root = _cmAcc.addCommitment(_leaves[i + 1][i]); - for (uint256 j = 0; j <= i; ++j) { + for (uint256 j = 0; j < i; ++j) { bytes32 computedRoot = MerkleTree.processProof({ siblings: _siblings[i + 1][j], directionBits: _directionBits[_cmAcc.commitmentTreeCapacity()][j], @@ -118,7 +118,9 @@ contract CommitmentTreeTest is Test, MerkleTreeExample { */ bytes32 commitment = bytes32(uint256(1)); - bytes32 newRoot = _cmAcc.addCommitment(commitment); + _cmAcc.addCommitment(commitment); + bytes32 existingCommitment = bytes32(uint256(3)); + bytes32 newRoot = _cmAcc.addCommitment(existingCommitment); _cmAcc.storeCommitmentTreeRoot(newRoot); bytes32 nonExistingCommitment = bytes32(uint256(2)); @@ -153,6 +155,8 @@ contract CommitmentTreeTest is Test, MerkleTreeExample { function test_verifyMerkleProof_reverts_on_wrong_path() public { bytes32 commitment = sha256("SOMETHING"); + bytes32 commitment2 = sha256("ELSE"); + _cmAcc.addCommitment(commitment2); bytes32 newRoot = _cmAcc.addCommitment(commitment); _cmAcc.storeCommitmentTreeRoot(newRoot); diff --git a/contracts/test/state/MerkleTree.t.sol b/contracts/test/state/MerkleTree.t.sol index 3c8336b9..014d3854 100644 --- a/contracts/test/state/MerkleTree.t.sol +++ b/contracts/test/state/MerkleTree.t.sol @@ -27,11 +27,11 @@ contract MerkleTreeTest is Test, MerkleTreeExample { _merkleTree.push(_leaves[7][0]); assertEq(_merkleTree.leafCount(), 1); - assertEq(_merkleTree.depth(), 1); + assertEq(_merkleTree.depth(), 0); _merkleTree.push(_leaves[7][1]); assertEq(_merkleTree.leafCount(), 2); - assertEq(_merkleTree.depth(), 2); + assertEq(_merkleTree.depth(), 1); _merkleTree.push(_leaves[7][2]); assertEq(_merkleTree.leafCount(), 3); @@ -39,7 +39,7 @@ contract MerkleTreeTest is Test, MerkleTreeExample { _merkleTree.push(_leaves[7][3]); assertEq(_merkleTree.leafCount(), 4); - assertEq(_merkleTree.depth(), 3); + assertEq(_merkleTree.depth(), 2); _merkleTree.push(_leaves[7][4]); assertEq(_merkleTree.leafCount(), 5); @@ -62,7 +62,7 @@ contract MerkleTreeTest is Test, MerkleTreeExample { // First compute what tree depth is required to store leaves uint8 treeDepth = 0; // Essentially compute the logarithm of the leaf count base 2 - for (uint256 i = leaves.length; i > 0; i >>= 1) { + for (uint256 i = 1; i < leaves.length; i <<= 1) { treeDepth++; } // Set up a protocol adapter Merkle tree and an OpenZeppelin one