diff --git a/core/trie/key.go b/core/trie/key.go index 618698c34f..dc946a5ef6 100644 --- a/core/trie/key.go +++ b/core/trie/key.go @@ -29,7 +29,7 @@ func (k *Key) SubKey(n uint8) *Key { } newKey := &Key{len: n} - copy(newKey.bitset[:], k.bitset[len(k.bitset)-int((k.len+7)/8):]) + copy(newKey.bitset[:], k.bitset[len(k.bitset)-int((k.len+7)/8):]) //nolint:gomnd // Shift right by the number of bits that are not needed shift := k.len - n diff --git a/core/trie/proof.go b/core/trie/proof.go index 026f236b43..854afc394f 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -20,8 +20,8 @@ func (pn *ProofNode) Hash() *felt.Felt { case pn.Binary != nil: return crypto.Pedersen(pn.Binary.LeftHash, pn.Binary.RightHash) case pn.Edge != nil: - length := make([]byte, 32) - length[31] = pn.Edge.Path.len + length := make([]byte, len(pn.Edge.Path.bitset)) + length[len(pn.Edge.Path.bitset)-1] = pn.Edge.Path.len pathFelt := pn.Edge.Path.Felt() lengthFelt := new(felt.Felt).SetBytes(length) return new(felt.Felt).Add(crypto.Pedersen(pn.Edge.Child, &pathFelt), lengthFelt) @@ -31,7 +31,6 @@ func (pn *ProofNode) Hash() *felt.Felt { } func (pn *ProofNode) PrettyPrint() { - if pn.Binary != nil { fmt.Printf(" Binary:\n") fmt.Printf(" LeftHash: %v\n", pn.Binary.LeftHash) @@ -56,6 +55,37 @@ type Edge struct { Value *felt.Felt } +func isEdge(sNode storageNode, nodeNumFromRoot int) bool { + sNodeLen := sNode.key.len + leftKey := sNode.node.Left.len + rightKey := sNode.node.Right.len + if nodeNumFromRoot == 0 { // Is root + return sNodeLen != 1 + } + if (leftKey-sNodeLen > 1) || (rightKey-sNodeLen > 1) { + return true + } + return false +} + +// The binary node uses the hash of children. If the child is an edge, we first need to represent it +// as an edge node, and then take its hash. +func getChildHash(tri *Trie, sNode storageNode, childKey *Key, nodeNumFromRoot int) (*felt.Felt, error) { + childNode, err := tri.GetNodeFromKey(childKey) + if err != nil { + return nil, err + } + leftIsEdgeBool := isEdge(storageNode{node: childNode, key: childKey}, nodeNumFromRoot) + if leftIsEdgeBool { + leftEdge := ProofNode{Edge: &Edge{ + Path: sNode.node.Left, + Child: childNode.Value, + }} + return leftEdge.Hash(), nil + } + return childNode.Value, nil +} + // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L514 func GetProof(leaf *felt.Felt, tri *Trie) ([]ProofNode, error) { leafKey := tri.feltToKey(leaf) @@ -66,17 +96,6 @@ func GetProof(leaf *felt.Felt, tri *Trie) ([]ProofNode, error) { nodesExcludingLeaf := nodesToLeaf[:len(nodesToLeaf)-1] proofNodes := []ProofNode{} - getValue := func(key *Key) (*felt.Felt, error) { - node, err := tri.GetNodeFromKey(key) - if err != nil { - return nil, err - } - return node.Value, nil - // return node.Hash(key, crypto.Pedersen), nil - } - - height := uint8(0) - // 1. If it's an edge-node in pathfinders impl, we need to expand the node into an edge + binary // -> Child should be internal node (len<251). Distance between child and parent should be > 1. // 2. If it's a binary-node, we store binary @@ -86,133 +105,67 @@ func GetProof(leaf *felt.Felt, tri *Trie) ([]ProofNode, error) { // 4. If it's an edge leaf, we store an edge leaf // -> Child should be leaf (len=251). Distance between child and parent should be > 1. - // Edge nodes are defined as having a child with len greater than 1 from the parent - isEdge := func(sNode *storageNode, nodeNumFromRoot int) (bool, error) { - sNodeLen := sNode.key.len - leftKey := sNode.node.Left.len - rightKey := sNode.node.Right.len - if nodeNumFromRoot == 0 { // Is root - if sNodeLen != 1 { - return true, nil - } - return false, nil - } - if (leftKey-sNodeLen > 1) || (rightKey-sNodeLen > 1) { - return true, nil - } - return false, nil - } - for i, sNode := range nodesExcludingLeaf { - height += uint8(sNode.key.len) - leftHash, err := getValue(sNode.node.Left) + leftHash, err := getChildHash(tri, sNode, sNode.node.Left, i) if err != nil { return nil, err } - - rightHash, err := getValue(sNode.node.Right) + rightHash, err := getChildHash(tri, sNode, sNode.node.Left, i) if err != nil { return nil, err } - fmt.Println("LeftHash", leftHash.String()) - fmt.Println("rightHash", rightHash.String()) - child := nodesToLeaf[i+1] - parentChildDistance := child.key.len - sNode.key.len + childIsInternal := nodesToLeaf[i+1].key.len < 251 + isEdgeBool := isEdge(sNode, i) - isEdgeBool, err := isEdge(&sNode, i) - if err != nil { - return nil, err - } - if child.key.len < 251 && isEdgeBool { // Internal Edge + if childIsInternal && isEdgeBool { // Internal Edge + // Juno node is split into an edge + binary. proofNodes = append(proofNodes, ProofNode{ Edge: &Edge{ - Path: sNode.key, // Todo: Path from that node to the leaf? + Path: sNode.key, Child: sNode.node.Value, // Value: value, // Todo: ?? }, - }) - height += sNode.key.len - - // Todo: If the child is an edge, we need to the hash of it's edge form. Both children. - leftNode, err := tri.GetNodeFromKey(sNode.node.Left) - if err != nil { - return nil, err - } - leftIsEdgeBool, err := isEdge(&storageNode{node: leftNode, key: sNode.node.Left}, i) - if err != nil { - return nil, err - } - if leftIsEdgeBool { - leftEdge := ProofNode{Edge: &Edge{ - Path: sNode.node.Left, - Child: leftNode.Value, - }} - leftHash = leftEdge.Hash() - } - rightNode, err := tri.GetNodeFromKey(sNode.node.Right) - if err != nil { - return nil, err - } - rightIsEdgeBool, err := isEdge(&storageNode{node: rightNode, key: sNode.node.Right}, i) - if err != nil { - return nil, err - } - if rightIsEdgeBool { - rightEdge := ProofNode{Edge: &Edge{ - Path: sNode.node.Right, - Child: rightNode.Value, - }} - rightHash = rightEdge.Hash() - } + }, + ProofNode{ + Binary: &Binary{ + LeftHash: leftHash, + RightHash: rightHash, + }, + }) + } else if childIsInternal && !isEdgeBool { // Internal Binary proofNodes = append(proofNodes, ProofNode{ Binary: &Binary{ LeftHash: leftHash, RightHash: rightHash, }, }) - height++ - - } else if child.key.len < 251 && parentChildDistance == 1 { // Internal Binary + } else if !childIsInternal && isEdgeBool { // Leaf Edge proofNodes = append(proofNodes, ProofNode{ - Binary: &Binary{ - LeftHash: leftHash, - RightHash: rightHash, + Edge: &Edge{ + Child: sNode.node.Value, + // Value: value, // Todo: ?? }, }) - height++ - } else if child.key.len == 251 && parentChildDistance == 1 { // Leaf binary + } else if !childIsInternal && !isEdgeBool { // Leaf binary proofNodes = append(proofNodes, ProofNode{ Binary: &Binary{ LeftHash: leftHash, RightHash: rightHash, }, }) - height++ - } else if child.key.len == 251 && isEdgeBool { // lead Edge - proofNodes = append(proofNodes, ProofNode{ - Edge: &Edge{ - // Path: sNode.key, // Todo: Path from that node to the leaf? - Child: sNode.node.Value, - // Value: value, // Todo: ?? - }, - }) - height += sNode.key.len } else { return nil, errors.New("unexpected error in GetProof") } - } - return proofNodes, nil } // verifyProof checks if `leafPath` leads from `root` to `leafHash` along the `proofNodes` // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2006 func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode) bool { - - if key.Len() != 251 { + if key.Len() != key.len { return false } @@ -233,7 +186,8 @@ func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode remainingPath.RemoveLastBit() case proofNode.Edge != nil: // The next "proofNode.Edge.len" bits must match - if !proofNode.Edge.Path.Equal(remainingPath.SubKey(proofNode.Edge.Path.Len())) { // Todo: Isn't edge.path from root? and remaining from edge to leaf?? + // Todo: Isn't edge.path from root? and remaining from edge to leaf?? + if !proofNode.Edge.Path.Equal(remainingPath.SubKey(proofNode.Edge.Path.Len())) { return false } expectedHash = proofNode.Edge.Child diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index e697628fde..88d8f44dcb 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -1,7 +1,6 @@ package trie_test import ( - "fmt" "testing" "github.com/NethermindEth/juno/core/felt" @@ -67,28 +66,6 @@ func buildSimpleDoubleBinaryTrie(t *testing.T) *trie.Trie { return tempTrie } -// func getProofNodeBinary(t *testing.T, tri *trie.Trie, node *trie.Node) trie.ProofNode { -// getHash := func(tri *trie.Trie, key *trie.Key) (*felt.Felt, error) { -// keyFelt := key.Felt() -// node2, err := tri.GetNode(&keyFelt) -// if err != nil { -// return nil, err -// } -// return node2.Hash(key, crypto.Pedersen), nil -// } - -// left, err := getHash(tri, node.Left) -// require.NoError(t, err) -// right, err := getHash(tri, node.Right) -// require.NoError(t, err) - -// return trie.ProofNode{ -// Binary: &trie.Binary{ -// LeftHash: left, RightHash: right}, -// } - -// } - func TestGetProofs(t *testing.T) { t.Run("Simple Trie - simple binary", func(t *testing.T) { tempTrie := buildSimpleTrie(t) @@ -116,20 +93,6 @@ func TestGetProofs(t *testing.T) { for _, pNode := range proofNodes { pNode.PrettyPrint() } - require.Equal(t, len(expectedProofNodes), len(proofNodes)) - for i, proof := range expectedProofNodes { - if proof.Binary != nil { - fmt.Println(proof.Binary.LeftHash.String(), expectedProofNodes[i].Binary.LeftHash.String()) - fmt.Println(proof.Binary.RightHash.String(), expectedProofNodes[i].Binary.RightHash.String()) - require.Equal(t, proof.Binary.LeftHash.String(), expectedProofNodes[i].Binary.LeftHash.String()) - require.Equal(t, proof.Binary.RightHash, expectedProofNodes[i].Binary.RightHash) - } else { - fmt.Println(proof.Edge.Child.String(), expectedProofNodes[i].Edge.Child.String()) - fmt.Println(proof.Edge.Path.String(), expectedProofNodes[i].Edge.Path.String()) - require.Equal(t, proof.Edge.Child.String(), expectedProofNodes[i].Edge.Child.String()) - require.Equal(t, proof.Edge.Path, expectedProofNodes[i].Edge.Path) - } - } require.Equal(t, expectedProofNodes, proofNodes) }) @@ -165,20 +128,6 @@ func TestGetProofs(t *testing.T) { for _, pNode := range proofNodes { pNode.PrettyPrint() } - require.Equal(t, len(expectedProofNodes), len(proofNodes)) - for i, proof := range expectedProofNodes { - if proof.Binary != nil { - fmt.Println(proof.Binary.LeftHash.String(), expectedProofNodes[i].Binary.LeftHash.String()) - fmt.Println(proof.Binary.RightHash.String(), expectedProofNodes[i].Binary.RightHash.String()) - require.Equal(t, proof.Binary.LeftHash.String(), expectedProofNodes[i].Binary.LeftHash.String()) - require.Equal(t, proof.Binary.RightHash, expectedProofNodes[i].Binary.RightHash) - } else { - fmt.Println(proof.Edge.Child.String(), expectedProofNodes[i].Edge.Child.String()) - fmt.Println(proof.Edge.Path.String(), expectedProofNodes[i].Edge.Path.String()) - require.Equal(t, proof.Edge.Child.String(), expectedProofNodes[i].Edge.Child.String()) - require.Equal(t, proof.Edge.Path, expectedProofNodes[i].Edge.Path) - } - } require.Equal(t, expectedProofNodes, proofNodes) }) @@ -214,20 +163,6 @@ func TestGetProofs(t *testing.T) { for _, pNode := range proofNodes { pNode.PrettyPrint() } - require.Equal(t, len(expectedProofNodes), len(proofNodes)) - for i, proof := range expectedProofNodes { - if proof.Binary != nil { - fmt.Println(proof.Binary.LeftHash.String(), expectedProofNodes[i].Binary.LeftHash.String()) - fmt.Println(proof.Binary.RightHash.String(), expectedProofNodes[i].Binary.RightHash.String()) - require.Equal(t, proof.Binary.LeftHash.String(), expectedProofNodes[i].Binary.LeftHash.String()) - require.Equal(t, proof.Binary.RightHash, expectedProofNodes[i].Binary.RightHash) - } else { - fmt.Println(proof.Edge.Child.String(), expectedProofNodes[i].Edge.Child.String()) - fmt.Println(proof.Edge.Path.String(), expectedProofNodes[i].Edge.Path.String()) - require.Equal(t, proof.Edge.Child.String(), expectedProofNodes[i].Edge.Child.String()) - require.Equal(t, proof.Edge.Path, expectedProofNodes[i].Edge.Path) - } - } require.Equal(t, expectedProofNodes, proofNodes) }) }