Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 67 additions & 49 deletions pkg/services/notary/notary.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ type (
// started is a status bool to protect from double start/shutdown.
started atomic.Bool

// reqMtx protects requests list.
// reqMtx protects the request list from concurrent requests addition/removal.
// Use per-request locks instead of this one to perform request-changing operations.
reqMtx sync.RWMutex
// requests represents a map of main transactions which needs to be completed
// with the associated fallback transactions grouped by the main transaction hash
Expand Down Expand Up @@ -89,6 +90,7 @@ const defaultTxChannelCapacity = 100
type (
// request represents Notary service request.
request struct {
lock sync.RWMutex
// isSent indicates whether the main transaction was successfully sent to the network.
isSent bool
main *transaction.Transaction
Expand Down Expand Up @@ -117,7 +119,8 @@ type (
)

// isMainCompleted denotes whether all signatures for the main transaction were collected.
func (r request) isMainCompleted() bool {
// The caller must hold the request RLock.
func (r *request) isMainCompleted() bool {
if r.witnessInfo == nil {
return false
}
Expand Down Expand Up @@ -246,20 +249,22 @@ func (n *Notary) OnNewRequest(payload *payload.P2PNotaryRequest) {

nvbFallback := payload.FallbackTransaction.GetAttributes(transaction.NotValidBeforeT)[0].Value.(*transaction.NotValidBefore).Height
nKeys := payload.MainTransaction.GetAttributes(transaction.NotaryAssistedT)[0].Value.(*transaction.NotaryAssisted).NKeys
newInfo, validationErr := n.verifyIncompleteWitnesses(payload.MainTransaction, nKeys)
newInfo, validationErr := verifyIncompleteWitnesses(payload.MainTransaction, nKeys)
if validationErr != nil {
n.Config.Log.Info("verification of main notary transaction failed; fallback transaction will be completed",
zap.String("main hash", payload.MainTransaction.Hash().StringLE()),
zap.String("fallback hash", payload.FallbackTransaction.Hash().StringLE()),
zap.String("verification error", validationErr.Error()))
}
n.reqMtx.Lock()
defer n.reqMtx.Unlock()
r, exists := n.requests[payload.MainTransaction.Hash()]
if exists {
r.lock.Lock() // RLock doesn't fit here since we modify r.minNotValidBefore below.
defer r.lock.Unlock()
if slices.ContainsFunc(r.fallbacks, func(fb *transaction.Transaction) bool {
return fb.Hash().Equals(payload.FallbackTransaction.Hash())
}) {
n.reqMtx.Unlock()
return // then we already have processed this request
}
r.minNotValidBefore = min(r.minNotValidBefore, nvbFallback)
Expand All @@ -270,8 +275,11 @@ func (n *Notary) OnNewRequest(payload *payload.P2PNotaryRequest) {
main: payload.MainTransaction.Copy(),
minNotValidBefore: nvbFallback,
}
r.lock.Lock()
defer r.lock.Unlock()
n.requests[payload.MainTransaction.Hash()] = r
}
n.reqMtx.Unlock()
if r.witnessInfo == nil && validationErr == nil {
r.witnessInfo = newInfo
}
Expand Down Expand Up @@ -347,21 +355,21 @@ func (n *Notary) OnRequestRemoval(pld *payload.P2PNotaryRequest) {
return
}

n.reqMtx.Lock()
defer n.reqMtx.Unlock()
n.reqMtx.RLock()
r, ok := n.requests[pld.MainTransaction.Hash()]
n.reqMtx.RUnlock()
if !ok {
return
}

r.lock.Lock()
for i, fb := range r.fallbacks {
if fb.Hash().Equals(pld.FallbackTransaction.Hash()) {
r.fallbacks = append(r.fallbacks[:i], r.fallbacks[i+1:]...)
break
}
}
if len(r.fallbacks) == 0 {
delete(n.requests, r.main.Hash())
}
r.lock.Unlock()
}

// PostPersist is a callback which is called after a new block event is received.
Expand All @@ -379,12 +387,18 @@ func (n *Notary) PostPersist() {
defer n.reqMtx.Unlock()
currHeight := n.Config.Chain.BlockHeight()
for h, r := range n.requests {
r.lock.Lock()
if len(r.fallbacks) == 0 {
delete(n.requests, r.main.Hash())
continue
}
if !r.isSent && r.isMainCompleted() && r.minNotValidBefore > currHeight {
if err := n.finalize(acc, r.main, h); err != nil {
n.Config.Log.Error("failed to finalize main transaction after PostPersist, waiting for the next block to retry",
zap.String("hash", r.main.Hash().StringLE()),
zap.Error(err))
}
r.lock.Unlock()
continue
}
if r.minNotValidBefore <= currHeight { // then at least one of the fallbacks can already be sent.
Expand All @@ -400,6 +414,7 @@ func (n *Notary) PostPersist() {
}
}
}
r.lock.Unlock()
}
}

Expand Down Expand Up @@ -446,53 +461,56 @@ func (n *Notary) newTxCallbackLoop() {
for {
select {
case tx := <-n.newTxs:
isMain := tx.tx.Hash() == tx.mainHash

n.reqMtx.Lock()
n.reqMtx.RLock()
r, ok := n.requests[tx.mainHash]
if !ok || isMain && (r.isSent || r.minNotValidBefore <= n.Config.Chain.BlockHeight()) {
n.reqMtx.Unlock()
n.reqMtx.RUnlock()
if !ok {
continue
}
if !isMain {
// Ensure that fallback was not already completed.
var isPending = slices.ContainsFunc(r.fallbacks, func(fb *transaction.Transaction) bool {
return fb.Hash() == tx.tx.Hash()
})
if !isPending {
n.reqMtx.Unlock()
continue
}
}
n.handleNewTx(r, tx)
case <-n.stopCh:
return
}
}
}

n.reqMtx.Unlock()
err := n.onTransaction(tx.tx)
if err != nil {
n.Config.Log.Error("new transaction callback finished with error",
zap.Error(err),
zap.Bool("is main", isMain))
continue
}
// handleNewTx tries to send a finalized transaction (either main or fallback) to the network.
func (n *Notary) handleNewTx(r *request, tx txHashPair) {
isMain := tx.tx.Hash() == tx.mainHash

n.reqMtx.Lock()
if isMain {
r.isSent = true
} else {
for i := range r.fallbacks {
if r.fallbacks[i].Hash() == tx.tx.Hash() {
r.fallbacks = append(r.fallbacks[:i], r.fallbacks[i+1:]...)
break
}
}
if len(r.fallbacks) == 0 {
delete(n.requests, tx.mainHash)
}
}
n.reqMtx.Unlock()
case <-n.stopCh:
r.lock.Lock()
defer r.lock.Unlock()
if isMain && (r.isSent || r.minNotValidBefore <= n.Config.Chain.BlockHeight()) {
return
}

if !isMain {
// Ensure that fallback was not already completed.
var isPending = slices.ContainsFunc(r.fallbacks, func(fb *transaction.Transaction) bool {
return fb.Hash() == tx.tx.Hash()
})
if !isPending {
return
}
}

err := n.onTransaction(tx.tx)
if err != nil {
n.Config.Log.Error("new transaction callback finished with error",
zap.Error(err),
zap.Bool("is main", isMain))
return
}
if isMain {
r.isSent = true
} else {
for i := range r.fallbacks {
if r.fallbacks[i].Hash() == tx.tx.Hash() {
r.fallbacks = append(r.fallbacks[:i], r.fallbacks[i+1:]...)
break
}
}
}
}

// updateTxSize returns a transaction with re-calculated size and an error.
Expand All @@ -508,7 +526,7 @@ func updateTxSize(tx *transaction.Transaction) (*transaction.Transaction, error)
// verifyIncompleteWitnesses checks that the tx either doesn't have all witnesses attached (in this case none of them
// can be multisignature) or it only has a partial multisignature. It returns the request type (sig/multisig), the
// number of signatures to be collected, sorted public keys (for multisig request only) and an error.
func (n *Notary) verifyIncompleteWitnesses(tx *transaction.Transaction, nKeysExpected uint8) ([]witnessInfo, error) {
func verifyIncompleteWitnesses(tx *transaction.Transaction, nKeysExpected uint8) ([]witnessInfo, error) {
var nKeysActual uint8
if len(tx.Signers) < 2 {
return nil, errors.New("transaction should have at least 2 signers")
Expand Down
6 changes: 2 additions & 4 deletions pkg/services/notary/notary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ func TestWallet(t *testing.T) {
}

func TestVerifyIncompleteRequest(t *testing.T) {
bc := fakechain.NewFakeChain()
notaryContractHash := nativehashes.Notary
_, ntr, _ := getTestNotary(t, bc, "./testdata/notary1.json", "one")
sig := append([]byte{byte(opcode.PUSHDATA1), keys.SignatureLen}, make([]byte, keys.SignatureLen)...) // we're not interested in signature correctness
acc1, _ := keys.NewPrivateKey()
acc2, _ := keys.NewPrivateKey()
Expand All @@ -65,7 +63,7 @@ func TestVerifyIncompleteRequest(t *testing.T) {
multisigScriptHash2 := hash.Hash160(multisigScript2)

checkErr := func(t *testing.T, tx *transaction.Transaction, nKeys uint8) {
witnessInfo, err := ntr.verifyIncompleteWitnesses(tx, nKeys)
witnessInfo, err := verifyIncompleteWitnesses(tx, nKeys)
require.Error(t, err)
require.Nil(t, witnessInfo)
}
Expand Down Expand Up @@ -475,7 +473,7 @@ func TestVerifyIncompleteRequest(t *testing.T) {

for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
actualInfo, err := ntr.verifyIncompleteWitnesses(testCase.tx, testCase.nKeys)
actualInfo, err := verifyIncompleteWitnesses(testCase.tx, testCase.nKeys)
require.NoError(t, err)
require.Equal(t, len(testCase.expectedInfo), len(actualInfo))
for i, expected := range testCase.expectedInfo {
Expand Down
Loading