Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion error.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package sphinx

import "fmt"
import (
"errors"
"fmt"
)

var (
// ErrReplayedPacket is an error returned when a packet is rejected
Expand All @@ -24,4 +27,7 @@ var (
// ErrLogEntryNotFound is an error returned when a packet lookup in a replay
// log fails because it is missing.
ErrLogEntryNotFound = fmt.Errorf("sphinx packet is not in log")

// ErrIOReadFull is returned when an io read full operation fails.
ErrIOReadFull = errors.New("io read full error")
)
96 changes: 67 additions & 29 deletions payload.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,48 +87,62 @@ func (hp *HopPayload) Encode(w io.Writer) error {
}

// Decode unpacks an encoded HopPayload from the passed reader into the target
// HopPayload.
func (hp *HopPayload) Decode(r io.Reader) error {
// HopPayload. tlvGuaranteed should be set to true if the caller only wishes to
// accept TLV encoded payloads. By doing so, zero-lengt tlv payloads are
// supported. If set to false, then the function will inspect the first byte to
// determine the type of payload.
func DecodeHopPayload(r io.Reader, tlvGuaranteed bool) (*HopPayload, error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment starts with incorrect name

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think you can condense this to:

func DecodeHopPayload(r io.Reader, tlvGuaranteed bool) (*HopPayload, error) {
	var (
		payloadSize uint16
		payloadType = PayloadTLV
		hmac        [HMACSize]byte
		bufReader   = bufio.NewReader(r)
	)

	peekByte, err := bufReader.Peek(1)
	if err != nil {
		return nil, fmt.Errorf("peek first payload byte: %w", err)
	}

	if !tlvGuaranteed && isLegacyPayloadByte(peekByte[0]) {
		payloadType = PayloadLegacy
		payloadSize = legacyPayloadSize()
	} else {
		payloadSize, err = tlvPayloadSize(bufReader)
		if err != nil {
			return nil, err
		}
	}

	var payload = make([]byte, payloadSize)
	if _, err = io.ReadFull(r, payload); err != nil {
		return nil, fmt.Errorf("%w: %w", ErrIOReadFull, err)
	}

	if _, err = io.ReadFull(r, hmac[:]); err != nil {
		return nil, fmt.Errorf("%w: %w", ErrIOReadFull, err)
	}

	return &HopPayload{
		Type:    payloadType,
		Payload: payload,
		HMAC:    hmac,
	}, nil
}

bufReader := bufio.NewReader(r)

// In order to properly parse the payload, we'll need to check the
// first byte. We'll use a bufio reader to peek at it without consuming
// it from the buffer.
var payloadSize uint16

hopPayload := &HopPayload{}

// If we are not sure if this is a TLV or legacy payload, then we need
// to inspect the first byte to determine the type of payload. The first
// byte is either a realm (legacy) or the beginning of a var-int
// encoding the length of the payload (TLV). We'll use a bufio reader to
// peek at it without consuming it from the buffer.
peekByte, err := bufReader.Peek(1)
if err != nil {
return err
return nil, fmt.Errorf("peek first payload byte: %w", err)
}

var (
legacyPayload = isLegacyPayloadByte(peekByte[0])
payloadSize uint16
)
switch {
case tlvGuaranteed:
// If we're instructed to only accept TLV payloads, then we set
// the type accordingly. This allows us to support zero-length
// TLV payloads.

hopPayload.Type = PayloadTLV

if legacyPayload {
payloadSize = legacyPayloadSize()
hp.Type = PayloadLegacy
} else {
payloadSize, err = tlvPayloadSize(bufReader)
if err != nil {
return err
return nil, err
}

hp.Type = PayloadTLV
}
case isLegacyPayloadByte(peekByte[0]):
// If the first byte indicates that this is a legacy payload,
// then we set the type accordingly.
hopPayload.Type = PayloadLegacy
payloadSize = legacyPayloadSize()

// Now that we know the payload size, we'll create a new buffer to
// read it out in full.
//
// TODO(roasbeef): can avoid all these copies
hp.Payload = make([]byte, payloadSize)
if _, err := io.ReadFull(bufReader, hp.Payload[:]); err != nil {
return err
default:
// Otherwise, we set the type to TLV.
hopPayload.Type = PayloadTLV

payloadSize, err = tlvPayloadSize(bufReader)
if err != nil {
return nil, err
}
}
if _, err := io.ReadFull(bufReader, hp.HMAC[:]); err != nil {
return err

err = readPayloadAndHMAC(hopPayload, bufReader, payloadSize)
if err != nil {
return nil, err
}

return nil
return hopPayload, nil
}

// HopData attempts to extract a set of forwarding instructions from the target
Expand All @@ -146,6 +160,26 @@ func (hp *HopPayload) HopData() (*HopData, error) {
return nil, nil
}

// readPayloadAndHMAC reads the payload and HMAC from the reader into the
// HopPayload.
func readPayloadAndHMAC(hp *HopPayload, r io.Reader, payloadSize uint16) error {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was useful when there were multiple call sites but now there is only one. so let's keep this logic in the single constructor

// Now that we know the payload size, we'll create a new buffer to read
// it out in full.
hp.Payload = make([]byte, payloadSize)

_, err := io.ReadFull(r, hp.Payload)
if err != nil {
return fmt.Errorf("%w: %w", ErrIOReadFull, err)
}

_, err = io.ReadFull(r, hp.HMAC[:])
if err != nil {
return fmt.Errorf("%w: %w", ErrIOReadFull, err)
}

return nil
}

// tlvPayloadSize uses the passed reader to extract the payload length encoded
// as a var-int.
func tlvPayloadSize(r io.Reader) (uint16, error) {
Expand Down Expand Up @@ -314,8 +348,12 @@ func legacyNumBytes() int {
return LegacyHopDataSize
}

// isLegacyPayload returns true if the given byte is equal to the 0x00 byte
// which indicates that the payload should be decoded as a legacy payload.
// isLegacyPayloadByte determines if the first byte of a hop payload indicates
// that it is a legacy payload. The first byte of a legacy payload will always
// be 0x00, as this is the realm. For TLV payloads, the first byte is a
// var-int encoding the length of the payload. A TLV stream can be empty, in
// which case its length is 0, which is also encoded as a 0x00 byte. This
// creates an ambiguity between a legacy payload and an empty TLV payload.
func isLegacyPayloadByte(b byte) bool {
return b == 0x00
}
Expand Down
44 changes: 31 additions & 13 deletions sphinx.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,8 @@ func (r *Router) Stop() {
// processOnionCfg is a set of config values that can be used to modify how an
// onion is processed.
type processOnionCfg struct {
blindingPoint *btcec.PublicKey
blindingPoint *btcec.PublicKey
tlvPayloadOnly bool
}

// ProcessOnionOpt defines the signature of a function option that can be used
Expand All @@ -525,6 +526,14 @@ func WithBlindingPoint(point *btcec.PublicKey) ProcessOnionOpt {
}
}

// WithTLVPayloadOnly is a functional option that signals that the onion packet
// being processed is an onion_message_packet.
func WithTLVPayloadOnly() ProcessOnionOpt {
return func(cfg *processOnionCfg) {
cfg.tlvPayloadOnly = true
}
}

// ProcessOnionPacket processes an incoming onion packet which has been forward
// to the target Sphinx router. If the encoded ephemeral key isn't on the
// target Elliptic Curve, then the packet is rejected. Similarly, if the
Expand Down Expand Up @@ -560,7 +569,9 @@ func (r *Router) ProcessOnionPacket(onionPkt *OnionPacket, assocData []byte,
// Continue to optimistically process this packet, deferring replay
// protection until the end to reduce the penalty of multiple IO
// operations.
packet, err := processOnionPacket(onionPkt, &sharedSecret, assocData)
packet, err := processOnionPacket(
onionPkt, &sharedSecret, assocData, cfg.tlvPayloadOnly,
)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -594,7 +605,9 @@ func (r *Router) ReconstructOnionPacket(onionPkt *OnionPacket, assocData []byte,
return nil, err
}

return processOnionPacket(onionPkt, &sharedSecret, assocData)
return processOnionPacket(
onionPkt, &sharedSecret, assocData, cfg.tlvPayloadOnly,
)
}

// DecryptBlindedHopData uses the router's private key to decrypt data encrypted
Expand Down Expand Up @@ -625,7 +638,8 @@ func (r *Router) OnionPublicKey() *btcec.PublicKey {
// packet. This function returns the next inner onion packet layer, along with
// the hop data extracted from the outer onion packet.
func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
assocData []byte) (*OnionPacket, *HopPayload, error) {
assocData []byte, tlvPayloadOnly bool) (*OnionPacket, *HopPayload,
error) {

dhKey := onionPkt.EphemeralKey
routeInfo := onionPkt.RoutingInfo
Expand All @@ -649,8 +663,8 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
zeroBytes := bytes.Repeat([]byte{0}, MaxPayloadSize)
headerWithPadding := append(routeInfo[:], zeroBytes...)

var hopInfo [numStreamBytes]byte
xor(hopInfo[:], headerWithPadding, streamBytes)
hopInfo := make([]byte, numStreamBytes)
xor(hopInfo, headerWithPadding, streamBytes)

// Randomize the DH group element for the next hop using the
// deterministic blinding factor.
Expand All @@ -660,8 +674,10 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
// With the MAC checked, and the payload decrypted, we can now parse
// out the payload so we can derive the specified forwarding
// instructions.
var hopPayload HopPayload
if err := hopPayload.Decode(bytes.NewReader(hopInfo[:])); err != nil {
hopPayload, err := DecodeHopPayload(
bytes.NewReader(hopInfo), tlvPayloadOnly,
)
if err != nil {
return nil, nil, err
}

Expand All @@ -676,14 +692,14 @@ func unwrapPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
HeaderMAC: hopPayload.HMAC,
}

return innerPkt, &hopPayload, nil
return innerPkt, hopPayload, nil
}

// processOnionPacket performs the primary key derivation and handling of onion
// packets. The processed packets returned from this method should only be used
// if the packet was not flagged as a replayed packet.
func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
assocData []byte) (*ProcessedPacket, error) {
assocData []byte, tlvPayloadOnly bool) (*ProcessedPacket, error) {

// First, we'll unwrap an initial layer of the onion packet. Typically,
// we'll only have a single layer to unwrap, However, if the sender has
Expand All @@ -693,7 +709,7 @@ func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
// they can properly check the HMAC and unwrap a layer for their
// handoff hop.
innerPkt, outerHopPayload, err := unwrapPacket(
onionPkt, sharedSecret, assocData,
onionPkt, sharedSecret, assocData, tlvPayloadOnly,
)
if err != nil {
return nil, err
Expand All @@ -703,7 +719,7 @@ func processOnionPacket(onionPkt *OnionPacket, sharedSecret *Hash256,
// However if the uncovered 'nextMac' is all zeroes, then this
// indicates that we're the final hop in the route.
var action ProcessCode = MoreHops
if bytes.Compare(zeroHMAC[:], outerHopPayload.HMAC[:]) == 0 {
if bytes.Equal(zeroHMAC[:], outerHopPayload.HMAC[:]) {
action = ExitNode
}

Expand Down Expand Up @@ -794,7 +810,9 @@ func (t *Tx) ProcessOnionPacket(seqNum uint16, onionPkt *OnionPacket,
// Continue to optimistically process this packet, deferring replay
// protection until the end to reduce the penalty of multiple IO
// operations.
packet, err := processOnionPacket(onionPkt, &sharedSecret, assocData)
packet, err := processOnionPacket(
onionPkt, &sharedSecret, assocData, cfg.tlvPayloadOnly,
)
if err != nil {
return err
}
Expand Down
56 changes: 56 additions & 0 deletions sphinx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,62 @@ func TestTLVPayloadMessagePacket(t *testing.T) {
hex.EncodeToString(finalPacket), hex.EncodeToString(b.Bytes()))
}

// TestProcessOnionMessageZeroLengthPayload tests that we can properly process
// an onion message that has a zero-length payload.
func TestProcessOnionMessageZeroLengthPayload(t *testing.T) {
t.Parallel()

// First, create a router that will be the destination of the onion
// message.
privKey, err := btcec.NewPrivateKey()
require.NoError(t, err)

router := NewRouter(&PrivKeyECDH{privKey}, NewMemoryReplayLog())
err = router.Start()
require.NoError(t, err)

defer router.Stop()

// Next, create a session key for the onion packet.
sessionKey, err := btcec.NewPrivateKey()
require.NoError(t, err)

// We'll create a simple one-hop path.
path := &PaymentPath{
{
NodePub: *privKey.PubKey(),
},
}

// The hop payload will be an empty TLV payload.
payload, err := NewTLVHopPayload(nil)
require.NoError(t, err)
require.Empty(t, payload.Payload)
path[0].HopPayload = payload

// Now, create the onion packet.
onionPacket, err := NewOnionPacket(
path, sessionKey, nil, DeterministicPacketFiller,
)
require.NoError(t, err)

// We'll now process the packet, making sure to indicate that this is
// an onion message.
processedPacket, err := router.ProcessOnionPacket(
onionPacket, nil, 0, WithTLVPayloadOnly(),
)
require.NoError(t, err)

// The packet should be decoded as an exit node.
require.EqualValues(t, ExitNode, processedPacket.Action)

// The payload should be of type TLV.
require.Equal(t, PayloadTLV, processedPacket.Payload.Type)

// And the payload should be empty.
require.Empty(t, processedPacket.Payload.Payload)
}

func TestSphinxCorrectness(t *testing.T) {
nodes, _, hopDatas, fwdMsg, err := newTestRoute(testLegacyRouteNumHops)
if err != nil {
Expand Down