Skip to content

Commit 072916a

Browse files
Locking channel accounts to transactions (#116)
* Locking channel accounts to transactions Instead of locking channel accounts for an arbitrary amount of time, lock them for the same amount of time as the transaction time bounds and also explicitly lock them to transactions, unlocking them in the webhook channel * remove fmt.Println statments * changes based on comments * changes based on latest comments * lint error
1 parent d11bfcb commit 072916a

16 files changed

+443
-64
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
-- +migrate Up
2+
3+
ALTER TABLE channel_accounts
4+
ADD COLUMN locked_tx_hash VARCHAR(64);
5+
CREATE INDEX idx_locked_tx_hash ON channel_accounts(locked_tx_hash);
6+
-- +migrate Down
7+
DROP INDEX IF EXISTS idx_locked_tx_hash;
8+
ALTER TABLE channel_accounts
9+
DROP COLUMN locked_tx_hash;

internal/serve/serve.go

+6
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ func initHandlerDeps(cfg Configs) (handlerDeps, error) {
156156
}
157157
go rpcService.TrackRPCServiceHealth(context.Background())
158158

159+
channelAccountStore := store.NewChannelAccountModel(dbConnectionPool)
160+
159161
accountService, err := services.NewAccountService(models)
160162
if err != nil {
161163
return handlerDeps{}, fmt.Errorf("instantiating account service: %w", err)
@@ -181,8 +183,10 @@ func initHandlerDeps(cfg Configs) (handlerDeps, error) {
181183

182184
// TSS setup
183185
tssTxService, err := tssservices.NewTransactionService(tssservices.TransactionServiceOptions{
186+
DB: dbConnectionPool,
184187
DistributionAccountSignatureClient: cfg.DistributionAccountSignatureClient,
185188
ChannelAccountSignatureClient: cfg.ChannelAccountSignatureClient,
189+
ChannelAccountStore: channelAccountStore,
186190
RPCService: rpcService,
187191
BaseFee: int64(cfg.BaseFee),
188192
})
@@ -228,10 +232,12 @@ func initHandlerDeps(cfg Configs) (handlerDeps, error) {
228232
webhookChannel := tsschannel.NewWebhookChannel(tsschannel.WebhookChannelConfigs{
229233
HTTPClient: &httpClient,
230234
Store: tssStore,
235+
ChannelAccountStore: channelAccountStore,
231236
MaxBufferSize: cfg.WebhookHandlerServiceChannelMaxBufferSize,
232237
MaxWorkers: cfg.WebhookHandlerServiceChannelMaxWorkers,
233238
MaxRetries: cfg.WebhookHandlerServiceChannelMaxRetries,
234239
MinWaitBtwnRetriesMS: cfg.WebhookHandlerServiceChannelMinWaitBtwnRetriesMS,
240+
NetworkPassphrase: cfg.NetworkPassphrase,
235241
})
236242

237243
router := tssrouter.NewRouter(tssrouter.RouterConfigs{

internal/signing/channel_account_db_signature_client.go

+9-2
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,16 @@ func (sc *channelAccountDBSignatureClient) NetworkPassphrase() string {
4343
return sc.networkPassphrase
4444
}
4545

46-
func (sc *channelAccountDBSignatureClient) GetAccountPublicKey(ctx context.Context) (string, error) {
46+
func (sc *channelAccountDBSignatureClient) GetAccountPublicKey(ctx context.Context, opts ...int) (string, error) {
47+
var lockedUntil time.Duration
48+
if len(opts) > 0 {
49+
lockedUntil = time.Duration(opts[0]) * time.Second
50+
} else {
51+
lockedUntil = time.Minute
52+
}
4753
for range store.ChannelAccountWaitTime {
48-
channelAccount, err := sc.channelAccountStore.GetIdleChannelAccount(ctx, time.Minute)
54+
// check to see if the variadic parameter for time exists and if so, use it here
55+
channelAccount, err := sc.channelAccountStore.GetAndLockIdleChannelAccount(ctx, lockedUntil)
4956
if err != nil {
5057
if errors.Is(err, store.ErrNoIdleChannelAccountAvailable) {
5158
log.Ctx(ctx).Warn("All channel accounts are in use. Retry in 1 second.")

internal/signing/channel_account_db_signature_client_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) {
3030

3131
t.Run("returns_error_when_couldn't_get_an_idle_channel_account", func(t *testing.T) {
3232
channelAccountStore.
33-
On("GetIdleChannelAccount", ctx, time.Minute).
33+
On("GetAndLockIdleChannelAccount", ctx, time.Duration(100)*time.Second).
3434
Return(nil, store.ErrNoIdleChannelAccountAvailable).
3535
Times(6).
3636
On("Count", ctx).
@@ -40,7 +40,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) {
4040

4141
getEntries := log.DefaultLogger.StartTest(log.WarnLevel)
4242

43-
publicKey, err := sc.GetAccountPublicKey(ctx)
43+
publicKey, err := sc.GetAccountPublicKey(ctx, 100)
4444
assert.ErrorIs(t, err, store.ErrNoIdleChannelAccountAvailable)
4545
assert.Empty(t, publicKey)
4646

@@ -54,7 +54,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) {
5454

5555
t.Run("returns_error_when_there's_no_channel_account_configured", func(t *testing.T) {
5656
channelAccountStore.
57-
On("GetIdleChannelAccount", ctx, time.Minute).
57+
On("GetAndLockIdleChannelAccount", ctx, time.Minute).
5858
Return(nil, store.ErrNoIdleChannelAccountAvailable).
5959
Times(6).
6060
On("Count", ctx).
@@ -79,7 +79,7 @@ func TestChannelAccountDBSignatureClientGetAccountPublicKey(t *testing.T) {
7979
t.Run("gets_an_idle_channel_account", func(t *testing.T) {
8080
channelAccountPublicKey := keypair.MustRandom().Address()
8181
channelAccountStore.
82-
On("GetIdleChannelAccount", ctx, time.Minute).
82+
On("GetAndLockIdleChannelAccount", ctx, time.Minute).
8383
Return(&store.ChannelAccount{PublicKey: channelAccountPublicKey}, nil).
8484
Once()
8585
defer channelAccountStore.AssertExpectations(t)

internal/signing/env_signature_client.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func (sc *envSignatureClient) NetworkPassphrase() string {
3636
return sc.networkPassphrase
3737
}
3838

39-
func (sc *envSignatureClient) GetAccountPublicKey(ctx context.Context) (string, error) {
39+
func (sc *envSignatureClient) GetAccountPublicKey(ctx context.Context, _ ...int) (string, error) {
4040
return sc.distributionAccountFull.Address(), nil
4141
}
4242

internal/signing/kms_signature_client.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func NewKMSSignatureClient(publicKey string, networkPassphrase string, keypairSt
5959
}, nil
6060
}
6161

62-
func (sc *kmsSignatureClient) GetAccountPublicKey(ctx context.Context) (string, error) {
62+
func (sc *kmsSignatureClient) GetAccountPublicKey(ctx context.Context, _ ...int) (string, error) {
6363
return sc.distributionAccountPublicKey, nil
6464
}
6565

internal/signing/mocks.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func (s *SignatureClientMock) NetworkPassphrase() string {
1818
return args.String(0)
1919
}
2020

21-
func (s *SignatureClientMock) GetAccountPublicKey(ctx context.Context) (string, error) {
21+
func (s *SignatureClientMock) GetAccountPublicKey(ctx context.Context, _ ...int) (string, error) {
2222
args := s.Called(ctx)
2323
return args.String(0), args.Error(1)
2424
}

internal/signing/signature_client.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ var (
1515

1616
type SignatureClient interface {
1717
NetworkPassphrase() string
18-
GetAccountPublicKey(ctx context.Context) (string, error)
18+
GetAccountPublicKey(ctx context.Context, opts ...int) (string, error)
1919
SignStellarTransaction(ctx context.Context, tx *txnbuild.Transaction, stellarAccounts ...string) (*txnbuild.Transaction, error)
2020
SignStellarFeeBumpTransaction(ctx context.Context, feeBumpTx *txnbuild.FeeBumpTransaction) (*txnbuild.FeeBumpTransaction, error)
2121
}

internal/signing/store/channel_accounts_model.go

+22-4
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,24 @@ type ChannelAccountModel struct {
2525

2626
var _ ChannelAccountStore = (*ChannelAccountModel)(nil)
2727

28-
func (ca *ChannelAccountModel) GetIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) {
28+
func (ca *ChannelAccountModel) GetAndLockIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) {
2929
query := fmt.Sprintf(`
3030
UPDATE channel_accounts
3131
SET
32+
locked_tx_hash = NULL,
3233
locked_at = NOW(),
3334
locked_until = NOW() + INTERVAL '%d seconds'
3435
WHERE public_key = (
3536
SELECT
3637
public_key
3738
FROM channel_accounts
3839
WHERE
39-
locked_until IS NULL
40-
OR locked_until < NOW()
40+
(locked_tx_hash IS NULL AND (locked_until IS NULL OR locked_until < NOW()))
4141
ORDER BY random()
4242
LIMIT 1
4343
FOR UPDATE SKIP LOCKED
4444
)
45-
RETURNING *
45+
RETURNING *;
4646
`, int64(lockedUntil.Seconds()))
4747

4848
var channelAccount ChannelAccount
@@ -84,6 +84,24 @@ func (ca *ChannelAccountModel) GetAllByPublicKey(ctx context.Context, sqlExec db
8484
return channelAccounts, nil
8585
}
8686

87+
func (ca *ChannelAccountModel) AssignTxToChannelAccount(ctx context.Context, publicKey string, txHash string) error {
88+
const query = `UPDATE channel_accounts SET locked_tx_hash = $1 WHERE public_key = $2`
89+
_, err := ca.DB.ExecContext(ctx, query, txHash, publicKey)
90+
if err != nil {
91+
return fmt.Errorf("assigning channel account: %w", err)
92+
}
93+
return nil
94+
}
95+
96+
func (ca *ChannelAccountModel) UnassignTxAndUnlockChannelAccount(ctx context.Context, txHash string) error {
97+
const query = `UPDATE channel_accounts SET locked_tx_hash = NULL, locked_at = NULL, locked_until = NULL WHERE locked_tx_hash = $1`
98+
_, err := ca.DB.ExecContext(ctx, query, txHash)
99+
if err != nil {
100+
return fmt.Errorf("unlocking channel account: %w", err)
101+
}
102+
return nil
103+
}
104+
87105
func (ca *ChannelAccountModel) BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error {
88106
if len(channelAccounts) == 0 {
89107
return nil

internal/signing/store/channel_accounts_model_test.go

+64-12
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func createChannelAccountFixture(t *testing.T, ctx context.Context, dbConnection
2525
require.NoError(t, err)
2626
}
2727

28-
func TestChannelAccountModelGetIdleChannelAccount(t *testing.T) {
28+
func TestChannelAccountModelGetAndLockIdleChannelAccount(t *testing.T) {
2929
dbt := dbtest.Open(t)
3030
defer dbt.Close()
3131

@@ -46,14 +46,15 @@ func TestChannelAccountModelGetIdleChannelAccount(t *testing.T) {
4646
channel_accounts
4747
SET
4848
locked_at = NOW(),
49-
locked_until = NOW() + '5 minutes'::INTERVAL
49+
locked_until = NOW() + '5 minutes'::INTERVAL,
50+
locked_tx_hash = 'hash'
5051
WHERE
5152
public_key = ANY($1)
5253
`
5354
_, err := dbConnectionPool.ExecContext(ctx, lockChannelAccountQuery, pq.Array([]string{channelAccount1.Address(), channelAccount2.Address()}))
5455
require.NoError(t, err)
5556

56-
ca, err := m.GetIdleChannelAccount(ctx, time.Minute)
57+
ca, err := m.GetAndLockIdleChannelAccount(ctx, time.Minute)
5758
assert.ErrorIs(t, err, ErrNoIdleChannelAccountAvailable)
5859
assert.Nil(t, ca)
5960
})
@@ -64,18 +65,19 @@ func TestChannelAccountModelGetIdleChannelAccount(t *testing.T) {
6465
createChannelAccountFixture(t, ctx, dbConnectionPool, ChannelAccount{PublicKey: channelAccount1.Address(), EncryptedPrivateKey: channelAccount1.Seed()}, ChannelAccount{PublicKey: channelAccount2.Address(), EncryptedPrivateKey: channelAccount2.Seed()})
6566

6667
const lockChannelAccountQuery = `
67-
UPDATE
68-
channel_accounts
69-
SET
70-
locked_at = NOW(),
71-
locked_until = NOW() + '5 minutes'::INTERVAL
72-
WHERE
73-
public_key = $1
74-
`
68+
UPDATE
69+
channel_accounts
70+
SET
71+
locked_at = NOW(),
72+
locked_until = NOW() + '5 minutes'::INTERVAL,
73+
locked_tx_hash = 'hash'
74+
WHERE
75+
public_key = $1
76+
`
7577
_, err := dbConnectionPool.ExecContext(ctx, lockChannelAccountQuery, channelAccount1.Address())
7678
require.NoError(t, err)
7779

78-
ca, err := m.GetIdleChannelAccount(ctx, time.Minute)
80+
ca, err := m.GetAndLockIdleChannelAccount(ctx, time.Minute)
7981
require.NoError(t, err)
8082
assert.Equal(t, ca.PublicKey, channelAccount2.Address())
8183
assert.Equal(t, ca.EncryptedPrivateKey, channelAccount2.Seed())
@@ -128,6 +130,56 @@ func TestChannelAccountModelGetAllByPublicKey(t *testing.T) {
128130
assert.Equal(t, channelAccount2.Address(), channelAccounts[1].PublicKey)
129131
}
130132

133+
func TestAssignTxToChannelAccount(t *testing.T) {
134+
dbt := dbtest.Open(t)
135+
defer dbt.Close()
136+
137+
dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
138+
require.NoError(t, err)
139+
defer dbConnectionPool.Close()
140+
141+
ctx := context.Background()
142+
m := NewChannelAccountModel(dbConnectionPool)
143+
144+
channelAccount := keypair.MustRandom()
145+
createChannelAccountFixture(t, ctx, dbConnectionPool, ChannelAccount{PublicKey: channelAccount.Address(), EncryptedPrivateKey: channelAccount.Seed()})
146+
147+
err = m.AssignTxToChannelAccount(ctx, channelAccount.Address(), "txhash")
148+
assert.NoError(t, err)
149+
channelAccountFromDB, err := m.Get(ctx, dbConnectionPool, channelAccount.Address())
150+
assert.NoError(t, err)
151+
assert.Equal(t, "txhash", channelAccountFromDB.LockedTxHash.String)
152+
153+
}
154+
155+
func TestUnlockChannelAccountFromTx(t *testing.T) {
156+
dbt := dbtest.Open(t)
157+
defer dbt.Close()
158+
159+
dbConnectionPool, err := db.OpenDBConnectionPool(dbt.DSN)
160+
require.NoError(t, err)
161+
defer dbConnectionPool.Close()
162+
163+
ctx := context.Background()
164+
m := NewChannelAccountModel(dbConnectionPool)
165+
166+
channelAccount := keypair.MustRandom()
167+
createChannelAccountFixture(t, ctx, dbConnectionPool, ChannelAccount{PublicKey: channelAccount.Address(), EncryptedPrivateKey: channelAccount.Seed()})
168+
err = m.AssignTxToChannelAccount(ctx, channelAccount.Address(), "txhash")
169+
assert.NoError(t, err)
170+
channelAccountFromDB, err := m.Get(ctx, dbConnectionPool, channelAccount.Address())
171+
assert.NoError(t, err)
172+
assert.Equal(t, "txhash", channelAccountFromDB.LockedTxHash.String)
173+
174+
err = m.UnassignTxAndUnlockChannelAccount(ctx, "txhash")
175+
assert.NoError(t, err)
176+
channelAccountFromDB, err = m.Get(ctx, dbConnectionPool, channelAccount.Address())
177+
assert.NoError(t, err)
178+
assert.False(t, channelAccountFromDB.LockedTxHash.Valid)
179+
assert.False(t, channelAccountFromDB.LockedAt.Valid)
180+
assert.False(t, channelAccountFromDB.LockedUntil.Valid)
181+
}
182+
131183
func TestChannelAccountModelBatchInsert(t *testing.T) {
132184
dbt := dbtest.Open(t)
133185
defer dbt.Close()

internal/signing/store/mocks.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ type ChannelAccountStoreMock struct {
1414

1515
var _ ChannelAccountStore = (*ChannelAccountStoreMock)(nil)
1616

17-
func (s *ChannelAccountStoreMock) GetIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) {
17+
func (s *ChannelAccountStoreMock) GetAndLockIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error) {
1818
args := s.Called(ctx, lockedUntil)
1919
if args.Get(0) == nil {
2020
return nil, args.Error(1)
@@ -38,6 +38,16 @@ func (s *ChannelAccountStoreMock) GetAllByPublicKey(ctx context.Context, sqlExec
3838
return args.Get(0).([]*ChannelAccount), args.Error(1)
3939
}
4040

41+
func (s *ChannelAccountStoreMock) AssignTxToChannelAccount(ctx context.Context, publicKey string, txHash string) error {
42+
args := s.Called(ctx, publicKey, txHash)
43+
return args.Error(0)
44+
}
45+
46+
func (s *ChannelAccountStoreMock) UnassignTxAndUnlockChannelAccount(ctx context.Context, txHash string) error {
47+
args := s.Called(ctx, txHash)
48+
return args.Error(0)
49+
}
50+
4151
func (s *ChannelAccountStoreMock) BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error {
4252
args := s.Called(ctx, sqlExec, channelAccounts)
4353
return args.Error(0)

internal/signing/store/types.go

+10-7
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,21 @@ import (
99
)
1010

1111
type ChannelAccount struct {
12-
PublicKey string `db:"public_key"`
13-
EncryptedPrivateKey string `db:"encrypted_private_key"`
14-
UpdatedAt time.Time `db:"updated_at"`
15-
CreatedAt time.Time `db:"created_at"`
16-
LockedAt sql.NullTime `db:"locked_at"`
17-
LockedUntil sql.NullTime `db:"locked_until"`
12+
PublicKey string `db:"public_key"`
13+
EncryptedPrivateKey string `db:"encrypted_private_key"`
14+
UpdatedAt time.Time `db:"updated_at"`
15+
CreatedAt time.Time `db:"created_at"`
16+
LockedAt sql.NullTime `db:"locked_at"`
17+
LockedUntil sql.NullTime `db:"locked_until"`
18+
LockedTxHash sql.NullString `db:"locked_tx_hash"`
1819
}
1920

2021
type ChannelAccountStore interface {
21-
GetIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error)
22+
GetAndLockIdleChannelAccount(ctx context.Context, lockedUntil time.Duration) (*ChannelAccount, error)
2223
Get(ctx context.Context, sqlExec db.SQLExecuter, publicKey string) (*ChannelAccount, error)
2324
GetAllByPublicKey(ctx context.Context, sqlExec db.SQLExecuter, publicKeys ...string) ([]*ChannelAccount, error)
25+
AssignTxToChannelAccount(ctx context.Context, publicKey string, txHash string) error
26+
UnassignTxAndUnlockChannelAccount(ctx context.Context, txHash string) error
2427
BatchInsert(ctx context.Context, sqlExec db.SQLExecuter, channelAccounts []*ChannelAccount) error
2528
Count(ctx context.Context) (int64, error)
2629
}

0 commit comments

Comments
 (0)