diff --git a/internal/db/migrations/2025-01-22.0-add_locked_tx_hash_to_channel_accounts.sql b/internal/db/migrations/2025-01-22.0-add_locked_tx_hash_to_channel_accounts.sql new file mode 100644 index 0000000..44409b0 --- /dev/null +++ b/internal/db/migrations/2025-01-22.0-add_locked_tx_hash_to_channel_accounts.sql @@ -0,0 +1,9 @@ +-- +migrate Up + +ALTER TABLE channel_accounts +ADD COLUMN locked_tx_hash VARCHAR(64); +CREATE INDEX idx_locked_tx_hash ON channel_accounts(locked_tx_hash); +-- +migrate Down +DROP INDEX IF EXISTS idx_locked_tx_hash; +ALTER TABLE channel_accounts +DROP COLUMN locked_tx_hash; diff --git a/internal/serve/serve.go b/internal/serve/serve.go index 438993f..0574b5b 100644 --- a/internal/serve/serve.go +++ b/internal/serve/serve.go @@ -156,6 +156,8 @@ func initHandlerDeps(cfg Configs) (handlerDeps, error) { } go rpcService.TrackRPCServiceHealth(context.Background()) + channelAccountStore := store.NewChannelAccountModel(dbConnectionPool) + accountService, err := services.NewAccountService(models) if err != nil { return handlerDeps{}, fmt.Errorf("instantiating account service: %w", err) @@ -181,8 +183,10 @@ func initHandlerDeps(cfg Configs) (handlerDeps, error) { // TSS setup tssTxService, err := tssservices.NewTransactionService(tssservices.TransactionServiceOptions{ + DB: dbConnectionPool, DistributionAccountSignatureClient: cfg.DistributionAccountSignatureClient, ChannelAccountSignatureClient: cfg.ChannelAccountSignatureClient, + ChannelAccountStore: channelAccountStore, RPCService: rpcService, BaseFee: int64(cfg.BaseFee), }) @@ -228,10 +232,12 @@ func initHandlerDeps(cfg Configs) (handlerDeps, error) { webhookChannel := tsschannel.NewWebhookChannel(tsschannel.WebhookChannelConfigs{ HTTPClient: &httpClient, Store: tssStore, + ChannelAccountStore: channelAccountStore, MaxBufferSize: cfg.WebhookHandlerServiceChannelMaxBufferSize, MaxWorkers: cfg.WebhookHandlerServiceChannelMaxWorkers, MaxRetries: cfg.WebhookHandlerServiceChannelMaxRetries, MinWaitBtwnRetriesMS: cfg.WebhookHandlerServiceChannelMinWaitBtwnRetriesMS, + NetworkPassphrase: cfg.NetworkPassphrase, }) router := tssrouter.NewRouter(tssrouter.RouterConfigs{ diff --git a/internal/signing/channel_account_db_signature_client.go b/internal/signing/channel_account_db_signature_client.go index c8da348..89b6120 100644 --- a/internal/signing/channel_account_db_signature_client.go +++ b/internal/signing/channel_account_db_signature_client.go @@ -43,9 +43,16 @@ func (sc *channelAccountDBSignatureClient) NetworkPassphrase() string { return sc.networkPassphrase } -func (sc *channelAccountDBSignatureClient) GetAccountPublicKey(ctx context.Context) (string, error) { +func (sc *channelAccountDBSignatureClient) GetAccountPublicKey(ctx context.Context, opts ...int) (string, error) { + var lockedUntil time.Duration + if len(opts) > 0 { + lockedUntil = time.Duration(opts[0]) * time.Second + } else { + lockedUntil = time.Minute + } for range store.ChannelAccountWaitTime { - channelAccount, err := sc.channelAccountStore.GetIdleChannelAccount(ctx, time.Minute) + // check to see if the variadic parameter for time exists and if so, use it here + channelAccount, err := sc.channelAccountStore.GetAndLockIdleChannelAccount(ctx, lockedUntil) if err != nil { if errors.Is(err, store.ErrNoIdleChannelAccountAvailable) { log.Ctx(ctx).Warn("All channel accounts are in use. Retry in 1 second.") diff --git a/internal/signing/channel_account_db_signature_client_test.go b/internal/signing/channel_account_db_signature_client_test.go index fac663d..984e6e4 100644 --- a/internal/signing/channel_account_db_signature_client_test.go +++ b/internal/signing/channel_account_db_signature_client_test.go @@ -30,7 +30,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) { t.Run("returns_error_when_couldn't_get_an_idle_channel_account", func(t *testing.T) { channelAccountStore. - On("GetIdleChannelAccount", ctx, time.Minute). + On("GetAndLockIdleChannelAccount", ctx, time.Duration(100)*time.Second). Return(nil, store.ErrNoIdleChannelAccountAvailable). Times(6). On("Count", ctx). @@ -40,7 +40,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) { getEntries := log.DefaultLogger.StartTest(log.WarnLevel) - publicKey, err := sc.GetAccountPublicKey(ctx) + publicKey, err := sc.GetAccountPublicKey(ctx, 100) assert.ErrorIs(t, err, store.ErrNoIdleChannelAccountAvailable) assert.Empty(t, publicKey) @@ -54,7 +54,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) { t.Run("returns_error_when_there's_no_channel_account_configured", func(t *testing.T) { channelAccountStore. - On("GetIdleChannelAccount", ctx, time.Minute). + On("GetAndLockIdleChannelAccount", ctx, time.Minute). Return(nil, store.ErrNoIdleChannelAccountAvailable). Times(6). On("Count", ctx). @@ -79,7 +79,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) { t.Run("gets_an_idle_channel_account", func(t *testing.T) { channelAccountPublicKey := keypair.MustRandom().Address() channelAccountStore. - On("GetIdleChannelAccount", ctx, time.Minute). + On("GetAndLockIdleChannelAccount", ctx, time.Minute). Return(&store.ChannelAccount{PublicKey: channelAccountPublicKey}, nil). Once() defer channelAccountStore.AssertExpectations(t) diff --git a/internal/signing/env_signature_client.go b/internal/signing/env_signature_client.go index 587c73d..61c00cb 100644 --- a/internal/signing/env_signature_client.go +++ b/internal/signing/env_signature_client.go @@ -36,7 +36,7 @@ func (sc *envSignatureClient) NetworkPassphrase() string { return sc.networkPassphrase } -func (sc *envSignatureClient) GetAccountPublicKey(ctx context.Context) (string, error) { +func (sc *envSignatureClient) GetAccountPublicKey(ctx context.Context, _ ...int) (string, error) { return sc.distributionAccountFull.Address(), nil } diff --git a/internal/signing/kms_signature_client.go b/internal/signing/kms_signature_client.go index fa853dc..a297f1f 100644 --- a/internal/signing/kms_signature_client.go +++ b/internal/signing/kms_signature_client.go @@ -59,7 +59,7 @@ func NewKMSSignatureClient(publicKey string, networkPassphrase string, keypairSt }, nil } -func (sc *kmsSignatureClient) GetAccountPublicKey(ctx context.Context) (string, error) { +func (sc *kmsSignatureClient) GetAccountPublicKey(ctx context.Context, _ ...int) (string, error) { return sc.distributionAccountPublicKey, nil } diff --git a/internal/signing/mocks.go b/internal/signing/mocks.go index 827f366..e1af79e 100644 --- a/internal/signing/mocks.go +++ b/internal/signing/mocks.go @@ -18,7 +18,7 @@ func (s *SignatureClientMock) NetworkPassphrase() string { return args.String(0) } -func (s *SignatureClientMock) GetAccountPublicKey(ctx context.Context) (string, error) { +func (s *SignatureClientMock) GetAccountPublicKey(ctx context.Context, _ ...int) (string, error) { args := s.Called(ctx) return args.String(0), args.Error(1) } diff --git a/internal/signing/signature_client.go b/internal/signing/signature_client.go index 87f26a8..bc62198 100644 --- a/internal/signing/signature_client.go +++ b/internal/signing/signature_client.go @@ -15,7 +15,7 @@ var ( type SignatureClient interface { NetworkPassphrase() string - GetAccountPublicKey(ctx context.Context) (string, error) + GetAccountPublicKey(ctx context.Context, opts ...int) (string, error) SignStellarTransaction(ctx context.Context, tx *txnbuild.Transaction, stellarAccounts ...string) (*txnbuild.Transaction, error) SignStellarFeeBumpTransaction(ctx context.Context, feeBumpTx *txnbuild.FeeBumpTransaction) (*txnbuild.FeeBumpTransaction, error) } diff --git a/internal/signing/store/channel_accounts_model.go b/internal/signing/store/channel_accounts_model.go index 2b38bce..d3dd1f2 100644 --- a/internal/signing/store/channel_accounts_model.go +++ b/internal/signing/store/channel_accounts_model.go @@ -25,10 +25,11 @@ type ChannelAccountModel struct { var _ ChannelAccountStore = (*ChannelAccountModel)(nil) -func (ca *ChannelAccountModel) GetIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) { +func (ca *ChannelAccountModel) GetAndLockIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) { query := fmt.Sprintf(` UPDATE channel_accounts SET + locked_tx_hash = NULL, locked_at = NOW(), locked_until = NOW() + INTERVAL '%d seconds' WHERE public_key = ( @@ -36,13 +37,12 @@ func (ca *ChannelAccountModel) GetIdleChannelAccount(ctx context.Context, locked public_key FROM channel_accounts WHERE - locked_until IS NULL - OR locked_until < NOW() + (locked_tx_hash IS NULL AND (locked_until IS NULL OR locked_until < NOW())) ORDER BY random() LIMIT 1 FOR UPDATE SKIP LOCKED ) - RETURNING * + RETURNING *; `, int64(lockedUntil.Seconds())) var channelAccount ChannelAccount @@ -84,6 +84,24 @@ func (ca *ChannelAccountModel) GetAllByPublicKey(ctx context.Context, sqlExec db return channelAccounts, nil } +func (ca *ChannelAccountModel) AssignTxToChannelAccount(ctx context.Context, publicKey string, txHash string) error { + const query = `UPDATE channel_accounts SET locked_tx_hash = $1 WHERE public_key = $2` + _, err := ca.DB.ExecContext(ctx, query, txHash, publicKey) + if err != nil { + return fmt.Errorf("assigning channel account: %w", err) + } + return nil +} + +func (ca *ChannelAccountModel) UnassignTxAndUnlockChannelAccount(ctx context.Context, txHash string) error { + const query = `UPDATE channel_accounts SET locked_tx_hash = NULL, locked_at = NULL, locked_until = NULL WHERE locked_tx_hash = $1` + _, err := ca.DB.ExecContext(ctx, query, txHash) + if err != nil { + return fmt.Errorf("unlocking channel account: %w", err) + } + return nil +} + func (ca *ChannelAccountModel) BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error { if len(channelAccounts) == 0 { return nil diff --git a/internal/signing/store/channel_accounts_model_test.go b/internal/signing/store/channel_accounts_model_test.go index b64b4af..2cd7f0d 100644 --- a/internal/signing/store/channel_accounts_model_test.go +++ b/internal/signing/store/channel_accounts_model_test.go @@ -25,7 +25,7 @@ func createChannelAccountFixture(t *testing.T, ctx context.Context, dbConnection require.NoError(t, err) } -func TestChannelAccountModelGetIdleChannelAccount(t *testing.T) { +func TestChannelAccountModelGetAndLockIdleChannelAccount(t *testing.T) { dbt := dbtest.Open(t) defer dbt.Close() @@ -46,14 +46,15 @@ func TestChannelAccountModelGetIdleChannelAccount(t *testing.T) { channel_accounts SET locked_at = NOW(), - locked_until = NOW() + '5 minutes'::INTERVAL + locked_until = NOW() + '5 minutes'::INTERVAL, + locked_tx_hash = 'hash' WHERE public_key = ANY($1) ` _, err := dbConnectionPool.ExecContext(ctx, lockChannelAccountQuery, pq.Array([]string{channelAccount1.Address(), channelAccount2.Address()})) require.NoError(t, err) - ca, err := m.GetIdleChannelAccount(ctx, time.Minute) + ca, err := m.GetAndLockIdleChannelAccount(ctx, time.Minute) assert.ErrorIs(t, err, ErrNoIdleChannelAccountAvailable) assert.Nil(t, ca) }) @@ -64,18 +65,19 @@ func TestChannelAccountModelGetIdleChannelAccount(t *testing.T) { createChannelAccountFixture(t, ctx, dbConnectionPool, ChannelAccount{PublicKey: channelAccount1.Address(), EncryptedPrivateKey: channelAccount1.Seed()}, ChannelAccount{PublicKey: channelAccount2.Address(), EncryptedPrivateKey: channelAccount2.Seed()}) const lockChannelAccountQuery = ` - UPDATE - channel_accounts - SET - locked_at = NOW(), - locked_until = NOW() + '5 minutes'::INTERVAL - WHERE - public_key = $1 - ` + UPDATE + channel_accounts + SET + locked_at = NOW(), + locked_until = NOW() + '5 minutes'::INTERVAL, + locked_tx_hash = 'hash' + WHERE + public_key = $1 + ` _, err := dbConnectionPool.ExecContext(ctx, lockChannelAccountQuery, channelAccount1.Address()) require.NoError(t, err) - ca, err := m.GetIdleChannelAccount(ctx, time.Minute) + ca, err := m.GetAndLockIdleChannelAccount(ctx, time.Minute) require.NoError(t, err) assert.Equal(t, ca.PublicKey, channelAccount2.Address()) assert.Equal(t, ca.EncryptedPrivateKey, channelAccount2.Seed()) @@ -128,6 +130,56 @@ func TestChannelAccountModelGetAllByPublicKey(t *testing.T) { assert.Equal(t, channelAccount2.Address(), channelAccounts[1].PublicKey) } +func TestAssignTxToChannelAccount(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + m := NewChannelAccountModel(dbConnectionPool) + + channelAccount := keypair.MustRandom() + createChannelAccountFixture(t, ctx, dbConnectionPool, ChannelAccount{PublicKey: channelAccount.Address(), EncryptedPrivateKey: channelAccount.Seed()}) + + err = m.AssignTxToChannelAccount(ctx, channelAccount.Address(), "txhash") + assert.NoError(t, err) + channelAccountFromDB, err := m.Get(ctx, dbConnectionPool, channelAccount.Address()) + assert.NoError(t, err) + assert.Equal(t, "txhash", channelAccountFromDB.LockedTxHash.String) + +} + +func TestUnlockChannelAccountFromTx(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + ctx := context.Background() + m := NewChannelAccountModel(dbConnectionPool) + + channelAccount := keypair.MustRandom() + createChannelAccountFixture(t, ctx, dbConnectionPool, ChannelAccount{PublicKey: channelAccount.Address(), EncryptedPrivateKey: channelAccount.Seed()}) + err = m.AssignTxToChannelAccount(ctx, channelAccount.Address(), "txhash") + assert.NoError(t, err) + channelAccountFromDB, err := m.Get(ctx, dbConnectionPool, channelAccount.Address()) + assert.NoError(t, err) + assert.Equal(t, "txhash", channelAccountFromDB.LockedTxHash.String) + + err = m.UnassignTxAndUnlockChannelAccount(ctx, "txhash") + assert.NoError(t, err) + channelAccountFromDB, err = m.Get(ctx, dbConnectionPool, channelAccount.Address()) + assert.NoError(t, err) + assert.False(t, channelAccountFromDB.LockedTxHash.Valid) + assert.False(t, channelAccountFromDB.LockedAt.Valid) + assert.False(t, channelAccountFromDB.LockedUntil.Valid) +} + func TestChannelAccountModelBatchInsert(t *testing.T) { dbt := dbtest.Open(t) defer dbt.Close() diff --git a/internal/signing/store/mocks.go b/internal/signing/store/mocks.go index d98e879..f9fac72 100644 --- a/internal/signing/store/mocks.go +++ b/internal/signing/store/mocks.go @@ -14,7 +14,7 @@ type ChannelAccountStoreMock struct { var _ ChannelAccountStore = (*ChannelAccountStoreMock)(nil) -func (s *ChannelAccountStoreMock) GetIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) { +func (s *ChannelAccountStoreMock) GetAndLockIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) { args := s.Called(ctx, lockedUntil) if args.Get(0) == nil { return nil, args.Error(1) @@ -38,6 +38,16 @@ func (s *ChannelAccountStoreMock) GetAllByPublicKey(ctx context.Context, sqlExec return args.Get(0).([]*ChannelAccount), args.Error(1) } +func (s *ChannelAccountStoreMock) AssignTxToChannelAccount(ctx context.Context, publicKey string, txHash string) error { + args := s.Called(ctx, publicKey, txHash) + return args.Error(0) +} + +func (s *ChannelAccountStoreMock) UnassignTxAndUnlockChannelAccount(ctx context.Context, txHash string) error { + args := s.Called(ctx, txHash) + return args.Error(0) +} + func (s *ChannelAccountStoreMock) BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error { args := s.Called(ctx, sqlExec, channelAccounts) return args.Error(0) diff --git a/internal/signing/store/types.go b/internal/signing/store/types.go index 06eb270..968e21b 100644 --- a/internal/signing/store/types.go +++ b/internal/signing/store/types.go @@ -9,18 +9,21 @@ import ( ) type ChannelAccount struct { - PublicKey string `db:"public_key"` - EncryptedPrivateKey string `db:"encrypted_private_key"` - UpdatedAt time.Time `db:"updated_at"` - CreatedAt time.Time `db:"created_at"` - LockedAt sql.NullTime `db:"locked_at"` - LockedUntil sql.NullTime `db:"locked_until"` + PublicKey string `db:"public_key"` + EncryptedPrivateKey string `db:"encrypted_private_key"` + UpdatedAt time.Time `db:"updated_at"` + CreatedAt time.Time `db:"created_at"` + LockedAt sql.NullTime `db:"locked_at"` + LockedUntil sql.NullTime `db:"locked_until"` + LockedTxHash sql.NullString `db:"locked_tx_hash"` } type ChannelAccountStore interface { - GetIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) + GetAndLockIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) Get(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) (*ChannelAccount, error) GetAllByPublicKey(ctx context.Context, sqlExec db.SQLExecuter, publicKeys ...string) ([]*ChannelAccount, error) + AssignTxToChannelAccount(ctx context.Context, publicKey string, txHash string) error + UnassignTxAndUnlockChannelAccount(ctx context.Context, txHash string) error BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error Count(ctx context.Context) (int64, error) } diff --git a/internal/tss/channels/webhook_channel.go b/internal/tss/channels/webhook_channel.go index 17a9839..b1f2bf8 100644 --- a/internal/tss/channels/webhook_channel.go +++ b/internal/tss/channels/webhook_channel.go @@ -4,11 +4,14 @@ import ( "bytes" "context" "encoding/json" + "fmt" "net/http" "time" "github.com/alitto/pond" "github.com/stellar/go/support/log" + "github.com/stellar/go/txnbuild" + channelAccountStore "github.com/stellar/wallet-backend/internal/signing/store" "github.com/stellar/wallet-backend/internal/tss" "github.com/stellar/wallet-backend/internal/tss/store" tssutils "github.com/stellar/wallet-backend/internal/tss/utils" @@ -16,20 +19,24 @@ import ( ) type WebhookChannelConfigs struct { - HTTPClient utils.HTTPClient Store store.Store - MaxBufferSize int - MaxWorkers int + ChannelAccountStore channelAccountStore.ChannelAccountStore + HTTPClient utils.HTTPClient MaxRetries int MinWaitBtwnRetriesMS int + NetworkPassphrase string + MaxBufferSize int + MaxWorkers int } type webhookPool struct { Pool *pond.WorkerPool Store store.Store + ChannelAccountStore channelAccountStore.ChannelAccountStore HTTPClient utils.HTTPClient MaxRetries int MinWaitBtwnRetriesMS int + NetworkPassphrase string } var WebhookChannelName = "WebhookChannel" @@ -41,9 +48,11 @@ func NewWebhookChannel(cfg WebhookChannelConfigs) *webhookPool { return &webhookPool{ Pool: pool, Store: cfg.Store, + ChannelAccountStore: cfg.ChannelAccountStore, HTTPClient: cfg.HTTPClient, MaxRetries: cfg.MaxRetries, MinWaitBtwnRetriesMS: cfg.MinWaitBtwnRetriesMS, + NetworkPassphrase: cfg.NetworkPassphrase, } } @@ -64,6 +73,10 @@ func (p *webhookPool) Receive(payload tss.Payload) { var i int sent := false ctx := context.Background() + err = p.UnlockChannelAccount(ctx, payload.TransactionXDR) + if err != nil { + log.Errorf("%s: error unlocking channel account from transaction: %e", WebhookChannelName, err) + } for i = 0; i < p.MaxRetries; i++ { httpResp, err := p.HTTPClient.Post(payload.WebhookURL, "application/json", bytes.NewBuffer(jsonData)) if err != nil { @@ -93,6 +106,31 @@ func (p *webhookPool) Receive(payload tss.Payload) { } +func (p *webhookPool) UnlockChannelAccount(ctx context.Context, txXDR string) error { + genericTx, err := txnbuild.TransactionFromXDR(txXDR) + if err != nil { + return fmt.Errorf("bad transaction xdr: %w", err) + } + var tx *txnbuild.Transaction + feeBumpTx, isFeeBumpTx := genericTx.FeeBump() + if isFeeBumpTx { + tx = feeBumpTx.InnerTransaction() + } + simpleTx, isTransaction := genericTx.Transaction() + if isTransaction { + tx = simpleTx + } + txHash, err := tx.HashHex(p.NetworkPassphrase) + if err != nil { + return fmt.Errorf("unable to hashhex transaction: %w", err) + } + err = p.ChannelAccountStore.UnassignTxAndUnlockChannelAccount(ctx, txHash) + if err != nil { + return fmt.Errorf("unable to unlock channel account associated with transaction: %w", err) + } + return nil +} + func (p *webhookPool) Stop() { p.Pool.StopAndWait() } diff --git a/internal/tss/channels/webhook_channel_test.go b/internal/tss/channels/webhook_channel_test.go index fe20572..f8eabae 100644 --- a/internal/tss/channels/webhook_channel_test.go +++ b/internal/tss/channels/webhook_channel_test.go @@ -4,13 +4,17 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "net/http" "strings" "testing" + "github.com/stellar/go/keypair" + "github.com/stellar/go/txnbuild" "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/db/dbtest" + channelAccountStore "github.com/stellar/wallet-backend/internal/signing/store" "github.com/stellar/wallet-backend/internal/tss" "github.com/stellar/wallet-backend/internal/tss/store" tssutils "github.com/stellar/wallet-backend/internal/tss/utils" @@ -26,14 +30,17 @@ func TestWebhookHandlerServiceChannel(t *testing.T) { require.NoError(t, err) defer dbConnectionPool.Close() store, _ := store.NewStore(dbConnectionPool) + channelAccountStore := channelAccountStore.ChannelAccountStoreMock{} mockHTTPClient := utils.MockHTTPClient{} cfg := WebhookChannelConfigs{ HTTPClient: &mockHTTPClient, Store: store, + ChannelAccountStore: &channelAccountStore, MaxBufferSize: 1, MaxWorkers: 1, MaxRetries: 3, MinWaitBtwnRetriesMS: 5, + NetworkPassphrase: "networkpassphrase", } channel := NewWebhookChannel(cfg) @@ -72,3 +79,91 @@ func TestWebhookHandlerServiceChannel(t *testing.T) { assert.Equal(t, string(tss.SentStatus), tx.Status) assert.NoError(t, err) } + +func TestUnlockChannelAccount(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + store, _ := store.NewStore(dbConnectionPool) + channelAccountStore := channelAccountStore.ChannelAccountStoreMock{} + mockHTTPClient := utils.MockHTTPClient{} + cfg := WebhookChannelConfigs{ + HTTPClient: &mockHTTPClient, + Store: store, + ChannelAccountStore: &channelAccountStore, + MaxBufferSize: 1, + MaxWorkers: 1, + MaxRetries: 3, + MinWaitBtwnRetriesMS: 5, + NetworkPassphrase: "networkpassphrase", + } + channel := NewWebhookChannel(cfg) + account := keypair.MustRandom() + tx, err := txnbuild.NewTransaction(txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{AccountID: account.Address()}, + IncrementSequenceNum: true, + Operations: []txnbuild.Operation{ + &txnbuild.Payment{ + Destination: keypair.MustRandom().Address(), + Amount: "10", + Asset: txnbuild.NativeAsset{}, + }, + }, + BaseFee: txnbuild.MinBaseFee, + Preconditions: txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(30)}, + }) + require.NoError(t, err) + + distributionAccount := keypair.MustRandom() + + feeBumpTx, err := txnbuild.NewFeeBumpTransaction(txnbuild.FeeBumpTransactionParams{ + Inner: tx, + FeeAccount: distributionAccount.Address(), + BaseFee: txnbuild.MinBaseFee, + }) + require.NoError(t, err) + + t.Run("simple_tx", func(t *testing.T) { + txXDR, err := tx.Base64() + assert.NoError(t, err) + txHash, err := tx.HashHex(cfg.NetworkPassphrase) + assert.NoError(t, err) + channelAccountStore. + On("UnassignTxAndUnlockChannelAccount", context.Background(), txHash). + Return(nil). + Once() + + err = channel.UnlockChannelAccount(context.Background(), txXDR) + assert.NoError(t, err) + }) + + t.Run("feebump_tx", func(t *testing.T) { + txXDR, err := feeBumpTx.Base64() + assert.NoError(t, err) + txHash, err := tx.HashHex(cfg.NetworkPassphrase) + assert.NoError(t, err) + channelAccountStore. + On("UnassignTxAndUnlockChannelAccount", context.Background(), txHash). + Return(nil). + Once() + + err = channel.UnlockChannelAccount(context.Background(), txXDR) + assert.NoError(t, err) + }) + + t.Run("unlock_channel_account_from_tx_returns_error", func(t *testing.T) { + txXDR, err := tx.Base64() + assert.NoError(t, err) + txHash, err := tx.HashHex(cfg.NetworkPassphrase) + assert.NoError(t, err) + channelAccountStore. + On("UnassignTxAndUnlockChannelAccount", context.Background(), txHash). + Return(errors.New("unabe to unlock channel account")). + Once() + + err = channel.UnlockChannelAccount(context.Background(), txXDR) + assert.Equal(t, "unable to unlock channel account associated with transaction: unabe to unlock channel account", err.Error()) + }) +} diff --git a/internal/tss/services/transaction_service.go b/internal/tss/services/transaction_service.go index 179584b..ebccabf 100644 --- a/internal/tss/services/transaction_service.go +++ b/internal/tss/services/transaction_service.go @@ -6,8 +6,10 @@ import ( "github.com/stellar/go/txnbuild" + "github.com/stellar/wallet-backend/internal/db" "github.com/stellar/wallet-backend/internal/services" "github.com/stellar/wallet-backend/internal/signing" + "github.com/stellar/wallet-backend/internal/signing/store" ) type TransactionService interface { @@ -17,8 +19,10 @@ type TransactionService interface { } type transactionService struct { + DB db.ConnectionPool DistributionAccountSignatureClient signing.SignatureClient ChannelAccountSignatureClient signing.SignatureClient + ChannelAccountStore store.ChannelAccountStore RPCService services.RPCService BaseFee int64 } @@ -26,13 +30,18 @@ type transactionService struct { var _ TransactionService = (*transactionService)(nil) type TransactionServiceOptions struct { + DB db.ConnectionPool DistributionAccountSignatureClient signing.SignatureClient ChannelAccountSignatureClient signing.SignatureClient + ChannelAccountStore store.ChannelAccountStore RPCService services.RPCService BaseFee int64 } func (o *TransactionServiceOptions) ValidateOptions() error { + if o.DB == nil { + return fmt.Errorf("DB cannot be nil") + } if o.DistributionAccountSignatureClient == nil { return fmt.Errorf("distribution account signature client cannot be nil") } @@ -45,6 +54,10 @@ func (o *TransactionServiceOptions) ValidateOptions() error { return fmt.Errorf("channel account signature client cannot be nil") } + if o.ChannelAccountStore == nil { + return fmt.Errorf("channel account store cannot be nil") + } + if o.BaseFee < int64(txnbuild.MinBaseFee) { return fmt.Errorf("base fee is lower than the minimum network fee") } @@ -57,8 +70,10 @@ func NewTransactionService(opts TransactionServiceOptions) (*transactionService, return nil, err } return &transactionService{ + DB: opts.DB, DistributionAccountSignatureClient: opts.DistributionAccountSignatureClient, ChannelAccountSignatureClient: opts.ChannelAccountSignatureClient, + ChannelAccountStore: opts.ChannelAccountStore, RPCService: opts.RPCService, BaseFee: opts.BaseFee, }, nil @@ -69,30 +84,47 @@ func (t *transactionService) NetworkPassphrase() string { } func (t *transactionService) BuildAndSignTransactionWithChannelAccount(ctx context.Context, operations []txnbuild.Operation, timeoutInSecs int64) (*txnbuild.Transaction, error) { - channelAccountPublicKey, err := t.ChannelAccountSignatureClient.GetAccountPublicKey(ctx) - if err != nil { - return nil, fmt.Errorf("getting channel account public key: %w", err) - } - channelAccountSeq, err := t.RPCService.GetAccountLedgerSequence(channelAccountPublicKey) - if err != nil { - return nil, fmt.Errorf("getting ledger sequence for channel account public key: %s: %w", channelAccountPublicKey, err) - } - tx, err := txnbuild.NewTransaction( - txnbuild.TransactionParams{ - SourceAccount: &txnbuild.SimpleAccount{ - AccountID: channelAccountPublicKey, - Sequence: channelAccountSeq, - }, - Operations: operations, - BaseFee: int64(t.BaseFee), - Preconditions: txnbuild.Preconditions{ - TimeBounds: txnbuild.NewTimeout(timeoutInSecs), + var tx *txnbuild.Transaction + var channelAccountPublicKey string + err := db.RunInTransaction(ctx, t.DB, nil, func(dbTx db.Transaction) error { + var err error + channelAccountPublicKey, err = t.ChannelAccountSignatureClient.GetAccountPublicKey(ctx, int(timeoutInSecs)) + if err != nil { + return fmt.Errorf("getting channel account public key: %w", err) + } + channelAccountSeq, err := t.RPCService.GetAccountLedgerSequence(channelAccountPublicKey) + if err != nil { + return fmt.Errorf("getting ledger sequence for channel account public key: %s: %w", channelAccountPublicKey, err) + } + tx, err = txnbuild.NewTransaction( + txnbuild.TransactionParams{ + SourceAccount: &txnbuild.SimpleAccount{ + AccountID: channelAccountPublicKey, + Sequence: channelAccountSeq, + }, + Operations: operations, + BaseFee: int64(t.BaseFee), + Preconditions: txnbuild.Preconditions{ + TimeBounds: txnbuild.NewTimeout(timeoutInSecs), + }, + IncrementSequenceNum: true, }, - IncrementSequenceNum: true, - }, - ) + ) + if err != nil { + return fmt.Errorf("building transaction: %w", err) + } + txHash, err := tx.HashHex(t.ChannelAccountSignatureClient.NetworkPassphrase()) + if err != nil { + return fmt.Errorf("unable to hashhex transaction: %w", err) + } + err = t.ChannelAccountStore.AssignTxToChannelAccount(ctx, channelAccountPublicKey, txHash) + if err != nil { + return fmt.Errorf("assigning channel account to tx: %w", err) + } + return nil + }) if err != nil { - return nil, fmt.Errorf("building transaction: %w", err) + return nil, err } tx, err = t.ChannelAccountSignatureClient.SignStellarTransaction(ctx, tx, channelAccountPublicKey) if err != nil { diff --git a/internal/tss/services/transaction_service_test.go b/internal/tss/services/transaction_service_test.go index d037d86..18313df 100644 --- a/internal/tss/services/transaction_service_test.go +++ b/internal/tss/services/transaction_service_test.go @@ -10,17 +10,40 @@ import ( "github.com/stellar/go/txnbuild" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stellar/wallet-backend/internal/db" + "github.com/stellar/wallet-backend/internal/db/dbtest" "github.com/stellar/wallet-backend/internal/services" "github.com/stellar/wallet-backend/internal/signing" + "github.com/stellar/wallet-backend/internal/signing/store" "github.com/stellar/wallet-backend/internal/tss/utils" ) func TestValidateOptions(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + t.Run("return_error_when_db_nil", func(t *testing.T) { + opts := TransactionServiceOptions{ + DistributionAccountSignatureClient: nil, + ChannelAccountSignatureClient: &signing.SignatureClientMock{}, + ChannelAccountStore: &store.ChannelAccountStoreMock{}, + RPCService: &services.RPCServiceMock{}, + BaseFee: 114, + } + err := opts.ValidateOptions() + assert.Equal(t, "DB cannot be nil", err.Error()) + + }) t.Run("return_error_when_distribution_signature_client_nil", func(t *testing.T) { opts := TransactionServiceOptions{ + DB: dbConnectionPool, DistributionAccountSignatureClient: nil, ChannelAccountSignatureClient: &signing.SignatureClientMock{}, + ChannelAccountStore: &store.ChannelAccountStoreMock{}, RPCService: &services.RPCServiceMock{}, BaseFee: 114, } @@ -31,8 +54,10 @@ func TestValidateOptions(t *testing.T) { t.Run("return_error_when_channel_signature_client_nil", func(t *testing.T) { opts := TransactionServiceOptions{ + DB: dbConnectionPool, DistributionAccountSignatureClient: &signing.SignatureClientMock{}, ChannelAccountSignatureClient: nil, + ChannelAccountStore: &store.ChannelAccountStoreMock{}, RPCService: &services.RPCServiceMock{}, BaseFee: 114, } @@ -40,10 +65,25 @@ func TestValidateOptions(t *testing.T) { assert.Equal(t, "channel account signature client cannot be nil", err.Error()) }) + t.Run("return_error_when_channel_account_store_nil", func(t *testing.T) { + opts := TransactionServiceOptions{ + DB: dbConnectionPool, + DistributionAccountSignatureClient: &signing.SignatureClientMock{}, + ChannelAccountSignatureClient: &signing.SignatureClientMock{}, + ChannelAccountStore: nil, + RPCService: &services.RPCServiceMock{}, + BaseFee: 114, + } + err := opts.ValidateOptions() + assert.Equal(t, "channel account store cannot be nil", err.Error()) + }) + t.Run("return_error_when_rpc_client_nil", func(t *testing.T) { opts := TransactionServiceOptions{ + DB: dbConnectionPool, DistributionAccountSignatureClient: &signing.SignatureClientMock{}, ChannelAccountSignatureClient: &signing.SignatureClientMock{}, + ChannelAccountStore: &store.ChannelAccountStoreMock{}, RPCService: nil, BaseFee: 114, } @@ -53,8 +93,10 @@ func TestValidateOptions(t *testing.T) { t.Run("return_error_when_base_fee_too_low", func(t *testing.T) { opts := TransactionServiceOptions{ + DB: dbConnectionPool, DistributionAccountSignatureClient: &signing.SignatureClientMock{}, ChannelAccountSignatureClient: &signing.SignatureClientMock{}, + ChannelAccountStore: &store.ChannelAccountStoreMock{}, RPCService: &services.RPCServiceMock{}, BaseFee: txnbuild.MinBaseFee - 10, } @@ -64,17 +106,24 @@ func TestValidateOptions(t *testing.T) { } func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() distributionAccountSignatureClient := signing.SignatureClientMock{} channelAccountSignatureClient := signing.SignatureClientMock{} - defer channelAccountSignatureClient.AssertExpectations(t) + channelAccountStore := store.ChannelAccountStoreMock{} mockRPCService := &services.RPCServiceMock{} txService, _ := NewTransactionService(TransactionServiceOptions{ + DB: dbConnectionPool, DistributionAccountSignatureClient: &distributionAccountSignatureClient, ChannelAccountSignatureClient: &channelAccountSignatureClient, + ChannelAccountStore: &channelAccountStore, RPCService: mockRPCService, BaseFee: 114, }) - + atomicTxErrorPrefix := "running atomic function in RunInTransactionWithResult: " t.Run("channel_account_signature_client_get_account_public_key_err", func(t *testing.T) { channelAccountSignatureClient. On("GetAccountPublicKey", context.Background()). @@ -85,7 +134,7 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { channelAccountSignatureClient.AssertExpectations(t) assert.Empty(t, tx) - assert.Equal(t, "getting channel account public key: channel accounts unavailable", err.Error()) + assert.Equal(t, atomicTxErrorPrefix+"getting channel account public key: channel accounts unavailable", err.Error()) }) t.Run("rpc_client_get_account_seq_err", func(t *testing.T) { @@ -106,7 +155,7 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { channelAccountSignatureClient.AssertExpectations(t) assert.Empty(t, tx) expectedErr := fmt.Errorf("getting ledger sequence for channel account public key: %s: rpc service down", channelAccount.Address()) - assert.Equal(t, expectedErr.Error(), err.Error()) + assert.Equal(t, atomicTxErrorPrefix+expectedErr.Error(), err.Error()) }) t.Run("build_tx_fails", func(t *testing.T) { @@ -126,20 +175,63 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { channelAccountSignatureClient.AssertExpectations(t) assert.Empty(t, tx) - assert.Equal(t, "building transaction: transaction has no operations", err.Error()) + assert.Equal(t, atomicTxErrorPrefix+"building transaction: transaction has no operations", err.Error()) }) + t.Run("lock_channel_account_to_tx_err", func(t *testing.T) { + channelAccount := keypair.MustRandom() + channelAccountSignatureClient. + On("GetAccountPublicKey", context.Background()). + Return(channelAccount.Address(), nil). + Once(). + On("NetworkPassphrase"). + Return("networkpassphrase"). + Once() + + channelAccountStore. + On("AssignTxToChannelAccount", context.Background(), channelAccount.Address(), mock.AnythingOfType("string")). + Return(errors.New("unable to assign channel account to tx")). + Once() + + mockRPCService. + On("GetAccountLedgerSequence", channelAccount.Address()). + Return(int64(1), nil). + Once() + defer mockRPCService.AssertExpectations(t) + + payment := txnbuild.Payment{ + Destination: keypair.MustRandom().Address(), + Amount: "10", + Asset: txnbuild.NativeAsset{}, + SourceAccount: keypair.MustRandom().Address(), + } + tx, err := txService.BuildAndSignTransactionWithChannelAccount(context.Background(), []txnbuild.Operation{&payment}, 30) + + channelAccountSignatureClient.AssertExpectations(t) + channelAccountStore.AssertExpectations(t) + assert.Empty(t, tx) + assert.Equal(t, atomicTxErrorPrefix+"assigning channel account to tx: unable to assign channel account to tx", err.Error()) + }) + t.Run("sign_stellar_transaction_w_channel_account_err", func(t *testing.T) { channelAccount := keypair.MustRandom() channelAccountSignatureClient. On("GetAccountPublicKey", context.Background()). Return(channelAccount.Address(), nil). Once(). + On("NetworkPassphrase"). + Return("networkpassphrase"). + Once(). On("SignStellarTransaction", context.Background(), mock.AnythingOfType("*txnbuild.Transaction"), []string{channelAccount.Address()}). Return(nil, errors.New("unable to sign")). Once() + channelAccountStore. + On("AssignTxToChannelAccount", context.Background(), channelAccount.Address(), mock.AnythingOfType("string")). + Return(nil). + Once() + mockRPCService. On("GetAccountLedgerSequence", channelAccount.Address()). Return(int64(1), nil). @@ -155,6 +247,7 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { tx, err := txService.BuildAndSignTransactionWithChannelAccount(context.Background(), []txnbuild.Operation{&payment}, 30) channelAccountSignatureClient.AssertExpectations(t) + channelAccountStore.AssertExpectations(t) assert.Empty(t, tx) assert.Equal(t, "signing transaction with channel account: unable to sign", err.Error()) }) @@ -166,10 +259,17 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { On("GetAccountPublicKey", context.Background()). Return(channelAccount.Address(), nil). Once(). + On("NetworkPassphrase"). + Return("networkpassphrase"). On("SignStellarTransaction", context.Background(), mock.AnythingOfType("*txnbuild.Transaction"), []string{channelAccount.Address()}). Return(signedTx, nil). Once() + channelAccountStore. + On("AssignTxToChannelAccount", context.Background(), channelAccount.Address(), mock.AnythingOfType("string")). + Return(nil). + Once() + mockRPCService. On("GetAccountLedgerSequence", channelAccount.Address()). Return(int64(1), nil). @@ -185,18 +285,27 @@ func TestBuildAndSignTransactionWithChannelAccount(t *testing.T) { tx, err := txService.BuildAndSignTransactionWithChannelAccount(context.Background(), []txnbuild.Operation{&payment}, 30) channelAccountSignatureClient.AssertExpectations(t) + channelAccountStore.AssertExpectations(t) assert.Equal(t, signedTx, tx) assert.NoError(t, err) }) } func TestBuildFeeBumpTransaction(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() distributionAccountSignatureClient := signing.SignatureClientMock{} channelAccountSignatureClient := signing.SignatureClientMock{} + channelAccountStore := store.ChannelAccountStoreMock{} mockRPCService := &services.RPCServiceMock{} txService, _ := NewTransactionService(TransactionServiceOptions{ + DB: dbConnectionPool, DistributionAccountSignatureClient: &distributionAccountSignatureClient, ChannelAccountSignatureClient: &channelAccountSignatureClient, + ChannelAccountStore: &channelAccountStore, RPCService: mockRPCService, BaseFee: 114, })