@@ -20,7 +20,7 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
2020import fr .acinq .bitcoin .scalacompat .{ByteVector32 , Crypto }
2121import fr .acinq .eclair .wire .protocol ._
2222import grizzled .slf4j .Logging
23- import scodec .Attempt
23+ import scodec .{ Attempt , DecodeResult }
2424import scodec .bits .ByteVector
2525
2626import scala .annotation .tailrec
@@ -347,10 +347,10 @@ object Sphinx extends Logging {
347347
348348 def create (sharedSecret : ByteVector32 , failure : FailureMessage , holdTime : FiniteDuration ): ByteVector = {
349349 val failurePayload = FailureMessageCodecs .failureOnionPayload(payloadAndPadLength).encode(failure).require.toByteVector
350- val zeroPayloads = Seq .fill(maxNumHop)(ByteVector .fill (hopPayloadLength)( 0 ))
351- val zeroHmacs = ( maxNumHop.to(1 , - 1 ) ).map(Seq .fill(_)(ByteVector .low(4 )))
350+ val zeroPayloads = Seq .fill(maxNumHop)(ByteVector .low (hopPayloadLength))
351+ val zeroHmacs = maxNumHop.to(1 , - 1 ).map(Seq .fill(_)(ByteVector .low(4 )))
352352 val plainError = attributableErrorCodec(totalLength, hopPayloadLength, maxNumHop).encode(AttributableError (failurePayload, zeroPayloads, zeroHmacs)).require.bytes
353- wrap(plainError, sharedSecret, holdTime, isSource = true ).get
353+ wrap(plainError, sharedSecret, holdTime, isSource = true )
354354 }
355355
356356 private def computeHmacs (mac : Mac32 , failurePayload : ByteVector , hopPayloads : Seq [ByteVector ], hmacs : Seq [Seq [ByteVector ]], minNumHop : Int ): Seq [ByteVector ] = {
@@ -363,9 +363,12 @@ object Sphinx extends Logging {
363363 newHmacs
364364 }
365365
366- def wrap (errorPacket : ByteVector , sharedSecret : ByteVector32 , holdTime : FiniteDuration , isSource : Boolean ): Try [ ByteVector ] = Try {
366+ def wrap (errorPacket : ByteVector , sharedSecret : ByteVector32 , holdTime : FiniteDuration , isSource : Boolean ): ByteVector = {
367367 val um = generateKey(" um" , sharedSecret)
368- val error = attributableErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode(errorPacket.bits).require.value
368+ val error = attributableErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).decode(errorPacket.bits) match {
369+ case Attempt .Successful (DecodeResult (value, _)) => value
370+ case Attempt .Failure (_) => AttributableError .zero(payloadAndPadLength, hopPayloadLength, maxNumHop)
371+ }
369372 val hopPayloads = hopPayloadCodec.encode(HopPayload (isSource, holdTime)).require.bytes +: error.hopPayloads.dropRight(1 )
370373 val hmacs = computeHmacs(Hmac256 (um), error.failurePayload, hopPayloads, error.hmacs.map(_.drop(1 )), 0 ) +: error.hmacs.dropRight(1 ).map(_.drop(1 ))
371374 val newError = attributableErrorCodec(errorPacket.length.toInt, hopPayloadLength, maxNumHop).encode(AttributableError (error.failurePayload, hopPayloads, hmacs)).require.bytes
@@ -374,14 +377,6 @@ object Sphinx extends Logging {
374377 newError xor stream
375378 }
376379
377- def wrapOrCreate (errorPacket : ByteVector , sharedSecret : ByteVector32 , holdTime : FiniteDuration ): ByteVector =
378- wrap(errorPacket, sharedSecret, holdTime, isSource = false ) match {
379- case Failure (_) =>
380- // There is no failure message for this use-case, using TemporaryNodeFailure instead.
381- create(sharedSecret, TemporaryNodeFailure (), holdTime)
382- case Success (value) => value
383- }
384-
385380 private def unwrap (errorPacket : ByteVector , sharedSecret : ByteVector32 , minNumHop : Int ): Try [(ByteVector , HopPayload )] = Try {
386381 val key = generateKey(" ammag" , sharedSecret)
387382 val stream = generateStream(key, errorPacket.length.toInt)
0 commit comments