Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Locking channel accounts to transactions #116

Merged
merged 5 commits into from
Feb 11, 2025
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-- +migrate Up

ALTER TABLE channel_accounts
ADD COLUMN locked_tx_hash VARCHAR(64);
CREATE INDEX idx_locked_tx_hash ON channel_accounts(locked_tx_hash);
-- +migrate Down
DROP INDEX IF EXISTS idx_locked_tx_hash;
ALTER TABLE channel_accounts
DROP COLUMN locked_tx_hash;
6 changes: 6 additions & 0 deletions internal/serve/serve.go
Original file line number Diff line number Diff line change
@@ -156,6 +156,8 @@ func initHandlerDeps(cfg Configs) (handlerDeps, error) {
}
go rpcService.TrackRPCServiceHealth(context.Background())

channelAccountStore := store.NewChannelAccountModel(dbConnectionPool)

accountService, err := services.NewAccountService(models)
if err != nil {
return handlerDeps{}, fmt.Errorf("instantiating account service: %w", err)
@@ -181,8 +183,10 @@ func initHandlerDeps(cfg Configs) (handlerDeps, error) {

// TSS setup
tssTxService, err := tssservices.NewTransactionService(tssservices.TransactionServiceOptions{
DB: dbConnectionPool,
DistributionAccountSignatureClient: cfg.DistributionAccountSignatureClient,
ChannelAccountSignatureClient: cfg.ChannelAccountSignatureClient,
ChannelAccountStore: channelAccountStore,
RPCService: rpcService,
BaseFee: int64(cfg.BaseFee),
})
@@ -228,10 +232,12 @@ func initHandlerDeps(cfg Configs) (handlerDeps, error) {
webhookChannel := tsschannel.NewWebhookChannel(tsschannel.WebhookChannelConfigs{
HTTPClient: &httpClient,
Store: tssStore,
ChannelAccountStore: channelAccountStore,
MaxBufferSize: cfg.WebhookHandlerServiceChannelMaxBufferSize,
MaxWorkers: cfg.WebhookHandlerServiceChannelMaxWorkers,
MaxRetries: cfg.WebhookHandlerServiceChannelMaxRetries,
MinWaitBtwnRetriesMS: cfg.WebhookHandlerServiceChannelMinWaitBtwnRetriesMS,
NetworkPassphrase: cfg.NetworkPassphrase,
})

router := tssrouter.NewRouter(tssrouter.RouterConfigs{
11 changes: 9 additions & 2 deletions internal/signing/channel_account_db_signature_client.go
Original file line number Diff line number Diff line change
@@ -43,9 +43,16 @@ func (sc *channelAccountDBSignatureClient) NetworkPassphrase() string {
return sc.networkPassphrase
}

func (sc *channelAccountDBSignatureClient) GetAccountPublicKey(ctx context.Context) (string, error) {
func (sc *channelAccountDBSignatureClient) GetAccountPublicKey(ctx context.Context, opts ...int) (string, error) {
var lockedUntil time.Duration
if len(opts) > 0 {
lockedUntil = time.Duration(opts[0]) * time.Second
} else {
lockedUntil = time.Minute
}
for range store.ChannelAccountWaitTime {
channelAccount, err := sc.channelAccountStore.GetIdleChannelAccount(ctx, time.Minute)
// check to see if the variadic parameter for time exists and if so, use it here
channelAccount, err := sc.channelAccountStore.GetAndLockIdleChannelAccount(ctx, lockedUntil)
if err != nil {
if errors.Is(err, store.ErrNoIdleChannelAccountAvailable) {
log.Ctx(ctx).Warn("All channel accounts are in use. Retry in 1 second.")
8 changes: 4 additions & 4 deletions internal/signing/channel_account_db_signature_client_test.go
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) {

t.Run("returns_error_when_couldn't_get_an_idle_channel_account", func(t *testing.T) {
channelAccountStore.
On("GetIdleChannelAccount", ctx, time.Minute).
On("GetAndLockIdleChannelAccount", ctx, time.Duration(100)*time.Second).
Return(nil, store.ErrNoIdleChannelAccountAvailable).
Times(6).
On("Count", ctx).
@@ -40,7 +40,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) {

getEntries := log.DefaultLogger.StartTest(log.WarnLevel)

publicKey, err := sc.GetAccountPublicKey(ctx)
publicKey, err := sc.GetAccountPublicKey(ctx, 100)
assert.ErrorIs(t, err, store.ErrNoIdleChannelAccountAvailable)
assert.Empty(t, publicKey)

@@ -54,7 +54,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) {

t.Run("returns_error_when_there's_no_channel_account_configured", func(t *testing.T) {
channelAccountStore.
On("GetIdleChannelAccount", ctx, time.Minute).
On("GetAndLockIdleChannelAccount", ctx, time.Minute).
Return(nil, store.ErrNoIdleChannelAccountAvailable).
Times(6).
On("Count", ctx).
@@ -79,7 +79,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) {
t.Run("gets_an_idle_channel_account", func(t *testing.T) {
channelAccountPublicKey := keypair.MustRandom().Address()
channelAccountStore.
On("GetIdleChannelAccount", ctx, time.Minute).
On("GetAndLockIdleChannelAccount", ctx, time.Minute).
Return(&store.ChannelAccount{PublicKey: channelAccountPublicKey}, nil).
Once()
defer channelAccountStore.AssertExpectations(t)
2 changes: 1 addition & 1 deletion internal/signing/env_signature_client.go
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ func (sc *envSignatureClient) NetworkPassphrase() string {
return sc.networkPassphrase
}

func (sc *envSignatureClient) GetAccountPublicKey(ctx context.Context) (string, error) {
func (sc *envSignatureClient) GetAccountPublicKey(ctx context.Context, _ ...int) (string, error) {
return sc.distributionAccountFull.Address(), nil
}

2 changes: 1 addition & 1 deletion internal/signing/kms_signature_client.go
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ func NewKMSSignatureClient(publicKey string, networkPassphrase string, keypairSt
}, nil
}

func (sc *kmsSignatureClient) GetAccountPublicKey(ctx context.Context) (string, error) {
func (sc *kmsSignatureClient) GetAccountPublicKey(ctx context.Context, _ ...int) (string, error) {
return sc.distributionAccountPublicKey, nil
}

2 changes: 1 addition & 1 deletion internal/signing/mocks.go
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@ func (s *SignatureClientMock) NetworkPassphrase() string {
return args.String(0)
}

func (s *SignatureClientMock) GetAccountPublicKey(ctx context.Context) (string, error) {
func (s *SignatureClientMock) GetAccountPublicKey(ctx context.Context, opts ...int) (string, error) {
args := s.Called(ctx)
return args.String(0), args.Error(1)
}
2 changes: 1 addition & 1 deletion internal/signing/signature_client.go
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ var (

type SignatureClient interface {
NetworkPassphrase() string
GetAccountPublicKey(ctx context.Context) (string, error)
GetAccountPublicKey(ctx context.Context, opts ...int) (string, error)
SignStellarTransaction(ctx context.Context, tx *txnbuild.Transaction, stellarAccounts ...string) (*txnbuild.Transaction, error)
SignStellarFeeBumpTransaction(ctx context.Context, feeBumpTx *txnbuild.FeeBumpTransaction) (*txnbuild.FeeBumpTransaction, error)
}
26 changes: 22 additions & 4 deletions internal/signing/store/channel_accounts_model.go
Original file line number Diff line number Diff line change
@@ -25,24 +25,24 @@ type ChannelAccountModel struct {

var _ ChannelAccountStore = (*ChannelAccountModel)(nil)

func (ca *ChannelAccountModel) GetIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) {
func (ca *ChannelAccountModel) GetAndLockIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) {
query := fmt.Sprintf(`
UPDATE channel_accounts
SET
locked_tx_hash = NULL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is unnecessary since UnlockChannelAccountFromTx() sets this to null

Copy link
Contributor Author

@gouthamp-stellar gouthamp-stellar Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true in theory, but in case the channel accounts do not get unlocked (they fail to reach the webhook for example), this will act as a fail safe, which is why I added it in there

locked_at = NOW(),
locked_until = NOW() + INTERVAL '%d seconds'
WHERE public_key = (
SELECT
public_key
FROM channel_accounts
WHERE
locked_until IS NULL
OR locked_until < NOW()
(locked_tx_hash IS NULL AND (locked_until IS NULL OR locked_until < NOW()))
ORDER BY random()
LIMIT 1
FOR UPDATE SKIP LOCKED
)
RETURNING *
RETURNING *;
`, int64(lockedUntil.Seconds()))

var channelAccount ChannelAccount
@@ -84,6 +84,24 @@ func (ca *ChannelAccountModel) GetAllByPublicKey(ctx context.Context, sqlExec db
return channelAccounts, nil
}

func (ca *ChannelAccountModel) AssignTxToChannelAccount(ctx context.Context, publicKey string, txHash string) error {
const query = `UPDATE channel_accounts SET locked_tx_hash = $1 WHERE public_key = $2`
_, err := ca.DB.ExecContext(ctx, query, txHash, publicKey)
if err != nil {
return fmt.Errorf("assigning channel account: %w", err)
}
return nil
}

func (ca *ChannelAccountModel) UnassignTxAndUnlockChannelAccount(ctx context.Context, txHash string) error {
const query = `UPDATE channel_accounts SET locked_tx_hash = NULL, locked_at = NULL, locked_until = NULL WHERE locked_tx_hash = $1`
_, err := ca.DB.ExecContext(ctx, query, txHash)
if err != nil {
return fmt.Errorf("unlocking channel account: %w", err)
}
return nil
}

func (ca *ChannelAccountModel) BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error {
if len(channelAccounts) == 0 {
return nil
76 changes: 64 additions & 12 deletions internal/signing/store/channel_accounts_model_test.go
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ func createChannelAccountFixture(t *testing.T, ctx context.Context, dbConnection
require.NoError(t, err)
}

func TestChannelAccountModelGetIdleChannelAccount(t *testing.T) {
func TestChannelAccountModelGetAndLockIdleChannelAccount(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()

@@ -46,14 +46,15 @@ func TestChannelAccountModelGetIdleChannelAccount(t *testing.T) {
channel_accounts
SET
locked_at = NOW(),
locked_until = NOW() + '5 minutes'::INTERVAL
locked_until = NOW() + '5 minutes'::INTERVAL,
locked_tx_hash = 'hash'
WHERE
public_key = ANY($1)
`
_, err := dbConnectionPool.ExecContext(ctx, lockChannelAccountQuery, pq.Array([]string{channelAccount1.Address(), channelAccount2.Address()}))
require.NoError(t, err)

ca, err := m.GetIdleChannelAccount(ctx, time.Minute)
ca, err := m.GetAndLockIdleChannelAccount(ctx, time.Minute)
assert.ErrorIs(t, err, ErrNoIdleChannelAccountAvailable)
assert.Nil(t, ca)
})
@@ -64,18 +65,19 @@ func TestChannelAccountModelGetIdleChannelAccount(t *testing.T) {
createChannelAccountFixture(t, ctx, dbConnectionPool, ChannelAccount{PublicKey: channelAccount1.Address(), EncryptedPrivateKey: channelAccount1.Seed()}, ChannelAccount{PublicKey: channelAccount2.Address(), EncryptedPrivateKey: channelAccount2.Seed()})

const lockChannelAccountQuery = `
UPDATE
channel_accounts
SET
locked_at = NOW(),
locked_until = NOW() + '5 minutes'::INTERVAL
WHERE
public_key = $1
`
UPDATE
channel_accounts
SET
locked_at = NOW(),
locked_until = NOW() + '5 minutes'::INTERVAL,
locked_tx_hash = 'hash'
WHERE
public_key = $1
`
_, err := dbConnectionPool.ExecContext(ctx, lockChannelAccountQuery, channelAccount1.Address())
require.NoError(t, err)

ca, err := m.GetIdleChannelAccount(ctx, time.Minute)
ca, err := m.GetAndLockIdleChannelAccount(ctx, time.Minute)
require.NoError(t, err)
assert.Equal(t, ca.PublicKey, channelAccount2.Address())
assert.Equal(t, ca.EncryptedPrivateKey, channelAccount2.Seed())
@@ -128,6 +130,56 @@ func TestChannelAccountModelGetAllByPublicKey(t *testing.T) {
assert.Equal(t, channelAccount2.Address(), channelAccounts[1].PublicKey)
}

func TestAssignTxToChannelAccount(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()

dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()

ctx := context.Background()
m := NewChannelAccountModel(dbConnectionPool)

channelAccount := keypair.MustRandom()
createChannelAccountFixture(t, ctx, dbConnectionPool, ChannelAccount{PublicKey: channelAccount.Address(), EncryptedPrivateKey: channelAccount.Seed()})

err = m.AssignTxToChannelAccount(ctx, channelAccount.Address(), "txhash")
assert.NoError(t, err)
channelAccountFromDB, err := m.Get(ctx, dbConnectionPool, channelAccount.Address())
assert.NoError(t, err)
assert.Equal(t, "txhash", channelAccountFromDB.LockedTxHash.String)

}

func TestUnlockChannelAccountFromTx(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()

dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()

ctx := context.Background()
m := NewChannelAccountModel(dbConnectionPool)

channelAccount := keypair.MustRandom()
createChannelAccountFixture(t, ctx, dbConnectionPool, ChannelAccount{PublicKey: channelAccount.Address(), EncryptedPrivateKey: channelAccount.Seed()})
err = m.AssignTxToChannelAccount(ctx, channelAccount.Address(), "txhash")
assert.NoError(t, err)
channelAccountFromDB, err := m.Get(ctx, dbConnectionPool, channelAccount.Address())
assert.NoError(t, err)
assert.Equal(t, "txhash", channelAccountFromDB.LockedTxHash.String)

err = m.UnassignTxAndUnlockChannelAccount(ctx, "txhash")
assert.NoError(t, err)
channelAccountFromDB, err = m.Get(ctx, dbConnectionPool, channelAccount.Address())
assert.NoError(t, err)
assert.False(t, channelAccountFromDB.LockedTxHash.Valid)
assert.False(t, channelAccountFromDB.LockedAt.Valid)
assert.False(t, channelAccountFromDB.LockedUntil.Valid)
}

func TestChannelAccountModelBatchInsert(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()
12 changes: 11 additions & 1 deletion internal/signing/store/mocks.go
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ type ChannelAccountStoreMock struct {

var _ ChannelAccountStore = (*ChannelAccountStoreMock)(nil)

func (s *ChannelAccountStoreMock) GetIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) {
func (s *ChannelAccountStoreMock) GetAndLockIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) {
args := s.Called(ctx, lockedUntil)
if args.Get(0) == nil {
return nil, args.Error(1)
@@ -38,6 +38,16 @@ func (s *ChannelAccountStoreMock) GetAllByPublicKey(ctx context.Context, sqlExec
return args.Get(0).([]*ChannelAccount), args.Error(1)
}

func (s *ChannelAccountStoreMock) AssignTxToChannelAccount(ctx context.Context, publicKey string, txHash string) error {
args := s.Called(ctx, publicKey, txHash)
return args.Error(0)
}

func (s *ChannelAccountStoreMock) UnassignTxAndUnlockChannelAccount(ctx context.Context, txHash string) error {
args := s.Called(ctx, txHash)
return args.Error(0)
}

func (s *ChannelAccountStoreMock) BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error {
args := s.Called(ctx, sqlExec, channelAccounts)
return args.Error(0)
17 changes: 10 additions & 7 deletions internal/signing/store/types.go
Original file line number Diff line number Diff line change
@@ -9,18 +9,21 @@ import (
)

type ChannelAccount struct {
PublicKey string `db:"public_key"`
EncryptedPrivateKey string `db:"encrypted_private_key"`
UpdatedAt time.Time `db:"updated_at"`
CreatedAt time.Time `db:"created_at"`
LockedAt sql.NullTime `db:"locked_at"`
LockedUntil sql.NullTime `db:"locked_until"`
PublicKey string `db:"public_key"`
EncryptedPrivateKey string `db:"encrypted_private_key"`
UpdatedAt time.Time `db:"updated_at"`
CreatedAt time.Time `db:"created_at"`
LockedAt sql.NullTime `db:"locked_at"`
LockedUntil sql.NullTime `db:"locked_until"`
LockedTxHash sql.NullString `db:"locked_tx_hash"`
}

type ChannelAccountStore interface {
GetIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error)
GetAndLockIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error)
Get(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) (*ChannelAccount, error)
GetAllByPublicKey(ctx context.Context, sqlExec db.SQLExecuter, publicKeys ...string) ([]*ChannelAccount, error)
AssignTxToChannelAccount(ctx context.Context, publicKey string, txHash string) error
UnassignTxAndUnlockChannelAccount(ctx context.Context, txHash string) error
BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error
Count(ctx context.Context) (int64, error)
}
Loading