Skip to content

Commit

Permalink
tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
rian committed May 7, 2024
1 parent e9d28dd commit c343913
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 168 deletions.
2 changes: 1 addition & 1 deletion core/trie/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (k *Key) SubKey(n uint8) *Key {
}

newKey := &Key{len: n}
copy(newKey.bitset[:], k.bitset[len(k.bitset)-int((k.len+7)/8):])
copy(newKey.bitset[:], k.bitset[len(k.bitset)-int((k.len+7)/8):]) //nolint:gomnd

// Shift right by the number of bits that are not needed
shift := k.len - n
Expand Down
158 changes: 56 additions & 102 deletions core/trie/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ func (pn *ProofNode) Hash() *felt.Felt {
case pn.Binary != nil:
return crypto.Pedersen(pn.Binary.LeftHash, pn.Binary.RightHash)
case pn.Edge != nil:
length := make([]byte, 32)
length[31] = pn.Edge.Path.len
length := make([]byte, len(pn.Edge.Path.bitset))
length[len(pn.Edge.Path.bitset)-1] = pn.Edge.Path.len
pathFelt := pn.Edge.Path.Felt()
lengthFelt := new(felt.Felt).SetBytes(length)
return new(felt.Felt).Add(crypto.Pedersen(pn.Edge.Child, &pathFelt), lengthFelt)
Expand All @@ -31,7 +31,6 @@ func (pn *ProofNode) Hash() *felt.Felt {
}

func (pn *ProofNode) PrettyPrint() {

if pn.Binary != nil {
fmt.Printf(" Binary:\n")
fmt.Printf(" LeftHash: %v\n", pn.Binary.LeftHash)
Expand All @@ -56,6 +55,37 @@ type Edge struct {
Value *felt.Felt
}

func isEdge(sNode storageNode, nodeNumFromRoot int) bool {
sNodeLen := sNode.key.len
leftKey := sNode.node.Left.len
rightKey := sNode.node.Right.len
if nodeNumFromRoot == 0 { // Is root
return sNodeLen != 1
}
if (leftKey-sNodeLen > 1) || (rightKey-sNodeLen > 1) {
return true
}
return false
}

// The binary node uses the hash of children. If the child is an edge, we first need to represent it
// as an edge node, and then take its hash.
func getChildHash(tri *Trie, sNode storageNode, childKey *Key, nodeNumFromRoot int) (*felt.Felt, error) {
childNode, err := tri.GetNodeFromKey(childKey)
if err != nil {
return nil, err
}
leftIsEdgeBool := isEdge(storageNode{node: childNode, key: childKey}, nodeNumFromRoot)
if leftIsEdgeBool {
leftEdge := ProofNode{Edge: &Edge{
Path: sNode.node.Left,
Child: childNode.Value,
}}
return leftEdge.Hash(), nil
}
return childNode.Value, nil
}

// https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L514
func GetProof(leaf *felt.Felt, tri *Trie) ([]ProofNode, error) {
leafKey := tri.feltToKey(leaf)
Expand All @@ -66,17 +96,6 @@ func GetProof(leaf *felt.Felt, tri *Trie) ([]ProofNode, error) {
nodesExcludingLeaf := nodesToLeaf[:len(nodesToLeaf)-1]
proofNodes := []ProofNode{}

getValue := func(key *Key) (*felt.Felt, error) {
node, err := tri.GetNodeFromKey(key)
if err != nil {
return nil, err
}
return node.Value, nil
// return node.Hash(key, crypto.Pedersen), nil
}

height := uint8(0)

// 1. If it's an edge-node in pathfinders impl, we need to expand the node into an edge + binary
// -> Child should be internal node (len<251). Distance between child and parent should be > 1.
// 2. If it's a binary-node, we store binary
Expand All @@ -86,133 +105,67 @@ func GetProof(leaf *felt.Felt, tri *Trie) ([]ProofNode, error) {
// 4. If it's an edge leaf, we store an edge leaf
// -> Child should be leaf (len=251). Distance between child and parent should be > 1.

// Edge nodes are defined as having a child with len greater than 1 from the parent
isEdge := func(sNode *storageNode, nodeNumFromRoot int) (bool, error) {
sNodeLen := sNode.key.len
leftKey := sNode.node.Left.len
rightKey := sNode.node.Right.len
if nodeNumFromRoot == 0 { // Is root
if sNodeLen != 1 {
return true, nil
}
return false, nil
}
if (leftKey-sNodeLen > 1) || (rightKey-sNodeLen > 1) {
return true, nil
}
return false, nil
}

for i, sNode := range nodesExcludingLeaf {

Check failure on line 108 in core/trie/proof.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary leading newline (whitespace)
height += uint8(sNode.key.len)

leftHash, err := getValue(sNode.node.Left)
leftHash, err := getChildHash(tri, sNode, sNode.node.Left, i)
if err != nil {
return nil, err
}

rightHash, err := getValue(sNode.node.Right)
rightHash, err := getChildHash(tri, sNode, sNode.node.Left, i)
if err != nil {
return nil, err
}
fmt.Println("LeftHash", leftHash.String())
fmt.Println("rightHash", rightHash.String())

child := nodesToLeaf[i+1]
parentChildDistance := child.key.len - sNode.key.len
childIsInternal := nodesToLeaf[i+1].key.len < 251
isEdgeBool := isEdge(sNode, i)

isEdgeBool, err := isEdge(&sNode, i)
if err != nil {
return nil, err
}
if child.key.len < 251 && isEdgeBool { // Internal Edge
if childIsInternal && isEdgeBool { // Internal Edge
// Juno node is split into an edge + binary.
proofNodes = append(proofNodes, ProofNode{
Edge: &Edge{
Path: sNode.key, // Todo: Path from that node to the leaf?
Path: sNode.key,
Child: sNode.node.Value,
// Value: value, // Todo: ??
},
})
height += sNode.key.len

// Todo: If the child is an edge, we need to the hash of it's edge form. Both children.
leftNode, err := tri.GetNodeFromKey(sNode.node.Left)
if err != nil {
return nil, err
}
leftIsEdgeBool, err := isEdge(&storageNode{node: leftNode, key: sNode.node.Left}, i)
if err != nil {
return nil, err
}
if leftIsEdgeBool {
leftEdge := ProofNode{Edge: &Edge{
Path: sNode.node.Left,
Child: leftNode.Value,
}}
leftHash = leftEdge.Hash()
}
rightNode, err := tri.GetNodeFromKey(sNode.node.Right)
if err != nil {
return nil, err
}
rightIsEdgeBool, err := isEdge(&storageNode{node: rightNode, key: sNode.node.Right}, i)
if err != nil {
return nil, err
}
if rightIsEdgeBool {
rightEdge := ProofNode{Edge: &Edge{
Path: sNode.node.Right,
Child: rightNode.Value,
}}
rightHash = rightEdge.Hash()
}
},
ProofNode{
Binary: &Binary{
LeftHash: leftHash,
RightHash: rightHash,
},
})
} else if childIsInternal && !isEdgeBool { // Internal Binary
proofNodes = append(proofNodes, ProofNode{
Binary: &Binary{
LeftHash: leftHash,
RightHash: rightHash,
},
})
height++

} else if child.key.len < 251 && parentChildDistance == 1 { // Internal Binary
} else if !childIsInternal && isEdgeBool { // Leaf Edge
proofNodes = append(proofNodes, ProofNode{
Binary: &Binary{
LeftHash: leftHash,
RightHash: rightHash,
Edge: &Edge{
Child: sNode.node.Value,
// Value: value, // Todo: ??
},
})
height++
} else if child.key.len == 251 && parentChildDistance == 1 { // Leaf binary
} else if !childIsInternal && !isEdgeBool { // Leaf binary
proofNodes = append(proofNodes, ProofNode{
Binary: &Binary{
LeftHash: leftHash,
RightHash: rightHash,
},
})
height++
} else if child.key.len == 251 && isEdgeBool { // lead Edge
proofNodes = append(proofNodes, ProofNode{
Edge: &Edge{
// Path: sNode.key, // Todo: Path from that node to the leaf?
Child: sNode.node.Value,
// Value: value, // Todo: ??
},
})
height += sNode.key.len
} else {
return nil, errors.New("unexpected error in GetProof")
}

}

return proofNodes, nil
}

// verifyProof checks if `leafPath` leads from `root` to `leafHash` along the `proofNodes`
// https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2006
func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode) bool {

if key.Len() != 251 {
if key.Len() != key.len {
return false
}

Expand All @@ -233,7 +186,8 @@ func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode
remainingPath.RemoveLastBit()
case proofNode.Edge != nil:
// The next "proofNode.Edge.len" bits must match
if !proofNode.Edge.Path.Equal(remainingPath.SubKey(proofNode.Edge.Path.Len())) { // Todo: Isn't edge.path from root? and remaining from edge to leaf??
// Todo: Isn't edge.path from root? and remaining from edge to leaf??
if !proofNode.Edge.Path.Equal(remainingPath.SubKey(proofNode.Edge.Path.Len())) {
return false
}
expectedHash = proofNode.Edge.Child
Expand Down
65 changes: 0 additions & 65 deletions core/trie/proof_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package trie_test

import (
"fmt"
"testing"

"github.com/NethermindEth/juno/core/felt"
Expand Down Expand Up @@ -67,28 +66,6 @@ func buildSimpleDoubleBinaryTrie(t *testing.T) *trie.Trie {
return tempTrie
}

// func getProofNodeBinary(t *testing.T, tri *trie.Trie, node *trie.Node) trie.ProofNode {
// getHash := func(tri *trie.Trie, key *trie.Key) (*felt.Felt, error) {
// keyFelt := key.Felt()
// node2, err := tri.GetNode(&keyFelt)
// if err != nil {
// return nil, err
// }
// return node2.Hash(key, crypto.Pedersen), nil
// }

// left, err := getHash(tri, node.Left)
// require.NoError(t, err)
// right, err := getHash(tri, node.Right)
// require.NoError(t, err)

// return trie.ProofNode{
// Binary: &trie.Binary{
// LeftHash: left, RightHash: right},
// }

// }

func TestGetProofs(t *testing.T) {
t.Run("Simple Trie - simple binary", func(t *testing.T) {
tempTrie := buildSimpleTrie(t)
Expand Down Expand Up @@ -116,20 +93,6 @@ func TestGetProofs(t *testing.T) {
for _, pNode := range proofNodes {
pNode.PrettyPrint()
}
require.Equal(t, len(expectedProofNodes), len(proofNodes))
for i, proof := range expectedProofNodes {
if proof.Binary != nil {
fmt.Println(proof.Binary.LeftHash.String(), expectedProofNodes[i].Binary.LeftHash.String())
fmt.Println(proof.Binary.RightHash.String(), expectedProofNodes[i].Binary.RightHash.String())
require.Equal(t, proof.Binary.LeftHash.String(), expectedProofNodes[i].Binary.LeftHash.String())
require.Equal(t, proof.Binary.RightHash, expectedProofNodes[i].Binary.RightHash)
} else {
fmt.Println(proof.Edge.Child.String(), expectedProofNodes[i].Edge.Child.String())
fmt.Println(proof.Edge.Path.String(), expectedProofNodes[i].Edge.Path.String())
require.Equal(t, proof.Edge.Child.String(), expectedProofNodes[i].Edge.Child.String())
require.Equal(t, proof.Edge.Path, expectedProofNodes[i].Edge.Path)
}
}
require.Equal(t, expectedProofNodes, proofNodes)
})

Expand Down Expand Up @@ -165,20 +128,6 @@ func TestGetProofs(t *testing.T) {
for _, pNode := range proofNodes {
pNode.PrettyPrint()
}
require.Equal(t, len(expectedProofNodes), len(proofNodes))
for i, proof := range expectedProofNodes {
if proof.Binary != nil {
fmt.Println(proof.Binary.LeftHash.String(), expectedProofNodes[i].Binary.LeftHash.String())
fmt.Println(proof.Binary.RightHash.String(), expectedProofNodes[i].Binary.RightHash.String())
require.Equal(t, proof.Binary.LeftHash.String(), expectedProofNodes[i].Binary.LeftHash.String())
require.Equal(t, proof.Binary.RightHash, expectedProofNodes[i].Binary.RightHash)
} else {
fmt.Println(proof.Edge.Child.String(), expectedProofNodes[i].Edge.Child.String())
fmt.Println(proof.Edge.Path.String(), expectedProofNodes[i].Edge.Path.String())
require.Equal(t, proof.Edge.Child.String(), expectedProofNodes[i].Edge.Child.String())
require.Equal(t, proof.Edge.Path, expectedProofNodes[i].Edge.Path)
}
}
require.Equal(t, expectedProofNodes, proofNodes)
})

Expand Down Expand Up @@ -214,20 +163,6 @@ func TestGetProofs(t *testing.T) {
for _, pNode := range proofNodes {
pNode.PrettyPrint()
}
require.Equal(t, len(expectedProofNodes), len(proofNodes))
for i, proof := range expectedProofNodes {
if proof.Binary != nil {
fmt.Println(proof.Binary.LeftHash.String(), expectedProofNodes[i].Binary.LeftHash.String())
fmt.Println(proof.Binary.RightHash.String(), expectedProofNodes[i].Binary.RightHash.String())
require.Equal(t, proof.Binary.LeftHash.String(), expectedProofNodes[i].Binary.LeftHash.String())
require.Equal(t, proof.Binary.RightHash, expectedProofNodes[i].Binary.RightHash)
} else {
fmt.Println(proof.Edge.Child.String(), expectedProofNodes[i].Edge.Child.String())
fmt.Println(proof.Edge.Path.String(), expectedProofNodes[i].Edge.Path.String())
require.Equal(t, proof.Edge.Child.String(), expectedProofNodes[i].Edge.Child.String())
require.Equal(t, proof.Edge.Path, expectedProofNodes[i].Edge.Path)
}
}
require.Equal(t, expectedProofNodes, proofNodes)
})
}
Expand Down

0 comments on commit c343913

Please sign in to comment.