Skip to content

Commit 23ad6b0

Browse files
committed
Fill with zeros if the error to relay is too short
1 parent 17ac650 commit 23ad6b0

File tree

4 files changed

+20
-19
lines changed

4 files changed

+20
-19
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
2020
import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto}
2121
import fr.acinq.eclair.wire.protocol._
2222
import grizzled.slf4j.Logging
23-
import scodec.Attempt
23+
import scodec.{Attempt, DecodeResult}
2424
import scodec.bits.ByteVector
2525

2626
import scala.annotation.tailrec
@@ -347,10 +347,10 @@ object Sphinx extends Logging {
347347

348348
def create(sharedSecret: ByteVector32, failure: FailureMessage, holdTime: FiniteDuration): ByteVector = {
349349
val failurePayload = FailureMessageCodecs.failureOnionPayload(payloadAndPadLength).encode(failure).require.toByteVector
350-
val zeroPayloads = Seq.fill(maxNumHop)(ByteVector.fill(hopPayloadLength)(0))
351-
val zeroHmacs = (maxNumHop.to(1, -1)).map(Seq.fill(_)(ByteVector.low(4)))
350+
val zeroPayloads = Seq.fill(maxNumHop)(ByteVector.low(hopPayloadLength))
351+
val zeroHmacs = maxNumHop.to(1, -1).map(Seq.fill(_)(ByteVector.low(4)))
352352
val plainError = attributableErrorCodec(totalLength, hopPayloadLength, maxNumHop).encode(AttributableError(failurePayload, zeroPayloads, zeroHmacs)).require.bytes
353-
wrap(plainError, sharedSecret, holdTime, isSource = true).get
353+
wrap(plainError, sharedSecret, holdTime, isSource = true)
354354
}
355355

356356
private def computeHmacs(mac: Mac32, failurePayload: ByteVector, hopPayloads: Seq[ByteVector], hmacs: Seq[Seq[ByteVector]], minNumHop: Int): Seq[ByteVector] = {
@@ -363,9 +363,12 @@ object Sphinx extends Logging {
363363
newHmacs
364364
}
365365

366-
def wrap(errorPacket: ByteVector, sharedSecret: ByteVector32, holdTime: FiniteDuration, isSource: Boolean): Try[ByteVector] = Try {
366+
def wrap(errorPacket: ByteVector, sharedSecret: ByteVector32, holdTime: FiniteDuration, isSource: Boolean): ByteVector = {
367367
val um = generateKey("um", sharedSecret)
368-
val error = attributableErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode(errorPacket.bits).require.value
368+
val error = attributableErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode(errorPacket.bits) match {
369+
case Attempt.Successful(DecodeResult(value, _)) => value
370+
case Attempt.Failure(_) => AttributableError.zero(payloadAndPadLength, hopPayloadLength, maxNumHop)
371+
}
369372
val hopPayloads = hopPayloadCodec.encode(HopPayload(isSource, holdTime)).require.bytes +: error.hopPayloads.dropRight(1)
370373
val hmacs = computeHmacs(Hmac256(um), error.failurePayload, hopPayloads, error.hmacs.map(_.drop(1)), 0) +: error.hmacs.dropRight(1).map(_.drop(1))
371374
val newError = attributableErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).encode(AttributableError(error.failurePayload, hopPayloads, hmacs)).require.bytes
@@ -374,14 +377,6 @@ object Sphinx extends Logging {
374377
newError xor stream
375378
}
376379

377-
def wrapOrCreate(errorPacket: ByteVector, sharedSecret: ByteVector32, holdTime: FiniteDuration): ByteVector =
378-
wrap(errorPacket, sharedSecret, holdTime, isSource = false) match {
379-
case Failure(_) =>
380-
// There is no failure message for this use-case, using TemporaryNodeFailure instead.
381-
create(sharedSecret, TemporaryNodeFailure(), holdTime)
382-
case Success(value) => value
383-
}
384-
385380
private def unwrap(errorPacket: ByteVector, sharedSecret: ByteVector32, minNumHop: Int): Try[(ByteVector, HopPayload)] = Try {
386381
val key = generateKey("ammag", sharedSecret)
387382
val stream = generateStream(key, errorPacket.length.toInt)

eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ object OutgoingPaymentPacket {
314314
Sphinx.peel(nodeSecret, Some(add.paymentHash), add.onionRoutingPacket) match {
315315
case Right(Sphinx.DecryptedPacket(_, _, sharedSecret)) =>
316316
val encryptedReason = reason match {
317-
case Left(forwarded) if useAttributableErrors => Sphinx.AttributableErrorPacket.wrapOrCreate(forwarded, sharedSecret, holdTime)
317+
case Left(forwarded) if useAttributableErrors => Sphinx.AttributableErrorPacket.wrap(forwarded, sharedSecret, holdTime, isSource = false)
318318
case Right(failure) if useAttributableErrors => Sphinx.AttributableErrorPacket.create(sharedSecret, failure, holdTime)
319319
case Left(forwarded) => Sphinx.FailurePacket.wrap(forwarded, sharedSecret)
320320
case Right(failure) => Sphinx.FailurePacket.create(sharedSecret, failure)

eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/AttributableError.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,10 @@ object AttributableError {
4848
(("failure_payload" | bytes(totalLength - metadataLength)) ::
4949
("hop_payloads" | listOfN(provide(maxNumHop), bytes(hopPayloadLength)).xmap[Seq[ByteVector]](_.toSeq, _.toList)) ::
5050
("hmacs" | hmacsCodec(maxNumHop))).as[AttributableError].complete}
51+
52+
def zero(payloadAndPadLength: Int, hopPayloadLength: Int, maxNumHop: Int): AttributableError =
53+
AttributableError(
54+
ByteVector.low(payloadAndPadLength),
55+
Seq.fill(maxNumHop)(ByteVector.low(hopPayloadLength)),
56+
maxNumHop.to(1, -1).map(Seq.fill(_)(ByteVector.low(4))))
5157
}

eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,13 +414,13 @@ class SphinxSpec extends AnyFunSuite {
414414
val Right(decrypted1) = AttributableErrorPacket.decrypt(packet1, (2 to 4).map(i => (sharedSecrets(i), publicKeys(i))))
415415
assert(decrypted1 == expected)
416416

417-
val Success(packet2) = AttributableErrorPacket.wrap(packet1, sharedSecrets(1), 5 millis, isSource = false)
417+
val packet2 = AttributableErrorPacket.wrap(packet1, sharedSecrets(1), 5 millis, isSource = false)
418418
assert(packet2.length == 1200)
419419

420420
val Right(decrypted2) = AttributableErrorPacket.decrypt(packet2, (1 to 4).map(i => (sharedSecrets(i), publicKeys(i))))
421421
assert(decrypted2 == expected)
422422

423-
val Success(packet3) = AttributableErrorPacket.wrap(packet2, sharedSecrets(0), 9 millis, isSource = false)
423+
val packet3 = AttributableErrorPacket.wrap(packet2, sharedSecrets(0), 9 millis, isSource = false)
424424
assert(packet3.length == 1200)
425425

426426
val Right(decrypted3) = AttributableErrorPacket.decrypt(packet3, (0 to 4).map(i => (sharedSecrets(i), publicKeys(i))))
@@ -440,11 +440,11 @@ class SphinxSpec extends AnyFunSuite {
440440
val packet1 = randomBytes(1200)
441441

442442
val hopPayload2 = AttributableError.HopPayload(isPayloadSource = false, 50 millis)
443-
val Success(packet2) = AttributableErrorPacket.wrap(packet1, sharedSecrets(1), 50 millis, isSource = false)
443+
val packet2 = AttributableErrorPacket.wrap(packet1, sharedSecrets(1), 50 millis, isSource = false)
444444
assert(packet2.length == 1200)
445445

446446
val hopPayload3 = AttributableError.HopPayload(isPayloadSource = false, 100 millis)
447-
val Success(packet3) = AttributableErrorPacket.wrap(packet2, sharedSecrets(0), 100 millis, isSource = false)
447+
val packet3 = AttributableErrorPacket.wrap(packet2, sharedSecrets(0), 100 millis, isSource = false)
448448
assert(packet3.length == 1200)
449449

450450
val Left(decryptionError) = AttributableErrorPacket.decrypt(packet3, (0 to 4).map(i => (sharedSecrets(i), publicKeys(i))))

0 commit comments

Comments
 (0)