Skip to content

Commit

Permalink
add nonce validation + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rianhughes committed Jan 8, 2025
1 parent 17a7759 commit 0cf2454
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 19 deletions.
5 changes: 5 additions & 0 deletions mempool/init_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package mempool_test

import (
_ "github.com/NethermindEth/juno/encoder/registry"
)
53 changes: 47 additions & 6 deletions mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ type txnList struct {

// Pool stores the transactions in a linked list for its inherent FCFS behaviour
type Pool struct {
db db.DB
state core.StateReader
db db.DB // persistent mempool
txPushed chan struct{}
txnList *txnList // in-memory
maxNumTxns uint16
Expand All @@ -50,8 +51,9 @@ type Pool struct {

// New initializes the Pool and starts the database writer goroutine.

Check failure on line 52 in mempool/mempool.go

View workflow job for this annotation

GitHub Actions / lint

`initializes` is a misspelling of `initialises` (misspell)
// It is the responsibility of the user to call the cancel function if the context is cancelled
func New(db db.DB, maxNumTxns uint16) (*Pool, func() error, error) {
func New(db db.DB, state core.StateReader, maxNumTxns uint16) (*Pool, func() error, error) {

Check failure on line 54 in mempool/mempool.go

View workflow job for this annotation

GitHub Actions / lint

importShadow: shadow of imported from 'github.com/NethermindEth/juno/db' package 'db' (gocritic)
pool := &Pool{
state: state,
db: db, // todo: txns should be deleted everytime a new block is stored (builder responsibility)
txPushed: make(chan struct{}, 1),
txnList: &txnList{},
Expand Down Expand Up @@ -195,12 +197,11 @@ func (p *Pool) handleTransaction(userTxn *BroadcastedTransaction) error {

// Push queues a transaction to the pool and adds it to both the in-memory list and DB
func (p *Pool) Push(userTxn *BroadcastedTransaction) error {
if p.txnList.len >= uint16(p.maxNumTxns) {
return ErrTxnPoolFull
err := p.validate(userTxn)
if err != nil {
return err
}

// todo(rian this PR): validation

// todo: should db overloading block the in-memory mempool??
select {
case p.dbWriteChan <- userTxn:
Expand Down Expand Up @@ -236,6 +237,44 @@ func (p *Pool) Push(userTxn *BroadcastedTransaction) error {
return nil
}

func (p *Pool) validate(userTxn *BroadcastedTransaction) error {
if p.txnList.len+1 >= uint16(p.maxNumTxns) {
return ErrTxnPoolFull
}

switch t := userTxn.Transaction.(type) {
case *core.DeployTransaction:
return fmt.Errorf("deploy transactions are not supported")
case *core.DeployAccountTransaction:
if !t.Nonce.IsZero() {
return fmt.Errorf("validation failed, received non-zero nonce %s", t.Nonce)
}
case *core.DeclareTransaction:
nonce, err := p.state.ContractNonce(t.SenderAddress)
if err != nil {
return fmt.Errorf("validation failed, error when retrieving nonce, %v:", err)
}
if nonce.Cmp(t.Nonce) > 0 {
return fmt.Errorf("validation failed, existing nonce %s, but received nonce %s", nonce, t.Nonce)
}
case *core.InvokeTransaction:
if t.TxVersion().Is(0) { // cant verify nonce since SenderAddress was only added in v1
return fmt.Errorf("invoke v0 transactions not supported")
}
nonce, err := p.state.ContractNonce(t.SenderAddress)
if err != nil {
return fmt.Errorf("validation failed, error when retrieving nonce, %v:", err)
}
if nonce.Cmp(t.Nonce) > 0 {
return fmt.Errorf("validation failed, existing nonce %s, but received nonce %s", nonce, t.Nonce)
}
case *core.L1HandlerTransaction:
// todo: verification of the L1 handler nonce requires checking the
// message nonce on the L1 Core Contract.
}
return nil
}

// Pop returns the transaction with the highest priority from the in-memory pool
func (p *Pool) Pop() (BroadcastedTransaction, error) {
p.txnList.mu.Lock()
Expand All @@ -256,6 +295,8 @@ func (p *Pool) Pop() (BroadcastedTransaction, error) {
}

// Remove removes a set of transactions from the pool
// todo: should be called by the builder to remove txns from the db everytime a new block is stored.
// todo: in the consensus+p2p world, the txns should also be removed from the in-memory pool.
func (p *Pool) Remove(hash ...*felt.Felt) error {
return errors.New("not implemented")
}
Expand Down
56 changes: 43 additions & 13 deletions mempool/mempool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ import (
"testing"
"time"

"github.com/NethermindEth/juno/blockchain"
"github.com/NethermindEth/juno/core"
"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/db"
"github.com/NethermindEth/juno/db/pebble"
"github.com/NethermindEth/juno/mempool"
"github.com/NethermindEth/juno/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

func setupDatabase(dltExisting bool) (*db.DB, func(), error) {

Check failure on line 19 in mempool/mempool_test.go

View workflow job for this annotation

GitHub Actions / lint

ptrToRefParam: consider to make non-pointer type for `*db.DB` (gocritic)
Expand All @@ -39,12 +40,14 @@ func setupDatabase(dltExisting bool) (*db.DB, func(), error) {

func TestMempool(t *testing.T) {
testDB, dbCloser, err := setupDatabase(true)
mockCtrl := gomock.NewController(t)
t.Cleanup(mockCtrl.Finish)
state := mocks.NewMockStateHistoryReader(mockCtrl)
require.NoError(t, err)
defer dbCloser()
pool, closer, err := mempool.New(*testDB, 5)
pool, closer, err := mempool.New(*testDB, state, 4)
defer closer()

Check failure on line 49 in mempool/mempool_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value is not checked (errcheck)
require.NoError(t, err)
blockchain.RegisterCoreTypesToEncoder()

l := pool.Len()
assert.Equal(t, uint16(0), l)
Expand All @@ -54,16 +57,19 @@ func TestMempool(t *testing.T) {

// push multiple to empty (1,2,3)
for i := uint64(1); i < 4; i++ {
senderAddress := new(felt.Felt).SetUint64(i)
state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil)
assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{
Transaction: &core.InvokeTransaction{
TransactionHash: new(felt.Felt).SetUint64(i),
Nonce: new(felt.Felt).SetUint64(1),
SenderAddress: senderAddress,
Version: new(core.TransactionVersion).SetUint64(1),
},
}))

l := pool.Len()
assert.Equal(t, uint16(i), l)
}

// consume some (remove 1,2, keep 3)
for i := uint64(1); i < 3; i++ {
txn, err := pool.Pop()
Expand All @@ -76,12 +82,16 @@ func TestMempool(t *testing.T) {

// push multiple to non empty (push 4,5. now have 3,4,5)
for i := uint64(4); i < 6; i++ {
senderAddress := new(felt.Felt).SetUint64(i)
state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil)
assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{
Transaction: &core.InvokeTransaction{
TransactionHash: new(felt.Felt).SetUint64(i),
Nonce: new(felt.Felt).SetUint64(1),
SenderAddress: senderAddress,
Version: new(core.TransactionVersion).SetUint64(1),
},
}))

l := pool.Len()
assert.Equal(t, uint16(i-2), l)
}
Expand All @@ -103,15 +113,16 @@ func TestMempool(t *testing.T) {

_, err = pool.Pop()
require.Equal(t, err.Error(), "transaction pool is empty")

}

func TestRestoreMempool(t *testing.T) {
blockchain.RegisterCoreTypesToEncoder()

mockCtrl := gomock.NewController(t)
t.Cleanup(mockCtrl.Finish)
state := mocks.NewMockStateHistoryReader(mockCtrl)
testDB, _, err := setupDatabase(true)
require.NoError(t, err)
pool, closer, err := mempool.New(*testDB, 1024)

pool, closer, err := mempool.New(*testDB, state, 1024)
require.NoError(t, err)

// Check both pools are empty
Expand All @@ -122,9 +133,14 @@ func TestRestoreMempool(t *testing.T) {

// push multiple transactions to empty mempool (1,2,3)
for i := uint64(1); i < 4; i++ {
senderAddress := new(felt.Felt).SetUint64(i)
state.EXPECT().ContractNonce(senderAddress).Return(new(felt.Felt).SetUint64(0), nil)
assert.NoError(t, pool.Push(&mempool.BroadcastedTransaction{
Transaction: &core.InvokeTransaction{
TransactionHash: new(felt.Felt).SetUint64(i),
Version: new(core.TransactionVersion).SetUint64(1),
SenderAddress: senderAddress,
Nonce: new(felt.Felt).SetUint64(0),
},
}))
assert.Equal(t, uint16(i), pool.Len())
Expand All @@ -142,7 +158,7 @@ func TestRestoreMempool(t *testing.T) {
require.NoError(t, err)
defer dbCloser()

poolRestored, closer2, err := mempool.New(*testDB, 1024)
poolRestored, closer2, err := mempool.New(*testDB, state, 1024)
require.NoError(t, err)
lenDB, err = poolRestored.LenDB()
require.NoError(t, err)
Expand All @@ -163,9 +179,11 @@ func TestRestoreMempool(t *testing.T) {

func TestWait(t *testing.T) {
testDB := pebble.NewMemTest(t)
pool, _, err := mempool.New(testDB, 1024)
mockCtrl := gomock.NewController(t)
t.Cleanup(mockCtrl.Finish)
state := mocks.NewMockStateHistoryReader(mockCtrl)
pool, _, err := mempool.New(testDB, state, 1024)
require.NoError(t, err)
blockchain.RegisterCoreTypesToEncoder()

select {
case <-pool.Wait():
Expand All @@ -174,22 +192,34 @@ func TestWait(t *testing.T) {
}

// One transaction.
state.EXPECT().ContractNonce(new(felt.Felt).SetUint64(1)).Return(new(felt.Felt).SetUint64(0), nil)
require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{
Transaction: &core.InvokeTransaction{
TransactionHash: new(felt.Felt).SetUint64(1),
Nonce: new(felt.Felt).SetUint64(1),
SenderAddress: new(felt.Felt).SetUint64(1),
Version: new(core.TransactionVersion).SetUint64(1),
},
}))
<-pool.Wait()

// Two transactions.
state.EXPECT().ContractNonce(new(felt.Felt).SetUint64(2)).Return(new(felt.Felt).SetUint64(0), nil)
require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{
Transaction: &core.InvokeTransaction{
TransactionHash: new(felt.Felt).SetUint64(2),
Nonce: new(felt.Felt).SetUint64(1),
SenderAddress: new(felt.Felt).SetUint64(2),
Version: new(core.TransactionVersion).SetUint64(1),
},
}))
state.EXPECT().ContractNonce(new(felt.Felt).SetUint64(3)).Return(new(felt.Felt).SetUint64(0), nil)
require.NoError(t, pool.Push(&mempool.BroadcastedTransaction{
Transaction: &core.InvokeTransaction{
TransactionHash: new(felt.Felt).SetUint64(3),
Nonce: new(felt.Felt).SetUint64(1),
SenderAddress: new(felt.Felt).SetUint64(3),
Version: new(core.TransactionVersion).SetUint64(1),
},
}))
<-pool.Wait()
Expand Down

0 comments on commit 0cf2454

Please sign in to comment.