Skip to content

mssmt: add new Copy method and InsertMany to optimize slightly #1467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions mssmt/compacted_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,191 @@ func (t *CompactedTree) MerkleProof(ctx context.Context, key [hashSize]byte) (

return NewProof(proof), nil
}

// collectLeavesRecursive is a recursive helper function that's used to traverse
// down an MS-SMT tree and collect all leaf nodes. It returns a map of leaf
// nodes indexed by their hash.
func collectLeavesRecursive(ctx context.Context, tx TreeStoreViewTx, node Node,
depth int) (map[[hashSize]byte]*LeafNode, error) {

// Base case: If it's a compacted leaf node.
if compactedLeaf, ok := node.(*CompactedLeafNode); ok {
if compactedLeaf.LeafNode.IsEmpty() {
return make(map[[hashSize]byte]*LeafNode), nil
}
return map[[hashSize]byte]*LeafNode{
compactedLeaf.Key(): compactedLeaf.LeafNode,
}, nil
}

// Recursive step: If it's a branch node.
if branchNode, ok := node.(*BranchNode); ok {
// Optimization: if the branch is empty, return early.
if depth < MaxTreeLevels &&
IsEqualNode(branchNode, EmptyTree[depth]) {

return make(map[[hashSize]byte]*LeafNode), nil
}

// Handle case where depth might exceed EmptyTree bounds if
// logic error exists
if depth >= MaxTreeLevels {
// This shouldn't happen if called correctly, implies a
// leaf.
return nil, fmt.Errorf("invalid depth %d for branch "+
"node", depth)
}

left, right, err := tx.GetChildren(depth, branchNode.NodeHash())
if err != nil {
// If children not found, it might be an empty branch
// implicitly Check if the error indicates "not found"
// or similar Depending on store impl, this might be how
// empty is signaled For now, treat error as fatal.
return nil, fmt.Errorf("error getting children for "+
"branch %s at depth %d: %w",
branchNode.NodeHash(), depth, err)
}

leftLeaves, err := collectLeavesRecursive(
ctx, tx, left, depth+1,
)
if err != nil {
return nil, err
}

rightLeaves, err := collectLeavesRecursive(
ctx, tx, right, depth+1,
)
if err != nil {
return nil, err
}

// Merge the results.
for k, v := range rightLeaves {
// Check for duplicate keys, although this shouldn't
// happen in a valid SMT.
if _, exists := leftLeaves[k]; exists {
return nil, fmt.Errorf("duplicate key %x "+
"found during leaf collection", k)
}
leftLeaves[k] = v
}

return leftLeaves, nil
}

// Handle unexpected node types or implicit empty nodes. If node is nil
// or explicitly an EmptyLeafNode representation
if node == nil || IsEqualNode(node, EmptyLeafNode) {
return make(map[[hashSize]byte]*LeafNode), nil
}

// Check against EmptyTree branches if possible (requires depth)
if depth < MaxTreeLevels && IsEqualNode(node, EmptyTree[depth]) {
return make(map[[hashSize]byte]*LeafNode), nil
}

return nil, fmt.Errorf("unexpected node type %T encountered "+
"during leaf collection at depth %d", node, depth)
}

// Copy copies all the key-value pairs from the source tree into the target
// tree.
func (t *CompactedTree) Copy(ctx context.Context, targetTree Tree) error {
var leaves map[[hashSize]byte]*LeafNode
err := t.store.View(ctx, func(tx TreeStoreViewTx) error {
root, err := tx.RootNode()
if err != nil {
return fmt.Errorf("error getting root node: %w", err)
}

// Optimization: If the source tree is empty, there's nothing to
// copy.
if IsEqualNode(root, EmptyTree[0]) {
leaves = make(map[[hashSize]byte]*LeafNode)
return nil
}

// Start recursive collection from the root at depth 0.
leaves, err = collectLeavesRecursive(ctx, tx, root, 0)
if err != nil {
return fmt.Errorf("error collecting leaves: %w", err)
}

return nil
})
if err != nil {
return err
}

// Insert all found leaves into the target tree using InsertMany for
// efficiency.
_, err = targetTree.InsertMany(ctx, leaves)
if err != nil {
return fmt.Errorf("error inserting leaves into "+
"target tree: %w", err)
}

return nil
}

// InsertMany inserts multiple leaf nodes provided in the leaves map within a
// single database transaction.
func (t *CompactedTree) InsertMany(ctx context.Context,
leaves map[[hashSize]byte]*LeafNode) (Tree, error) {

if len(leaves) == 0 {
return t, nil
}

dbErr := t.store.Update(ctx, func(tx TreeStoreUpdateTx) error {
currentRoot, err := tx.RootNode()
if err != nil {
return err
}
rootBranch := currentRoot.(*BranchNode)
Copy link
Preview

Copilot AI Apr 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The direct type assertion of currentRoot to *BranchNode can lead to a runtime panic if the type is not correct. It is recommended to perform a type check to handle unexpected types gracefully.

Suggested change
rootBranch := currentRoot.(*BranchNode)
rootBranch, ok := currentRoot.(*BranchNode)
if !ok {
return fmt.Errorf("expected currentRoot to be of type *BranchNode, got %T", currentRoot)
}

Copilot uses AI. Check for mistakes.


for key, leaf := range leaves {
// Check for potential sum overflow before each
// insertion.
sumRoot := rootBranch.NodeSum()
sumLeaf := leaf.NodeSum()
err = CheckSumOverflowUint64(sumRoot, sumLeaf)
if err != nil {
return fmt.Errorf("compact tree leaf insert "+
"sum overflow, root: %d, leaf: %d; %w",
sumRoot, sumLeaf, err)
}

// Insert the leaf using the internal helper.
newRoot, err := t.insert(
tx, &key, 0, rootBranch, leaf,
)
if err != nil {
return fmt.Errorf("error inserting leaf "+
"with key %x: %w", key, err)
}
rootBranch = newRoot

// Update the root within the transaction for
// consistency, even though the insert logic passes the
// root explicitly.
err = tx.UpdateRoot(rootBranch)
if err != nil {
return fmt.Errorf("error updating root "+
"during InsertMany: %w", err)
}
}

// The root is already updated by the last iteration of the
// loop. No final update needed here, but returning nil error
// signals success.
return nil
})
if dbErr != nil {
return nil, dbErr
}

return t, nil
}
9 changes: 9 additions & 0 deletions mssmt/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,13 @@ type Tree interface {
// proof. This is noted by the returned `Proof` containing an empty
// leaf.
MerkleProof(ctx context.Context, key [hashSize]byte) (*Proof, error)

// InsertMany inserts multiple leaf nodes provided in the leaves map
// within a single database transaction.
InsertMany(ctx context.Context, leaves map[[hashSize]byte]*LeafNode) (
Tree, error)

// Copy copies all the key-value pairs from the source tree into the
// target tree.
Copy(ctx context.Context, targetTree Tree) error
}
164 changes: 164 additions & 0 deletions mssmt/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ func bitIndex(idx uint8, key *[hashSize]byte) byte {
return (byteVal >> (idx % 8)) & 1
}

// setBit returns a copy of the key with the bit at the given depth set to 1.
func setBit(key [hashSize]byte, depth int) [hashSize]byte {
byteIndex := depth / 8
bitIndex := depth % 8
key[byteIndex] |= (1 << bitIndex)
return key
}

// iterFunc is a type alias for closures to be invoked at every iteration of
// walking through a tree.
type iterFunc = func(height int, current, sibling, parent Node) error
Expand Down Expand Up @@ -333,6 +341,162 @@ func (t *FullTree) MerkleProof(ctx context.Context, key [hashSize]byte) (
return NewProof(proof), nil
}

// findLeaves recursively traverses the tree represented by the given node and
// collects all non-empty leaf nodes along with their reconstructed keys.
func findLeaves(ctx context.Context, tx TreeStoreViewTx, node Node,
keyPrefix [hashSize]byte,
depth int) (map[[hashSize]byte]*LeafNode, error) {

// Base case: If it's a leaf node.
if leafNode, ok := node.(*LeafNode); ok {
if leafNode.IsEmpty() {
return make(map[[hashSize]byte]*LeafNode), nil
}
return map[[hashSize]byte]*LeafNode{keyPrefix: leafNode}, nil
}

// Recursive step: If it's a branch node.
if branchNode, ok := node.(*BranchNode); ok {
// Optimization: if the branch is empty, return early.
if IsEqualNode(branchNode, EmptyTree[depth]) {
return make(map[[hashSize]byte]*LeafNode), nil
}

left, right, err := tx.GetChildren(depth, branchNode.NodeHash())
if err != nil {
return nil, fmt.Errorf("error getting children for "+
"branch %s at depth %d: %w",
branchNode.NodeHash(), depth, err)
}

// Recursively find leaves in the left subtree. The key prefix
// remains the same as the 0 bit is implicitly handled by the
// initial keyPrefix state.
leftLeaves, err := findLeaves(
ctx, tx, left, keyPrefix, depth+1,
)
if err != nil {
return nil, err
}

// Recursively find leaves in the right subtree. Set the bit
// corresponding to the current depth in the key prefix.
rightKeyPrefix := setBit(keyPrefix, depth)

rightLeaves, err := findLeaves(
ctx, tx, right, rightKeyPrefix, depth+1,
)
if err != nil {
return nil, err
}

// Merge the results.
for k, v := range rightLeaves {
leftLeaves[k] = v
}
return leftLeaves, nil
}

// Handle unexpected node types.
return nil, fmt.Errorf("unexpected node type %T encountered "+
"during leaf collection", node)
}

// Copy copies all the key-value pairs from the source tree into the target
// tree.
func (t *FullTree) Copy(ctx context.Context, targetTree Tree) error {
var leaves map[[hashSize]byte]*LeafNode
err := t.store.View(ctx, func(tx TreeStoreViewTx) error {
root, err := tx.RootNode()
if err != nil {
return fmt.Errorf("error getting root node: %w", err)
}

// Optimization: If the source tree is empty, there's nothing
// to copy.
if IsEqualNode(root, EmptyTree[0]) {
leaves = make(map[[hashSize]byte]*LeafNode)
return nil
}

leaves, err = findLeaves(ctx, tx, root, [hashSize]byte{}, 0)
if err != nil {
return fmt.Errorf("error finding leaves: %w", err)
}
return nil
})
if err != nil {
return err
}

// Insert all found leaves into the target tree using InsertMany for
// efficiency.
_, err = targetTree.InsertMany(ctx, leaves)
if err != nil {
return fmt.Errorf("error inserting leaves into target "+
"tree: %w", err)
}

return nil
}

// InsertMany inserts multiple leaf nodes provided in the leaves map within a
// single database transaction.
func (t *FullTree) InsertMany(ctx context.Context,
leaves map[[hashSize]byte]*LeafNode) (Tree, error) {

if len(leaves) == 0 {
return t, nil
}

err := t.store.Update(ctx, func(tx TreeStoreUpdateTx) error {
currentRoot, err := tx.RootNode()
if err != nil {
return err
}
rootBranch := currentRoot.(*BranchNode)

for key, leaf := range leaves {
// Check for potential sum overflow before each
// insertion.
sumRoot := rootBranch.NodeSum()
sumLeaf := leaf.NodeSum()
err = CheckSumOverflowUint64(sumRoot, sumLeaf)
if err != nil {
return fmt.Errorf("full tree leaf insert sum "+
"overflow, root: %d, leaf: %d; %w",
sumRoot, sumLeaf, err)
}

// Insert the leaf using the internal helper.
newRoot, err := t.insert(tx, &key, leaf)
if err != nil {
return fmt.Errorf("error inserting leaf "+
"with key %x: %w", key, err)
}
rootBranch = newRoot

// Update the root within the transaction so subsequent
// inserts in this batch read the correct state.
err = tx.UpdateRoot(rootBranch)
if err != nil {
return fmt.Errorf("error updating root "+
"during InsertMany: %w", err)
}
}

// The root is already updated by the last iteration of the
// loop. No final update needed here, but returning nil error
// signals success.
return nil
})
if err != nil {
return nil, err
}

return t, nil
}

// VerifyMerkleProof determines whether a merkle proof for the leaf found at the
// given key is valid.
func VerifyMerkleProof(key [hashSize]byte, leaf *LeafNode, proof *Proof,
Expand Down
Loading
Loading