diff --git a/db/buckets.go b/db/buckets.go index 2773773f5a..e5037378a3 100644 --- a/db/buckets.go +++ b/db/buckets.go @@ -34,6 +34,10 @@ const ( Temporary // used temporarily for migrations SchemaIntermediateState L1HandlerTxnHashByMsgHash // maps l1 handler msg hash to l1 handler txn hash + MempoolHead // key of the head node + MempoolTail // key of the tail node + MempoolLength // number of transactions + MempoolNode ) // Key flattens a prefix and series of byte arrays into a single []byte. diff --git a/mempool/db_utils.go b/mempool/db_utils.go new file mode 100644 index 0000000000..1638910464 --- /dev/null +++ b/mempool/db_utils.go @@ -0,0 +1,63 @@ +package mempool + +import ( + "errors" + "math/big" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/encoder" +) + +func headValue(txn db.Transaction, head *felt.Felt) error { + return txn.Get(db.MempoolHead.Key(), func(b []byte) error { + head.SetBytes(b) + return nil + }) +} + +func tailValue(txn db.Transaction, tail *felt.Felt) error { + return txn.Get(db.MempoolTail.Key(), func(b []byte) error { + tail.SetBytes(b) + return nil + }) +} + +func updateHead(txn db.Transaction, head *felt.Felt) error { + return txn.Set(db.MempoolHead.Key(), head.Marshal()) +} + +func updateTail(txn db.Transaction, tail *felt.Felt) error { + return txn.Set(db.MempoolTail.Key(), tail.Marshal()) +} + +func readDBElem(txn db.Transaction, itemKey *felt.Felt) (dbPoolTxn, error) { + var item dbPoolTxn + keyBytes := itemKey.Bytes() + err := txn.Get(db.MempoolNode.Key(keyBytes[:]), func(b []byte) error { + return encoder.Unmarshal(b, &item) + }) + return item, err +} + +func setDBElem(txn db.Transaction, item *dbPoolTxn) error { + itemBytes, err := encoder.Marshal(item) + if err != nil { + return err + } + keyBytes := item.Txn.Transaction.Hash().Bytes() + return txn.Set(db.MempoolNode.Key(keyBytes[:]), itemBytes) +} + +func lenDB(txn db.Transaction) (int, error) { + var l int + err := txn.Get(db.MempoolLength.Key(), func(b []byte) error { + l = int(new(big.Int).SetBytes(b).Int64()) + return nil + }) + + if err != nil && errors.Is(err, db.ErrKeyNotFound) { + return 0, nil + } + return l, err +} diff --git a/mempool/mempool.go b/mempool/mempool.go new file mode 100644 index 0000000000..4ebb80f7c8 --- /dev/null +++ b/mempool/mempool.go @@ -0,0 +1,302 @@ +package mempool + +import ( + "errors" + "fmt" + "math/big" + "sync" + + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/utils" +) + +var ErrTxnPoolFull = errors.New("transaction pool is full") + +type BroadcastedTransaction struct { + Transaction core.Transaction + DeclaredClass core.Class +} + +// runtime mempool txn +type memPoolTxn struct { + Txn BroadcastedTransaction + Next *memPoolTxn +} + +// persistent db txn value +type dbPoolTxn struct { + Txn BroadcastedTransaction + NextHash *felt.Felt +} + +// memTxnList represents a linked list of user transactions at runtime +type memTxnList struct { + head *memPoolTxn + tail *memPoolTxn + len int + mu sync.Mutex +} + +func (t *memTxnList) push(newNode *memPoolTxn) { + t.mu.Lock() + defer t.mu.Unlock() + if t.tail != nil { + t.tail.Next = newNode + t.tail = newNode + } else { + t.head = newNode + t.tail = newNode + } + t.len++ +} + +func (t *memTxnList) pop() (BroadcastedTransaction, error) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.head == nil { + return BroadcastedTransaction{}, errors.New("transaction pool is empty") + } + + headNode := t.head + t.head = headNode.Next + if t.head == nil { + t.tail = nil + } + t.len-- + return headNode.Txn, nil +} + +// Pool represents a blockchain mempool, managing transactions using both an +// in-memory and persistent database. +type Pool struct { + log utils.SimpleLogger + state core.StateReader + db db.DB // to store the persistent mempool + txPushed chan struct{} + memTxnList *memTxnList + maxNumTxns int + dbWriteChan chan *BroadcastedTransaction + wg sync.WaitGroup +} + +// New initialises the Pool and starts the database writer goroutine. +// It is the responsibility of the caller to execute the closer function. +func New(mainDB db.DB, state core.StateReader, maxNumTxns int, log utils.SimpleLogger) (*Pool, func() error) { + pool := &Pool{ + log: log, + state: state, + db: mainDB, // todo: txns should be deleted everytime a new block is stored (builder responsibility) + txPushed: make(chan struct{}, 1), + memTxnList: &memTxnList{}, + maxNumTxns: maxNumTxns, + dbWriteChan: make(chan *BroadcastedTransaction, maxNumTxns), + } + closer := func() error { + close(pool.dbWriteChan) + pool.wg.Wait() + if err := pool.db.Close(); err != nil { + return fmt.Errorf("failed to close mempool database: %v", err) + } + return nil + } + pool.dbWriter() + return pool, closer +} + +func (p *Pool) dbWriter() { + p.wg.Add(1) + go func() { + defer p.wg.Done() + for txn := range p.dbWriteChan { + err := p.writeToDB(txn) + if err != nil { + p.log.Errorw("error in handling user transaction in persistent mempool", "err", err) + } + } + }() +} + +// LoadFromDB restores the in-memory transaction pool from the database +func (p *Pool) LoadFromDB() error { + return p.db.View(func(txn db.Transaction) error { + headVal := new(felt.Felt) + err := headValue(txn, headVal) + if err != nil { + if errors.Is(err, db.ErrKeyNotFound) { + return nil + } + return err + } + // loop through the persistent pool and push nodes to the in-memory pool + currentHash := headVal + for currentHash != nil { + curDBElem, err := readDBElem(txn, currentHash) + if err != nil { + return err + } + newMemPoolTxn := &memPoolTxn{ + Txn: curDBElem.Txn, + } + if curDBElem.NextHash != nil { + nextDBTxn, err := readDBElem(txn, curDBElem.NextHash) + if err != nil { + return err + } + newMemPoolTxn.Next = &memPoolTxn{ + Txn: nextDBTxn.Txn, + } + } + p.memTxnList.push(newMemPoolTxn) + currentHash = curDBElem.NextHash + } + return nil + }) +} + +// writeToDB adds the transaction to the persistent pool db +func (p *Pool) writeToDB(userTxn *BroadcastedTransaction) error { + return p.db.Update(func(dbTxn db.Transaction) error { + tailVal := new(felt.Felt) + if err := tailValue(dbTxn, tailVal); err != nil { + if !errors.Is(err, db.ErrKeyNotFound) { + return err + } + tailVal = nil + } + if err := setDBElem(dbTxn, &dbPoolTxn{Txn: *userTxn}); err != nil { + return err + } + if tailVal != nil { + // Update old tail to point to the new item + var oldTailElem dbPoolTxn + oldTailElem, err := readDBElem(dbTxn, tailVal) + if err != nil { + return err + } + oldTailElem.NextHash = userTxn.Transaction.Hash() + if err = setDBElem(dbTxn, &oldTailElem); err != nil { + return err + } + } else { + // Empty list, make new item both the head and the tail + if err := updateHead(dbTxn, userTxn.Transaction.Hash()); err != nil { + return err + } + } + if err := updateTail(dbTxn, userTxn.Transaction.Hash()); err != nil { + return err + } + pLen, err := lenDB(dbTxn) + if err != nil { + return err + } + return dbTxn.Set(db.MempoolLength.Key(), new(big.Int).SetInt64(int64(pLen+1)).Bytes()) + }) +} + +// Push queues a transaction to the pool +func (p *Pool) Push(userTxn *BroadcastedTransaction) error { + err := p.validate(userTxn) + if err != nil { + return err + } + + select { + case p.dbWriteChan <- userTxn: + default: + select { + case _, ok := <-p.dbWriteChan: + if !ok { + p.log.Errorw("cannot store user transasction in persistent pool, database write channel is closed") + } + p.log.Errorw("cannot store user transasction in persistent pool, database is full") + default: + p.log.Errorw("cannot store user transasction in persistent pool, database is full") + } + } + + newNode := &memPoolTxn{Txn: *userTxn, Next: nil} + p.memTxnList.push(newNode) + + select { + case p.txPushed <- struct{}{}: + default: + } + + return nil +} + +func (p *Pool) validate(userTxn *BroadcastedTransaction) error { + if p.memTxnList.len+1 >= p.maxNumTxns { + return ErrTxnPoolFull + } + + switch t := userTxn.Transaction.(type) { + case *core.DeployTransaction: + return fmt.Errorf("deploy transactions are not supported") + case *core.DeployAccountTransaction: + if !t.Nonce.IsZero() { + return fmt.Errorf("validation failed, received non-zero nonce %s", t.Nonce) + } + case *core.DeclareTransaction: + nonce, err := p.state.ContractNonce(t.SenderAddress) + if err != nil { + return fmt.Errorf("validation failed, error when retrieving nonce, %v", err) + } + if nonce.Cmp(t.Nonce) > 0 { + return fmt.Errorf("validation failed, existing nonce %s, but received nonce %s", nonce, t.Nonce) + } + case *core.InvokeTransaction: + if t.TxVersion().Is(0) { // cant verify nonce since SenderAddress was only added in v1 + return fmt.Errorf("invoke v0 transactions not supported") + } + nonce, err := p.state.ContractNonce(t.SenderAddress) + if err != nil { + return fmt.Errorf("validation failed, error when retrieving nonce, %v", err) + } + if nonce.Cmp(t.Nonce) > 0 { + return fmt.Errorf("validation failed, existing nonce %s, but received nonce %s", nonce, t.Nonce) + } + case *core.L1HandlerTransaction: + // todo: verification of the L1 handler nonce requires checking the + // message nonce on the L1 Core Contract. + } + return nil +} + +// Pop returns the transaction with the highest priority from the in-memory pool +func (p *Pool) Pop() (BroadcastedTransaction, error) { + return p.memTxnList.pop() +} + +// Remove removes a set of transactions from the pool +// todo: should be called by the builder to remove txns from the db everytime a new block is stored. +// todo: in the consensus+p2p world, the txns should also be removed from the in-memory pool. +func (p *Pool) Remove(hash ...*felt.Felt) error { + return errors.New("not implemented") +} + +// Len returns the number of transactions in the in-memory pool +func (p *Pool) Len() int { + return p.memTxnList.len +} + +func (p *Pool) Wait() <-chan struct{} { + return p.txPushed +} + +// Len returns the number of transactions in the persistent pool +func (p *Pool) LenDB() (int, error) { + txn, err := p.db.NewTransaction(false) + if err != nil { + return 0, err + } + lenDB, err := lenDB(txn) + if err != nil { + return 0, err + } + return lenDB, txn.Discard() +} diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go new file mode 100644 index 0000000000..21e9ba89e1 --- /dev/null +++ b/mempool/mempool_test.go @@ -0,0 +1,226 @@ +package mempool_test + +import ( + "os" + "testing" + "time" + + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/pebble" + _ "github.com/NethermindEth/juno/encoder/registry" + "github.com/NethermindEth/juno/mempool" + "github.com/NethermindEth/juno/mocks" + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func setupDatabase(dbPath string, dltExisting bool) (db.DB, func(), error) { + if _, err := os.Stat(dbPath); err == nil { + if dltExisting { + if err := os.RemoveAll(dbPath); err != nil { + return nil, nil, err + } + } + } else if !os.IsNotExist(err) { + return nil, nil, err + } + persistentPool, err := pebble.New(dbPath) + if err != nil { + return nil, nil, err + } + closer := func() { + // The db should be closed by the mempool closer function + os.RemoveAll(dbPath) + } + return persistentPool, closer, nil +} + +func TestMempool(t *testing.T) { + testDB, dbCloser, err := setupDatabase("testmempool", true) + log := utils.NewNopZapLogger() + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + state := mocks.NewMockStateHistoryReader(mockCtrl) + require.NoError(t, err) + defer dbCloser() + pool, closer := mempool.New(testDB, state, 4, log) + require.NoError(t, pool.LoadFromDB()) + + require.Equal(t, 0, pool.Len()) + + _, err = pool.Pop() + require.Equal(t, err.Error(), "transaction pool is empty") + + // push multiple to empty (1,2,3) + for i := uint64(1); i < 4; i++ { + senderAddress := new(felt.Felt).SetUint64(i) + state.EXPECT().ContractNonce(senderAddress).Return(&felt.Zero, nil) + require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(i), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: senderAddress, + Version: new(core.TransactionVersion).SetUint64(1), + }, + })) + require.Equal(t, int(i), pool.Len()) + } + // consume some (remove 1,2, keep 3) + for i := uint64(1); i < 3; i++ { + txn, err := pool.Pop() + require.NoError(t, err) + require.Equal(t, i, txn.Transaction.Hash().Uint64()) + require.Equal(t, int(3-i), pool.Len()) + } + + // push multiple to non empty (push 4,5. now have 3,4,5) + for i := uint64(4); i < 6; i++ { + senderAddress := new(felt.Felt).SetUint64(i) + state.EXPECT().ContractNonce(senderAddress).Return(&felt.Zero, nil) + require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(i), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: senderAddress, + Version: new(core.TransactionVersion).SetUint64(1), + }, + })) + require.Equal(t, int(i-2), pool.Len()) + } + + // push more than max + require.ErrorIs(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(123), + }, + }), mempool.ErrTxnPoolFull) + + // consume all (remove 3,4,5) + for i := uint64(3); i < 6; i++ { + txn, err := pool.Pop() + require.NoError(t, err) + require.Equal(t, i, txn.Transaction.Hash().Uint64()) + } + require.Equal(t, 0, pool.Len()) + + _, err = pool.Pop() + require.Equal(t, err.Error(), "transaction pool is empty") + require.NoError(t, closer()) +} + +func TestRestoreMempool(t *testing.T) { + log := utils.NewNopZapLogger() + + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + state := mocks.NewMockStateHistoryReader(mockCtrl) + testDB, dbCloser, err := setupDatabase("testrestoremempool", true) + require.NoError(t, err) + defer dbCloser() + + pool, closer := mempool.New(testDB, state, 1024, log) + require.NoError(t, pool.LoadFromDB()) + // Check both pools are empty + lenDB, err := pool.LenDB() + require.NoError(t, err) + require.Equal(t, 0, lenDB) + require.Equal(t, 0, pool.Len()) + + // push multiple transactions to empty mempool (1,2,3) + for i := uint64(1); i < 4; i++ { + senderAddress := new(felt.Felt).SetUint64(i) + state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil) + require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(i), + Version: new(core.TransactionVersion).SetUint64(1), + SenderAddress: senderAddress, + Nonce: new(felt.Felt).SetUint64(0), + }, + })) + require.Equal(t, int(i), pool.Len()) + } + // check the db has stored the transactions + time.Sleep(100 * time.Millisecond) + lenDB, err = pool.LenDB() + require.NoError(t, err) + require.Equal(t, 3, lenDB) + // Close the mempool + require.NoError(t, closer()) + + testDB, _, err = setupDatabase("testrestoremempool", false) + require.NoError(t, err) + + poolRestored, closer2 := mempool.New(testDB, state, 1024, log) + time.Sleep(100 * time.Millisecond) + require.NoError(t, poolRestored.LoadFromDB()) + lenDB, err = poolRestored.LenDB() + require.NoError(t, err) + require.Equal(t, 3, lenDB) + require.Equal(t, 3, poolRestored.Len()) + + // Remove transactions + _, err = poolRestored.Pop() + require.NoError(t, err) + _, err = poolRestored.Pop() + require.NoError(t, err) + lenDB, err = poolRestored.LenDB() + require.NoError(t, err) + require.Equal(t, 3, lenDB) + require.Equal(t, 1, poolRestored.Len()) + require.NoError(t, closer2()) +} + +func TestWait(t *testing.T) { + log := utils.NewNopZapLogger() + testDB, dbCloser, err := setupDatabase("testwait", true) + require.NoError(t, err) + defer dbCloser() + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + state := mocks.NewMockStateHistoryReader(mockCtrl) + pool, _ := mempool.New(testDB, state, 1024, log) + require.NoError(t, pool.LoadFromDB()) + + select { + case <-pool.Wait(): + require.Fail(t, "wait channel should not be signalled on empty mempool") + default: + } + + // One transaction. + state.EXPECT().ContractNonce(new(felt.Felt).SetUint64(1)).Return(new(felt.Felt).SetUint64(0), nil) + require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(1), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: new(felt.Felt).SetUint64(1), + Version: new(core.TransactionVersion).SetUint64(1), + }, + })) + <-pool.Wait() + + // Two transactions. + state.EXPECT().ContractNonce(new(felt.Felt).SetUint64(2)).Return(new(felt.Felt).SetUint64(0), nil) + require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(2), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: new(felt.Felt).SetUint64(2), + Version: new(core.TransactionVersion).SetUint64(1), + }, + })) + state.EXPECT().ContractNonce(new(felt.Felt).SetUint64(3)).Return(new(felt.Felt).SetUint64(0), nil) + require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{ + Transaction: &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(3), + Nonce: new(felt.Felt).SetUint64(1), + SenderAddress: new(felt.Felt).SetUint64(3), + Version: new(core.TransactionVersion).SetUint64(1), + }, + })) + <-pool.Wait() +}