From 3d8c57dba7d7a74a410a98262720adde677ac362 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Fri, 27 Oct 2023 12:59:07 -0700 Subject: [PATCH 1/6] fn: introduce option type this commit introduces many of the most common functions you will want to use with the Option type. Not all of them are used immediately in this PR. --- fn/option.go | 149 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 fn/option.go diff --git a/fn/option.go b/fn/option.go new file mode 100644 index 0000000000..a2c3afdc25 --- /dev/null +++ b/fn/option.go @@ -0,0 +1,149 @@ +package fn + +// Option[A] represents a value which may or may not be there. This is very +// often preferable to nil-able pointers. +type Option[A any] struct { + isSome bool + some A +} + +// Some trivially injects a value into an optional context. +// +// Some : A -> Option[A]. +func Some[A any](a A) Option[A] { + return Option[A]{ + isSome: true, + some: a, + } +} + +// None trivially constructs an empty option +// +// None : Option[A]. +func None[A any]() Option[A] { + return Option[A]{} +} + +// ElimOption is the universal Option eliminator. It can be used to safely +// handle all possible values inside the Option by supplying two continuations. +// +// ElimOption : (Option[A], () -> B, A -> B) -> B. +func ElimOption[A, B any](o Option[A], b func() B, f func(A) B) B { + if o.isSome { + return f(o.some) + } + + return b() +} + +// UnwrapOr is used to extract a value from an option, and we supply the default +// value in the case when the Option is empty. +// +// UnwrapOr : (Option[A], A) -> A. +func (o Option[A]) UnwrapOr(a A) A { + if o.isSome { + return o.some + } + + return a +} + +// WhenSome is used to conditionally perform a side-effecting function that +// accepts a value of the type that parameterizes the option. If this function +// performs no side effects, WhenSome is useless. +// +// WhenSome : (Option[A], A -> ()) -> (). +func (o Option[A]) WhenSome(f func(A)) { + if o.isSome { + f(o.some) + } +} + +// IsSome returns true if the Option contains a value +// +// IsSome : Option[A] -> bool. +func (o Option[A]) IsSome() bool { + return o.isSome +} + +// IsNone returns true if the Option is empty +// +// IsNone : Option[A] -> bool. +func (o Option[A]) IsNone() bool { + return !o.isSome +} + +// FlattenOption joins multiple layers of Options together such that if any of +// the layers is None, then the joined value is None. Otherwise the innermost +// Some value is returned. +// +// FlattenOption : Option[Option[A]] -> Option[A]. +func FlattenOption[A any](oo Option[Option[A]]) Option[A] { + if oo.IsNone() { + return None[A]() + } + if oo.some.IsNone() { + return None[A]() + } + + return oo.some +} + +// ChainOption transforms a function A -> Option[B] into one that accepts an +// Option[A] as an argument. +// +// ChainOption : (A -> Option[B]) -> Option[A] -> Option[B]. +func ChainOption[A, B any](f func(A) Option[B]) func(Option[A]) Option[B] { + return func(o Option[A]) Option[B] { + if o.isSome { + return f(o.some) + } + + return None[B]() + } +} + +// MapOption transforms a pure function A -> B into one that will operate +// inside the Option context. +// +// MapOption : (A -> B) -> Option[A] -> Option[B]. +func MapOption[A, B any](f func(A) B) func(Option[A]) Option[B] { + return func(o Option[A]) Option[B] { + if o.isSome { + return Some(f(o.some)) + } + + return None[B]() + } +} + +// LiftA2Option transforms a pure function (A, B) -> C into one that will +// operate in an Option context. For the returned function, if either of its +// arguments are None, then the result will be None. +// +// LiftA2Option : ((A, B) -> C) -> (Option[A], Option[B]) -> Option[C]. +func LiftA2Option[A, B, C any]( + f func(A, B) C, +) func(Option[A], Option[B]) Option[C] { + + return func(o1 Option[A], o2 Option[B]) Option[C] { + if o1.isSome && o2.isSome { + return Some(f(o1.some, o2.some)) + } + + return None[C]() + } +} + +// Alt chooses the left Option if it is full, otherwise it chooses the right +// option. This can be useful in a long chain if you want to choose between +// many different ways of producing the needed value. +// +// Alt : Option[A] -> Option[A] -> Option[A]. +func (o Option[A]) Alt(o2 Option[A]) Option[A] { + if o.isSome { + return o + } + + return o2 +} From 18e06f379474a2b58fb1bb8215cfbb9c902bcca2 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 1 Nov 2023 20:35:25 -0700 Subject: [PATCH 2/6] fn: add mvar implementation --- fn/mvar.go | 180 ++++++++++++++++++++++++++ fn/mvar_test.go | 327 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 507 insertions(+) create mode 100644 fn/mvar.go create mode 100644 fn/mvar_test.go diff --git a/fn/mvar.go b/fn/mvar.go new file mode 100644 index 0000000000..6f5312d00d --- /dev/null +++ b/fn/mvar.go @@ -0,0 +1,180 @@ +package fn + +import ( + "sync/atomic" +) + +// MVar[A any] is a structure that is designed to store a single value in an API +// that dispenses with data races. Think of it as a box for a value. +// +// It has two states: full and empty. +// +// It supports two operations: take and put. +// +// The state transition rules are as follows: +// 1. put while full blocks. +// 2. put while empty sets. +// 3. take while full resets. +// 4. take while empty blocks. +// 5. read while full nops. +// 6. read while empty blocks. +type MVar[A any] struct { + // current is an immediately available copy of whatever is inside the + // value channel that is served to readers. It is updated whenever a + // change to the value channel is successful. + current *atomic.Pointer[A] + // readers is used to wake all blocked readers when a new value is + // written. + readers chan chan A + + // takers is used to wake a single taker when a new value is written. + takers chan chan A + + // value is a bounded channel of size 1 that represents the core state + // oof the channel. + value chan A +} + +// Zero initializes an MVar that has no values in it. In this state, TakeMVar +// will block and PutMVar will immediately succeed. +// +// Zero : () -> MVar[A]. +func Zero[A any]() MVar[A] { + ptr := atomic.Pointer[A]{} + + return MVar[A]{ + current: &ptr, + readers: make(chan chan A), + takers: make(chan chan A), + value: make(chan A, 1), + } +} + +// NewMVar initializes an MVar that has a value in it from the getgo. In this +// state, TakeMVar will succeed immediately and PutMVar will block. +// +// NewMVar : A -> MVar[A]. +func NewMVar[A any](a A) MVar[A] { + z := Zero[A]() + z.value <- a + z.current.Store(&a) + + return z +} + +// Take will wait for a value to be put into the MVar and then immediately +// take it out. +// +// Take : MVar[A] -> A. +func (m *MVar[A]) Take() A { + select { + case v := <-m.value: + m.current.Store(nil) + return v + default: + t := make(chan A) + m.takers <- t + return <-t + } +} + +// TryTake is the non-blocking version of TakeMVar, it will return an +// None() Option if it would have blocked. +// +// TryTake : MVar[A] -> Option[A]. +func (m *MVar[A]) TryTake() Option[A] { + select { + case v := <-m.value: + m.current.Store(nil) + return Some(v) + default: + return None[A]() + } +} + +// Put will wait for a value to be made empty and will immediately replace it +// with the argument. +// +// Put : (MVar[A], A) -> (). +func (m *MVar[A]) Put(a A) { +readLoop: + // Give the newly put value to all of the waiting readers. + for { + select { + case r := <-m.readers: + r <- a + default: + break readLoop + } + } + + // Give the newly put value to a single taker if one exists. If there + // are no available takers, then store it in the MVar. Since the value + // channel is bounded with capacity 1, subsequent put operations will + // block. + select { + case t := <-m.takers: + t <- a + default: + m.value <- a + m.current.Store(&a) + } +} + +// TryPut is the non-blocking version of Put and will return true if the MVar is +// successfully set. +// +// TryPut : (MVar[A], A) -> bool. +func (m *MVar[A]) TryPut(a A) bool { + select { + case m.value <- a: + m.current.Store(&a) + return true + default: + return false + } +} + +// Read will atomically read the contents of the MVar. If the MVar is empty, +// Read will block until a value is put in. Callers of Read are guaranteed to +// be woken up before callers of Take. +// +// Read : MVar[A] -> A. +func (m *MVar[A]) Read() A { + // Check to see if MVar has something in it. + if ptr := m.current.Load(); ptr != nil { + return *ptr + } + + // It's empty so we need to wait. + r := make(chan A) + m.readers <- r + + return <-r +} + +// TryRead will atomically read the contents of the MVar if it is full. +// Otherwise, it will return None. +// +// TryRead : MVar[A] -> Option[A]. +func (m *MVar[A]) TryRead() Option[A] { + if ptr := m.current.Load(); ptr != nil { + return Some(*ptr) + } + + return None[A]() +} + +// IsFull will return true if the MVar currently has a value in it. +// +// IsFull : MVar[A] -> bool. +func (m *MVar[A]) IsFull() bool { + return m.current.Load() != nil +} + +// IsEmpty will return true if the MVar currently does not have a value in it. +// +// IsEmpty : MVar[A] -> bool. +func (m *MVar[A]) IsEmpty() bool { + return m.current.Load() == nil +} diff --git a/fn/mvar_test.go b/fn/mvar_test.go new file mode 100644 index 0000000000..0e06cf4150 --- /dev/null +++ b/fn/mvar_test.go @@ -0,0 +1,327 @@ +package fn + +import ( + "sync" + "sync/atomic" + "testing" + "testing/quick" + "time" + + "github.com/stretchr/testify/require" +) + +// blockTimeout is a parameter that defines all of the waiting periods for the +// tests in this file. Generally there is a tradeoff here where the higher this +// value is, the less flaky the tests will be, at the expense of the tests +// taking longer to execute. +const blockTimeout = time.Millisecond + +// TestTakeZeroBlocks ensures that if we initialize an empty MVar and +// immediately try to Take from it, it will block. +func TestTakeZeroBlocks(t *testing.T) { + t.Parallel() + + m := Zero[uint8]() + require.True(t, blocks(func() { m.Take() })) +} + +// TestTakeNewMVarProceeds ensures that if we initialize an MVar with a value +// in it and immediately try to Take from it, it will succeed. +func TestTakeNewMVarProceeds(t *testing.T) { + t.Parallel() + + m := NewMVar[uint8](0) + require.False(t, blocks(func() { m.Take() })) +} + +// TestPutNewMVarBlocks ensures that if we initialize an MVar with a value in +// it and immediately try to Put a new value into it, it will block. +func TestPutNewMVarBlocks(t *testing.T) { + t.Parallel() + + m := NewMVar[uint8](0) + require.True(t, blocks(func() { m.Put(1) })) +} + +// TestPutZeroProceeds ensures that if we initialize an empty Mvar and then try +// to Put a new value into it, it will succeed. +func TestPutZeroProceeds(t *testing.T) { + t.Parallel() + + m := Zero[uint8]() + require.False(t, blocks(func() { m.Put(1) })) +} + +// TestPutWhenEmptyLeavesFull ensures that we successfully leave the Mvar in a +// full state after executing a Put in an empty state. +func TestPutWhenEmptyLeavesFull(t *testing.T) { + t.Parallel() + + m := Zero[uint8]() + m.Put(0) + if m.IsEmpty() { + t.Fatal("Put left empty") + } +} + +// TestTakeWhenFullLeavesEmpty ensures that we successfully leave the Mvar in an +// empty state after executing a Take in a full state. +func TestTakeWhenFullLeavesEmpty(t *testing.T) { + t.Parallel() + + m := NewMVar[uint8](0) + m.Take() + if m.IsFull() { + t.Fatal("Take left full") + } +} + +// TestReadWhenFullLeavesFull ensures that a Read when in a full state does not +// change the state of the MVar. +func TestReadWhenFullLeavesFull(t *testing.T) { + t.Parallel() + + m := NewMVar[uint8](0) + m.Read() + if m.IsEmpty() { + t.Fatal("Read left empty") + } +} + +// TestTakeAfterTryTakeBlocks ensures that regardless of what state the Mvar +// begins in, if we try to Take immediately after a TryTake, it will always +// block. +func TestTakeAfterTryTakeBlocks(t *testing.T) { + t.Parallel() + + err := quick.Check(func(set bool, n uint8) bool { + m := gen(set, n) + m.TryTake() + return blocks(func() { m.Take() }) + }, nil) + require.NoError(t, err, "Take after TryTake did not block") +} + +// TestPutAfterTryPutBlocks ensures that regardless of what state the MVar +// begins in, if we try to Put immediately after a TryPut, it will always block. +func TestPutAfterTryPutBlocks(t *testing.T) { + t.Parallel() + + err := quick.Check(func(set bool, n uint8) bool { + m := gen(set, n) + m.TryPut(0) + return blocks(func() { m.Put(1) }) + }, nil) + require.NoError(t, err, "Put after TryPut did not block") +} + +// TestTryTakeLeavesEmpty ensures that regardless of what state the MVar begins +// in, if we execute a TryTake, the resulting MVar state is empty. +func TestTryTakeLeavesEmpty(t *testing.T) { + t.Parallel() + + err := quick.Check(func(set bool, n uint8) bool { + m := gen(set, n) + m.TryTake() + return m.IsEmpty() + }, nil) + require.NoError(t, err, "TryTake did not leave empty") +} + +// TestTryPutLeavesFull ensures that regardless of what state the MVar begins +// in, if we execute a TryPut, the resulting MVar state is full. +func TestTryPutLeavesFull(t *testing.T) { + t.Parallel() + + err := quick.Check(func(set bool, n uint8) bool { + m := gen(set, n) + m.TryPut(n) + return m.IsFull() + }, nil) + require.NoError(t, err, "TryPut did not leave full") +} + +// TestReadWhenEmptyBlocks ensures that if an MVar is in an empty state, a +// Read operation will block. +func TestReadWhenEmptyBlocks(t *testing.T) { + t.Parallel() + + err := quick.Check(func(set bool, n uint8) bool { + m := gen(set, n) + return implies(m.IsEmpty(), blocks(func() { m.Read() })) + }, nil) + require.NoError(t, err, "Read did not block when empty") +} + +// TestTryReadNops ensures that a TryRead will not change the state of the MVar. +// It implicitly tests a second property which is that TryRead will never block. +func TestTryReadNops(t *testing.T) { + t.Parallel() + + err := quick.Check(func(set bool, n uint8) bool { + m := gen(set, n) + before := m.IsEmpty() + tryReadBlocked := blocks(func() { m.TryRead() }) + after := m.IsEmpty() + + return before == after && !tryReadBlocked + }, nil) + require.NoError(t, err, "TryRead did not leave state unchanged") +} + +// TestPutWakesAllReaders ensures the property that if we have many blocked +// Read operations that are waiting for the MVar to be filled, all of them are +// woken up when we execute a Put operation. +func TestPutWakesAllReaders(t *testing.T) { + t.Parallel() + + m := Zero[uint8]() + v := uint8(21) + n := uint32(10) + + counter := atomic.Uint32{} + wg := sync.WaitGroup{} + for i := uint32(0); i < n; i++ { + wg.Add(1) + go func() { + x := m.Read() + if x == v { + counter.Add(1) + } + wg.Done() + }() + } + m.Put(v) + wg.Wait() + require.Equal(t, counter.Load(), n, "not all readers given same value") +} + +// TestPutWakesReadersBeforeTaker ensures the property that a waiting taker +// does not preempt any waiting readers. This test construction is a bit +// delicate, using a Sleep to ensure that all goroutines that are set up to +// wait on the Put have gotten to the point where they actually block. +func TestPutWakesReadersBeforeTaker(t *testing.T) { + t.Parallel() + + m := Zero[uint8]() + v := uint8(21) + n := uint32(10) + + counter := atomic.Uint32{} + wg := sync.WaitGroup{} + + // Set up taker first. + wg.Add(1) + go func() { + x := m.Take() + if x == v { + counter.Add(1) + } + wg.Done() + }() + + // Set up readers. + for i := uint32(0); i < n; i++ { + wg.Add(1) + go func() { + x := m.Read() + if x == v { + counter.Add(1) + } + wg.Done() + }() + } + + time.Sleep(blockTimeout) // Forgive me + + m.Put(v) + wg.Wait() + require.Equal( + t, counter.Load(), n+1, "readers did not wake before taker", + ) +} + +// TestPutWakesOneTaker ensures the property that only a single blocked Take +// operation wakes when a Put comes in. This test construction is a bit delicate +// using a Sleep to wait for the counter to be incremented by the Take goroutine +// after waking. +func TestPutWakesOneTaker(t *testing.T) { + t.Parallel() + + m := Zero[uint8]() + v := uint8(21) + n := uint32(10) + + counter := atomic.Uint32{} + wg := sync.WaitGroup{} + for i := uint32(0); i < n; i++ { + wg.Add(1) + go func() { + x := m.Take() + if x == v { + counter.Add(1) + } + wg.Done() + }() + } + m.Put(v) + + time.Sleep(blockTimeout) // Forgive me + + require.Equal( + t, + counter.Load(), + uint32(1), + "put wakes zero or more than one taker ", + ) +} + +// TestTakeWakesPutter ensures that if there is a blocked Put operation due to +// the MVar being full, that it unblocks when a Take operation is executed. This +// is because the Take operation would set the MVar to an empty state, allowing +// the blocked Put to proceed. +func TestTakeWakesPutter(t *testing.T) { + t.Parallel() + + m := NewMVar[uint8](0) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { m.Put(1); wg.Done() }() + m.Take() + wg.Wait() + require.Equal(t, m.Read(), uint8(1)) +} + +// blocks is a helper function to decide if the supplied function blocks or not. +// This is not fool-proof since it does make a judgement call based off of the +// file-global timeout parameter. +func blocks(f func()) bool { + unblocked := make(chan struct{}) + go func() { f(); unblocked <- struct{}{} }() + + select { + case <-unblocked: + return false + case <-time.NewTimer(blockTimeout).C: + return true + } +} + +// implies is a helper function that computes the `=>` operation from boolean +// algebra. It is true when the first argument is false, or if both arguments +// are true. +func implies(b bool, b2 bool) bool { + return !b || b && b2 +} + +// gen is a helper function whose first argument decides if the returned MVar +// should be full or empty, and the second argument decides what value should +// be in the MVar if it is full. We use this because it is substantially easier +// than teaching testing.Quick how to generate random MVar[uint8] values. +func gen(set bool, n uint8) MVar[uint8] { + if set { + return NewMVar[uint8](n) + } + + return Zero[uint8]() +} From 3c56bb29f0c0e0c12461859ec9eb8fc739f3494d Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Mon, 6 Nov 2023 15:30:02 -0800 Subject: [PATCH 3/6] htlcswitch+peer: add flush api to channel link --- htlcswitch/interfaces.go | 14 ++++++++++++++ htlcswitch/link.go | 28 ++++++++++++++++++++++++++++ htlcswitch/mock.go | 9 +++++++++ peer/test_utils.go | 13 +++++++++++++ 4 files changed, 64 insertions(+) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 304f4a5957..b8d868a136 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -136,6 +136,20 @@ type ChannelUpdateHandler interface { // clean. This can be used with dynamic commitment negotiation or coop // close negotiation which require a clean channel state. ShutdownIfChannelClean() error + + // Flush is a method that disables htlc adds to the channel until it has + // reached an empty state. When we reach zero HTLCs, the supplied + // function will be called. + Flush(func()) error + + // CancelFlush will abort an in-progress flush. If there is no + // current flush operation taking place, then this function will return + // an error. + CancelFlush() error + + // IsFlushing returns true if there is a currently in-progress flush + // operation. + IsFlushing() bool } // ChannelLink is an interface which represents the subsystem for managing the diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 46aee939d1..9f498746be 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -19,6 +19,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/contractcourt" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/invoices" @@ -364,6 +365,10 @@ type channelLink struct { // resolving those htlcs when we receive a message on hodlQueue. hodlMap map[models.CircuitKey]hodlHtlc + // flushCont is a function that is called when the channel finishes + // flushing. + flushCont fn.MVar[func()] + // log is a link-specific logging instance. log btclog.Logger @@ -393,6 +398,7 @@ func NewChannelLink(cfg ChannelLinkConfig, hodlQueue: queue.NewConcurrentQueue(10), log: build.NewPrefixLog(logPrefix, log), quit: make(chan struct{}), + flushCont: fn.Zero[func()](), } } @@ -528,6 +534,28 @@ func (l *channelLink) Stop() { } } +func (l *channelLink) Flush(onFlushed func()) error { + if !l.flushCont.TryPut(onFlushed) { + return errors.New( + "can't flush because flush already in progress", + ) + } + + return nil +} + +func (l *channelLink) CancelFlush() error { + if l.flushCont.TryTake().IsNone() { + return errors.New("no flush in progress to cancel") + } + + return nil +} + +func (l *channelLink) IsFlushing() bool { + return l.flushCont.IsFull() +} + // WaitForShutdown blocks until the link finishes shutting down, which includes // termination of all dependent goroutines. func (l *channelLink) WaitForShutdown() { diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index fe593c9c52..6db376d8aa 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -905,6 +905,15 @@ func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) { f.eligible = true return f.shortChanID, nil } +func (f *mockChannelLink) Flush(onFlushed func()) error { + return errors.New("mockChannelLink does not support flush api") +} +func (f *mockChannelLink) CancelFlush() error { + return errors.New("mockChannelLink does not support flush api") +} +func (f *mockChannelLink) IsFlushing() bool { + return false +} var _ ChannelLink = (*mockChannelLink)(nil) diff --git a/peer/test_utils.go b/peer/test_utils.go index add15cf19d..bd3a0d721b 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -4,6 +4,7 @@ import ( "bytes" crand "crypto/rand" "encoding/binary" + "errors" "io" "math/rand" "net" @@ -499,6 +500,18 @@ type mockMessageConn struct { curReadMessage []byte } +func (m *mockUpdateHandler) Flush(func()) error { + return errors.New("mockUpdateHandler does not support flush api") +} + +func (m *mockUpdateHandler) CancelFlush() error { + return errors.New("mockUpdateHandler does not support flush api") +} + +func (m *mockUpdateHandler) IsFlushing() bool { + return false +} + func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn { return &mockMessageConn{ t: t, From d2257fdb9900858ec9de32ddea9ac558f1ee0241 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Mon, 6 Nov 2023 16:32:08 -0800 Subject: [PATCH 4/6] htlcswitch: update link logic to drive forward active flushes --- htlcswitch/link.go | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 9f498746be..a0f3f4fddd 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -570,7 +570,8 @@ func (l *channelLink) WaitForShutdown() { func (l *channelLink) EligibleToForward() bool { return l.channel.RemoteNextRevocation() != nil && l.ShortChanID() != hop.Source && - l.isReestablished() + l.isReestablished() && + !l.IsFlushing() } // isReestablished returns true if the link has successfully completed the @@ -1302,6 +1303,20 @@ func (l *channelLink) htlcManager() { case <-l.quit: return } + + // After we are finished processing the event, if the link is + // flushing, we check if the channel is clean and invoke the + // post-flush hook if it is. + if l.IsFlushing() && l.channel.IsChannelClean() { + // This will not block since flushCont must be full. + // We Read instead of Take to ensure a new flush + // operation can't be initiated until the continuation + // for the current flush has completed. + l.flushCont.Read()() + + // Reset the flushCont MVar. This will also not block. + l.flushCont.Take() + } } } @@ -2166,7 +2181,6 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { default: l.log.Warnf("received unknown message of type %T", msg) } - } // ackDownStreamPackets is responsible for removing htlcs from a link's mailbox @@ -3081,6 +3095,25 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, fwdInfo := pld.ForwardingInfo() + // If we are in a flush state we need to cancel back all of the + // net new HTLCs rather than forwarding them. This is the first + // opportunity we have to bounce invalid HTLC adds without + // doing a force-close. + if l.IsFlushing() { + var isReceive bool + switch fwdInfo.NextHop { + case hop.Exit: + isReceive = true + default: + isReceive = false + } + failure := lnwire.NewTemporaryChannelFailure(nil) + l.sendHTLCError( + pd, NewLinkError(failure), obfuscator, + isReceive, + ) + } + switch fwdInfo.NextHop { case hop.Exit: err := l.processExitHop( From 0ca384e655b2482ae082b705ea3c0bd4eee2ddb4 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Mon, 6 Nov 2023 18:15:08 -0800 Subject: [PATCH 5/6] htlcswitch+peer: make shutdown procedure use flush api This commit removes the requirement that the channel state is clean prior to shutdown. Now we invoke the new flush api and make the htlcManager quit and remove the link from the switch when the flush is complete. --- htlcswitch/interfaces.go | 2 +- htlcswitch/link.go | 44 +++------------- htlcswitch/link_test.go | 105 +++++++++++++++++++++------------------ htlcswitch/mock.go | 2 +- peer/brontide.go | 25 +++++----- peer/test_utils.go | 2 +- 6 files changed, 78 insertions(+), 102 deletions(-) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index b8d868a136..3d068dd813 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -135,7 +135,7 @@ type ChannelUpdateHandler interface { // ShutdownIfChannelClean shuts the link down if the channel state is // clean. This can be used with dynamic commitment negotiation or coop // close negotiation which require a clean channel state. - ShutdownIfChannelClean() error + ShutdownHtlcManager() // Flush is a method that disables htlc adds to the channel until it has // reached an empty state. When we reach zero HTLCs, the supplied diff --git a/htlcswitch/link.go b/htlcswitch/link.go index a0f3f4fddd..eb7efcdc33 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -274,13 +274,6 @@ type ChannelLinkConfig struct { GetAliases func(base lnwire.ShortChannelID) []lnwire.ShortChannelID } -// shutdownReq contains an error channel that will be used by the channelLink -// to send an error if shutdown failed. If shutdown succeeded, the channel will -// be closed. -type shutdownReq struct { - err chan error -} - // channelLink is the service which drives a channel's commitment update // state-machine. In the event that an HTLC needs to be propagated to another // link, the forward handler from config is used which sends HTLC to the @@ -343,7 +336,7 @@ type channelLink struct { // shutdownRequest is a channel that the channelLink will listen on to // service shutdown requests from ShutdownIfChannelClean calls. - shutdownRequest chan *shutdownReq + shutdownRequest chan struct{} // updateFeeTimer is the timer responsible for updating the link's // commitment fee every time it fires. @@ -393,7 +386,7 @@ func NewChannelLink(cfg ChannelLinkConfig, cfg: cfg, channel: channel, shortChanID: channel.ShortChanID(), - shutdownRequest: make(chan *shutdownReq), + shutdownRequest: make(chan struct{}), hodlMap: make(map[models.CircuitKey]hodlHtlc), hodlQueue: queue.NewConcurrentQueue(10), log: build.NewPrefixLog(logPrefix, log), @@ -1286,19 +1279,8 @@ func (l *channelLink) htlcManager() { ) } - case req := <-l.shutdownRequest: - // If the channel is clean, we send nil on the err chan - // and return to prevent the htlcManager goroutine from - // processing any more updates. The full link shutdown - // will be triggered by RemoveLink in the peer. - if l.channel.IsChannelClean() { - req.err <- nil - return - } - - // Otherwise, the channel has lingering updates, send - // an error and continue. - req.err <- ErrLinkFailedShutdown + case <-l.shutdownRequest: + return case <-l.quit: return @@ -2792,26 +2774,14 @@ func (l *channelLink) HandleChannelUpdate(message lnwire.Message) { l.mailBox.AddMessage(message) } -// ShutdownIfChannelClean triggers a link shutdown if the channel is in a clean +// ShutdownHtlcManager triggers a link shutdown if the channel is in a clean // state and errors if the channel has lingering updates. // // NOTE: Part of the ChannelUpdateHandler interface. -func (l *channelLink) ShutdownIfChannelClean() error { - errChan := make(chan error, 1) - +func (l *channelLink) ShutdownHtlcManager() { select { - case l.shutdownRequest <- &shutdownReq{ - err: errChan, - }: + case l.shutdownRequest <- struct{}{}: case <-l.quit: - return ErrLinkShuttingDown - } - - select { - case err := <-errChan: - return err - case <-l.quit: - return ErrLinkShuttingDown } } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 37e306559d..ea103e0d4e 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -6490,92 +6490,99 @@ func TestPendingCommitTicker(t *testing.T) { } } -// TestShutdownIfChannelClean tests that a link will exit the htlcManager loop -// if and only if the underlying channel state is clean. -func TestShutdownIfChannelClean(t *testing.T) { +func TestFlushInvokesCallbackWhenDrained(t *testing.T) { t.Parallel() const chanAmt = btcutil.SatoshiPerBitcoin * 5 - const chanReserve = btcutil.SatoshiPerBitcoin * 1 - aliceLink, bobChannel, batchTicker, start, _, err := - newSingleLinkTestHarness(t, chanAmt, chanReserve) + const reserve = btcutil.SatoshiPerBitcoin * 1 + aliceLink, bobChannel, _, start, _, err := + newSingleLinkTestHarness(t, chanAmt, reserve) require.NoError(t, err) - var ( - coreLink = aliceLink.(*channelLink) - aliceMsgs = coreLink.cfg.Peer.(*mockPeer).sentMsgs - ) + coreLink := aliceLink.(*channelLink) + aliceMsgs := coreLink.cfg.Peer.(*mockPeer).sentMsgs - shutdownAssert := func(expectedErr error) { - err = aliceLink.ShutdownIfChannelClean() - if expectedErr != nil { - require.Error(t, err, expectedErr) - } else { - require.NoError(t, err) - } + if err := start(); err != nil { + t.Fatalf("unable to start test harness: %v", err) } - err = start() - require.NoError(t, err) - ctx := linkTestContext{ - t: t, - aliceLink: aliceLink, + t: t, + aliceLink: aliceLink, bobChannel: bobChannel, - aliceMsgs: aliceMsgs, + aliceMsgs: aliceMsgs, } - // First send an HTLC from Bob to Alice and assert that the link can't - // be shutdown while the update is outstanding. + flushFinished := make(chan struct{}) + assertFlushFinished := func(exp bool) { + select { + case <-flushFinished: + if !exp { + t.Fatal("flush callback invoked") + } + default: + if exp { + t.Fatal("flush callback not invoked") + } + } + } + + htlc := generateHtlc(t, coreLink, 0) - // <---add----- + // <-- add --- ctx.sendHtlcBobToAlice(htlc) - // <---sig----- + // <-- sig --- ctx.sendCommitSigBobToAlice(1) - // ----rev----> + // --- rev --> ctx.receiveRevAndAckAliceToBob() - shutdownAssert(ErrLinkFailedShutdown) - // ----sig----> + // put the link into a flush state + aliceLink.Flush(func() { + flushFinished <- struct{}{} + }) + assertFlushFinished(false) + + // --- sig --> ctx.receiveCommitSigAliceToBob(1) - shutdownAssert(ErrLinkFailedShutdown) + assertFlushFinished(false) - // <---rev----- + // <-- rev --- ctx.sendRevAndAckBobToAlice() - shutdownAssert(ErrLinkFailedShutdown) + assertFlushFinished(false) - // ---settle--> + // --- set --> ctx.receiveSettleAliceToBob() - shutdownAssert(ErrLinkFailedShutdown) + assertFlushFinished(false) - // ----sig----> + // --- sig --> ctx.receiveCommitSigAliceToBob(0) - shutdownAssert(ErrLinkFailedShutdown) + assertFlushFinished(false) - // <---rev----- + // <-- rev --- ctx.sendRevAndAckBobToAlice() - shutdownAssert(ErrLinkFailedShutdown) + assertFlushFinished(false) // There is currently no controllable breakpoint between Alice // receiving the CommitSig and her sending out the RevokeAndAck. As // soon as the RevokeAndAck is generated, the channel becomes clean. // This can happen right after the CommitSig is received, so there is // no shutdown assertion here. - // <---sig----- + // <-- sig --- ctx.sendCommitSigBobToAlice(0) - // ----rev----> + // --- rev --> ctx.receiveRevAndAckAliceToBob() - shutdownAssert(nil) + <-flushFinished +} - // Now that the link has exited the htlcManager loop, attempt to - // trigger the batch ticker. It should not be possible. - select { - case batchTicker <- time.Now(): - t.Fatalf("expected batch ticker to be inactive") - case <-time.After(5 * time.Second): - } +func TestFlushBlocksAdds(t *testing.T) { + // TODO +} + + +func TestFlushNotEligibleToFwd(t *testing.T) { + // TODO } // TestPipelineSettle tests that a link should only pipeline a settle if the diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 6db376d8aa..3abead528f 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -898,7 +898,7 @@ func (f *mockChannelLink) ChannelPoint() *wire.OutPoint { return func (f *mockChannelLink) Stop() {} func (f *mockChannelLink) EligibleToForward() bool { return f.eligible } func (f *mockChannelLink) MayAddOutgoingHtlc(lnwire.MilliSatoshi) error { return nil } -func (f *mockChannelLink) ShutdownIfChannelClean() error { return nil } +func (f *mockChannelLink) ShutdownHtlcManager() {} func (f *mockChannelLink) setLiveShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid } func (f *mockChannelLink) IsUnadvertised() bool { return f.unadvertised } func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) { diff --git a/peer/brontide.go b/peer/brontide.go index 4c43d4a474..b766a32304 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -3136,19 +3136,18 @@ func (p *Brontide) tryLinkShutdown(cid lnwire.ChannelID) error { return ErrChannelNotFound } - // Else, the link exists, so attempt to trigger shutdown. If this - // fails, we'll send an error message to the remote peer. - if err := chanLink.ShutdownIfChannelClean(); err != nil { - return err - } - - // Next, we remove the link from the switch to shut down all of the - // link's goroutines and remove it from the switch's internal maps. We - // don't call WipeChannel as the channel must still be in the - // activeChannels map to process coop close messages. - p.cfg.Switch.RemoveLink(cid) - - return nil + return chanLink.Flush(func() { + // Else, the link exists, so attempt to trigger shutdown. If + // this fails, we'll send an error message to the remote peer. + chanLink.ShutdownHtlcManager() + + // Next, we remove the link from the switch to shut down all of + // the link's goroutines and remove it from the switch's + // internal maps. We don't call WipeChannel as the channel must + // still be in the activeChannels map to process coop close + // messages. + p.cfg.Switch.RemoveLink(cid) + }) } // fetchLinkFromKeyAndCid fetches a link from the switch via the remote's diff --git a/peer/test_utils.go b/peer/test_utils.go index bd3a0d721b..a01e0e8b7c 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -483,7 +483,7 @@ func (m *mockUpdateHandler) EligibleToForward() bool { return false } func (m *mockUpdateHandler) MayAddOutgoingHtlc(lnwire.MilliSatoshi) error { return nil } // ShutdownIfChannelClean currently returns nil. -func (m *mockUpdateHandler) ShutdownIfChannelClean() error { return nil } +func (m *mockUpdateHandler) ShutdownHtlcManager() {} type mockMessageConn struct { t *testing.T From e6a55b433428dcf0576db75db5bf68905759cc66 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 7 Nov 2023 10:01:37 -0800 Subject: [PATCH 6/6] htlcswitch+peer: add trivial implementations for flush api to mocks --- htlcswitch/mock.go | 5 +++-- peer/test_utils.go | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 3abead528f..e0ad284321 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -906,10 +906,11 @@ func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) { return f.shortChanID, nil } func (f *mockChannelLink) Flush(onFlushed func()) error { - return errors.New("mockChannelLink does not support flush api") + onFlushed() + return nil } func (f *mockChannelLink) CancelFlush() error { - return errors.New("mockChannelLink does not support flush api") + return errors.New("no flush in progress to cancel") } func (f *mockChannelLink) IsFlushing() bool { return false diff --git a/peer/test_utils.go b/peer/test_utils.go index a01e0e8b7c..72a976ebc3 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -500,12 +500,13 @@ type mockMessageConn struct { curReadMessage []byte } -func (m *mockUpdateHandler) Flush(func()) error { - return errors.New("mockUpdateHandler does not support flush api") +func (m *mockUpdateHandler) Flush(onFlushed func()) error { + onFlushed() + return nil } func (m *mockUpdateHandler) CancelFlush() error { - return errors.New("mockUpdateHandler does not support flush api") + return errors.New("no flush in progress to cancel") } func (m *mockUpdateHandler) IsFlushing() bool {