Skip to content

Commit

Permalink
fix(auth): audit issues with unordered txs (#23392)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex | Interchain Labs <[email protected]>
Co-authored-by: Alexander Peters <[email protected]>
  • Loading branch information
3 people authored Jan 31, 2025
1 parent 8eb6822 commit ddf9e18
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 35 deletions.
17 changes: 17 additions & 0 deletions types/mempool/mempool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"math/rand"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -55,6 +56,21 @@ type testTx struct {
address sdk.AccAddress
// useful for debugging
strAddress string
unordered bool
timeout *time.Time
}

// GetTimeoutTimeStamp implements types.TxWithUnordered.
func (tx testTx) GetTimeoutTimeStamp() time.Time {
if tx.timeout == nil {
return time.Time{}
}
return *tx.timeout
}

// GetUnordered implements types.TxWithUnordered.
func (tx testTx) GetUnordered() bool {
return tx.unordered
}

func (tx testTx) GetSigners() ([][]byte, error) { panic("not implemented") }
Expand All @@ -73,6 +89,7 @@ func (tx testTx) GetSignaturesV2() (res []txsigning.SignatureV2, err error) {

var (
_ sdk.Tx = (*testTx)(nil)
_ sdk.TxWithUnordered = (*testTx)(nil)
_ signing.SigVerifiableTx = (*testTx)(nil)
_ cryptotypes.PubKey = (*testPubKey)(nil)
)
Expand Down
20 changes: 10 additions & 10 deletions types/mempool/priority_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ func (mp *PriorityNonceMempool[C]) Insert(ctx context.Context, tx sdk.Tx) error
priority := mp.cfg.TxPriority.GetTxPriority(ctx, tx)
nonce := sig.Sequence

// if it's an unordered tx, we use the gas instead of the nonce
// if it's an unordered tx, we use the timeout timestamp instead of the nonce
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
gasLimit, err := unordered.GetGasLimit()
nonce = gasLimit
if err != nil {
return err
timestamp := unordered.GetTimeoutTimeStamp().Unix()
if timestamp < 0 {
return errors.New("invalid timestamp value")
}
nonce = uint64(timestamp)
}

key := txMeta[C]{nonce: nonce, priority: priority, sender: sender}
Expand Down Expand Up @@ -469,13 +469,13 @@ func (mp *PriorityNonceMempool[C]) Remove(tx sdk.Tx) error {
sender := sig.Signer.String()
nonce := sig.Sequence

// if it's an unordered tx, we use the gas instead of the nonce
// if it's an unordered tx, we use the timeout timestamp instead of the nonce
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
gasLimit, err := unordered.GetGasLimit()
nonce = gasLimit
if err != nil {
return err
timestamp := unordered.GetTimeoutTimeStamp().Unix()
if timestamp < 0 {
return errors.New("invalid timestamp value")
}
nonce = uint64(timestamp)
}

scoreKey := txMeta[C]{nonce: nonce, sender: sender}
Expand Down
37 changes: 37 additions & 0 deletions types/mempool/priority_nonce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -970,3 +970,40 @@ func TestNextSenderTx_TxReplacement(t *testing.T) {
iter := mp.Select(ctx, nil)
require.Equal(t, txs[3], iter.Tx())
}

func TestPriorityNonceMempool_UnorderedTx(t *testing.T) {
ctx := sdk.NewContext(nil, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2)
sa := accounts[0].Address
sb := accounts[1].Address

mp := mempool.DefaultPriorityMempool()

now := time.Now()
oneHour := now.Add(1 * time.Hour)
thirtyMin := now.Add(30 * time.Minute)
twoHours := now.Add(2 * time.Hour)
fifteenMin := now.Add(15 * time.Minute)

txs := []testTx{
{id: 1, priority: 0, address: sa, timeout: &thirtyMin, unordered: true},
{id: 0, priority: 0, address: sa, timeout: &oneHour, unordered: true},
{id: 3, priority: 0, address: sb, timeout: &fifteenMin, unordered: true},
{id: 2, priority: 0, address: sb, timeout: &twoHours, unordered: true},
}

for _, tx := range txs {
c := ctx.WithPriority(tx.priority)
require.NoError(t, mp.Insert(c, tx))
}

require.Equal(t, 4, mp.CountTx())

orderedTxs := fetchTxs(mp.Select(ctx, nil), 100000)
require.Equal(t, len(txs), len(orderedTxs))

// check order
for i, tx := range orderedTxs {
require.Equal(t, txs[i].id, tx.(testTx).id)
}
}
28 changes: 14 additions & 14 deletions types/mempool/sender_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,21 @@ func (snm *SenderNonceMempool) Insert(_ context.Context, tx sdk.Tx) error {
sender := sdk.AccAddress(sig.PubKey.Address()).String()
nonce := sig.Sequence

// if it's an unordered tx, we use the timeout timestamp instead of the nonce
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
timestamp := unordered.GetTimeoutTimeStamp().Unix()
if timestamp < 0 {
return errors.New("invalid timestamp value")
}
nonce = uint64(timestamp)
}

senderTxs, found := snm.senders[sender]
if !found {
senderTxs = skiplist.New(skiplist.Uint64)
snm.senders[sender] = senderTxs
}

// if it's an unordered tx, we use the gas instead of the nonce
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
gasLimit, err := unordered.GetGasLimit()
nonce = gasLimit
if err != nil {
return err
}
}

senderTxs.Set(nonce, tx)

key := txKey{nonce: nonce, address: sender}
Expand Down Expand Up @@ -236,13 +236,13 @@ func (snm *SenderNonceMempool) Remove(tx sdk.Tx) error {
sender := sdk.AccAddress(sig.PubKey.Address()).String()
nonce := sig.Sequence

// if it's an unordered tx, we use the gas instead of the nonce
// if it's an unordered tx, we use the timeout timestamp instead of the nonce
if unordered, ok := tx.(sdk.TxWithUnordered); ok && unordered.GetUnordered() {
gasLimit, err := unordered.GetGasLimit()
nonce = gasLimit
if err != nil {
return err
timestamp := unordered.GetTimeoutTimeStamp().Unix()
if timestamp < 0 {
return errors.New("invalid timestamp value")
}
nonce = uint64(timestamp)
}

senderTxs, found := snm.senders[sender]
Expand Down
65 changes: 65 additions & 0 deletions types/mempool/sender_nonce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"math/rand"
"testing"
"time"

"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -192,3 +193,67 @@ func (s *MempoolTestSuite) TestTxNotFoundOnSender() {
err = mp.Remove(tx)
require.Equal(t, mempool.ErrTxNotFound, err)
}

func (s *MempoolTestSuite) TestUnorderedTx() {
t := s.T()

ctx := sdk.NewContext(nil, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2)
sa := accounts[0].Address
sb := accounts[1].Address

mp := mempool.NewSenderNonceMempool(mempool.SenderNonceMaxTxOpt(5000))

now := time.Now()
oneHour := now.Add(1 * time.Hour)
thirtyMin := now.Add(30 * time.Minute)
twoHours := now.Add(2 * time.Hour)
fifteenMin := now.Add(15 * time.Minute)

txs := []testTx{
{id: 0, address: sa, timeout: &oneHour, unordered: true},
{id: 1, address: sa, timeout: &thirtyMin, unordered: true},
{id: 2, address: sb, timeout: &twoHours, unordered: true},
{id: 3, address: sb, timeout: &fifteenMin, unordered: true},
}

for _, tx := range txs {
c := ctx.WithPriority(tx.priority)
require.NoError(t, mp.Insert(c, tx))
}

require.Equal(t, 4, mp.CountTx())

orderedTxs := fetchTxs(mp.Select(ctx, nil), 100000)
require.Equal(t, len(txs), len(orderedTxs))

// Because the sender is selected randomly it can be any of these options
acceptableOptions := [][]int{
{3, 1, 2, 0},
{3, 1, 0, 2},
{3, 2, 1, 0},
{1, 3, 0, 2},
{1, 3, 2, 0},
{1, 0, 3, 2},
}

orderedTxsIds := make([]int, len(orderedTxs))
for i, tx := range orderedTxs {
orderedTxsIds[i] = tx.(testTx).id
}

anyAcceptableOrder := false
for _, option := range acceptableOptions {
for i, tx := range orderedTxs {
if tx.(testTx).id != txs[option[i]].id {
break
}

if i == len(orderedTxs)-1 {
anyAcceptableOrder = true
}
}
}

require.True(t, anyAcceptableOrder, "expected any of %v but got %v", acceptableOptions, orderedTxsIds)
}
32 changes: 32 additions & 0 deletions x/auth/ante/ante_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"strings"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
Expand Down Expand Up @@ -1384,3 +1385,34 @@ func TestAnteHandlerReCheck(t *testing.T) {
_, err = suite.anteHandler(suite.ctx, tx, false)
require.NotNil(t, err, "antehandler on recheck did not fail once feePayer no longer has sufficient funds")
}

func TestAnteHandlerUnorderedTx(t *testing.T) {
suite := SetupTestSuite(t, false)
accs := suite.CreateTestAccounts(1)
msg := testdata.NewTestMsg(accs[0].acc.GetAddress())

// First send a normal sequential tx with sequence 0
suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), authtypes.FeeCollectorName, testdata.NewTestFeeAmount()).Return(nil).AnyTimes()

privs, accNums, accSeqs := []cryptotypes.PrivKey{accs[0].priv}, []uint64{1000}, []uint64{0}
_, err := suite.DeliverMsgs(t, privs, []sdk.Msg{msg}, testdata.NewTestFeeAmount(), testdata.NewTestGasLimit(), accNums, accSeqs, suite.ctx.ChainID(), false)
require.NoError(t, err)

// we try to send another tx with the same sequence, it will fail
_, err = suite.DeliverMsgs(t, privs, []sdk.Msg{msg}, testdata.NewTestFeeAmount(), testdata.NewTestGasLimit(), accNums, accSeqs, suite.ctx.ChainID(), false)
require.Error(t, err)

// now we'll still use the same sequence but because it's unordered, it will be ignored and accepted anyway
msgs := []sdk.Msg{msg}
require.NoError(t, suite.txBuilder.SetMsgs(msgs...))
suite.txBuilder.SetFeeAmount(testdata.NewTestFeeAmount())
suite.txBuilder.SetGasLimit(testdata.NewTestGasLimit())

tx, txErr := suite.CreateTestUnorderedTx(suite.ctx, privs, accNums, accSeqs, suite.ctx.ChainID(), apisigning.SignMode_SIGN_MODE_DIRECT, true, time.Now().Add(time.Minute))
require.NoError(t, txErr)
txBytes, err := suite.clientCtx.TxConfig.TxEncoder()(tx)
bytesCtx := suite.ctx.WithTxBytes(txBytes)
require.NoError(t, err)
_, err = suite.anteHandler(bytesCtx, tx, false)
require.NoError(t, err)
}
22 changes: 14 additions & 8 deletions x/auth/ante/sigverify.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,18 +320,24 @@ func (svd SigVerificationDecorator) consumeSignatureGas(
// verifySig will verify the signature of the provided signer account.
func (svd SigVerificationDecorator) verifySig(ctx context.Context, tx sdk.Tx, acc sdk.AccountI, sig signing.SignatureV2, newlyCreated bool) error {
execMode := svd.ak.GetEnvironment().TransactionService.ExecMode(ctx)
if execMode == transaction.ExecModeCheck {
if sig.Sequence < acc.GetSequence() {
unorderedTx, ok := tx.(sdk.TxWithUnordered)
isUnordered := ok && unorderedTx.GetUnordered()

// only check sequence if the tx is not unordered
if !isUnordered {
if execMode == transaction.ExecModeCheck {
if sig.Sequence < acc.GetSequence() {
return errorsmod.Wrapf(
sdkerrors.ErrWrongSequence,
"account sequence mismatch: expected higher than or equal to %d, got %d", acc.GetSequence(), sig.Sequence,
)
}
} else if sig.Sequence != acc.GetSequence() {
return errorsmod.Wrapf(
sdkerrors.ErrWrongSequence,
"account sequence mismatch, expected higher than or equal to %d, got %d", acc.GetSequence(), sig.Sequence,
"account sequence mismatch: expected %d, got %d", acc.GetSequence(), sig.Sequence,
)
}
} else if sig.Sequence != acc.GetSequence() {
return errorsmod.Wrapf(
sdkerrors.ErrWrongSequence,
"account sequence mismatch: expected %d, got %d", acc.GetSequence(), sig.Sequence,
)
}

// we're in simulation mode, or in ReCheckTx, or context is not
Expand Down
62 changes: 62 additions & 0 deletions x/auth/ante/testutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ante_test
import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
Expand Down Expand Up @@ -241,6 +242,67 @@ func (suite *AnteTestSuite) RunTestCase(t *testing.T, tc TestCase, args TestCase
}
}

func (suite *AnteTestSuite) CreateTestUnorderedTx(
ctx sdk.Context, privs []cryptotypes.PrivKey,
accNums, accSeqs []uint64,
chainID string, signMode apisigning.SignMode,
unordered bool, unorderedTimeout time.Time,
) (xauthsigning.Tx, error) {
suite.txBuilder.SetUnordered(unordered)
suite.txBuilder.SetTimeoutTimestamp(unorderedTimeout)

// First round: we gather all the signer infos. We use the "set empty
// signature" hack to do that.
var sigsV2 []signing.SignatureV2
for i, priv := range privs {
sigV2 := signing.SignatureV2{
PubKey: priv.PubKey(),
Data: &signing.SingleSignatureData{
SignMode: signMode,
Signature: nil,
},
Sequence: accSeqs[i],
}

sigsV2 = append(sigsV2, sigV2)
}
err := suite.txBuilder.SetSignatures(sigsV2...)
if err != nil {
return nil, err
}

// Second round: all signer infos are set, so each signer can sign.
sigsV2 = []signing.SignatureV2{}
for i, priv := range privs {
anyPk, err := codectypes.NewAnyWithValue(priv.PubKey())
if err != nil {
return nil, err
}

signerData := txsigning.SignerData{
Address: sdk.AccAddress(priv.PubKey().Address()).String(),
ChainID: chainID,
AccountNumber: accNums[i],
Sequence: accSeqs[i],
PubKey: &anypb.Any{TypeUrl: anyPk.TypeUrl, Value: anyPk.Value},
}
sigV2, err := tx.SignWithPrivKey(
ctx, signMode, signerData,
suite.txBuilder, priv, suite.clientCtx.TxConfig, accSeqs[i])
if err != nil {
return nil, err
}

sigsV2 = append(sigsV2, sigV2)
}
err = suite.txBuilder.SetSignatures(sigsV2...)
if err != nil {
return nil, err
}

return suite.txBuilder.GetTx(), nil
}

// CreateTestTx is a helper function to create a tx given multiple inputs.
func (suite *AnteTestSuite) CreateTestTx(
ctx sdk.Context, privs []cryptotypes.PrivKey,
Expand Down
Loading

0 comments on commit ddf9e18

Please sign in to comment.