Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Locking channel accounts to transactions #116

Merged
merged 5 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-- +migrate Up

ALTER TABLE channel_accounts
ADD COLUMN locked_tx_hash VARCHAR(64);
CREATE INDEX idx_locked_tx_hash ON channel_accounts(locked_tx_hash);
-- +migrate Down
DROP INDEX IF EXISTS idx_locked_tx_hash;
ALTER TABLE channel_accounts
DROP COLUMN locked_tx_hash;
6 changes: 6 additions & 0 deletions internal/serve/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ func initHandlerDeps(cfg Configs) (handlerDeps, error) {
}
go rpcService.TrackRPCServiceHealth(context.Background())

channelAccountStore := store.NewChannelAccountModel(dbConnectionPool)

accountService, err := services.NewAccountService(models)
if err != nil {
return handlerDeps{}, fmt.Errorf("instantiating account service: %w", err)
Expand All @@ -181,8 +183,10 @@ func initHandlerDeps(cfg Configs) (handlerDeps, error) {

// TSS setup
tssTxService, err := tssservices.NewTransactionService(tssservices.TransactionServiceOptions{
DB: dbConnectionPool,
DistributionAccountSignatureClient: cfg.DistributionAccountSignatureClient,
ChannelAccountSignatureClient: cfg.ChannelAccountSignatureClient,
ChannelAccountStore: channelAccountStore,
RPCService: rpcService,
BaseFee: int64(cfg.BaseFee),
})
Expand Down Expand Up @@ -228,10 +232,12 @@ func initHandlerDeps(cfg Configs) (handlerDeps, error) {
webhookChannel := tsschannel.NewWebhookChannel(tsschannel.WebhookChannelConfigs{
HTTPClient: &httpClient,
Store: tssStore,
ChannelAccountStore: channelAccountStore,
MaxBufferSize: cfg.WebhookHandlerServiceChannelMaxBufferSize,
MaxWorkers: cfg.WebhookHandlerServiceChannelMaxWorkers,
MaxRetries: cfg.WebhookHandlerServiceChannelMaxRetries,
MinWaitBtwnRetriesMS: cfg.WebhookHandlerServiceChannelMinWaitBtwnRetriesMS,
NetworkPassphrase: cfg.NetworkPassphrase,
})

router := tssrouter.NewRouter(tssrouter.RouterConfigs{
Expand Down
11 changes: 9 additions & 2 deletions internal/signing/channel_account_db_signature_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,16 @@ func (sc *channelAccountDBSignatureClient) NetworkPassphrase() string {
return sc.networkPassphrase
}

func (sc *channelAccountDBSignatureClient) GetAccountPublicKey(ctx context.Context) (string, error) {
func (sc *channelAccountDBSignatureClient) GetAccountPublicKey(ctx context.Context, opts ...int) (string, error) {
var lockedUntil time.Duration
if len(opts) > 0 {
lockedUntil = time.Duration(opts[0]) * time.Second
} else {
lockedUntil = time.Minute
}
for range store.ChannelAccountWaitTime {
channelAccount, err := sc.channelAccountStore.GetIdleChannelAccount(ctx, time.Minute)
// check to see if the variadic parameter for time exists and if so, use it here
channelAccount, err := sc.channelAccountStore.GetAndLockIdleChannelAccount(ctx, lockedUntil)
if err != nil {
if errors.Is(err, store.ErrNoIdleChannelAccountAvailable) {
log.Ctx(ctx).Warn("All channel accounts are in use. Retry in 1 second.")
Expand Down
8 changes: 4 additions & 4 deletions internal/signing/channel_account_db_signature_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) {

t.Run("returns_error_when_couldn't_get_an_idle_channel_account", func(t *testing.T) {
channelAccountStore.
On("GetIdleChannelAccount", ctx, time.Minute).
On("GetAndLockIdleChannelAccount", ctx, time.Duration(100)*time.Second).
Return(nil, store.ErrNoIdleChannelAccountAvailable).
Times(6).
On("Count", ctx).
Expand All @@ -40,7 +40,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) {

getEntries := log.DefaultLogger.StartTest(log.WarnLevel)

publicKey, err := sc.GetAccountPublicKey(ctx)
publicKey, err := sc.GetAccountPublicKey(ctx, 100)
assert.ErrorIs(t, err, store.ErrNoIdleChannelAccountAvailable)
assert.Empty(t, publicKey)

Expand All @@ -54,7 +54,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) {

t.Run("returns_error_when_there's_no_channel_account_configured", func(t *testing.T) {
channelAccountStore.
On("GetIdleChannelAccount", ctx, time.Minute).
On("GetAndLockIdleChannelAccount", ctx, time.Minute).
Return(nil, store.ErrNoIdleChannelAccountAvailable).
Times(6).
On("Count", ctx).
Expand All @@ -79,7 +79,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) {
t.Run("gets_an_idle_channel_account", func(t *testing.T) {
channelAccountPublicKey := keypair.MustRandom().Address()
channelAccountStore.
On("GetIdleChannelAccount", ctx, time.Minute).
On("GetAndLockIdleChannelAccount", ctx, time.Minute).
Return(&store.ChannelAccount{PublicKey: channelAccountPublicKey}, nil).
Once()
defer channelAccountStore.AssertExpectations(t)
Expand Down
2 changes: 1 addition & 1 deletion internal/signing/env_signature_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (sc *envSignatureClient) NetworkPassphrase() string {
return sc.networkPassphrase
}

func (sc *envSignatureClient) GetAccountPublicKey(ctx context.Context) (string, error) {
func (sc *envSignatureClient) GetAccountPublicKey(ctx context.Context, _ ...int) (string, error) {
return sc.distributionAccountFull.Address(), nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/signing/kms_signature_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func NewKMSSignatureClient(publicKey string, networkPassphrase string, keypairSt
}, nil
}

func (sc *kmsSignatureClient) GetAccountPublicKey(ctx context.Context) (string, error) {
func (sc *kmsSignatureClient) GetAccountPublicKey(ctx context.Context, _ ...int) (string, error) {
return sc.distributionAccountPublicKey, nil
}

Expand Down
2 changes: 1 addition & 1 deletion internal/signing/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (s *SignatureClientMock) NetworkPassphrase() string {
return args.String(0)
}

func (s *SignatureClientMock) GetAccountPublicKey(ctx context.Context) (string, error) {
func (s *SignatureClientMock) GetAccountPublicKey(ctx context.Context, _ ...int) (string, error) {
args := s.Called(ctx)
return args.String(0), args.Error(1)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/signing/signature_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var (

type SignatureClient interface {
NetworkPassphrase() string
GetAccountPublicKey(ctx context.Context) (string, error)
GetAccountPublicKey(ctx context.Context, opts ...int) (string, error)
SignStellarTransaction(ctx context.Context, tx *txnbuild.Transaction, stellarAccounts ...string) (*txnbuild.Transaction, error)
SignStellarFeeBumpTransaction(ctx context.Context, feeBumpTx *txnbuild.FeeBumpTransaction) (*txnbuild.FeeBumpTransaction, error)
}
Expand Down
26 changes: 22 additions & 4 deletions internal/signing/store/channel_accounts_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,24 @@ type ChannelAccountModel struct {

var _ ChannelAccountStore = (*ChannelAccountModel)(nil)

func (ca *ChannelAccountModel) GetIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) {
func (ca *ChannelAccountModel) GetAndLockIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) {
query := fmt.Sprintf(`
UPDATE channel_accounts
SET
locked_tx_hash = NULL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is unnecessary since UnlockChannelAccountFromTx() sets this to null

Copy link
Contributor Author

@gouthamp-stellar gouthamp-stellar Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true in theory, but in case the channel accounts do not get unlocked (they fail to reach the webhook for example), this will act as a fail safe, which is why I added it in there

locked_at = NOW(),
locked_until = NOW() + INTERVAL '%d seconds'
WHERE public_key = (
SELECT
public_key
FROM channel_accounts
WHERE
locked_until IS NULL
OR locked_until < NOW()
(locked_tx_hash IS NULL AND (locked_until IS NULL OR locked_until < NOW()))
ORDER BY random()
LIMIT 1
FOR UPDATE SKIP LOCKED
)
RETURNING *
RETURNING *;
`, int64(lockedUntil.Seconds()))

var channelAccount ChannelAccount
Expand Down Expand Up @@ -84,6 +84,24 @@ func (ca *ChannelAccountModel) GetAllByPublicKey(ctx context.Context, sqlExec db
return channelAccounts, nil
}

func (ca *ChannelAccountModel) AssignTxToChannelAccount(ctx context.Context, publicKey string, txHash string) error {
const query = `UPDATE channel_accounts SET locked_tx_hash = $1 WHERE public_key = $2`
_, err := ca.DB.ExecContext(ctx, query, txHash, publicKey)
if err != nil {
return fmt.Errorf("assigning channel account: %w", err)
}
return nil
}

func (ca *ChannelAccountModel) UnassignTxAndUnlockChannelAccount(ctx context.Context, txHash string) error {
const query = `UPDATE channel_accounts SET locked_tx_hash = NULL, locked_at = NULL, locked_until = NULL WHERE locked_tx_hash = $1`
_, err := ca.DB.ExecContext(ctx, query, txHash)
if err != nil {
return fmt.Errorf("unlocking channel account: %w", err)
}
return nil
}

func (ca *ChannelAccountModel) BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error {
if len(channelAccounts) == 0 {
return nil
Expand Down
76 changes: 64 additions & 12 deletions internal/signing/store/channel_accounts_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func createChannelAccountFixture(t *testing.T, ctx context.Context, dbConnection
require.NoError(t, err)
}

func TestChannelAccountModelGetIdleChannelAccount(t *testing.T) {
func TestChannelAccountModelGetAndLockIdleChannelAccount(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()

Expand All @@ -46,14 +46,15 @@ func TestChannelAccountModelGetIdleChannelAccount(t *testing.T) {
channel_accounts
SET
locked_at = NOW(),
locked_until = NOW() + '5 minutes'::INTERVAL
locked_until = NOW() + '5 minutes'::INTERVAL,
locked_tx_hash = 'hash'
WHERE
public_key = ANY($1)
`
_, err := dbConnectionPool.ExecContext(ctx, lockChannelAccountQuery, pq.Array([]string{channelAccount1.Address(), channelAccount2.Address()}))
require.NoError(t, err)

ca, err := m.GetIdleChannelAccount(ctx, time.Minute)
ca, err := m.GetAndLockIdleChannelAccount(ctx, time.Minute)
assert.ErrorIs(t, err, ErrNoIdleChannelAccountAvailable)
assert.Nil(t, ca)
})
Expand All @@ -64,18 +65,19 @@ func TestChannelAccountModelGetIdleChannelAccount(t *testing.T) {
createChannelAccountFixture(t, ctx, dbConnectionPool, ChannelAccount{PublicKey: channelAccount1.Address(), EncryptedPrivateKey: channelAccount1.Seed()}, ChannelAccount{PublicKey: channelAccount2.Address(), EncryptedPrivateKey: channelAccount2.Seed()})

const lockChannelAccountQuery = `
UPDATE
channel_accounts
SET
locked_at = NOW(),
locked_until = NOW() + '5 minutes'::INTERVAL
WHERE
public_key = $1
`
UPDATE
channel_accounts
SET
locked_at = NOW(),
locked_until = NOW() + '5 minutes'::INTERVAL,
locked_tx_hash = 'hash'
WHERE
public_key = $1
`
_, err := dbConnectionPool.ExecContext(ctx, lockChannelAccountQuery, channelAccount1.Address())
require.NoError(t, err)

ca, err := m.GetIdleChannelAccount(ctx, time.Minute)
ca, err := m.GetAndLockIdleChannelAccount(ctx, time.Minute)
require.NoError(t, err)
assert.Equal(t, ca.PublicKey, channelAccount2.Address())
assert.Equal(t, ca.EncryptedPrivateKey, channelAccount2.Seed())
Expand Down Expand Up @@ -128,6 +130,56 @@ func TestChannelAccountModelGetAllByPublicKey(t *testing.T) {
assert.Equal(t, channelAccount2.Address(), channelAccounts[1].PublicKey)
}

func TestAssignTxToChannelAccount(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()

dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()

ctx := context.Background()
m := NewChannelAccountModel(dbConnectionPool)

channelAccount := keypair.MustRandom()
createChannelAccountFixture(t, ctx, dbConnectionPool, ChannelAccount{PublicKey: channelAccount.Address(), EncryptedPrivateKey: channelAccount.Seed()})

err = m.AssignTxToChannelAccount(ctx, channelAccount.Address(), "txhash")
assert.NoError(t, err)
channelAccountFromDB, err := m.Get(ctx, dbConnectionPool, channelAccount.Address())
assert.NoError(t, err)
assert.Equal(t, "txhash", channelAccountFromDB.LockedTxHash.String)

}

func TestUnlockChannelAccountFromTx(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()

dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()

ctx := context.Background()
m := NewChannelAccountModel(dbConnectionPool)

channelAccount := keypair.MustRandom()
createChannelAccountFixture(t, ctx, dbConnectionPool, ChannelAccount{PublicKey: channelAccount.Address(), EncryptedPrivateKey: channelAccount.Seed()})
err = m.AssignTxToChannelAccount(ctx, channelAccount.Address(), "txhash")
assert.NoError(t, err)
channelAccountFromDB, err := m.Get(ctx, dbConnectionPool, channelAccount.Address())
assert.NoError(t, err)
assert.Equal(t, "txhash", channelAccountFromDB.LockedTxHash.String)

err = m.UnassignTxAndUnlockChannelAccount(ctx, "txhash")
assert.NoError(t, err)
channelAccountFromDB, err = m.Get(ctx, dbConnectionPool, channelAccount.Address())
assert.NoError(t, err)
assert.False(t, channelAccountFromDB.LockedTxHash.Valid)
assert.False(t, channelAccountFromDB.LockedAt.Valid)
assert.False(t, channelAccountFromDB.LockedUntil.Valid)
}

func TestChannelAccountModelBatchInsert(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()
Expand Down
12 changes: 11 additions & 1 deletion internal/signing/store/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type ChannelAccountStoreMock struct {

var _ ChannelAccountStore = (*ChannelAccountStoreMock)(nil)

func (s *ChannelAccountStoreMock) GetIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) {
func (s *ChannelAccountStoreMock) GetAndLockIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) {
args := s.Called(ctx, lockedUntil)
if args.Get(0) == nil {
return nil, args.Error(1)
Expand All @@ -38,6 +38,16 @@ func (s *ChannelAccountStoreMock) GetAllByPublicKey(ctx context.Context, sqlExec
return args.Get(0).([]*ChannelAccount), args.Error(1)
}

func (s *ChannelAccountStoreMock) AssignTxToChannelAccount(ctx context.Context, publicKey string, txHash string) error {
args := s.Called(ctx, publicKey, txHash)
return args.Error(0)
}

func (s *ChannelAccountStoreMock) UnassignTxAndUnlockChannelAccount(ctx context.Context, txHash string) error {
args := s.Called(ctx, txHash)
return args.Error(0)
}

func (s *ChannelAccountStoreMock) BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error {
args := s.Called(ctx, sqlExec, channelAccounts)
return args.Error(0)
Expand Down
17 changes: 10 additions & 7 deletions internal/signing/store/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@ import (
)

type ChannelAccount struct {
PublicKey string `db:"public_key"`
EncryptedPrivateKey string `db:"encrypted_private_key"`
UpdatedAt time.Time `db:"updated_at"`
CreatedAt time.Time `db:"created_at"`
LockedAt sql.NullTime `db:"locked_at"`
LockedUntil sql.NullTime `db:"locked_until"`
PublicKey string `db:"public_key"`
EncryptedPrivateKey string `db:"encrypted_private_key"`
UpdatedAt time.Time `db:"updated_at"`
CreatedAt time.Time `db:"created_at"`
LockedAt sql.NullTime `db:"locked_at"`
LockedUntil sql.NullTime `db:"locked_until"`
LockedTxHash sql.NullString `db:"locked_tx_hash"`
}

type ChannelAccountStore interface {
GetIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error)
GetAndLockIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error)
Get(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) (*ChannelAccount, error)
GetAllByPublicKey(ctx context.Context, sqlExec db.SQLExecuter, publicKeys ...string) ([]*ChannelAccount, error)
AssignTxToChannelAccount(ctx context.Context, publicKey string, txHash string) error
UnassignTxAndUnlockChannelAccount(ctx context.Context, txHash string) error
BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error
Count(ctx context.Context) (int64, error)
}
Expand Down
Loading
Loading