From ec2de42848e22116173b00c98011ca2145e192ad Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Fri, 27 Sep 2024 16:04:24 -0300 Subject: [PATCH 1/3] htlcswitch: use fn.GoroutineManager Replaced the use of s.quit and s.wg with s.gm (GoroutineManager). This fixes a race condition between s.wg.Add(1) and s.wg.Wait(). Also added a test which used to fail under `-race` before this commit. --- htlcswitch/switch.go | 131 ++++++++++++++++++++++++-------------- htlcswitch/switch_test.go | 68 ++++++++++++++++++-- 2 files changed, 145 insertions(+), 54 deletions(-) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 720625f2c5a..ed240224b62 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -2,6 +2,7 @@ package htlcswitch import ( "bytes" + "context" "errors" "fmt" "math/rand" @@ -245,8 +246,8 @@ type Switch struct { // This will be retrieved by the registered links atomically. bestHeight uint32 - wg sync.WaitGroup - quit chan struct{} + // gm starts and stops tasks in goroutines and waits for them. + gm *fn.GoroutineManager // cfg is a copy of the configuration struct that the htlc switch // service was initialized with. @@ -368,8 +369,11 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { return nil, err } + gm := fn.NewGoroutineManager() + s := &Switch{ bestHeight: currentHeight, + gm: gm, cfg: &cfg, circuits: circuitMap, linkIndex: make(map[lnwire.ChannelID]ChannelLink), @@ -382,7 +386,6 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { chanCloseRequests: make(chan *ChanClose), resolutionMsgs: make(chan *resolutionMsg), resMsgStore: resStore, - quit: make(chan struct{}), } s.aliasToReal = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID) @@ -420,14 +423,14 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro ResolutionMsg: msg, errChan: errChan, }: - case <-s.quit: + case <-s.gm.Done(): return ErrSwitchExiting } select { case err := <-errChan: return err - case <-s.quit: + case <-s.gm.Done(): return ErrSwitchExiting } } @@ -493,14 +496,11 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash, // Since the attempt was known, we can start a goroutine that can // extract the result when it is available, and pass it on to the // caller. - s.wg.Add(1) - go func() { - defer s.wg.Done() - + ok := s.gm.Go(context.TODO(), func(ctx context.Context) { var n *networkResult select { case n = <-nChan: - case <-s.quit: + case <-ctx.Done(): // We close the result channel to signal a shutdown. We // don't send any result in this case since the HTLC is // still in flight. @@ -524,7 +524,11 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash, return } resultChan <- result - }() + }) + // The switch shutting down is signaled by closing the channel. + if !ok { + close(resultChan) + } return resultChan, nil } @@ -704,12 +708,19 @@ func (s *Switch) ForwardPackets(linkQuit <-chan struct{}, select { case <-linkQuit: return nil - case <-s.quit: + + case <-s.gm.Done(): return nil + default: - // Spawn a goroutine to log the errors returned from failed packets. - s.wg.Add(1) - go s.logFwdErrs(&numSent, &wg, fwdChan) + // Spawn a goroutine to log the errors returned from failed + // packets. + ok := s.gm.Go(context.TODO(), func(ctx context.Context) { + s.logFwdErrs(ctx, &numSent, &wg, fwdChan) + }) + if !ok { + return nil + } } // Make a first pass over the packets, forwarding any settles or fails. @@ -820,8 +831,8 @@ func (s *Switch) ForwardPackets(linkQuit <-chan struct{}, } // logFwdErrs logs any errors received on `fwdChan`. -func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) { - defer s.wg.Done() +func (s *Switch) logFwdErrs(ctx context.Context, num *int, wg *sync.WaitGroup, + fwdChan chan error) { // Wait here until the outer function has finished persisting // and routing the packets. This guarantees we don't read from num until @@ -836,7 +847,8 @@ func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) { log.Errorf("Unhandled error while reforwarding htlc "+ "settle/fail over htlcswitch: %v", err) } - case <-s.quit: + + case <-s.gm.Done(): log.Errorf("unable to forward htlc packet " + "htlc switch was stopped") return @@ -862,7 +874,7 @@ func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error, return nil case <-linkQuit: return ErrLinkShuttingDown - case <-s.quit: + case <-s.gm.Done(): return errors.New("htlc switch was stopped") } } @@ -940,8 +952,6 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) ( // // NOTE: This method MUST be spawned as a goroutine. func (s *Switch) handleLocalResponse(pkt *htlcPacket) { - defer s.wg.Done() - attemptID := pkt.incomingHTLCID // The error reason will be unencypted in case this a local @@ -1114,7 +1124,9 @@ func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter, // handlePacketForward is used in cases when we need forward the htlc update // from one channel link to another and be able to propagate the settle/fail // updates back. This behaviour is achieved by creation of payment circuits. -func (s *Switch) handlePacketForward(packet *htlcPacket) error { +func (s *Switch) handlePacketForward(ctx context.Context, + packet *htlcPacket) error { + switch htlc := packet.htlc.(type) { // Channel link forwarded us a new htlc, therefore we initiate the // payment circuit within our internal state so we can properly forward @@ -1123,7 +1135,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { return s.handlePacketAdd(packet, htlc) case *lnwire.UpdateFulfillHTLC: - return s.handlePacketSettle(packet) + return s.handlePacketSettle(ctx, packet) // Channel link forwarded us an update_fail_htlc message. // @@ -1132,7 +1144,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // forward it. Thus there's no need to catch `UpdateFailMalformedHTLC` // here. case *lnwire.UpdateFailHTLC: - return s.handlePacketFail(packet, htlc) + return s.handlePacketFail(ctx, packet, htlc) default: return fmt.Errorf("wrong update type: %T", htlc) @@ -1436,7 +1448,7 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, case s.chanCloseRequests <- command: return updateChan, errChan - case <-s.quit: + case <-s.gm.Done(): errChan <- ErrSwitchExiting close(updateChan) return updateChan, errChan @@ -1453,9 +1465,9 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, // total link capacity. // // NOTE: This MUST be run as a goroutine. -func (s *Switch) htlcForwarder() { - defer s.wg.Done() - +// +//nolint:funlen +func (s *Switch) htlcForwarder(ctx context.Context) { defer func() { s.blockEpochStream.Cancel() @@ -1489,6 +1501,8 @@ func (s *Switch) htlcForwarder() { var wg sync.WaitGroup for _, link := range linksToStop { wg.Add(1) + // Here it is ok to start a goroutine directly bypassing + // s.gm, because we want for them to complete here. go func(l ChannelLink) { defer wg.Done() @@ -1613,7 +1627,7 @@ out: // encounter is due to the circuit already being // closed. This is fine, as processing this message is // meant to be idempotent. - err = s.handlePacketForward(pkt) + err = s.handlePacketForward(ctx, pkt) if err != nil { log.Errorf("Unable to forward resolution msg: %v", err) } @@ -1622,21 +1636,22 @@ out: // packet concretely, then either forward it along, or // interpret a return packet to a locally initialized one. case cmd := <-s.htlcPlex: - cmd.err <- s.handlePacketForward(cmd.pkt) + cmd.err <- s.handlePacketForward(ctx, cmd.pkt) // When this time ticks, then it indicates that we should // collect all the forwarding events since the last internal, // and write them out to our log. case <-s.cfg.FwdEventTicker.Ticks(): - s.wg.Add(1) - go func() { - defer s.wg.Done() - - if err := s.FlushForwardingEvents(); err != nil { + // The error of Go is ignored: if it is shutting down, + // the loop will terminate on the next iteration, in + // s.gm.Done case. + _ = s.gm.Go(ctx, func(ctx context.Context) { + err := s.FlushForwardingEvents() + if err != nil { log.Errorf("Unable to flush "+ "forwarding events: %v", err) } - }() + }) // The log ticker has fired, so we'll calculate some forwarding // stats for the last 10 seconds to display within the logs to @@ -1739,7 +1754,7 @@ out: // memory. s.pendingSettleFails = s.pendingSettleFails[:0] - case <-s.quit: + case <-s.gm.Done(): return } } @@ -1760,8 +1775,17 @@ func (s *Switch) Start() error { } s.blockEpochStream = blockEpochStream - s.wg.Add(1) - go s.htlcForwarder() + ok := s.gm.Go(context.TODO(), func(ctx context.Context) { + s.htlcForwarder(ctx) + }) + if !ok { + // We are already stopping so we can ignore the error. + _ = s.Stop() + err = fmt.Errorf("unable to start htlc forwarder: %w", + ErrSwitchExiting) + log.Errorf("%v", err) + return err + } if err := s.reforwardResponses(); err != nil { s.Stop() @@ -1991,9 +2015,8 @@ func (s *Switch) Stop() error { log.Info("HTLC Switch shutting down...") defer log.Debug("HTLC Switch shutdown complete") - close(s.quit) - - s.wg.Wait() + // Ask running goroutines to stop and wait for them. + s.gm.Stop() // Wait until all active goroutines have finished exiting before // stopping the mailboxes, otherwise the mailbox map could still be @@ -2349,7 +2372,7 @@ func (s *Switch) RemoveLink(chanID lnwire.ChannelID) { select { case <-stopChan: return - case <-s.quit: + case <-s.gm.Done(): return } } @@ -2990,7 +3013,9 @@ func (s *Switch) handlePacketAdd(packet *htlcPacket, } // handlePacketSettle handles forwarding a settle packet. -func (s *Switch) handlePacketSettle(packet *htlcPacket) error { +func (s *Switch) handlePacketSettle(ctx context.Context, + packet *htlcPacket) error { + // If the source of this packet has not been set, use the circuit map // to lookup the origin. circuit, err := s.closeCircuit(packet) @@ -3029,8 +3054,12 @@ func (s *Switch) handlePacketSettle(packet *htlcPacket) error { // NOTE: `closeCircuit` modifies the state of `packet`. if localHTLC { // TODO(yy): remove the goroutine and send back the error here. - s.wg.Add(1) - go s.handleLocalResponse(packet) + ok := s.gm.Go(ctx, func(ctx context.Context) { + s.handleLocalResponse(packet) + }) + if !ok { + return ErrSwitchExiting + } // If this is a locally initiated HTLC, there's no need to // forward it so we exit. @@ -3066,7 +3095,7 @@ func (s *Switch) handlePacketSettle(packet *htlcPacket) error { } // handlePacketFail handles forwarding a fail packet. -func (s *Switch) handlePacketFail(packet *htlcPacket, +func (s *Switch) handlePacketFail(ctx context.Context, packet *htlcPacket, htlc *lnwire.UpdateFailHTLC) error { // If the source of this packet has not been set, use the circuit map @@ -3085,8 +3114,12 @@ func (s *Switch) handlePacketFail(packet *htlcPacket, // NOTE: `closeCircuit` modifies the state of `packet`. if packet.incomingChanID == hop.Source { // TODO(yy): remove the goroutine and send back the error here. - s.wg.Add(1) - go s.handleLocalResponse(packet) + ok := s.gm.Go(ctx, func(ctx context.Context) { + s.handleLocalResponse(packet) + }) + if !ok { + return ErrSwitchExiting + } // If this is a locally initiated HTLC, there's no need to // forward it so we exit. diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 88093214607..fbd9bf4bfaf 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -8,6 +8,7 @@ import ( "io" mrand "math/rand" "reflect" + "sync" "testing" "time" @@ -3159,6 +3160,60 @@ func TestSwitchGetAttemptResult(t *testing.T) { } } +// TestSwitchGetAttemptResultStress runs series of GetAttemptResult and Stop in +// parallel to make sure there is no race condition between these actions. +func TestSwitchGetAttemptResultStress(t *testing.T) { + t.Parallel() + + const paymentID = 123 + + s, err := initSwitchWithTempDB(t, testStartingHeight) + require.NoError(t, err, "unable to init switch") + require.NoError(t, s.Start(), "unable to start switch") + + lookup := make(chan *PaymentCircuit, 1) + s.circuits = &mockCircuitMap{ + lookup: lookup, + } + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + + for range 1000 { + // Next let the lookup find the circuit in the circuit + // map. It should subscribe to payment results, and + // return the result when available. + lookup <- &PaymentCircuit{} + _, err := s.GetAttemptResult( + paymentID, lntypes.Hash{}, + newMockDeobfuscator(), + ) + require.NoError(t, err, "unable to get payment result") + } + }() + + // Run s.Stop() in parallel with consecutive GetAttemptResult calls to + // make sure this doesn't result in a race condition. + wg.Add(1) + go func() { + defer wg.Done() + + // Sleep 10ms to let several GetAttemptResult calls happen, so + // s.Stop() happens in the middle of GetAttemptResult series. + // The value 10ms was found empirically - this time is needed + // to expose the race condition (as a crash under -race) in the + // version of Switch before GoroutineManager was added. + time.Sleep(10 * time.Millisecond) + + require.NoError(t, s.Stop()) + }() + + wg.Wait() +} + // TestInvalidFailure tests that the switch returns an unreadable failure error // if the failure cannot be decrypted. func TestInvalidFailure(t *testing.T) { @@ -4953,7 +5008,7 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { // Pull packet from Bob's link, and do nothing with it. select { case <-bobLink.packets: - case <-s.quit: + case <-s.gm.Done(): t.Fatal("switch shutting down, failed to forward packet") } @@ -5012,7 +5067,8 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) require.Equal(t, aliceAlias, failMsg.Update.ShortChannelID) - case <-s2.quit: + + case <-s2.gm.Done(): t.Fatal("switch shutting down, failed to forward packet") } } @@ -5193,7 +5249,8 @@ func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) { failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) - case <-s.quit: + + case <-s.gm.Done(): t.Fatal("switch shutting down, failed to receive fail packet") } } @@ -5393,7 +5450,8 @@ func testSwitchHandlePacketForward(t *testing.T, zeroConf, private, failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum) require.True(t, ok) require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) - case <-s.quit: + + case <-s.gm.Done(): t.Fatal("switch shutting down, failed to receive failure") } } @@ -5549,7 +5607,7 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { isAlias := failScid == aliceAlias || failScid == aliceAlias2 require.True(t, isAlias) - case <-s.quit: + case <-s.gm.Done(): t.Fatalf("switch shutting down, failed to receive failure") } From 91c27f1c1aecc2ca64a6e20a5a9a2d20f95c8344 Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Fri, 11 Oct 2024 11:04:35 -0300 Subject: [PATCH 2/3] htlcswitch: fix linter warnings --- htlcswitch/switch.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index ed240224b62..0d5b061d035 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -1764,6 +1764,7 @@ out: func (s *Switch) Start() error { if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { log.Warn("Htlc Switch already started") + return errors.New("htlc switch already started") } @@ -1784,12 +1785,15 @@ func (s *Switch) Start() error { err = fmt.Errorf("unable to start htlc forwarder: %w", ErrSwitchExiting) log.Errorf("%v", err) + return err } if err := s.reforwardResponses(); err != nil { - s.Stop() + // We are already stopping so we can ignore the error. + _ = s.Stop() log.Errorf("unable to reforward responses: %v", err) + return err } @@ -1797,6 +1801,7 @@ func (s *Switch) Start() error { // We are already stopping so we can ignore the error. _ = s.Stop() log.Errorf("unable to reforward resolutions: %v", err) + return err } From c70fcbe51c7e709b908156f7696a800e214d57d3 Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Fri, 13 Dec 2024 01:05:55 -0300 Subject: [PATCH 3/3] docs: add release notes entry --- docs/release-notes/release-notes-0.19.0.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index 2d4edf1e204..c77ff729312 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -63,6 +63,9 @@ * [Fixed a bug](https://github.com/lightningnetwork/lnd/pull/9322) that caused estimateroutefee to ignore the default payment timeout. +* [Fixed a possible crash of htlcswitch upon shutdown](https://github.com/lightningnetwork/lnd/pull/9140) + caused by a race condition in goroutines tracking mechanism. + # New Features * [Support](https://github.com/lightningnetwork/lnd/pull/8390) for