Skip to content

Commit

Permalink
comments: db_utils.go, inline, felt.Zero
Browse files Browse the repository at this point in the history
  • Loading branch information
rianhughes committed Jan 15, 2025
1 parent 99b8ede commit fd3959d
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 83 deletions.
63 changes: 63 additions & 0 deletions mempool/db_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package mempool

import (
"errors"
"math/big"

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/db"
"github.com/NethermindEth/juno/encoder"
)

func headValue(txn db.Transaction, head *felt.Felt) error {
return txn.Get(db.MempoolHead.Key(), func(b []byte) error {
head.SetBytes(b)
return nil
})
}

func tailValue(txn db.Transaction, tail *felt.Felt) error {
return txn.Get(db.MempoolTail.Key(), func(b []byte) error {
tail.SetBytes(b)
return nil
})
}

func updateHead(txn db.Transaction, head *felt.Felt) error {
return txn.Set(db.MempoolHead.Key(), head.Marshal())
}

func updateTail(txn db.Transaction, tail *felt.Felt) error {
return txn.Set(db.MempoolTail.Key(), tail.Marshal())
}

func readDBElem(txn db.Transaction, itemKey *felt.Felt) (dbPoolTxn, error) {
var item dbPoolTxn
keyBytes := itemKey.Bytes()
err := txn.Get(db.MempoolNode.Key(keyBytes[:]), func(b []byte) error {
return encoder.Unmarshal(b, &item)
})
return item, err
}

func setDBElem(txn db.Transaction, item *dbPoolTxn) error {
itemBytes, err := encoder.Marshal(item)
if err != nil {
return err
}
keyBytes := item.Txn.Transaction.Hash().Bytes()
return txn.Set(db.MempoolNode.Key(keyBytes[:]), itemBytes)
}

func lenDB(txn db.Transaction) (int, error) {
var l int
err := txn.Get(db.MempoolLength.Key(), func(b []byte) error {
l = int(new(big.Int).SetBytes(b).Int64())
return nil
})

if err != nil && errors.Is(err, db.ErrKeyNotFound) {
return 0, nil
}
return l, err
}
102 changes: 21 additions & 81 deletions mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/NethermindEth/juno/core"
"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/db"
"github.com/NethermindEth/juno/encoder"
"github.com/NethermindEth/juno/utils"
)

Expand Down Expand Up @@ -123,26 +122,26 @@ func (p *Pool) dbWriter() {
// LoadFromDB restores the in-memory transaction pool from the database
func (p *Pool) LoadFromDB() error {
return p.db.View(func(txn db.Transaction) error {
headValue := new(felt.Felt)
err := p.headHash(txn, headValue)
headVal := new(felt.Felt)
err := headValue(txn, headVal)
if err != nil {
if errors.Is(err, db.ErrKeyNotFound) {
return nil
}
return err
}
// loop through the persistent pool and push nodes to the in-memory pool
currentHash := headValue
currentHash := headVal
for currentHash != nil {
curDBElem, err := p.readDBElem(txn, currentHash)
curDBElem, err := readDBElem(txn, currentHash)
if err != nil {
return err
}
newMemPoolTxn := &memPoolTxn{
Txn: curDBElem.Txn,
}
if curDBElem.NextHash != nil {
nextDBTxn, err := p.readDBElem(txn, curDBElem.NextHash)
nextDBTxn, err := readDBElem(txn, curDBElem.NextHash)
if err != nil {
return err
}
Expand All @@ -160,41 +159,41 @@ func (p *Pool) LoadFromDB() error {
// writeToDB adds the transaction to the persistent pool db
func (p *Pool) writeToDB(userTxn *BroadcastedTransaction) error {
return p.db.Update(func(dbTxn db.Transaction) error {
tailValue := new(felt.Felt)
if err := p.tailValue(dbTxn, tailValue); err != nil {
tailVal := new(felt.Felt)
if err := tailValue(dbTxn, tailVal); err != nil {
if !errors.Is(err, db.ErrKeyNotFound) {
return err
}
tailValue = nil
tailVal = nil
}
if err := p.setDBElem(dbTxn, &dbPoolTxn{Txn: *userTxn}); err != nil {
if err := setDBElem(dbTxn, &dbPoolTxn{Txn: *userTxn}); err != nil {
return err
}
if tailValue != nil {
if tailVal != nil {
// Update old tail to point to the new item
var oldTailElem dbPoolTxn
oldTailElem, err := p.readDBElem(dbTxn, tailValue)
oldTailElem, err := readDBElem(dbTxn, tailVal)
if err != nil {
return err
}
oldTailElem.NextHash = userTxn.Transaction.Hash()
if err = p.setDBElem(dbTxn, &oldTailElem); err != nil {
if err = setDBElem(dbTxn, &oldTailElem); err != nil {
return err
}
} else {
// Empty list, make new item both the head and the tail
if err := p.updateHead(dbTxn, userTxn.Transaction.Hash()); err != nil {
if err := updateHead(dbTxn, userTxn.Transaction.Hash()); err != nil {
return err
}
}
if err := p.updateTail(dbTxn, userTxn.Transaction.Hash()); err != nil {
if err := updateTail(dbTxn, userTxn.Transaction.Hash()); err != nil {
return err
}
pLen, err := p.lenDB(dbTxn)
pLen, err := lenDB(dbTxn)
if err != nil {
return err
}
return p.updateLen(dbTxn, pLen+1)
return dbTxn.Set(db.MempoolLength.Key(), new(big.Int).SetInt64(int64(pLen+1)).Bytes())
})
}

Expand Down Expand Up @@ -285,78 +284,19 @@ func (p *Pool) Len() int {
return p.memTxnList.len
}

func (p *Pool) Wait() <-chan struct{} {
return p.txPushed
}

// Len returns the number of transactions in the persistent pool
func (p *Pool) LenDB() (int, error) {
p.wg.Add(1)
defer p.wg.Done()
txn, err := p.db.NewTransaction(false)
if err != nil {
return 0, err
}
lenDB, err := p.lenDB(txn)
lenDB, err := lenDB(txn)
if err != nil {
return 0, err
}
return lenDB, txn.Discard()
}

func (p *Pool) lenDB(txn db.Transaction) (int, error) {
var l int
err := txn.Get(db.MempoolLength.Key(), func(b []byte) error {
l = int(new(big.Int).SetBytes(b).Int64())
return nil
})

if err != nil && errors.Is(err, db.ErrKeyNotFound) {
return 0, nil
}
return l, err
}

func (p *Pool) updateLen(txn db.Transaction, l int) error {
return txn.Set(db.MempoolLength.Key(), new(big.Int).SetInt64(int64(l)).Bytes())
}

func (p *Pool) Wait() <-chan struct{} {
return p.txPushed
}

func (p *Pool) headHash(txn db.Transaction, head *felt.Felt) error {
return txn.Get(db.MempoolHead.Key(), func(b []byte) error {
head.SetBytes(b)
return nil
})
}

func (p *Pool) updateHead(txn db.Transaction, head *felt.Felt) error {
return txn.Set(db.MempoolHead.Key(), head.Marshal())
}

func (p *Pool) tailValue(txn db.Transaction, tail *felt.Felt) error {
return txn.Get(db.MempoolTail.Key(), func(b []byte) error {
tail.SetBytes(b)
return nil
})
}

func (p *Pool) updateTail(txn db.Transaction, tail *felt.Felt) error {
return txn.Set(db.MempoolTail.Key(), tail.Marshal())
}

func (p *Pool) readDBElem(txn db.Transaction, itemKey *felt.Felt) (dbPoolTxn, error) {
var item dbPoolTxn
keyBytes := itemKey.Bytes()
err := txn.Get(db.MempoolNode.Key(keyBytes[:]), func(b []byte) error {
return encoder.Unmarshal(b, &item)
})
return item, err
}

func (p *Pool) setDBElem(txn db.Transaction, item *dbPoolTxn) error {
itemBytes, err := encoder.Marshal(item)
if err != nil {
return err
}
keyBytes := item.Txn.Transaction.Hash().Bytes()
return txn.Set(db.MempoolNode.Key(keyBytes[:]), itemBytes)
}
4 changes: 2 additions & 2 deletions mempool/mempool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func TestMempool(t *testing.T) {
// push multiple to empty (1,2,3)
for i := uint64(1); i < 4; i++ { //nolint:dupl
senderAddress := new(felt.Felt).SetUint64(i)
state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil)
state.EXPECT().ContractNonce(senderAddress).Return(&felt.Zero, nil)
require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{
Transaction: &core.InvokeTransaction{
TransactionHash: new(felt.Felt).SetUint64(i),
Expand All @@ -79,7 +79,7 @@ func TestMempool(t *testing.T) {
// push multiple to non empty (push 4,5. now have 3,4,5)
for i := uint64(4); i < 6; i++ {
senderAddress := new(felt.Felt).SetUint64(i)
state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil)
state.EXPECT().ContractNonce(senderAddress).Return(&felt.Zero, nil)
require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{
Transaction: &core.InvokeTransaction{
TransactionHash: new(felt.Felt).SetUint64(i),
Expand Down

0 comments on commit fd3959d

Please sign in to comment.